# 5-fold cross validation with optimized parameters 

In [2]:
# Standard library imports
import os
import sys
import random
from collections import Counter

# Third-party library imports
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.metrics import classification_report, accuracy_score
from sklearn.model_selection import StratifiedShuffleSplit
from glob import glob
import matplotlib.pyplot as plt
import seaborn as sns


# Local module imports
os.chdir('C:/Users/Adminn/Documents/GitHub/CEG/src')
from Graph_builder import *  # Import graph-building utilities
from CellECMGraphs_multiple import *  # Import Cell-ECM graph utilities
from Helper_functions import *  # Import helper functions

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)

# Set paths for data directories
full_stack_img_dir = np.sort(glob('D:/raw/*/img/*')) # Path to IMC files 
panel_path = "C:/Users/Adminn/Downloads/panel.csv" # Path to panel - with ECM column 
cell_data_dir = np.sort(glob("C:/Users/Adminn/Desktop/PhD/cell_ECM_graphs/data/cell_data/*")) # Path to cell data files - cell types/centroids

In [3]:
# Build cell graph with optimal parameters from previous tuning

save_folder = 'C:/Users/Adminn/Documents/GitHub/Cell_ECM_Graphs_Publication/Notebooks/figure_11/fig11_results/'

# Build cell graphs 
ceg = Cell_ECM_Graphs(full_stack_img_path=full_stack_img_dir, 
                panel_path=panel_path,
                cell_data_path=cell_data_dir,
                save_folder=save_folder,
                )

# Optimal parameters for cell graph building from Bayesian optimization
Dmax_CC = 50  # Maximum distance for cell-cell connections
Dmax_CE = 0   # Maximum distance for cell-ECM connections
N_graph = 3 # Number of nearest neighbours in graph building 
K_clsf = 29 # Number of K for KNN classifier

# Build cell-ECM graphs for all regions of interest (ROIs)
ceg.build_multiple_graphs(Dmax_CC=Dmax_CC, Dmax_CE=Dmax_CE, interaction_k=N_graph)

# Cluster all ECM patches of all ROIs together 
ceg.joint_ecm_clustering() 


Building Cell-ECM-Graphs...
ROI 0 complete.
ROI 1 complete.
ROI 2 complete.
ROI 3 complete.
ROI 4 complete.
ROI 5 complete.
ROI 6 complete.
ROI 7 complete.
ROI 8 complete.
ROI 9 complete.
ROI 10 complete.
ROI 11 complete.
ROI 12 complete.
ROI 13 complete.
ROI 14 complete.
ROI 15 complete.
ROI 16 complete.
ROI 17 complete.
ROI 18 complete.
ROI 19 complete.
ROI 20 complete.
ROI 21 complete.
ROI 22 complete.
ROI 23 complete.
ROI 24 complete.
ROI 25 complete.
ROI 26 complete.
ROI 27 complete.
ROI 28 complete.
ROI 29 complete.
ROI 30 complete.
ROI 31 complete.
ROI 32 complete.
ROI 33 complete.
ROI 34 complete.
ROI 35 complete.
Clustering all ECM patches together ... 


In [None]:
def embeddings_neighbourhood_feature_vector(ceg_dict, g_type):
    all_node_features = []
    all_node_labels = [ ]
    cell_or_ecm = []
    if g_type == 'cellgraph':
        # Acess cell graph 
        temp_G = ceg_dict.cell_G
    if g_type == 'cellecmgraph':
        temp_G = ceg_dict.G
    
    for n,attri in temp_G.nodes(data=True):
        if 'cell' in n:
            all_node_labels.append(attri['cell_type'])
            cell_or_ecm.append('cell')
        else:
            all_node_labels.append(attri['ecm_labels'])
            cell_or_ecm.append('ecm')

    unique_labels = np.unique(all_node_labels)
    if g_type == 'cellgraph':
        temp_G = ceg_dict.cell_G
    if g_type == 'cellecmgraph':
        temp_G = ceg_dict.G
            
    nodes = list(temp_G.nodes)
    for n in nodes:
        neighbours = list(temp_G.neighbors(n))
        frequency_count_per_node = pd.DataFrame(np.zeros((len(unique_labels))), index=unique_labels)
        for neigh in neighbours:
            if 'cell' in neigh:
                n_ct = temp_G.nodes[neigh]['cell_type']
            else:
                n_ct = str(temp_G.nodes[neigh]['ecm_labels'])
            frequency_count_per_node.loc[n_ct] += 1
        
        all_node_features.append(frequency_count_per_node.values.flatten())
    
    return np.array(all_node_features), np.array(all_node_labels), np.array(cell_or_ecm)

