# Imports

In [None]:
import sys
print(f"Python version is {sys.version}")

import torch
print(f"Torch version is {torch.__version__}")

import os
import time
import logging

from tqdm import tqdm

In [None]:
import torchvision
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.models.detection import FasterRCNN
from torchvision.models.resnet import resnet50
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops import FeaturePyramidNetwork, MultiScaleRoIAlign
from torchvision.models.detection.rpn import AnchorGenerator

# Helpers, utility

In [None]:
# Ensure forward pass stability of the model
def test_model(input_type="real"):
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  model.to(device)
  model.eval()

  # Change to random noise if prompted
  if input_type == "noise":
    sample_input = torch.randn(3, 800, 800)
    sample_input = sample_input.to(device)
  elif input_type == "real": # TODO: Implement
    pass

  with torch.no_grad():
      outputs = model([sample_input])

  print(f'Model output for input type: {input_type}')
  print(outputs)

test_model(input_type="noise")

## Inspect a sample

In [None]:
def inspect_sample(dataset, index):
  # Load a sample from the dataset
  img, target = dataset[index]

  # Verify the image
  print(f"Image shape: {img.shape}")  # Should be [C, H, W]
  print(f"Image dtype: {img.dtype}")  # Should be torch.float32

  # Verify the target
  print(f"\nTarget keys: {target.keys()}")

  # Verify the bounding boxes
  print(f"\nBoxes: {target['boxes']}")
  print(f"Boxes shape: {target['boxes'].shape}")

  # Verify the labels
  print(f"\nLabels: {target['labels']}")
  print(f"Labels shape: {target['labels'].shape}")

  # Verify other target fields
  print(f"\nImage ID: {target['image_id']}")
  print(f"Area: {target['area']}")
  print(f"Iscrowd: {target['iscrowd']}")

inspect_sample(train_dataset, 0)

In [None]:
# TODO: debug: for checking validity of bounding box bounds
# def validate_bboxes_subset(dataset, sample_size=2000):
#     """
#     Checks a random subset of the dataset to ensure bounding boxes are valid.
#     sample_size: how many samples to validate
#     """
#     num_samples = len(dataset)
#     sampled_indices = random.sample(range(num_samples), min(sample_size, num_samples))

#     invalid_samples = []
#     for idx in tqdm(sampled_indices, desc="Validating bounding boxes (subset)", unit="sample"):
#         img, target = dataset[idx]

#         # get image dimensions
#         _, img_height, img_width = img.shape
#         boxes = target['boxes'].cpu()

#         if boxes.size(0) == 0:
#             continue

#         widths = boxes[:, 2] - boxes[:, 0]
#         heights = boxes[:, 3] - boxes[:, 1]

#         invalid_mask = (
#             (widths <= 0) |
#             (heights <= 0) |
#             (boxes[:, 0] < 0) |
#             (boxes[:, 1] < 0) |
#             (boxes[:, 2] > img_width) |
#             (boxes[:, 3] > img_height)
#         )

#         if invalid_mask.any():
#             invalid_samples.append(idx)

#     if invalid_samples:
#         print("Found invalid bounding boxes in these samples:", invalid_samples)
#     else:
#         print(f"\nNo invalid bounding boxes found in this random subset of {len(sampled_indices)} samples.")

#     return invalid_samples


# invalid_samples = validate_bboxes_subset(train_dataset, sample_size=2000)


## Visualize the Image and Bounding Boxes

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

def visualize_image_with_boxes(img, target):
    # Convert the image tensor to a NumPy array and transpose to [H, W, C]
    img_np = img.permute(1, 2, 0).numpy()

    fig, ax = plt.subplots(1, figsize=(12, 9))
    ax.imshow(img_np)

    boxes = target['boxes']
    for box in boxes:
        xmin, ymin, xmax, ymax = box
        width = xmax - xmin
        height = ymax - ymin
        rect = patches.Rectangle(
            (xmin, ymin), width, height,
            linewidth=1, edgecolor='r', facecolor='none'
        )
        ax.add_patch(rect)

    plt.show()

# Fetch a random sample from the dataset
random_index = random.randint(0, len(train_dataset))
img, target = train_dataset[random_index]

# Visualize the sample
visualize_image_with_boxes(img, target)

## Test-Fetch Batches

In [None]:
# Fetch multiple batches to ensure consistency
for _ in range(5):
    images, targets = next(iter(train_loader))
    print(f"Number of images: {len(images)}")

In [None]:
# Fetch a batch
images, targets = next(iter(train_loader))

# Verify the batch
print(f"Number of images: {len(images)}")
print(f"Image 0 shape: {images[0].shape}")
print(f"Image 1 shape: {images[1].shape}")
print(f"Target 0 keys: {targets[0].keys()}")
print(f"Target 1 keys: {targets[1].keys()}")

