# Faster RCNN ResNet50

## Dataset setup

In [1]:
import albumentations.pytorch
import datasets as ds
import utils
import torch as t
import torchvision as tv
import torchvision.tv_tensors as tvt
import albumentations as A

# utils.set_seed(42)

x, y = ds.extract_all('./datasets/090/annotations.xml')
x1, y1 = ds.extract_all('./datasets/190/annotations.xml')
x2, y2 = ds.extract_all('./datasets/30/annotations.xml')
x3, y3 = ds.extract_all('./datasets/60/annotations.xml')
x4, y4 = ds.extract_all('./datasets/90/annotations.xml')

x5, y5 = ds.extract_all('./datasets/clear1/annotations.xml')

train_ds = ds.CustomImageDataset([*x, *x1, *x2, *x3, *x4], [*y, *y1, *y2, *y3, *y4], transform=A.Compose([
    A.HorizontalFlip(p=0.5),
    A.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    A.Affine(
        translate_percent={"x": 0.0625, "y": 0.0625},
        scale=(0.9, 1.1),
        rotate=(-15, 15),
        p=0.5
    ),
    albumentations.pytorch.ToTensorV2()
],
    bbox_params=ds.albumentations_params)
                                 )

val_ds = ds.CustomImageDataset(x5, y5)

dl_train = t.utils.data.DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=utils.unroller)
dl_val = t.utils.data.DataLoader(val_ds, batch_size=4, shuffle=True, collate_fn=utils.unroller)


## Model setup

In [2]:

device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')

model = tv.models.detection.fasterrcnn_resnet50_fpn(weights='COCO_V1', backbone_weights='IMAGENET1K_V2').to(device)

# replace the pre-trained head with a new one
model.roi_heads.box_predictor = tv.models.detection.faster_rcnn.FastRCNNPredictor(
    in_channels=model.roi_heads.box_predictor.cls_score.in_features,
    num_classes=2
).to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = t.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = t.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)



## Training

In [3]:
#from tqdm.notebook import tqdm
from tqdm.auto import tqdm
import torchmetrics as tm

# epoch metrics
best_val_loss = float('inf')
best_val_map = 0.0
metric = tm.detection.mean_ap.MeanAveragePrecision("xyxy")

# early stopping
patience = 5
counter = 0
best_model_state = None

# hyperparams
num_epochs = 100

for epoch in range(num_epochs):
    # ---- TRAIN ----
    model.train()
    train_loss = 0

    dl_train_tqdm = tqdm(dl_train, desc=f"Train Epoch {epoch + 1}", leave=True)

    for images, targets in dl_train_tqdm:
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        total_loss = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        train_loss += total_loss.item()

        dl_train_tqdm.set_postfix(loss=total_loss.item())

    train_loss /= len(dl_train)

    # ---- VALIDATION ----
    val_loss = 0
    metric.reset()

    with t.no_grad():
        for images, targets in dl_val:
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            # loss
            #print(targets)
            loss_dict = model(images, targets)
            #print(loss_dict)

            losses = sum(loss for loss in loss_dict.values())
            val_loss += losses

            # predictions

            model.eval()
            pred = model(images)
            metric.update(pred, targets)

    val_loss /= len(dl_val)  # avg loss

    val_metrics = metric.compute()
    val_map = val_metrics["map"].item()

    print(f"epoch={epoch + 1}; train_loss={train_loss:.4f}; val_loss={val_loss:.4f}; val_map={val_map:.4f}")

    # ---- Early Stopping ----
    if val_loss < best_val_loss:
        best_val_loss = val_loss

        best_model_state = model.state_dict()
        counter = 0
    else:
        counter += 1
        if counter >= patience:
            print(f"Early stopping at epoch {epoch + 1}")
            model.load_state_dict(best_model_state)
            break


Train Epoch 1:   0%|          | 0/81 [00:00<?, ?it/s]

epoch=1; train_loss=0.1418; val_loss=0.0916; val_map=0.3429


Train Epoch 2:   0%|          | 0/81 [00:00<?, ?it/s]

KeyboardInterrupt: 