## Load data

In [1]:
import sys
 
sys.path.insert(0, "../../")

In [2]:
from dataAnalysis.DataAnalysis import DataAnalysis
import pandas as pd

data = pd.read_csv(r"../../extdata/sbcdata.csv", header=0)
data_analysis = DataAnalysis(data)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.data['Label'] = self.data['Diagnosis']


Training: 
Assessable data are 528101 cases and 1015074 CBCs
Control data are 527038 cases and 1013548 CBCs
Sepsis data are 1488 cases and 1526 CBCs
$$$$$$$$$$$$$$$$$$$$
Testing: 


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.data['Label'] = self.data['Diagnosis']


Controls: 365794, Sepsis: 490
Assessable data are 180494 cases and 366284 CBCs
Control data are 180157 cases and 365794 CBCs
Sepsis data are 472 cases and 490 CBCs


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.data['Label'] = self.data['Diagnosis']


Controls: 437629, Sepsis: 448
Assessable data are 157922 cases and 438077 CBCs
Control data are 180157 cases and 437629 CBCs
Sepsis data are 438 cases and 448 CBCs


In [3]:
import torch

y_train = torch.tensor(data_analysis.get_y_train(), dtype=torch.long)
X_train = torch.tensor(data_analysis.get_X_train(), dtype=torch.float)

y_test = torch.tensor(data_analysis.get_y_test(), dtype=torch.long)
X_test = torch.tensor(data_analysis.get_X_test(), dtype=torch.float)

y_gw_test = torch.tensor(data_analysis.get_y_gw(), dtype=torch.long)
X_gw_test = torch.tensor(data_analysis.get_X_gw(), dtype=torch.float)

In [4]:
y_all = torch.concat((y_train, y_test, y_gw_test))
X_all = torch.concat((X_train, X_test, X_gw_test))

## Train/Validation/Test splits

In [5]:
def true_indices_like(tensor):
    return torch.ones((tensor.shape[0])).type(torch.bool)

def false_indices_like(tensor):
    return torch.zeros((tensor.shape[0])).type(torch.bool)

def split(train_features):
    tensor = true_indices_like(train_features)
    max_index = round(tensor.shape[0] * 0.8)
    train = torch.zeros(tensor.shape[0])
    train[:max_index] = 1
    
    val = torch.zeros(tensor.shape[0])
    val[max_index:] = 1
    return{
        "train": train.type(torch.bool),
        "val":val.type(torch.bool)
    }
train_data = split(X_train)

train_mask = torch.concat((train_data["train"], false_indices_like(X_test), false_indices_like(X_gw_test)))
val_mask = torch.concat((train_data["val"], false_indices_like(X_test), false_indices_like(X_gw_test)))
test_l_mask = torch.concat((false_indices_like(X_train), true_indices_like(X_test), false_indices_like(X_gw_test)))
test_gw_mask = torch.concat((false_indices_like(X_train), false_indices_like(X_test), true_indices_like(X_gw_test)))

## Graph construction

In [6]:
import torch
import numpy as np
from dataAnalysis.Constants import *

def to_tensor(df):
    return torch.Tensor(list(df.values))

def get_quantil_tensor():
    number_of_quantiles = 10
    q = torch.arange(0, 1, 1/number_of_quantiles)
    q = torch.Tensor([0.025,0.05, 0.1, 0.2, 0.35, 0.5, 0.65, 0.8, 0.9, 0.95, 0.975, 1])
    return q

def get_quantiles(tensor):
    q = get_quantil_tensor() 
    return torch.quantile(tensor, q)

def normalize(tensor):
    mean = torch.mean(tensor, dim = 0)
    std = torch.std(tensor, dim = 0)
    mean_diff = tensor - mean
    return mean_diff / std

def get_quantile_indices(tensor, quantiles):
    quantile_indices = []
    all_indices = torch.Tensor([])
    prev_quantile = -1e-4
    indices_control = torch.arange(0, tensor.shape[0])
    for i in range(quantiles.nelement()):
        indices_u = (tensor > prev_quantile).nonzero(as_tuple=True)[0] # (tensor > prev_quantile and tensor <= quantiles[i]).nonzero(as_tuple=True)[0]
        indices_o = (tensor <= quantiles[i]).nonzero(as_tuple=True)[0]
        indices = torch.from_numpy(np.intersect1d(indices_u, indices_o))
        quantile_indices.append(indices)
        prev_quantile = quantiles[i]
    return quantile_indices


def create_node_features(node_type, quantiles):
    nodes_features = []
    prev_quantile = torch.Tensor([0])
    for i in range(quantiles.nelement()):
        node_features = [prev_quantile.item(), quantiles[i].item(), get_quantil_tensor()[i].item()]
        prev_quantile = quantiles[i]
        nodes_features.append(node_features)
    return torch.tensor(nodes_features)

def create_edge_features_to_patient(node_type, quantile_indices):
    source_edge_list = None
    target_edge_list = None
    for i in range(len(quantile_indices)):
        target_edges = torch.ones((quantile_indices[i].nelement())) * i
        source_edges = quantile_indices[i]
        source_edge_list = source_edges if source_edge_list is None else torch.concat((source_edge_list, source_edges))
        target_edge_list = target_edges if target_edge_list is None else torch.concat((target_edge_list, target_edges))
    return torch.stack([source_edge_list, target_edge_list]).type(torch.long)

