In [None]:
import matplotlib 
import seaborn as sns
%matplotlib inline
# plt.style.use(['science','no-latex'])
matplotlib.rcParams.update({'font.size':8})
from fastai.vision.all import *
def get_df(runs):
    id_df=pd.DataFrame({'id':[run.id for run in runs]})
    summary_df = pd.DataFrame.from_records([{k:v for k,v in run.summary._json_dict.items() if not k.startswith('_')} for run in runs]) 
    config_df = pd.DataFrame.from_records([{k:v for k,v in run.config.items() if not k.startswith('_')} for run in runs])
    return pd.concat([id_df,config_df,summary_df],axis=1)


In [None]:
# IFP distribution
IFP_distribution = pd.read_csv('results/IFP_dis.csv')
IFP_distribution1 = pd.read_csv('results/IFP_scream_color_dis.csv')

fig,axes = plt.subplots(2,2,figsize=(6.84,4),dpi=100)
plt.tight_layout(pad=1.08,h_pad=2)
axes = axes.flatten()

for i, ds in enumerate(['smallnorb','dsprites_full','color_dsprites']):
    ax = axes[i]
    ax.set_title(ds)
    mask = (IFP_distribution['dataset'] == ds)
    ds = IFP_distribution[mask]
    for a, df in ds.groupby('action'):
        sns.kdeplot(df['IFP'],label=a,shade=True,ax=ax)
    ax.legend()
    ax.set_xlim(0,None)
    
ax = axes[3]
ax.set_title('scream_dsprites')
mask = (IFP_distribution1['dataset'] == 'scream_dsprites')
ds = IFP_distribution1[mask]
for a, df in ds.groupby('factor'):
    sns.kdeplot(df['IFP'],label=int(a),shade=True,ax=ax)
ax.legend()
ax.set_xlim(None,200)
plt.savefig('pics/IFP_distribution.pdf')

In [None]:
# Information Diffusion, NMI2
r= api.run('erow/dlib/23kjiyy2')
nmi = r.history(keys=['discrete_mi'])['discrete_mi'].values.tolist()/np.log([[[3,6,40,32,32]]])
nmi.sort(1)
NMI2 = nmi[:,-2].max(1)
plt.plot(NMI2)


In [None]:
# Information leakage
run=api.run('erow/dlib/2qidd1r3')
his=run.history(keys=['discrete_mi'])
run.config['loss']

mi=np.array(his.iloc[:,1].values.tolist())
mi.shape
nmi = mi / np.log([[[3,6,40,32,32]]])
nmi.sort(1)
plt.subplots(figsize=(3,2),dpi=200)
g=sns.heatmap(nmi[::2,-1],
            annot=True,fmt='.2f',annot_kws={"size":4})
g.set_xticklabels(['shape','scale','orientation','posX','posY'],rotation=30)
plt.ylabel('1e4')
plt.savefig('pics/IL_NMI1.pdf')

plt.subplots(figsize=(3,2),dpi=100)
g=sns.heatmap(nmi[::2,-2],
            annot=True,fmt='.2f',annot_kws={"size":4})
g.set_xticklabels(['shape','scale','orientation','posX','posY'],rotation=30)
plt.ylabel('1e4')
plt.savefig('pics/IL_NMI2.pdf')

In [None]:
# gamma
# runs =api.runs('public/fractionVAE',{'tags':'gamma','config.base':"70"})
# runs
data = pd.read_csv('gamma.csv').values #gamma.csv
plt.subplots(figsize=(3.3,2),dpi=200)
g=sns.heatmap(data[:,1:6],
            yticklabels=data[:,0],
            annot=True,fmt='.2f')
g.set_xticklabels(['shape','scale','orientation','posX','posY'],rotation=30)
plt.ylabel('$\gamma$')
plt.savefig('pics/gamma_NMI2.pdf')


In [None]:
# annealing test, without supervision
annealing_test = pd.read_csv('results/annealing_test.csv')

df = annealing_test[annealing_test['dataset']=='dsprites_full']
plt.grid(True)
mi = df['MI'].values
plt.plot(df['beta'].iloc[:-1],+mi[1:]-mi[:-1])
plt.savefig('dsprites_anneal.pdf')

# udr
import numpy as np
def compute_gaussian_kl(z_mean, z_logvar):
          return np.mean(
              0.5 * (np.square(z_mean) + np.exp(z_logvar) - z_logvar - 1),
              axis=0)
def representation_function(encoder):
    def _representation_function(x):
        mu, logvar = encoder(x)
        kl = compute_gaussian_kl(mu, logvar)
        return (mu, kl)
    return _representation_function
results={}

