# ST Pathology GNN Network

This notebook contains the full analysis of the ST Pathology framework. We start with the module to predict the Heidelberg classifier subgroups from spatially resolved transcriptomics using a 3-hop subgraph


In [None]:
## Install Pytorch geometric
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

The input data are saved on the google Drive: We connect the google drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
## Import the network arcitectures
import sys
sys.path.append('/content/drive/My Drive/GIN_Train/Script')

from GIN_ClassPred_V2_All_Functions import *
from reduceNN import *
from runQC import *


## Import data and run quality controll

In [None]:
graph_train = torch.load("/content/drive/My Drive/GIN_Train/Data/Graph_Class_train.pt")

In [None]:
graph_train = runQC(graph_train)

In [None]:
graph_NN1 = reduceNN(graph_train, hop=1)
graph_NN2 = reduceNN(graph_train, hop=2)
graph_NN3 = graph_train

In [None]:
model_NN1 = RunTrainingGIN(graph_NN1, num_classes=11, epochs=200, batch_size=1500)
torch.save(model_NN1, '/content/drive/My Drive/GIN_Train/Model/Class_model_NN1.pth')


model_NN2 = RunTrainingGIN(graph_NN2, num_classes=11,epochs=200,batch_size=1500)
torch.save(model_NN2, '/content/drive/My Drive/GIN_Train/Model/Class_model_NN2.pth')


model_NN3 = RunTrainingGIN(graph_NN3, num_classes=11,epochs=200,batch_size=1500)
torch.save(model_NN3, '/content/drive/My Drive/GIN_Train/Model/Class_model_NN3.pth')

In [None]:
# Validation

graph_val = torch.load("/content/drive/My Drive/GIN_Train/Data/Graph_Class_val.pt")
graph_val = runQC(graph_val)
graph_val_NN1 = reduceNN(graph_val, hop=1)
graph_val_NN2 = reduceNN(graph_val, hop=2)
graph_val_NN3 = graph_val


## Run Validation

In [None]:
import os
import torch
import numpy as np
import pandas as pd
import networkx as nx
import torch
import torch_geometric.utils as utils
import matplotlib as PL
from tqdm import tqdm
import sklearn
from sklearn import preprocessing
import matplotlib.pyplot as plt

import torch
from torch.nn import BatchNorm1d, Linear
from torch_geometric.nn import GATConv
from torch_geometric.data import DataLoader
from torch_geometric.utils import add_self_loops
from torch_geometric.data import Data
import torch.nn.functional as F
from torch_geometric.nn import GINConv
from torch.nn import Linear
import torch.optim as optim
from torch_geometric.nn import global_mean_pool
import torch.nn as nn

def RunEvaluationGINClass1(graph, model):

  model.eval()
  latent_space = []
  class_out_logits = []
  class_out_list = []


  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  print("Running on:", device)

  model.to(device)
  i=1

  for data in tqdm(graph, desc="Eval"):

    #i=i+1
    #print(i)
    latent, class_out = model(data.to(device))

    ## Latent space
    latent_space.append(latent.mean(dim=0, keepdim=True).detach().cpu().numpy())

    ## Status
    class_out_logits.append(class_out.detach().cpu().numpy())
    class_out_list.append(torch.argmax(class_out, dim=1).detach().cpu().numpy())



  return(np.concatenate(latent_space), np.concatenate(class_out_logits), np.concatenate(class_out_list))

In [None]:
data

In [None]:
val_NN1 = RunEvaluationGINClass1(graph_val_NN1, model_NN1)
val_NN2 = RunEvaluationGINClass1(graph_val_NN2, model_NN2)
val_NN3 = RunEvaluationGINClass1(graph_val_NN3, model_NN3)

In [None]:
gt = []
for i in tqdm(range(len(graph_val_NN1))):
  gt.append(graph_val_NN1[i].Class.detach().cpu().numpy())

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix

predicted = val_NN1[2]
labels = np.array(gt)

