In [None]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from voc import get_dataloader
from main_utils import set_seed
from model_factory import get_model
from ema import RobustEMA
import os
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm


set_seed(42)

In [None]:
SIZE = (224, 224)
CONFIDENCE_THRESHOLD = 0.7
BATCH_SIZE = 4
CHECKPOINT_DIR = "./checkpoints"
METRIC_SUPERVISED = ["loss_classifier", "loss_box_reg", "loss_objectness", "loss_rpn_box_reg"]
METRICS_UNSUPERVISED = ["loss_classifier", "loss_objectness"]
LAMBDA_UNSUPERVISED = 5.0

os.makedirs(CHECKPOINT_DIR, exist_ok=True)

def scale_to_01(image, **kwargs):
    return image.astype('float32') / 255.0

weak_augmentations = A.Compose([
    A.Resize(SIZE[0], SIZE[1]),         
    A.HorizontalFlip(p=0.5),
    A.Lambda(image=scale_to_01), 
    ToTensorV2(),
], bbox_params=A.BboxParams(format='pascal_voc'))

strong_augmentations = A.Compose(
        [
            A.Resize(SIZE[0], SIZE[1]),
            A.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.8),
            A.GaussianBlur(blur_limit=(3, 7), sigma_limit=(0.1, 2.0), p=0.5),
            A.CoarseDropout(num_holes_range=(3, 3), hole_height_range=(0.05, 0.1),
                             hole_width_range=(0.05, 0.1), p=0.5),
            A.Lambda(image=scale_to_01), 
            ToTensorV2(),
        ],
        bbox_params=A.BboxParams(format='pascal_voc')
    )

test_transforms = A.Compose([
    A.Resize(SIZE[0], SIZE[1]),
    A.Lambda(image=scale_to_01), 
    ToTensorV2(), 
], bbox_params=A.BboxParams(format='pascal_voc'))


dt_train_labeled = get_dataloader("trainval", "2007", BATCH_SIZE, transform=weak_augmentations)
dt_train_unlabeled_weakaug = get_dataloader("trainval", "2012", BATCH_SIZE, transform=weak_augmentations) 
dt_train_unlabeled_strongaug = get_dataloader("trainval", "2012", BATCH_SIZE, transform=strong_augmentations)
dt_test = get_dataloader("test", "2007", BATCH_SIZE, transform=test_transforms, shuffle=False)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
def plot_losses(history, save_dir=None, filename="loss_plot.png"):
    sns.set_theme(style="whitegrid")
    epochs = range(1, len(history["total"]) + 1)
    plt.figure(figsize=(8, 5))

    for comp in METRIC_SUPERVISED:
        plt.plot(epochs, history[f"{comp}_supervised"], label=f"Train {comp}_supervised", linewidth=2)
    for comp in METRIC_SUPERVISED:
        plt.plot(epochs, history[f"{comp}_unsupervised"], label=f"Train {comp}_unsupervised", linewidth=2)
        
    plt.plot(epochs, history["total"], label="Train total", linewidth=2)

    plt.title("Training Loss Components Over Epochs")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.tight_layout()
    plt.show()

    # Save to disk
    if save_dir is not None:
        os.makedirs(save_dir, exist_ok=True)
        out_path = os.path.join(save_dir, f"{filename}_{epochs}")
        plt.savefig(out_path, dpi=300, bbox_inches="tight")
        print(f"[INFO] Plot saved to: {out_path}")



SyntaxError: invalid syntax. Perhaps you forgot a comma? (1180880994.py, line 40)

In [None]:
def load_checkpoint(checkpoint_path, optimizer=None, device='cuda'):
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    model = get_model(device=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Model weights loaded from {checkpoint_path}")

    if optimizer and 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print(f"Optimizer state loaded from {checkpoint_path}")

    epoch = checkpoint.get('epoch', 0)
    print(f"Resuming from epoch {epoch}")
    return model, optimizer, epoch

def train_burn_in(model, optimizer, dt_train_labeled, device):
    model.train()
    train_batches = 0
    history = {key : 0 for key in METRIC_SUPERVISED}

    for images, targets in tqdm(dt_train_labeled, desc="Training"):
        # if train_batches == 5: break
        for target in targets:
            target["boxes"] = target["boxes"].to(device)
            target["labels"] = target["labels"].to(device)
        images = images.to(device)
        loss_dict = model(images, targets)
        loss = sum(loss_dict.values())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        for k, v in loss_dict.items():
            history[k] += v.item()

        history["total"] += loss.item()
        train_batches += 1
    for key in history:
        history[key] = history[key] / train_batches
    return history


def save_checkpoint(model, optimizer, epoch, path):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }
    torch.save(checkpoint, path)
    print(f"Checkpoint saved at {path}")


def pipeline_burn_in(epochs, dt_train_labeled, device, checkpoint_every):

    model = get_model(device=device)
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    history = {key : [] for key in METRIC_SUPERVISED}

    for epoch in range(epochs):
        print(f"\n==================== Epoch {epoch+1}/{epochs} ====================\n")
        train_history = train_burn_in(model, optimizer, dt_train_labeled, device)
        lr_scheduler.step(train_history["total"])
        for key, val in train_history.items():
            history[key].append(val)
        plot_losses(history)
        if (epoch + 1) % checkpoint_every == 0 or (epoch + 1) == epochs:
            checkpoint_path = os.path.join(CHECKPOINT_DIR, f"checkpoint_epoch_{epoch+1}.pth")
            save_checkpoint(model, optimizer, epoch + 1, checkpoint_path)

