In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, ConcatDataset
import torchvision.transforms.functional as TF
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from sklearn.model_selection import StratifiedShuffleSplit
from torchmetrics.detection.mean_ap import MeanAveragePrecision
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from tqdm import tqdm
from PIL import Image

device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
print(device)

os.environ["TOKENIZERS_PARALLELISM"] = "false"

np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [None]:
DATASET_ROOT = "/kaggle/input/mammography"

In [None]:
class MammographyDataset(Dataset):
    def __init__(self, df, img_dir, train=False, strong_aug=True):
        self.data = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.train = train
        self.strong_aug = strong_aug
        self.label_map = {"BENIGN": 0, "MALIGNANT": 1}
        self.data["label"] = self.data["pathology"].map(self.label_map)
        for col in ["xmin", "ymin", "xmax", "ymax"]:
            if col in self.data.columns:
                self.data[col] = self.data[col].fillna(0)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        img_path = os.path.join(self.img_dir, row["image_name"])
        image = Image.open(img_path).convert("RGB")
        
        original_width, original_height = image.size
        
        target_height, target_width = 981, 800
        
        image = image.resize((target_width, target_height), Image.BILINEAR)
        scale_x = target_width / original_width
        scale_y = target_height / original_height
        
        scaled_xmin = row["xmin"] * scale_x
        scaled_ymin = row["ymin"] * scale_y
        scaled_xmax = row["xmax"] * scale_x
        scaled_ymax = row["ymax"] * scale_y
        
        label = torch.tensor(row["label"], dtype=torch.long)
        
        if self.train:
            image, scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax = self._apply_augmentations(
                image, scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax, target_width, target_height
            )
        
        if label == 1:
            bbox_xywh = torch.tensor(
                [
                    scaled_xmin,
                    scaled_ymin,
                    scaled_xmax - scaled_xmin,
                    scaled_ymax - scaled_ymin,
                ],
                dtype=torch.float32,
            )
            annotations = {
                "image_id": 0,
                "annotations": [
                    {
                        "bbox": bbox_xywh,
                        "area": float(bbox_xywh[2] * bbox_xywh[3]),
                        "category_id": 0,
                        "iscrowd": 0,
                    }
                ],
            }
        else:
            annotations = {"image_id": 0, "annotations": []}
        
        return image, annotations
    
    def _apply_augmentations(self, image, xmin, ymin, xmax, ymax, img_width, img_height):
        if np.random.random() > 0.5:
            image = TF.hflip(image)
            xmin_new = img_width - xmax
            xmax_new = img_width - xmin
            xmin, xmax = xmin_new, xmax_new
            
        if np.random.random() > 0.5:
            scale_factor = np.random.uniform(0.95, 1.05)
            new_width = int(img_width * scale_factor)
            new_height = int(img_height * scale_factor)
            
            image = TF.resize(image, (new_height, new_width), interpolation=TF.InterpolationMode.BILINEAR)
            
            if scale_factor > 1:
                left = (new_width - img_width) // 2
                top = (new_height - img_height) // 2
                image = TF.crop(image, top, left, img_height, img_width)
                
                xmin = xmin * scale_factor - left
                ymin = ymin * scale_factor - top
                xmax = xmax * scale_factor - left
                ymax = ymax * scale_factor - top
            else:
                padding_left = (img_width - new_width) // 2
                padding_top = (img_height - new_height) // 2
                padding_right = img_width - new_width - padding_left
                padding_bottom = img_height - new_height - padding_top
                
                image = TF.pad(image, (padding_left, padding_top, padding_right, padding_bottom), fill=0)
                
                xmin = xmin * scale_factor + padding_left
                ymin = ymin * scale_factor + padding_top
                xmax = xmax * scale_factor + padding_left
                ymax = ymax * scale_factor + padding_top
            
            xmin = max(0, min(xmin, img_width))
            ymin = max(0, min(ymin, img_height))
            xmax = max(0, min(xmax, img_width))
            ymax = max(0, min(ymax, img_height))

        if self.strong_aug:
            if np.random.random() > 0.5:
                angle = np.random.uniform(-5, 5)
                image = TF.rotate(image, angle, interpolation=TF.InterpolationMode.BILINEAR, fill=0)
                
                center_x, center_y = img_width / 2, img_height / 2
                angle_rad = -angle * np.pi / 180
                
                corners = [
                    (xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)
                ]
                rotated_corners = []
                for x, y in corners:
                    x_temp = x - center_x
                    y_temp = y - center_y
                    x_new = x_temp * np.cos(angle_rad) - y_temp * np.sin(angle_rad)
                    y_new = x_temp * np.sin(angle_rad) + y_temp * np.cos(angle_rad)
                    rotated_corners.append((x_new + center_x, y_new + center_y))
                
                xs = [c[0] for c in rotated_corners]
                ys = [c[1] for c in rotated_corners]
                xmin, xmax = min(xs), max(xs)
                ymin, ymax = min(ys), max(ys)
                
                xmin = max(0, min(xmin, img_width))
                ymin = max(0, min(ymin, img_height))
                xmax = max(0, min(xmax, img_width))
                ymax = max(0, min(ymax, img_height))
        
            if np.random.random() > 0.5:
                brightness_factor = np.random.uniform(0.85, 1.15)
                image = TF.adjust_brightness(image, brightness_factor)
            
            if np.random.random() > 0.5:
                contrast_factor = np.random.uniform(0.85, 1.15)
                image = TF.adjust_contrast(image, contrast_factor)
            
            if np.random.random() > 0.5:
                gamma = np.random.uniform(0.9, 1.1)
                image = TF.adjust_gamma(image, gamma)
        
        return image, xmin, ymin, xmax, ymax

In [None]:
def grounding_dino_collate_fn(batch):
    images, annotations = zip(*batch)
    return list(images), list(annotations)