print("Accuracy:", accuracy_score(labels, predicted))
print("Precision:", precision_score(labels, predicted, average='macro'))
print("Recall:", recall_score(labels, predicted, average='macro'))
print("F1 Score:", f1_score(labels, predicted, average='macro'))
print("Confusion Matrix:\n", confusion_matrix(labels, predicted))

In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

cm = confusion_matrix(labels, predicted)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='g', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix: True vs Predicted')
plt.show()

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix

predicted = val_NN2[2]
labels = np.array(gt)

print("Accuracy:", accuracy_score(labels, predicted))
print("Precision:", precision_score(labels, predicted, average='macro'))
print("Recall:", recall_score(labels, predicted, average='macro'))
print("F1 Score:", f1_score(labels, predicted, average='macro'))
print("Confusion Matrix:\n", confusion_matrix(labels, predicted))

In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

cm = confusion_matrix(labels, predicted)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='g', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix: True vs Predicted')
plt.show()

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix

predicted = val_NN3[2]
labels = np.array(gt)

print("Accuracy:", accuracy_score(labels, predicted))
print("Precision:", precision_score(labels, predicted, average='macro'))
print("Recall:", recall_score(labels, predicted, average='macro'))
print("F1 Score:", f1_score(labels, predicted, average='macro'))
print("Confusion Matrix:\n", confusion_matrix(labels, predicted))

In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

cm = confusion_matrix(labels, predicted)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='g', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix: True vs Predicted')
plt.show()

## Run Prediction only from a single plot gene expression file

In [None]:
## Linear Network

class LinearExp(torch.nn.Module):
    def __init__(self, num_features_exp, hidden_channels, num_classes):
        super(LinearExp, self).__init__()

        # First Layer
        #self.merge = Linear(num_features_exp, hidden_channels)

        # MLP Prediction Class
        self.mlp_class = torch.nn.Sequential(
            torch.nn.Linear(num_features_exp, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5),
            torch.nn.Linear(hidden_channels, num_classes)
        )


    def forward(self, data):
        exp = data.y
        class_out = self.mlp_class(exp)

        return class_out


def RunTrainingLinear(graph, hidden_channels = 256, num_classes=11, epochs = 50,learning_rate = 0.001, batch_size=32):

  num_features_exp = graph[1].y.shape[1]

  model = LinearExp(num_features_exp, hidden_channels, num_classes=num_classes)
  optimizer = optim.Adam(model.parameters(), lr=learning_rate)

  model.train()
  loader = DataLoader(graph, batch_size=batch_size, shuffle=True)

  criterion = torch.nn.CrossEntropyLoss()

  epoch_loss_list = []
  epoch_loss = 0

  #data = next(iter(loader))

  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  print("Running on:", device)
  model = model.to(device)


  for epoch in tqdm(range(epochs), desc="Training"):
    for data in loader:
        optimizer.zero_grad()
        class_out = model(data.to(device))

        #Class
        gt = data.Class.long()
        loss = criterion(class_out, gt)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    epoch_loss /= len(loader)
    epoch_loss_list.append(epoch_loss)

  import matplotlib.pyplot as plt
  plt.close()
  plt.scatter(range(len(epoch_loss_list)), epoch_loss_list)
  plt.show()
  plt.close()

  return(model)



In [None]:
model_lin = RunTrainingLinear(graph_train, epochs=50, num_classes=11, batch_size=2000)

In [None]:
def RunEvaluationGINClassLin(graph, model):

  model.eval()
  class_out_logits = []
  class_out_list = []


  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  print("Running on:", device)

  model.to(device)
  i=1

  for data in tqdm(graph, desc="Eval"):

    #i=i+1
    #print(i)
    class_out = model(data.to(device))

    ## Status
    class_out_logits.append(class_out.detach().cpu().numpy())
    class_out_list.append(torch.argmax(class_out, dim=1).detach().cpu().numpy())



  return(np.concatenate(class_out_logits), np.concatenate(class_out_list))

