# Combined Training Notebook: RetinaNet + U-Net

This notebook provides training implementations for both RetinaNet (object detection) and U-Net (semantic segmentation) models.

## Common Setup and Imports

In [None]:
import os
import numpy as np
import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import glob
import cv2
import copy
import random
import collections
from natsort import natsorted
from tqdm import trange
from PIL import Image

# Check GPU availability
print('CUDA available: {}'.format(torch.cuda.is_available()))
if torch.cuda.is_available():
    print('GPU count: {}'.format(torch.cuda.device_count()))
    for i in range(torch.cuda.device_count()):
        print(f'GPU {i}: {torch.cuda.get_device_name(i)}')

## 1. RetinaNet Training

### RetinaNet Setup and Configuration

In [None]:
# RetinaNet imports
from retinanet import model
from retinanet.dataloader import CocoDataset, CSVDataset, collater, Resizer, AspectRatioBasedSampler, Augmenter, Normalizer
from retinanet import coco_eval, csv_eval

# RetinaNet Configuration
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

RETINANET_CONFIG = {
    'LEARNING_RATE': 1e-5,
    'EPOCHS': 100,
    'BATCH_SIZE': 4,
    'DEPTH': 50,  # ResNet depth (18, 34, 50, 101, 152)
    'PRETRAINED': True,
    'MODEL_SAVE_PATH': './retinanet_weights/',
    'DATASET_TYPE': 'csv'  # or 'coco'
}

# Create save directory
if not os.path.exists(RETINANET_CONFIG['MODEL_SAVE_PATH']):
    os.makedirs(RETINANET_CONFIG['MODEL_SAVE_PATH'])

In [None]:
# U-Net imports
from Unet.trainer import train, val
from Unet.loss import dice_loss, dice
from Unet.Unet import UNet
from Unet.preprocessing import *
from Unet.datagenerater import Dental_Single_Data_Generator
from Unet.utils import *
from Unet.progressbar import Bar

# U-Net Configuration
UNET_CONFIG = {
    'IMAGE_SIZE': (512, 512),
    'N_CLASSES': 14,  # Updated for 14 landmark points
    'TRAIN_BATCH': 4,
    'TEST_BATCH': 1,
    'EPOCHS': 50,
    'LEARNING_RATE': 5e-4,
    'SEED': 42,
    'MODEL_SAVE_PATH': './unet_weights/',
    'ENCODER_NAME': 'vgg16',  # or 'timm-tf_efficientnet_lite4'
    'USE_ATTENTION': True,
    'NUM_LANDMARKS': 14  # Total number of landmark points
}

# Set random seed for reproducibility
random.seed(UNET_CONFIG['SEED'])
torch.manual_seed(UNET_CONFIG['SEED'])
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(UNET_CONFIG['SEED'])

# Create save directory
if not os.path.exists(UNET_CONFIG['MODEL_SAVE_PATH']):
    os.makedirs(UNET_CONFIG['MODEL_SAVE_PATH'])

print(f"U-Net configuration set with image size: {UNET_CONFIG['IMAGE_SIZE']}")
print(f"Number of landmark classes: {UNET_CONFIG['N_CLASSES']}")

In [None]:
# Configure your dataset paths here
RETINANET_DATA_PATHS = {
    'csv_train': './dataset/annotations_train.csv',
    'csv_val': './dataset/annotations_val.csv',
    'csv_classes': './dataset/classes.csv',
    'coco_path': './dataset/coco/'  # if using COCO format
}

# Update paths as needed for your dataset
print("RetinaNet dataset paths configured. Update the paths above to match your dataset location.")

In [None]:
def create_unet_dataloaders(x_train, y_train, x_val, y_val, config):
    """
    Create data loaders for U-Net training with support for multiple landmark points
    """
    # Data transforms
    transform_train = transforms.Compose([
        Gamma_2D(),
        Shift_2D(),
        RandomBrightness(),
        Rotation_2D(),
        RandomSharp(),
        RandomBlur(),
        RandomNoise(),
        Invert(),
        RandomClahe(),
        ToTensor(),
    ])
    
    transform_val = transforms.Compose([
        ToTensor(),
    ])
    
    # Create datasets for each landmark point
    train_loaders = []
    val_loaders = []
    
    for landmark_idx in range(config['NUM_LANDMARKS']):
        # Create datasets for each landmark
        trainset = Dental_Single_Data_Generator(
            config['IMAGE_SIZE'], x_train, y_train, 
            landmark_num=landmark_idx, mode="train", transform=transform_train
        )
        
        valset = Dental_Single_Data_Generator(
            config['IMAGE_SIZE'], x_val, y_val, 
            landmark_num=landmark_idx, mode="train", transform=transform_val
        )
        
        # Create data loaders
        trainloader = DataLoader(trainset, batch_size=config['TRAIN_BATCH'], shuffle=True)
        valloader = DataLoader(valset, batch_size=config['TEST_BATCH'], shuffle=False)
        
        train_loaders.append(trainloader)
        val_loaders.append(valloader)
    
    return train_loaders, val_loaders

