In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
import torch
import datetime as dt
import os
import time
import copy
import gzip
import json
import imp

import data
import train
import stats
import plot

# Network training

In [None]:
from train import train_network, get_net_from_method

# Stats

In [None]:
from stats import calc_datapoint_statistics, calc_global_statistics

# Import data

In [None]:
from data import load_dataset, compute_idx_splits, scale_to_standard, get_dir_files, load_method_dict, compute_pca_projections

# Additional plotting

In [None]:
from plot import plot_results

# Main loop

In [None]:
""" Timestamp of the format: hour:minute:second """                  
def timestamp(dt_obj):
    return "%d_%d_%d_%d_%d_%d" % (dt_obj.year, dt_obj.month, dt_obj.day, dt_obj.hour, dt_obj.minute, dt_obj.second)

In [None]:
available_datasets = {'boston', 'concrete', 'energy', 'abalone', 'naval', 
                      'power', 'protein', 'wine_red', 'yacht', 'year', 
                      'california', 'diabetes', 'superconduct',
                    'toy_modulated', 'toy_hf'}

toy_datasets = {'toy_hf','toy_modulated'}
small_datasets = {'toy_hf','yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red'}
large_datasets = {'toy_modulated', 'kin8nm', 'abalone', 'naval', 'power', 'superconduct', 'protein', 'california'}
very_large_datasets = {'year'}

In [None]:
available_splits = {'random_folds', 'single_random_split', 'single_label_split', 'label_folds', 'single_pca_split', 'pca_folds'}
available_methods = {'vanilla', 'de', 'pu', 'mc_wd=0.000001', 'pu_de', 'mc_pu', 'swag', 'evidential', 'concrete_dropout', 'new_wdrop_exact_l=5'}

dt_now = dt.datetime.now()
exp_ident = 'TEST' # TODO: SPECIFY IDENTIFIER HERE
exp_dir = './experiment_results/%s_%s' % (exp_ident, timestamp(dt_now))
reuse_exp_dir = False

# Base parameters
n_output = 1

net_params = {'n_output': n_output,
            'layer_width': 100,
            'num_layers': 2,
            'nonlinearity': torch.nn.ReLU(), #tanh,sigmoid
            'init_corrcoef':0.0,
            'de_components': 5} 

train_params = {'device': 'cpu', #torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
              'drop_bool':True,
              'drop_bool_ll':True,
              'drop_p':0.1,
              'num_epochs': 45,#00,
              'batch_size': 100,
              'learning_rate': 0.001,
              'loss_func':torch.nn.MSELoss(reduction='mean'),
              'weight_decay': 0,
              'loss_params':[5,1,False]}

