In [None]:
import torch
import torch.nn as nn
from torchvision.models.detection import fasterrcnn_mobilenet_v3_large_fpn, FasterRCNN_MobileNet_V3_Large_FPN_Weights
from torchvision.datasets import CocoDetection
from torch.utils.data import DataLoader
from pycocotools.cocoeval import COCOeval
import torchvision
import torch.nn.functional as F

import copy
import json
import os
import time
from datetime import datetime

In [None]:
from bloc_diag_model.BlocDiagBoxHead import BlocDiagBoxHead
from utils.train_utils import train_one_epoch

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Load pretrained weights and model
weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT
model = fasterrcnn_mobilenet_v3_large_fpn(weights=weights)

In [None]:
in_features = model.roi_heads.box_head.fc6.in_features
num_classes = 91


original_box_head = model.roi_heads.box_head

custom_box_head = BlocDiagBoxHead([12544, 1024, 1024], [[64]*16, [1024]], [[784]*16, [1024]], True)

custom_box_head.fc6.load_state_dict(original_box_head.fc6.state_dict())
custom_box_head.fc7.load_state_dict(original_box_head.fc7.state_dict())

model.roi_heads.box_head = custom_box_head.to(device)

##### sum of the L1 penalties of all off-block-diagonal components

In [None]:
model.roi_heads.box_head.get_off_diag_loss()

In [None]:
model.to(device)
model.train()

In [None]:
# Load the transform used during pretraining
transform = weights.transforms()

In [None]:
# Data from https://cocodataset.org/#download
data_dir = "data/coco"
train_img_folder = os.path.join(data_dir, "train2017")
train_ann_file = os.path.join(data_dir, "annotations/instances_train2017.json")

In [None]:
N = 20000
subset_indices = list(range(N))
dataset_raw = torch.utils.data.Subset(CocoDetection(train_img_folder, train_ann_file), subset_indices)
dataset = []
for i in range(N):
    img, target = dataset_raw[i]
    img_id = dataset_raw.dataset.ids[dataset_raw.indices[i]]
    transformed_img = transform(img)
    for t in target:
        t["image_id"] = img_id
    dataset.append((transformed_img, target))

In [None]:
def coco_collate_fn(batch):
    return tuple(zip(*batch))

In [None]:
data_loader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=coco_collate_fn)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

In [None]:
n_total_epoch = 3
lambda_offdiag = 1e-4
timestamp_start = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')


for epoch_num in range(n_total_epoch):
    start_time = time.time()
    
    train_one_epoch(model, data_loader, optimizer, epoch_num=epoch_num+1, n_total_epoch=n_total_epoch,
                    device=device,
                    metrics_path=os.path.join('./saved_models',
                                              f"train_metrics__lambda_offdiag={lambda_offdiag}__n_total_epoch={n_total_epoch}__{timestamp_start}.csv"),
                    lambda_offdiag=lambda_offdiag)

    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Epoch {epoch_num+1}/{n_total_epoch} completed in {elapsed_time:.2f} seconds.")


model_save_path = os.path.join('./saved_models',
                               f"model_BlocDiagBoxHead__lambda_offdiag={lambda_offdiag}__n_total_epoch={n_total_epoch}__{timestamp_start}.pt")
torch.save(model.state_dict(), model_save_path)

#### Model evaluation

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
n_total_epoch = 3
lambda_offdiag = 1e-4
model_timestamp_start = '2025-04-28_14-16-04'

In [None]:
model_save_path = os.path.join('./saved_models',
                               f"model_BlocDiagBoxHead__lambda_offdiag={lambda_offdiag}__n_total_epoch={n_total_epoch}__{model_timestamp_start}.pt")

In [None]:
model.load_state_dict(torch.load(model_save_path))
model.to(device)

In [None]:
from utils.eval_utils import evaluate_model

In [None]:
# Load the transform used during pretraining
transform = weights.transforms()

In [None]:
data_dir = "data/coco"
val_img_folder = os.path.join(data_dir, "val2017")
val_ann_file = os.path.join(data_dir, "annotations/instances_val2017.json")

In [None]:
def coco_collate_fn(batch):
    return tuple(zip(*batch))

In [None]:
# Load dataset and apply transform manually to images only
dataset_raw = CocoDetection(val_img_folder, val_ann_file)

In [None]:
dataset = []
for i in range(len(dataset_raw)):
    img, target = dataset_raw[i]
    img_id = dataset_raw.ids[i]
    transformed_img = transform(img)
    for t in target:
        t["image_id"] = img_id
    dataset.append((transformed_img, target))

val_data_loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=coco_collate_fn)

In [None]:
# Evaluate using pycocotools
coco_gt = dataset_raw.coco

In [None]:
eval_file_path = os.path.join('./model_eval_results', f"coco_val_results__{model_timestamp_start}.csv")

In [None]:
# Eval model and save results to json
evaluate_model(model, val_data_loader, device=device,
               output_path=eval_file_path)

In [None]:
print(eval_file_path)
coco_dt = coco_gt.loadRes(eval_file_path)
eval = COCOeval(coco_gt, coco_dt, iouType='bbox')
eval.evaluate()
eval.accumulate()
eval.summarize()

In [None]:
import matplotlib.pyplot as plt

In [None]:
fc6_layer = model.roi_heads.box_head.fc6

# Get the weight matrix (shape: [1024, 12544])
fc6_weight = fc6_layer.weight.data.cpu().numpy()



# Plot as a heatmap
plt.figure(figsize=(14, 6))
plt.imshow(fc6_weight, aspect='auto', interpolation='nearest', cmap='viridis', vmin=-0.01, vmax=0.01)
plt.colorbar(label="Weight Value")
plt.title("Weight Matrix")
plt.xlabel("Input Features")
plt.ylabel("Output Features")
plt.show()