In [None]:
val_lin = RunEvaluationGINClassLin(graph_val, model_lin)

In [None]:
 #

In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

cm = confusion_matrix(labels, predicted)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='g', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix: True vs Predicted')
plt.show()

## Try Different Architectures

In [None]:
## Import the network arcitectures
import sys
sys.path.append('/content/drive/My Drive/GIN_Train/Script')
from GAN_V1 import *


In [None]:
import os
import torch
import numpy as np
import pandas as pd
import networkx as nx
import torch
import torch_geometric.utils as utils
import matplotlib as PL
from tqdm import tqdm
import sklearn
from sklearn import preprocessing
import matplotlib.pyplot as plt
from torch_geometric.nn import global_mean_pool
import torch
from torch.nn import BatchNorm1d, Linear
from torch_geometric.nn import GATConv
from torch_geometric.loader import DataLoader
from torch_geometric.loader import DataLoader
from torch_geometric.utils import add_self_loops
from torch_geometric.data import Data
import torch.nn.functional as F
from torch_geometric.nn import GINConv
from torch.nn import Linear
import torch.optim as optim
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn import GATConv
import torch
import torch.nn as nn
import torch
from torch_geometric.nn import GATConv, global_mean_pool, LayerNorm
from torch.nn import Linear

from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, softmax
import torch.nn.functional as F