In [None]:
def create_train_dataloaders(train_name_a, train_name_b, labeled_fraction=0.1):
    train_csv_a = f"{DATASET_ROOT}/dataset_{train_name_a}/train/train.csv"
    train_img_dir_a = f"{DATASET_ROOT}/dataset_{train_name_a}/train"
    train_full_df_a = pd.read_csv(train_csv_a)
    
    splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
    train_idx_a, val_idx_a = next(splitter.split(train_full_df_a, train_full_df_a["pathology"]))
    train_df_a = train_full_df_a.iloc[train_idx_a].reset_index(drop=True)
    val_df_a = train_full_df_a.iloc[val_idx_a].reset_index(drop=True)
    
    train_csv_b = f"{DATASET_ROOT}/dataset_{train_name_b}/train/train.csv"
    train_img_dir_b = f"{DATASET_ROOT}/dataset_{train_name_b}/train"
    train_full_df_b = pd.read_csv(train_csv_b)
    
    train_idx_b, val_idx_b = next(splitter.split(train_full_df_b, train_full_df_b["pathology"]))
    train_df_b = train_full_df_b.iloc[train_idx_b].reset_index(drop=True)
    val_df_b = train_full_df_b.iloc[val_idx_b].reset_index(drop=True)
    
    splitter_b = StratifiedShuffleSplit(n_splits=1, test_size=1-labeled_fraction, random_state=42)
    labeled_idx_b, unlabeled_idx_b = next(splitter_b.split(train_df_b, train_df_b["pathology"]))
    train_df_b_labeled = train_df_b.iloc[labeled_idx_b].reset_index(drop=True)
    train_df_b_unlabeled = train_df_b.iloc[unlabeled_idx_b].reset_index(drop=True)
    
    train_dataset_a = MammographyDataset(train_df_a, train_img_dir_a)
    train_dataset_b_labeled = MammographyDataset(train_df_b_labeled, train_img_dir_b)
    supervised_train_dataset = ConcatDataset([train_dataset_a, train_dataset_b_labeled])
    
    unsupervised_train_dataset = MammographyDataset(train_df_b_unlabeled, train_img_dir_b, train=True, strong_aug=False)
    
    val_dataset_a = MammographyDataset(val_df_a, train_img_dir_a)
    val_dataset_b = MammographyDataset(val_df_b, train_img_dir_b)
    
    combined_df = pd.concat([train_df_a, train_df_b_labeled], ignore_index=True)
    class_counts = combined_df["pathology"].value_counts().to_dict()
    weights = combined_df["pathology"].apply(lambda x: 1.0 / class_counts[x]).values
    sampler = WeightedRandomSampler(weights=weights, num_samples=len(combined_df), replacement=True)
    
    supervised_train_loader = DataLoader(
        supervised_train_dataset,
        batch_size=2,
        num_workers=4,
        collate_fn=grounding_dino_collate_fn,
        sampler=sampler,
    )
    
    unsupervised_train_loader = DataLoader(
        unsupervised_train_dataset,
        batch_size=2,
        shuffle=True,
        num_workers=4,
        collate_fn=grounding_dino_collate_fn,
    )
    
    val_loader = DataLoader(
        ConcatDataset([val_dataset_a, val_dataset_b]),
        batch_size=4,
        shuffle=False,
        num_workers=4,
        collate_fn=grounding_dino_collate_fn,
    )
    
    return supervised_train_loader, unsupervised_train_loader, val_loader

In [None]:
def create_test_dataloaders():
    datasets = ["A", "B", "C"]
    test_loaders = []

    for dataset in datasets:
        test_csv = f"{DATASET_ROOT}/dataset_{dataset}/test/test.csv"
        test_img_dir = f"{DATASET_ROOT}/dataset_{dataset}/test"

        test_df = pd.read_csv(test_csv)
        test_df.columns = [c.strip().lower().replace(" ", "_") for c in test_df.columns]

        test_dataset = MammographyDataset(test_df, test_img_dir)
        test_loader = DataLoader(
            test_dataset,
            batch_size=4,
            shuffle=False,
            num_workers=4,
            collate_fn=grounding_dino_collate_fn,
        )
        test_loaders.append(test_loader)

    return tuple(test_loaders)

In [None]:
def plot_metrics(
    train_name,
    num_epochs,
    train_losses,
    val_losses,
    train_map50_scores,
    val_map50_scores,
    train_map75_scores,
    val_map75_scores,
    train_map_scores,
    val_map_scores,
    sup_train_losses,
    unsup_train_losses,
):
    epochs_range = range(1, num_epochs + 1)
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    axes[0, 0].plot(
        epochs_range,
        sup_train_losses,
        "b-o",
        label="Supervised Loss",
        linewidth=2,
        markersize=8,
    )
    axes[0, 0].plot(
        epochs_range,
        unsup_train_losses,
        "g-^",
        label="Unsupervised Loss",
        linewidth=2,
        markersize=8,
    )
    axes[0, 0].plot(
        epochs_range,
        val_losses,
        "r-s",
        label="Validation Loss",
        linewidth=2,
        markersize=8,
    )
    axes[0, 0].set_xlabel("Epoch", fontsize=12)
    axes[0, 0].set_ylabel("Loss (per sample)", fontsize=12)
    axes[0, 0].set_title(
        f"Training Losses (Sup/Unsup) and Validation Loss - {train_name}", fontsize=14, fontweight="bold"
    )
    axes[0, 0].legend(fontsize=11)
    axes[0, 0].grid(True, alpha=0.3)
    
    axes[0, 1].plot(
        epochs_range,
        train_map50_scores,
        "b-o",
        label="Training mAP@50",
        linewidth=2,
        markersize=8,
    )
    axes[0, 1].plot(
        epochs_range,
        val_map50_scores,
        "r-s",
        label="Validation mAP@50",
        linewidth=2,
        markersize=8,
    )
    axes[0, 1].set_xlabel("Epoch", fontsize=12)
    axes[0, 1].set_ylabel("mAP@50", fontsize=12)
    axes[0, 1].set_title(
        f"mAP@50 (IoU=0.5) - {train_name}", fontsize=14, fontweight="bold"
    )
    axes[0, 1].legend(fontsize=11)
    axes[0, 1].grid(True, alpha=0.3)
    
    axes[1, 0].plot(
        epochs_range,
        train_map75_scores,
        "b-o",
        label="Training mAP@75",
        linewidth=2,
        markersize=8,
    )
    axes[1, 0].plot(
        epochs_range,
        val_map75_scores,
        "r-s",
        label="Validation mAP@75",
        linewidth=2,
        markersize=8,
    )
    axes[1, 0].set_xlabel("Epoch", fontsize=12)
    axes[1, 0].set_ylabel("mAP@75", fontsize=12)
    axes[1, 0].set_title(
        f"mAP@75 (IoU=0.75) - {train_name}", fontsize=14, fontweight="bold"
    )
    axes[1, 0].legend(fontsize=11)
    axes[1, 0].grid(True, alpha=0.3)
    
    axes[1, 1].plot(
        epochs_range,
        train_map_scores,
        "b-o",
        label="Training mAP",
        linewidth=2,
        markersize=8,
    )
    axes[1, 1].plot(
        epochs_range,
        val_map_scores,
        "r-s",
        label="Validation mAP",
        linewidth=2,
        markersize=8,
    )
    axes[1, 1].set_xlabel("Epoch", fontsize=12)
    axes[1, 1].set_ylabel("mAP", fontsize=12)
    axes[1, 1].set_title(
        f"Mean Average Precision (mAP) - {train_name}", fontsize=14, fontweight="bold"
    )
    axes[1, 1].legend(fontsize=11)
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f"training_metrics_{train_name}.png", dpi=300, bbox_inches="tight")
    plt.show()
    
    # Changed: Added sup and unsup losses to CSV
    metrics_df = pd.DataFrame(
        {
            "epoch": list(epochs_range),
            "train_loss": train_losses,
            "sup_train_loss": sup_train_losses,
            "unsup_train_loss": unsup_train_losses,
            "val_loss": val_losses,
            "train_map50": train_map50_scores,
            "val_map50": val_map50_scores,
            "train_map75": train_map75_scores,
            "val_map75": val_map75_scores,
            "train_map": train_map_scores,
            "val_map": val_map_scores,
        }
    )
    metrics_df.to_csv(f"training_metrics_{train_name}.csv", index=False)
    print(f"\n✓ Metrics saved to training_metrics_{train_name}.csv")
    
    print("\n" + "=" * 60)
    print(f"FINAL TRAINING SUMMARY - {train_name}")
    print("=" * 60)
    print(f"Final Training Loss (Total): {train_losses[-1]:.6f}")
    print(f"Final Supervised Loss: {sup_train_losses[-1]:.6f}")
    print(f"Final Unsupervised Loss: {unsup_train_losses[-1]:.6f}")
    print(f"Final Validation Loss: {val_losses[-1]:.6f}")
    print(f"Final Training mAP@50: {train_map50_scores[-1]:.4f}")
    print(f"Final Validation mAP@50: {val_map50_scores[-1]:.4f}")
    print(f"Final Training mAP@75: {train_map75_scores[-1]:.4f}")
    print(f"Final Validation mAP@75: {val_map75_scores[-1]:.4f}")
    print(f"Final Training mAP: {train_map_scores[-1]:.4f}")
    print(f"Final Validation mAP: {val_map_scores[-1]:.4f}")

