In [110]:
import pandas as pd
from sklearn.preprocessing import QuantileTransformer

def count_cbc_cases(data):
    comp_data = data.query("~(WBC.isnull() & HGB.isnull() & MCV.isnull() & PLT.isnull() & RBC.isnull())",
                           engine='python')
    unique_data = comp_data.drop_duplicates(subset=["Id", "Center"])
    return len(unique_data)


def count_cbc(data):
    comp_data = data.query("~(WBC.isnull() & HGB.isnull() & MCV.isnull() & PLT.isnull() & RBC.isnull())",
                           engine='python')
    return len(comp_data)


class Features:
    def __init__(self, data):
        unique_data = data.drop_duplicates(subset=["Id", "Center", "Time"], keep=False)
        non_icu_unique_data = unique_data.query("~(Sender.str.contains('ICU')) & ~(~SecToIcu.isnull() & SecToIcu < 0)",
                                                engine='python')
        first_non_icu_unique_data = non_icu_unique_data.query("Episode == 1 ", engine='python')
        complete_first_non_icu_unique_data = first_non_icu_unique_data.query("~(WBC.isnull() | HGB.isnull() | "
                                                                             "MCV.isnull() | PLT.isnull() | "
                                                                             "RBC.isnull())", engine='python')
        sirs_complete_first_non_icu_unique_data = complete_first_non_icu_unique_data.query("Diagnosis != 'SIRS'",
                                                                                           engine='python')
        sirs_complete_first_non_icu_unique_data = \
            sirs_complete_first_non_icu_unique_data.query("(Diagnosis == 'Control') | ((Diagnosis == 'Sepsis') & ("
                                                          "~TargetIcu.isnull() & "
                                                          "TargetIcu.str.contains('MICU')))",
                                                                                           engine='python')
        self.data = sirs_complete_first_non_icu_unique_data
        self.data['Label'] = self.data['Diagnosis']
        self.data['W'] = self.data['Sex'] == "W"
        self.data['M'] = self.data['Sex'] == "M"

        control_filter = (self.data["Diagnosis"] == 'Control') | \
                         ((self.data["SecToIcu"] > 3600 * 6) & (
                                     ~self.data["TargetIcu"].isnull() & self.data["TargetIcu"]
                                     .str.contains('MICU', na=False)))
        sepsis_filter = (self.data["Diagnosis"] == 'Sepsis') & \
                        (self.data["SecToIcu"] <= 3600 * 6) & \
                        (self.data["TargetIcu"].str.contains('MICU', na=False))
        self.data.loc[control_filter, "Label"] = "Control"
        self.data.loc[sepsis_filter, "Label"] = "Sepsis"
        self.data["Label"] = self.data["Label"] == "Sepsis"

        self.control_data = self.data.loc[control_filter]
        self.sepsis_data = self.data.loc[sepsis_filter]

    def get_x(self):
        feature_columns = ["Age","Sex", "HGB", "PLT", "RBC", "WBC", "MCV"]
        return self.data.loc[:, feature_columns].replace(to_replace='W', value=1).replace(to_replace='M', value=0)#QuantileTransformer(n_quantiles=100).fit_transform(

    def get_y(self):
        return (self.data["Label"] == "Sepsis").astype(int) #self.data.loc[:, "Label"]

    def get_control_data(self):
        return self.control_data

    def get_sepsis_data(self):
        return self.sepsis_data

    def get_data(self):
        return self.data.sample(frac=1).reset_index()


In [111]:
class Training(Features):
    def __init__(self, data):
        leipzig_training_data = data.query("Center == 'Leipzig' & Set == 'Training'")
        Features.__init__(self, leipzig_training_data)


In [112]:
class Validation(Features):
    def __init__(self, data):
        leipzig_validation_data = data.query("Center == 'Leipzig' & Set == 'Validation'")
        Features.__init__(self, leipzig_validation_data)

In [113]:
class GreifswaldValidation(Features):
    def __init__(self, data):
        greifswald_validation_data = data.query("Center == 'Greifswald' & Set == 'Validation'")
        Features.__init__(self, greifswald_validation_data)

