In [26]:
CONFIG = {
    "dataset_root": "/home/gaurav/Desktop/GeoAI/small_geo/check_data", 
    # under this path expect folders:
    #   sar_quads/
    #   rgb_quads/
    #   falsecolor_quads/

    "classes": ["sar", "rgb", "falsecolor"],

    "epochs": 2,
    "batch_size": 64,
    "learning_rate": 1e-4,
    "weight_decay": 1e-5,

    "use_augmentation": True,

    # checkpoints
    "load_checkpoint": False,
    "checkpoint_load_path": "",
    "checkpoint_save_path": "/home/gaurav/scratch/interiit/gaurav/checkpoint/best_model_3classes_450_all_data.pt",
}


In [27]:
import gc
import torch

def cleanup():
    gc.collect()
    torch.cuda.empty_cache()

In [28]:
import os
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset

class ThreeClassDataset(Dataset):
    def __init__(self, root, classes, transform=None):
        self.root = root
        self.classes = classes
        self.transform = transform

        self.samples = []   # list of (filepath, label)
        image_ext = (".png", ".jpg", ".jpeg", ".tif", ".tiff")

        print("\nScanning folders recursively...\n")

        for label, cls in enumerate(classes):
            folder = os.path.join(root, cls + "_quads")
            if not os.path.isdir(folder):
                raise ValueError(f"Folder missing: {folder}")

            count = 0

            # recursively walk inside this class folder
            for r, d, files in os.walk(folder):
                for f in files:
                    if f.lower().endswith(image_ext):
                        full_path = os.path.join(r, f)
                        self.samples.append((full_path, label))
                        count += 1

            print(f"-> {cls}: {count} images")

        print(f"\nTotal images loaded: {len(self.samples)}")
        print(f"Classes: {classes}\n")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path).convert("RGB")

        if self.transform:
            img = self.transform(img)

        return img, torch.tensor(label, dtype=torch.long)


In [29]:
from torchvision import transforms

train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

test_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])


In [30]:
from torch.utils.data import random_split, DataLoader

full_ds = ThreeClassDataset(
    root=CONFIG["dataset_root"],
    classes=CONFIG["classes"],
    transform=train_transform
)


val_size = int(0.2 * len(full_ds))
test_size = int(0.1 * len(full_ds))
train_size = len(full_ds) - val_size - test_size