def add_features_and_edges(graph):
    for i, feature_name in enumerate(FEATURES):
        if feature_name not in [HGB_COLUMN_NAME, WBC_COLUMN_NAME, RBC_COLUMN_NAME, MCV_COLUMN_NAME, PLT_COLUMN_NAME]:
            continue
        feature_vector = graph[PATIENT_NAME].x[:, i]
        node_quantiles = get_quantiles(feature_vector)
        quantile_indices = get_quantile_indices(feature_vector, node_quantiles)
        graph[feature_name].x = create_node_features(feature_name, node_quantiles)
        graph[PATIENT_NAME, EDGE_TYPE, feature_name].edge_index = create_edge_features_to_patient(feature_name, quantile_indices)

In [7]:
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T
from imblearn.under_sampling import RandomUnderSampler

def normalize(tensor):
    mean = torch.mean(tensor, dim = 0)
    std = torch.std(tensor, dim = 0)
    mean_diff = tensor - mean
    return mean_diff / std

graph = HeteroData()
graph[PATIENT_NAME].x = X_all
add_features_and_edges(graph)
graph[PATIENT_NAME].y = y_all
graph[PATIENT_NAME].train_mask = train_mask
graph[PATIENT_NAME].val_mask = val_mask
graph[PATIENT_NAME].test_l_mask = test_l_mask
graph[PATIENT_NAME].test_gw_mask = test_gw_mask
graph = T.ToUndirected()(graph)
graph[PATIENT_NAME].x = normalize(graph[PATIENT_NAME].x)

In [8]:
graph