use_splits = available_splits
datasets = ['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'kin8nm', 'power', 'naval', 'california', 'superconduct', 'protein', 'year']
methods = ['swag', 'de', 'pu', 'evidential', 'pu_de', 'mc_pu', 'mc_wd=0.000001', 'concrete_dropout', 'new_wdrop_exact_l=5']

with_valset = False
folds_with_val = False
val_perc = 0.
train_perc = 0.8 # 0.8
if with_valset:
    folds_with_val = True # False
    val_perc = 0.2 # 0
    train_perc = 0.7 # 0.8

eval_epochs = False


def _local_stats(net, method, X_train, y_train, X_test, y_test, method_dict, 
                 X_val=None, y_val=None, train_projections=None, test_projections=None, val_projections=None, epoch=None):
    
    if epoch is not None:
        if epoch not in method_dict:
            method_dict[epoch] = dict()
            
        _local_stats(net, method, X_train, y_train, X_test, y_test, method_dict[epoch],
                    X_val=X_val, y_val=y_val, train_projections=train_projections, test_projections=test_projections, val_projections=val_projections,
                    epoch=None)
        return
    
    iso_reg_train = []
    df_train = calc_datapoint_statistics(net=net,data=[X_train, y_train], method=method, iso_reg=iso_reg_train, train_N=len(X_train))
    iso_reg = iso_reg_train

    if with_valset:
        iso_reg_val = []
        df_val = calc_datapoint_statistics(net=net,data=[X_val, y_val],  method=method, iso_reg=iso_reg_val, train_N=len(X_train))
        iso_reg = iso_reg_val

    df_test  = calc_datapoint_statistics(net=net,data=[X_test, y_test],  method=method, iso_reg=iso_reg, train_N=len(X_train))

    if train_projections is not None:
        df_train['pca0_projection'] = train_projections #tmp[train_idxs].values
    if test_projections is not None:
        df_test['pca0_projection']  = test_projections #tmp[test_idxs].values

    if with_valset:
        if val_projections is not None:
            df_val['pca0_projection'] = val_projections
        method_dict[method] = [df_train, df_val, df_test]
    else:
        method_dict[method] = [df_train, df_test]

def _global_stats(method_dict):
    global_stats = {}
    for method in method_dict:
        split_idx_list = [0, 1, 2] if with_valset else [0, 1]
        global_stats[method] = [calc_global_statistics(method_dict[method][i], n_output=n_output) for i in split_idx_list]
    return global_stats
            

def _method_dict_to_json(method_dict):
    
    method_dict_json = dict()
    for method in method_dict:
        method_item = method_dict[method]
        
        method_dict_json[method] = []
        for item in method_item:
            if isinstance(item, list): # item is list of Dataframes
                method_dict_json[method].append([df.to_json() for df in item])
            else: # item is Dataframe
                method_dict_json[method].append(item.to_json())
    
    return method_dict_json

def _read_splits_from_exp_dir(exp_dir, dataset_id):
    
    dir_files = get_dir_files(exp_dir, dataset_id)
    split_modes = dir_files['data_dict'].keys()
    
    splits = {}
    for split_mode in split_modes:
        
        splits[split_mode] = []
        for fold_idx in sorted(dir_files['data_dict'][split_mode]):
            with gzip.open(dir_files['data_dict'][split_mode][fold_idx]) as f:
                data_dict = json.load(f)
            
            if with_valset:
                splits[split_mode].append((data_dict['train_idxs'], data_dict['val_idxs'], data_dict['test_idxs']))
            else:
                splits[split_mode].append((data_dict['train_idxs'], data_dict['test_idxs']))
            
    return splits

def _store_model(net_dict, net, method_identifier):
    
    if isinstance(net, list):
        for i, subnet in enumerate(net):
            net_dict['%s_sub=%d' % (method_identifier, i)] = copy.deepcopy(subnet.state_dict())
    else:
        net_dict[method_identifier] = copy.deepcopy(net.state_dict())
            
if reuse_exp_dir and datasets is None:
    datasets = [ds for ds in os.listdir(exp_dir) if ds in available_datasets]
    print("Reusing datasets: ", datasets)
    
start_ = time.time()
for dataset_id in datasets: 
    
    X, y = load_dataset(dataset_id)
    n_feat = X.shape[1]
    
    net_params_ = dict(net_params)
    train_params_ = dict(train_params)
    
    if dataset_id in very_large_datasets:
        fold_idxs = [0, 3, 5, 7, 9]
        split_idxs = [spl for spl in use_splits if spl in ['random_folds', 'label_folds', 'pca_folds']]
        train_params_['num_epochs'] = 150
        train_params_['batch_size'] = 500
    
    elif dataset_id in large_datasets:
        fold_idxs = [0, 3, 5, 7, 9]
        split_idxs = [spl for spl in available_splits if spl in use_splits]
        train_params_['num_epochs'] = 150 
    else:
        fold_idxs = list(range(10))
        split_idxs = [spl for spl in available_splits if spl in use_splits]
    
    reuse_exp_dir_dataset = reuse_exp_dir
    if reuse_exp_dir_dataset:
        try:
            splits = _read_splits_from_exp_dir(exp_dir, dataset_id)
            projections = compute_pca_projections(X)
        except FileNotFoundError:
            reuse_exp_dir_dataset = False
    
    if not reuse_exp_dir_dataset:
        splits = compute_idx_splits(X, y, fold_idxs=fold_idxs, splits=split_idxs, train_perc=train_perc, val_perc=val_perc, folds_with_val=folds_with_val) # use 10-folds
        projections = splits['projections']
    
    for split_mode in splits:
        
        if split_mode == 'projections':
            continue
        
        folds = splits[split_mode]        
        if (type(folds) == tuple) and (len(folds) in [2, 3]):
            folds = [folds]
        
        for fold_idx, split_idxs in enumerate(folds):
            
            if with_valset:
                train_idxs, val_idxs, test_idxs = split_idxs
            else:
                train_idxs, test_idxs = split_idxs
            
            identifier = 'dataset=%s_splitmode=%s_foldidx=%d' % (dataset_id, split_mode, fold_idx)
            
            X_train = X[train_idxs]
            X_test = X[test_idxs]
            y_train = y[train_idxs]
            y_test = y[test_idxs]
            train_projections, test_projections = projections[train_idxs], projections[test_idxs]
            
            if with_valset:
                X_val = X[val_idxs]
                y_val = y[val_idxs]
                val_projections = projections[val_idxs]
                X_train, y_train, X_test, y_test, X_val, y_val, X_scaler, y_scaler = scale_to_standard(X_train, y_train, X_test, y_test, X_val, y_val)
            else:
                X_train, y_train, X_test, y_test, _, _, X_scaler, y_scaler = scale_to_standard(X_train, y_train, X_test, y_test)
            
            # choose a bunch of uncertainty methods and train the respective models 
            method_dict = {}
            method_dict_json = {}
            method_dict_epochs = {}
            net_dict = {}
            for method in methods:
                
                method_identifier = '%s_method=%s' % (identifier, method)
                print(method_identifier)
                
                net = get_net_from_method(method, n_feat, len(X_train), net_params_, train_params_, n_output=n_output) 
                print(net_params_, train_params_)
                

                if eval_epochs:
                    train_network(net=net, data=[X_train, y_train], train_params=train_params_, method=method, 
                                  epoch_callback=lambda net_, epoch: _local_stats(net_, method, X_train, y_train, X_test, y_test, method_dict_epochs,
                                                                     train_projections=train_projections, test_projections=test_projections, epoch=epoch))
                else:
                    train_network(net=net, data=[X_train, y_train], train_params=train_params_, method=method)


                if with_valset:
                    _local_stats(net, method, X_train, y_train,  X_test, y_test, method_dict, 
                                 X_val, y_val, train_projections, test_projections, val_projections)
                else:
                    _local_stats(net, method, X_train, y_train, X_test, y_test, method_dict,
                                 train_projections=train_projections, test_projections=test_projections)

                _store_model(net_dict, net, method_identifier)
            
            exp_dataset_dir = '%s/%s' % (exp_dir, dataset_id)
            os.makedirs(exp_dataset_dir, exist_ok=True)
            
            if not reuse_exp_dir_dataset:
                data_dict = {'X_mean': X_scaler.mean_.tolist(), 'X_scale': X_scaler.scale_.tolist(), 'y_mean': y_scaler.scale_.tolist(), 'y_scale': y_scaler.scale_.tolist(), 
                             'train_idxs': train_idxs.tolist(), 'test_idxs': test_idxs.tolist()}
                if with_valset:
                    data_dict['val_idxs'] = val_idxs.tolist()
    
                with gzip.open('%s/data_dict_%s.json.zip' % (exp_dataset_dir, identifier), 'wt', encoding='ascii') as fp:
                    json.dump(data_dict, fp)
            
            if reuse_exp_dir_dataset:
                dir_files = get_dir_files(exp_dir, dataset_id)
                prev_method_dict = load_method_dict(dir_files, split_mode, folds=[fold_idx])[0]
                prev_method_dict.update(method_dict)
                method_dict = prev_method_dict
            
            method_dict_json = _method_dict_to_json(method_dict)
            with gzip.open('%s/method_dict_%s.json.zip' % (exp_dataset_dir, identifier), 'wt', encoding='ascii') as fp:
                json.dump(method_dict_json, fp)
            
            if n_output == 1:
                if not reuse_exp_dir_dataset:
                    plot_results(method_dict, '%s/%s.png' % (exp_dataset_dir, identifier), with_valset=with_valset)
                          
            # print global statistics for the different methods (for both train and test)
            global_stats = _global_stats(method_dict)
            with gzip.open('%s/global_stats_%s.json.zip' % (exp_dataset_dir, identifier), 'wt', encoding='ascii') as fp:
                json.dump(global_stats, fp)
            
            if eval_epochs:
                if reuse_exp_dir_dataset:
                    dir_files = get_dir_files(exp_dir, dataset_id)
                    prev_method_dict_epochs = load_method_dict(dir_files, dataset_id, folds=[fold_idx], epochs=True)
                    for epoch in method_dict_epochs:
                        if epoch in prev_method_dict_epochs:
                            prev_method_dict_epochs[epoch].update(method_dict_epochs[epoch])
                        else:
                            prev_method_dict_epochs[epoch] = method_dict_epochs[epoch]
                    method_dict_epochs = prev_method_dict_epochs
                
                global_stats_epochs = {epoch: _global_stats(method_dict_epochs[epoch]) for epoch in method_dict_epochs}
                with gzip.open('%s/global_stats_epochs_%s.json.zip' % (exp_dataset_dir, identifier), 'wt', encoding='ascii') as fp:
                    json.dump(global_stats_epochs, fp)
            
            if reuse_exp_dir_dataset:
                dir_files = get_dir_files(exp_dir, dataset_id)
                prev_net_dict = torch.load(dir_files['model'][split_mode][fold_idx])
                prev_net_dict.update(net_dict)
                net_dict = prev_net_dict
            
            torch.save(net_dict, '%s/model_%s.pt' % (exp_dataset_dir, identifier))

print(time.time() - start_)