In [None]:
class CoOpContext(nn.Module):
    def __init__(self, device, processor, model, initial_prompt):
        super().__init__()

        tokens = processor.tokenizer(
            initial_prompt, return_tensors="pt", padding=False
        ).to(device)
        input_ids = tokens["input_ids"]

        with torch.no_grad():
            text_embeds = model.model.text_backbone.embeddings.word_embeddings(
                input_ids
            )

        full_embeds = text_embeds[0, :, :]

        self.all_embeddings = nn.Parameter(full_embeds.clone())
        self.initial_token_ids = input_ids[0].clone()

        seq_len = full_embeds.shape[0]
        self.trainable_mask = torch.ones(seq_len, dtype=torch.bool)
        self.trainable_mask[0] = False
        self.trainable_mask[-2] = False
        self.trainable_mask[-1] = False

        def freeze_hook(grad):
            mask = self.trainable_mask.unsqueeze(-1).to(grad.device)
            return grad * mask

        self.hook_handle = self.all_embeddings.register_hook(freeze_hook)

    def forward(self):
        return self.all_embeddings

    def parameters(self, recurse=True):
        return super().parameters(recurse)

In [None]:
def move_labels_to_device(labels, device):
    new_labels = []
    for lbl in labels:
        new_lbl = {}
        for k, v in lbl.items():
            if torch.is_tensor(v):
                new_lbl[k] = v.to(device)
            else:
                new_lbl[k] = v
        new_labels.append(new_lbl)
    return new_labels

In [None]:
def transform_boxes(boxes, transforms_applied, img_shape):
    if len(boxes) == 0:
        return boxes
    
    h, w = img_shape
    
    for transform_name, params in transforms_applied:
        if transform_name == 'rotate':
            angle = params
            angle_rad = np.radians(angle)
            
            cx, cy = w / 2, h / 2
            
            boxes_transformed = []
            
            for box in boxes:
                x1, y1, x2, y2 = box
                
                corners = torch.tensor([
                    [x1, y1],
                    [x2, y1],
                    [x2, y2],
                    [x1, y2]
                ], device=box.device, dtype=box.dtype)
                
                corners[:, 0] -= cx
                corners[:, 1] -= cy
                
                cos_a = np.cos(angle_rad)
                sin_a = np.sin(angle_rad)
                
                rotated_corners = corners.clone()
                rotated_corners[:, 0] = corners[:, 0] * cos_a - corners[:, 1] * sin_a
                rotated_corners[:, 1] = corners[:, 0] * sin_a + corners[:, 1] * cos_a
                
                rotated_corners[:, 0] += cx
                rotated_corners[:, 1] += cy
                
                x_min = rotated_corners[:, 0].min()
                y_min = rotated_corners[:, 1].min()
                x_max = rotated_corners[:, 0].max()
                y_max = rotated_corners[:, 1].max()
                
                x_min = torch.clamp(x_min, 0, w)
                y_min = torch.clamp(y_min, 0, h)
                x_max = torch.clamp(x_max, 0, w)
                y_max = torch.clamp(y_max, 0, h)
                
                boxes_transformed.append(torch.stack([x_min, y_min, x_max, y_max]))
            
            boxes = torch.stack(boxes_transformed)
    
    return boxes

def apply_strong_augmentation(image, boxes=None):
    transforms_applied = []
    
    if np.random.random() > 0.5:
        brightness_factor = np.random.uniform(0.7, 1.3)
        image = TF.adjust_brightness(image, brightness_factor)
    
    if np.random.random() > 0.5:
        contrast_factor = np.random.uniform(0.7, 1.3)
        image = TF.adjust_contrast(image, contrast_factor)
    
    if np.random.random() > 0.5:
        gamma = np.random.uniform(0.8, 1.2)
        image = TF.adjust_gamma(image, gamma)
    
    angle = 0
    if np.random.random() > 0.5:
        angle = np.random.uniform(-10, 10)
        image = TF.rotate(image, angle, interpolation=TF.InterpolationMode.BILINEAR, fill=0)
        transforms_applied.append(('rotate', angle))
    
    if boxes is not None and len(boxes) > 0:
        w, h = image.size
        img_shape = (h, w)
        boxes = transform_boxes(boxes, transforms_applied, img_shape)
        return image, boxes
    
    return image

