### COSMOS

In [None]:
import torch
import os
from utils.data_loader import CosmosTestDataset
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, average_precision_score
import utils.misc as misc

def collate_fn(batch):
    images, bboxes, caption_match, caption_diff, label  = zip(*batch)
    return list(images), list(bboxes), list(caption_match), list(caption_diff), list(label)

transform_full = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    ])    

test_dataset = CosmosTestDataset(json_file="data/cosmos_anns_acm/cosmos_anns_acm/acm_anns/public_test_acm.json", \
    img_dir="data", transform_full=transform_full)

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=1
)
        

In [None]:
import matplotlib.pyplot as plt
import random

def visualize_results(preds, gts, dataloader, k=5):
    indices = random.sample(range(len(preds)), k)
    
    for idx in indices:
        # Get data
        image, bboxes, caption1, caption2, label, bert_score = dataloader[idx]
        pred_label = preds[idx]
        gt_label = gts[idx]

        # Plot image
        plt.imshow(image.permute(1, 2, 0))  # if image is a tensor with shape (C, H, W)
        plt.axis('off')
        plt.title(f"Prediction: {'Out-of-context' if pred_label else 'In-context'} | "
                  f"Ground Truth: {'Out-of-context' if gt_label else 'In-context'} | "
                  f"{'✓' if pred_label == gt_label else '✗'}")
        plt.show()

        # Print captions
        print(f"Caption 1: {caption1}")
        print(f"Caption 2: {caption2}")
        print(f"BERT Score: {bert_score:.4f}")
        print("-" * 60)
        
        

In [None]:
from networks.cosmos import CosmosFullModel


class CosmosTest:
    def __init__(self, load_path, dataloader, device):
        self.device = device
        self.dataloader = dataloader
        self.model = CosmosFullModel(300, "cuda")
        checkpoint = torch.load(load_path)
        self.model.load_state_dict(checkpoint['model_state_dict'])


    def get_prediction(self, caption1_scores, caption2_scores, bboxes, bert_score):
        caption1_bboxes = misc.top_bbox_from_scores(bboxes, caption1_scores)
        caption2_bboxes = misc.top_bbox_from_scores(bboxes, caption2_scores)
        bbox_overlap = misc.is_bbox_overlap(caption1_bboxes, caption2_bboxes, 0.5)
        if bbox_overlap:
            if bert_score >= 0.5:
                context = 0
            else:
                context = 1
            return context
        else:
            return 0
        
    def run_test(self):
        model.eval()
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for image, bboxes, caption1, caption2, label, bert_score in tqdm(dataloader, desc="Test Batch:"):
                object_embeddings, match_embeddings, diff_embeddings = model(image, bboxes, caption1, caption2)
                caption1_scores, caption2_scores = misc.get_scores(object_embeddings, match_embeddings, diff_embeddings)
                
                preds = self.get_prediction(caption1_scores[0], caption2_scores[0], bboxes[0], bert_score[0])
                labels = label.item()  # All match cases are positives
                
                all_preds.append(preds)
                all_labels.append(labels)
                
        accuracy = accuracy_score(all_labels, all_preds)
        f1 = f1_score(all_labels, all_preds)
        ap = average_precision_score(all_labels, all_preds)       
        
        tqdm.write(f"Test Metrics: Accuracy = {accuracy}, F1 Score = {f1}, Average Precision = {ap}") 
        
        return all_preds, all_labels