In [1]:
import os
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
from typing import List, Tuple

# ---------------------------------------------------
# CONFIG
# ---------------------------------------------------
CLASS_NAMES = ["sar", "rgb", "falsecolor"]

CHECKPOINT_PATH = "/home/gaurav/scratch/interiit/gaurav/checkpoint/best_model_3classes_450_all_data.pt"  # <-- update this
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------------------------------------------------
# MODEL DEFINITION (MUST MATCH TRAINING HEAD)
# ---------------------------------------------------
def load_model(num_classes=3, checkpoint_path=CHECKPOINT_PATH):
    model = models.resnet50(weights=None)
    in_features = model.fc.in_features

    model.fc = nn.Sequential(
        nn.Linear(in_features, 512),
        nn.BatchNorm1d(512),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(512, num_classes)
    )

    # Load weights
    state = torch.load(checkpoint_path, map_location=DEVICE)
    state = {k.replace("_orig_mod.", ""): v for k, v in state.items()}
    model.load_state_dict(state)

    model.to(DEVICE)
    model.eval()
    return model

# ---------------------------------------------------
# TRANSFORM (MATCH TRAINING)
# ---------------------------------------------------
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

# ---------------------------------------------------
# SINGLE IMAGE PREDICTION
# ---------------------------------------------------
def predict_image(model, image_path: str) -> Tuple[str, float]:
    img = Image.open(image_path).convert("RGB")
    x = transform(img).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        outputs = model(x)
        probs = torch.softmax(outputs, dim=1)[0].cpu().numpy()

    idx = probs.argmax()
    return CLASS_NAMES[idx], float(probs[idx])

# ---------------------------------------------------
# FOLDER PREDICTION (RECURSIVE)
# ---------------------------------------------------
def predict_folder(model, folder_path: str):
    image_extensions = (".jpg", ".jpeg", ".png", ".tif", ".tiff")
    
    results = []  # list of (image_path, class, confidence)

    for root, _, files in os.walk(folder_path):
        for f in files:
            if f.lower().endswith(image_extensions):
                p = os.path.join(root, f)
                cls, prob = predict_image(model, p)
                results.append((p, cls, prob))

    return results

# ---------------------------------------------------
# OPTIONAL: SORT OUTPUTS INTO CLASS FOLDERS
# ---------------------------------------------------
def sort_predictions(results, output_root="sorted_predictions"):
    os.makedirs(output_root, exist_ok=True)

    for cls in CLASS_NAMES:
        os.makedirs(os.path.join(output_root, cls), exist_ok=True)

    for path, cls, prob in results:
        fname = os.path.basename(path)
        dst = os.path.join(output_root, cls, fname)
        Image.open(path).save(dst)

    print(f"Sorted outputs saved to: {output_root}")


# ---------------------------------------------------
# MAIN USAGE EXAMPLES
# ---------------------------------------------------
if __name__ == "__main__":
    print("Loading model...")
    model = load_model()

    # Example 1 — predict a single image
    # test_image = "/home/gaurav/scratch/interiit/GAURAV_BIG_DATA/SAR_BIG/MMRS_SAR/data/detection/SARV2/images/2__1__0___384.png"
    # if os.path.exists(test_image):
    #     cls, prob = predict_image(model, test_image)
    #     print(f"Image: {test_image}")
    #     print(f"Predicted: {cls} (conf {prob:.4f})")

    # Example 2 — predict a folder
    folder = "/home/gaurav/scratch/interiit/gaurav/dataset"
    if os.path.isdir(folder):
        results = predict_folder(model, folder)
        print("\nFolder Results (first 10):")
        for r in results[:]:
            print(r)

        # Optional: save sorted results
        # sort_predictions(results)


Loading model...


  state = torch.load(checkpoint_path, map_location=DEVICE)



Folder Results (first 10):
('/home/gaurav/scratch/interiit/gaurav/dataset/true_sample_10_512.png', 'rgb', 0.9999666213989258)
('/home/gaurav/scratch/interiit/gaurav/dataset/swir_sample_18_512.png', 'falsecolor', 0.5035607218742371)
('/home/gaurav/scratch/interiit/gaurav/dataset/i_sample_18_512.png', 'rgb', 0.5754244327545166)
('/home/gaurav/scratch/interiit/gaurav/dataset/swir_sample_16_512.png', 'falsecolor', 0.549708366394043)
('/home/gaurav/scratch/interiit/gaurav/dataset/true_sample_19_512.png', 'rgb', 0.9986940026283264)
('/home/gaurav/scratch/interiit/gaurav/dataset/urban_sample_10_512.png', 'rgb', 0.6027699112892151)
('/home/gaurav/scratch/interiit/gaurav/dataset/IR_ship_111.png', 'rgb', 0.5665589570999146)
('/home/gaurav/scratch/interiit/gaurav/dataset/urban_sample_19_512.png', 'falsecolor', 0.9960899353027344)
('/home/gaurav/scratch/interiit/gaurav/dataset/false_sample_19_512.png', 'falsecolor', 0.999323844909668)
('/home/gaurav/scratch/interiit/gaurav/dataset/urban_sample_16

In [10]:
def cleanup():
    import gc
    gc.collect()
    torch.cuda.empty_cache()

In [11]:
cleanup()