# pipeline_burn_in(50, dt_train_labeled, device, 3)

In [5]:
images, labels = next(iter(dt_train_unlabeled_weakaug))
from torchvision.ops import batched_nms
NMS_IOU = 0.5
def generate_pseudo_labels(model : torch.nn.Module, images : torch.Tensor, device):
    model.eval()
    with torch.no_grad():
        images = images.to(device)
        outputs = model(images, None)
        for output in outputs:
            boxes  = output["boxes"]
            labels = output["labels"]
            scores = output["scores"]

            keep_nms = batched_nms(
                boxes, scores, labels,
                iou_threshold=NMS_IOU
            )
            boxes  = boxes[keep_nms]
            labels = labels[keep_nms]
            scores = scores[keep_nms]

            boxes_to_keep = scores > CONFIDENCE_THRESHOLD        
            boxes  = boxes[boxes_to_keep]
            labels = labels[boxes_to_keep]
            scores = scores[boxes_to_keep]

            output["boxes"]  = boxes
            output["labels"] = labels
            output["scores"] = scores
        return outputs       
    
# model, optimizer, epoch = load_checkpoint(checkpoint_path=checkpoint_path, optimizer=None, device=device)
# generate_pseudo_labels(model, images, device)

In [None]:
def train_semi_supervised_one_epoch(teacher : RobustEMA, student, optimizer, dt_labeled, dt_weak, dt_strong):
    student.train()
    train_batches = 0
    history = {}
    for key in METRIC_SUPERVISED:
        history[f"{key}_supervised"] = 0
    for key in METRICS_UNSUPERVISED:
        history[f"{key}_unsupervised"] = 0
    history["total"] = 0

    for (img_labeled, targets_labeled), (img_weak, _), (img_strong, _) in zip(dt_labeled, dt_weak, dt_strong):
        if train_batches == 5: break
        # SHOULD REPLACE THE TRANSFORMATION OF HORIZONTAL FLIP WITH SOMETHING PHOTOMETRIC
        weak_targets = generate_pseudo_labels(teacher.ema, img_weak, device)
        
        for target in weak_targets:
            target["boxes"] = target["boxes"].to(device)
            target["labels"] = target["labels"].to(device)
        img_strong = img_strong.to(device)
        loss_dict_unsupervised = student(img_strong, weak_targets)

        for target in targets_labeled:
            target["boxes"] = target["boxes"].to(device)
            target["labels"] = target["labels"].to(device)
        img_labeled = img_labeled.to(device)
        loss_dict_supervised = student(img_labeled, targets_labeled)

        optimizer.zero_grad()
        loss = sum(loss_dict_supervised.values()) + LAMBDA_UNSUPERVISED * (loss_dict_unsupervised["loss_classifier"] + loss_dict_unsupervised["loss_objectness"])
        loss.backward()
        optimizer.step()

        teacher.update(student)
        for k in METRICS_UNSUPERVISED:
            history[f"{k}_unsupervised"] += loss_dict_unsupervised[k]
        for k in METRIC_SUPERVISED:
            history[f"{k}_supervised"] += loss_dict_supervised[k] 

        history["total"] += loss.item()
        train_batches += 1
    for key in history:
        history[key] = history[key] / train_batches
    return history


def run_semi_supervised_pipeline(checkpoint_path, epochs, dt_labeled, dt_weak, dt_strong, dt_test):
    student, _, _ = load_checkpoint(checkpoint_path=checkpoint_path, optimizer=None, device=device)
    teacher = RobustEMA(student)
    optimizer = torch.optim.SGD(student.parameters(), lr=1e-2, momentum=0.9)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
    history = {}
    for key in METRIC_SUPERVISED:
        history[f"{key}_supervised"] = []
    for key in METRICS_UNSUPERVISED:
        history[f"{key}_unsupervised"] = []
    
    for epoch in range(epochs):
        print(f"\n==================== Epoch {epoch+1}/{epochs} ====================\n")
        train_history = train_semi_supervised_one_epoch(teacher, student, optimizer, dt_labeled, dt_weak, dt_strong)
        lr_scheduler.step(train_history["total"])
        for key, val in train_history.items():
            history[key].append(val)
        plot_losses(history)

checkpoint_path="checkpoints/checkpoint_epoch_42.pth"
run_semi_supervised_pipeline(checkpoint_path, 10, dt_train_labeled, dt_train_unlabeled_weakaug, dt_train_unlabeled_strongaug, dt_test)


Model weights loaded from checkpoints/checkpoint_epoch_42.pth
Resuming from epoch 42




OutOfMemoryError: CUDA out of memory. Tried to allocate 158.00 MiB. GPU 0 has a total capacity of 5.63 GiB of which 130.06 MiB is free. Including non-PyTorch memory, this process has 4.84 GiB memory in use. Of the allocated memory 4.28 GiB is allocated by PyTorch, and 446.30 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)