In [None]:
def train_one_dataset(model, processor, sup_train_name, unsup_train_name, num_epochs):
    print(f"\n{'#'*70}")
    print(f"#  SEMI-SUPERVISED TRAINING")
    print(f"{'#'*70}\n")

    sup_train_loader, unsup_train_loader, val_loader = create_train_dataloaders(sup_train_name, unsup_train_name)

    initial_prompt = "malignant tumor cancer."
    conf_threshold = 0.1
    lambda_u = 1
    
    print(f"{'='*70}")
    print(f"  Configuration")
    print(f"{'='*70}")
    print(f"  Supervised Dataset   : {sup_train_name}")
    print(f"  Unsupervised Dataset : {unsup_train_name}")
    print(f"  Epochs               : {num_epochs}")
    print(f"  Base Prompt          : '{initial_prompt}'")
    print(f"  Lambda_u (target)    : {lambda_u}")
    print(f"  Confidence Threshold : {conf_threshold}")
    print(f"{'='*70}\n")

    context_module = CoOpContext(
        device=device,
        processor=processor,
        model=model,
        initial_prompt=initial_prompt,
    )

    print(f"  Initial Context Vector:\n {context_module()}\n")
    print(f"  Context Vector Shape: {context_module().shape}\n")

    lr = 5e-4
    optimizer = optim.AdamW(context_module.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=num_epochs,
        eta_min=1e-4
    )

    train_losses = []
    sup_train_losses = []
    unsup_train_losses = []
    val_losses = []
    train_map50_scores = []
    train_map75_scores = []
    train_map_scores = []
    val_map50_scores = []
    val_map75_scores = []
    val_map_scores = []
    best_val_loss = float("inf")

    train_map_metric = MeanAveragePrecision(iou_thresholds=[0.5, 0.75]).to(device)
    val_map_metric = MeanAveragePrecision(iou_thresholds=[0.5, 0.75]).to(device)
    
    train_map_metric.warn_on_many_detections = False
    val_map_metric.warn_on_many_detections = False

    for epoch in range(num_epochs):
        print(f"\n{'='*70}")
        print(f"  Epoch {epoch + 1}/{num_epochs}")
        print(f"{'='*70}\n")

        if epoch < 5:
            current_lambda_u = 0.0
        else:
            current_lambda_u = lambda_u * min(1.0, (epoch + 1 - 5) / (num_epochs - 5))
            
        print(f"  Current Lambda_u: {current_lambda_u:.4f}\n")

        model.eval()
        context_module.train()
        total_train_loss = 0.0
        total_sup_loss = 0.0
        total_unsup_loss = 0.0
        sup_batches = 0
        unsup_batches = 0
        train_map_metric.reset()

        sup_iter = iter(sup_train_loader)
        unsup_iter = iter(unsup_train_loader)
        
        max_iters = max(len(sup_train_loader), len(unsup_train_loader))
        
        train_pbar = tqdm(range(max_iters), desc="  Training  ")
        
        for _ in train_pbar:
            use_supervised = np.random.random() > 0.5
            
            if use_supervised:
                try:
                    images, annotations = next(sup_iter)
                except StopIteration:
                    sup_iter = iter(sup_train_loader)
                    images, annotations = next(sup_iter)
                
                batch_size = len(images)
                inputs = processor(
                    images=images, annotations=annotations, return_tensors="pt"
                ).to(device)
                del inputs["pixel_mask"]

                context_expanded = context_module().unsqueeze(0).expand(batch_size, -1, -1)
                context_token_ids = context_module.initial_token_ids.unsqueeze(0).expand(batch_size, -1)

                prompt_enc = {}
                prompt_enc["input_ids"] = context_token_ids
                prompt_enc["inputs_embeds"] = context_expanded
                labels = move_labels_to_device(inputs["labels"], device)
                inputs["labels"] = labels

                inputs = inputs | prompt_enc

                outputs = model(**inputs)
                loss_dict = outputs.loss_dict
                weight_dict = {
                    "loss_ce": 2.0,
                    "loss_bbox": model.config.bbox_loss_coefficient,
                    "loss_giou": model.config.giou_loss_coefficient,
                }
                enc_weight_dict = {k + "_enc": v for k, v in weight_dict.items()}
                weight_dict.update(enc_weight_dict)
                weight_dict["loss_ce_enc"] = 0
                weight_dict["loss_bbox_enc"] = 0
                weight_dict["loss_giou_enc"] = 0
                loss = sum(
                    loss_dict[k] * weight_dict[k]
                    for k in loss_dict.keys()
                    if k in weight_dict
                )

                total_train_loss += loss.item()
                total_sup_loss += loss.item()
                sup_batches += 1

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                with torch.no_grad():
                    target_sizes = [img.size[::-1] for img in images]
                    results = processor.post_process_grounded_object_detection(
                        outputs,
                        inputs.input_ids,
                        threshold=0,
                        text_threshold=0,
                        target_sizes=target_sizes,
                    )

                    preds = []
                    targets = []

                    for res, anno in zip(results, annotations):
                        preds.append(
                            {
                                "boxes": res["boxes"],
                                "scores": res["scores"],
                                "labels": torch.zeros(
                                    len(res["boxes"]), dtype=torch.long, device=device
                                ),
                            }
                        )

                        if len(anno["annotations"]) > 0:
                            gt_boxes = torch.stack([a["bbox"] for a in anno["annotations"]])
                            gt_boxes_xyxy = torch.zeros_like(gt_boxes)
                            gt_boxes_xyxy[:, 0] = gt_boxes[:, 0]
                            gt_boxes_xyxy[:, 1] = gt_boxes[:, 1]
                            gt_boxes_xyxy[:, 2] = gt_boxes[:, 0] + gt_boxes[:, 2]
                            gt_boxes_xyxy[:, 3] = gt_boxes[:, 1] + gt_boxes[:, 3]

                            targets.append(
                                {
                                    "boxes": gt_boxes_xyxy.to(device),
                                    "labels": torch.zeros(
                                        len(gt_boxes), dtype=torch.long, device=device
                                    ),
                                }
                            )
                        else:
                            targets.append(
                                {
                                    "boxes": torch.empty((0, 4), device=device),
                                    "labels": torch.empty(
                                        0, dtype=torch.long, device=device
                                    ),
                                }
                            )

                    train_map_metric.update(preds, targets)

                train_pbar.set_postfix({"type": "sup", "loss": f"{loss.item():.4f}"})
            
            else:
                try:
                    images_weak, _ = next(unsup_iter)
                except StopIteration:
                    unsup_iter = iter(unsup_train_loader)
                    images_weak, _ = next(unsup_iter)
                
                batch_size = len(images_weak)
                
                with torch.no_grad():                    
                    inputs_weak = processor(
                        images=images_weak, return_tensors="pt"
                    ).to(device)
                    del inputs_weak["pixel_mask"]
                    
                    context_expanded = context_module().unsqueeze(0).expand(batch_size, -1, -1)
                    context_token_ids = context_module.initial_token_ids.unsqueeze(0).expand(batch_size, -1)
                    
                    prompt_enc = {}
                    prompt_enc["input_ids"] = context_token_ids
                    prompt_enc["inputs_embeds"] = context_expanded
                    
                    inputs_weak_with_prompt = inputs_weak | prompt_enc
                    
                    outputs_weak = model(**inputs_weak_with_prompt)
                    
                    target_sizes = [img.size[::-1] for img in images_weak]
                    results_weak = processor.post_process_grounded_object_detection(
                        outputs_weak,
                        inputs_weak_with_prompt.input_ids,
                        threshold=conf_threshold,
                        text_threshold=conf_threshold,
                        target_sizes=target_sizes,
                    )
                    
                    pseudo_annotations = []
                    images_strong = []

                    for idx, res in enumerate(results_weak):
                        high_conf_mask = res["scores"] > conf_threshold
                        pseudo_boxes = res["boxes"][high_conf_mask]
                        pseudo_scores = res["scores"][high_conf_mask]
                        
                        if len(pseudo_boxes) > 0:
                            if len(pseudo_boxes) > 5:
                                top_k_indices = torch.topk(pseudo_scores, k=5).indices
                                pseudo_boxes = pseudo_boxes[top_k_indices]
                            
                            img_strong, pseudo_boxes_transformed = apply_strong_augmentation(
                                images_weak[idx], 
                                boxes=pseudo_boxes
                            )
                            
                            pseudo_boxes_xywh = torch.zeros_like(pseudo_boxes_transformed)
                            pseudo_boxes_xywh[:, 0] = pseudo_boxes_transformed[:, 0]
                            pseudo_boxes_xywh[:, 1] = pseudo_boxes_transformed[:, 1]
                            pseudo_boxes_xywh[:, 2] = pseudo_boxes_transformed[:, 2] - pseudo_boxes_transformed[:, 0]
                            pseudo_boxes_xywh[:, 3] = pseudo_boxes_transformed[:, 3] - pseudo_boxes_transformed[:, 1]
                            
                            annotations_list = []
                            for box in pseudo_boxes_xywh:
                                area = float(box[2] * box[3])
                                if area > 0:
                                    annotations_list.append({
                                        "bbox": box.cpu().tolist(),
                                        "area": area,
                                        "category_id": 0,
                                        "iscrowd": 0,
                                    })
                            
                            pseudo_annotations.append({
                                "image_id": idx,
                                "annotations": annotations_list
                            })
                        else:
                            img_strong = apply_strong_augmentation(images_weak[idx])
                            pseudo_annotations.append({
                                "image_id": idx,
                                "annotations": []
                            })
                        
                        images_strong.append(img_strong)
                
                inputs_strong = processor(
                    images=images_strong, annotations=pseudo_annotations, return_tensors="pt"
                ).to(device)
                del inputs_strong["pixel_mask"]
                
                context_expanded = context_module().unsqueeze(0).expand(batch_size, -1, -1)
                context_token_ids = context_module.initial_token_ids.unsqueeze(0).expand(batch_size, -1)
                
                prompt_enc = {}
                prompt_enc["input_ids"] = context_token_ids
                prompt_enc["inputs_embeds"] = context_expanded
                labels = move_labels_to_device(inputs_strong["labels"], device)
                inputs_strong["labels"] = labels
                
                inputs_strong = inputs_strong | prompt_enc
                
                outputs_strong = model(**inputs_strong)
                loss_dict = outputs_strong.loss_dict
                weight_dict = {
                    "loss_ce": 2.0,
                    "loss_bbox": model.config.bbox_loss_coefficient,
                    "loss_giou": model.config.giou_loss_coefficient,
                }
                enc_weight_dict = {k + "_enc": v for k, v in weight_dict.items()}
                weight_dict.update(enc_weight_dict)
                weight_dict["loss_ce_enc"] = 0
                weight_dict["loss_bbox_enc"] = 0
                weight_dict["loss_giou_enc"] = 0
                unsup_loss = sum(
                    loss_dict[k] * weight_dict[k]
                    for k in loss_dict.keys()
                    if k in weight_dict
                )
                
                weighted_unsup_loss = current_lambda_u * unsup_loss
                
                total_train_loss += weighted_unsup_loss.item()
                total_unsup_loss += unsup_loss.item()
                unsup_batches += 1
                
                optimizer.zero_grad()
                weighted_unsup_loss.backward()
                optimizer.step()
                
                num_pseudo_boxes = sum(len(anno["annotations"]) for anno in pseudo_annotations)
                train_pbar.set_postfix({
                    "type": "unsup", 
                    "loss": f"{unsup_loss.item():.4f}",
                })

        avg_train_loss = total_train_loss / max_iters
        avg_sup_loss = total_sup_loss / sup_batches if sup_batches > 0 else 0
        avg_unsup_loss = total_unsup_loss / unsup_batches if unsup_batches > 0 else 0
        
        train_losses.append(avg_train_loss)
        sup_train_losses.append(avg_sup_loss)
        unsup_train_losses.append(avg_unsup_loss)

        train_metrics = train_map_metric.compute()
        train_map50_scores.append(train_metrics["map_50"].item())
        train_map75_scores.append(train_metrics["map_75"].item())
        train_map_scores.append(train_metrics["map"].item())

        print(f"\n  Training Metrics:")
        print(f"    Total Loss        : {avg_train_loss:.6f}")
        print(f"    Supervised Loss   : {avg_sup_loss:.6f} ({sup_batches} batches)")
        print(f"    Unsupervised Loss : {avg_unsup_loss:.6f} ({unsup_batches} batches)")
        print(f"    mAP@50            : {train_metrics['map_50']:.4f}")
        print(f"    mAP@75            : {train_metrics['map_75']:.4f}")
        print(f"    mAP (Overall)     : {train_metrics['map']:.4f}\n")

        torch.cuda.empty_cache()

        model.eval()
        context_module.eval()
        total_val_loss = 0.0
        val_map_metric.reset()

        val_pbar = tqdm(val_loader, desc="  Validation")
        with torch.no_grad():
            for images, annotations in val_pbar:
                batch_size = len(images)
                inputs = processor(
                    images=images, annotations=annotations, return_tensors="pt"
                ).to(device)
                del inputs["pixel_mask"]

                context_expanded = (
                    context_module().unsqueeze(0).expand(batch_size, -1, -1)
                )
                context_token_ids = context_module.initial_token_ids.unsqueeze(
                    0
                ).expand(batch_size, -1)

                prompt_enc = {}
                prompt_enc["input_ids"] = context_token_ids
                prompt_enc["inputs_embeds"] = context_expanded
                labels = move_labels_to_device(inputs["labels"], device)
                inputs["labels"] = labels

                inputs = inputs | prompt_enc

                outputs = model(**inputs)
                loss_dict = outputs.loss_dict
                weight_dict = {
                    "loss_ce": 2.0,
                    "loss_bbox": model.config.bbox_loss_coefficient,
                    "loss_giou": model.config.giou_loss_coefficient,
                }
                enc_weight_dict = {k + "_enc": v for k, v in weight_dict.items()}
                weight_dict.update(enc_weight_dict)
                weight_dict["loss_ce_enc"] = 0
                weight_dict["loss_bbox_enc"] = 0
                weight_dict["loss_giou_enc"] = 0
                loss = sum(
                    loss_dict[k] * weight_dict[k]
                    for k in loss_dict.keys()
                    if k in weight_dict
                )

                total_val_loss += loss.item()

                target_sizes = [img.size[::-1] for img in images]
                results = processor.post_process_grounded_object_detection(
                    outputs,
                    inputs.input_ids,
                    threshold=0,
                    text_threshold=0,
                    target_sizes=target_sizes,
                )

                preds = []
                targets = []

                for res, anno in zip(results, annotations):
                    preds.append(
                        {
                            "boxes": res["boxes"],
                            "scores": res["scores"],
                            "labels": torch.zeros(
                                len(res["boxes"]), dtype=torch.long, device=device
                            ),
                        }
                    )

                    if len(anno["annotations"]) > 0:
                        gt_boxes = torch.stack([a["bbox"] for a in anno["annotations"]])
                        gt_boxes_xyxy = torch.zeros_like(gt_boxes)
                        gt_boxes_xyxy[:, 0] = gt_boxes[:, 0]
                        gt_boxes_xyxy[:, 1] = gt_boxes[:, 1]
                        gt_boxes_xyxy[:, 2] = gt_boxes[:, 0] + gt_boxes[:, 2]
                        gt_boxes_xyxy[:, 3] = gt_boxes[:, 1] + gt_boxes[:, 3]

                        targets.append(
                            {
                                "boxes": gt_boxes_xyxy.to(device),
                                "labels": torch.zeros(
                                    len(gt_boxes), dtype=torch.long, device=device
                                ),
                            }
                        )
                    else:
                        targets.append(
                            {
                                "boxes": torch.empty((0, 4), device=device),
                                "labels": torch.empty(
                                    0, dtype=torch.long, device=device
                                ),
                            }
                        )

                val_map_metric.update(preds, targets)

                val_pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        avg_val_loss = total_val_loss / len(val_loader)
        val_losses.append(avg_val_loss)

        val_metrics = val_map_metric.compute()
        val_map50_scores.append(val_metrics["map_50"].item())
        val_map75_scores.append(val_metrics["map_75"].item())
        val_map_scores.append(val_metrics["map"].item())

        print(f"\n  Validation Metrics:")
        print(f"    Loss         : {avg_val_loss:.6f}")
        print(f"    mAP@50       : {val_metrics['map_50']:.4f}")
        print(f"    mAP@75       : {val_metrics['map_75']:.4f}")
        print(f"    mAP (Overall): {val_metrics['map']:.4f}\n")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss

            torch.save(
                {
                    "context_vectors": context_module.all_embeddings.data.clone(),
                },
                f"best_context_vectors_{sup_train_name}_{unsup_train_name}.pth",
            )

            print(f"  {'*'*66}")
            print(f"  *** BEST MODEL SAVED! (Val Loss: {best_val_loss:.6f}) ***")
            print(f"  {'*'*66}\n")

        scheduler.step()

    print(f"\n{'='*70}")
    print(f"  Training Complete")
    print(f"{'='*70}")
    print(f"  Best Validation Loss: {best_val_loss:.6f}")
    print(f"{'='*70}\n")

    plot_metrics(
        f"{sup_train_name}_{unsup_train_name}",
        num_epochs,
        train_losses,
        val_losses,
        train_map50_scores,
        val_map50_scores,
        train_map75_scores,
        val_map75_scores,
        train_map_scores,
        val_map_scores,
        sup_train_losses,
        unsup_train_losses,
    )