# Alternative: Single multi-class dataset approach
def create_unet_dataloaders_multiclass(x_train, y_train, x_val, y_val, config):
    """
    Create data loaders for U-Net training with multi-class output
    """
    # Data transforms
    transform_train = transforms.Compose([
        Gamma_2D(),
        Shift_2D(),
        RandomBrightness(),
        Rotation_2D(),
        RandomSharp(),
        RandomBlur(),
        RandomNoise(),
        Invert(),
        RandomClahe(),
        ToTensor(),
    ])
    
    transform_val = transforms.Compose([
        ToTensor(),
    ])
    
    # Create datasets with all landmarks combined
    trainset = Dental_Single_Data_Generator(
        config['IMAGE_SIZE'], x_train, y_train, 
        landmark_num=-1, mode="train", transform=transform_train  # -1 for all landmarks
    )
    
    valset = Dental_Single_Data_Generator(
        config['IMAGE_SIZE'], x_val, y_val, 
        landmark_num=-1, mode="train", transform=transform_val
    )
    
    # Create data loaders
    trainloader = DataLoader(trainset, batch_size=config['TRAIN_BATCH'], shuffle=True)
    valloader = DataLoader(valset, batch_size=config['TEST_BATCH'], shuffle=False)
    
    return trainloader, valloader

# Create data loaders (uncomment when dataset is ready)
# Option 1: Separate models for each landmark
# train_loaders, val_loaders = create_unet_dataloaders(x_train, y_train, x_val, y_val, UNET_CONFIG)

# Option 2: Single multi-class model
# trainloader, valloader = create_unet_dataloaders_multiclass(x_train, y_train, x_val, y_val, UNET_CONFIG)

In [None]:
def create_retinanet_dataloaders(config, data_paths):
    if config['DATASET_TYPE'] == 'csv':
        # CSV Dataset
        dataset_train = CSVDataset(
            train_file=data_paths['csv_train'],
            class_list=data_paths['csv_classes'],
            transform=transforms.Compose([Normalizer(), Augmenter(), Resizer()])
        )
        
        dataset_val = CSVDataset(
            train_file=data_paths['csv_val'],
            class_list=data_paths['csv_classes'],
            transform=transforms.Compose([Normalizer(), Resizer()])
        )
        
    elif config['DATASET_TYPE'] == 'coco':
        # COCO Dataset
        dataset_train = CocoDataset(
            data_paths['coco_path'],
            set_name='train2017',
            transform=transforms.Compose([Normalizer(), Augmenter(), Resizer()])
        )
        
        dataset_val = CocoDataset(
            data_paths['coco_path'],
            set_name='val2017',
            transform=transforms.Compose([Normalizer(), Resizer()])
        )
    
    # Create data loaders
    sampler_train = AspectRatioBasedSampler(dataset_train, batch_size=config['BATCH_SIZE'], drop_last=False)
    dataloader_train = DataLoader(dataset_train, num_workers=0, collate_fn=collater, batch_sampler=sampler_train)
    
    sampler_val = AspectRatioBasedSampler(dataset_val, batch_size=1, drop_last=False)
    dataloader_val = DataLoader(dataset_val, num_workers=0, collate_fn=collater, batch_sampler=sampler_val)
    
    return dataloader_train, dataloader_val, dataset_train, dataset_val

# Create data loaders (uncomment when you have dataset paths configured)
# dataloader_train, dataloader_val, dataset_train, dataset_val = create_retinanet_dataloaders(RETINANET_CONFIG, RETINANET_DATA_PATHS)
# print(f'Num training images: {len(dataset_train)}')
# print(f'Num validation images: {len(dataset_val)}')
# print(f'Num classes: {dataset_train.num_classes()}')