class GAN(torch.nn.Module):
    def __init__(self, num_features_exp, hidden_channels, num_classes):
        super(GAN, self).__init__()

        # Attention GAT Conv Layers
        per_head_hidden_channels = hidden_channels // 5
        self.conv1_exp = GATConv(num_features_exp, per_head_hidden_channels, heads=5)
        self.conv2_exp = GATConv(per_head_hidden_channels * 5, per_head_hidden_channels, heads=5)


       # Batch norm layer
        self.bn1 = torch.nn.BatchNorm1d(hidden_channels)
        self.bn2 = torch.nn.BatchNorm1d(hidden_channels)
        self.dropout = torch.nn.Dropout(0.5) # Add dropout for regularization

        # Latent space
        self.merge = Linear(hidden_channels, hidden_channels)

        # Initiate weights
        torch.nn.init.xavier_uniform_(self.merge.weight.data)

        # MLP Prediction Class
        self.mlp_class = torch.nn.Sequential(
            torch.nn.Linear(hidden_channels, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(hidden_channels),
            torch.nn.Dropout(0.5), # Add dropout in the MLP as well
            torch.nn.Linear(hidden_channels, num_classes)
        )

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    torch.nn.init.constant_(m.bias, 0)


    def forward(self, data):
        exp, edge_index = data.x, data.edge_index

        # GATConv layers require edge_index to be long type
        edge_index = edge_index.long()

        x_exp, attention_weights_1 = self.conv1_exp(exp, edge_index, return_attention_weights=True)
        x_exp = F.leaky_relu(x_exp)
        x_exp = self.dropout(self.bn1(x_exp))

        x_exp, attention_weights_2 = self.conv2_exp(x_exp, edge_index, return_attention_weights=True)
        x_exp = F.leaky_relu(x_exp)
        x_exp = self.dropout(self.bn2(x_exp))

        x = self.merge(x_exp)
        x = F.leaky_relu(x)

        class_out = self.mlp_class(global_mean_pool(x, data.batch))

        return x, class_out, attention_weights_1, attention_weights_2

def RunGAN1(graph,num_classes, hidden_channels = 255, epochs = 50, learning_rate = 0.001, batch_size=16, weight_decay=0.01):

  num_features_exp = graph[1].x.shape[1]

  model = GAN(num_features_exp, hidden_channels, num_classes)
  optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

  model.train()
  loader = DataLoader(graph, batch_size=batch_size, shuffle=True, drop_last=True)

  criterion = torch.nn.CrossEntropyLoss()

  epoch_loss_list = []
  epoch_loss = 0

  #data = next(iter(loader))

  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  print("Running on:", device)
  model = model.to(device)


  for epoch in tqdm(range(epochs), desc="Training"):
    for data in loader:
        optimizer.zero_grad()
        latent, class_out, AT1, AT2 = model(data.to(device))

        #Class
        gt = data.Class.long()
        loss = criterion(class_out, gt)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    epoch_loss /= len(loader)
    epoch_loss_list.append(epoch_loss)

  import matplotlib.pyplot as plt
  plt.close()
  plt.scatter(range(len(epoch_loss_list)), epoch_loss_list)
  plt.show()
  plt.close()

  return(model)



In [None]:
model_NN1 = RunGAN1(graph_NN1, num_classes=11, epochs=100, batch_size=1500)
torch.save(model_NN1, '/content/drive/My Drive/GIN_Train/Model/Class_model_GAN_NN1.pth')


model_NN2 = RunGAN1(graph_NN2, num_classes=11,epochs=100,batch_size=1500)
torch.save(model_NN2, '/content/drive/My Drive/GIN_Train/Model/Class_model_GAN_N2.pth')


model_NN3 = RunGAN1(graph_NN3, num_classes=11,epochs=100,batch_size=1500)
torch.save(model_NN3, '/content/drive/My Drive/GIN_Train/Model/Class_model_GANN_N3.pth')

In [None]:
model_NN1 = torch.load('/content/drive/My Drive/GIN_Train/Model/Class_model_GAN_NN1.pth')
model_NN2 = torch.load('/content/drive/My Drive/GIN_Train/Model/Class_model_GANNN2.pth')
model_NN3 = torch.load('/content/drive/My Drive/GIN_Train/Model/Class_model_GANNN3.pth')

Load validation data

In [None]:
graph_val = torch.load("/content/drive/My Drive/GIN_Train/Data/Graph_Class_val.pt")
graph_val = runQC(graph_val)
graph_val_NN1 = reduceNN(graph_val, hop=1)
graph_val_NN2 = reduceNN(graph_val, hop=2)
graph_val_NN3 = graph_val

In [None]:
graph_val_NN1[1]

In [None]:
import os
import torch
import numpy as np
import pandas as pd
import networkx as nx
import torch
import torch_geometric.utils as utils
import matplotlib as PL
from tqdm import tqdm
import sklearn
from sklearn import preprocessing
import matplotlib.pyplot as plt

import torch
from torch.nn import BatchNorm1d, Linear
from torch_geometric.nn import GATConv
from torch_geometric.data import DataLoader
from torch_geometric.utils import add_self_loops
from torch_geometric.data import Data
import torch.nn.functional as F
from torch_geometric.nn import GINConv
from torch.nn import Linear
import torch.optim as optim
from torch_geometric.nn import global_mean_pool
import torch.nn as nn



  model.to(device)
  i=1

  for data in tqdm(graph, desc="Eval"):

    #i=i+1
    #print(i)
    latent, class_out, AT1, AT2 = model(data.to(device))

    ## Latent space
    latent_space.append(latent.mean(dim=0, keepdim=True).detach().cpu().numpy())

    ## Status
    class_out_logits.append(class_out.detach().cpu().numpy())
    class_out_list.append(torch.argmax(class_out, dim=1).detach().cpu().numpy())



  return(np.concatenate(latent_space), np.concatenate(class_out_logits), np.concatenate(class_out_list))

In [None]:
val_NN1 = RunEvaluationGAN(graph_val_NN1, model_NN1)
val_NN2 = RunEvaluationGAN(graph_val_NN2, model_NN2)
val_NN3 = RunEvaluationGAN(graph_val_NN3, model_NN3)

In [None]:
gt = []
for i in tqdm(range(len(graph_val_NN1))):
  gt.append(graph_val_NN1[i].Class.detach().cpu().numpy())

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix

predicted = val_NN1[2]
labels = np.array(gt)

print("Accuracy:", accuracy_score(labels, predicted))
print("Precision:", precision_score(labels, predicted, average='macro'))
print("Recall:", recall_score(labels, predicted, average='macro'))
print("F1 Score:", f1_score(labels, predicted, average='macro'))
print("Confusion Matrix:\n", confusion_matrix(labels, predicted))

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix



predicted = val_NN2[2]
labels = np.array(gt)

print("Accuracy:", accuracy_score(labels, predicted))
print("Precision:", precision_score(labels, predicted, average='macro'))
print("Recall:", recall_score(labels, predicted, average='macro'))
print("F1 Score:", f1_score(labels, predicted, average='macro'))
print("Confusion Matrix:\n", confusion_matrix(labels, predicted))

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix

predicted = val_NN3[2]
labels = np.array(gt)

print("Accuracy:", accuracy_score(labels, predicted))
print("Precision:", precision_score(labels, predicted, average='macro'))
print("Recall:", recall_score(labels, predicted, average='macro'))
print("F1 Score:", f1_score(labels, predicted, average='macro'))
print("Confusion Matrix:\n", confusion_matrix(labels, predicted))

## Different Data Split: Testset 0.2 of Patients

In [None]:
# Read in pd.data.frame

import pandas as pd
from sklearn.model_selection import train_test_split

df = pd.read_csv("/content/drive/My Drive/GIN_Train/Data/df_subgraph_train_data_split.csv")
df['class_n'] = df['class_n'].astype(int)
train, test = train_test_split(df, test_size=0.5, stratify=df["class_n"])

df_subgraph = pd.read_csv("/content/drive/My Drive/GIN_Train/Data/df_subgraph_train_data.csv")

In [None]:
print(df_subgraph.shape)

In [None]:
import torch
graph_train = torch.load("/content/drive/My Drive/GIN_Train/Data/Graph_Class_train.pt")

In [None]:
## Import the network arcitectures
import sys
sys.path.append('/content/drive/My Drive/GIN_Train/Script')

from GIN_ClassPred_V2_All_Functions import *
from reduceNN import *
from runQC import *

In [None]:
len(graph_train)

In [None]:
import numpy as np
pat_train = np.asarray(train["pat_index"])
filtered_df = df_subgraph[df_subgraph['pat_index'].isin(pat_train)].copy()

In [None]:
filtered_df

In [None]:
np.unique(np.asarray(filtered_df["class_n"]))

In [None]:
## Relabel Data
original_vector = np.unique(np.asarray(filtered_df["class_n"]))
new_labels = [0, 1, 2, 3, 4, 5]

mapping = {original: new for original, new in zip(original_vector, new_labels)}

# Apply the mapping
relabelled_vector = [mapping[item] for item in np.asarray(filtered_df["class_n"])]

print(len(relabelled_vector))
print(filtered_df.shape)

In [None]:
filtered_df.loc[:, "class_n"] = relabelled_vector

In [None]:
sub_train = np.asarray(filtered_df["index_subgraph"])-1

In [None]:
graph_train_pat = [graph_train[i] for i in sub_train]

class_new = np.asarray(filtered_df["class_n"])
for i in tqdm(range(len(graph_train_pat))):
  graph_train_pat[i].Class = torch.as_tensor(np.asarray(class_new[i], dtype="int8"), dtype=torch.float)


In [None]:
graph_train_pat=runQC(graph_train_pat)

Preprocess Test Set

In [None]:
pat_train = np.asarray(test["pat_index"])
filtered_df = df_subgraph[df_subgraph['pat_index'].isin(pat_train)].copy()

original_vector = np.unique(np.asarray(filtered_df["class_n"]))
original_vector

In [None]:
import os
import torch
import numpy as np
from tqdm import tqdm

def runQC(graph,nr_nodes=15):



  ## Remove subgraphs with less then 3 hop
  nodes = []
  for i in tqdm(range(len(graph))):
    nodes.append(graph[i].num_nodes)

  nodes = np.hstack(nodes)
  index=np.where(nodes>=nr_nodes)[0]
  graph = [graph[i] for i in index]

  index_list_out = index




  NN = []
  for i in tqdm(range(len(graph))):
    NN.append(graph[i].neighborhood_index.max().detach().cpu().numpy())

  samples = np.hstack(NN)
  index=np.where(samples==3)[0]
  graph = [graph[i] for i in index]

  index_list_out = [index_list_out[i] for i in index]


  return(graph,index_list_out)

In [None]:
new_labels = [0, 1, 2, 3, 4, 5]
mapping = {original: new for original, new in zip(original_vector, new_labels)}
filtered_df["class_n"] = [mapping[item] for item in np.asarray(filtered_df["class_n"])]

sub_train = np.asarray(filtered_df["index_subgraph"])-1
graph_test_pat = [graph_train[i] for i in sub_train]

class_new = np.asarray(filtered_df["class_n"])
for i in tqdm(range(len(graph_test_pat))):
  graph_test_pat[i].Class = torch.as_tensor(np.asarray(class_new[i], dtype="int8"), dtype=torch.float)


graph_test_pat, index =runQC(graph_test_pat)
filtered_df = filtered_df.iloc[index].copy()




In [None]:
print(np.unique(np.asarray(filtered_df["class_n"])))
print(np.unique(np.asarray(filtered_df["Class"])))


Change Class in graph dataset

In [None]:
model_NN3_split = RunTrainingGIN(graph_train_pat, num_classes=6,epochs=50,batch_size=1500)
torch.save(model_NN3_split, '/content/drive/My Drive/GIN_Train/Model/Class_model_NN3_split.pth')

In [None]:
val = RunEvaluationGINClass(graph_test_pat, model_NN3_split)

In [None]:
gt = []
for i in tqdm(range(len(graph_test_pat))):
  gt.append(graph_test_pat[i].Class.detach().cpu().numpy())

In [None]:
np.unique(gt)


In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix

predicted = val[2]
labels = np.array(gt)

print("Accuracy:", accuracy_score(labels, predicted))
print("Precision:", precision_score(labels, predicted, average='macro'))
print("Recall:", recall_score(labels, predicted, average='macro'))
print("F1 Score:", f1_score(labels, predicted, average='macro'))
print("Confusion Matrix:\n", confusion_matrix(labels, predicted))

In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

cm = confusion_matrix(labels, predicted)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='g', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix: True vs Predicted')
plt.show()

In [None]:
## Prediction on patient level
filtered_df["prediction"] = predicted
filtered_df



In [None]:
grouped = filtered_df.groupby(['pat_index', 'prediction']).size().reset_index(name='counts').copy()
total_counts = grouped.groupby('pat_index')['counts'].transform('sum')
grouped['percentage'] = (grouped['counts'] / total_counts) * 100
grouped

#len(np.unique(np.array(grouped["pat_index"])))

In [None]:
class_df = filtered_df.groupby(['pat_index', 'class_n']).size().reset_index(name='counts').copy()
class_df

In [None]:
max_percentage_idx = grouped.groupby('pat_index')['percentage'].idxmax()
max_percentage_rows = grouped.loc[max_percentage_idx]
max_percentage_rows["GT"] = np.array(class_df["class_n"], dtype="int8")
max_percentage_rows

In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

cm = confusion_matrix(max_percentage_rows["GT"], max_percentage_rows["prediction"])

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='g', cmap='Reds')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix: True vs Predicted')
plt.show()

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix

predicted = max_percentage_rows["prediction"]
labels = max_percentage_rows["GT"]

print("Accuracy:", accuracy_score(labels, predicted))
print("Precision:", precision_score(labels, predicted, average='macro'))
print("Recall:", recall_score(labels, predicted, average='macro'))
print("F1 Score:", f1_score(labels, predicted, average='macro'))
print("Confusion Matrix:\n", confusion_matrix(labels, predicted))