In [1]:
# Cell 1: Imports
import os
from datetime import datetime
from typing import Tuple

import numpy as np
import satlaspretrain_models
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from dataLoaderAugment import SatelliteSegmentationDataset
from preprocessing import ProcessData
from sklearn.metrics import f1_score
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter

  Referenced from: <FB2FD416-6C4D-3621-B677-61F07C02A3C5> /Users/martin/anaconda3/envs/geospatial/lib/python3.9/site-packages/torchvision/image.so
  warn(


## Loss function

In [2]:
class DiceLoss(nn.Module):
    def __init__(self, multiclass=False):
        super(DiceLoss, self).__init__()
        self.multiclass = multiclass
        
    def forward(self, pred, target):
        pred = torch.softmax(pred, dim=1)
        
        if self.multiclass:
            target = F.one_hot(target, num_classes=pred.shape[1]).permute(0, 3, 1, 2).float()
            
        numerator = 2 * (pred * target).sum(dim=(2, 3))
        denominator = pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3))
        
        dice_score = 1 - (numerator + 1) / (denominator + 1)
        return dice_score.mean()

class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        
    def forward(self, pred, target):
        ce_loss = F.cross_entropy(pred, target, weight=self.alpha, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        return focal_loss.mean()

class CombinedLoss(nn.Module):
    def __init__(self, alpha=None):
        super().__init__()
        self.focal = FocalLoss(alpha=alpha, gamma=2.0)
        self.dice = DiceLoss(multiclass=True)
        
    def forward(self, pred, target):
        focal_loss = self.focal(pred, target)
        dice_loss = self.dice(pred, target)
        return 0.5 * focal_loss + 0.5 * dice_loss

## Dataloader

In [3]:
def create_dataloaders(
    train_images: np.ndarray,
    train_labels: np.ndarray,
    val_images: np.ndarray,
    val_labels: np.ndarray,
    batch_size: int = 8,
    patch_size: int = 256,
    patch_stride: int = 128,
    num_workers: int = 4
) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
    """
    Create train and validation dataloaders with the enhanced dataset.
    """
    train_dataset = SatelliteSegmentationDataset(
        images=train_images,
        labels=train_labels,
        patch_size=patch_size,
        patch_stride=patch_stride,
        augment=True,
        max_patches_per_image=32
    )
    
    val_dataset = SatelliteSegmentationDataset(
        images=val_images,
        labels=val_labels,
        patch_size=patch_size,
        patch_stride=patch_size,
        augment=False
    )
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, val_loader

## training function

In [4]:
def train_model(
    model,
    train_images,
    train_labels,
    val_images,
    val_labels,
    num_epochs=20,
    batch_size=4,
    patch_size=128,
    patch_stride=64,
    learning_rate=1e-4,
    device='cpu',
    save_dir='../models'
):
    """
    Complete training function with optimized components for satellite image segmentation.
    
    Args:
        model: The pre-trained model to fine-tune
        train_images: Training images array (N, C, H, W)
        train_labels: Training labels array (N, H, W)
        val_images: Validation images array
        val_labels: Validation labels array
        num_epochs: Number of training epochs
        batch_size: Batch size for training
        patch_size: Size of image patches
        patch_stride: Stride for patch extraction
        learning_rate: Initial learning rate
        device: Device to train on ('cuda' or 'cpu')
        save_dir: Directory to save model checkpoints
    """
    # Create dataloaders
    train_loader, val_loader = create_dataloaders(
        train_images=train_images,
        train_labels=train_labels,
        val_images=val_images,
        val_labels=val_labels,
        batch_size=batch_size,
        patch_size=patch_size,
        patch_stride=patch_stride
    )
    
    # Initialize optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=0.1
    )
    
    # Initialize learning rate scheduler
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=learning_rate,
        epochs=num_epochs,
        steps_per_epoch=len(train_loader),
        pct_start=0.2,  # 10% warmup
        div_factor=25,  # initial_lr = max_lr/10
        final_div_factor=1000  # final_lr = initial_lr/100
    )
    
    # Initialize loss function
    class_weights = torch.tensor([0.1, 1.0, 1.0, 1.0, 1.0]).to(device)
    criterion = CombinedLoss(alpha=class_weights)
    
    # Setup tensorboard and checkpoints
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    writer = SummaryWriter(f'runs/experiment_{timestamp}')
    os.makedirs(save_dir, exist_ok=True)
    
    best_val_f1 = 0.0
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        epoch_loss = 0
        num_batches = 0
        
        progress_bar = tqdm.tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        
        for batch_idx, (data, target) in enumerate(progress_bar):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output[0], target)
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            scheduler.step()
            
            # Update metrics
            epoch_loss += loss.item()
            num_batches += 1
            
            # Log learning rate
            current_lr = scheduler.get_last_lr()[0]
            writer.add_scalar('Training/LR', current_lr, 
                            epoch * len(train_loader) + batch_idx)
            
            # Update progress bar
            progress_bar.set_postfix({
                'batch_loss': f'{loss.item():.4f}',
                'avg_loss': f'{epoch_loss/num_batches:.4f}',
                'lr': f'{current_lr:.6f}'
            })
            
            # Log batch-level metrics
            writer.add_scalar('Training/BatchLoss', loss.item(), 
                            epoch * len(train_loader) + batch_idx)
        
        # Log epoch-level metrics
        avg_epoch_loss = epoch_loss / num_batches
        writer.add_scalar('Training/EpochLoss', avg_epoch_loss, epoch)
        
        # Validation phase
        model.eval()
        val_loss = 0
        val_predictions = []
        val_targets = []
        
        with torch.no_grad():
            for val_data, val_target in tqdm.tqdm(val_loader, desc='Validation'):
                val_data, val_target = val_data.to(device), val_target.to(device)
                
                val_output = model(val_data)
                batch_loss = criterion(val_output[0], val_target)
                val_loss += batch_loss.item()
                
                # Collect predictions and targets
                pred = val_output[0].argmax(dim=1).cpu().numpy()
                val_predictions.extend(pred.flatten())
                val_targets.extend(val_target.cpu().numpy().flatten())
        
        # Calculate validation metrics
        avg_val_loss = val_loss / len(val_loader)
        val_f1 = f1_score(val_targets, val_predictions, average='weighted')
        
        # Log validation metrics
        writer.add_scalar('Validation/Loss', avg_val_loss, epoch)
        writer.add_scalar('Validation/F1', val_f1, epoch)
        
        print(f'\nEpoch {epoch+1}/{num_epochs}:')
        print(f'Training Loss: {avg_epoch_loss:.4f}')
        print(f'Validation Loss: {avg_val_loss:.4f}')
        print(f'Validation F1: {val_f1:.4f}')
        
        # Save best model
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_f1': val_f1,
                'training_loss': avg_epoch_loss,
                'val_loss': avg_val_loss,
            }, f'{save_dir}/best_model_{timestamp}.pth')
        
        # Save regular checkpoint every 5 epochs
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_f1': val_f1,
                'training_loss': avg_epoch_loss,
                'val_loss': avg_val_loss,
            }, f'{save_dir}/checkpoint_epoch_{epoch+1}_{timestamp}.pth')
    
    writer.close()
    return model

