In [None]:
!pip install altair vega_datasets

In [8]:
import pandas as pd
import altair as alt

# Destilacion de ultima capa

In [26]:

def load_KD_data(folder="./Cifar10/ResNet101/exp1/students"):
    data = pd.read_csv(folder+"/summary.csv") 
    data = data[data["student"] != 'EfficientNetB0']
    if "temp" not in data.columns:
        d_arr=[d.split(",")[0] for d in data["distillation"]]
        t_arr=[float(d.split(",T-")[1]) for d in data["distillation"]]
        data["distillation"]=d_arr
        data["temp"]=t_arr
    return data


def plot_KD(data,phase,field):
    detalle=['test_acc', 'test_teacher/acc', 'test_loss', 'test_eval',
           'train_acc', 'train_teacher/acc', 'train_loss', 'train_eval','distillation', 'temp']
    
    field_dict={'acc':"Accuracy [%]", 'eval':"Perdida Cross Entropy", 'loss':"Perdida de Destilación"
        
    }
    
    #data['train_acc']-=data['test_acc']
    
    
    bar=alt.Chart(data).mark_point().encode(
        alt.X('temp:O', scale=alt.Scale(zero=False,base=10,type='log', ),title="Temperatura"),
        alt.Y('%s_%s'%(phase,field), 
              scale=alt.Scale(zero=False, type='log' if field in ['loss','eval'] else 'linear'), 
              title=field_dict[field]),
        shape=alt.Shape('distillation', legend=alt.Legend(title="Destilación")),
        color=alt.Color('student', legend=alt.Legend(title="Modelo")),
        size=alt.value(50),
        tooltip=detalle
    ).interactive()
    
    if field == 'acc':
        accs = {'Model':['MobileNet','ResNet18','ResNet101'],
                      'ce_train':[95.73,98.15,98.52],
                      'ce_test':[87.8,90.58,90.68]}
        df=pd.DataFrame(accs)

        aggregates = alt.Chart(df).mark_rule(opacity=0.5).encode(
                    y='ce_%s:Q'%phase,
                    color='Model:N',
                    size=alt.value(2))

        
        return (aggregates+bar).properties(width=600,height=400) 
    return bar.properties(width=600,height=400)
    
def load_and_plot_KD(folder="./Cifar10/ResNet101/exp1/students",phase='test',field='acc'):
    data = load_KD_data(folder)

    return plot_KD(data,phase,field)



load_and_plot_KD(folder="./Cifar10/ResNet101/exp2/students",phase='train')

In [22]:
def omniplot(folder="./Cifar10/ResNet101/exp1/students"):
    data = load_KD_data(folder)

    detalle=['test_acc', 'test_loss', 'test_eval',
               'train_acc','train_loss', 'train_eval', 'temp']
    chart=alt.Chart(data).mark_point().encode(
        alt.X(alt.repeat("column"), type='quantitative',scale=alt.Scale(zero=False,base=10,type='log')),
        alt.Y(alt.repeat("row"), type='quantitative',scale=alt.Scale(zero=False,base=10,type='log')),
        shape='student',
        color='distillation'
    ).properties(
        width=150,
        height=150
    ).repeat(
        row=detalle,
        column=detalle
    )
    return chart

omniplot(folder="./Cifar10/ResNet101/exp3/students")

## Destilacion con features usando KD

In [23]:
!ls Cifar10/ResNet101/students

