# Document aligner project
## Basic Model Implementation
- Working on milestone 1
- Implementing a simple encoder decoder model

In [2]:
from dataset_loader import get_dataloaders, visualize_batch
from reconstruction_model import DocumentReconstructionModel


import os
import glob
from typing import Dict, List, Tuple, Optional, Callable
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import numpy as np


In [3]:
# Create dataloaders
train_loader, val_loader = get_dataloaders(
    data_dir='/home/daniel-choate/Datasets/DocUnwarp/renders/synthetic_data_pitch_sweep',
    batch_size=8,
    train_split=0.8,
    img_size=(512, 512), #(256, 256)
)

# Visualize samples
sample_batch = next(iter(train_loader))
# visualize_batch(sample_batch, num_samples=4)

Found 2486 samples in /home/daniel-choate/Datasets/DocUnwarp/renders/synthetic_data_pitch_sweep
Train samples: 1988, Val samples: 498


# Model description
- TODO

Training:
- [ ] Additional metrics (PSNR, SSIM)
- [ ] Learning rate scheduling
- [ ] Gradient clipping
- [ ] Mixed precision training
- [ ] Logging to tensorboard/wandb

Main training loop
- [ ] Implement better model architecture
- [ ] Try different loss functions
- [ ] Add learning rate scheduling
- [ ] Implement early stopping
- [ ] Add visualization and logging
- [ ] Experiment with data augmentation
- [ ] Use pretrained model from HuggingFace
- [ ] Enable MASKED LOSSES
- [ ] Use DEPTH
- [ ] Use UV
- [ ] Use BORDER 

In [4]:
# # Training loop 

# def train_one_epoch(
#     model: nn.Module, 
#     dataloader: DataLoader,
#     criterion: nn.Module,
#     optimizer: torch.optim.Optimizer, 
#     devie: torch.device, 
#     epoch: int
# ) -> float:
#     """ 
#     Train for one epoch

#     *** CHECK TODO LIST ***
#     """
#     model.train()
#     total_loss = 0.0

#     # Loop through each batch 
#     for batch_idx, batch in enumerate(dataloader):
#         # Move data to device 
#         rgb = batch['rgb'].to(device)
#         ground_truth = batch['ground_truth'].to(device)

#         # Load mask if using masked loss
#         mask = batch.get('border', None)
#         if mask is not None:
#             mask = mask.to(device)

#         # Forward pass
#         optimizer.zero_grad()

## Training Loop

In [7]:
# Main training loop 

def main():
    """
    Main training loop 

    *** CHECK TODO LIST ***

    """
    
    # Configuration
    DATA_DIR = '/home/daniel-choate/Datasets/DocUnwarp/renders/synthetic_data_pitch_sweep'
    BATCH_SIZE = 8
    NUM_EPOCHS = 50
    LEARNING_RATE = 1e-4
    IMG_SIZE = (512, 512)

    # Set device 
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Create dataloaders
    train_loader, val_loader = get_dataloaders(
        data_dir=DATA_DIR,
        batch_size=BATCH_SIZE,
        img_size=IMG_SIZE,
        use_depth=False, # TODO: True if using depth info
        use_uv=False, # TODO: True if using UV maps
        use_border=False # TODO: True if using border masks for better training
    )

    # Visualize a batch (testing)
    sample_batch = next(iter(train_loader))
    print(f"Batch RGB shape: {sample_batch['rgb'].shape}")
    print(f"Batch GT shape: {sample_batch['ground_truth'].shape}")
    if 'border' in sample_batch:
        print(f"Batch border mask shape: {sample_batch['border'].shape}")
    # visualize_batch(sample_batch) # Troubleshooting / sanity check 


    # Create model
    model = DocumentReconstructionModel().to(device)
    print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}")

    # TODO: TRY DIFFERENT LOSS FUNCTIONS
    # Option 1: simple losses (baseline, just to get it working)
    criterion = nn.MSELoss() # Basic L2 Loss - sensitive to lighting
    # criterion = nn.L1Loss() # L1 loss also sensitive to lighting

    # Option 2: SSIM loss (RECOMMENDED - structure instead of lighting)
    # criterion = SSIMLoss() # Might need a pip install

    # Option 3: Masked losses (doc pixel focus)
    # Make sure to do use_border = True above **
    criterion = MaskedL1Loss(use_mask=True)
    criterion = MaskedMSELoss(use_mask=True)
    

In [8]:
# Execute main 

print("Document Reconstruction Dataset Loader")
print("="*50)

# Quick test 
train_loader, val_loader = get_dataloaders(
    data_dir = '/home/daniel-choate/Datasets/DocUnwarp/renders/synthetic_data_pitch_sweep',
    batch_size = 4, 
    img_size=(512, 512)
)

print("\nDataset loaded successfully!")

# # Visualize sample batch 
# print("\nVisualizing a sample batch ...")
# sample_batch = next(iter(train_loader))
# print(f"Batch shape - RGB: {sample_batch['rgb'].shape}, Ground Truth: {sample_batch['ground_truth'].shape}")
# visualize_batch(sample_batch, num_samples=min(4, sample_batch['rgb'].shape[0]))

# Main training loop 
main()

Document Reconstruction Dataset Loader
Found 2486 samples in /home/daniel-choate/Datasets/DocUnwarp/renders/synthetic_data_pitch_sweep
Train samples: 1988, Val samples: 498

Dataset loaded successfully!
Using device: cuda
Found 2486 samples in /home/daniel-choate/Datasets/DocUnwarp/renders/synthetic_data_pitch_sweep
Train samples: 1988, Val samples: 498
Batch RGB shape: torch.Size([8, 3, 512, 512])
Batch GT shape: torch.Size([8, 3, 512, 512])

Model parameters: 483,267


In [9]:
# Execute main 

print("Document Reconstruction Dataset Loader")
print("="*50)

# Quick test 
try: 
    train_loader, val_loader = get_dataloaders(
        data_dir = '/home/daniel-choate/Datasets/DocUnwarp/renders/synthetic_data_pitch_sweep',
        batch_size = 4, 
        img_size=(512, 512)
    )

    print("\nDataset loaded successfully!")

    # # Visualize sample batch 
    # print("\nVisualizing a sample batch ...")
    # sample_batch = next(iter(train_loader))
    # print(f"Batch shape - RGB: {sample_batch['rgb'].shape}, Ground Truth: {sample_batch['ground_truth'].shape}")
    # visualize_batch(sample_batch, num_samples=min(4, sample_batch['rgb'].shape[0]))

    # Main training loop 
    main()

except Exception as e:
    print(f"\nError loading dataset: {e}")
    print("Check that data directory exists and contains required files")

Document Reconstruction Dataset Loader
Found 2486 samples in /home/daniel-choate/Datasets/DocUnwarp/renders/synthetic_data_pitch_sweep
Train samples: 1988, Val samples: 498

Dataset loaded successfully!
Using device: cuda
Found 2486 samples in /home/daniel-choate/Datasets/DocUnwarp/renders/synthetic_data_pitch_sweep
Train samples: 1988, Val samples: 498
Batch RGB shape: torch.Size([8, 3, 512, 512])
Batch GT shape: torch.Size([8, 3, 512, 512])

Error loaded dataset: name 'DocumentReconstructionModel' is not defined
Check that data directory exists and contains required files
