In [7]:
import os
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
from tqdm import tqdm

# ----------------------------------------------------
# CONFIG
# ----------------------------------------------------
CLASS_NAMES = ["sar", "rgb", "falsecolor"]

ROOT_FOLDER = "/home/gaurav/scratch/interiit/EarthMind-Bench/img/test"
CHECKPOINT_PATH = "/home/gaurav/scratch/interiit/gaurav/checkpoint/best_model_3classes_450.pt"

BATCH_SIZE = 64
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ----------------------------------------------------
# LOAD MODEL
# ----------------------------------------------------
def load_model(checkpoint):
    model = models.resnet50(weights=None)
    in_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Linear(in_features, 512),
        nn.BatchNorm1d(512),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(512, len(CLASS_NAMES))
    )

    state = torch.load(checkpoint, map_location=DEVICE)
    state = {k.replace("_orig_mod.", ""): v for k, v in state.items()}
    model.load_state_dict(state)

    model.to(DEVICE)
    model.eval()
    return model


# ----------------------------------------------------
# DATASET (RECURSIVE IMAGE SEARCH)
# ----------------------------------------------------
class QuadFolderDataset(Dataset):
    def __init__(self, root_dir):
        self.samples = []
        image_ext = (".png", ".jpg", ".jpeg")

        for idx, cls in enumerate(CLASS_NAMES):
            base_folder = os.path.join(root_dir, cls + "_quads")
            if not os.path.isdir(base_folder):
                raise RuntimeError(f"Missing folder: {base_folder}")

            # ðŸ”¥ recursive walk here
            for root, _, files in os.walk(base_folder):
                for fname in files:
                    if fname.lower().endswith(image_ext):
                        self.samples.append((
                            os.path.join(root, fname),
                            idx
                        ))

        print(f"Loaded {len(self.samples)} images across: {CLASS_NAMES}")

        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
        ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path).convert("RGB")
        img = self.transform(img)
        return img, label, path


# ----------------------------------------------------
# LOAD DATA
# ----------------------------------------------------
dataset = QuadFolderDataset(ROOT_FOLDER)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)


# ----------------------------------------------------
# INFERENCE + METRICS WITH TQDM
# ----------------------------------------------------
model = load_model(CHECKPOINT_PATH)

all_preds = []
all_labels = []
all_paths = []
all_probs = []

with torch.no_grad():
    for images, labels, paths in tqdm(loader, desc="Evaluating", ncols=120):
        images = images.to(DEVICE)

        outputs = model(images)
        probs = torch.softmax(outputs, dim=1)
        preds = torch.argmax(probs, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels)
        all_paths.extend(paths)
        all_probs.extend(probs.cpu().numpy())


# ----------------------------------------------------
# METRICS
# ----------------------------------------------------
accuracy = accuracy_score(all_labels, all_preds)
cm = confusion_matrix(all_labels, all_preds)
report = classification_report(all_labels, all_preds, target_names=CLASS_NAMES)

print("\n==============================")
print("ðŸ“Œ FINAL TEST METRICS")
print("==============================\n")

print(f"Accuracy: {accuracy:.4f}\n")

print("Classification Report:")
print(report)

print("\nConfusion Matrix:")
print(cm)

# Optional: save results
results = list(zip(all_paths, all_labels, all_preds, all_probs))


Loaded 4174 images across: ['sar', 'rgb', 'falsecolor']


  state = torch.load(checkpoint, map_location=DEVICE)
Evaluating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 66/66 [01:32<00:00,  1.40s/it]


ðŸ“Œ FINAL TEST METRICS

Accuracy: 0.8479

Classification Report:
              precision    recall  f1-score   support

         sar       1.00      0.77      0.87      2087
         rgb       0.82      0.92      0.87      2087
  falsecolor       0.00      0.00      0.00         0

    accuracy                           0.85      4174
   macro avg       0.61      0.57      0.58      4174
weighted avg       0.91      0.85      0.87      4174


Confusion Matrix:
[[1610  429   48]
 [   0 1929  158]
 [   0    0    0]]



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [None]:
# import os
# import torch
# import torch.nn as nn
# from torchvision import models, transforms
# from torch.utils.data import Dataset, DataLoader
# from PIL import Image
# import numpy as np
# from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
# from tqdm import tqdm

# # ----------------------------------------------------
# # CONFIG
# # ----------------------------------------------------
# CLASS_NAMES = ["sar", "rgb", "falsecolor"]

# ROOT_FOLDER = "/home/gaurav/scratch/interiit/GAURAV_BIG_DATA/spacenet6"    # contains *_quads folders
# CHECKPOINT_PATH = "/content/drive/MyDrive/Inter IIT/best_model_3classes.pt"

# BATCH_SIZE = 64
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# # ----------------------------------------------------
# # MODEL: Must match training architecture
# # ----------------------------------------------------
# def load_model(checkpoint):
#     model = models.resnet50(weights=None)
#     in_features = model.fc.in_features
#     model.fc = nn.Sequential(
#         nn.Linear(in_features, 512),
#         nn.BatchNorm1d(512),
#         nn.ReLU(),
#         nn.Dropout(0.3),
#         nn.Linear(512, len(CLASS_NAMES))
#     )

#     state = torch.load(checkpoint, map_location=DEVICE)
#     state = {k.replace("_orig_mod.", ""): v for k, v in state.items()}
#     model.load_state_dict(state)

#     model.to(DEVICE)
#     model.eval()
#     return model