for k,v in models.items():
    mean_rep = [ representation_function(encoder) for encoder in v]
    results[k]=udr.compute_udr_sklearn(dataset, 
                                       mean_rep,
                                       np.random.RandomState())
    

In [None]:
# fluctuation
import matplotlib.ticker as mtick
order = [f'model-{i}.pt' for i in range(1,20,2)]
itrs = [i*11520 for i in range(1,20,2)]
metric_name=['beta VAE','MIG','DCI disnentanglement']
metrics = ['eval_accuracy','discrete_mig','disentanglement']
fig,axes = plt.subplots(1,len(metrics),figsize=(6.75,2),dpi=200)
for j in range(3):
    ax = axes[j]
    ax.grid()
    ax.set_title(metric_name[j])
    ax.set_xlabel('Iterations')
    ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.2f'))
    ax.xaxis.get_major_formatter().set_powerlimits((0,1))
    ax.set_xlim(0,itrs[-1])
    for i,model in enumerate([constant,tc, annealedVAE]):
        ax.plot(itrs,[model[f][metrics[j]] for f in order])

# plt.ylim(0,0.3)
axes[0].legend([r'$\beta$-VAE',r'$\beta$-TCVAE','AnnealedVAE'])
plt.savefig('pics/metric_fluctuation.pdf')


In [None]:
# metrics 
matplotlib.rcParams['xtick.minor.visible']=False
metrics = ['evaluation_results.eval_accuracy','evaluation_results.discrete_mig','evaluation_results.disentanglement']
metric_name=['beta VAE','MIG','DCI disnentanglement']
fig,axes = plt.subplots(1,len(metrics),figsize=(6.75,2),dpi=200)
plt.tight_layout(pad=0.1)
for i,metric in enumerate(metrics):
    t=df.iloc[best_model][metric].values
    axes[i].grid(axis='y')
    t=np.clip(t,0,1)
    sns.violinplot(data=t.reshape(-1,50).T,ax=axes[i])
    axes[i].set_title(metric_name[i])
fig.savefig('pics/unstable_models.pdf')

In [None]:
# comparison
runs = list(map(api.run,['erow/dlib/a7w65vl4','erow/dlib/3vebm4of','erow/dlib/9odv037v']))
fig,axes = plt.subplots(1,3,sharex=True,figsize=(6.85,2),dpi=200)
plt.tight_layout()
ax = axes[0]
r = runs[1]

df = r.history(keys=['MI'])
ax.plot(df['_step']*50,df['MI'],label='MI')
df = r.history(keys=['kl_loss'])
ax.plot(df['_step']*50,df['kl_loss'],label = 'KL')
ax.legend()
ax.set_title('AnnealedVAE')

ax = axes[1]
r = runs[2]

df = r.history(keys=['MI'])
ax.plot(df['_step']*50,df['MI'],label='MI')
ax.plot([5],label='Vars')
ax.legend(loc=3)
ax.legend()
ax2 = ax.twinx()

ax2.plot(steps[:-1],var[:-1],)

ax.set_title('CascadeVAE')

ax = axes[2]
r = runs[0]

df = r.history(keys=['MI'])
ax.plot(df['_step']*50,df['MI'],label='MI')
ax.plot([5],label='beta')
ax.legend(loc=3)

ax2 = ax.twinx()
df = r.history(keys=['beta'])
ax2.plot(df['_step']*50,df['beta'],'')

ax.set_title('DEFT')
plt.savefig('pics/comparison.pdf')


fig,axes = plt.subplots(1,2,sharex=True,figsize=(6.85,2),dpi=200)

ax = axes[0]
for r in runs[:2]:
    model=r.config['model']
    df = r.history(keys=['discrete_mi'])
    nmi2=df['discrete_mi'].map(NMI1)
    ax.plot(df['_step']*50,nmi2,label=model)
    
for r in runs[2:]:
    model=r.config['model']
    df = r.history(keys=['discrete_mi'])
    nmi2=df['discrete_mi'].map(NMI1)
    ax.plot(df['_step']*50,nmi2,label=model)
ax.set_xticks(np.linspace(0,70000,5))
ax.legend(prop={'size': 5})

ax = axes[1]
for r in runs[:2]:
    model=r.config['model']
    df = r.history(keys=['discrete_mi'])
    nmi2=df['discrete_mi'].map(NMI2)
    ax.plot(df['_step']*50,nmi2,label=model)
    
for r in runs[2:]:
    model=r.config['model']
    df = r.history(keys=['discrete_mi'])
    nmi2=df['discrete_mi'].map(NMI2)
    ax.plot(df['_step']*50,nmi2,label=model)
ax.legend(prop={'size': 5})
plt.savefig('pics/NMI_comparison.pdf')