# DSM to DTM Semantic Segmentation with UperNet

This notebook clones the DSM2DTM project from GitHub and runs the training on Kaggle.

## 1. Clone GitHub Repository and Install Dependencies

In [None]:
# Clone your GitHub repository
# Update this URL to your actual GitHub repository
!git clone https://github.com/yourusername/dsm2dtm.git
%cd dsm2dtm

# Install requirements
!pip install -r requirements.txt

## 2. Import Libraries and Project Modules

In [None]:
import os
import sys
import torch
from transformers import ConvNextConfig, UperNetConfig, UperNetForSemanticSegmentation
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from tqdm import tqdm

# Import project modules
from dataloader import DSMPatchFolderDataset
from torch.utils.data import DataLoader, random_split
from utils import normalize_patch, calculate_patch_positions, is_valid_patch

# Import training functions
sys.path.append('.')
from train import calculate_miou, train_loop, validate_loop, get_train_val_loaders

## 3. Setup Dataset Path

**Important:** Upload your dataset to Kaggle as a dataset and update the path below.

In [None]:
# Configuration
# Update this path to your Kaggle dataset
dataset_root = "/kaggle/input/your-dataset-name/datasets" 

# Training parameters
batch_size = 8
patch_size = 256
num_epochs = 100
nodata_val = -9999
nodata_threshold = 0.1
number_of_classes = 2  # Binary segmentation (DTM vs non-DTM)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Check if dataset exists
if os.path.exists(dataset_root):
    print(f"Dataset found at: {dataset_root}")
    print(f"Contents: {os.listdir(dataset_root)}")
else:
    print(f"Dataset not found at: {dataset_root}")
    print("Please update the dataset_root path to your Kaggle dataset location.")

## 4. Create Data Loaders

In [None]:
train_loader, val_loader = get_train_val_loaders(
    dataset_root,
    batch_size,
    patch_size,
    nodata_val,
    nodata_threshold,
    val_split=0.2,
    num_workers=2  # Reduced for Kaggle compatibility
)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

## 5. Initialize Model

In [None]:
# Configure UperNet with ConvNeXt backbone
backbone_config = ConvNextConfig(out_features=["stage1", "stage2", "stage3", "stage4"])
config = UperNetConfig(backbone_config=backbone_config, num_labels=number_of_classes)
model = UperNetForSemanticSegmentation(config)
model.to(device)

print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")

## 6. Setup Training Components

In [None]:
# Loss function and optimizer
criterion = CrossEntropyLoss(ignore_index=255)
optimizer = AdamW(model.parameters(), lr=5e-5)

best_miou = 0.0
save_path = "pretrain"

## 7. Training Loop

In [None]:
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    # Training
    train_loop(train_loader, model, criterion, optimizer, device)
    
    # Validation
    val_loss, val_acc, val_miou = validate_loop(val_loader, model, criterion, device, number_of_classes)

    # Save best model
    if val_miou > best_miou:
        best_miou = val_miou
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        model.save_pretrained(save_path)
        print(f"New best model saved at: {save_path} with mIoU: {val_miou:.4f}")
    else:
        print(f"mIoU did not improve ({val_miou:.4f} <= {best_miou:.4f}), model not saved.")

print(f"\nTraining completed! Best mIoU: {best_miou:.4f}")

## 8. Test Inference (Optional)

In [None]:
# Test inference using the project's infer.py functions
from infer import inference_on_raster

# Load the best model for inference
model = UperNetForSemanticSegmentation.from_pretrained(save_path)
model.to(device)
model.eval()

# Example usage (uncomment and modify paths as needed)
# input_path = "/path/to/test/image.tif"
# output_path = "prediction_output.tif"
# if os.path.exists(input_path):
#     inference_on_raster(model, input_path, output_path, patch_size=256, nodata_val=-9999)

print("Inference functions are available for testing.")

## 9. Save Results Summary

In [None]:
print(f"Training completed successfully!")
print(f"Best model saved in: {save_path}")
print(f"Best validation mIoU: {best_miou:.4f}")

# List saved files
if os.path.exists(save_path):
    saved_files = os.listdir(save_path)
    print(f"Saved model files: {saved_files}")
    
# Create a results summary file
with open('training_results.txt', 'w') as f:
    f.write(f"Training completed\n")
    f.write(f"Best validation mIoU: {best_miou:.4f}\n")
    f.write(f"Total epochs: {num_epochs}\n")
    f.write(f"Batch size: {batch_size}\n")
    f.write(f"Patch size: {patch_size}\n")
    f.write(f"Number of classes: {number_of_classes}\n")
    
print("Results saved to training_results.txt")