# DermoSegDiff Training Notebook

This notebook provides a complete pipeline for training DermoSegDiff on custom skin lesion segmentation data.

## Data Structure Expected:
```
data_dir/
├── train/
│   ├── images/ (*.jpg)
│   └── masks/ (*.png - binary 0,255)
├── val/
│   ├── images/ (*.jpg)
│   └── masks/ (*.png - binary 0,255)
└── test/
    └── images/ (*.jpg - no masks)
```

In [None]:
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import yaml
from PIL import Image
import glob

# Add src to path
sys.path.append('src')

# Check GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"CUDA Version: {torch.version.cuda}")

## Configuration Setup

Define the training configuration for your custom dataset.

In [None]:
# Configuration for custom dataset
config = {
    "dataset": {
        "name": "custom",
        "class_name": "CustomDatasetFast", 
        "data_dir": "/path/to/your/data",  # CHANGE THIS TO YOUR DATA PATH
        "input_size": 256,
        "img_channels": 3,
        "msk_channels": 1,
        "add_boundary_mask": False,
        "add_boundary_dist": False
    },
    "model": {
        "name": "dermosegdiff_custom",
        "class": "DermoSegDiff",
        "save_dir": "weights",
        "params": {
            "dim_x": 128,
            "dim_g": 64,
            "channels_x": 1,
            "channels_g": 3,
            "dim_mults": [1, 2, 4, 8],
            "resnet_block_groups": 4
        }
    },
    "training": {
        "epochs": 100,
        "optimizer": {
            "name": "AdamW",
            "params": {
                "lr": 1e-4,
                "betas": [0.9, 0.999],
                "weight_decay": 1e-4
            }
        },
        "scheduler": {
            "factor": 0.5,
            "patience": 10,
            "min_lr": 1e-7
        },
        "loss": {
            "mse": {"weight": 1.0}
        },
        "ema": {
            "use": True,
            "params": {
                "beta": 0.9999,
                "update_every": 10
            }
        },
        "intial_weights": {
            "use": False,
            "file_path": ""
        }
    },
    "diffusion": {
        "schedule": {
            "timesteps": 1000,
            "mode": "cosine",
            "beta_start": 0.0001,
            "beta_end": 0.02
        }
    },
    "data_loader": {
        "train": {
            "batch_size": 4,
            "shuffle": True,
            "num_workers": 4,
            "drop_last": True
        },
        "validation": {
            "batch_size": 4,
            "shuffle": False,
            "num_workers": 4,
            "drop_last": False
        },
        "test": {
            "batch_size": 1,
            "shuffle": False,
            "num_workers": 4,
            "drop_last": False
        }
    },
    "run": {
        "device": device,
        "continue_training": False,
        "writer_dir": "runs"
    },
    "testing": {
        "ensemble": 1,
        "model_weigths": {
            "overload": False,
            "file_path": ""
        },
        "result_imgs": {
            "save": True,
            "dir": "results"
        }
    }
}

# IMPORTANT: Change this to your actual data directory
config["dataset"]["data_dir"] = input("Enter your data directory path: ")

print("Configuration loaded successfully!")

## Data Verification

Let's verify your data structure and visualize some samples.

In [None]:
# Verify data structure
data_dir = config["dataset"]["data_dir"]

train_images = glob.glob(os.path.join(data_dir, "train", "images", "*.jpg"))
train_masks = glob.glob(os.path.join(data_dir, "train", "masks", "*.png"))
val_images = glob.glob(os.path.join(data_dir, "val", "images", "*.jpg"))
val_masks = glob.glob(os.path.join(data_dir, "val", "masks", "*.png"))
test_images = glob.glob(os.path.join(data_dir, "test", "images", "*.jpg"))

print("Data Statistics:")
print(f"Training Images: {len(train_images)}")
print(f"Training Masks: {len(train_masks)}")
print(f"Validation Images: {len(val_images)}")
print(f"Validation Masks: {len(val_masks)}")
print(f"Test Images: {len(test_images)}")

# Visualize some samples
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

