In [1]:
from collections import defaultdict
import json
import os
import evaluate
import pandas as pd

import logging
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)

In [2]:
def get_image_id(composit_image_id):
    composit_image_id = composit_image_id.split('_')
    if len(composit_image_id) > 2:
        image_id = composit_image_id[-2]
    else:
        image_id = composit_image_id[-1]
    return image_id

def get_split(composit_image_id):
    composit_image_id = composit_image_id.split('_')
    if len(composit_image_id) > 2:
        composit_split = composit_image_id[:-2]
    else:
        composit_split = composit_image_id[:-1]
    
    split = '_'.join(composit_split)
    return split

def group_results(raw_results):
    results = {}
    for element in raw_results:
        image_id = get_image_id(element['image_id'])
        split = get_split(element['image_id'])
        results[image_id] = element['caption']
    return results

In [3]:
import re

def pre_caption(caption, max_words=0):
    caption = re.sub(
        r"([.!\"()*#:;~])",
        ' ',
        caption.lower(),
    )
    caption = re.sub(
        r"\s{2,}",
        ' ',
        caption,
    )
    caption = caption.rstrip('\n')
    caption = caption.strip(' ')

    #truncate caption
    caption_words = caption.split(' ')
    if max_words and len(caption_words)>max_words:
        caption = ' '.join(caption_words[:max_words])

    return caption

def load_testsets(path_text_data, list_skills=None):
    list_skills = list_skills if list_skills else []

    test_datasets = {'all' : pd.read_csv(path_text_data + 'Caption_all.tsv', sep='\t')}
    test_datasets.update({
        'test_%s' %skill : pd.read_csv(path_text_data + 'Caption_testing_%s.tsv'%skill, sep='\t') 
        for skill in list_skills if os.path.isfile(path_text_data + 'Caption_testing_%s.tsv'%skill)})
    test_datasets['test'] = test_datasets['all'][test_datasets['all'].split == 'test']
    test_datasets['val'] = test_datasets['all'][test_datasets['all'].split == 'val']
    del test_datasets['all']

    for split, dataset in test_datasets.items():
        dataset['caption'] = dataset['caption'].apply(pre_caption)
        dataset['image_id'] = dataset['image_ID'].astype(str)

    for split in test_datasets:
        test_datasets[split] = test_datasets[split][['image_id','caption']]

    return test_datasets

def to_grouped_dict(dataset):
    dict_dataset = dataset.to_dict(orient='records')
    grouped_dataset = defaultdict(list)
    for element in dict_dataset:
        grouped_dataset[element['image_id']].append(element['caption'])

    return grouped_dataset

def build_lists_for_evaluation(results, test_dataset):
    predictions = []
    references = []
    for image_id in test_dataset:
        predictions.append(results[image_id])
        references.append(test_dataset[image_id])
    return predictions, references

### Load References

In [4]:
path_text_data = './data/'
list_skills = ['color','counting','gender']
test_datasets = load_testsets(path_text_data, list_skills=list_skills)
test_datasets['test'].head()

Unnamed: 0,image_id,caption
125,1007129816,the man with pierced ears is wearing glasses a...
126,1007129816,a man with glasses is wearing a beer can croch...
127,1007129816,a man with gauges and glasses is wearing a bli...
128,1007129816,a man in an orange hat starring at something
129,1007129816,a man wears an orange hat and glasses


In [5]:
for split in test_datasets:
    test_datasets[split] = to_grouped_dict(test_datasets[split])

### Load Results

In [6]:
from tqdm.auto import trange, tqdm
from pathlib import Path

In [7]:
from nltk.translate.meteor_score import meteor_score
import numpy as np
cap_metrics = evaluate.combine(['bleu', 'rouge'])

def compute_metrics(predictions, references):
    metrics = cap_metrics.compute(predictions=predictions, references=references)
    for i in range(4):
        metrics[f'bleu{i+1}'] = metrics['precisions'][i]
    metrics['meteor'] = np.mean([meteor_score(hypothesis=p, references=rs) for p, rs in zip(predictions, references)])
    return metrics