# main.py

## Run configuration

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# To be able to import .py scripts stored under src/
import sys
sys.path.append('/content/drive/MyDrive/Colab Notebooks/cmpe593/term-project/src')
import os
import torch
from datetime import datetime


# Run parameters
model_type = "custom_FADE" # TODO: custom_FADE or "baseline"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

##############################################################################
# PATHS
##############################################################################
base_path = "/content/drive/MyDrive"
project_dir = os.path.join(base_path, "Colab Notebooks", "cmpe593", "term-project")

## Dataset paths
dataset_split = "train2017" # TODO: pick from ("train2017", "val2017")
coco_base_dir = f'{base_path}/coco'
images_dir = os.path.join(coco_base_dir, 'images', dataset_split)
annotations_path = os.path.join(coco_base_dir,
                                'annotations',
                                f'instances_{dataset_split}.json')

## Artifact paths (checkpoints, plots, metrics)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results_dir = os.path.join(project_dir, 'results', timestamp)
checkpoint_dir = os.path.join(results_dir, 'checkpoints')
plot_dir = os.path.join(results_dir, 'plots')
metrics_dir = os.path.join(results_dir, 'metrics')


# Ensure artifact paths exist
for dir in [results_dir, checkpoint_dir, plot_dir, metrics_dir]:
  if not os.path.exists(dir):
    os.makedirs(dir)
    print(f'Directory created at {dir}')
  else:
    print(f'Directory {dir} already exists.')
##############################################################################

## Load data and model

In [None]:
from dataops.data_loader import get_data_loader
from torch.utils.data import Subset


def get_num_classes(dataloader):
    """
    Retrieve the number of classes from the dataset used by the dataloader.

    Args:
        dataloader (DataLoader): PyTorch DataLoader object.

    Returns:
        int: Number of classes.
    """
    # Access the original dataset
    dataset = dataloader.dataset.dataset if isinstance(dataloader.dataset, Subset) else dataloader.dataset

    # Retrieve category IDs
    category_ids = dataset.cat_ids
    print(f"Number of categories: {len(category_ids)}")
    return len(category_ids), category_ids


# Get data loader and number of classes inside the dataset
dataloader = get_data_loader(images_dir, annotations_path, train=True, subset=True, subset_size=10000)
num_classes, category_ids = get_num_classes(dataloader)

In [None]:
from modelops import model_loader

model = model_loader.load_model(
    model_type=model_type,
    device=device,
    eval_mode=False
)

## Modeling

### Set parameters and optimizer

In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR, StepLR

# Set hyperparameters
num_classes = len(category_ids) + 1  # +1 for background
quick_experiment = False
hyperparam_key = 'quick_experiment' if quick_experiment else 'longer_experiment'

hyperparameters = {
    # Quick experiment is for sanity check
    'quick_experiment': {
        'lr_scheduler': None,
        'lr': 0.01,
        'momentum': 0.9,
        'weight_decay': 0.0005,
        'num_epochs': 3,
    },

    'longer_experiment': {
        'lr_scheduler': 'warmup + stepLR@ep3',
        'lr': 0.01,
        'momentum': 0.9,
        'weight_decay': 0.0005,
        'num_epochs': 10,
    },
}

# Extract hyperparameters
active_hparams = hyperparameters[hyperparam_key]
lr = active_hparams['lr']
momentum = active_hparams['momentum']
weight_decay = active_hparams['weight_decay']
num_epochs = active_hparams['num_epochs']

# Initialize model parameters that require gradients
params = [p for p in model.parameters() if p.requires_grad]

# Define the optimizer
optimizer = optim.SGD(
    params,
    lr=lr,  # Applies to all parameters
    momentum=momentum,
    weight_decay=weight_decay
)

def warmup_schedule(epoch):
    warmup_epochs = 3
    if epoch < warmup_epochs:
        return (epoch + 1) / warmup_epochs  # Start from a small value and scale up
    return 1.0  # Maintain full learning rate after warm-up

# Apply the scheduler
lr_scheduler = LambdaLR(optimizer, lr_lambda=warmup_schedule)

# Print learning rate for debugging
print("Learning Rate Scheduler Test:")
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}: LR = {optimizer.param_groups[0]['lr']:.6f}")
    lr_scheduler.step()

### Save params for debugging purposes

In [None]:
import json

# Path to save the hyperparameters file
hyperparams_path = os.path.join(results_dir, "hyperparameters.json")
active_hparams['model_type'] = f'{model_type}-{model.__class__.__name__}'

# Save hyperparameters
with open(hyperparams_path, "w") as f:
    json.dump(active_hparams, f, indent=4)

print(f"Hyperparameters saved to {hyperparams_path}")

In [None]:
import time
import shutil
import datetime
from tqdm import tqdm