for i in range(3):
    if i < len(train_images):
        # Load image
        img_path = train_images[i]
        mask_path = train_masks[i]
        
        img = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')
        
        axes[0, i].imshow(img)
        axes[0, i].set_title(f"Train Image {i+1}")
        axes[0, i].axis('off')
        
        axes[1, i].imshow(mask, cmap='gray')
        axes[1, i].set_title(f"Train Mask {i+1}")
        axes[1, i].axis('off')

plt.tight_layout()
plt.show()

## Import Required Modules

Import all necessary modules for training.

In [None]:
from utils.helper_funcs import (
    get_model_path,
    get_conf_name,
    print_config,
)
from models import *
from forward.forward_schedules import ForwardSchedule
from forward.forward_process import ForwardProcess
from torch.optim import lr_scheduler, AdamW
from train_validate import train, validate
from loaders.dataloaders import get_dataloaders
from torch.utils.tensorboard import SummaryWriter
from common.logging import get_logger
from ema_pytorch import EMA
import warnings
warnings.filterwarnings('ignore')

## Setup Training Components

Initialize model, optimizer, scheduler, and other training components.

In [None]:
# Create directories
Path(config["model"]["save_dir"]).mkdir(exist_ok=True, parents=True)
Path(config["run"]["writer_dir"]).mkdir(exist_ok=True, parents=True)

# Setup logging
logger = get_logger(filename=f"{config['model']['name']}", dir="logs")
print_config(config, logger)

# Create tensorboard writer
writer = SummaryWriter(f'{config["run"]["writer_dir"]}/{config["model"]["name"]}')

# Get configuration name
ID = get_conf_name(config)
print(f"Configuration ID: {ID}")

# Setup forward process
forward_schedule = ForwardSchedule(**config["diffusion"]["schedule"])
forward_process = ForwardProcess(forward_schedule)

# Get dataloaders
tr_dataloader, vl_dataloader = get_dataloaders(config, ["tr", "vl"])

print(f"Training batches: {len(tr_dataloader)}")
print(f"Validation batches: {len(vl_dataloader)}")

In [None]:
# Initialize model
Net = globals()[config["model"]["class"]]
model = Net(**config["model"]["params"])
model = model.to(device)

total_params = sum(p.numel() for p in model.parameters())
logger.info(f"Number of model parameters: {total_params:,}")

# Setup optimizer
optimizer = AdamW(model.parameters(), **config["training"]["optimizer"]["params"])

# Setup scheduler
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, "min", **config["training"]["scheduler"])

# Setup EMA
if config["training"]["ema"]["use"]:
    ema = EMA(model=model, **config["training"]["ema"]["params"])
    ema.to(device)
else:
    ema = None

print("Model and training components initialized successfully!")

## Training Loop

Start the training process.

In [None]:
# Training variables
epochs = config["training"]["epochs"]
start_epoch = 0
best_vl_loss = np.Inf

print(f"Starting training for {epochs} epochs...")

for epoch in range(start_epoch, epochs):
    print(f"\n=== Epoch {epoch+1}/{epochs} ===")
    
    # Training
    tr_losses, model = train(
        model,
        tr_dataloader,
        forward_process,
        device,
        optimizer,
        ema=ema,
        cfg=config,
        extra={"skip_steps": 10, "prefix": f"ep:{epoch+1}/{epochs}"},
        logger=logger
    )
    
    # Validation
    vl_losses = validate(
        ema.ema_model if ema else model,
        vl_dataloader,
        forward_process,
        device,
        cfg=config,
        vl_runs=3,
        logger=logger
    )
    
    # Calculate average losses
    tr_loss = np.mean([l[0] for l in tr_losses])
    vl_loss = np.mean([l[0] for l in vl_losses])
    
    # Log to tensorboard
    writer.add_scalars(
        f"Loss/train vs validation/{config['training']['loss_name']}",
        {"Train": tr_loss, "Validation": vl_loss},
        epoch,
    )
    
    # Update scheduler
    scheduler.step(vl_loss)
    
    print(f"Epoch {epoch+1}: Train Loss = {tr_loss:.6f}, Val Loss = {vl_loss:.6f}")
    
    # Save best model
    if best_vl_loss > vl_loss:
        print(f">>> New best model! Previous: {best_vl_loss:.6f}, Current: {vl_loss:.6f}")
        best_vl_loss = vl_loss
        
        model_path = get_model_path(name=ID, dir=config["model"]["save_dir"])
        checkpoint = {
            "model": model.state_dict(),
            "epoch": epoch,
            "epochs": epochs,
            "optimizer": optimizer.state_dict(),
            "ema": ema.state_dict() if ema else None,
            "vl_loss": vl_loss,
        }
        torch.save(checkpoint, model_path)
        print(f"Model saved to: {model_path}")