# Run everything

In [5]:
if __name__ == "__main__":
    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Initialize ProcessData and model
    data = ProcessData()
    data.prepared_data, data.labels = data.load_preprocessed_data()
    data.prepared_data = data.prepared_data[:, :9, :, :]
    
    # Create train/val split
    train_sample = 20
    val_sample = 15
    
    # Initialize model
    weights_manager = satlaspretrain_models.Weights()
    model = weights_manager.get_pretrained_model(
        "Sentinel2_SwinT_SI_MS",
        fpn=True,
        head=satlaspretrain_models.Head.SEGMENT,
        num_categories=5,
        device='cpu'
    )
    model = model.to(device)
    
    # Train the model
    trained_model = train_model(
        model=model,
        train_images=data.prepared_data[:train_sample],
        train_labels=data.labels[:train_sample],
        val_images=data.prepared_data[train_sample:train_sample + val_sample],
        val_labels=data.labels[train_sample:train_sample + val_sample],
        num_epochs=20,
        batch_size=8,
        device=device
    )

Loaded preprocessed data from /Users/martin/Desktop/inf367project/INF367A-DeforestationDrivers


  weights = torch.load(weights_file, map_location=torch.device('cpu'))
Epoch 1/20: 100%|██████████| 80/80 [02:16<00:00,  1.71s/it, batch_loss=0.6870, avg_loss=0.6625, lr=0.000018]
Validation: 100%|██████████| 79/79 [01:07<00:00,  1.17it/s]