In [None]:
def test_one(model, processor, context_module, test_loader, dataset_name):
    model.eval()
    context_module.eval()

    test_map_metric = MeanAveragePrecision(iou_thresholds=[0.5, 0.75]).to(device)
    test_map_metric.warn_on_many_detections = False
    test_map_metric.reset()

    total_test_loss = 0.0

    test_pbar = tqdm(test_loader, desc=f"Testing {dataset_name}")

    with torch.no_grad():
        for images, annotations in test_pbar:
            batch_size = len(images)

            inputs = processor(
                images=images, annotations=annotations, return_tensors="pt"
            ).to(device)
            del inputs["pixel_mask"]

            context_expanded = context_module().unsqueeze(0).expand(batch_size, -1, -1)
            context_token_ids = context_module.initial_token_ids.unsqueeze(0).expand(
                batch_size, -1
            )

            prompt_enc = {}
            prompt_enc["input_ids"] = context_token_ids
            prompt_enc["inputs_embeds"] = context_expanded

            labels = move_labels_to_device(inputs["labels"], device)
            inputs["labels"] = labels

            inputs = inputs | prompt_enc

            outputs = model(**inputs)
            loss = outputs.loss
            loss_dict = outputs.loss_dict

            weight_dict = {
                "loss_ce": 2.0,
                "loss_bbox": model.config.bbox_loss_coefficient,
                "loss_giou": model.config.giou_loss_coefficient,
            }
            enc_weight_dict = {k + "_enc": v for k, v in weight_dict.items()}
            weight_dict.update(enc_weight_dict)
            weight_dict["loss_ce_enc"] = 0
            weight_dict["loss_bbox_enc"] = 0
            weight_dict["loss_giou_enc"] = 0

            loss = sum(
                loss_dict[k] * weight_dict[k]
                for k in loss_dict.keys()
                if k in weight_dict
            )

            total_test_loss += loss.item()

            target_sizes = [img.size[::-1] for img in images]
            results = processor.post_process_grounded_object_detection(
                outputs,
                inputs.input_ids,
                threshold=0,
                text_threshold=0,
                target_sizes=target_sizes,
            )

            preds = []
            targets = []

            for res, anno in zip(results, annotations):
                preds.append(
                    {
                        "boxes": res["boxes"],
                        "scores": res["scores"],
                        "labels": torch.zeros(
                            len(res["boxes"]), dtype=torch.long, device=device
                        ),
                    }
                )

                if len(anno["annotations"]) > 0:
                    gt_boxes = torch.stack([a["bbox"] for a in anno["annotations"]])
                    gt_boxes_xyxy = torch.zeros_like(gt_boxes)
                    gt_boxes_xyxy[:, 0] = gt_boxes[:, 0]
                    gt_boxes_xyxy[:, 1] = gt_boxes[:, 1]
                    gt_boxes_xyxy[:, 2] = gt_boxes[:, 0] + gt_boxes[:, 2]
                    gt_boxes_xyxy[:, 3] = gt_boxes[:, 1] + gt_boxes[:, 3]

                    targets.append(
                        {
                            "boxes": gt_boxes_xyxy.to(device),
                            "labels": torch.zeros(
                                len(gt_boxes), dtype=torch.long, device=device
                            ),
                        }
                    )
                else:
                    targets.append(
                        {
                            "boxes": torch.empty((0, 4), device=device),
                            "labels": torch.empty(0, dtype=torch.long, device=device),
                        }
                    )

            test_map_metric.update(preds, targets)
            test_pbar.set_postfix({"loss": f"{loss.item():.4f}"})

    avg_test_loss = total_test_loss / len(test_loader)
    test_metrics = test_map_metric.compute()

    print(f"\n{'='*70}")
    print(f"  {dataset_name} Results")
    print(f"{'='*70}")
    print(f"  Average Loss       : {avg_test_loss:.6f}")
    print(f"  mAP@50            : {test_metrics['map_50']:.4f}")
    print(f"  mAP@75            : {test_metrics['map_75']:.4f}")
    print(f"  mAP (Overall)     : {test_metrics['map']:.4f}")
    print(f"{'='*70}\n")

