In [1]:
from src.dataset import COCO2CLIPDataset, PACCO2CLIPDataset
from src.metrics import get_image_and_text_tensor
from src.train_util import get_name, do_train
from src.model import CrossAttentionModule, MLPs
from src.eval_util import do_eval
from src.plots_util import plot_losses, plot_values, bcolors

import cv2
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
from torch.nn import functional as F
from tqdm import tqdm
import yaml 

import json
import pandas as pd
import numpy as np


device = 'cuda:7'

# Utilities

In [2]:
def plot_results(results, warmup=True, limit=None):
    if limit is not None:
        for field, results_type in results.items():
            for type, to_cut in results_type.items():
                results[field][type] = to_cut[:limit] if to_cut is not None else None
    losses = results['losses']
    labels = ["Training Loss PACCO", "Validation Loss PACCO", 'Validation Loss COCO'] if losses['additional_val_loss'] is not None else ["Training Loss COCO", "Validation Loss COCO", 'Validation Loss COCO']
    plot_losses(losses['train_loss'], losses['val_loss'], additional_val_losses=losses['additional_val_loss'], labels=labels, plot=True, warmup=warmup)   
    if results['mean_rank_sums'] is not None:
        pacco_results, clip_pacco_results = results['mean_rank_sums']['mean_rank_sums'], results['mean_rank_sums']['clip_mean_rank_sums']
        plot_values([pacco_results, clip_pacco_results], ['Current Model', 'CLIP B/16'], 'Mean Rank Sum', 'Mean Rank sum on PACCO Validation set', warmup=warmup)
    if results['rsums']is not None:
        rsums, clip_rsums = results['rsums']['rsums'], results['rsums']['clip_rsums'] 
        plot_values([rsums, clip_rsums], ['Current Model', 'CLIP B/16'], 'rsum', 'rsum on COCO', warmup=warmup)
        
def train_test_save(model_name, train_dataset, val_dataset, test1k_imm=None, test1k_txt=None, model=None, additional_val_dataset=None, warmup=True):
    out_path = os.path.join('checkpoints', f'{model_name}.pth')
    # loss_path = os.path.join('checkpoints', 'loss', f'{model_name}.jpg')

    train_config = model_name.split('_')[0]
    model_config = model_name[model_name.find('_') + 1:]
    train_config_path = f"configs/train/{train_config}.yaml"
    model_config_path = f"configs/model/{model_config}.yaml"
    config = {}
    with open(train_config_path, 'r') as config_file:
        config['train'] = yaml.safe_load(config_file)
    with open(model_config_path, 'r') as config_file:
        config['model'] = yaml.safe_load(config_file)
    print(f"Configuration loaded!\n{json.dumps(config, indent=2)}")
    print("-" * 80)
    print(f"{bcolors.BOLD}{bcolors.UNDERLINE}{model_name}{bcolors.ENDC}")
    print("-" * 80)

    if model is None:
        # model loading
        if 'num_attention_layers' in config['model']:
            model = CrossAttentionModule.from_config(config['model'])
        else:
            model = MLPs.from_config(config['model'])
        model.to(device)

    model = do_train(model, train_dataset, val_dataset, config['train'], plot=True, loss_path=None, additional_val_dataset=additional_val_dataset, warmup=warmup)
    do_eval(model, model_name, test1k_imm, test1k_txt)
    torch.save(model.state_dict(), out_path)
    return model

def show_results(model_name=None, test1k_imm=None, test1k_txt=None, show_loss=True, plotting_results=False, evaluate=True, limit=None, warmup=True):

    print("-" * 80)
    print(f"{bcolors.BOLD}{bcolors.UNDERLINE}{model_name}{bcolors.ENDC}")
    print("-" * 80)

    # config_path = f'configs/{model_name}.yaml'
    # with open(config_path, 'r') as config_file:
    #     config = yaml.safe_load(config_file)
    train_config = model_name.split('_')[0]
    model_config = model_name[model_name.find('_') + 1:]
    train_config_path = f"configs/train/{train_config}.yaml"
    model_config_path = f"configs/model/{model_config}.yaml"
    config = {}
    with open(train_config_path, 'r') as config_file:
        config['train'] = yaml.safe_load(config_file)
    with open(model_config_path, 'r') as config_file:
        config['model'] = yaml.safe_load(config_file)
    print(f"Configuration loaded!\n{json.dumps(config, indent=2)}")

    if plotting_results:
        results = torch.load(f'checkpoints/results/{model_name}.pt')
        plot_results(results, warmup, limit)
    if show_loss:
        print("Training losses:")
        loss = cv2.imread(f'checkpoints/loss/{model_name}.jpg')
        loss = cv2.cvtColor(loss, cv2.COLOR_BGR2RGB)

        # Display the image using Matplotlib
        plt.imshow(loss)
        plt.axis('off')  # Optional: Turn off axis labels
        plt.show()
    if evaluate:
        # model loading
        if 'num_attention_layers' in config['model']:
            model = CrossAttentionModule.from_config(config['model'])
        else:
            model = MLPs.from_config(config['model'])
        model.load_state_dict(torch.load(f"checkpoints/{model_name}.pth", device))
        model.to(device)

        do_eval(model, model_name, test1k_imm, test1k_txt)

# Dataset Loading

In [3]:
train_dataset = COCO2CLIPDataset('./features/ViT-B-16/train.json')
val_dataset = COCO2CLIPDataset('./features/ViT-B-16/val.json')

Loading dataset...
Dataset loaded!
Loading dataset...
Dataset loaded!


In [4]:
hard_train_dataset = PACCO2CLIPDataset('./fg-ovd_feature_extraction/training_sets/1_attributes.pt')
medium_train_dataset = PACCO2CLIPDataset('./fg-ovd_feature_extraction/training_sets/2_attributes.pt')
easy_train_dataset = PACCO2CLIPDataset('./fg-ovd_feature_extraction/training_sets/3_attributes.pt')
trivial_train_dataset = PACCO2CLIPDataset('./fg-ovd_feature_extraction/training_sets/shuffle_negatives.pt')
hard_val_dataset = PACCO2CLIPDataset('./fg-ovd_feature_extraction/val_sets/1_attributes.pt')

Loading dataset...
Dataset loaded!
Loading dataset...
Dataset loaded!


In [5]:
test1k_imm, test1k_txt = get_image_and_text_tensor('features/ViT-B-16/test1k.json')

100%|██████████| 1000/1000 [00:00<00:00, 4198.44it/s]