In [114]:
# TODO Implement greifswald validation
class DataAnalysis:
    def __init__(self, data):
        self.training = Training(data)
        print("Training: ")
        print(f"Assessable data are {count_cbc_cases(self.training.get_data())} cases "
              f"and {count_cbc(self.training.get_data())} CBCs")
        print(f"Control data are {count_cbc_cases(self.training.get_control_data())} cases "
              f"and {count_cbc(self.training.get_control_data())} CBCs")
        print(f"Sepsis data are {count_cbc_cases(self.training.get_sepsis_data())} cases "
              f"and {count_cbc(self.training.get_sepsis_data())} CBCs")
        print(20 * "$")
        print("Testing: ")
        self.validation = Validation(data)
        print(f"Controls: {self.validation.get_control_data().shape[0]},"
              f" Sepsis: {self.validation.get_sepsis_data().shape[0]}")
        print(f"Assessable data are {count_cbc_cases(self.validation.get_data())} cases "
              f"and {count_cbc(self.validation.get_data())} CBCs")
        print(f"Control data are {count_cbc_cases(self.validation.get_control_data())} cases "
              f"and {count_cbc(self.validation.get_control_data())} CBCs")
        print(f"Sepsis data are {count_cbc_cases(self.validation.get_sepsis_data())} cases "
              f"and {count_cbc(self.validation.get_sepsis_data())} CBCs")

        self.greifswald_vaidation = GreifswaldValidation(data)
        print(f"Controls: {self.greifswald_vaidation.get_control_data().shape[0]},"
              f" Sepsis: {self.greifswald_vaidation.get_sepsis_data().shape[0]}")
        print(f"Assessable data are {count_cbc_cases(self.greifswald_vaidation.get_data())} cases "
              f"and {count_cbc(self.greifswald_vaidation.get_data())} CBCs")
        print(f"Control data are {count_cbc_cases(self.validation.get_control_data())} cases "
              f"and {count_cbc(self.greifswald_vaidation.get_control_data())} CBCs")
        print(f"Sepsis data are {count_cbc_cases(self.greifswald_vaidation.get_sepsis_data())} cases "
              f"and {count_cbc(self.greifswald_vaidation.get_sepsis_data())} CBCs")

    def get_training_data(self):
        return self.training.get_data()

    def get_testing_data(self):
        return self.validation.get_data()


In [115]:
data = pd.read_csv(r"extdata/sbcdata.csv", header=0)
# data = pd.read_csv(r"extdata/sbc_small.csv", header=0)
data_analysis = DataAnalysis(data)

training = data_analysis.get_training_data()
testing = data_analysis.get_testing_data()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.data['Label'] = self.data['Diagnosis']
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.data['W'] = self.data['Sex'] == "W"
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.data['M'] = self.data['Sex'] == "M"
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .l

Training: 
Assessable data are 528101 cases and 1015074 CBCs
Control data are 527038 cases and 1013548 CBCs
Sepsis data are 1488 cases and 1526 CBCs
$$$$$$$$$$$$$$$$$$$$
Testing: 


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.data['Label'] = self.data['Diagnosis']
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.data['W'] = self.data['Sex'] == "W"
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.data['M'] = self.data['Sex'] == "M"
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .l

Controls: 365794, Sepsis: 490
Assessable data are 180494 cases and 366284 CBCs
Control data are 180157 cases and 365794 CBCs
Sepsis data are 472 cases and 490 CBCs


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.data['Label'] = self.data['Diagnosis']
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.data['W'] = self.data['Sex'] == "W"
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.data['M'] = self.data['Sex'] == "M"
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .l

Controls: 437629, Sepsis: 448
Assessable data are 157922 cases and 438077 CBCs
Control data are 180157 cases and 437629 CBCs
Sepsis data are 438 cases and 448 CBCs


## Constants

In [116]:
HGB_COLUMN_NAME = "HGB"
WBC_COLUMN_NAME = "WBC"
RBC_COLUMN_NAME = "RBC"
MCV_COLUMN_NAME = "MCV"
PLT_COLUMN_NAME = "PLT"

SEX_COLUMN_NAME = "Sex"
W_COLUMN_NAME = "W"
M_COLUMN_NAME = "M"
AGE_COLUMN_NAME = "Age"
PATIENT_NAME = "PATIENT"
EDGE_TYPE = "HAS"
LABEL_COLUMN_NAME = "Label"