In [None]:
def test_all_datasets(model, processor, train_name):
    print(f"\n{'#'*70}")
    print(f"#  TESTING MODEL - Trained on Dataset {train_name}")
    print(f"{'#'*70}\n")

    initial_prompt = "malignant tumor cancer."

    context_module = CoOpContext(
        device=device,
        processor=processor,
        model=model,
        initial_prompt=initial_prompt,
    )

    def freeze_hook(grad):
        mask = context_module.trainable_mask.unsqueeze(-1).to(grad.device)
        return grad * mask

    context_module.all_embeddings.register_hook(freeze_hook)

    print(f"{'='*70}")
    print(f"  Loading Best Context Vectors from Dataset {train_name}")
    print(f"{'='*70}")

    checkpoint = torch.load(
        f"best_context_vectors_{train_name}.pth", map_location=device
    )
    context_module.all_embeddings.data = checkpoint["context_vectors"]
    print(f"  ✓ Context vectors loaded successfully!")
    print(f"  Context Vector: {context_module.all_embeddings}")
    print(f"  Shape: {context_module.all_embeddings.shape}")
    print(f"{'='*70}\n")

    word_embeddings = model.model.text_backbone.embeddings.word_embeddings.weight
    trainable_embeddings = context_module.all_embeddings[context_module.trainable_mask]
    context_norm = F.normalize(trainable_embeddings, p=2, dim=1)
    word_embeddings_norm = F.normalize(word_embeddings, p=2, dim=1)

    similarities = torch.matmul(context_norm, word_embeddings_norm.T)

    for i in range(len(trainable_embeddings)):
        top_sims, top_indices = torch.topk(similarities[i], k=5)

        print(f"\nTrainable Context Vector {i}:")
        for sim, idx in zip(top_sims, top_indices):
            word = processor.tokenizer.decode([idx.item()])
            print(f"  {word} ({sim.item():.4f})")

    test_loader_1, test_loader_2, test_loader_3 = create_test_dataloaders()

    test_one(
        model, processor, context_module, test_loader_1, dataset_name="Test Dataset A"
    )
    test_one(
        model, processor, context_module, test_loader_2, dataset_name="Test Dataset B"
    )
    test_one(
        model, processor, context_module, test_loader_3, dataset_name="Test Dataset C"
    )

