In [1]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import models, transforms
from PIL import Image
import shutil

# -------------------------------
# CONFIG
# -------------------------------
CLASS_NAMES = ["false_color", "ndvi", "SARV2_resized", "swir", "true_color", "urban"]
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MODEL_PATH = "/home/gaurav/scratch/interiit/gaurav/checkpoint/current_training_model6.pt"
ROOT_DATASET = "/home/gaurav/scratch/interiit/GAURAV_BIG_DATA/spacenet6"        # <-- SET THIS
SORTED_OUTPUT_FOLDER = "predicted_sorted"      # optional


# -------------------------------------------
# LOAD MODEL (same architecture as training)
# -------------------------------------------
model = models.resnet50(weights=None)
model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, len(CLASS_NAMES))
)

state = torch.load(MODEL_PATH, map_location=DEVICE)
state = {k.replace("_orig_mod.", ""): v for k, v in state.items()}
model.load_state_dict(state)

if DEVICE.type == "cuda":
    model = model.half()

model.to(DEVICE)
model.eval()


# -------------------------------------------
# PREPROCESSING PIPELINE
# -------------------------------------------
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])


# -------------------------------------------
# 1. Predict a SINGLE image
# -------------------------------------------
def predict_image(image_path):
    img = Image.open(image_path)
    if img.mode != "RGB":
        img = img.convert("RGB")

    tensor = transform(img).unsqueeze(0).to(DEVICE)
    if DEVICE.type == "cuda":
        tensor = tensor.half()

    with torch.no_grad():
        outputs = model(tensor)
        probs = torch.softmax(outputs, dim=1)[0]

    top_idx = probs.argmax().item()
    top_class = CLASS_NAMES[top_idx]
    top_prob = float(probs[top_idx])

    return top_class, top_prob


# -------------------------------------------
# 2. Predict an ENTIRE folder recursively
# -------------------------------------------
def predict_folder(root_dir):
    predictions = []  # (path, predicted_class, prob)

    for r, d, files in os.walk(root_dir):
        for f in files:
            if f.lower().endswith((".png", ".jpg", ".jpeg")):
                full_path = os.path.join(r, f)
                cls, pr = predict_image(full_path)
                predictions.append((full_path, cls, pr))

    return predictions


# -------------------------------------------
# 3. Count predictions per class
# -------------------------------------------
def count_predictions(pred_list):
    counts = {cls: 0 for cls in CLASS_NAMES}
    for _, cls, _ in pred_list:
        counts[cls] += 1
    return counts


# -------------------------------------------
# 4. OPTIONAL: Save images sorted by prediction
# -------------------------------------------
def save_by_prediction(pred_list, out_root=SORTED_OUTPUT_FOLDER):
    os.makedirs(out_root, exist_ok=True)

    for cls in CLASS_NAMES:
        os.makedirs(os.path.join(out_root, cls), exist_ok=True)

    for path, cls, _ in pred_list:
        fname = os.path.basename(path)
        dest = os.path.join(out_root, cls, fname)
        shutil.copy(path, dest)


# -------------------------------------------
# RUN EVERYTHING
# -------------------------------------------
if __name__ == "__main__":
    print("\nRunning predictions on folder:", ROOT_DATASET)

    preds = predict_folder(ROOT_DATASET)

    print("\nPrediction Results (first 10):")
    for p in preds[:10]:
        print(p)

    # Count classes
    counts = count_predictions(preds)
    print("\nPrediction counts:")
    for cls, n in counts.items():
        print(f"{cls}: {n}")

    print(f"\nTotal images processed: {len(preds)}")

    # Sort into folders (optional)
    # save_by_prediction(preds)
    # print(f"\nSorted images saved to: {SORTED_OUTPUT_FOLDER}")


  state = torch.load(MODEL_PATH, map_location=DEVICE)



Running predictions on folder: /home/gaurav/scratch/interiit/GAURAV_BIG_DATA/spacenet6

Prediction Results (first 10):
('/home/gaurav/scratch/interiit/GAURAV_BIG_DATA/spacenet6/falsecolor_quads/SN6_Train_AOI_11_Rotterdam_PS-RGBNIR_20190822083256_20190822083600_tile_10927_BL.png', 'false_color', 0.4326171875)
('/home/gaurav/scratch/interiit/GAURAV_BIG_DATA/spacenet6/falsecolor_quads/SN6_Train_AOI_11_Rotterdam_PS-RGBNIR_20190804125251_20190804125541_tile_6883_BR.png', 'false_color', 0.60400390625)
('/home/gaurav/scratch/interiit/GAURAV_BIG_DATA/spacenet6/falsecolor_quads/SN6_Train_AOI_11_Rotterdam_PS-RGBNIR_20190823094036_20190823094408_tile_10253_BL.png', 'false_color', 0.8994140625)
('/home/gaurav/scratch/interiit/GAURAV_BIG_DATA/spacenet6/falsecolor_quads/SN6_Train_AOI_11_Rotterdam_PS-RGBNIR_20190822083256_20190822083600_tile_10923_BR.png', 'false_color', 0.62841796875)
('/home/gaurav/scratch/interiit/GAURAV_BIG_DATA/spacenet6/falsecolor_quads/SN6_Train_AOI_11_Rotterdam_PS-RGBNIR_201