## Model Performance Analysis

The purpose of this notebook is to take a deeper dive into what exactly is happening when we continually train an image classification model that learns embeddings instead of 1-hot labels. The main motivation here is driven by the results we see when attempted to set up a simple experiment with an AlexNet image classification model trained on BERT Embeddings. 

### Experiment details

We incrementally trained this model 5 classes at a time (50 epochs), did a 100-way validation at train time, and tested per 5-class task during test time. The whole experiment was run 5 times to obtain the final accuracy values.

### Baseline

In [9]:
import utils
import warnings
warnings.filterwarnings('ignore')
out = utils.cleaned_up_json(f"/nethome/bdevnani3/raid/continual/cifar100_alexnet_base_icl_itl/run_0/logs_run_id_0.json")

out['acc'].style.applymap(utils.color)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
0,71.6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,23.2,73.4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,21.4,18.4,81.6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,25.2,23.4,10.8,82.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,17.0,23.6,35.0,23.6,79.6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5,25.4,25.2,20.0,21.8,20.0,87.8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
6,23.8,15.8,21.4,21.6,22.2,31.6,80.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
7,18.2,20.8,23.0,18.0,27.8,19.0,26.6,69.8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
8,24.8,26.0,28.0,30.8,24.4,24.2,19.0,23.8,78.8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
9,17.4,22.2,25.6,30.6,18.0,26.0,24.0,24.0,23.2,73.6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [10]:
out['avg_acc']
out['gem_bwt']

'-56.294736561022304'

### BERT

In [14]:
import utils
import warnings
warnings.filterwarnings('ignore')
out = utils.cleaned_up_json(f"/nethome/bdevnani3/raid/continual/cifar100_alexnet_bert_icl_itl/run_0/logs_run_id_0.json")

out['acc'].style.applymap(utils.color)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
0,77.6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,22.6,77.4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,22.2,33.0,87.8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,20.8,41.2,22.8,86.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,29.8,40.2,41.6,37.8,85.6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5,26.2,28.2,38.4,34.0,25.2,93.2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
6,18.6,28.0,35.8,20.2,27.2,37.6,90.8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
7,22.6,24.6,21.6,25.2,20.6,20.8,21.8,74.2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
8,27.0,23.6,22.6,27.4,29.6,13.6,18.6,24.8,82.2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
9,33.8,16.2,19.0,23.2,30.0,31.8,33.6,16.0,39.4,78.6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [15]:
out['avg_acc']


'31.269999'

In [11]:
import utils
import warnings
warnings.filterwarnings('ignore')
out = utils.cleaned_up_json(f"/nethome/bdevnani3/raid/continual/cifar100_alexnet_word2vec_icl_itl/run_0/logs_run_id_0.json")

out['acc'].style.applymap(utils.color)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
0,75.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,32.0,73.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,29.2,39.8,85.6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,25.8,41.4,35.8,83.8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,30.6,29.8,32.6,34.4,84.8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5,23.8,28.6,36.2,44.0,43.6,90.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
6,26.4,28.4,46.8,33.4,29.2,56.4,85.4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
7,21.4,23.2,23.8,36.4,33.4,35.8,49.2,72.4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
8,35.0,33.2,23.4,43.8,29.8,30.8,31.0,31.2,81.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
9,25.0,24.8,29.4,25.6,40.4,23.8,36.6,23.0,39.8,77.4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [13]:
out['avg_acc']
# out['gem_bwt']

'31.539999'