<a href="https://colab.research.google.com/github/jrsykes/CocoaReader/blob/master/2step_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
!pip install torchmetrics
from torchmetrics.classification import MulticlassPrecision, MulticlassRecall, MulticlassF1Score
from torch.utils.data import DataLoader
from torchvision import transforms

In [None]:
import numpy as np
from tqdm import tqdm
from PIL import Image
import os
!pip install loguru
from loguru import logger
# import wandb
import random

In [23]:
!mkdir -p ~/.ssh
!cp /content/drive/MyDrive/colab_ssh/id_ed25519 ~/.ssh/id_ed25519
!chmod 600 ~/.ssh/id_ed25519

# Add GitHub to known hosts
!ssh-keyscan -t rsa github.com >> ~/.ssh/known_hosts


# github.com:22 SSH-2.0-2e51c3195


In [24]:
!GIT_SSH_COMMAND="ssh -i ~/.ssh/id_ed25519 -o IdentitiesOnly=yes" git clone git@github.com:Fairfield-Vision/RT-DETR_cocoa.git


Cloning into 'RT-DETR_cocoa'...
remote: Enumerating objects: 343, done.[K
remote: Counting objects: 100% (20/20), done.[K
remote: Compressing objects: 100% (13/13), done.[K
remote: Total 343 (delta 7), reused 18 (delta 7), pack-reused 323 (from 1)[K
Receiving objects: 100% (343/343), 139.71 MiB | 28.70 MiB/s, done.
Resolving deltas: 100% (44/44), done.


In [25]:
import sys
sys.path.append('/content/RT-DETR_cocoa')

from HiveSight_utils import create_rtdetr_model, load_partial_weights, NPYDataset
from HiveSight_utils import export_to_onnx, draw_bboxes_on_image, custom_collate_fn


In [26]:
from rtdetrv2_pytorch.src.zoo.rtdetr.rtdetr_criterion import RTDETRCriterion
from rtdetrv2_pytorch.src.zoo.rtdetr.matcher import HungarianMatcher

In [27]:
logger.add("train_script.log", rotation="500 MB")

1

In [28]:
def freeze_layers(model, freeze_backbone=True):
    """Freeze/unfreeze specific layers in the model."""
    for param in model.backbone.parameters():
        param.requires_grad = not freeze_backbone
    for param in model.encoder.parameters():
        param.requires_grad = freeze_backbone
    for param in model.decoder.parameters():
        param.requires_grad = freeze_backbone
    for param in model.regression_head.parameters():
        param.requires_grad = not freeze_backbone

In [29]:
def filter_bbox_samples(dataset):
    """Filter dataset to include only samples with bounding box annotations."""
    return [sample for sample in dataset if len(sample[2]) > 0]

In [None]:
def train(config=None):
    image_root = "/users/jrs596/scratch/EC25/data_test"
    n_classes = len(os.listdir(image_root))
    annotation_root = "/users/jrs596/longship/yolo_annotations"
    RT_DETR_weights = "/users/jrs596/scratch/TORCH_HOME/hub/checkpoints/rtdetrv2_r18vd_120e_coco_rerun_48.1.pth"
    batch_size = 100
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    weight_dict = {
        "cost_class": config.cost_class_weight,
        "cost_bbox": config.cost_bbox_weight,
        "cost_giou": config.cost_giou_weight
    }
    matcher = HungarianMatcher(weight_dict=weight_dict, use_focal_loss=False)
    rtdetr_criterion = RTDETRCriterion(
        matcher=matcher,
        weight_dict={
            "loss_ce": config.loss_ce_weight,
            "loss_bbox": config.loss_bbox_weight,
            "loss_giou": config.loss_giou_weight
        },
        losses=['labels', 'boxes', 'cardinality'],
        alpha=0.25,
        gamma=2.0,
        eos_coef=0.1,
        num_classes=1
    )
    rtdetr_criterion.to(device)
    classification_criterion = torch.nn.CrossEntropyLoss()
    precision = MulticlassPrecision(num_classes=n_classes, average='macro').to(device)
    recall = MulticlassRecall(num_classes=n_classes, average='macro').to(device)
    f1_score = MulticlassF1Score(num_classes=n_classes, average='macro').to(device)
    logger.info("Creating datasets and data loaders")
    dataset = NPYDataset(image_root, annotation_root)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)
    model = create_rtdetr_model(num_classes_bb=2, num_queries=300, n_classes=n_classes)
    checkpoint = torch.load(RT_DETR_weights, map_location="cpu", weights_only=True)
    state_dict = checkpoint['ema']['module']
    load_partial_weights(model.encoder, state_dict, "encoder")
    load_partial_weights(model.decoder, state_dict, "decoder")
    model.to(device)
    best_f1_score, best_giou = 0.0, float("inf")

    # Phase 1: Train Backbone & Classification Head
    logger.info("Phase 1: Training classification head with frozen transformer")
    freeze_layers(model, freeze_backbone=False)
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate)
    phase_1_running = True

    while phase_1_running:
        model.train()
        for images, labels, _ in tqdm(train_loader, desc="Phase 1 Training"):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            _, _, pred_labels = model(images)
            loss = classification_criterion(pred_labels, labels)
            loss.backward()
            optimizer.step()

        f1 = f1_score.compute().item()
        if f1 <= best_f1_score:
            phase_1_running = False
        else:
            best_f1_score = f1
        f1_score.reset()

    # Phase 2: Train Transformer for Bounding Box Detection (only samples with bounding boxes)
    train_dataset = filter_bbox_samples(train_dataset)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)

    logger.info("Phase 2: Training transformer with frozen classification head")
    freeze_layers(model, freeze_backbone=True)
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate)
    phase_2_running = True

    while phase_2_running:
        model.train()
        for images, _, bboxes in tqdm(train_loader, desc="Phase 2 Training"):
            images = images.to(device)
            optimizer.zero_grad()
            pred_logits, pred_boxes, _ = model(images)
            loss_dict = rtdetr_criterion({"pred_logits": pred_logits, "pred_boxes": pred_boxes}, bboxes)
            loss = sum(loss_dict.values())
            loss.backward()
            optimizer.step()

        giou = loss_dict["loss_giou"].item()
        if giou >= best_giou:
            phase_2_running = False
        else:
            best_giou = giou

#     wandb.finish()

if __name__ == "__main__":<br>
    wandb.init(project="HiveSight-RT-DETR_cocoa", config={<br>
        "learning_rate": 1e-4,<br>
        "cost_class_weight": 1.0,<br>
        "cost_bbox_weight": 5.0,<br>
        "cost_giou_weight": 2.0,<br>
        "loss_ce_weight": 1.0,<br>
        "loss_bbox_weight": 5.0,<br>
        "loss_giou_weight": 2.0<br>
    })<br>
    train(config=wandb.config)

In [None]:
config={
        "learning_rate": 1e-4,
        "cost_class_weight": 1.0,
        "cost_bbox_weight": 5.0,
        "cost_giou_weight": 2.0,
        "loss_ce_weight": 1.0,
        "loss_bbox_weight": 5.0,
        "loss_giou_weight": 2.0
    }

In [None]:
train(config=config)