# Document aligner project
## Basic Model Implementation
- Working on milestone 1
- Implementing a simple encoder decoder model

In [None]:
from dataset_loader import get_dataloaders, visualize_batch
from reconstruction_model import DocumentReconstructionModel
from reconstruction_model import MaskedL1Loss, MaskedMSELoss, SSIMLoss, UVReconstructionLoss

import time 
import os
import glob
from typing import Dict, List, Tuple, Optional, Callable
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import numpy as np


In [None]:
# Create dataloaders
train_loader, val_loader = get_dataloaders(
    data_dir='/home/daniel-choate/Datasets/DocUnwarp/renders/synthetic_data_pitch_sweep',
    batch_size=8,
    train_split=0.8,
    img_size=(512, 512), #(256, 256)
)

# Visualize samples
sample_batch = next(iter(train_loader))
# visualize_batch(sample_batch, num_samples=4)

# Model description
- TODO

Training:
- [ ] Additional metrics (PSNR, SSIM)
- [ ] Learning rate scheduling
- [ ] Gradient clipping
- [ ] Mixed precision training
- [ ] Logging to tensorboard/wandb

Main training loop
- [ ] Implement better model architecture
- [ ] Try different loss functions
- [ ] Add learning rate scheduling
- [ ] Implement early stopping
- [ ] Add visualization and logging
- [ ] Experiment with data augmentation
- [ ] Use pretrained model from HuggingFace
- [ ] Enable MASKED LOSSES
- [ ] Use DEPTH
- [ ] Use UV
- [ ] Use BORDER
- [ ] Try different OPTIMIZERS

## Training and validation functions

In [None]:
# Training loop 

def train_one_epoch(
    model: nn.Module, 
    dataloader: DataLoader,
    criterion: nn.Module,
    optimizer: torch.optim.Optimizer, 
    device: torch.device, 
    epoch: int
) -> float:
    """ 
    Train for one epoch

    *** CHECK TODO LIST ***
    """
    model.train()
    total_loss = 0.0

    # Loop through each batch 
    # t_batch_start = time.time()
    for batch_idx, batch in enumerate(dataloader):
        t0 = time.time() # TIMING CHECK 0
        # print(f"DataLoader prep time: {(t0 - t_batch_start)*1000:.1f}ms")
        # Move data to device 
        rgb = batch['rgb'].to(device)
        ground_truth = batch['ground_truth'].to(device)

        # Load mask if using masked loss
        mask = batch.get('border', None)
        if mask is not None:
            mask = mask.to(device)

        torch.cuda.synchronize()
        t1 = time.time() # TIMING CHECK 1

        # Forward pass
        optimizer.zero_grad()
        output = model(rgb)
        # print(f"Finished batch {batch_idx}")

        # Compute loss
        if isinstance(criterion, (MaskedL1Loss, MaskedMSELoss)):
            loss = criterion(output, ground_truth, mask)
        elif isinstance(criterion, SSIMLoss):
            loss = criterion(output, ground_truth)
        elif isinstance(criterion, UVReconstructionLoss):
            # Extract additional outputs if avail for UV-based
            losses=criterion(pred_image=output, target_image=ground_truth, mask=mask)
            loss=losses['total']
        else:
            # Standard (MSE, L1)
            # print(f"Standard loss: {criterion}")
            loss = criterion(output, ground_truth)

        torch.cuda.synchronize()
        t2 = time.time() # TIMING CHECK 2

        # Backward pass 
        loss.backward()
        # UPDATE MODEL
        optimizer.step()

        torch.cuda.synchronize()
        t3 = time.time()
        
        total_loss += loss.item()

        # if batch_idx % 20 == 0:
        #     print(
        #         f"Batch {batch_idx}: "
        #         f"h2d={(t1-t0)*1000:.1f}ms | "
        #         f"fwd+loss={(t2-t1)*1000:.1f}ms | "
        #         f"bwd+step={(t3-t2)*1000:.1f}"
        #     )

        # # Print progess 
        # if batch_idx % 10 == 0:
        #     print(f"Epoch {epoch} [{batch_idx}/{len(dataloader)}] Loss: {loss.item():.4f}")
        # t_batch_start = time.time()

    avg_loss = total_loss / len(dataloader)
    return avg_loss


# Validation function 

def validate(
    model: nn.Module, 
    dataloader: DataLoader, 
    criterion: nn.Module, 
    device: torch.device
) -> float:
    """
    Validate the model

    TODO: MODIFY to add more metrics (PSNR, SSIM, etc)
    """
    model.eval()
    total_loss = 0.0

    with torch.no_grad():
        for batch in dataloader:
            rgb = batch['rgb'].to(device)
            ground_truth = batch['ground_truth'].to(device)

            # Optional: MASKED LOSS ***
            mask = batch.get('border', None)
            if mask is not None:
                mask = mask.to(device)

            output = model(rgb)

            # Compute loss (standard or masked 
            if isinstance(criterion, (MaskedL1Loss, MaskedMSELoss)):
                loss = criterion(output, ground_truth, mask)
            elif isinstance(criterion, SSIMLoss):
                loss = criterion(output, ground_truth)
            elif isinstance(criterion, UVReconstructionLoss):
                losses = criterion(pred_image=output, target_image=ground_truth, mask=mask)
                loss = losses['total']
            else:
                loss = criterion(output, ground_truth)

            total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    return avg_loss


