In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline



import os, math
import joblib
from joblib import Parallel, delayed
from collections import Counter, defaultdict

import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
import numpy as np
import pandas as pd
import cvxpy as cp
import networkx as nx
import torch

from sklearn.preprocessing import normalize
from sklearn.cluster import KMeans


import pytorch_lightning as pl
pl.utilities.seed.seed_everything(seed=0)

import matplotlib
from matplotlib import rc, ticker
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter, FormatStrFormatter
from mpl_toolkits.axes_grid1 import make_axes_locatable

rc('text', usetex=True)
rc('font', family='serif')

# OWN MODULES
from src.data.data_module import UNOSDataModule, UKRegDataModule, UNOS2UKRegDataModule
from src.models.organsync import OrganSync_Network

In [None]:
from pathlib import Path

# DATA PARAMS
data = 'UKReg'
batch_size = 256
synth=False

root_data_dir = Path("datasets").absolute()

# LOAD DATA
if data == 'UNOS':
    project = 'organsync-net'
    data_dir = root_data_dir / 'processed_UNOS'
    dm = UNOSDataModule(data_dir, batch_size=batch_size, is_synth=synth)
    dm.prepare_data()
elif data == 'U2U':
    project = 'organsync-net-u2u'
    data_dir = root_data_dir / 'processed_UNOS2UKReg_no_split'
    dm = UNOS2UKRegDataModule(data_dir, batch_size=batch_size, is_synth=synth)
    dm.prepare_data()
else:
    project = 'organsync-net-ukreg'
    data_dir = root_data_dir / 'processed_UKReg'
    dm = UKRegDataModule(data_dir, batch_size=batch_size, is_synth=synth)
    dm.prepare_data()

dm.setup(stage='test')
dm.setup(stage='fit')    



In [None]:
from pytorch_lightning import Trainer

dm.prepare_data()
dm.setup(stage='fit')
 
lr = .005
gamma = 0.9
lambd = 0.5
weight_decay = 1e-3
epochs = 30
batch_size = 128
num_hidden_layers = 1
output_dim = 8
hidden_dim = 16
activation_type = 'relu'
dropout_prob = 0.0
control = False
is_synth = False
test_size = 0.05

# CONSTRUCT MODEL
input_dim = dm.size(1)
model = OrganSync_Network(
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    num_hidden_layers=num_hidden_layers,
    output_dim=output_dim,
    lr=lr, gamma=gamma, lambd=lambd, weight_decay=weight_decay,
    activation_type=activation_type,
    dropout_prob=dropout_prob).double()
 
# TRAIN NETWORK
trainer = Trainer(callbacks=[], max_epochs=epochs)
trainer.fit(model, datamodule=dm)
 
# TEST NETWORK
trainer.test(datamodule=dm)


In [None]:
# REPRESENTATION FROM DATA
resolution_k  = 10
n_per_cluster = 50

n = 1000

X, O, Y, _ = dm.train_dataloader().dataset.dataset.tensors
X_t, O_t, Y_t, _ = dm.test_dataloader().dataset.tensors

with torch.no_grad():
    U = model.representation(torch.cat((X, O), dim=1))
    u = model.representation(torch.cat((X_t, O_t), dim=1))

cluster= KMeans(n_clusters=resolution_k)

cluster.fit(U)

print('Size of c(U):', Counter(cluster.labels_))

In [None]:
# SELECT PATIENTS FROM CLUSTERS
# this is done on the test set

patients = np.empty((0, n_per_cluster), dtype=int)

cluster_labels = np.arange(0, resolution_k, 1)
test_cluster_labels = cluster.predict(u)

for label in cluster_labels:
    patients_of_label = np.where(test_cluster_labels == label)[0]
    patients_in_label = patients_of_label[np.random.randint(0, len(patients_of_label), (n_per_cluster,))]
    
    patients = np.append(patients, patients_in_label.reshape(1, -1), axis=0)

In [None]:
# PER CLUSTER BUILD u
u_per_cluster = np.empty((0, n_per_cluster, model.output_dim))

for ps in patients:
    u_per_cluster = np.append(u_per_cluster, u[ps].view(1, n_per_cluster, -1), axis=0)


In [None]:
# a_s per model -> (resolution_k, n_per_cluster, len(U)) => per (cluster, patient, a_s)
A_s = np.empty((0, n_per_cluster, n))
U_limited_indices = torch.randint(0, len(U), (n,))
U_limited = U[U_limited_indices]
U_labels = cluster.predict(U_limited)

In [None]:
# PER CLUSTER COMPUTE a
# NOTE: this cell comprises the bulk of 
#   the computation; might run long.

lambda_ = .1

def convex_opt(u, U, lambd):
    a = cp.Variable(U.shape[0])

    objective = cp.Minimize(cp.norm2(a@U - u)**2 + lambd * cp.norm1(a))
    constraints = [0 <= a, a <= 1, cp.sum(a) == 1]
    prob = cp.Problem(objective, constraints)

    _ = prob.solve(warm_start=True, solver=cp.SCS)

    return a.value

print('-- STARTING --')

