In [None]:
# Copyright (C) 2024 Istituto Italiano di Tecnologia.  All rights reserved.
#
# This work is licensed under the LICENSE file 
# located at the root directory.

In [1]:
# Imports
%load_ext autoreload

from importlib import import_module
from collections import defaultdict
from pathlib import Path
import os

import matplotlib.pyplot as plt
import pandas as pd

from vlpers.utils.evaluation import get_ranks, mAP, mAP_at, mRR, recall_at
import vlpers.utils.logging as logging
from vlpers.utils.logging import logger
from vlpers.utils import misc
from pprint import pformat

os.chdir('..')
os.chdir = lambda x: None

### Config

In [None]:
from configs.retrieval import Config as cfg
from configs.methods.baselines.image_text import Method

dataset = 'conconchi'

cfg.Data = import_module(f'configs.datasets.{dataset}').Data
cfg.Method = Method

cfg.Logging.log_dir = Path('logs/weighted_average')
cfg.Logging.exp_dir = dataset

logger.setLevel(10)
misc.set_reproducibility()
    
exp_dir = (cfg.Logging.log_dir / cfg.Logging.exp_dir)
exp_dir.mkdir(parents=True, exist_ok=True)

cfg.Logging.exp_dir = Path(exp_dir)
logging.enable_file_handler(exp_dir / 'logs.txt')

misc.save_config(cfg.Logging.exp_dir, cfg)
misc.git_check_workspace(make_patch=cfg.Git.make_patch, path=exp_dir.resolve())
logger.info(f'\n{pformat(cfg.to_dict())}\n')

method = cfg.Method()

### Learn Concepts

In [3]:
# Learn the new concepts embeddings
concept_dataset = cfg.Data.LearnDS(image_transform=method.image_transform, text_transform=method.text_transform)
if not cfg.Method.load_concepts:
    method.set_mode(cfg.Method.Mode.LEARNING)
    for batch in concept_dataset.dl:
        images, _, concepts = batch 
        method.learn_concepts(images, concepts)
        
    if cfg.Logging.log_dir and cfg.Logging.save_concepts:
        method.save_concepts(cfg.Logging.exp_dir)

### Image Pool

In [4]:
# Populate image pool
method.set_mode(cfg.Method.Mode.TESTING)
image_dataset = cfg.Data.ImagePoolDS(image_transform=method.image_transform, text_transform=method.text_transform)
if not cfg.Method.load_image_pool:
    for batch in image_dataset.dl:
        images, *_ = batch 
        method.add_image_pool(images)
    if cfg.Logging.log_dir and cfg.Logging.save_image_pool:
        method.save_image_pool(cfg.Logging.exp_dir)

### Evaluation

In [None]:
weighted_avg = defaultdict(lambda:[])

for a in range(11):
    method.alpha = (a / 10)
    
        # Retrieve images
    eval_dataset = cfg.Data.EvalDS(image_transform=method.image_transform, text_transform=method.text_transform)
    results = defaultdict(lambda:[])
    for batch in eval_dataset.dl:
        gt, labels, concepts = batch

        scores = method.retrieve(descriptions=labels, concepts=concepts)
        
        # Log Metrics and log
        ranks, rank_ids = get_ranks(scores, gt)
        
        ranks = [[rank for rank in gt_ranks if rank != -1] for gt_ranks in ranks.tolist()]
        rank_ids = [[eval_dataset.map.reset_index().set_index(['ID_GTS']).loc[id].item() for id in gt_ids if id != -1] for gt_ids in rank_ids.tolist()]
        
        results['Ranks'] += ranks
        results['Rank_ids'] += rank_ids
        results['mRR'] += mRR(scores, gt, avg=False).tolist()
        results['mAP'] += mAP(scores, gt, avg=False).tolist()
        for k in [1, 5, 10]:
            results[f'R@{k}'] += recall_at(scores, gt, k=k, avg=False).tolist()
            results[f'mAP@{k}'] += mAP_at(scores, gt, k=k, avg=False).tolist()
      
    # Save Metrics 
    labels = eval_dataset.df
    results = pd.DataFrame.from_dict(results)
    results = pd.concat([labels, results], axis=1)

    weighted_avg['alpha'] += [method.alpha]
    weighted_avg['mAP'] += [results[["mAP"]].mean().item()]
    
    logger.info(f'alpha: {method.alpha} mAP: {results[["mAP"]].mean().item():.2f}')
    
misc.save_results(cfg.Logging.exp_dir, pd.DataFrame.from_dict(weighted_avg))


### Graphs

In [11]:
pd.set_option('display.precision', 2)

methods = {}
for p in Path('logs/weighted_average').glob('*'):
    methods[p.name] = pd.read_csv(p / 'results.csv', sep=';')

In [12]:
color = ['#008fd5', '#fc4f30', '#e5ae38', '#8b8b8b', '#810f7c']

def plot_methods(methods, contains='', name_list=None):
    if isinstance(name_list, str):
        name_list = [name_list]
    
    i = 0
    for name in methods:
        if name_list and name not in name_list:
            continue 
        
        if contains not in name:
            continue
        
        df = methods[name]
        plt.plot(df['alpha'], df['mAP'] * 100, label=name, color=color[i])
        i += 1

    # Add labels and legend
    plt.xticks([n/10 for n in range(11)])
    plt.xlim(-0.05, 1.35)
    plt.xlabel('α')
    plt.ylabel('mAP [%]')
    plt.legend(loc=2)

    # Show the plot
    plt.savefig('weighted.svg')  
    plt.show()
    

In [None]:
import seaborn as sns
sns.set_theme()
sns.set_style('whitegrid')
plot_methods(methods, name_list=['circo', 'conconchi', 'cirr', 'deepfashion', 'fashioniq'])