In [1]:
import sys, os, glob, re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sb
from torch.utils.data import DataLoader, TensorDataset
import torch
import wandb
import yaml

sys.path.append('../')
from src.dataset_large import ProtoPlanetaryDisks
from src.ae_model import ConvLinTrans_AE

main_path = os.path.dirname(os.getcwd())
save_plots = False

In [2]:
def load_model_list(ID='zg3r4orb', device='cpu'):
    
    fname = glob.glob('%s/wandb/run-*-%s/model.pt' % 
                      (main_path, ID))[0]
    
    config_f = glob.glob('%s/wandb/run-*-%s/config.yaml' % 
                         (main_path, ID))[0]
    with open(config_f, 'r') as f:
        conf = yaml.safe_load(f)
    conf = {k: v['value'] for k,v in conf.items() if 'wandb' not in k}
    aux = re.findall('\/run-(\d+\_\d+?)-\S+\/', config_f)
    conf['date'] = aux[0] if len(aux) != 0 else ''
    del aux
    conf['ID'] = ID
    
    print('Loading from... \n', fname)
    
    model = ConvLinTrans_AE(latent_dim=conf['latent_dim'],
                            img_dim=187,
                            in_ch=1,
                            kernel=conf['kernel_size'],
                            n_conv_blocks=conf['conv_blocks'])
        
    state_dict = torch.load(fname, map_location=device)
    if list(state_dict.keys())[0].split('.')[0] == 'module':
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:] # remove `module.`
            new_state_dict[name] = v
    else:
        new_state_dict = state_dict
    model.load_state_dict(new_state_dict)
    model.eval()
    model.to(device)
    print('Is model in cuda? ', next(model.parameters()).is_cuda)
    
    return model, conf


def evaluate_encoder(model, dataloader, conf, 
                     force=False, device='cpu'):
    
    fname_mu = '%s/wandb/run--%s/latent_space_mu.txt' % (main_path, conf['ID'])
    fname_std = '%s/wandb/run--%s/latent_space_std.txt' % (main_path, conf['ID'])

    if os.path.exists(fname_mu) & os.path.exists(fname_std) & ~force:
        print('Loading from files...')
        mu = np.loadtxt(fname_mu)
        std = np.loadtxt(fname_std)
    
    else:
        print('Evaluating Encoder...')
        time_start = datetime.datetime.now()
        
        mu, logvar, xhat, labels = [], [], [], []
        with tqdm_notebook(total=len(dataloader)) as pbar:
            for i, (data, label, onehot, pp) in enumerate(dataloader):
                data = data.to(device)
                onehot = onehot.to(device)
                pp = pp.to(device)
                cc = torch.cat([onehot, pp], dim=1)
                if params['label_dim'] > 0 and params['physics_dim'] > 0:
                    mu_, logvar_ = model.encoder(data, label=onehot, phy=pp)
                elif params['label_dim'] > 0 and params['physics_dim'] == 0:
                    mu_, logvar_ = model.encoder(data, label=onehot)
                elif params['label_dim'] == 0:
                    mu_, logvar_ = model.encoder(data)
                else:
                    print('Check conditional dimension...')
                mu.extend(mu_.data.cpu().numpy())
                logvar.extend(logvar_.data.cpu().numpy())
                labels.extend(label)
                torch.cuda.empty_cache()
                pbar.update()
        mu = np.array(mu)
        std = np.exp(0.5 * np.array(logvar))

        #np.savetxt(fname_mu, mu)
        #np.savetxt(fname_std, std)
        #np.savetxt(fname_lbs, np.asarray(labels), fmt='%s')
        elap_time = datetime.datetime.now() - time_start
        print('Elapsed time  : %.2f s' % (elap_time.seconds))
        print('##'*20)
        
    mu_df = pd.DataFrame(mu)
    std_df = pd.DataFrame(std)
        
    mu_df['class'] = labels
    std_df['class'] = labels
    
    return mu_df, std_df


In [3]:
ID = '0ip71wfh'
gpu = False
rnd_seed = 13

In [4]:
if not os.path.exists('%s/wandb/run--%s/model.pt' % 
                      (main_path, ID)):
    print('Downloading files from Weight & Biases')
    
    api = wandb.Api()
    run = api.run('deep_ppd/PPD-AE/%s' % (ID))
    run.file('model.pt').download(replace=True, 
                                  root='%s/wandb/run--%s/' % 
                                  (main_path, ID))
    run.file('config.yaml').download(replace=True, 
                                     root='%s/wandb/run--%s/' % 
                                     (main_path, ID))

device = torch.device("cuda:0" if torch.cuda.is_available() and gpu else "cpu")

# Load Model 

In [5]:
model, config = load_model_list(ID=ID)
config

Loading from... 
 /home/jorgemarpa/Astro/PPDAE/wandb/run--0ip71wfh/model.pt
Is model in cuda?  False


{'batch_size': 64,
 'comment': 'overfitting',
 'cond': 'F',
 'conv_blocks': 4,
 'data': 'PPD',
 'dropout': 0.2,
 'dry_run': False,
 'early_stop': False,
 'kernel_size': 3,
 'latent_dim': 8,
 'lr': 0.001,
 'lr_sch': 'step',
 'machine': 'exalearn',
 'model_name': 'ConvLinTrans_AE',
 'n_train_params': 4609667,
 'num_epochs': 100,
 'physics_dim': 0,
 'rnd_seed': 13,
 'date': '',
 'ID': '0ip71wfh'}

# Load Dataset 

In [6]:
dataset = ProtoPlanetaryDisks(machine='exalearn', 
                              transform=False, img_norm=True)

In [7]:
train_loader, val_loader = dataset.get_dataloader(batch_size=128,
                                                  shuffle=True,
                                                  val_split=.2,
                                                  random_seed=rnd_seed)

In [54]:
img_test_loader = DataLoader(dataset.imgs_test, 
                             batch_size=100, drop_last=False)
par_test_loader = DataLoader(dataset.par_test, 
                             batch_size=100, drop_last=False)

In [9]:
len(train_loader), len(val_loader)

(500, 125)

In [None]:
dataset.