def node_classification_knn(ceg, g_type,N_NEIGHBORS):

    if g_type == 'ecmgraph':
        x, y, cell_or_ecm = embeddings_neighbourhood_feature_vector(ceg, 'cellecmgraph')
        x = x[cell_or_ecm == 'cell'] # features for cells
        y = y[cell_or_ecm == 'cell'] # labels for cells 
        x = x[:, :3]  # Select only the ECM markers

    else:
        x, y, cell_or_ecm = embeddings_neighbourhood_feature_vector(ceg, g_type)
        x = x[cell_or_ecm == 'cell'].astype(int)
        y = y[cell_or_ecm == 'cell']
        
    # Stratified train-test split
    unique_classes, class_counts = np.unique(y, return_counts=True)
    sufficient_classes = unique_classes[class_counts >= 2]  # Classes that have at least 2 samples

    # Filter x and y to include only the sufficient classes
    x_filtered = x[np.isin(y, sufficient_classes)]
    y_filtered = y[np.isin(y, sufficient_classes)]

    # Stratified train-test split
    split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
    for train_index, test_index in split.split(x_filtered, y_filtered):
        X_train, X_test = x_filtered[train_index], x_filtered[test_index]
        y_train, y_test = y_filtered[train_index], y_filtered[test_index]


    # Train KNN model
    from sklearn.neighbors import KNeighborsClassifier
    knn = KNeighborsClassifier(n_neighbors=N_NEIGHBORS, metric='cityblock')  # You can adjust n_neighbors
    knn.fit(X_train, y_train)

    # Make predictions on test data
    y_pred = knn.predict(X_test)

    # Evaluate model performance (classification report)
    report = classification_report(y_test, y_pred, output_dict=True)
    
    # Convert report to DataFrame
    report_df = pd.DataFrame(report).transpose()

    # Calculate accuracy for each class
    accuracy_per_class = {}
    for cls in set(y_test):
        idx = y_test == cls
        accuracy_per_class[cls] = accuracy_score(y_test[idx], y_pred[idx])

    # Calculate overall accuracy
    overall_accuracy = accuracy_score(y_test, y_pred)

    # Add class accuracies to the report DataFrame
    report_df = report_df.drop(columns='support')
    accuracy_per_class = pd.DataFrame(accuracy_per_class, index=[0]).T
    report_df = pd.concat((accuracy_per_class, report_df), axis=1)
    report_df.columns = ['Accuracy', 'Precision', 'Recall', 'F1-Score']
    return report_df

def get_macro_metrics(cg_final, name):
    cg_macro_metrics = pd.DataFrame(
    [[cg_final.Precision.accuracy, cg_final.Precision['weighted avg'],cg_final.Recall['weighted avg'],cg_final['F1-Score']['weighted avg'] ]],
        columns=['Accuracy','Precision', 'Recall', 'F1-Score'],
        index=[name])
    return cg_macro_metrics

   cg_accuracy = []
    cg_precision = []
    cg_recall = []
    cg_f1_scores = []

    ceg_accuracy = []
    ceg_precision = []
    ceg_recall = []
    ceg_f1_scores = []

    ecm_accuracy = []
    ecm_precision = []
    ecm_recall = []
    ecm_f1_scores = []

    cg_reports = []
    ceg_reports = []
    ecm_reports = []

    for i in range(len(ceg.ceg_dict)):

            cell_report = node_classification_knn(ceg.ceg_dict[i], 'cellgraph', k_classifier)
            cell_ecm_report = node_classification_knn(ceg.ceg_dict[i], 'cellecmgraph',k_classifier)
            ecm_report = node_classification_knn(ceg.ceg_dict[i], 'ecmgraph', k_classifier)
            
            cell_acc = cell_report.loc['accuracy'].Precision
            cell_pre = cell_report.loc['weighted avg'].Precision
            cell_rec = cell_report.loc['weighted avg'].Recall
            cell_f1 = cell_report.loc['weighted avg', 'F1-Score']
            cg_reports.append(cell_report)

            cg_accuracy.append(cell_acc)
            cg_precision.append(cell_pre)
            cg_recall.append(cell_rec)
            cg_f1_scores.append(cell_f1)

            ecm_acc = ecm_report.loc['accuracy'].Precision
            ecm_pre = ecm_report.loc['weighted avg'].Precision
            ecm_rec = ecm_report.loc['weighted avg'].Recall
            ecm_f1 = ecm_report.loc['weighted avg', 'F1-Score']
            ecm_reports.append(ecm_report)

            ecm_accuracy.append(ecm_acc)
            ecm_precision.append(ecm_pre)
            ecm_recall.append(ecm_rec)
            ecm_f1_scores.append(ecm_f1)
            
            ceg_reports.append(cell_ecm_report)
            cell_ecm_acc = cell_ecm_report.loc['accuracy'].Precision
            cell_ecm_pre = cell_ecm_report.loc['weighted avg'].Precision
            cell_ecm_rec = cell_ecm_report.loc['weighted avg'].Recall
            cell_ecm_f1 = cell_ecm_report.loc['weighted avg', 'F1-Score']


            ceg_accuracy.append(cell_ecm_acc)
            ceg_precision.append(cell_ecm_pre)
            ceg_recall.append(cell_ecm_rec)
            ceg_f1_scores.append(cell_ecm_f1)

    cg_reports_df = pd.concat(cg_reports)
    cg_final = cg_reports_df.groupby(cg_reports_df.index).mean().round(4)
    cg_final.to_csv(save_folder+'cg_celltypes.csv')

    ecm_reports_df = pd.concat(ecm_reports)
    ecm_final = ecm_reports_df.groupby(ecm_reports_df.index).mean().round(4)
    ecm_final.to_csv(save_folder+'ecm_celltypes.csv')

    ceg_reports_df = pd.concat(ceg_reports)
    ceg_final = ceg_reports_df.groupby(ceg_reports_df.index).mean().round(4)
    ceg_final.to_csv(save_folder+'ceg_celltypes.csv')
    
    # create table of 
    cg_macro_metrics = get_macro_metrics(cg_final, 'Cell Graphs')
    ceg_macro_metrics = get_macro_metrics(ceg_final, 'Cell-ECM Graphs')
    ecm_macro_metrics = get_macro_metrics(ecm_final, 'ECM Graphs')

    final_macro_metrics = pd.concat((cg_macro_metrics, ceg_macro_metrics, ecm_macro_metrics))
    final_macro_metrics.to_csv(save_folder+'macro_metrics.csv')
    cell_f1 = final_macro_metrics.iloc[0,3]
    cell_ecm_f1 = final_macro_metrics.iloc[1,3]