for i, u_s_in_cluster in enumerate(u_per_cluster):
    a = Parallel(n_jobs=joblib.cpu_count())(delayed(convex_opt)(u_, U_limited, lambda_) for u_ in u_s_in_cluster)
    A_s = np.append(A_s, np.array(a)[:].reshape(1, n_per_cluster, -1), axis=0)
    print(f'---- finished cluster {i}')
print('-- FINISHED --')


In [None]:
# BUILD MATRIX FROM a
#    INFO: every cluster connects to every other cluster.
#      On each row, there is the amount of the column
#      the cluster has on other clusters.
threshold = 1e-2


M = np.empty((resolution_k, resolution_k))

filtered = np.where(A_s >= threshold, A_s, np.zeros(A_s.shape))

for i, r in enumerate(M):
    sample_U = np.repeat(U_labels[np.newaxis, :], n_per_cluster, axis=0)

    sample = sample_U[filtered[i,:,:].astype(bool)]

    unique, counts = np.unique(sample, return_counts=True)
    label_distribution = dict(zip(unique, counts))
    
    M[i, list(label_distribution.keys())] = list(label_distribution.values())

M = normalize(M, axis=1, norm='l1')

In [None]:
# PLOT MATRIX (XO)
def plot_matrix(m, require_colorbar, title):
    rc('axes', linewidth= 4.5) 

    require_colorbar = True

    fig, ax = plt.subplots(figsize=(5,5))

    ax.set_yticks(np.arange(0,resolution_k, 1))
    ax.set_yticklabels(np.arange(0, resolution_k, 1), fontsize=25)
    ax.set_xticks(np.arange(0,resolution_k, 1))
    ax.set_xticklabels(np.arange(0, resolution_k, 1), fontsize=25)

    ax.set_ylabel('composed of',  fontsize=20)
    ax.set_xlabel('contributes to',  fontsize=20)

    ax.tick_params(length=10, width=2)    

    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    im = ax.imshow(m)

    divider = make_axes_locatable(ax)


    ax.set_title(title, fontsize=25)

    if require_colorbar:
        cax = divider.append_axes("right", size='10%', pad=.2)
        cax.tick_params(length=5, width=1)
        
        print(cax.yaxis.get_ticklabels())

        fig.colorbar(im, cax=cax, ticks=np.arange(0, M.max(), .1))
        cax.set_yticklabels(["{:.1f}".format(i) for i in np.arange(0, M.max(), .1)], fontsize=15)
        
    return fig

f = plot_matrix(M, True, data)

In [None]:
# SAVE FIGURE
# SAVE RESULTS
fig_detail=data


f.savefig(f'{fig_detail}_u_composition.pdf', bbox_inches = "tight")
np.save(f'{fig_detail}_M', M)



In [None]:
fig_name=f'{data} (semi-synth.)'
name=f'{data}'

M_l = np.load(f'./{name}_M.npy')
f = plot_matrix(M_l, True, fig_name)
f.savefig(f'{name}_u_composition.pdf', bbox_inches = "tight")

In [None]:
# SINGLE EXAMPLE
threshold = 1e-2

# GET LATEST PATIENT (in regyr)
tmp = dm._test_processed.copy(deep=True)
tmp.loc[:,dm.real_cols] = dm.scaler.inverse_transform(tmp[dm.real_cols])
patient_index = 2266#tmp.regyr.argmax()

X, O, _, d = dm.test_dataloader().dataset[patient_index]
row = dm._test_processed.iloc[[patient_index]].copy(deep=True)
row[dm.real_cols] = dm.scaler.inverse_transform(row[dm.real_cols])

with torch.no_grad():
    u_single = model.representation(torch.cat((X, O), dim=0).view(1, -1).double())

a = convex_opt(u_single.flatten(), U_limited, lambda_)
a_filtered = np.where(a >= threshold, a, np.zeros(a.shape))
a_filtered_indices = np.nonzero(a_filtered)[0]

contributors = dm._train_processed.iloc[a_filtered_indices].copy(deep=True)
contributors.loc[:, dm.real_cols] = dm.scaler.inverse_transform(contributors[dm.real_cols])

contributors['contribution'] = a_filtered[a_filtered_indices]


row.SERUM_BILIRUBIN = np.exp(row.SERUM_BILIRUBIN)
row.SERUM_CREATININE = np.exp(row.SERUM_CREATININE)

contributors.SERUM_BILIRUBIN = np.exp(contributors.SERUM_BILIRUBIN)
contributors.SERUM_CREATININE = np.exp(contributors.SERUM_CREATININE)

In [None]:
# EXAMPLE
row

In [None]:
# CONTRIBUTORS
contributors.sort_values(by='contribution', ascending=False)

In [None]:
contributors.contribution.to_numpy() @ contributors.Y.to_numpy()

In [None]:
D = dm._test_processed.copy(deep=True)

D[dm.real_cols] = dm.scaler.inverse_transform(dm._test_processed[dm.real_cols])
D.SERUM_BILIRUBIN = np.exp(D.SERUM_BILIRUBIN)
D.SERUM_CREATININE = np.exp(D.SERUM_CREATININE)

In [None]:
D[(D.SERUM_SODIUM == 140) & (D.regyr == 2020)]

In [None]:
D.loc[18004]