In [None]:
def visualize_one(model, processor, context_module, test_loader):
    count = 0
    max_visualizations = 5

    model.eval()
    context_module.eval()

    with torch.no_grad():
        for images, annotations in test_loader:
            batch_size = len(images)

            inputs = processor(
                images=images, annotations=annotations, return_tensors="pt"
            ).to(device)
            del inputs["pixel_mask"]

            context_expanded = context_module().unsqueeze(0).expand(batch_size, -1, -1)
            context_token_ids = context_module.initial_token_ids.unsqueeze(0).expand(
                batch_size, -1
            )

            prompt_enc = {}
            prompt_enc["input_ids"] = context_token_ids
            prompt_enc["inputs_embeds"] = context_expanded

            labels = move_labels_to_device(inputs["labels"], device)
            inputs["labels"] = labels

            inputs = inputs | prompt_enc

            outputs = model(**inputs)

            target_sizes = [img.size[::-1] for img in images]
            results = processor.post_process_grounded_object_detection(
                outputs,
                inputs.input_ids,
                threshold=0.1,
                text_threshold=0.1,
                target_sizes=target_sizes,
            )

            for img, res, lbls in zip(images, results, annotations):
                boxes = res["boxes"]
                scores = res["scores"]
                text_labels = res.get("text_labels", None)

                k = 5
                top_k = min(k, len(scores))
                sorted_indices = scores.argsort(descending=True)[:top_k]
                boxes = boxes[sorted_indices]
                scores = scores[sorted_indices]
                if text_labels is not None:
                    text_labels = [text_labels[i] for i in sorted_indices.cpu().numpy()]

                fig, (ax_img, ax_legend) = plt.subplots(
                    1, 2, figsize=(16, 10), gridspec_kw={"width_ratios": [5, 1]}
                )

                ax_img.imshow(img)
                ax_img.set_title(
                    f"Predictions: {len(boxes)} | GT: {len(lbls.get('annotations', []))}",
                    fontsize=14,
                    fontweight="bold",
                )
                ax_img.axis("off")

                if "annotations" in lbls and len(lbls["annotations"]) > 0:
                    for ann in lbls["annotations"]:
                        bbox = ann["bbox"].cpu().numpy()
                        rect = patches.Rectangle(
                            (bbox[0], bbox[1]),
                            bbox[2],
                            bbox[3],
                            linewidth=3,
                            edgecolor="red",
                            facecolor="none",
                        )
                        ax_img.add_patch(rect)

                        ax_img.text(
                            bbox[0],
                            bbox[1] - 5,
                            "GT",
                            color="white",
                            fontsize=10,
                            fontweight="bold",
                            bbox=dict(
                                facecolor="red", alpha=0.9, edgecolor="none", pad=1
                            ),
                        )

                legend_text = []
                for idx, (pred, score) in enumerate(zip(boxes, scores)):
                    x1, y1, x2, y2 = pred.cpu().numpy()
                    score_val = score.cpu().item()

                    rect = patches.Rectangle(
                        (x1, y1),
                        x2 - x1,
                        y2 - y1,
                        linewidth=3,
                        edgecolor="blue",
                        facecolor="none",
                    )
                    ax_img.add_patch(rect)

                    ax_img.text(
                        x1,
                        y1 - 5,
                        str(idx + 1),
                        color="white",
                        fontsize=14,
                        fontweight="bold",
                        bbox=dict(facecolor="blue", alpha=0.9, edgecolor="none", pad=2),
                    )

                    if text_labels is not None and idx < len(text_labels):
                        label_str = text_labels[idx]
                    else:
                        label_str = "N/A"

                    legend_text.append(
                        f"{idx + 1}. Score: {score_val:.3f}\n   Label: {label_str}"
                    )

                ax_legend.axis("off")
                ax_legend.set_xlim(0, 1)
                ax_legend.set_ylim(0, 1)

                ax_legend.text(
                    0.05,
                    0.95,
                    "Predictions",
                    fontsize=16,
                    fontweight="bold",
                    verticalalignment="top",
                )

                y_pos = 0.88
                for text in legend_text:
                    ax_legend.text(
                        0.05,
                        y_pos,
                        text,
                        fontsize=11,
                        verticalalignment="top",
                        family="monospace",
                    )
                    y_pos -= 0.12

                if "annotations" in lbls and len(lbls["annotations"]) > 0:
                    gt_y_pos = y_pos - 0.05
                    ax_legend.text(
                        0.05,
                        gt_y_pos,
                        "Ground Truth",
                        fontsize=14,
                        fontweight="bold",
                        verticalalignment="top",
                        color="red",
                    )
                    ax_legend.text(
                        0.05,
                        gt_y_pos - 0.08,
                        f"GT. {len(lbls['annotations'])} lesion(s)",
                        fontsize=11,
                        verticalalignment="top",
                        family="monospace",
                        color="red",
                    )

                plt.tight_layout()
                plt.show()

            count += 1
            if count >= max_visualizations:
                break