In [None]:
def train_unet_multiple_landmarks(train_loaders, val_loaders, config):
    """
    Train separate U-Net models for each landmark point (following the original approach)
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    models = []
    training_histories = []
    
    for landmark_idx in range(config['NUM_LANDMARKS']):\n        print(f\"\\n{'='*60}\")\n        print(f\"Training U-Net for Landmark {landmark_idx + 1}/{config['NUM_LANDMARKS']}\")\n        print(f\"{'='*60}\")\n        \n        # Create model for this landmark\n        model = create_unet_model(config)\n        model.to(device)\n        \n        # Get data loaders for this landmark\n        trainloader = train_loaders[landmark_idx]\n        valloader = val_loaders[landmark_idx]\n        \n        # Optimizer\n        optimizer = torch.optim.Adam(model.parameters(), lr=config['LEARNING_RATE'])\n        \n        # Training history\n        train_losses = []\n        val_losses = []\n        best_val_loss = float('inf')\n        \n        for epoch in range(config['EPOCHS']):\n            # Training phase\n            model.train()\n            epoch_train_losses = []\n            epoch_train_dice = []\n            \n            for batch_idx, sample in enumerate(trainloader):\n                images = sample['image'].to(device)\n                masks = sample['landmarks'].to(device)\n                \n                # Forward pass\n                outputs = model(images)\n                outputs = torch.sigmoid(outputs)\n                \n                # Calculate loss\n                loss = dice_loss(outputs, masks)\n                dice_score = dice(outputs, masks)\n                \n                # Backward pass\n                optimizer.zero_grad()\n                loss.backward()\n                optimizer.step()\n                \n                epoch_train_losses.append(loss.item())\n                epoch_train_dice.append(dice_score.item())\n            \n            # Validation phase\n            model.eval()\n            epoch_val_losses = []\n            epoch_val_dice = []\n            \n            with torch.no_grad():\n                for batch_idx, sample in enumerate(valloader):\n                    images = sample['image'].to(device)\n                    masks = sample['landmarks'].to(device)\n                    \n                    outputs = model(images)\n                    outputs = torch.sigmoid(outputs)\n                    \n                    loss = dice_loss(outputs, masks)\n                    dice_score = dice(outputs, masks)\n                    \n                    epoch_val_losses.append(loss.item())\n                    epoch_val_dice.append(dice_score.item())\n            \n            # Calculate epoch metrics\n            avg_train_loss = np.mean(epoch_train_losses)\n            avg_train_dice = np.mean(epoch_train_dice)\n            avg_val_loss = np.mean(epoch_val_losses)\n            avg_val_dice = np.mean(epoch_val_dice)\n            \n            train_losses.append(avg_train_loss)\n            val_losses.append(avg_val_loss)\n            \n            if epoch % 5 == 0 or epoch == config['EPOCHS'] - 1:\n                print(f'Epoch {epoch+1}/{config[\"EPOCHS\"]}:')\n                print(f'  Train - Loss: {avg_train_loss:.5f}, Dice: {avg_train_dice:.5f}')\n                print(f'  Val   - Loss: {avg_val_loss:.5f}, Dice: {avg_val_dice:.5f}')\n            \n            # Save best model\n            if avg_val_loss < best_val_loss:\n                best_val_loss = avg_val_loss\n                \n                # Create landmark-specific directory\n                landmark_dir = os.path.join(config['MODEL_SAVE_PATH'], str(landmark_idx))\n                if not os.path.exists(landmark_dir):\n                    os.makedirs(landmark_dir)\n                \n                # Save model state dict\n                if hasattr(model, 'module'):\n                    state_dict = model.module.state_dict()\n                else:\n                    state_dict = model.state_dict()\n                \n                save_path = os.path.join(landmark_dir, 'weight.pth')\n                torch.save(state_dict, save_path)\n                \n                if epoch % 5 == 0:\n                    print(f'  Best model saved: {save_path} (Val Loss: {best_val_loss:.5f})')\n        \n        print(f'Landmark {landmark_idx + 1} training completed! Best val loss: {best_val_loss:.5f}')\n        \n        models.append(model)\n        training_histories.append({\n            'train_losses': train_losses,\n            'val_losses': val_losses,\n            'best_val_loss': best_val_loss\n        })\n    \n    return models, training_histories\n\n\ndef train_unet_multiclass(trainloader, valloader, config):\n    \"\"\"\n    Train single U-Net model for all landmark points (multi-class approach)\n    \"\"\"\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    \n    # Create model\n    model = create_unet_model(config)\n    model.to(device)\n    \n    # Optimizer\n    optimizer = torch.optim.Adam(model.parameters(), lr=config['LEARNING_RATE'])\n    \n    # Training history\n    train_losses = []\n    val_losses = []\n    best_val_loss = float('inf')\n    \n    print(f\"Training multi-class U-Net for {config['NUM_LANDMARKS']} landmarks\")\n    print(f\"{'='*60}\")\n    \n    for epoch in range(config['EPOCHS']):\n        # Training phase\n        model.train()\n        epoch_train_losses = []\n        epoch_train_dice = []\n        \n        for batch_idx, sample in enumerate(trainloader):\n            images = sample['image'].to(device)\n            masks = sample['landmarks'].to(device)\n            \n            # Forward pass\n            outputs = model(images)\n            outputs = torch.sigmoid(outputs)\n            \n            # Calculate loss (assuming masks have shape [batch, num_landmarks, height, width])\n            loss = dice_loss(outputs, masks)\n            dice_score = dice(outputs, masks)\n            \n            # Backward pass\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n            \n            epoch_train_losses.append(loss.item())\n            epoch_train_dice.append(dice_score.item())\n        \n        # Validation phase\n        model.eval()\n        epoch_val_losses = []\n        epoch_val_dice = []\n        \n        with torch.no_grad():\n            for batch_idx, sample in enumerate(valloader):\n                images = sample['image'].to(device)\n                masks = sample['landmarks'].to(device)\n                \n                outputs = model(images)\n                outputs = torch.sigmoid(outputs)\n                \n                loss = dice_loss(outputs, masks)\n                dice_score = dice(outputs, masks)\n                \n                epoch_val_losses.append(loss.item())\n                epoch_val_dice.append(dice_score.item())\n        \n        # Calculate epoch metrics\n        avg_train_loss = np.mean(epoch_train_losses)\n        avg_train_dice = np.mean(epoch_train_dice)\n        avg_val_loss = np.mean(epoch_val_losses)\n        avg_val_dice = np.mean(epoch_val_dice)\n        \n        train_losses.append(avg_train_loss)\n        val_losses.append(avg_val_loss)\n        \n        if epoch % 5 == 0 or epoch == config['EPOCHS'] - 1:\n            print(f'Epoch {epoch+1}/{config[\"EPOCHS\"]}:')\n            print(f'  Train - Loss: {avg_train_loss:.5f}, Dice: {avg_train_dice:.5f}')\n            print(f'  Val   - Loss: {avg_val_loss:.5f}, Dice: {avg_val_dice:.5f}')\n        \n        # Save best model\n        if avg_val_loss < best_val_loss:\n            best_val_loss = avg_val_loss\n            \n            # Save model state dict\n            if hasattr(model, 'module'):\n                state_dict = model.module.state_dict()\n            else:\n                state_dict = model.state_dict()\n            \n            save_path = os.path.join(config['MODEL_SAVE_PATH'], 'best_multiclass_unet.pth')\n            torch.save(state_dict, save_path)\n            \n            if epoch % 5 == 0:\n                print(f'  Best model saved: {save_path} (Val Loss: {best_val_loss:.5f})')\n    \n    print(f'Multi-class training completed! Best validation loss: {best_val_loss:.5f}')\n    \n    return model, train_losses, val_losses\n\n\n# Train U-Net models (uncomment when ready)\n# Option 1: Train separate models for each landmark (original approach)\n# models, histories = train_unet_multiple_landmarks(train_loaders, val_loaders, UNET_CONFIG)\n\n# Option 2: Train single multi-class model\n# model, train_losses, val_losses = train_unet_multiclass(trainloader, valloader, UNET_CONFIG)"

In [None]:
def visualize_unet_predictions(models, val_loaders, config, num_samples=2):\n    \"\"\"\n    Visualize U-Net predictions for multiple landmark models\n    \"\"\"\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    \n    # Set all models to eval mode\n    for model in models:\n        model.eval()\n    \n    with torch.no_grad():\n        # Get a sample from the first landmark's validation set\n        sample_data = next(iter(val_loaders[0]))\n        images = sample_data['image'].to(device)\n        \n        for sample_idx in range(min(num_samples, images.shape[0])):\n            # Create a grid for visualization\n            num_cols = min(5, config['NUM_LANDMARKS'] + 1)  # Image + up to 4 landmarks per row\n            num_rows = (config['NUM_LANDMARKS'] + num_cols) // num_cols\n            \n            fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols * 4, num_rows * 4))\n            axes = axes.flatten() if num_rows > 1 else [axes] if num_cols == 1 else axes\n            \n            # Show original image\n            image = images[sample_idx].cpu().numpy().squeeze()\n            axes[0].imshow(image, cmap='gray')\n            axes[0].set_title('Input Image')\n            axes[0].axis('off')\n            \n            # Show predictions for each landmark\n            for landmark_idx in range(config['NUM_LANDMARKS']):\n                model = models[landmark_idx]\n                output = model(images[sample_idx:sample_idx+1])\n                prediction = torch.sigmoid(output).cpu().numpy().squeeze()\n                \n                ax_idx = landmark_idx + 1\n                if ax_idx < len(axes):\n                    axes[ax_idx].imshow(prediction, cmap='gray')\n                    axes[ax_idx].set_title(f'Landmark {landmark_idx + 1}')\n                    axes[ax_idx].axis('off')\n            \n            # Hide unused subplots\n            for i in range(config['NUM_LANDMARKS'] + 1, len(axes)):\n                axes[i].axis('off')\n            \n            plt.tight_layout()\n            plt.show()\n\n\ndef visualize_unet_multiclass_predictions(model, valloader, config, num_samples=2):\n    \"\"\"\n    Visualize U-Net predictions for multi-class model\n    \"\"\"\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    model.eval()\n    \n    with torch.no_grad():\n        for i, sample in enumerate(valloader):\n            if i >= num_samples:\n                break\n            \n            images = sample['image'].to(device)\n            masks = sample['landmarks'].to(device)\n            \n            outputs = model(images)\n            outputs = torch.sigmoid(outputs)\n            \n            # Convert to numpy for visualization\n            image = images[0].cpu().numpy().squeeze()\n            ground_truth = masks[0].cpu().numpy()  # Shape: [num_landmarks, height, width]\n            predictions = outputs[0].cpu().numpy()  # Shape: [num_landmarks, height, width]\n            \n            # Create visualization grid\n            num_cols = 4  # Image, GT, Pred, Overlay\n            num_rows = config['NUM_LANDMARKS']\n            \n            fig, axes = plt.subplots(num_rows, num_cols, figsize=(16, num_rows * 4))\n            \n            for landmark_idx in range(config['NUM_LANDMARKS']):\n                row = landmark_idx\n                \n                # Original image\n                axes[row, 0].imshow(image, cmap='gray')\n                axes[row, 0].set_title(f'Input Image (Landmark {landmark_idx + 1})')\n                axes[row, 0].axis('off')\n                \n                # Ground truth\n                axes[row, 1].imshow(ground_truth[landmark_idx], cmap='gray')\n                axes[row, 1].set_title(f'Ground Truth {landmark_idx + 1}')\n                axes[row, 1].axis('off')\n                \n                # Prediction\n                axes[row, 2].imshow(predictions[landmark_idx], cmap='gray')\n                axes[row, 2].set_title(f'Prediction {landmark_idx + 1}')\n                axes[row, 2].axis('off')\n                \n                # Overlay\n                overlay = image + predictions[landmark_idx]\n                axes[row, 3].imshow(overlay, cmap='gray')\n                axes[row, 3].set_title(f'Overlay {landmark_idx + 1}')\n                axes[row, 3].axis('off')\n            \n            plt.tight_layout()\n            plt.show()\n\n\ndef plot_training_history_multiple_landmarks(histories, config):\n    \"\"\"\n    Plot training history for multiple landmark models\n    \"\"\"\n    fig, axes = plt.subplots(2, 1, figsize=(12, 10))\n    \n    # Plot losses\n    for i, history in enumerate(histories):\n        axes[0].plot(history['train_losses'], label=f'Train Loss - Landmark {i+1}', alpha=0.7)\n        axes[0].plot(history['val_losses'], label=f'Val Loss - Landmark {i+1}', alpha=0.7, linestyle='--')\n    \n    axes[0].set_title('Training and Validation Loss for All Landmarks')\n    axes[0].set_xlabel('Epoch')\n    axes[0].set_ylabel('Loss')\n    axes[0].legend(bbox_to_anchor=(1.05, 1), loc='upper left')\n    axes[0].grid(True, alpha=0.3)\n    \n    # Plot best validation losses\n    best_losses = [h['best_val_loss'] for h in histories]\n    axes[1].bar(range(1, config['NUM_LANDMARKS'] + 1), best_losses)\n    axes[1].set_title('Best Validation Loss per Landmark')\n    axes[1].set_xlabel('Landmark Index')\n    axes[1].set_ylabel('Best Validation Loss')\n    axes[1].grid(True, alpha=0.3)\n    \n    plt.tight_layout()\n    plt.show()\n    \n    # Print summary\n    print(\"\\nTraining Summary:\")\n    print(\"=\" * 50)\n    for i, history in enumerate(histories):\n        print(f\"Landmark {i+1:2d}: Best Val Loss = {history['best_val_loss']:.5f}\")\n    print(f\"\\nAverage Best Val Loss: {np.mean(best_losses):.5f}\")\n    print(f\"Std Best Val Loss: {np.std(best_losses):.5f}\")\n\n\n# Visualization functions (uncomment when models are trained)\n# Option 1: Visualize multiple landmark models\n# visualize_unet_predictions(models, val_loaders, UNET_CONFIG)\n# plot_training_history_multiple_landmarks(histories, UNET_CONFIG)\n\n# Option 2: Visualize multi-class model\n# visualize_unet_multiclass_predictions(model, valloader, UNET_CONFIG)"

## Usage Instructions

### To train RetinaNet:
1. Update the `RETINANET_DATA_PATHS` dictionary with your dataset paths
2. Uncomment the RetinaNet data loading and model creation cells
3. Run the RetinaNet training cell

### To train U-Net for 14 landmarks:

**Option 1: Separate models for each landmark (Original approach)**
- This approach trains 14 separate U-Net models, one for each landmark
- Each model specializes in detecting a specific landmark point
- Follows the original codebase structure

**Option 2: Single multi-class model**
- This approach trains one U-Net model that outputs all 14 landmarks simultaneously
- More efficient in terms of memory and training time
- Requires landmarks to be formatted as multi-channel masks

### Steps to train U-Net:
1. Update the `UNET_DATA_PATHS` dictionary with your dataset paths
2. Choose your training approach (Option 1 or Option 2)
3. Uncomment the appropriate data loading cells
4. Run the U-Net training cell

### Configuration Details:
- **NUM_LANDMARKS**: Set to 14 for your landmark points
- **N_CLASSES**: Set to 14 for multi-class approach, or 1 for individual landmark models
- **IMAGE_SIZE**: Adjust based on your input image dimensions
- **BATCH_SIZE**: Adjust based on your GPU memory

### Model Saving:
- **Separate models**: Saves each landmark model in `./unet_weights/{landmark_idx}/weight.pth`
- **Multi-class model**: Saves in `./unet_weights/best_multiclass_unet.pth`

### Notes:
- Make sure you have the required dependencies installed
- The original codebase uses the separate models approach (Option 1)
- Adjust hyperparameters in the configuration dictionaries as needed
- Monitor GPU memory usage and adjust batch sizes if necessary
- Both approaches will save checkpoints during training

### Expected Data Format:
- Images: PNG format in the specified image directory
- Labels: NumPy (.npy) format containing landmark annotations
- The data generator expects landmark_num parameter to select specific landmarks (0-13 for individual landmarks, -1 for all landmarks)

In [None]:
def train_retinanet(retinanet, dataloader_train, dataloader_val, config, dataset_type='csv'):
    retinanet.training = True
    
    # Optimizer and scheduler
    optimizer = optim.Adam(retinanet.parameters(), lr=config['LEARNING_RATE'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)
    
    loss_hist = collections.deque(maxlen=500)
    best_map = 0.0
    
    for epoch_num in range(config['EPOCHS']):
        # Training phase
        retinanet.train()
        retinanet.module.freeze_bn()
        
        epoch_loss = []
        
        for iter_num, data in enumerate(dataloader_train):
            try:
                optimizer.zero_grad()
                
                if torch.cuda.is_available():
                    classification_loss, regression_loss = retinanet([data['img'].cuda().float(), data['annot']])
                else:
                    classification_loss, regression_loss = retinanet([data['img'].float(), data['annot']])
                
                classification_loss = classification_loss.mean()
                regression_loss = regression_loss.mean()
                
                loss = classification_loss + regression_loss
                
                if bool(loss == 0):
                    continue
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(retinanet.parameters(), 0.1)
                optimizer.step()
                
                loss_hist.append(float(loss))
                epoch_loss.append(float(loss))
                
                if iter_num % 10 == 0:
                    print(f'Epoch: {epoch_num} | Iteration: {iter_num} | '
                          f'Classification loss: {float(classification_loss):.5f} | '
                          f'Regression loss: {float(regression_loss):.5f} | '
                          f'Running loss: {np.mean(loss_hist):.5f}')
                
                del classification_loss, regression_loss
                
            except Exception as e:
                print(f'Error in training iteration: {e}')
                continue
        
        # Validation phase
        print('Evaluating dataset...')
        
        if dataset_type == 'coco':
            coco_eval.evaluate_coco(dataloader_val, retinanet)
        elif dataset_type == 'csv':
            mAP = csv_eval.evaluate(dataloader_val, retinanet)
            
            # Calculate mean mAP
            mean_map = np.mean([ap[0] for ap in mAP])
            print(f'Mean mAP: {mean_map:.4f}')
            
            # Save best model
            if mean_map > best_map:
                best_map = mean_map
                model_path = os.path.join(config['MODEL_SAVE_PATH'], f'best_retinanet_epoch_{epoch_num}.pt')
                torch.save(retinanet.module, model_path)
                print(f'Best model saved: {model_path} (mAP: {best_map:.4f})')
        
        # Update learning rate
        scheduler.step(np.mean(epoch_loss))
        
        # Save checkpoint
        checkpoint_path = os.path.join(config['MODEL_SAVE_PATH'], f'retinanet_epoch_{epoch_num}.pt')
        torch.save(retinanet.module, checkpoint_path)
    
    # Save final model
    retinanet.eval()
    final_model_path = os.path.join(config['MODEL_SAVE_PATH'], 'retinanet_final.pt')
    torch.save(retinanet, final_model_path)
    print(f'Training completed! Final model saved: {final_model_path}')

# Train RetinaNet (uncomment when ready)
# train_retinanet(retinanet, dataloader_train, dataloader_val, RETINANET_CONFIG, RETINANET_CONFIG['DATASET_TYPE'])

## 2. U-Net Training

### U-Net Setup and Configuration

In [None]:
# U-Net imports
from Unet.trainer import train, val
from Unet.loss import dice_loss, dice
from Unet.Unet import UNet
from Unet.preprocessing import *
from Unet.datagenerater import Dental_Single_Data_Generator
from Unet.utils import *
from Unet.progressbar import Bar

# U-Net Configuration
UNET_CONFIG = {
    'IMAGE_SIZE': (512, 512),
    'N_CLASSES': 1,
    'TRAIN_BATCH': 4,
    'TEST_BATCH': 1,
    'EPOCHS': 50,
    'LEARNING_RATE': 5e-4,
    'SEED': 42,
    'MODEL_SAVE_PATH': './unet_weights/',
    'ENCODER_NAME': 'vgg16',  # or 'timm-tf_efficientnet_lite4'
    'USE_ATTENTION': True
}

# Set random seed for reproducibility
random.seed(UNET_CONFIG['SEED'])
torch.manual_seed(UNET_CONFIG['SEED'])
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(UNET_CONFIG['SEED'])

# Create save directory
if not os.path.exists(UNET_CONFIG['MODEL_SAVE_PATH']):
    os.makedirs(UNET_CONFIG['MODEL_SAVE_PATH'])

print(f"U-Net configuration set with image size: {UNET_CONFIG['IMAGE_SIZE']}")

### U-Net Dataset Configuration

In [None]:
# Configure your U-Net dataset paths
UNET_DATA_PATHS = {
    'image_path': './dataset/images/',
    'label_path': './dataset/labels/',
    'train_split': 0.8  # 80% for training, 20% for validation
}

def prepare_unet_dataset(data_paths):
    """
    Prepare U-Net dataset by splitting images and labels into train/val sets
    """
    # Get all image and label files
    image_files = natsorted(glob.glob(os.path.join(data_paths['image_path'], '*.png')))
    label_files = natsorted(glob.glob(os.path.join(data_paths['label_path'], '*.npy')))
    
    # Match image and label files
    matched_pairs = []
    for label_file in label_files:
        base_name = os.path.basename(label_file).split('.')[0]
        matching_images = [img for img in image_files if base_name in os.path.basename(img)]
        if matching_images:
            matched_pairs.append((matching_images[0], label_file))
    
    # Split into train/val
    split_idx = int(len(matched_pairs) * data_paths['train_split'])
    
    train_pairs = matched_pairs[:split_idx]
    val_pairs = matched_pairs[split_idx:]
    
    x_train = [pair[0] for pair in train_pairs]
    y_train = [pair[1] for pair in train_pairs]
    x_val = [pair[0] for pair in val_pairs]
    y_val = [pair[1] for pair in val_pairs]
    
    print(f'Training samples: {len(x_train)}')
    print(f'Validation samples: {len(x_val)}')
    
    return x_train, y_train, x_val, y_val

# Prepare dataset (uncomment when you have dataset paths configured)
# x_train, y_train, x_val, y_val = prepare_unet_dataset(UNET_DATA_PATHS)

### U-Net Data Transforms and Loaders

In [None]:
def create_unet_dataloaders(x_train, y_train, x_val, y_val, config):
    """
    Create data loaders for U-Net training
    """
    # Data transforms
    transform_train = transforms.Compose([
        Gamma_2D(),
        Shift_2D(),
        RandomBrightness(),
        Rotation_2D(),
        RandomSharp(),
        RandomBlur(),
        RandomNoise(),
        Invert(),
        RandomClahe(),
        ToTensor(),
    ])
    
    transform_val = transforms.Compose([
        ToTensor(),
    ])
    
    # Create datasets
    trainset = Dental_Single_Data_Generator(
        config['IMAGE_SIZE'], x_train, y_train, 
        landmark_num=0, mode="train", transform=transform_train
    )
    
    valset = Dental_Single_Data_Generator(
        config['IMAGE_SIZE'], x_val, y_val, 
        landmark_num=0, mode="train", transform=transform_val
    )
    
    # Create data loaders
    trainloader = DataLoader(trainset, batch_size=config['TRAIN_BATCH'], shuffle=True)
    valloader = DataLoader(valset, batch_size=config['TEST_BATCH'], shuffle=False)
    
    return trainloader, valloader

# Create data loaders (uncomment when dataset is ready)
# trainloader, valloader = create_unet_dataloaders(x_train, y_train, x_val, y_val, UNET_CONFIG)

### U-Net Model Setup

In [None]:
def create_unet_model(config):
    """
    Create U-Net model
    """
    try:
        # Try to use segmentation_models_pytorch
        import segmentation_models_pytorch as smp
        
        if config['USE_ATTENTION']:
            model = smp.Unet(
                encoder_name=config['ENCODER_NAME'],
                decoder_attention_type='scse',
                in_channels=1,
                classes=config['N_CLASSES']
            )
        else:
            model = smp.Unet(
                encoder_name=config['ENCODER_NAME'],
                in_channels=1,
                classes=config['N_CLASSES']
            )
        
        print(f'Created U-Net with {config["ENCODER_NAME"]} encoder')
        
    except ImportError:
        # Fallback to basic U-Net implementation
        model = UNet(n_channels=1, n_classes=config['N_CLASSES'])
        print('Created basic U-Net model')
    
    # Load pretrained weights if available
    weight_files = glob.glob(os.path.join(config['MODEL_SAVE_PATH'], '*.pth'))
    if weight_files:
        latest_weight = natsorted(weight_files)[-1]
        try:
            state_dict = torch.load(latest_weight)
            model.load_state_dict(state_dict)
            print(f'Loaded pretrained weights from: {latest_weight}')
        except Exception as e:
            print(f'Could not load weights: {e}')
    
    # Move to GPU if available
    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            print(f"Using {torch.cuda.device_count()} GPUs for training")
            model = torch.nn.DataParallel(model)
        model = model.cuda()
    
    return model

# Create model (uncomment when ready)
# unet_model = create_unet_model(UNET_CONFIG)

### U-Net Training Loop

In [None]:
def train_unet(model, trainloader, valloader, config):
    """
    Train U-Net model
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=config['LEARNING_RATE'])
    
    # Training history
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    
    for epoch in range(config['EPOCHS']):
        # Training phase
        model.train()
        epoch_train_losses = []
        epoch_train_dice = []
        
        for batch_idx, sample in enumerate(trainloader):
            images = sample['image'].to(device)
            masks = sample['landmarks'].to(device)
            
            # Forward pass
            outputs = model(images)
            outputs = torch.sigmoid(outputs)
            
            # Calculate loss
            loss = dice_loss(outputs, masks)
            dice_score = dice(outputs, masks)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_train_losses.append(loss.item())
            epoch_train_dice.append(dice_score.item())
        
        # Validation phase
        model.eval()
        epoch_val_losses = []
        epoch_val_dice = []
        
        with torch.no_grad():
            for batch_idx, sample in enumerate(valloader):
                images = sample['image'].to(device)
                masks = sample['landmarks'].to(device)
                
                outputs = model(images)
                outputs = torch.sigmoid(outputs)
                
                loss = dice_loss(outputs, masks)
                dice_score = dice(outputs, masks)
                
                epoch_val_losses.append(loss.item())
                epoch_val_dice.append(dice_score.item())
        
        # Calculate epoch metrics
        avg_train_loss = np.mean(epoch_train_losses)
        avg_train_dice = np.mean(epoch_train_dice)
        avg_val_loss = np.mean(epoch_val_losses)
        avg_val_dice = np.mean(epoch_val_dice)
        
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        
        print(f'Epoch {epoch+1}/{config["EPOCHS"]}:')
        print(f'  Train - Loss: {avg_train_loss:.5f}, Dice: {avg_train_dice:.5f}')
        print(f'  Val   - Loss: {avg_val_loss:.5f}, Dice: {avg_val_dice:.5f}')
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            
            # Save model state dict
            if hasattr(model, 'module'):
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()
            
            save_path = os.path.join(config['MODEL_SAVE_PATH'], 'best_unet.pth')
            torch.save(state_dict, save_path)
            print(f'  Best model saved: {save_path} (Val Loss: {best_val_loss:.5f})')
        
        print('-' * 50)
    
    # Plot training history
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(range(1, len(train_losses) + 1), train_losses, 'b-', label='Train Loss')
    plt.plot(range(1, len(val_losses) + 1), val_losses, 'r-', label='Val Loss')
    plt.title('Loss Over Time')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(config['MODEL_SAVE_PATH'], 'training_history.png'))
    plt.show()
    
    print(f'Training completed! Best validation loss: {best_val_loss:.5f}')
    
    return model, train_losses, val_losses