train_ds, val_ds, test_ds = random_split(
    full_ds, [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

# For val/test, override their transforms
val_ds.dataset.transform = test_transform
test_ds.dataset.transform = test_transform

train_loader = DataLoader(train_ds, batch_size=CONFIG["batch_size"], shuffle=True, num_workers=2)
val_loader   = DataLoader(val_ds, batch_size=CONFIG["batch_size"], shuffle=False, num_workers=2)
test_loader  = DataLoader(test_ds, batch_size=CONFIG["batch_size"], shuffle=False, num_workers=2)



Scanning folders recursively...

-> sar: 10322 images
-> rgb: 11802 images
-> falsecolor: 11802 images

Total images loaded: 33926
Classes: ['sar', 'rgb', 'falsecolor']



In [None]:
# print(next(iter(train_loader)))

# # input_data, batch_labels = next(iter(train_loader))

# # print(input_data.shape)
# # print(batch_labels.shape)

# # Grab the raw output batch
# raw_batch = next(iter(train_loader))

# # Check how many items are in the batch
# print(f"Number of items in the batch: {len(raw_batch)}")

# # Unpack into three variables
# input_data, batch_labels = next(iter(train_loader))
# # OR
# input_data, batch_labels = next(iter(train_loader))

# print(input_data.shape)
# print(batch_labels.shape)


[tensor([[[[ 0.2967,  0.2967,  0.2967,  ...,  0.7248,  0.7077,  0.7077],
          [ 0.2796,  0.2967,  0.2967,  ...,  0.7077,  0.7077,  0.6906],
          [ 0.2624,  0.2796,  0.3138,  ...,  0.7419,  0.7419,  0.7077],
          ...,
          [ 0.9303,  0.9646,  0.9988,  ..., -0.1486, -0.1486, -0.1486],
          [ 1.0673,  1.0844,  1.1358,  ..., -0.1657, -0.1657, -0.1486],
          [ 1.1358,  1.1529,  1.1872,  ..., -0.0972, -0.0972, -0.0801]],

         [[ 0.3803,  0.3803,  0.3803,  ...,  0.7654,  0.7829,  0.7829],
          [ 0.3627,  0.3803,  0.3803,  ...,  0.7479,  0.7654,  0.7654],
          [ 0.3277,  0.3452,  0.3803,  ...,  0.7304,  0.7304,  0.7479],
          ...,
          [ 0.8004,  0.7829,  0.7654,  ..., -0.0574, -0.0399, -0.0399],
          [ 0.7479,  0.7654,  0.7654,  ..., -0.0749, -0.0749, -0.0399],
          [ 0.7479,  0.7479,  0.7304,  ..., -0.0049, -0.0049,  0.0301]],

         [[ 0.5136,  0.5136,  0.5136,  ...,  0.9842,  0.9494,  0.9494],
          [ 0.4962,  0.5136, 

In [32]:
import torch.nn as nn
from torchvision import models

def build_model(num_classes):
    model = models.resnet50(weights="IMAGENET1K_V2")
    in_features = model.fc.in_features

    model.fc = nn.Sequential(
        nn.Linear(in_features, 512),
        nn.BatchNorm1d(512),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(512, num_classes)
    )
    return model


In [33]:
import torch
import torch.optim as optim
from torch.amp import autocast, GradScaler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = build_model(num_classes=len(CONFIG["classes"])).to(device)

optimizer = optim.Adam(model.parameters(),
                       lr=CONFIG["learning_rate"],
                       weight_decay=CONFIG["weight_decay"])

criterion = nn.CrossEntropyLoss()
scaler = GradScaler()
best_val_acc = 0.0


In [34]:
from tqdm import tqdm

for epoch in range(CONFIG["epochs"]):
    model.train()
    train_correct, train_total, train_loss = 0, 0, 0

    # ---- TRAIN LOOP WITH TQDM ----
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']} - Train", leave=False)

    for imgs, labels in pbar:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()

        with autocast("cuda"):
            outputs = model(imgs)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item() * imgs.size(0)
        train_correct += (outputs.argmax(1) == labels).sum().item()
        train_total += len(labels)

        # live stats in progress bar
        pbar.set_postfix({
            "loss": f"{train_loss/train_total:.4f}",
            "acc": f"{train_correct/train_total:.4f}"
        })

    # ---- VALIDATION LOOP WITH TQDM ----
    model.eval()
    val_correct, val_total, val_loss = 0, 0, 0

    with torch.no_grad():
        pbar_val = tqdm(val_loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']} - Val", leave=False)

        for imgs, labels in pbar_val:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * imgs.size(0)
            val_correct += (outputs.argmax(1) == labels).sum().item()
            val_total += len(labels)

            pbar_val.set_postfix({
                "loss": f"{val_loss/val_total:.4f}",
                "acc": f"{val_correct/val_total:.4f}"
            })

    # ---- Epoch Summary ----
    train_acc = train_correct / train_total
    val_acc   = val_correct / val_total

    print(f"Epoch {epoch+1}/{CONFIG['epochs']} | "
          f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), CONFIG["checkpoint_save_path"])
        print("  Saved best model!")


                                                                                             

Epoch 1/2 | Train Acc: 0.9694 | Val Acc: 0.9857
  Saved best model!


                                                                                             

Epoch 2/2 | Train Acc: 0.9868 | Val Acc: 0.9857




In [35]:
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np

model.load_state_dict(torch.load(CONFIG["checkpoint_save_path"]))
model.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    for imgs, labels in tqdm(test_loader, desc="Testing"):
        imgs = imgs.to(device)
        outputs = model(imgs)
        preds = outputs.argmax(1).cpu().numpy()
        all_preds.append(preds)
        all_labels.append(labels.numpy())

y_pred = np.concatenate(all_preds)
y_true = np.concatenate(all_labels)

print("\nTest Accuracy:", (y_pred == y_true).mean())
print("\nClassification Report:")
print(classification_report(y_true, y_pred, target_names=CONFIG["classes"]))

cm = confusion_matrix(y_true, y_pred)
print("\nConfusion Matrix:\n", cm)


  model.load_state_dict(torch.load(CONFIG["checkpoint_save_path"]))
Testing: 100%|██████████| 53/53 [00:05<00:00,  9.66it/s]



Test Accuracy: 0.9823113207547169

Classification Report:
              precision    recall  f1-score   support

         sar       1.00      1.00      1.00      1040
         rgb       0.97      0.98      0.97      1176
  falsecolor       0.98      0.97      0.97      1176

    accuracy                           0.98      3392
   macro avg       0.98      0.98      0.98      3392
weighted avg       0.98      0.98      0.98      3392


Confusion Matrix:
 [[1039    1    0]
 [   0 1156   20]
 [   0   39 1137]]


In [None]:
# cleanup()