In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline



import os, math
import joblib
from joblib import Parallel, delayed
import sys
from pathlib import Path

dir_path = Path(os.getcwd()).absolute()
module_path = str(dir_path.parent.parent.parent)

if module_path not in sys.path:
    sys.path.append(module_path)
    
import numpy as np
import pandas as pd
import cvxpy as cp
import torch

import pytorch_lightning as pl

import matplotlib
from matplotlib import rc
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from pathlib import Path

# OWN MODULES
from organsync.models.organsync_network import OrganSync_Network

from experiments.data.utils import get_data_tuples
from experiments.data.data_module import UNOSDataModule, UKRegDataModule, UNOS2UKRegDataModule


rc('text', usetex=True)
rc('font', family='serif')
torch.manual_seed(0)

In [None]:
data = 'U2U'
batch_size = 256
synth=True

if data == 'UNOS':
    project = 'organsync-net'
    data_dir = '../data/processed'
    dm = UNOSDataModule(data_dir, batch_size=batch_size, is_synth=synth)
elif data == 'U2U':
    project = 'organsync-net-u2u'
    data_dir = '../data/processed_UNOS2UKReg_no_split'
    dm = UNOS2UKRegDataModule(data_dir, batch_size=batch_size, is_synth=synth, control=False)
    dm.prepare_data()
else:
    project = 'organsync-net-ukreg'
    data_dir = '../data/processed_UKReg/clinical_ukeld_2_ukeld'
    dm = UKRegDataModule(data_dir, batch_size=batch_size, is_synth=synth)
    dm.prepare_data()
    

In [None]:
dm.setup(stage='fit')
dm.setup(stage='test')


dm.train_dataloader().dataset.dataset.tensors[1].size(0) + dm.val_dataloader().dataset.dataset.tensors[1].size(0)+ dm.test_dataloader().dataset.tensors[1].size(0)

In [None]:
dm._train_processed

In [None]:
model_id = '30iz1r27'
params = wandb.restore(f'organsync_net.ckpt.ckpt', run_path=f'jeroenbe/{project}/{model_id}', replace=True)
model = OrganSync_Network.load_from_checkpoint(params.name).double()
trainer = pl.Trainer()
#trainer.test(model, datamodule=dm)

In [None]:
dm.setup(stage='test')
dm.setup(stage='fit')

In [None]:
train_length = len(dm.train_dataloader().dataset.dataset)
X, O, Y, delt = dm.train_dataloader().dataset.dataset[torch.randint(0, train_length,  (500,))]#

catted = torch.cat((X, O), dim=1).double()


with torch.no_grad():
    U = model.representation(catted)

In [None]:
n = 50
test_length = len(dm.test_dataloader().dataset)

with torch.no_grad():
    X_new, O_new, Y_, delt_= dm.test_dataloader().dataset[torch.randint(0, test_length,  (n,))]
    new_pair = torch.cat((X_new, O_new), dim=1).double() 
    u = model.representation(new_pair) # synth_u
    

In [None]:
lambd=np.linspace(0, .2, 11)
result=dict()


def convex_opt(u, lambd):
    a = cp.Variable(U.shape[0])

    objective = cp.Minimize(cp.norm2(a@U - u)**2 + lambd * cp.norm1(a))
    constraints = [0 <= a, a <= 1, cp.sum(a) == 1]
    prob = cp.Problem(objective, constraints)

    _ = prob.solve(warm_start=True, solver=cp.SCS)

    return a.value, (a.value @ Y.numpy()).item()

for l in lambd:
    res = Parallel(n_jobs=joblib.cpu_count())(delayed(convex_opt)(u_, l) for u_ in u)
    result[f'lambda: {l}'] = res

In [None]:
RES = result#.shape# = np.array(result, dtype=object)

In [None]:
synth_rmse = np.array([])
synth_diff = np.array([])
synth_diff_std = np.array([])
synth_std = np.array([])
avg_a = np.array([])
std_a = np.array([])

avg_top_a = np.array([])
std_top_a = np.array([])

for k in RES.keys():
    y_ = np.array(RES[k], dtype=object)[:,1].astype(float)
    y_diff = y_ * dm.std + dm.mean
    
    rmse = np.sqrt((y_ - Y_.numpy())**2)
    rmse_diff = np.abs(y_diff - (Y_.numpy() * dm.std + dm.mean))
    
    synth_rmse = np.append(synth_rmse, rmse.mean())
    synth_std = np.append(synth_std, rmse.std())
    
    synth_diff = np.append(synth_diff, rmse_diff.mean())
    synth_diff_std = np.append(synth_diff_std, synth_diff.std())
    
    
    a_s = np.array([*np.array(RES[k], dtype=object)[:,0]])
    as_3 = np.where(a_s > 1e-5, a_s, np.zeros(a_s.shape))

    as_3_nz = np.count_nonzero(as_3, axis=1)
    
    avg_a = np.append(avg_a, as_3_nz.mean())
    std_a = np.append(std_a, as_3_nz.std())
    
    k=50
    top_k = np.partition(a_s,-k, axis=1)[:,-k:]
    avg_top_a = np.append(avg_top_a, top_k.mean())
    std_top_a = np.append(std_top_a, top_k.std())
    
    

In [None]:
print(synth_rmse)
print(synth_std / np.sqrt(n))
print('---')

print(avg_a)
print(std_a / np.sqrt(n))

In [None]:
def my_formatter(x, pos):
    """Format 1 as 1, 0 as 0, and all values whose absolute values is between
    0 and 1 without the leading "0." (e.g., 0.7 is formatted as .7 and -0.4 is
    formatted as -.4)."""
    val_str = '{:g}'.format(x)
    if np.abs(x) > 0 and np.abs(x) < 1:
        return val_str.replace("0", "", 1)
    else:
        return val_str

major_formatter = FuncFormatter(my_formatter)


# FIGURE
fig, ax1 = plt.subplots(1,1, figsize=(8,6))

ax2 = ax1.twinx()

line_kwargs = {
    'capsize': 5, 'linewidth': 5, 'marker':'D', 'elinewidth': 1.5, 'markersize':15
}

blue='#0066CC'
green='#00FF80'


ax1.errorbar(lambd, synth_rmse, synth_std/np.sqrt(n), color=blue, **line_kwargs)
ax2.errorbar(lambd, avg_a, std_a/np.sqrt(n), color=green, linestyle='--', **line_kwargs)


ax1.set_ylim([0, 1e-04])
ax1.set_yticks([0, 5e-5, 1e-4])
ax1.set_yticklabels([0,  5e-5, 1e-4], fontsize=25, color=blue)
ax1.set_ylabel('n-RMSE', fontsize=30)
ax1.set_xlabel(r'$\lambda$', fontsize=30)
ax1.set_xticks(lambd)
ax1.set_xticklabels(lambd, fontsize=25)
ax1.xaxis.set_major_formatter(major_formatter)
ax1.set_title(f'UNOS (semi-synth.)', fontsize=35)


ax2.set_ylabel(r'$|\mathbf{a}_\mathtt{>1e-5}|$', fontsize=30)
ax2.set_yticks(np.linspace(250, 400, 3))
ax2.set_yticklabels(np.linspace(250, 400, 3).astype(int), fontsize=25, color='#00CC66')

fig.savefig(f'{data}_synth_a.pdf', bbox_inches = "tight")