In [None]:
def visualize_all_datasets(model, processor, train_name):
    print(f"\n{'#'*70}")
    print(f"#  Visualize MODEL - Trained on Dataset {train_name}")
    print(f"{'#'*70}\n")

    initial_prompt = "malignant tumor cancer."

    context_module = CoOpContext(
        device=device,
        processor=processor,
        model=model,
        initial_prompt=initial_prompt,
    )

    def freeze_hook(grad):
        mask = context_module.trainable_mask.unsqueeze(-1).to(grad.device)
        return grad * mask

    context_module.all_embeddings.register_hook(freeze_hook)

    print(f"{'='*70}")
    print(f"  Loading Best Context Vectors from Dataset {train_name}")
    print(f"{'='*70}")

    checkpoint = torch.load(
        f"best_context_vectors_{train_name}.pth", map_location=device
    )
    context_module.all_embeddings.data = checkpoint["context_vectors"]
    print(f"  ✓ Context vectors loaded successfully!")

    test_loader_1, test_loader_2, test_loader_3 = create_test_dataloaders()
    
    print(f"\n{'='*70}")
    print(f"  Visualizing Test Dataset A")
    print(f"{'='*70}\n")
    visualize_one(model, processor, context_module, test_loader_1)

    print(f"\n{'='*70}")
    print(f"  Visualizing Test Dataset B")
    print(f"{'='*70}\n")
    visualize_one(model, processor, context_module, test_loader_2)

    print(f"\n{'='*70}")
    print(f"  Visualizing Test Dataset C")
    print(f"{'='*70}\n")
    visualize_one(model, processor, context_module, test_loader_3)

In [None]:
model_id = "IDEA-Research/grounding-dino-base"

print(f"\n{'#'*70}")
print(f"#  INITIALIZING MODEL")
print(f"{'#'*70}\n")
print(f"  Model: {model_id}")

processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id)

model = model.to(device)

for param in model.parameters():
    param.requires_grad = False

model.eval()
print(f"  Model loaded and frozen successfully!\n")

In [None]:
train_one_dataset(model=model, processor=processor, sup_train_name="A", unsup_train_name="B", num_epochs=50)

In [None]:
test_all_datasets(model=model, processor=processor, train_name="A_B")

In [None]:
visualize_all_datasets(model=model, processor=processor, train_name="A_B")

In [None]:
train_one_dataset(model=model, processor=processor, sup_train_name="A", unsup_train_name="C", num_epochs=50)

In [None]:
test_all_datasets(model=model, processor=processor, train_name="A_C")

In [None]:
visualize_all_datasets(model=model, processor=processor, train_name="A_C")