In [8]:
base_dir = 'BLIP/output/'
exp_names = [
    str(dir_.stem) for dir_ in Path(base_dir).glob('*') if str(dir_.stem) not in ['saved_exps', '.gitignore']]

In [9]:
exp_names = [exp for exp in exp_names if 'aae' in exp or 'base' in exp]
exp_names

['caption_base_flickr',
 'caption_flickr_aae_color',
 'caption_flickr_aae_counting',
 'caption_flickr_aae_gender',
 'caption_flickr_aae_color+counting+gender']

In [10]:
metrics = {}
for exp_name in exp_names:
    result_dir = Path(f'{base_dir}/{exp_name}/result')
    pbar = tqdm(list(result_dir.glob('test_epoch*.json')))
    metrics[exp_name] = defaultdict(list)
    for result_path in pbar:
        if 'rank' in str(result_path):
            continue
        epoch = int(result_path.stem.replace('test_epoch', ''))
        with result_path.open() as fp:
            raw_results = json.load(fp)

        results = group_results(raw_results)
        for split in test_datasets:
            if split == 'val':
                continue
            predictions, references = build_lists_for_evaluation(results, test_datasets[split])

            pbar.set_description(f"{exp_name}: {split}")
            computed_metric = compute_metrics(predictions=predictions,
                                              references=references)

            metrics[exp_name][split].append(computed_metric)

  0%|          | 0/30 [00:00<?, ?it/s]

INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using defa

  0%|          | 0/36 [00:00<?, ?it/s]

INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using defa

  0%|          | 0/27 [00:00<?, ?it/s]

INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using defa

  0%|          | 0/36 [00:00<?, ?it/s]

INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using defa

In [11]:
for exp_name in exp_names:
    result_dir = Path(f'{base_dir}/{exp_name}/result')
    pbar = tqdm(list(result_dir.glob('val_epoch*.json')))
    for result_path in pbar:
        if 'rank' in str(result_path):
            continue
        epoch = int(result_path.stem.replace('val_epoch', ''))
        with result_path.open() as fp:
            raw_results = json.load(fp)

        results = group_results(raw_results)

        split = 'val'
        predictions, references = build_lists_for_evaluation(results, test_datasets[split])

        pbar.set_description(f"{exp_name}: {split}")
        computed_metric = compute_metrics(predictions=predictions,
                                          references=references)

        metrics[exp_name][split].append(computed_metric)

  0%|          | 0/30 [00:00<?, ?it/s]

INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.


  0%|          | 0/36 [00:00<?, ?it/s]

INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.


  0%|          | 0/27 [00:00<?, ?it/s]

INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.


  0%|          | 0/36 [00:00<?, ?it/s]

INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.


  0%|          | 0/27 [00:00<?, ?it/s]

INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.


In [12]:
for exp_name in metrics:
    print(f'Epochs registered for experiment {exp_name}:')
    for split in metrics[exp_name]:
        print(f'{split:12.12s}:', len(metrics[exp_name][split]))

Epochs registered for experiment caption_base_flickr:
test_color  : 10
test_countin: 10
test_gender : 10
test        : 10
val         : 10
Epochs registered for experiment caption_flickr_aae_color:
test_color  : 12
test_countin: 12
test_gender : 12
test        : 12
val         : 12
Epochs registered for experiment caption_flickr_aae_counting:
test_color  : 9
test_countin: 9
test_gender : 9
test        : 9
val         : 9
Epochs registered for experiment caption_flickr_aae_gender:
test_color  : 12
test_countin: 12
test_gender : 12
test        : 12
val         : 12
Epochs registered for experiment caption_flickr_aae_color+counting+gender:
test_color  : 9
test_countin: 9
test_gender : 9
test        : 9
val         : 9


### Display Results

In [13]:
def bold(text):
    BOLD = '\033[1m'
    END = '\033[0m'
    return BOLD + text + END

