In [None]:
import sys
import os
import matplotlib.pyplot as plt
import torch
import numpy as np

In [None]:
init_stat_dir = '../saved/translation'

In [None]:
tick_size = 23
caption_size = 20
plt.rcParams.update({
    "xtick.labelsize": tick_size,
    "xtick.labelsize": tick_size,
"font.size": tick_size,
    "font.family":"sans-serif"
})

In [None]:
def load_n(stats,name,class_list=list(range(10)),rm=[],rep=[]):
    tracker = []
    old_n = 0
    for classes in class_list:
        mats = stats[classes][name]
        cls_tracker = []
        for i,mat in enumerate(mats):
            if i not in rm:
                cls_tracker.append(mat)
            if i in rep:
                cls_tracker.append(mat)
        cls_tracker = torch.Tensor(cls_tracker).unsqueeze(1)
        tracker.append(torch.Tensor(cls_tracker).unsqueeze(1))
    tracker = torch.cat(tracker,dim=1)
    return tracker


def load_data(stats,name,class_list=list(range(10)),rm=[],rep=[],mean=False):
    tracker = []
    for classes in class_list:
        mats = stats[classes][name]
        cls_tracker = []
        for i,mat in enumerate(mats):
            if i not in rm:
                if mean:
                    cls_tracker.append(torch.mean(mat))
                else:
                    cls_tracker.append(mat)
            if i in rep:
                cls_tracker.append(torch.mean(mat))
        cls_tracker = torch.Tensor(cls_tracker)   
        cls_tracker = cls_tracker.unsqueeze(1)
        tracker.append(cls_tracker)

    tracker = torch.cat(tracker,dim=1)
    return tracker

## Investigation of J

In [None]:

def gen_plot(model_list,param_list,dataset,lim1=None,lim2=None,label=None):
#    fig,axs = plt.subplots(1,2,figsize=(12,4))
    fig,axs = plt.subplots(1,1,figsize=(6,4))
    for model,param in zip(model_list,param_list):
        for i,param in enumerate(param_list):
            model_path  = 'torch' + model
            stat_dir = os.path.join(init_stat_dir,f'{dataset}/{model}_{param}')
            stat_dir = os.path.join(stat_dir,'metric.pt')

            stats = torch.load(stat_dir)

            if model == 'inb':
                n = load_n(stats,'nparams')
            else:
                n = load_n(stats,'nparams',rm=[0])



            if model == 'inb':
                wd = load_data(stats,'wd',mean=True)
            else:
                wd = load_data(stats,'wd',mean=True,rm=[0])
                
            n = torch.sum(n,dim=1)/5 # per client
            wd = torch.mean(wd,dim=1)
            
            print(n)
            print('n:',n.shape)
            print('wd:',wd.shape)
                
            axs.plot(n,wd,label=f'J={label[i]}')
#                axs.set_title(f'{modelname[model]}')
            axs.set_title(f'HistIndAEINB: J')
            axs.set_xlabel('Communication Cost \n(# params sent)')
            #axs.set_ylabel('Average Waserstein Distance')
            axs.set_ylabel('Avg. Wasserstein-2')

    axs.grid(True,alpha=1)
    axs.legend(prop={'size': caption_size},fancybox=True, framealpha=0.3)
    axs.set_ylim(lim1)

    plt.show()


### RMNIST

In [None]:
model_list = ['histindaeinb']
J_list = [200,100,50,30,10]
param_list = [f'10_10_{J}_500' for J in J_list]
dataset = 'rmnist'
gen_plot(model_list,param_list,dataset,label=J_list)
#gen_plot(model_list,param_list,target_list, class_list,lim1=(26.5,28),lim2=(98,110))

### RFMNIST

In [None]:
model_list = ['histindaeinb']
J_list = [200,100,50,30,10]
J_list = [50,30,10]
param_list = [f'10_10_{J}_500' for J in J_list]
dataset = 'rfmnist'
gen_plot(model_list,param_list,dataset,label=J_list)
#gen_plot(model_list,param_list,target_list, class_list,lim1=(26.5,28),lim2=(98,110))