## Run Configuration

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# To be able to import .py scripts stored under src/
import sys
sys.path.append('/content/drive/MyDrive/Colab Notebooks/cmpe593/term-project/src')
import os
import torch
from datetime import datetime


# Run parameters
model_type = "custom_FADE" # TODO: custom_FADE or "baseline"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

##############################################################################
# PATHS
##############################################################################
base_path = "/content/drive/MyDrive"
project_dir = os.path.join(base_path, "Colab Notebooks", "cmpe593", "term-project")

## Dataset paths
dataset_split = "train2017" # TODO: pick from ("train2017", "val2017")
coco_base_dir = f'{base_path}/coco'
images_dir = os.path.join(coco_base_dir, 'images', dataset_split)
annotations_path = os.path.join(coco_base_dir,
                                'annotations',
                                f'instances_{dataset_split}.json')

## Artifact paths (checkpoints, plots, metrics)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results_dir = os.path.join(project_dir, 'results', timestamp)
checkpoint_dir = os.path.join(results_dir, 'checkpoints')
plot_dir = os.path.join(results_dir, 'plots')
metrics_dir = os.path.join(results_dir, 'metrics')


# Ensure artifact paths exist
for dir in [results_dir, checkpoint_dir, plot_dir, metrics_dir]:
  if not os.path.exists(dir):
    os.makedirs(dir)
    print(f'Directory created at {dir}')
  else:
    print(f'Directory {dir} already exists.')
##############################################################################

## Load data and model

In [None]:
from dataops.data_loader import get_data_loader
from dataops.utils import get_num_classes

# Get data loader and number of classes inside the dataset
dataloader = get_data_loader(images_dir, annotations_path, train=True, subset_size=10000)
num_classes, category_ids = get_num_classes(dataloader)

In [None]:
from modelops import model_loader

model = model_loader.load_model(
    model_type=model_type,
    device=device,
    eval_mode=False
)

## Configure Model and Training

### Set parameters and optimizer

In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR

# Set hyperparameters
num_classes = len(category_ids) + 1  # +1 for background
quick_experiment = False
hyperparam_key = 'quick_experiment' if quick_experiment else 'longer_experiment'

hyperparameters = {
    # Quick experiment is for sanity check
    'quick_experiment': {
        'lr_scheduler': None,
        'lr': 0.01,
        'momentum': 0.9,
        'weight_decay': 0.0005,
        'num_epochs': 3,
    },

    'longer_experiment': {
        'lr_scheduler': 'scheduledLRonly',
        'init_lr': 0.01,
        'momentum': 0.9,
        'weight_decay': 0.0005,
        'num_epochs': 20,
    },
}

# Extract hyperparameters
active_hparams = hyperparameters[hyperparam_key]
lr = active_hparams['init_lr']
momentum = active_hparams['momentum']
weight_decay = active_hparams['weight_decay']
num_epochs = active_hparams['num_epochs']

# Initialize model parameters that require gradients
params = [p for p in model.parameters() if p.requires_grad]

# Define the optimizer
optimizer = optim.SGD(
    params,
    lr=lr,
    momentum=momentum,
    weight_decay=weight_decay
)

### Save params for debugging purposes

In [None]:
import json

# Path to save the hyperparameters file
hyperparams_path = os.path.join(results_dir, "hyperparameters.json")
active_hparams['model_type'] = f'{model_type}-{model.__class__.__name__}'

# Save hyperparameters
with open(hyperparams_path, "w") as f:
    json.dump(active_hparams, f, indent=4)

print(f"Hyperparameters saved to {hyperparams_path}")

### Define training loop

In [None]:
import time
import shutil
import datetime
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR


def train_model(
    model,
    train_loader,
    optimizer,
    device,
    num_epochs,
    drive_checkpoint_dir,
    lr_scheduler=None,
    val_loader=None,
):
  epoch_losses = []
  error_batches = []

  # File paths for saving loss data
  epoch_loss_file = os.path.join(drive_checkpoint_dir, "epoch_losses.txt")

  for epoch in range(num_epochs):
    start_time = time.time()
    model.train()
    epoch_loss = 0.0
    batch_count = 0

    print(f"Epoch {epoch + 1}/{num_epochs}: Learning Rate = {optimizer.param_groups[0]['lr']}")
    print('-' * 20)

    # Training loop
    for images, targets, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}", unit="batch"):
      batch_count += 1

      # Move images and targets to the device
      images = [img.to(device) for img in images]
      targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

      try:
        # Forward pass
        loss_dict = model(images, targets)
      except AssertionError as e:
        print(f"\n[ERROR] Bounding box error in batch {batch_count}!")
        batch_error_info = {"epoch": epoch+1, "batch_idx": batch_count, "boxes": []}

        for i, tgt in enumerate(targets):
          batch_error_info["boxes"].append(tgt['boxes'].cpu().tolist())

        error_batches.append(batch_error_info)
        continue

      # Compute total loss
      losses = sum(loss for loss in loss_dict.values())
      loss_value = losses.item()
      epoch_loss += loss_value

      # Backward pass and optimization
      optimizer.zero_grad()
      losses.backward()
      optimizer.step()

      # Display intermediate losses
      if batch_count % 250 == 0:
        print(f"Batch {batch_count}, Loss: {loss_value:.4f}")

    # Step the learning rate scheduler
    current_lr = optimizer.param_groups[0]['lr']
    if lr_scheduler:
      print(f"At epoch {epoch + 1} stepping from learning rate: {current_lr:.6f}")
      lr_scheduler.step()
      print(f"At epoch {epoch + 1} stepping to learning rate: {optimizer.param_groups[0]['lr']:.6f}")

    epoch_duration = time.time() - start_time
    average_epoch_loss = epoch_loss / len(train_loader)
    epoch_losses.append(average_epoch_loss)

    with open(epoch_loss_file, "a") as f:
        # Save epoch summary
        epoch_summary = (
            f"Epoch [{epoch + 1}/{num_epochs}] completed in {str(datetime.timedelta(seconds=int(epoch_duration)))}\n"
            f"Learning Rate: "
            f"Average Loss: {average_epoch_loss:.4f}\n"
        )

        # Print to console
        print(epoch_summary.strip())

        # Write to file
        f.write(epoch_summary)

    # Save checkpoint
    drive_checkpoint_path = os.path.join(drive_checkpoint_dir, f'model_epoch_{epoch + 1}.pth')
    torch.save(model.state_dict(), drive_checkpoint_path)
    print(f"Checkpoint saved to Google Drive at {drive_checkpoint_path}")

  # =========================
  # Training loop completes
  # =========================
  print("Training loop finished.")

  # Save the model state dictionary
  drive_final_checkpoint_path = os.path.join(drive_checkpoint_dir, 'model_final.pth')
  torch.save(model.state_dict(), drive_final_checkpoint_path)
  print(f"Final model saved to Google Drive at {drive_final_checkpoint_path}")

  # Save and display the training loss plot
  import matplotlib.pyplot as plt

  plot_filename_drive = os.path.join(plot_dir, 'training_loss_plot.png')
  plt.figure()
  plt.plot(range(1, num_epochs + 1), epoch_losses, marker='o')
  plt.xlabel('Epoch')
  plt.ylabel('Average Loss')
  plt.title('Training Loss Over Epochs')
  plt.grid(True)
  plt.savefig(plot_filename_drive)
  print(f"Training loss plot saved to Google Drive at {plot_filename_drive}")
  plt.show()

  if error_batches:
      print("\nSome batches triggered bounding box assertions. Their info:")
      for info in error_batches:
          print(f"Epoch {info['epoch']}, Batch {info['batch_idx']}, Boxes: {info['boxes']}")
  else:
      print("No bounding box errors encountered!")

## Initiate Training

In [None]:
def lr_schedule(epoch):
    epoch += 1
    lr = 1.0

    if epoch in [1]:
      lr = 1.0
    elif epoch in [2, 3]:
      lr = 0.4
    elif epoch in [4, 5]:
      lr = 0.1
    elif epoch in [6, 7]:
      lr = 0.03
    elif epoch in [8, 9]:
      lr = 0.01
    else:
      lr = 0.001

    return lr


# Apply the scheduler
lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_schedule)

train_model(
    model=model,
    train_loader=dataloader,
    optimizer=optimizer,
    device=device,
    num_epochs=num_epochs,
    drive_checkpoint_dir=checkpoint_dir,
    lr_scheduler=lr_scheduler,
    val_loader=None
)