# Train U-Net (uncomment when ready)
# trained_unet, train_losses, val_losses = train_unet(unet_model, trainloader, valloader, UNET_CONFIG)

### Visualization Functions

In [None]:
def visualize_unet_predictions(model, dataloader, config, num_samples=4):
    """
    Visualize U-Net predictions
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.eval()
    
    with torch.no_grad():
        for i, sample in enumerate(dataloader):
            if i >= num_samples:
                break
            
            images = sample['image'].to(device)
            masks = sample['landmarks'].to(device)
            
            outputs = model(images)
            outputs = torch.sigmoid(outputs)
            
            # Convert to numpy for visualization
            image = images[0].cpu().numpy().squeeze()
            mask = masks[0].cpu().numpy().squeeze()
            prediction = outputs[0].cpu().numpy().squeeze()
            
            # Create visualization
            fig, axes = plt.subplots(1, 4, figsize=(16, 4))
            
            axes[0].imshow(image, cmap='gray')
            axes[0].set_title('Input Image')
            axes[0].axis('off')
            
            axes[1].imshow(mask, cmap='gray')
            axes[1].set_title('Ground Truth')
            axes[1].axis('off')
            
            axes[2].imshow(prediction, cmap='gray')
            axes[2].set_title('Prediction')
            axes[2].axis('off')
            
            # Overlay
            overlay = image + prediction
            axes[3].imshow(overlay, cmap='gray')
            axes[3].set_title('Overlay')
            axes[3].axis('off')
            
            plt.tight_layout()
            plt.show()

# Visualize predictions (uncomment when model is trained)
# visualize_unet_predictions(trained_unet, valloader, UNET_CONFIG)

## Usage Instructions

### To train RetinaNet:
1. Update the `RETINANET_DATA_PATHS` dictionary with your dataset paths
2. Uncomment the RetinaNet data loading and model creation cells
3. Run the RetinaNet training cell

### To train U-Net:
1. Update the `UNET_DATA_PATHS` dictionary with your dataset paths
2. Uncomment the U-Net data loading and model creation cells
3. Run the U-Net training cell

### Notes:
- Make sure you have the required dependencies installed
- Adjust hyperparameters in the configuration dictionaries as needed
- Monitor GPU memory usage and adjust batch sizes if necessary
- Both models will save checkpoints during training