# Google Colab

In [None]:
from google.colab import drive
drive.mount('/content/drive')
import os
os.chdir('/content/drive/MyDrive/NeuraVisionLib/')
# ! git pull

while not 'dataloaders' in os.listdir():
    os.chdir('../')
# os.listdir()

In [None]:
import os
os.listdir('../../../MyDrive/mass_dataset/')

# Main

In [None]:
from rasterio.errors import NotGeoreferencedWarning
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=NotGeoreferencedWarning)

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.models import resnet50
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import precision_score, recall_score, f1_score, jaccard_score
from dataloaders.mass_roads_dataloader import MassRoadsDataset, custom_collate_fn
from models.MULDE.models import MLPs, ScoreOrLogDensityNetwork

# --- Hyperparameters ---
# DATA_DIR = '/home/ri/Desktop/Projects/Datasets/Mass_Roads/dataset/'
DATA_DIR = '../../../MyDrive/mass_dataset/'
CHECKPOINT_PATH = 'log_density_segmentation_checkpoint.pth'
BEST_MODEL_PATH = 'best_log_density_segmentation_model.pth'
BATCH_SIZE = 4
LEARNING_RATE = 1e-4
EPOCHS = 300
ACCUMULATION_STEPS = 8
WINDOW_SIZE = 128
WINDOW_STRIDE = 64
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SIGMA_MIN = 0.01
SIGMA_MAX = 0.3
NUM_NOISE_LEVELS = 10
USE_AUTOMATIC_MIXED_PRECISION = True  # Set to True to use torch.cuda.amp.autocast

# Generate noise levels
noise_levels = torch.logspace(
    start=torch.log10(torch.tensor(SIGMA_MIN)),
    end=torch.log10(torch.tensor(SIGMA_MAX)),
    steps=NUM_NOISE_LEVELS
).to(DEVICE)

# --- Dataset Preparation ---
train_dataset = MassRoadsDataset(root_dir=DATA_DIR, split='train', window_size=WINDOW_SIZE, stride=WINDOW_STRIDE)
val_dataset = MassRoadsDataset(root_dir=DATA_DIR, split='val', window_size=WINDOW_SIZE, stride=WINDOW_STRIDE)
test_dataset = MassRoadsDataset(root_dir=DATA_DIR, split='test', window_size=WINDOW_SIZE, stride=WINDOW_STRIDE)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=custom_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=custom_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=custom_collate_fn)

def plot_patches(sat_patches, map_patches, n_patches=1):
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))

    for i in range(min(n_patches, sat_patches.shape[0])):
        sat_patch = sat_patches[i].numpy().transpose(1, 2, 0) 
        map_patch = map_patches[i].numpy()

        ax[0].imshow(sat_patch)
        ax[0].set_title(f'Satellite Patch {i+1}')
        ax[0].axis('off')

        ax[1].imshow(map_patch, cmap='gray')
        ax[1].set_title(f'Map Patch {i+1}')
        ax[1].axis('off')

    plt.show()

for i, (sat_patches, map_patches) in enumerate(train_loader):
    print(f"Batch {i+1}:")
    print(f"Sat Patches Shape: {sat_patches.shape}")
    print(f"Map Patches Shape: {map_patches.shape}")

    plot_patches(sat_patches, map_patches, n_patches=1)
    if i > 2:
        break

# --- Model Definition ---
resnet = resnet50(pretrained=True)
resnet = nn.Sequential(
    *list(resnet.children())[:-2],
    nn.AdaptiveAvgPool2d((1, 1))
).to(DEVICE)
for param in resnet.parameters():
    param.requires_grad = False

mlp = MLPs(
    input_dim=2048 + 1,
    output_dim=WINDOW_SIZE * WINDOW_SIZE,
    units=[4096, 4096],
    layernorm=True,
    dropout=0.1
)
log_density_model = ScoreOrLogDensityNetwork(mlp, score_network=False).to(DEVICE)