# # ----------------------------------------------------
# # DATASET
# # ----------------------------------------------------
# class QuadFolderDataset(Dataset):
#     def __init__(self, root_dir):
#         self.samples = []
#         image_ext = (".png", ".jpg", ".jpeg")

#         for idx, cls in enumerate(CLASS_NAMES):
#             folder = os.path.join(root_dir, cls + "_quads")
#             if not os.path.isdir(folder):
#                 raise RuntimeError(f"Missing folder: {folder}")

#             for fname in os.listdir(folder):
#                 if fname.lower().endswith(image_ext):
#                     self.samples.append((os.path.join(folder, fname), idx))

#         print(f"Loaded {len(self.samples)} images for evaluation.")

#         self.transform = transforms.Compose([
#             transforms.Resize((256, 256)),
#             transforms.ToTensor(),
#             transforms.Normalize(
#                 mean=[0.485, 0.456, 0.406],
#                 std=[0.229, 0.224, 0.225],
#             ),
#         ])

#     def __len__(self):
#         return len(self.samples)

#     def __getitem__(self, idx):
#         path, label = self.samples[idx]
#         img = Image.open(path).convert("RGB")
#         img = self.transform(img)
#         return img, label, path


# # ----------------------------------------------------
# # LOAD DATA
# # ----------------------------------------------------
# dataset = QuadFolderDataset(ROOT_FOLDER)
# loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)


# # ----------------------------------------------------
# # INFERENCE + TQDM + METRICS
# # ----------------------------------------------------
# model = load_model(CHECKPOINT_PATH)

# all_preds = []
# all_labels = []
# all_paths = []
# all_probs = []

# print("\nRunning inference...")
# with torch.no_grad():
#     for images, labels, paths in tqdm(loader, desc="Evaluating", unit="batch"):
#         images = images.to(DEVICE)

#         outputs = model(images)
#         probs = torch.softmax(outputs, dim=1)
#         preds = torch.argmax(probs, dim=1)

#         all_preds.extend(preds.cpu().numpy())
#         all_labels.extend(labels.numpy())
#         all_paths.extend(paths)
#         all_probs.extend(probs.cpu().numpy())


# # ----------------------------------------------------
# # METRICS
# # ----------------------------------------------------
# accuracy = accuracy_score(all_labels, all_preds)
# cm = confusion_matrix(all_labels, all_preds)
# report = classification_report(all_labels, all_preds, target_names=CLASS_NAMES)

# print("\n==============================")
# print("ðŸ“Œ FINAL TEST METRICS")
# print("==============================")

# print(f"\nAccuracy: {accuracy:.4f}\n")

# print("Classification Report:")
# print(report)

# print("Confusion Matrix:")
# print(cm)

# # Structured results if needed
# results = list(zip(all_paths, all_labels, all_preds, all_probs))


In [9]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import os

# ----------------------------------------------------
# CONFIG
# ----------------------------------------------------
CLASS_NAMES = ["sar", "rgb", "falsecolor"]
CHECKPOINT_PATH = "/home/gaurav/scratch/interiit/gaurav/checkpoint/best_model_3classes_450_all_data.pt"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ----------------------------------------------------
# LOAD MODEL
# ----------------------------------------------------
def load_model(checkpoint_path):
    model = models.resnet50(weights=None)
    in_features = model.fc.in_features

    model.fc = nn.Sequential(
        nn.Linear(in_features, 512),
        nn.BatchNorm1d(512),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(512, len(CLASS_NAMES)),
    )

    state = torch.load(checkpoint_path, map_location=DEVICE)
    state = {k.replace("_orig_mod.", ""): v for k, v in state.items()}
    model.load_state_dict(state)

    model.to(DEVICE)
    model.eval()
    return model


# ----------------------------------------------------
# PREPROCESSING PIPELINE
# ----------------------------------------------------
preprocess = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])


# ----------------------------------------------------
# PREDICT ONE IMAGE
# ----------------------------------------------------
def predict_image(image_path, model):
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image not found: {image_path}")

    img = Image.open(image_path).convert("RGB")
    img_tensor = preprocess(img).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        output = model(img_tensor)
        probs = torch.softmax(output, dim=1).cpu().numpy().flatten()

    predicted_idx = int(probs.argmax())
    predicted_class = CLASS_NAMES[predicted_idx]
    confidence = float(probs[predicted_idx])

    return predicted_class, confidence, probs


# ----------------------------------------------------
# EXAMPLE USAGE
# ----------------------------------------------------
if __name__ == "__main__":
    model = load_model(CHECKPOINT_PATH)

    img_path = "/home/gaurav/scratch/interiit/GAURAV_BIG_DATA/SAR_BIG/MMRS_SAR/data/detection/SSDD/images/000006.jpg"   # update this

    pred_class, conf, probs = predict_image(img_path, model)

    print("\n==============================")
    print("ðŸ“Œ SINGLE IMAGE PREDICTION")
    print("==============================")
    print(f"Image: {img_path}")
    print(f"Predicted class:  {pred_class}")
    print(f"Confidence:       {conf:.4f}")
    print(f"All probabilities: {probs}")


  state = torch.load(checkpoint_path, map_location=DEVICE)



ðŸ“Œ SINGLE IMAGE PREDICTION
Image: /home/gaurav/scratch/interiit/GAURAV_BIG_DATA/SAR_BIG/MMRS_SAR/data/detection/SSDD/images/000006.jpg
Predicted class:  sar
Confidence:       0.9987
All probabilities: [9.9866223e-01 9.4086351e-04 3.9688803e-04]