## Training Loop

In [None]:
# Main training loop 

def main():
    """
    Main training loop 

    *** CHECK TODO LIST ***

    """
    
    # Configuration
    DATA_DIR = '/home/daniel-choate/Datasets/DocUnwarp/renders/synthetic_data_pitch_sweep'
    BATCH_SIZE = 8
    NUM_EPOCHS = 50
    LEARNING_RATE = 1e-4
    IMG_SIZE = (256, 256) # (512, 512) Using smaller images for faster training for now

    # Set device 
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Create dataloaders
    train_loader, val_loader = get_dataloaders(
        data_dir=DATA_DIR,
        batch_size=BATCH_SIZE,
        img_size=IMG_SIZE,
        use_depth=False, # TODO: True if using depth info
        use_uv=False, # TODO: True if using UV maps
        use_border=False # TODO: True if using border masks for better training
    )

    # Visualize a batch (testing)
    sample_batch = next(iter(train_loader))
    print(f"Batch RGB shape: {sample_batch['rgb'].shape}")
    print(f"Batch GT shape: {sample_batch['ground_truth'].shape}")
    if 'border' in sample_batch:
        print(f"Batch border mask shape: {sample_batch['border'].shape}")
    # visualize_batch(sample_batch) # Troubleshooting / sanity check 


    # Create model
    model = DocumentReconstructionModel().to(device)
    print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}")

    # TODO: TRY DIFFERENT LOSS FUNCTIONS
    # Option 1: simple losses (baseline, just to get it working)
    # criterion = nn.MSELoss() # Basic L2 Loss - sensitive to lighting
    # criterion = nn.L1Loss() # L1 loss also sensitive to lighting

    # Option 2: SSIM loss (RECOMMENDED - structure instead of lighting)
    criterion = SSIMLoss() # Might need a pip install

    # Option 3: Masked losses (doc pixel focus)
    # Make sure to do use_border = True above **
    # criterion = MaskedL1Loss(use_mask=True)
    # criterion = MaskedMSELoss(use_mask=True)

    # Option 4: Combined loss with UV supervision (GET TO THIS EVENTUALLY)
    # NOTE: need to set use_uv = True, use_border = True, 
    # criterion = UVReconstructionLoss(
    #     reconstruction_weight=1.0,
    #     uv_weight=0.5,
    #     smoothness_weight=0.01, 
    #     use_mask=True,
    #     loss_type='ssim' # Use SSIM for geometric recon
    # )

    # TODO: TRY DIFFERENT OPTIMIZERS
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    # optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)

    # Training loop
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []

    for epoch in range(NUM_EPOCHS):
        # Epoch timing check
        start_epoch = time.time()
        
        print(f"\n{'='*50}")
        print(f"EPOCH {epoch+1}/{NUM_EPOCHS}")
        print(f"{'='*50}")

        # TRAIN 
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device, epoch)
        print(f"Train loss: {train_loss:.4f}")
        train_losses.append(train_loss)

        # VALIDATE 
        val_loss = validate(model, val_loader, criterion, device)
        print(f"Val Loss: {val_loss:.4f}")
        val_losses.append(val_loss)

        # Save the best model 
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"Saved best model with val loss {val_loss}.4f")

        epoch_time = time.time() - start_epoch
        print(f"Epoch time: {epoch_time:.4f}s")

    print("\nTraining complete!")
    print(f"Best validation loss: {best_val_loss:.4f}")
    

In [None]:
# Execute main 

print("Document Reconstruction Dataset Loader")
print("="*50)

# Quick test 
train_loader, val_loader = get_dataloaders(
    data_dir = '/home/daniel-choate/Datasets/DocUnwarp/renders/synthetic_data_pitch_sweep',
    batch_size = 4, 
    img_size=(256, 256) # (512,512)
)

print("\nDataset loaded successfully!")

# # Visualize sample batch 
# print("\nVisualizing a sample batch ...")
# sample_batch = next(iter(train_loader))
# print(f"Batch shape - RGB: {sample_batch['rgb'].shape}, Ground Truth: {sample_batch['ground_truth'].shape}")
# visualize_batch(sample_batch, num_samples=min(4, sample_batch['rgb'].shape[0]))

# Main training loop 
main()

In [None]:
# Execute main 

print("Document Reconstruction Dataset Loader")
print("="*50)

# Quick test 
try: 
    train_loader, val_loader = get_dataloaders(
        data_dir = '/home/daniel-choate/Datasets/DocUnwarp/renders/synthetic_data_pitch_sweep',
        batch_size = 4, 
        img_size=(512, 512)
    )

    print("\nDataset loaded successfully!")

    # # Visualize sample batch 
    # print("\nVisualizing a sample batch ...")
    # sample_batch = next(iter(train_loader))
    # print(f"Batch shape - RGB: {sample_batch['rgb'].shape}, Ground Truth: {sample_batch['ground_truth'].shape}")
    # visualize_batch(sample_batch, num_samples=min(4, sample_batch['rgb'].shape[0]))

    # Main training loop 
    main()

except Exception as e:
    print(f"\nError loading dataset: {e}")
    print("Check that data directory exists and contains required files")