HeteroData(
  [1mPATIENT[0m={
    x=[1819435, 7],
    y=[1819435],
    train_mask=[1819435],
    val_mask=[1819435],
    test_l_mask=[1819435],
    test_gw_mask=[1819435]
  },
  [1mHGB[0m={ x=[12, 3] },
  [1mWBC[0m={ x=[12, 3] },
  [1mRBC[0m={ x=[12, 3] },
  [1mMCV[0m={ x=[12, 3] },
  [1mPLT[0m={ x=[12, 3] },
  [1m(PATIENT, HAS, HGB)[0m={ edge_index=[2, 1819435] },
  [1m(PATIENT, HAS, WBC)[0m={ edge_index=[2, 1819435] },
  [1m(PATIENT, HAS, RBC)[0m={ edge_index=[2, 1819435] },
  [1m(PATIENT, HAS, MCV)[0m={ edge_index=[2, 1819435] },
  [1m(PATIENT, HAS, PLT)[0m={ edge_index=[2, 1819435] },
  [1m(HGB, rev_HAS, PATIENT)[0m={ edge_index=[2, 1819435] },
  [1m(WBC, rev_HAS, PATIENT)[0m={ edge_index=[2, 1819435] },
  [1m(RBC, rev_HAS, PATIENT)[0m={ edge_index=[2, 1819435] },
  [1m(MCV, rev_HAS, PATIENT)[0m={ edge_index=[2, 1819435] },
  [1m(PLT, rev_HAS, PATIENT)[0m={ edge_index=[2, 1819435] }
)

### Model defintion

In [10]:
import torch.nn.functional as F
from torch_geometric.nn import HANConv, HGTConv, FiLMConv, Linear
import torch
from dataAnalysis.Constants import FEATURES


class HetGraphNeuralNetwork(torch.nn.Module):
    def __init__(self, metadata,node_types, hidden_dim = 128, out_channels = 1):
        super(HetGraphNeuralNetwork, self).__init__()
                 
        self.lin_dict = torch.nn.ModuleDict()
        for node_type in node_types:
            self.lin_dict[node_type] = Linear(-1, hidden_dim)
        self.lin_end = Linear(-1, out_channels)
        self.conv1 = HANConv(-1,hidden_dim, metadata)
        self.conv_end = HANConv(hidden_dim, hidden_dim, metadata)


    def forward(self, x_dict, edge_index):
        for node_type, x in x_dict.items():
            x_dict[node_type] = self.lin_dict[node_type](x).relu_()
        x_dict = self.conv1(x_dict, edge_index)
        x_dict = self.conv_end(x_dict, edge_index)
        x_dict[PATIENT_NAME] = self.lin_end(x_dict[PATIENT_NAME])
        
        return x_dict

## Shift data to device

In [11]:
device = torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu")

graph = graph.to(device)

sepsis_cases = torch.count_nonzero(graph[PATIENT_NAME].y[train_mask])
control_cases = graph[PATIENT_NAME].y[train_mask].size(dim=0) - sepsis_cases
WEIGHT = (control_cases / (sepsis_cases + 1e-10))
WEIGHT = WEIGHT.to(device)

print("Data shifted to the device " + str(device))



Data shifted to the device cuda:1


## Model-Wrapper Class

In [12]:
import torch 
from torch_geometric.nn import to_hetero
from tqdm.notebook import tqdm

class ModelWrapper():
    def __init__(self, graph):
        self.LEARNING_RATE = 3e-4
        self.MAX_EPOCHS = 10000
        
        self.graph = graph
        model = HetGraphNeuralNetwork(graph.metadata(),graph.node_types, hidden_dim = 64, out_channels=1) 
#         model = GraphNeuralNetwork(hidden_dim = 32, out_channels=1)         
#         model = to_hetero(model, graph.metadata(), aggr='sum')
        self.model = model.to(device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.LEARNING_RATE,betas=(0.9, 0.999), eps=1e-08)
        
        self.last_loss = 0
        self.increased_loss = 0
        self.BREAKING_THRESHOLD = 10    
        self.val_loss = []
        self.train_loss = []
    
    def validate(self):
        with torch.inference_mode():
            self.model.eval()
            out = self.model(self.graph.x_dict, self.graph.edge_index_dict)[PATIENT_NAME]
            loss = F.binary_cross_entropy_with_logits(torch.squeeze(out[val_mask]), self.graph[PATIENT_NAME].y[val_mask].type(torch.float32),
                                                      pos_weight=WEIGHT)
            self.val_loss.append(loss.item())
            if loss.item() > self.last_loss:
                self.increased_loss += 1
            else:
                self.increased_loss = 0
            self.last_loss = loss.item()

    def train(self):
        for epoch in tqdm(range(self.MAX_EPOCHS)):
#             print(epoch)
            self.model.train()
            self.optimizer.zero_grad()
            out = self.model(self.graph.x_dict, self.graph.edge_index_dict)[PATIENT_NAME]
            loss = F.binary_cross_entropy_with_logits(torch.squeeze(out[train_mask]), self.graph[PATIENT_NAME].y[train_mask].type(torch.float32),
                                                      pos_weight=WEIGHT)
            self.train_loss.append(loss.item())
            loss.backward()
            self.optimizer.step()
            self.validate() 

            if self.increased_loss >= self.BREAKING_THRESHOLD:
                print(f"Breaked at {str(epoch)}")
                break
                
            
    def get_model(self):
        return self.model    
    
    def plot_loss(self):
        plt.plot(self.epochs, self.train_loss, 'g', label='Training loss')
        plt.plot(self.epochs, self.val_loss, 'y', label='Validation loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.show()

In [None]:
model_wrapper = ModelWrapper(graph)
import time 
print("HAN-Conv")
start = time.time()
model_wrapper.train()
print(time.time() - start )
model = model_wrapper.get_model()

HGT-Conv


  0%|          | 0/10000 [00:00<?, ?it/s]

In [None]:
graph

## Error evaluation

In [26]:
import time

number_of_iterations = 100
dataframes = []
for i in range(number_of_iterations):
    graph = graph.to(device)
    
    
    start = time.time()
    model_wrapper = ModelWrapper(graph)
    model_wrapper.train()
    print(time.time() - start)
    
    def predict_proba(graph, mask):
        with torch.inference_mode():
            model.eval()
            logits = model(graph.x_dict, graph.edge_index_dict)[PATIENT_NAME]
            scores = torch.sigmoid(torch.squeeze(logits[mask]))
            scores = torch.unsqueeze(scores, 0)
            proba_predict = torch.concat((1- scores, scores), dim = 0)
        return torch.transpose(proba_predict, 0, 1)

    def predict(graph, mask):
        pred = torch.round(predict_proba(graph, mask)[:, 1])
        return pred
    model = model_wrapper.get_model()
    model.predict_proba = predict_proba
    model.predict = predict
    
    from dataAnalysis.Metrics import Evaluation

    graph = graph.cpu()
    model = model.cpu()
    evaluation = Evaluation(y_test, y_gw_test, X_test, X_gw_test)
    evaluation.set_test_args([graph, test_l_mask])
    evaluation.set_gw_args([graph, test_gw_mask])
    
    df = evaluation.get_df_metrics(model)
    dataframes.append(df)
for df in dataframes:
    print(df)

  0%|          | 0/1 [00:00<?, ?it/s]

1.4528789520263672


KeyboardInterrupt: 

## Evaluation

In [None]:
def predict_proba(graph, mask):
    with torch.inference_mode():
        model.eval()
        logits = model(graph.x_dict, graph.edge_index_dict)[PATIENT_NAME]
        scores = torch.sigmoid(torch.squeeze(logits[mask]))
        scores = torch.unsqueeze(scores, 0)
        proba_predict = torch.concat((1- scores, scores), dim = 0)
    return torch.transpose(proba_predict, 0, 1)

def predict(graph, mask):
    pred = torch.round(predict_proba(graph, mask)[:, 1])
    return pred

model.predict_proba = predict_proba
model.predict = predict

In [None]:
from dataAnalysis.Metrics import Evaluation

graph = graph.cpu()
model = model.cpu()
evaluation = Evaluation(y_test, y_gw_test, X_test, X_gw_test)
evaluation.set_test_args([graph, test_l_mask])
evaluation.set_gw_args([graph, test_gw_mask])

In [None]:
evaluation.plot_confusion_matrix(model)
evaluation.get_df_metrics(model)