# Spatio-Temporal Adaptive Fusion Transformer (STAFT) - Variant B Training

This notebook implements the training pipeline for **Variant B** (SK-ResNeXt-50 Encoder + U-Net++ Decoder with Deep Supervision) for Land Cover Classification.

## 1. Setup Environment

In [None]:
# Mount Google Drive to access the dataset
from google.colab import drive
drive.mount('/content/drive')

# Create project directory if it doesn't exist
import os
os.makedirs('/content/drive/MyDrive/STAFT_Project/checkpoints', exist_ok=True)

In [None]:
# Install necessary libraries
!pip install segmentation-models-pytorch albumentations timm

## 2. Define Model (Variant B)

In [None]:
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp

class VariantB(nn.Module):
    """
    Variant B â€“ SK-ResNeXt Encoder + Dense (U-Net++) Decoder
    """
    def __init__(self, in_channels=4, classes=24, deep_supervision=True):
        super(VariantB, self).__init__()
        
        # U-Net++ with SK-ResNeXt-50 encoder
        self.model = smp.UnetPlusPlus(
            encoder_name="skresnext50_32x4d",
            encoder_weights="imagenet",
            in_channels=in_channels,
            classes=classes,
            decoder_use_batchnorm=True,
            deep_supervision=deep_supervision,
        )

    def forward(self, x):
        return self.model(x)

class DeepSupervisionLoss(nn.Module):
    def __init__(self, weights=None, ignore_index=None):
        super(DeepSupervisionLoss, self).__init__()
        self.ce_loss = nn.CrossEntropyLoss(weight=weights, ignore_index=ignore_index if ignore_index is not None else -100)
        self.dice_loss = smp.losses.DiceLoss(mode='multiclass', ignore_index=ignore_index)

    def forward(self, outputs, target):
        loss = 0
        # Deep supervision returns a list of outputs
        if isinstance(outputs, (list, tuple)):
            for output in outputs:
                ce = self.ce_loss(output, target)
                dice = self.dice_loss(output, target)
                loss += (ce + dice)
            loss /= len(outputs)
        else:
            loss = self.ce_loss(outputs, target) + self.dice_loss(outputs, target)
        return loss

## 3. Dataset Class

In [None]:
from torch.utils.data import Dataset
import numpy as np
import glob
import cv2

class LandCoverDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        
        # TODO: Update these paths to match your folder structure on Drive
        # Example: /content/drive/MyDrive/STAFT_Project/data/train/images/*.tif
        # self.image_paths = sorted(glob.glob(os.path.join(root_dir, split, 'images', '*.tif')))
        # self.mask_paths = sorted(glob.glob(os.path.join(root_dir, split, 'masks', '*.tif')))
        
        # --- DUMMY DATA FOR TESTING (Remove when real data is ready) ---
        self.image_paths = ["dummy"] * 100 
        self.mask_paths = ["dummy"] * 100
        # ----------------------------------------------------------------

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # --- REAL LOADING LOGIC (Uncomment this) ---
        # img_path = self.image_paths[idx]
        # mask_path = self.mask_paths[idx]
        # image = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # 4 channels
        # mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        
        # --- DUMMY LOGIC (Remove this) ---
        image = np.random.rand(256, 256, 4).astype(np.float32)
        mask = np.random.randint(0, 24, (256, 256)).astype(np.uint8)
        # ---------------------------------

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        if not isinstance(image, torch.Tensor):
            image = torch.from_numpy(image).permute(2, 0, 1)
        if not isinstance(mask, torch.Tensor):
            mask = torch.from_numpy(mask).long()

        return image, mask

## 4. Training Loop

In [None]:
import time
from torch.utils.data import DataLoader, random_split
from tqdm.notebook import tqdm
import torch.optim as optim

# Configuration
DATA_DIR = '/content/drive/MyDrive/STAFT_Project/data'
CHECKPOINT_DIR = '/content/drive/MyDrive/STAFT_Project/checkpoints'
BATCH_SIZE = 8
EPOCHS = 20
LR = 5e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {DEVICE}")

# Prepare Data
dataset = LandCoverDataset(root_dir=DATA_DIR, split='train')
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# Initialize Model
model = VariantB(in_channels=4, classes=24, deep_supervision=True).to(DEVICE)
criterion = DeepSupervisionLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

# Training Function
def train_epoch(model, loader):
    model.train()
    running_loss = 0.0
    pbar = tqdm(loader, desc="Training", leave=False)
    
    for images, masks in pbar:
        images, masks = images.to(DEVICE), masks.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        pbar.set_postfix({'loss': running_loss / (pbar.n + 1)})
        
    return running_loss / len(loader)

def validate(model, loader):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, masks in tqdm(loader, desc="Validating", leave=False):
            images, masks = images.to(DEVICE), masks.to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs, masks)
            running_loss += loss.item()
            
            # Accuracy on final output (assuming index 0 is final in deep supervision list)
            if isinstance(outputs, (list, tuple)):
                final = outputs[0]
            else:
                final = outputs
                
            preds = torch.argmax(final, dim=1)
            correct += (preds == masks).sum().item()
            total += torch.numel(preds)
            
    return running_loss / len(loader), correct / total

# Main Loop
best_loss = float('inf')

print("Starting Training...")
for epoch in range(EPOCHS):
    train_loss = train_epoch(model, train_loader)
    val_loss, val_acc = validate(model, val_loader)
    
    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
    
    scheduler.step(val_loss)
    
    # Save Checkpoint
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, 'best_model_variant_b.pth'))
        print("Saved Best Model!")
        
    torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, 'last_model_variant_b.pth'))