In [None]:
import os
import sys
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

sys.path.append('..')
from models.unet_model import UNet

In [None]:
DATA_DIR = '../data_lcc'
SAR_DIR = os.path.join(DATA_DIR, 'sar')
MASK_DIR = os.path.join(DATA_DIR, 'ground_truth')
WEIGHTS_DIR = '../weights'
PREDICTIONS_DIR = '../predictions_lcc/unet'

os.makedirs(WEIGHTS_DIR, exist_ok=True)
os.makedirs(PREDICTIONS_DIR, exist_ok=True)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 8
EPOCHS = 25
LEARNING_RATE = 1e-4
IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
N_CHANNELS = 3

COLOR_MAP = {
    (65, 155, 223): 1,   # 0x419bdf -> Water
    (57, 125, 73): 2,    # 0x397d49 -> Trees
    (122, 135, 198): 4,  # 0x7a87c6 -> Flooded Vegetation
    (228, 150, 53): 5,   # 0xe49635 -> Crops
    (196, 40, 27): 7,    # 0xc4281b -> Built Area
    (165, 155, 143): 8,  # 0xa59b8f -> Bare Ground
    (168, 235, 255): 9,  # 0xa8ebff -> Snow/Ice
    (97, 97, 97): 10,    # 0x616161 -> Clouds
    (227, 226, 195): 11, # 0xe3e2c3 -> Rangeland
}

CLASS_LABELS = [1, 2, 4, 5, 7, 8, 9, 10, 11]
LABEL_TO_INDEX = {label: i for i, label in enumerate(CLASS_LABELS)}
INDEX_TO_COLOR = {v: k for k, v in COLOR_MAP.items()}

print(f"Using device: {DEVICE}")
print(f"Number of classes: {len(CLASS_LABELS)}")

In [None]:
class LandCoverDataset(Dataset):
    def __init__(self, image_dir, mask_dir, image_ids, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_ids = image_ids
        self.transform = transform

    def __len__(self):
        return len(self.image_ids)

    def _rgb_to_mask(self, rgb_mask):
        mask = np.zeros((rgb_mask.shape[0], rgb_mask.shape[1]), dtype=np.int64)
        for color, label in COLOR_MAP.items():
            locations = np.where(np.all(rgb_mask == color, axis=-1))
            mask[locations] = LABEL_TO_INDEX[label]
        return torch.from_numpy(mask)

    def __getitem__(self, idx):
        img_name = self.image_ids[idx]
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, img_name)

        image = Image.open(img_path).convert("RGB")
        mask_rgb = np.array(Image.open(mask_path).convert("RGB"))

        mask = self._rgb_to_mask(mask_rgb)

        if self.transform:
            image = self.transform(image)

        return image, mask

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

all_files = sorted([f for f in os.listdir(SAR_DIR) if f.endswith('.png')])
random.seed(42)
random.shuffle(all_files)

n_files = len(all_files)
train_split = int(n_files * 0.8)
val_split = int(n_files * 0.9)

train_ids = all_files[:train_split]
val_ids = all_files[train_split:val_split]
test_ids = all_files[val_split:]

train_dataset = LandCoverDataset(SAR_DIR, MASK_DIR, train_ids, transform=transform)
val_dataset = LandCoverDataset(SAR_DIR, MASK_DIR, val_ids, transform=transform)
test_dataset = LandCoverDataset(SAR_DIR, MASK_DIR, test_ids, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Total images: {n_files}")
print(f"Training set size: {len(train_dataset)}")
print(f"Validation set size: {len(val_dataset)}")
print(f"Test set size: {len(test_dataset)}")

In [None]:
def train_fn(loader, model, optimizer, loss_fn):
    loop = tqdm(loader, desc="Training")
    model.train()
    total_loss = 0

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.to(device=DEVICE, dtype=torch.long)

        predictions = model(data)
        loss = loss_fn(predictions, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        loop.set_postfix(loss=loss.item())
        
    return total_loss / len(loader)


def evaluate_model(loader, model, loss_fn):
    num_correct = 0
    num_pixels = 0
    total_loss = 0
    model.eval()

    with torch.no_grad():
        loop = tqdm(loader, desc="Evaluating")
        for data, targets in loop:
            data = data.to(device=DEVICE)
            targets = targets.to(device=DEVICE, dtype=torch.long)
            
            predictions = model(data)
            loss = loss_fn(predictions, targets)
            total_loss += loss.item()

            preds = torch.argmax(predictions, dim=1)
            num_correct += (preds == targets).sum()
            num_pixels += torch.numel(preds)
            
            loop.set_postfix(accuracy=f"{(num_correct/num_pixels)*100:.2f}%")

    accuracy = (num_correct / num_pixels) * 100
    avg_loss = total_loss / len(loader)
    print(f"Validation Accuracy: {accuracy:.2f}%")
    return avg_loss, accuracy

In [None]:
model = UNet(n_channels=N_CHANNELS, n_classes=len(CLASS_LABELS)).to(DEVICE)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

best_val_loss = float('inf')
model_save_path = os.path.join(WEIGHTS_DIR, "unet.pth")

print("Starting training...")

for epoch in range(EPOCHS):
    print(f"\n--- Epoch {epoch+1}/{EPOCHS} ---")
    
    train_loss = train_fn(train_loader, model, optimizer, loss_fn)
    val_loss, val_accuracy = evaluate_model(val_loader, model, loss_fn)
    
    print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}, Val Accuracy={val_accuracy:.2f}%")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), model_save_path)
        print(f"Model saved to {model_save_path}")

print("\n--- Training Finished ---")

In [None]:
def mask_to_rgb(mask_tensor, class_map):
    mask_np = mask_tensor.cpu().numpy()
    rgb_image = np.zeros((mask_np.shape[0], mask_np.shape[1], 3), dtype=np.uint8)
    
    for class_idx, color in class_map.items():
        original_label = CLASS_LABELS[class_idx]
        rgb_color = INDEX_TO_COLOR[original_label]
        rgb_image[mask_np == class_idx] = rgb_color
        
    return Image.fromarray(rgb_image)

print(f"Loading best model from {model_save_path}")
model.load_state_dict(torch.load(model_save_path))

test_loss, test_accuracy = evaluate_model(test_loader, model, loss_fn)
print(f"\n--- Test Set Performance ---")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Pixel Accuracy: {test_accuracy:.2f}%")


print(f"\nSaving test predictions to {PREDICTIONS_DIR}...")
model.eval()
with torch.no_grad():
    for i, (x, y) in enumerate(tqdm(test_dataset, desc="Saving Predictions")):
        x = x.unsqueeze(0).to(DEVICE)
        
        preds = torch.argmax(model(x), dim=1).squeeze(0)
        
        pred_rgb = mask_to_rgb(preds, {i: v for i, v in enumerate(CLASS_LABELS)})
        
        original_filename = test_ids[i]
        pred_rgb.save(os.path.join(PREDICTIONS_DIR, original_filename))

print("Predictions saved successfully.")