In [1]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import numpy as np
import os
import shutil
import time

# --- Configuration ---
CLASS_NAMES = ["false_color", "ndvi", "SARV2_resized", "swir", "true_color", "urban"]
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
local_path = "/home/gaurav/scratch/interiit/gaurav/checkpoint/current_training_model6.pt"
IMAGE_PATH = "/home/gaurav/scratch/interiit/EarthMind-Bench/img/test/rgb/img/dfc2023_test_P_0614.png"



In [2]:
def single_shot_predict():
    total_start = time.time()
    
    # 1. Faster I/O: Copy to Local Disk
    # Only useful if you might run this script again. 
    # If strictly once ever, direct load might be equal speed.
    
    # 2. Architecture
    model = models.resnet50(weights=None)
    model.fc = nn.Sequential(
        nn.Linear(model.fc.in_features, 512),
        nn.BatchNorm1d(512),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(512, len(CLASS_NAMES))
    )

    # 3. Load Weights
    checkpoint = torch.load(local_path, map_location=DEVICE)
    state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
    clean_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model.load_state_dict(clean_state_dict)

    # 4. Optimization: FP16
    if DEVICE.type == 'cuda':
        model.half()
    model.to(DEVICE)
    model.eval()

    # 5. Preprocess
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    image = Image.open(IMAGE_PATH).convert("RGB")
    input_tensor = transform(image).unsqueeze(0).to(DEVICE)
    if DEVICE.type == 'cuda':
        input_tensor = input_tensor.half()

    # 6. Predict (No Warmup)
    with torch.no_grad():
        outputs = model(input_tensor)
        probs = torch.nn.functional.softmax(outputs, dim=1)

    # Result
    probs_list = probs.float().cpu().numpy().flatten()
    result = {k: float(v) for k, v in zip(CLASS_NAMES, probs_list)}
    
    print(f"Total Execution Time: {time.time() - total_start:.4f}s")
    return result



In [3]:
# Execute
probs = single_shot_predict()
top_class = max(probs, key=probs.get)
print(f"Result: {top_class} ({probs[top_class]*100:.2f}%)")

  checkpoint = torch.load(local_path, map_location=DEVICE)


Total Execution Time: 3.4238s
Result: true_color (77.20%)
