In [54]:
import torch 
import pandas as pd
import os
from sklearn import preprocessing
import torch.nn.functional as F
import pickle
import matplotlib.pyplot as plt

In [55]:
le = preprocessing.LabelEncoder()
test_df = pd.read_csv('split_setup/test.csv', sep='\t')
test_files = test_df['filename'].values.reshape(-1)

path = 'path to meta.csv here'
meta = pd.read_csv(path, sep="\t")
test_labels_meta = torch.from_numpy(le.fit_transform(meta[['scene_label']].values.reshape(-1)))
test_indices = list(meta[meta['filename'].isin(test_files)].index)
test_labels = test_labels_meta[test_indices]

In [56]:
def load_logits(logits_files, subset):
    logits = []
    for logits_file in logits_files:
        file = f'{logits_file}_{subset}.pt'
        logit = torch.load(os.path.join('logits', file)).float()
        logits.append(logit)
    return logits

def ensemble_logits(logits):
    ensemble = logits[0]

    if len(logits) == 1:
        ensemble = F.log_softmax(ensemble, dim=-1)
        return ensemble
    
    for logit in logits[1:]:
        ensemble += logit
    ensemble = ensemble / len(logits)
    # ensemble = F.log_softmax(ensemble / 2, dim=-1)
    return ensemble

def ensemble_test_stats(logits_files, subset):
    logits = load_logits(logits_files, subset)
    ensemble = ensemble_logits(logits)
    test_ensemble = ensemble[test_indices]

    # loss
    samples_loss = F.cross_entropy(test_ensemble, test_labels, reduction="none")
    loss = samples_loss.mean()

    # accuracy
    _, preds = torch.max(test_ensemble, dim=1)
    n_correct_per_sample = (preds == test_labels)
    n_correct = n_correct_per_sample.sum()
    acc = n_correct / len(test_ensemble)
    return acc, loss, ensemble

In [57]:
splits = [5, 10, 25, 50, 100]
files = ['mn04_as_dir_fms_lr2']
for split in splits:
    acc, loss, _ = ensemble_test_stats(files, split)
    print(round(acc.item()*100,2), round(loss.item(),3))

47.64 1.538
50.55 1.452
53.96 1.419
56.4 1.45
59.45 1.425


In [64]:
teachers = [
    'mn04_as_dir_fms', 
    'mn04_as_dir_fms_lr', 
    'mn04_as_dir_fms_lr2', 
    'mn04_as_dir_fms_early', 
    'mn04_as_dir_fms_early2', 

    'mn05_as_dir_fms',
    'mn05_as_dir_fms_lr',
    'mn05_as_dir_fms_lr2',
    'mn05_as_dir_fms_early', 
    'mn05_as_dir_fms_early2', 
    
    'mn10_as_dir_fms',
    'mn10_as_dir_fms_lr',
    'mn10_as_dir_fms_lr2',
    'mn10_as_dir_fms_early', 
    'mn10_as_dir_fms_early2', 

    'dymn04_as_dir_fms', 
    'dymn04_as_dir_fms_lr', 
    'dymn10_as_dir_fms',
    'dymn10_as_dir_fms_lr',

    'passt_dir_fms',
    'passt_dir_fms_lr_ws',
    'cpr_128_dir_fms',
]



In [73]:
def calc_best_ensemble(teachers, split):
    best_ensemble_teachers = []
    # best_ensemble = None
    best_teacher = ''
    best_acc = -1
    best_loss = 100
    for _ in range(len(teachers)):
        is_new = False
        for teacher in teachers:
            if teacher in best_ensemble_teachers:
                continue
            
            ensemble = best_ensemble_teachers + [teacher]
            acc, loss, _ = ensemble_test_stats(ensemble, split)

            if loss < best_loss:
                best_teacher = teacher
                best_loss = loss
                best_acc = acc
                is_new = True
        if is_new:
            best_ensemble_teachers.append(best_teacher)
        else:
            break
    return best_acc.item(), best_loss.item(), best_ensemble_teachers

In [92]:
acc, loss, best_teachers = calc_best_ensemble(teachers, 100)

In [93]:
acc, loss

(0.6844339370727539, 0.861076295375824)

In [94]:
best_teachers

['mn10_as_dir_fms_early2',
 'passt_dir_fms',
 'cpr_128_dir_fms',
 'mn05_as_dir_fms_early2',
 'mn05_as_dir_fms',
 'mn04_as_dir_fms_early2']