writer.flush()
writer.close()
print("\nTraining completed!")

## Testing and Prediction

Load the best model and make predictions on test data.

In [None]:
from reverse.reverse_process import sample
from modules.transforms import DiffusionTransform
import torchvision.transforms as T

# Load best model
model_path = get_model_path(name=ID, dir=config["model"]["save_dir"])
checkpoint = torch.load(model_path, map_location=device)

if ema:
    ema.load_state_dict(checkpoint["ema"])
    model = ema.ema_model
else:
    model.load_state_dict(checkpoint["model"])

model.to(device)
model.eval()

print(f"Best model loaded from: {model_path}")
print(f"Best validation loss: {checkpoint['vl_loss']:.6f}")

In [None]:
# Get test dataloader
te_dataloader = get_dataloaders(config, "te")
DT = DiffusionTransform((config["dataset"]["input_size"], config["dataset"]["input_size"]))

# Create output directory
output_dir = "predicted_masks"
Path(output_dir).mkdir(exist_ok=True, parents=True)

print(f"Starting prediction on {len(te_dataloader)} test images...")

with torch.no_grad():
    for step, batch in enumerate(te_dataloader):
        batch_imgs = batch["image"].to(device)
        batch_ids = batch["id"]
        
        # Generate predictions
        samples = sample(
            forward_schedule,
            model,
            images=batch_imgs,
            out_channels=1,
            desc=f"Predicting {step+1}/{len(te_dataloader)}",
        )
        
        preds = samples[-1][:, :1, :, :].to(device)
        
        # Post-process and save predictions
        for i, pred in enumerate(preds):
            # Convert to binary mask
            pred_binary = torch.where(pred > 0, 1, 0).float()
            
            # Convert to PIL Image
            pred_np = DT.get_reverse_transform_to_numpy()(pred_binary)[:, :, 0]
            pred_img = Image.fromarray((pred_np * 255).astype(np.uint8))
            
            # Save prediction
            output_path = os.path.join(output_dir, f"{batch_ids[i]}.png")
            pred_img.save(output_path)
            
        if (step + 1) % 10 == 0:
            print(f"Processed {step + 1}/{len(te_dataloader)} batches")

print(f"\nPrediction completed! Results saved to: {output_dir}")

## Visualization of Results

Display some prediction results.

In [None]:
# Visualize some results
test_images = glob.glob(os.path.join(config["dataset"]["data_dir"], "test", "images", "*.jpg"))
predicted_masks = glob.glob(os.path.join(output_dir, "*.png"))

fig, axes = plt.subplots(2, 5, figsize=(20, 8))

for i in range(min(5, len(test_images))):
    # Load original image
    img_name = os.path.basename(test_images[i]).replace('.jpg', '')
    img = Image.open(test_images[i]).convert('RGB')
    
    # Load predicted mask
    pred_path = os.path.join(output_dir, f"{img_name}.png")
    if os.path.exists(pred_path):
        pred_mask = Image.open(pred_path).convert('L')
    else:
        pred_mask = Image.new('L', img.size, 0)
    
    axes[0, i].imshow(img)
    axes[0, i].set_title(f"Test Image {i+1}")
    axes[0, i].axis('off')
    
    axes[1, i].imshow(pred_mask, cmap='gray')
    axes[1, i].set_title(f"Predicted Mask {i+1}")
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

print("Visualization completed!")

## Summary

Training and prediction completed successfully!

### Files Created:
- **Model weights**: Saved in `weights/` directory
- **Predicted masks**: Saved in `predicted_masks/` directory
- **Training logs**: Available in tensorboard logs

### Next Steps:
1. Review the predicted masks in the output folder
2. Evaluate the model performance if you have ground truth for validation
3. Fine-tune hyperparameters if needed
4. Use the trained model for inference on new data