## Functions for discretization of nodes

In [123]:
import torch
import numpy as np

def to_tensor(df):
    return torch.Tensor(list(df.values))

def get_quantil_tensor():
#     number_of_quantiles = 20
#     q = torch.arange(0, 1, 1/number_of_quantiles)
    q = torch.Tensor([0.025,0.05, 0.1, 0.2, 0.35, 0.5, 0.65, 0.8, 0.9, 0.95, 0.975, 1])
    return q

def get_quantiles(tensor):
    q = get_quantil_tensor() 
    return torch.quantile(tensor, q)

def normalize(tensor):
    mean = torch.mean(tensor, dim = 0)
    std = torch.std(tensor, dim = 0)
    mean_diff = tensor - mean
    return mean_diff / std

def get_quantile_indices(tensor, quantiles):
    quantile_indices = []
    all_indices = torch.Tensor([])
    prev_quantile = 0
    indices_control = torch.arange(0, tensor.shape[0])
    for i in range(quantiles.nelement()):
        indices_u = (tensor >= prev_quantile).nonzero(as_tuple=True)[0] # (tensor > prev_quantile and tensor <= quantiles[i]).nonzero(as_tuple=True)[0]
        indices_o = (tensor < quantiles[i]).nonzero(as_tuple=True)[0]
        indices = torch.from_numpy(np.intersect1d(indices_u, indices_o))
        quantile_indices.append(indices)
        prev_quantile = quantiles[i]
    final_indices = (tensor >= prev_quantile).nonzero(as_tuple=True)[0] 
    quantile_indices.append(final_indices)
    return quantile_indices


def create_node_features(node_type, quantiles):
    nodes_features = []
    prev_quantile = torch.Tensor([0])
    for i in range(quantiles.nelement()):
        node_features = [prev_quantile.item(), quantiles[i].item(), get_quantil_tensor()[i].item()]
        prev_quantile = quantiles[i]
        nodes_features.append(node_features)
    nodes_features.append([prev_quantile.item(), prev_quantile.item(), 1])
    return torch.tensor(nodes_features)

def create_edge_features_to_patient(node_type, quantile_indices):
    source_edge_list = None
    target_edge_list = None
    for i in range(len(quantile_indices)):
        target_edges = torch.ones((quantile_indices[i].nelement())) * i
        source_edges = quantile_indices[i]
        source_edge_list = source_edges if source_edge_list is None else torch.concat((source_edge_list, source_edges))
        target_edge_list = target_edges if target_edge_list is None else torch.concat((target_edge_list, target_edges))
    return torch.stack([source_edge_list, target_edge_list]).type(torch.long)

def add_features_and_edges(graph, node_type, dataframe, rus_indices):
    node_values = torch.Tensor(dataframe[node_type])[rus_indices] if rus_indices is not None else torch.Tensor(dataframe[node_type])
    node_quantiles = get_quantiles(node_values)
    quantile_indices = get_quantile_indices(node_values, node_quantiles)
    graph[node_type].x = create_node_features(node_type, node_quantiles)
    graph[PATIENT_NAME, EDGE_TYPE, node_type].edge_index = create_edge_features_to_patient(node_type, quantile_indices)

## Graph construction

In [129]:
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T
from imblearn.under_sampling import RandomUnderSampler

graph = HeteroData()
rus = RandomUnderSampler(random_state=42)
training_x = training.loc[:, (AGE_COLUMN_NAME, W_COLUMN_NAME, M_COLUMN_NAME, HGB_COLUMN_NAME, RBC_COLUMN_NAME, WBC_COLUMN_NAME, PLT_COLUMN_NAME, MCV_COLUMN_NAME)]
training_y = training.loc[:, LABEL_COLUMN_NAME]
x_t, y_t = rus.fit_resample(training_x, training_y)
rus_indices = rus.sample_indices_

graph[PATIENT_NAME].x = to_tensor(training_x)[rus_indices]
add_features_and_edges(graph, WBC_COLUMN_NAME, training, rus_indices)
add_features_and_edges(graph, RBC_COLUMN_NAME, training, rus_indices)
add_features_and_edges(graph, PLT_COLUMN_NAME, training, rus_indices)
add_features_and_edges(graph, MCV_COLUMN_NAME, training, rus_indices)
add_features_and_edges(graph, HGB_COLUMN_NAME, training, rus_indices)
graph[PATIENT_NAME].y = to_tensor(training_y)[rus_indices]
labels = to_tensor(training_y)
print(labels.count_nonzero())
print(labels.shape)
print(training_y.to_numpy()[training_y.to_numpy().nonzero()].shape)

