In [15]:
import albumentations.pytorch
import customdataset as ds
import torch as t
import torchvision as tv
import torchvision.tv_tensors as tvt
import albumentations as A


x, y = ds.extract_all('./datasets/090/annotations.xml')
x1, y1 = ds.extract_all('./datasets/190/annotations.xml')
x2, y2 = ds.extract_all('./datasets/30/annotations.xml')
x3, y3 = ds.extract_all('./datasets/60/annotations.xml')
x4, y4 = ds.extract_all('./datasets/90/annotations.xml')

x5, y5 = ds.extract_all('./datasets/clear1/annotations.xml')

train_ds = ds.CustomImageDataset([*x, *x1, *x2, *x3, *x4], [*y, *y1, *y2, *y3, *y4], transform=A.Compose([
        A.Resize(1024, 1024),
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
        A.ShiftScaleRotate(
            shift_limit=0.0625,  # fraction of image height/width
            scale_limit=0.1,     # +/- 10% scale
            rotate_limit=15,     # degrees
            border_mode=0,       # fill outside with 0 (black)
            p=0.5
        ),
        albumentations.pytorch.ToTensorV2()
    ],
    bbox_params=ds.albumentations_params)
)

val_ds = ds.CustomImageDataset(x5, y5)

In [16]:

device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')
dl_train = t.utils.data.DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=lambda e: tuple(zip(*e)))
dl_val = t.utils.data.DataLoader(val_ds, batch_size=4, shuffle=True, collate_fn=lambda e: tuple(zip(*e)))

model = tv.models.detection.fasterrcnn_resnet50_fpn(weights='COCO_V1', backbone_weights='IMAGENET1K_V2')

# replace the pre-trained head with a new one
model.roi_heads.box_predictor = tv.models.detection.faster_rcnn.FastRCNNPredictor(
    in_channels=model.roi_heads.box_predictor.cls_score.in_features,
    num_classes=2
)


In [17]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = t.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = t.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)


In [18]:
best_val_loss = float('inf')
patience = 5          # stop if val_loss doesn't improve for 5 epochs
counter = 0
best_model_state = None

for epoch in range(num_epochs):
    # ---- TRAIN ----
    model.train()
    train_loss_epoch = 0
    for images, targets in dl_train:
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        total_loss = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        train_loss_epoch += total_loss.item()
    train_loss_epoch /= len(dl_train)

    # ---- VALIDATION ----
    model.eval()
    val_loss_epoch = 0
    with t.no_grad():
        for images, targets in dl_val:
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            loss_dict = model(images, targets)
            total_loss = sum(loss for loss in loss_dict.values())
            val_loss_epoch += total_loss.item()
    val_loss_epoch /= len(dl_val)

    print(f"Epoch {epoch+1} | Train Loss: {train_loss_epoch:.4f} | Val Loss: {val_loss_epoch:.4f}")

    # ---- Early Stopping Check ----
    if val_loss_epoch < best_val_loss:
        best_val_loss = val_loss_epoch
        best_model_state = model.state_dict()  # save best model
        counter = 0
    else:
        counter += 1
        if counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            model.load_state_dict(best_model_state)  # restore best model
            break

KeyboardInterrupt: 