def train_model(
    model,
    train_loader,
    optimizer,
    device,
    num_epochs,
    drive_checkpoint_dir,
    lr_scheduler=None,
    val_loader=None,
):
  epoch_losses = []
  batch_losses = [] # Track batch-level losses
  error_batches = []  # To store details of failing batches for debug
  ap_history = []     # To store AP metrics after each epoch

  # File paths for saving loss data
  epoch_loss_file = os.path.join(drive_checkpoint_dir, "epoch_losses.txt")
  batch_loss_file = os.path.join(drive_checkpoint_dir, "batch_losses.txt")

  # Initialize StepLR scheduler for dynamic adjustment after warm-up
  post_warmup_scheduler = StepLR(optimizer, step_size=3, gamma=0.1)

  for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    start_time = time.time()
    batch_count = 0

    print(f"Epoch {epoch + 1}/{num_epochs}")
    print('-' * 20)

    # Training loop
    for images, targets, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}", unit="batch"):
      batch_count += 1

      # Move images and targets to the device
      images = [img.to(device) for img in images]
      targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

      try:
        # Forward pass
        loss_dict = model(images, targets)
      except AssertionError as e:
        print(f"\n[ERROR] Bounding box error in batch {batch_count}!")
        batch_error_info = {"epoch": epoch+1, "batch_idx": batch_count, "boxes": []}

        for i, tgt in enumerate(targets):
          batch_error_info["boxes"].append(tgt['boxes'].cpu().tolist())

        error_batches.append(batch_error_info)
        continue

      # Compute total loss
      losses = sum(loss for loss in loss_dict.values())
      loss_value = losses.item()
      epoch_loss += loss_value
      batch_losses.append(loss_value)  # Track batch loss

      # Backward pass and optimization
      optimizer.zero_grad()
      losses.backward()

      # Monitor gradient norms
      total_grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
      print(f"Batch {batch_count}, Gradient Norm: {total_grad_norm:.6f}")
      optimizer.step()


      # Display intermediate losses
      if batch_count % 200 == 0:
        print(f"Batch {batch_count}, Loss: {loss_value:.4f}")


    # Transition from warm-up to StepLR after the warm-up phase
    if epoch == 3:
      print(f"Transitioning from warm-up to StepLR at learning rate: {optimizer.param_groups[0]['lr']:.6f}")
      lr_scheduler = post_warmup_scheduler

    # Step the learning rate scheduler
    if lr_scheduler:
      lr_scheduler.step()

    current_lr = optimizer.param_groups[0]['lr']
    epoch_duration = time.time() - start_time
    average_epoch_loss = epoch_loss / len(train_loader)
    epoch_losses.append(average_epoch_loss)

    # Save epoch summary
    print(f"Epoch [{epoch + 1}/{num_epochs}] completed in {str(datetime.timedelta(seconds=int(epoch_duration)))}")
    print(f"Average Loss: {average_epoch_loss:.4f}")
    print(f"Current Learning Rate: {current_lr:.6f}")

    # Save checkpoint
    drive_checkpoint_path = os.path.join(drive_checkpoint_dir, f'model_epoch_{epoch + 1}.pth')
    torch.save(model.state_dict(), drive_checkpoint_path)
    print(f"Checkpoint saved to Google Drive at {drive_checkpoint_path}")

  # =========================
  # Training loop completes
  # =========================
  print("Training loop finished.")

  # Save the model state dictionary
  drive_final_checkpoint_path = os.path.join(drive_checkpoint_dir, 'model_final.pth')
  torch.save(model.state_dict(), drive_final_checkpoint_path)
  print(f"Final model saved to Google Drive at {drive_final_checkpoint_path}")

  # Save and display the training loss plot
  import matplotlib.pyplot as plt

  plot_filename_drive = os.path.join(plot_dir, 'training_loss_plot.png')
  plt.figure()
  plt.plot(range(1, num_epochs + 1), epoch_losses, marker='o')
  plt.xlabel('Epoch')
  plt.ylabel('Average Loss')
  plt.title('Training Loss Over Epochs')
  plt.grid(True)
  plt.savefig(plot_filename_drive)
  print(f"Training loss plot saved to Google Drive at {plot_filename_drive}")
  plt.show()

  if error_batches:
      print("\nSome batches triggered bounding box assertions. Their info:")
      for info in error_batches:
          print(f"Epoch {info['epoch']}, Batch {info['batch_idx']}, Boxes: {info['boxes']}")
  else:
      print("No bounding box errors encountered!")

In [None]:
train_model(
    model=model,
    train_loader=dataloader,
    optimizer=optimizer,
    device=device,
    num_epochs=num_epochs,
    drive_checkpoint_dir=checkpoint_dir,
    val_loader=None
)