print(rus.sample_indices_.shape)

tensor(1526)
torch.Size([1015074])
(1526,)
(3052,)


  return torch.Tensor(list(df.values))


In [130]:
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T

test_graph = HeteroData()

test_graph[PATIENT_NAME].x = to_tensor(testing.loc[:, (AGE_COLUMN_NAME, W_COLUMN_NAME, M_COLUMN_NAME, HGB_COLUMN_NAME, RBC_COLUMN_NAME, WBC_COLUMN_NAME, PLT_COLUMN_NAME, MCV_COLUMN_NAME)])
add_features_and_edges(test_graph, WBC_COLUMN_NAME, testing, None)
add_features_and_edges(test_graph, RBC_COLUMN_NAME, testing, None)
add_features_and_edges(test_graph, PLT_COLUMN_NAME, testing, None)
add_features_and_edges(test_graph, MCV_COLUMN_NAME, testing, None)
add_features_and_edges(test_graph, HGB_COLUMN_NAME, testing, None)
test_graph[PATIENT_NAME].y = to_tensor(testing.loc[:, LABEL_COLUMN_NAME])

  return torch.Tensor(list(df.values))


In [131]:
transform = T.RandomNodeSplit(num_test=0, num_val=.2)
graph = transform(graph)
graph = T.ToUndirected()(graph)
test_graph = T.ToUndirected()(test_graph)

In [143]:
print(graph)
print(test_graph)