optimizer = optim.AdamW(log_density_model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
scaler = torch.cuda.amp.GradScaler() if USE_AUTOMATIC_MIXED_PRECISION else None

# --- Loss Functions ---
def dice_loss(pred, target, smooth=1e-6):
    intersection = (pred * target).sum()
    return 1 - (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)

def combined_loss(pred, target):
    bce = nn.BCELoss()(pred, target)
    dice = dice_loss(pred, target)
    return bce + dice

# --- Save and Load Model ---
def save_model(log_density_model, optimizer, epoch, loss, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': log_density_model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, path)
    print(f"Model saved at {path}")

def load_checkpoint(path):
    if os.path.exists(path):
        checkpoint = torch.load(path)
        log_density_model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        return checkpoint['epoch'], checkpoint['loss']
    return 0, None

# --- Training and Validation Loops ---
def forward_pass(satellite_patches, noise_levels, resnet, log_density_model):
    features = resnet(satellite_patches).flatten(start_dim=1)
    batch_size = features.size(0)
    num_noise_levels = len(noise_levels)
    features = features.repeat_interleave(num_noise_levels, dim=0)
    noise_tensor = noise_levels.repeat(batch_size, 1).view(-1, 1)
    features_with_noise = torch.cat([features, noise_tensor], dim=1)
    predictions = log_density_model(features_with_noise)
    return predictions.view(batch_size, num_noise_levels, 1, WINDOW_SIZE, WINDOW_SIZE)

def train_one_epoch(train_loader, resnet, log_density_model, noise_levels, optimizer, device, scaler, accumulation_steps):
    resnet.eval()
    log_density_model.train()
    epoch_loss = 0

    optimizer.zero_grad()
    for batch_idx, (satellite_patches, road_maps) in enumerate(tqdm(train_loader, desc="Training")):
        if satellite_patches.numel() == 0 or road_maps.numel() == 0:
            continue

        satellite_patches = satellite_patches.to(device)
        road_maps = road_maps.to(device).float() / 255.0
        road_maps = road_maps.unsqueeze(1)

        if USE_AUTOMATIC_MIXED_PRECISION:
            with torch.cuda.amp.autocast():
                predictions = forward_pass(satellite_patches, noise_levels, resnet, log_density_model)
                predictions = predictions[:, 0, :, :, :]
                loss = combined_loss(torch.sigmoid(predictions), road_maps)
                loss = loss / accumulation_steps
            scaler.scale(loss).backward()
        else:
            predictions = forward_pass(satellite_patches, noise_levels, resnet, log_density_model)
            predictions = predictions[:, 0, :, :, :]
            loss = combined_loss(torch.sigmoid(predictions), road_maps)
            loss = loss / accumulation_steps
            loss.backward()

        if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
            if USE_AUTOMATIC_MIXED_PRECISION:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()
            optimizer.zero_grad()

        epoch_loss += loss.item()

    return epoch_loss / len(train_loader)

def validate_one_epoch(val_loader, resnet, log_density_model, noise_levels, device, scaler):
    resnet.eval()
    log_density_model.eval()
    val_loss = 0

    with torch.no_grad():
        for satellite_patches, road_maps in tqdm(val_loader, desc="Validation"):
            if satellite_patches.numel() == 0 or road_maps.numel() == 0:
                continue

            satellite_patches = satellite_patches.to(device)
            road_maps = road_maps.to(device).float() / 255.0
            road_maps = road_maps.unsqueeze(1)

            if USE_AUTOMATIC_MIXED_PRECISION:
                with torch.cuda.amp.autocast():
                    predictions = forward_pass(satellite_patches, noise_levels, resnet, log_density_model)
                    predictions = predictions[:, 0, :, :, :]
                    loss = combined_loss(torch.sigmoid(predictions), road_maps)
            else:
                predictions = forward_pass(satellite_patches, noise_levels, resnet, log_density_model)
                predictions = predictions[:, 0, :, :, :]
                loss = combined_loss(torch.sigmoid(predictions), road_maps)

            val_loss += loss.item()

    return val_loss / len(val_loader)

def train_model(resnet, log_density_model, train_loader, val_loader, noise_levels, optimizer, scheduler, device, epochs, accumulation_steps, checkpoint_path, best_model_path):
    start_epoch, _ = load_checkpoint(checkpoint_path)
    best_val_loss = float('inf')

    for epoch in range(start_epoch, epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        train_loss = train_one_epoch(train_loader, resnet, log_density_model, noise_levels, optimizer, device, scaler, accumulation_steps)
        val_loss = validate_one_epoch(val_loader, resnet, log_density_model, noise_levels, device, scaler)

        print(f"Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")
        scheduler.step()

        save_model(log_density_model, optimizer, epoch, val_loss, path=checkpoint_path)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_model(log_density_model, optimizer, epoch, val_loss, path=best_model_path)
            print(f"Best model updated with Validation Loss: {best_val_loss:.4f}")

    print("Training complete.")

# --- Evaluation Function ---
def evaluate(model, resnet, dataloader, noise_levels, device):
    model.eval()
    total_loss = 0
    predictions_list = []
    ground_truth_list = []
    satellite_patches_list = []

    with torch.no_grad():
        for satellite_patches, road_maps in tqdm(dataloader, desc="Evaluating"):
            if satellite_patches.numel() == 0 or road_maps.numel() == 0:
                continue

            satellite_patches = satellite_patches.to(device)
            road_maps = road_maps.to(device).float() / 255.0

            if USE_AUTOMATIC_MIXED_PRECISION:
                with torch.cuda.amp.autocast():
                    predictions = forward_pass(satellite_patches, noise_levels, resnet, model)
                    predictions = predictions[:, 0, :, :, :]  # Use the predictions for noise level 0
            else:
                predictions = forward_pass(satellite_patches, noise_levels, resnet, model)
                predictions = predictions[:, 0, :, :, :]  # Use the predictions for noise level 0

            predictions_list.append(torch.sigmoid(predictions).cpu())
            ground_truth_list.append(road_maps.cpu())
            satellite_patches_list.append(satellite_patches.cpu())

            # Compute loss
            loss = combined_loss(torch.sigmoid(predictions), road_maps)
            total_loss += loss.item()

    # Compute metrics
    all_preds = torch.cat([p.flatten() for p in predictions_list]).numpy() > 0.5
    all_targets = torch.cat([t.flatten() for t in ground_truth_list]).numpy()
    precision = precision_score(all_targets, all_preds)
    recall = recall_score(all_targets, all_preds)
    f1 = f1_score(all_targets, all_preds)
    iou = jaccard_score(all_targets, all_preds)

    print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}, IoU: {iou:.4f}")

    return total_loss / len(dataloader), predictions_list, ground_truth_list, satellite_patches_list

# --- Plot Predictions and Ground Truth ---
def plot_predictions(predictions, ground_truth, satellite_images, n_samples=5):
    for i in range(min(n_samples, len(predictions))):
        pred = predictions[i].squeeze().numpy()
        gt = ground_truth[i].squeeze().numpy()
        sat_img = satellite_images[i].permute(1, 2, 0).numpy()

        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        axes[0].imshow(sat_img)
        axes[0].set_title('Satellite Image')
        axes[0].axis('off')

        axes[1].imshow(pred, cmap='gray')
        axes[1].set_title('Prediction')
        axes[1].axis('off')

        axes[2].imshow(gt, cmap='gray')
        axes[2].set_title('Ground Truth')
        axes[2].axis('off')

        plt.show()



In [None]:
# --- Start Training ---
train_model(
    resnet=resnet,
    log_density_model=log_density_model,
    train_loader=train_loader,
    val_loader=val_loader,
    noise_levels=noise_levels,
    optimizer=optimizer,
    scheduler=scheduler,
    device=DEVICE,
    epochs=EPOCHS,
    accumulation_steps=ACCUMULATION_STEPS,
    checkpoint_path=CHECKPOINT_PATH,
    best_model_path=BEST_MODEL_PATH,
)


In [None]:
# --- Evaluate on Test Set ---
val_loss, predictions, ground_truth, satellite_patches = evaluate(
    model=log_density_model,
    resnet=resnet,
    dataloader=test_loader,
    noise_levels=noise_levels,
    device=DEVICE,
)
print(f"Test Loss: {val_loss:.4f}")

# --- Plot Predictions ---
plot_predictions(predictions, ground_truth, satellite_patches, n_samples=5)


In [None]:
# Epoch 1/300: 100%|██████████| 277/277 [53:31<00:00, 11.59s/it]
# Epoch 1/300, Loss: 0.1431
# Checkpoint saved at log_density_segmentation_checkpoint.pth
# Epoch 2/300: 100%|██████████| 277/277 [27:57<00:00,  6.06s/it]
# Epoch 2/300, Loss: 0.1259
# Checkpoint saved at log_density_segmentation_checkpoint.pth
# Epoch 3/300: 100%|██████████| 277/277 [27:26<00:00,  5.94s/it]
# Epoch 3/300, Loss: 0.1256
# Checkpoint saved at log_density_segmentation_checkpoint.pth
# Epoch 4/300: 100%|██████████| 277/277 [26:52<00:00,  5.82s/it]
# Epoch 4/300, Loss: 0.1255
# Checkpoint saved at log_density_segmentation_checkpoint.pth
# Epoch 5/300: 100%|██████████| 277/277 [27:13<00:00,  5.90s/it]
# Epoch 5/300, Loss: 0.1254
# Checkpoint saved at log_density_segmentation_checkpoint.pth
# Epoch 6/300: 100%|██████████| 277/277 [27:23<00:00,  5.93s/it]
# Epoch 6/300, Loss: 0.1254
# Checkpoint saved at log_density_segmentation_checkpoint.pth
# Epoch 7/300: 100%|██████████| 277/277 [27:14<00:00,  5.90s/it]
# Epoch 7/300, Loss: 0.1254
# Checkpoint saved at log_density_segmentation_checkpoint.pth
# Epoch 8/300: 100%|██████████| 277/277 [1:03:25<00:00, 13.74s/it]
# Epoch 8/300, Loss: 0.1254
# Checkpoint saved at log_density_segmentation_checkpoint.pth
# Epoch 9/300: 100%|██████████| 277/277 [28:32<00:00,  6.18s/it]
# Epoch 9/300, Loss: 0.1254
# Checkpoint saved at log_density_segmentation_checkpoint.pth
# Epoch 10/300: 100%|██████████| 277/277 [28:31<00:00,  6.18s/it]
# Epoch 10/300, Loss: 0.1254
# Checkpoint saved at log_density_segmentation_checkpoint.pth
# Epoch 11/300: 100%|██████████| 277/277 [28:24<00:00,  6.15s/it]
# Epoch 11/300, Loss: 0.1254
# Checkpoint saved at log_density_segmentation_checkpoint.pth
# Epoch 12/300: 100%|██████████| 277/277 [28:17<00:00,  6.13s/it]
# Epoch 12/300, Loss: 0.1254
# Checkpoint saved at log_density_segmentation_checkpoint.pth
# Epoch 13/300: 100%|██████████| 277/277 [28:30<00:00,  6.18s/it]
# Epoch 13/300, Loss: 0.1254

In [None]:
# import torch
# import gc
# gc.collect()
# torch.cuda.empty_cache()