Epoch 1/20:
Training Loss: 0.6625
Validation Loss: 0.6949
Validation F1: 0.3094


Epoch 2/20: 100%|██████████| 80/80 [02:06<00:00,  1.58s/it, batch_loss=0.6749, avg_loss=0.6225, lr=0.000052]
Validation: 100%|██████████| 79/79 [01:10<00:00,  1.12it/s]



Epoch 2/20:
Training Loss: 0.6225
Validation Loss: 0.6808
Validation F1: 0.2826


Epoch 3/20: 100%|██████████| 80/80 [02:26<00:00,  1.83s/it, batch_loss=0.6414, avg_loss=0.6256, lr=0.000086]
Validation: 100%|██████████| 79/79 [01:09<00:00,  1.14it/s]



Epoch 3/20:
Training Loss: 0.6256
Validation Loss: 0.6913
Validation F1: 0.2819


Epoch 4/20: 100%|██████████| 80/80 [02:06<00:00,  1.58s/it, batch_loss=0.9397, avg_loss=0.6147, lr=0.000100]
Validation: 100%|██████████| 79/79 [01:08<00:00,  1.16it/s]



Epoch 4/20:
Training Loss: 0.6147
Validation Loss: 0.8104
Validation F1: 0.0788


Epoch 5/20: 100%|██████████| 80/80 [02:15<00:00,  1.69s/it, batch_loss=0.7428, avg_loss=0.6322, lr=0.000099]
Validation: 100%|██████████| 79/79 [01:07<00:00,  1.17it/s]



Epoch 5/20:
Training Loss: 0.6322
Validation Loss: 0.8104
Validation F1: 0.0788


Epoch 6/20: 100%|██████████| 80/80 [02:22<00:00,  1.78s/it, batch_loss=0.6478, avg_loss=0.6369, lr=0.000096]
Validation: 100%|██████████| 79/79 [01:10<00:00,  1.12it/s]



Epoch 6/20:
Training Loss: 0.6369
Validation Loss: 0.8104
Validation F1: 0.0788


Epoch 7/20: 100%|██████████| 80/80 [02:24<00:00,  1.80s/it, batch_loss=0.6707, avg_loss=0.6231, lr=0.000092]
Validation: 100%|██████████| 79/79 [01:11<00:00,  1.11it/s]



Epoch 7/20:
Training Loss: 0.6231
Validation Loss: 0.8104
Validation F1: 0.0788


Epoch 8/20: 100%|██████████| 80/80 [02:28<00:00,  1.86s/it, batch_loss=0.6764, avg_loss=0.6360, lr=0.000085]
Validation: 100%|██████████| 79/79 [01:08<00:00,  1.15it/s]



Epoch 8/20:
Training Loss: 0.6360
Validation Loss: 0.6821
Validation F1: 0.2858


Epoch 9/20: 100%|██████████| 80/80 [02:33<00:00,  1.92s/it, batch_loss=0.7649, avg_loss=0.6113, lr=0.000078]
Validation: 100%|██████████| 79/79 [01:09<00:00,  1.14it/s]



Epoch 9/20:
Training Loss: 0.6113
Validation Loss: 0.7155
Validation F1: 0.0990


Epoch 10/20: 100%|██████████| 80/80 [02:28<00:00,  1.85s/it, batch_loss=0.6994, avg_loss=0.6096, lr=0.000069]
Validation: 100%|██████████| 79/79 [01:10<00:00,  1.13it/s]