HeteroData(
  [1mPATIENT[0m={
    x=[3052, 8],
    y=[3052],
    train_mask=[3052],
    val_mask=[3052],
    test_mask=[3052]
  },
  [1mWBC[0m={ x=[13, 3] },
  [1mRBC[0m={ x=[13, 3] },
  [1mPLT[0m={ x=[13, 3] },
  [1mMCV[0m={ x=[13, 3] },
  [1mHGB[0m={ x=[13, 3] },
  [1m(PATIENT, HAS, WBC)[0m={ edge_index=[2, 3052] },
  [1m(PATIENT, HAS, RBC)[0m={ edge_index=[2, 3052] },
  [1m(PATIENT, HAS, PLT)[0m={ edge_index=[2, 3052] },
  [1m(PATIENT, HAS, MCV)[0m={ edge_index=[2, 3052] },
  [1m(PATIENT, HAS, HGB)[0m={ edge_index=[2, 3052] },
  [1m(WBC, rev_HAS, PATIENT)[0m={ edge_index=[2, 3052] },
  [1m(RBC, rev_HAS, PATIENT)[0m={ edge_index=[2, 3052] },
  [1m(PLT, rev_HAS, PATIENT)[0m={ edge_index=[2, 3052] },
  [1m(MCV, rev_HAS, PATIENT)[0m={ edge_index=[2, 3052] },
  [1m(HGB, rev_HAS, PATIENT)[0m={ edge_index=[2, 3052] }
)
HeteroData(
  [1mPATIENT[0m={
    x=[366284, 8],
    y=[366284]
  },
  [1mWBC[0m={ x=[13, 3] },
  [1mRBC[0m={ x=[13, 3] },
  [1mPLT[0

### Model and optimizer defintion

In [144]:
from torch_geometric.nn import GATConv, to_hetero,Linear
from torch_geometric.nn.conv import HANConv

from torchmetrics import AUROC

class GNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
#         self.conv1 = GATConv((-1,-1), 16, add_self_loops=False)
#         self.conv2 = GATConv((-1,-1), 8,  add_self_loops=False)
#         self.conv3 = GATConv((-1,-1), 2,  add_self_loops=False)
        self.conv1 = HANConv(-1, 32, graph.metadata())
        self.conv2 = HANConv(-1, 16, graph.metadata())
        self.conv3 = HANConv(-1, 8, graph.metadata())
        self.lin_end = Linear(8, 1)
#         self.batchnorm_1 = torch.nn.BatchNorm1d(16)
#         self.batchnorm_2 = torch.nn.BatchNorm1d(8)

    def forward(self, graph):
        x, edge_index = graph.x_dict, graph.edge_index_dict
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        x = torch.relu(x)
        x = self.conv3(x, edge_index)
        x = torch.relu(x)
        x = self.lin_end(x[PATIENT_NAME])
        return x # torch.log_softmax(x[PATIENT_NAME], dim=-1)#torch.log_softmax(x[PATIENT_NAME], dim=-1) ##TODO change this 
    

### Sample weight calulcation for loss

In [145]:
# sepsis_cases = torch.count_nonzero(graph[PATIENT_NAME].y)
# control_cases = graph[PATIENT_NAME].y.size(dim=0) - sepsis_cases
# control_weight = sepsis_cases / (1*(control_cases + sepsis_cases))
# sepsis_weight = control_cases*1 / (control_cases + sepsis_cases)
# class_weights = torch.tensor([control_weight, sepsis_weight ]) 

In [146]:
import torch_geometric.transforms as T
import torch.nn.functional as F
from torch_geometric.loader import NeighborLoader, ImbalancedSampler
import matplotlib.pyplot as plt

train_mask = graph[PATIENT_NAME].train_mask
val_mask = graph[PATIENT_NAME].val_mask
auroc_metric = AUROC(task="binary")
model = GNN()
model = to_hetero(model, graph.metadata(), aggr='sum')
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
epochs = range(10)

# all_weights = torch.clone(graph[PATIENT_NAME].y) * sepsis_weight + control_weight
loss_values = []

def train():
#     weight = torch.clone(graph[PATIENT_NAME].y) * torch.max(all_weights) + torch.min(all_weights)
    model.train()
    optimizer.zero_grad()
    out = model(graph)
#     class_weight = torch.clone(graph[PATIENT_NAME].y) * class_weights[1] + class_weights[0]
    loss = F.binary_cross_entropy_with_logits(torch.squeeze(out), graph[PATIENT_NAME].y.type(torch.float))
    print(loss.item())
    loss_values.append(loss.item())
    loss.backward()
    optimizer.step()
    
    
@torch.no_grad()
def validate_unseen():
    model.eval()
    pred = model(test_graph)
    pred_patient = torch.round(torch.sigmoid(torch.squeeze(pred)))
    print(pred_patient)
    auroc = auroc_metric(pred_patient, test_graph[PATIENT_NAME].y)
    print(f"AUROC: {auroc.item():.4f}")
    correct = (pred_patient == test_graph[PATIENT_NAME].y).sum()
    acc = int(correct) / int(test_graph[PATIENT_NAME].x.size(dim=0))
    print(f'Accuracy: {acc:.4f}')
    
for epoch in epochs:
    print(epoch)
    train()
    validate_unseen()
plt.plot(epochs, loss_values, 'g', label='Training loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
#     if epoch % 1 == 0:
#         test()
#     print("validate")
    

0


Traceback (most recent call last):
  File "/home/dwalke/.local/lib/python3.10/site-packages/torch/fx/graph_module.py", line 267, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/home/dwalke/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "<eval_with_key>.3", line 12, in forward
    getattr_1__PATIENT = graph__PATIENT.x_dict
AttributeError: 'NoneType' object has no attribute 'x_dict'

Call using an FX-traced Module, line 12 of the traced Module's generated forward function:
    graph__HGB = graph_dict.get('HGB', None);  graph_dict = None
    getattr_1__PATIENT = graph__PATIENT.x_dict

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    getattr_1__WBC = graph__WBC.x_dict

    getattr_1__RBC = graph__RBC.x_dict



AttributeError: 'NoneType' object has no attribute 'x_dict'

In [None]:
# import sklearn.model_selection.StratifiedKFold as skfold

# skf = skfold(n_splits=10)
# for i, (train_index, test_index) in enumerate()

In [28]:
import time
def sleeper(minutes):
    for i in range(minutes):
        time.sleep(60)
        print("Still sleeping and waiting for you so you dont have to reconnect everything")
sleeper(120)

Still sleeping and waiting for you so you dont have to reconnect everything
Still sleeping and waiting for you so you dont have to reconnect everything


KeyboardInterrupt: 