In [14]:
display_name = {
    'caption_base_flickr': 'Base',
    
    'caption_flickr_augmented_c': 'Col',
    'caption_flickr_augmented_counting': 'Cnt',
    'caption_augmented_flickr': 'Gen',

    'caption_flickr_augmented_color+counting': 'Col+Cnt',
    'caption_flickr_augmented_c+g': 'Col+Gen',
    'caption_flickr_augmented_counting+gender': 'Cnt+Gen',
    'caption_flickr_augmented_color+counting+gender': 'C+C+G',

    
    'caption_flickr_inpaiting_color': 'ICol',
    'caption_flickr_inpaiting_counting': 'ICnt',
    'caption_flickr_inpaiting_gender': 'IGen',

    'caption_flickr_inpaiting_color+counting': 'ICol+Cnt',
    'caption_flickr_inpaiting_color+gender': 'ICol+Gen',
    'caption_flickr_inpaiting_counting+gender': 'ICnt+Gen',
    'caption_flickr_inpaiting_color+counting+gender': 'IC+C+G'
    
    'caption_flickr_aae_color': 'ACol',
    'caption_flickr_aae_counting': 'ACnt',
    'caption_flickr_aae_gender': 'AGen',
    'caption_flickr_aae_color+counting+gender': 'AC+C+G'

}
relevant_metrics = ['bleu', 'bleu1', 'bleu2', 'bleu3', 'bleu4', 'rouge1', 'rouge2', 'rougeL', 'rougeLsum', 'meteor']

SyntaxError: invalid syntax (<ipython-input-14-782b1b20fd30>, line 23)

In [None]:
import torch
best_epochs = {}
for exp_name in exp_names:
    ckpt_path = f'{base_dir}/{exp_name}/checkpoint_best.pth'
    if not os.path.exists(ckpt_path):
        continue
    ckpt = torch.load(ckpt_path)
    best_epochs[exp_name] = ckpt['epoch']

In [None]:
best_epochs

In [None]:
# for split in ['val', 'test', 'test_gender', 'test_color']:
# for exp_name in ['caption_base_flickr','caption_augmented_flickr','caption_flickr_augmented_c','caption_flickr_augmented_c+g']:

In [None]:
exp_names

In [None]:
sort_test = ['val', 'test', 'test_color', 'test_counting', 'test_gender']
sort_name = list(display_name.keys())

selected_exp_names = [name for name in exp_names if '+' not in name] # non composite train sets
# selected_exp_names = exp_names                                       # all train sets

In [None]:
for split in sorted(test_datasets.keys(), key=sort_test.index):
    heading = f'{split:11.11s}'

    for exp_name in sorted(selected_exp_names, key=sort_name.index):
        exp_display_name = display_name.get(exp_name, exp_name)
        heading += f'{exp_display_name:>9.8s}'
    print(bold(heading))

    for m in relevant_metrics:
        row = f'{m:10.10s}:'
        highest_metric = float('-inf')

        for exp_name in sorted(selected_exp_names, key=sort_name.index):
            if exp_name not in best_epochs:
                continue
            best_epoch = best_epochs[exp_name]
            instance_metrics = metrics[exp_name][split][best_epoch]   
            highest_metric = max(highest_metric, metrics[exp_name][split][best_epoch][m])

        for exp_name in sorted(selected_exp_names, key=sort_name.index):
            if exp_name not in best_epochs:
                continue
            best_epoch = best_epochs[exp_name]
            instance_metrics = metrics[exp_name][split][best_epoch]
            metric_str = f'{instance_metrics[m]:9.4f}'
            if instance_metrics[m] == highest_metric:
                metric_str = bold(metric_str)
            row += metric_str
            
        print(row)
        
    print()

In [None]:
import matplotlib
import matplotlib.pyplot as plt

In [None]:
for exp_name in exp_names:
    val_bleu4 = [m['bleu4'] for m in metrics[exp_name][split]]
    plt.plot(val_bleu4, label=display_name[exp_name])

plt.show()