Epoch 10/20:
Training Loss: 0.6096
Validation Loss: 0.6972
Validation F1: 0.3450


Epoch 11/20: 100%|██████████| 80/80 [02:34<00:00,  1.93s/it, batch_loss=0.6845, avg_loss=0.5988, lr=0.000060]
Validation: 100%|██████████| 79/79 [01:10<00:00,  1.11it/s]



Epoch 11/20:
Training Loss: 0.5988
Validation Loss: 0.7178
Validation F1: 0.2991


Epoch 12/20: 100%|██████████| 80/80 [02:36<00:00,  1.96s/it, batch_loss=0.6459, avg_loss=0.5816, lr=0.000050]
Validation: 100%|██████████| 79/79 [01:21<00:00,  1.03s/it]



Epoch 12/20:
Training Loss: 0.5816
Validation Loss: 0.6899
Validation F1: 0.3465


Epoch 13/20: 100%|██████████| 80/80 [03:08<00:00,  2.36s/it, batch_loss=0.5820, avg_loss=0.5717, lr=0.000040]
Validation: 100%|██████████| 79/79 [01:07<00:00,  1.18it/s]



Epoch 13/20:
Training Loss: 0.5717
Validation Loss: 0.7088
Validation F1: 0.3210


Epoch 14/20: 100%|██████████| 80/80 [02:29<00:00,  1.86s/it, batch_loss=0.5560, avg_loss=0.5598, lr=0.000031]
Validation: 100%|██████████| 79/79 [01:08<00:00,  1.16it/s]



Epoch 14/20:
Training Loss: 0.5598
Validation Loss: 0.6928
Validation F1: 0.3497


Epoch 15/20: 100%|██████████| 80/80 [02:19<00:00,  1.74s/it, batch_loss=0.6246, avg_loss=0.5565, lr=0.000022]
Validation: 100%|██████████| 79/79 [01:08<00:00,  1.16it/s]



Epoch 15/20:
Training Loss: 0.5565
Validation Loss: 0.7249
Validation F1: 0.2938


Epoch 16/20: 100%|██████████| 80/80 [02:38<00:00,  1.98s/it, batch_loss=0.5972, avg_loss=0.5551, lr=0.000015]
Validation: 100%|██████████| 79/79 [01:12<00:00,  1.09it/s]



Epoch 16/20:
Training Loss: 0.5551
Validation Loss: 0.6534
Validation F1: 0.4022


Epoch 17/20: 100%|██████████| 80/80 [02:27<00:00,  1.84s/it, batch_loss=0.6514, avg_loss=0.5460, lr=0.000008]
Validation: 100%|██████████| 79/79 [01:07<00:00,  1.17it/s]



Epoch 17/20:
Training Loss: 0.5460
Validation Loss: 0.6635
Validation F1: 0.3903


Epoch 18/20: 100%|██████████| 80/80 [02:31<00:00,  1.90s/it, batch_loss=0.5600, avg_loss=0.5490, lr=0.000004]
Validation: 100%|██████████| 79/79 [01:08<00:00,  1.15it/s]



Epoch 18/20:
Training Loss: 0.5490
Validation Loss: 0.6566
Validation F1: 0.3998


Epoch 19/20: 100%|██████████| 80/80 [02:35<00:00,  1.94s/it, batch_loss=0.5645, avg_loss=0.5487, lr=0.000001]
Validation: 100%|██████████| 79/79 [01:08<00:00,  1.15it/s]



Epoch 19/20:
Training Loss: 0.5487
Validation Loss: 0.6634
Validation F1: 0.3914


Epoch 20/20: 100%|██████████| 80/80 [02:35<00:00,  1.95s/it, batch_loss=0.5760, avg_loss=0.5475, lr=0.000000]
Validation: 100%|██████████| 79/79 [01:09<00:00,  1.14it/s]



Epoch 20/20:
Training Loss: 0.5475
Validation Loss: 0.6607
Validation F1: 0.3949