[34mMobileNet-KD[m[m   [34mMobileNet-KDCE[m[m [34mResNet18-KD[m[m    [34mResNet18-KDCE[m[m


In [34]:
import altair as alt


In [40]:
def load_data(folder):
    source = pd.read_csv(folder+"/summary.csv") 
    source=source.fillna(1)
    source['student last_layer']=[row['student']+","+row['last_layer'] for i,row in source.iterrows()]
    return source

def plot_feats(data,phase,field):
    detalle=['test_acc', 'test_teacher/acc', 'test_loss', 'test_eval',
           'train_acc', 'train_teacher/acc', 'train_loss', 'train_eval','distillation','last_layer','layer']
    
    field_dict={'acc':"Accuracy [%]", 'eval':"Perdida Cross Entropy", 'loss':"Perdida de Destilación"
        
    }
    
    #data['train_acc']-=data['test_acc']
    
    
    bar=alt.Chart(data).mark_point().encode(
        alt.X('layer:O', scale=alt.Scale(zero=False,base=10,type='log', ),title="Capa"),
        alt.Y('%s_%s'%(phase,field), 
              scale=alt.Scale(zero=False, type='log' if field in ['loss','eval'] else 'linear'), 
              title=field_dict[field]),
        shape=alt.Color('distillation', legend=alt.Legend(title="Destilación")),
        color=alt.Color('student', legend=alt.Legend(title="Modelo")),
        size=alt.value(50),
        tooltip=detalle
    ).interactive()
    
    if field=='acc':
        accs = {'Model':['MobileNet','ResNet18','ResNet101'],
                      'ce_train':[95.73,98.15,98.52],
                      'ce_test':[87.8,90.58,90.68]}
        df=pd.DataFrame(accs)

        aggregates = alt.Chart(df).mark_rule(opacity=0.5).encode(
                    y='ce_%s:Q'%phase,
                    color='Model:N',
                    size=alt.value(2))

        
        return (aggregates+bar).properties(width=600,height=400) 
    return bar.properties(width=600,height=400)

def loadNplotFeats(folder="./Cifar10/ResNet101/exp7/students",phase="test",field="acc"):
    return plot_feats(load_data(folder),phase,field)

In [41]:
loadNplotFeats()

In [52]:
def feat_omniplot(folder="./Cifar10/ResNet101/exp7/students",scale='log'):
    data = load_data(folder)

    detalle=['test_acc', 'test_loss', 'test_eval',
               'train_acc','train_loss', 'train_eval']
    chart=alt.Chart(data).mark_point().encode(
        alt.X(alt.repeat("column"), type='quantitative',scale=alt.Scale(zero=False,base=10,type=scale)),
        alt.Y(alt.repeat("row"), type='quantitative',scale=alt.Scale(zero=False,base=10,type=scale)),
        color='student',
        shape='distillation'
    ).properties(
        width=150,
        height=150
    ).repeat(
        row=detalle,
        column=detalle
    )
    return chart

feat_omniplot(folder="./Cifar10/ResNet101/exp4/students")

In [56]:
load_KD_data().head()

Unnamed: 0.1,Unnamed: 0,test_acc,test_teacher/acc,test_loss,test_eval,train_acc,train_teacher/acc,train_loss,train_eval,epoch,student,distillation,temp
0,0,90.32,90.68,2467.345938,0.493469,98.152,98.274,286.173518,0.057235,99,ResNet18,KD_CE,100.0
1,1,90.28,90.68,0.251264,0.450828,98.36,98.288,0.032126,0.049576,99,ResNet18,KD_CE,1.0
2,2,90.24,90.68,228784.270391,0.457569,98.288,98.298,26333.361176,0.052667,83,ResNet18,KD_CE,1000.0
3,3,90.02,90.68,24.829716,0.496368,98.334,98.364,2.722346,0.054231,78,ResNet18,KD_CE,10.0
4,4,90.0,90.68,6.059729,0.483544,98.126,98.334,0.735777,0.057665,98,ResNet18,KD_CE,5.0


# Dataset Merging

In [107]:

def load_KD_data(folder="./Cifar10/ResNet101/exp1/students"):
    data = pd.read_csv(folder+"/summary.csv") 
    data = data[data["student"] != 'EfficientNetB0']
    if "temp" not in data.columns:
        d_arr=[d.split(",")[0] for d in data["distillation"]]
        t_arr=[float(d.split(",T-")[1]) for d in data["distillation"]]
        data["distillation"]=d_arr
        data["temp"]=t_arr
    return data

drop_cols="test_teacher/acc train_teacher/acc lr	epochs	train_batch_size teacher	test_batch_size resume epoch".split()


d1=load_KD_data(folder="./Cifar10/ResNet101/exp1/students").drop(["test_teacher/acc","train_teacher/acc"],axis=1)
d2=load_KD_data(folder="./Cifar10/ResNet101/exp2/students").drop(drop_cols,axis=1)
d3=load_KD_data(folder="./Cifar10/ResNet101/exp3/students").drop(drop_cols,axis=1)
db=d3.merge(d2,on=['student','distillation','temp'],suffixes=('_R01','_R1')).drop(['Unnamed: 0_R01','Unnamed: 0_R1'],axis=1)
db=db.merge(d1,on=['student','distillation','temp']).drop(['epoch','Unnamed: 0'],axis=1)

In [111]:
detalle=['test_acc' + s for s in ["","_R1",'_R01']]

alt.Chart(db).mark_point().encode(
    alt.X(alt.repeat("column"), type='quantitative',scale=alt.Scale(zero=False,base=10,type='linear')),
    alt.Y(alt.repeat("row"), type='quantitative',scale=alt.Scale(zero=False,base=10,type='linear')),
    color='student',
    shape='distillation'
).properties(
    width=150,
    height=150
).repeat(
    row=detalle,
    column=detalle
)


In [108]:
db.head()

Unnamed: 0,test_acc_R01,test_loss_R01,test_eval_R01,train_acc_R01,train_loss_R01,train_eval_R01,student,distillation,temp,test_acc_R1,...,test_eval_R1,train_acc_R1,train_loss_R1,train_eval_R1,test_acc,test_loss,test_eval,train_acc,train_loss,train_eval
0,89.33,2462.161829,0.492432,97.648,332.814397,0.066563,ResNet18,KD_CE,100.0,10.0,...,2.302861,9.84,11517.09,2.303418,90.32,2467.345938,0.493469,98.152,286.173518,0.057235
1,90.02,0.187363,0.337151,98.002,0.058635,0.066467,ResNet18,KD_CE,1.0,10.0,...,2.326473,10.0,1.246259,2.328405,90.28,0.251264,0.450828,98.36,0.032126,0.049576
2,89.34,243786.745078,0.487573,97.604,33371.804589,0.066744,ResNet18,KD_CE,1000.0,10.0,...,2.303176,9.934,1152305.0,2.304609,90.24,228784.270391,0.457569,98.288,26333.361176,0.052667
3,90.05,21.899282,0.437634,97.632,3.35101,0.066709,ResNet18,KD_CE,10.0,10.0,...,2.302697,9.946,115.1815,2.303356,90.02,24.829716,0.496368,98.334,2.722346,0.054231
4,89.06,5.887921,0.469171,97.784,0.793954,0.06162,ResNet18,KD_CE,5.0,13.19,...,2.302721,9.984,28.82868,2.303233,90.0,6.059729,0.483544,98.126,0.735777,0.057665


In [104]:
detalle

['test_acc', 'test_acc_R1', 'test_acc_R01']

In [None]:
db[test_acc_R01	] test_acc_R01