In [1]:
import os
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
run = 'imagenette'
dataset = 'Imagenette'
first_split = pd.concat([pd.read_csv(f'predictions/{dataset}/{run}/first-split/{f}') for f in os.listdir(f'predictions/{dataset}/{run}/first-split')])
second_split = pd.concat([pd.read_csv(f'predictions/{dataset}/{run}/second-split/{f}') for f in os.listdir(f'predictions/{dataset}/{run}/second-split')])
first_split

Unnamed: 0,image_id,stage,epoch,label,prediction,loss
0,3246,first-split,0,3.0,3.0,2.017901
1,4682,first-split,0,4.0,4.0,1.745118
2,206,first-split,0,0.0,0.0,1.749493
3,488,first-split,0,0.0,0.0,1.719831
4,8770,first-split,0,9.0,9.0,1.664168
...,...,...,...,...,...,...
4730,8058,first-split,12,8.0,8.0,0.006124
4731,978,first-split,12,1.0,1.0,0.016846
4732,6746,first-split,12,7.0,7.0,0.023937
4733,8376,first-split,12,8.0,8.0,0.002773


## First Split Learning Time

aka the first epoch at which example x is classified correctly for the rest of (stage 1) training

In [3]:
def get_fslt(df):
    df = df.sort_values('epoch', ascending=False)
    incorrect = df[df['prediction'] != df['label']]
    if len(incorrect) == 0:
        return 0
    else:
        return incorrect.iloc[0]['epoch'] + 1 # +1 because we want the first epoch where the model was correct

fslt = first_split.groupby('image_id', as_index=False).apply(get_fslt)
fslt.columns = ['image_id', 'fslt']

## Second Split Forgetting Time

aka the first epoch where example x is never classified correctly

In [4]:
def get_ssft(df):
    df = df.sort_values('epoch', ascending=False)
    correct = df[df['prediction'] == df['label']]
    if len(correct) == 0:
        return 0
    else:
        return correct.iloc[0]['epoch'] + 1 # +1 because we want the first epoch where the model was incorrect

ssft = second_split.groupby('image_id', as_index=False).apply(get_ssft)
ssft.columns = ['image_id', 'ssft']

In [5]:
# check that if example x has a fslt of num_epochs (i.e. it was never correct), then it has a ssft of 0 (i.e. it was never correct)
len(first_split['epoch'].unique())

30

In [6]:
results = pd.merge(fslt, ssft, on='image_id')

In [7]:
results

Unnamed: 0,image_id,fslt,ssft
0,0,1,10
1,2,0,10
2,4,0,10
3,6,0,10
4,8,7,9
...,...,...,...
4730,9460,0,10
4731,9462,0,10
4732,9464,0,10
4733,9466,0,10
