In [2]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from PIL import Image, ImageDraw, ImageFont
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.io import read_image
from torchvision.models import vit_b_16  # Example ViT model

In [40]:
LABEL_DICT = {
    "a_Good": 0,
    "b_Moderate": 1,
    "c_Unhealthy_for_Sensitive_Groups": 2,
    "d_Unhealthy": 3,
    "e_Very_Unhealthy": 4,
    "f_Severe": 5,
}

In [3]:
def load_model(model_path="/home/sagemaker-user/mingxi/models/241027_ViT_reg_finetune.pth"):
    model = vit_b_16(weights="DEFAULT")
    model.heads = nn.Linear(model.heads[0].in_features, 1)
    model.load_state_dict(torch.load(model_path))
    return model

In [5]:
def image_to_tensor(image_path):
    img = Image.open(image_path)
    # Define transformations: resize and convert to tensor
    transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),  # Resize to 224x224
            transforms.ToTensor(),  # Convert to tensor
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

    # Apply transformations
    input_tensor = transform(img).unsqueeze(0)
    return input_tensor

In [41]:
model=load_model()

In [47]:
def aqi_prediction(image_path,model):
    image_tensor=image_to_tensor(image_path)
    prediction=model(image_tensor)
    rounded_prediction=round(prediction[0].item())
    predicted_class=list(LABEL_DICT.keys())[rounded_prediction]
    print(f'The predicted AQI class is: {predicted_class}')
    return prediction[0].item(),rounded_prediction,predicted_class

In [48]:
predicted_aqi=aqi_prediction("/home/sagemaker-user/data/w210_gasp_data/img/1.JPG",model)

The predicted AQI class is: d_Unhealthy


In [49]:
predicted_aqi

(2.6209523677825928, 3, 'd_Unhealthy')