In [44]:
from typing import List
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import HTMLResponse
import torch
from PIL import Image
import torchvision.transforms as transforms
from torchvision import models
import io

data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(size=(128, 128)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomResizedCrop((128, 128), scale=(0.8, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=2),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

def load_model(model_class, exp_name, models_dir='models'):
    """
    Load a previously saved model.
    
    Args:
    - model_class: The class of the model to be loaded.
    - exp_name: Name of the experiment or model.
    - models_dir: Directory where models are saved.
    
    Returns:
    - model: The loaded model.
    """
    model_path = os.path.join(models_dir, f"{exp_name}.pt")
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model '{exp_name}' not found in directory '{models_dir}'")
    
    checkpoint = torch.load(model_path)
    config = checkpoint['config']
    
    model = model_class(config)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Model loaded from: {model_path}")
    return model

def get_model():
    model = models.vit_b_16(pretrained=False)
    return model

def load_model_dpt():
    model_path = r"D:\ml projects\ViT\models\fine_tuned_vit_b_16_on_intel_image_dataset.pt"
    try:
        fine_tuned = get_model()
        state_dict = torch.load(model_path, map_location=torch.device('cpu'))
        fine_tuned.load_state_dict(state_dict)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        fine_tuned.to(device)
        fine_tuned.eval()
        return fine_tuned
    except Exception as e:
        print(f"Error loading model: {e}")
        return None

def pre_process_dpt(image_bytes, transform):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    image = Image.open(io.BytesIO(image_bytes))
    input_batch = transform(image).to(device)
    return input_batch, image

@torch.no_grad()
def predict_image(image, model, classes=('buildings', 'forest', 'glacier', 'mountain', 'sea', 'street'), device="cuda"):
    model.eval()
    image = image.unsqueeze(0).to(device)
    model = model.to(device)
    logits = model(image)
    predicted_class = torch.argmax(logits, dim=1).item()
    return classes[predicted_class]

app = FastAPI()
model = load_model_dpt()






Error loading model: Error(s) in loading state_dict for VisionTransformer:
	Unexpected key(s) in state_dict: "classifier.0.weight", "classifier.0.bias". 
	size mismatch for heads.head.weight: copying a param with shape torch.Size([6, 768]) from checkpoint, the shape in current model is torch.Size([1000, 768]).
	size mismatch for heads.head.bias: copying a param with shape torch.Size([6]) from checkpoint, the shape in current model is torch.Size([1000]).


In [50]:
import torch
import torchvision.models as models
import os

# Step 1: Define the model architecture
def get_model(num_classes=6):
    model = models.vit_b_16(weights=None)  # Change 'pretrained' to 'weights=None'
    model.heads[-1] = torch.nn.Linear(in_features=model.heads[-1].in_features, out_features=num_classes, bias=True)
    return model

# Step 2: Load the Model with Flexible State_dict Loading
def load_model(model_path, num_classes=6):
    # Initialize the model architecture with the correct number of classes
    model = get_model(num_classes=num_classes)

    # Load the state_dict
    state_dict = torch.load(model_path, map_location=torch.device('cpu'))

    # Filter out keys from state_dict that do not belong to the model
    model_state_dict = model.state_dict()
    state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}

    # Load the filtered state_dict
    model.load_state_dict(state_dict)

    # If using GPU, move the model to CUDA device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Set the model to evaluation mode
    model.eval()

    return model

# Step 3: Path where the model is saved
model_path = './models/fine_tuned_vit_b_16_on_intel_image_dataset.pt'

# Step 4: Load the model
model = load_model(model_path, num_classes=6)

# Optional: Verify if the model loaded successfully
if model:
    print("Model loaded successfully.")
else:
    print("Failed to load the model.")


Model loaded successfully.


In [51]:
print(model)

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

In [None]:
@app.post("/uploadfiles/")
async def create_upload_files(files: List[UploadFile] = File(...)):
    if model is None:
        return {"error": "Model failed to load. Please check the model path and try again."}
    
    for file in files:
        contents = await file.read()
        input_batch, _ = pre_process_dpt(contents, data_transform)
        class_type = predict_image(input_batch, model)
        return {"class_type": class_type}

@app.get("/")
async def main():
    content = """
    <body>
        <h3>Upload an image to get classification</h3>
        <form action="/uploadfiles/" enctype="multipart/form-data" method="post">
            <input name="files" type="file" multiple>
            <input type="submit">
        </form>
    </body>
    """
    return HTMLResponse(content=content)