In [None]:
from utils import *
import torch # type: ignore
from models import SAGEnorm, GATnorm, GraphTransformernorm
import os
from torch.optim.lr_scheduler import StepLR # type: ignore
from datetime import date
import seaborn as sns
from scipy.stats import mode
import networkx as nx
import numpy as np # type: ignore
import matplotlib.pyplot as plt # type: ignore

### Load GT focal and graph

In [None]:
set_seed(222)
H = load_trans()
train_loader, test_loader, val_loader, dataset_loader = mask_and_batch_tran_interpret(H)

def create_model_loss(config, loss_type="bce", alpha=None, gamma=None):
    if config["model_type"] == 'gat':
        model = GATnorm(config["hidden_size"], config["num_layers"], config["dropout"], config["activation_function"], config["num_heads"], loss_type=loss_type, alpha=alpha, gamma=gamma)
    elif config["model_type"] == 'graphsage':
        model = SAGEnorm(config["hidden_size"], config["num_layers"], config["dropout"], config["activation_function"], loss_type=loss_type, alpha=alpha, gamma=gamma)
    elif config["model_type"] == 'graphtransformer':
        model = GraphTransformernorm(config["hidden_size"], config["num_layers"], config["dropout"], config["activation_function"], config["num_heads"], loss_type=loss_type, alpha=alpha, gamma=gamma)
    print(model)
    return model

device = 'cpu'
#'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
# loading special loss models
def load_checkpoint(basemodel_path, checkpoint_path, test_loader=None, load_state_dicts=True, loss_type='focal', alpha=0.75, gamma=1, device='cpu'):
    base_model = torch.load(basemodel_path) #, map_location=device
    print(base_model["config"])

    checkpoint = torch.load(checkpoint_path) #, map_location=device
    print(checkpoint["config"]["loss_type"], checkpoint["config"]["alpha"])

    model_loaded = create_model_loss(base_model["config"], loss_type=loss_type, alpha=alpha, gamma=gamma)

    if load_state_dicts:
        model_loaded.load_state_dict(checkpoint["model_state_dict"])

    optimizer = set_optim(base_model["config"], model_loaded)

    if load_state_dicts:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)

    scheduler = StepLR(optimizer, step_size=1, gamma=0.95)

    if load_state_dicts:
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])

    model_loaded.to(device)
    model_loaded.eval()

    # first_batch = next(iter(test_loader))
    # with torch.inference_mode():
    #     first_batch.to(device)
    #     loaded_model_output = model_loaded(first_batch)

    return model_loaded, optimizer, scheduler #, loaded_model_output

In [None]:
basemodel_path = r"BASEMODEL_PATH"
checkpoint_path = r"CHECKPOINT_PATH" 
model_loaded, optimizer, scheduler = load_checkpoint(basemodel_path, checkpoint_path)

### Interpretability

In [None]:
import pickle as pkl
graph_filepath = r"SIMILARITY_GRAPH_PATH"

with open(graph_filepath, 'rb') as file:
            G = pkl.load(file)

type(G)

In [None]:
def get_patient_ids(loader):
    # get patient ids
    set_seed(22)
    return np.array(get_unmasked_node_ids(loader))

def find_node_by_patient_id(patient_id, graph):
    # find node in the graph based on its patient id
    for node, data in graph.nodes(data=True):
        if data.get('patient_id') == patient_id:
            return node
    return None

def get_attention_weights(model, loader):
    # get attention weights from the last layer (attention: modify model class so it can output attention weights)
    model.eval()
    with torch.inference_mode():
        logits, (indices, attention_weights) = model(next(iter(loader))) #single graph batch

    edge_index_np = indices.detach().cpu().numpy().T
    attention_scores_np = attention_weights.detach().cpu().numpy()
    
    return edge_index_np, attention_scores_np

def get_attention_scores(patient_id, graph, model, loader):
    # return attention scores of a source node and its neighbors, based on its patient id
    node_index = find_node_by_patient_id(patient_id, graph)
    if node_index is None:
        print(f"Patient ID {patient_id} not found in the graph.")
        return []
    edge_index_np, attention_scores_np = get_attention_weights(model, loader)
    
    info = []
    neighbors = list(graph.neighbors(node_index))
    print(f"Patient ID: {patient_id}, Node index: {node_index}")
    print(f"Neighbors: {neighbors}")

    for neighbor in neighbors:
        neighbor_patient_id = graph.nodes[neighbor].get('patient_id')
        edge = (node_index, neighbor)
        
        # only consider the edge if node_index is the source (ORDER MATTERS)
        mask = (edge_index_np[:, 0] == node_index) & (edge_index_np[:, 1] == neighbor)

        if not mask.any():
            continue

        print(f"Considering edge: {edge}, Neighbor patient ID: {neighbor_patient_id}")
        #print(f"Mask: {mask}")

        attention_score_indices = np.where(mask)[0]
        print(f"Attention score indices: {attention_score_indices}")

        for attention_score_index in attention_score_indices:
            score = attention_scores_np[attention_score_index].item()
            #hops = nx.shortest_path_length(graph, source=node_index, target=neighbor)
            print(f"Appending: source node index: {node_index}, target node index: {neighbor}, Attention score: {score}") #, Hops: {hops}")

            info.append({
                'source_node_index': node_index,
                'target_node_index': neighbor,
                'source_patient_id': patient_id,
                'target_patient_id': neighbor_patient_id,
                'attention_score': score,
                #'hops': hops
            })

    return info
    

def profile_attention(results, loader, model, graph):
    # match attention info with preds, probs, label based on patient_id, for all nodes in a loader
    patient_ids = get_patient_ids(loader)
    attention_profiles = {}

    for key, instances in results.items(): # profile ; node info (dict)
        if key not in attention_profiles:
            attention_profiles[key] = {
                'patient_ids': [],
                'attention_weights': [],
                'probs': [],
                'preds': [],
                'labels': []
            }

        for instance in instances:
            patient_id = patient_ids[instance['index']]
            attention_scores = get_attention_scores(patient_id, graph, model, loader)

            attention_profiles[key]['patient_ids'].append(patient_id)
            attention_profiles[key]['attention_weights'].extend(attention_scores)
            attention_profiles[key]['probs'].append(instance['prob'])
            attention_profiles[key]['preds'].append(instance['pred'])
            attention_profiles[key]['labels'].append(instance['label'])

    return attention_profiles


def aggregate_attention_profiles(attention_profiles):
    # get mean attention score for each profile
    aggregated_profiles = {}
    for key, profiles in attention_profiles.items():
        attention_scores = []
        for profile in profiles['attention_weights']:
            attention_scores.append(profile['attention_score'])
        aggregated_profiles[key] = np.mean(attention_scores) if attention_scores else 0
    return aggregated_profiles

def compute_statistics(attention_profiles, graph):
    # stats from each profile
    statistics = {}
    for key, profiles in attention_profiles.items():
        #print(profiles)
        degrees = []
        similarities = []
        for profile in profiles['attention_weights']:
            node_index = find_node_by_patient_id(profile['source_patient_id'], graph)
            if node_index is not None:
                degrees.append(graph.degree[node_index])
                for neighbor in graph.neighbors(node_index):
                    edge = (node_index, neighbor)
                    if graph.has_edge(*edge):
                        edge_feature = graph.edges[edge].get('weight', 0) # weight=0 if not specified
                        similarities.append(edge_feature)
        
        if degrees:
            degree_mean = np.mean(degrees)
            degree_std = np.std(degrees)
            degree_mode = mode(degrees)[0][0] if degrees else None
            degree_median = np.median(degrees)
        else:
            degree_mean = degree_std = degree_mode = 0

        if similarities:
            similarity_mean = np.mean(similarities)
            similarity_std = np.std(similarities)
            similarity_mode = mode(similarities)[0][0] if similarities else None
            similarity_median = np.median(similarities)
        else:
            similarity_mean = similarity_std = similarity_mode = 0

        statistics[key] = {
            'average_degree': degree_mean,
            'degree_std': degree_std,
            'degree_mode': degree_mode,
            'degree_median': degree_median,
            'average_similarity': similarity_mean,
            'similarity_std': similarity_std,
            'similarity_mode': similarity_mode,
            'similarity_median': similarity_median,
        }
    return statistics

def visualize_statistics(statistics):
    for key, stats in statistics.items():
        print(f"Statistics for {key}:")
        print(f"  Average degree: {stats['average_degree']:.5f}")
        print(f"  Degree sd: {stats['degree_std']:.5f}")
        print(f"  Degree mode: {stats['degree_mode']:.5f}")
        print(f"  Degree median: {stats['degree_median']:.5f}")
        print(f"  Avg similarity: {stats['average_similarity']:.5f}")
        print(f"  Similarity sd: {stats['similarity_std']:.5f}")
        print(f"  Similarity mode: {stats['similarity_mode']:.5f}")
        print(f"  Similarity median: {stats['similarity_median']:.5f}")

def visualize_attention_profiles(aggregated_profiles):
    labels = list(aggregated_profiles.keys())
    weights = list(aggregated_profiles.values())

    sorted_indices = np.argsort(weights)
    sorted_labels = [labels[i] for i in sorted_indices]
    sorted_weights = [weights[i] for i in sorted_indices]

    plt.figure(figsize=(15, 8))
    bar_positions = np.arange(len(sorted_labels))
    bar_width = 0.4

    plt.bar(bar_positions, sorted_weights, bar_width, label='Attention weights')

    for i, weight in enumerate(sorted_weights):
        plt.text(i, weight * 0.95, f'{weight:.2f}', ha='center', va='bottom', fontsize=10)

    plt.title('Average attention weights')
    plt.xlabel('Profiles')
    plt.ylabel('Attention weight')
    plt.xticks(bar_positions, sorted_labels)
    plt.ylim(0, max(sorted_weights) * 1.1)  
    plt.legend()
    plt.show()

def plot_attention_distribution(attention_profiles):
    plt.figure(figsize=(15, 10))
    
    for key, profiles in attention_profiles.items():
        attention_scores = [p['attention_score'] for p in profiles['attention_weights']]
        sns.kdeplot(attention_scores, label=key, shade=True)
    
    plt.title('Attention score distribution by profile')
    plt.xlabel('Attention score')
    plt.ylabel('Density')
    plt.legend()
    plt.show()

def detailed_report(attention_profiles, graph, patient_ids=None):
    if patient_ids is not None:
        filtered_attention_profiles = {}
        for key, profile in attention_profiles.items():
            filtered_weights = [p for p in profile['attention_weights'] if p['source_patient_id'] in patient_ids]
            if filtered_weights:
                filtered_profile = {
                    'patient_ids': [pid for pid in profile['patient_ids'] if pid in patient_ids],
                    'attention_weights': filtered_weights,
                    'probs': [profile['probs'][i] for i, pid in enumerate(profile['patient_ids']) if pid in patient_ids],
                    'preds': [profile['preds'][i] for i, pid in enumerate(profile['patient_ids']) if pid in patient_ids],
                    'labels': [profile['labels'][i] for i, pid in enumerate(profile['patient_ids']) if pid in patient_ids]
                }
                filtered_attention_profiles[key] = filtered_profile
        attention_profiles = filtered_attention_profiles
    
    aggregated_profiles = aggregate_attention_profiles(attention_profiles)
    statistics = compute_statistics(attention_profiles, graph)
    visualize_statistics(statistics)

In [None]:
# return ALL instances from each profile: index, id, prob, pred, label
def return_results(probs, preds, labels, loader): 
    set_seed(22)
    results = {}
    patient_ids = get_unmasked_node_ids(loader)

    # true positives
    true_positives = np.where((preds == 1) & (labels == 1))[0]
    results['True positive (TP)'] = [{
        'index': idx,
        'patient_id': patient_ids[idx],
        'prob': probs[idx],
        'pred': preds[idx],
        'label': labels[idx]
    } for idx in true_positives]

    # true negatives
    true_negatives = np.where((preds == 0) & (labels == 0))[0]
    results['True negative (TN)'] = [{
        'index': idx,
        'patient_id': patient_ids[idx],
        'prob': probs[idx],
        'pred': preds[idx],
        'label': labels[idx]
    } for idx in true_negatives]

    # false positives
    false_positives = np.where((preds == 1) & (labels == 0))[0]
    results['False positive (FP)'] = [{
        'index': idx,
        'patient_id': patient_ids[idx],
        'prob': probs[idx],
        'pred': preds[idx],
        'label': labels[idx]
    } for idx in false_positives]

    # false negatives
    false_negatives = np.where((preds == 0) & (labels == 1))[0]
    results['False negative (FN)'] = [{
        'index': idx,
        'patient_id': patient_ids[idx],
        'prob': probs[idx],
        'pred': preds[idx],
        'label': labels[idx]
    } for idx in false_negatives]

    return results

def get_predictions(dataset_loader):
    set_seed(22)
    _, _, probs, preds, labels = test(dataset_loader, model_loaded, device)
    return np.array(probs), np.array(preds), np.array(labels)

# remember - for ALL instances (train, test, val)
probs,preds,labels = get_predictions(dataset_loader)
results = return_results(probs, preds, labels, dataset_loader)
results

In [None]:
if os.path.exists('attention_profiles.pkl'):
    with open('attention_profiles.pkl', 'rb') as pickle_file:
        attention_profiles = pkl.load(pickle_file)
else:
    attention_profiles = profile_attention(results, test_loader, model_loaded, G) # this is with test loader. later, full dataset+filtering
    with open('attention_profiles.pkl', 'wb') as pickle_file:
        pkl.dump(attention_profiles, pickle_file)

In [None]:
detailed_report(attention_profiles, G)

In [None]:
plot_attention_distribution(attention_profiles)

## Full dataset

Get info from full graph; then filter and plot test data.

In [None]:
import torch
from torch.nn import ModuleList, Linear, BatchNorm1d
from torch_geometric.nn import TransformerConv

#rerun "checkpoint load"
class GraphTransformernorm(torch.nn.Module):
    def __init__(self, hidden_size, num_layers, dropout, activation_function, num_heads, loss_type="bce", alpha=0, gamma=0):
        super(GraphTransformernorm, self).__init__()
        torch.manual_seed(22)
        self.alpha = alpha
        self.gamma = gamma
        self.num_layers = num_layers
        self.dropout = dropout
        self.num_heads = num_heads
        self.loss_type = loss_type
        self.input_size = 300

        if activation_function == 'relu':
            self.activation = F.relu
        elif activation_function == 'leaky_relu':
            self.activation = F.leaky_relu
        else:
            raise ValueError("Unsupported activation function")

        self.convs = ModuleList()
        self.bns = ModuleList()  
        self.convs.append(TransformerConv(self.input_size, hidden_size, heads=self.num_heads))
        self.bns.append(BatchNorm1d(hidden_size * self.num_heads))
        for _ in range(1, self.num_layers - 1):
            self.convs.append(TransformerConv(hidden_size * self.num_heads, hidden_size, heads=self.num_heads))
            self.bns.append(BatchNorm1d(hidden_size * self.num_heads))
        self.convs.append(TransformerConv(hidden_size * self.num_heads, hidden_size, heads=1, concat=False))
        self.bns.append(BatchNorm1d(hidden_size))  

        self.post_mp = Linear(hidden_size, 1)

    def forward(self, data):
        x, edge_index, batch = data.node_feature, data.edge_index, data.batch

        for i in range(len(self.convs) - 1):
            x = self.convs[i](x, edge_index)
            x = self.bns[i](x)  
            x = self.activation(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x, attention_weights = self.convs[-1](x, edge_index, return_attention_weights= True) #,return_attention_weights= True
        
        x = self.bns[-1](x)  
        x = self.post_mp(x)
        return x, attention_weights

In [None]:
if os.path.exists('attention_profiles_fulldataset.pkl'):
    with open('attention_profiles_fulldataset.pkl', 'rb') as pickle_file:
        attention_profiles_fulldataset = pkl.load(pickle_file)
else:
    with open('attention_profiles_fulldataset.pkl', 'wb') as pickle_file:
        attention_profiles_fulldataset = profile_attention(results, dataset_loader, model_loaded, G)
        pkl.dump(attention_profiles_fulldataset, pickle_file)

In [None]:
test_patient_ids = [pid for profiles in attention_profiles.values() for pid in profiles['patient_ids']] # getting ids from the test set
print(test_patient_ids, len(test_patient_ids)) #952

In [None]:
11745 in test_patient_ids

In [None]:
def filter_attention_profiles_by_test_ids(attention_profiles_fulldataset, test_patient_ids, results, graph):
    filtered_profiles = {}
    all_instances = [item for sublist in results.values() for item in sublist]

    for key, profile in attention_profiles_fulldataset.items():
        filtered_weights = [p for p in profile['attention_weights'] if p['source_patient_id'] in test_patient_ids] # keep source nodes that are in test set
        if filtered_weights:
            profile_test_patient_ids = [pid for pid in profile['patient_ids'] if pid in test_patient_ids]
            filtered_profiles[key] = {
                'patient_ids': profile_test_patient_ids,
                'attention_weights': filtered_weights,
                'probs': [instance['prob'] for instance in all_instances if instance['patient_id'] in profile_test_patient_ids],
                'preds': [instance['pred'] for instance in all_instances if instance['patient_id'] in profile_test_patient_ids],
                'labels': [instance['label'] for instance in all_instances if instance['patient_id'] in profile_test_patient_ids],
                'neighbor_labels': []
            }
            for weight in filtered_weights:
                source_node = find_node_by_patient_id(weight['source_patient_id'], graph)
                if source_node is not None:
                    neighbor_labels = []
                    for neighbor in graph.neighbors(source_node):
                        neighbor_patient_id = graph.nodes[neighbor].get('patient_id')
                        neighbor_label = next((instance['label'] for instance in all_instances if instance['patient_id'] == neighbor_patient_id), None)
                        if neighbor_label is not None:
                            neighbor_labels.append(neighbor_label)
                    filtered_profiles[key]['neighbor_labels'].append(neighbor_labels)
    return filtered_profiles

filtered_profiles = filter_attention_profiles_by_test_ids(attention_profiles_fulldataset, test_patient_ids, results, G)
print(filtered_profiles)

In [None]:
import pprint
pprint.pprint(filtered_profiles)

In [None]:
def plot_attention_distribution(filtered_profiles):
    label_colors = {0: '#3A49FF', 1: "#FFF800"}  

    for key, profile in filtered_profiles.items():
        attention_weights = profile['attention_weights']
        neighbor_labels = profile['neighbor_labels']
        
        neighbor_label_distribution = {}
        
        # attention scores based on neighbor labels
        for i, weights in enumerate(attention_weights):
            attention_score = weights['attention_score']
            if i < len(neighbor_labels):
                for neighbor_label in neighbor_labels[i]:
                    if neighbor_label not in neighbor_label_distribution:
                        neighbor_label_distribution[neighbor_label] = []
                    neighbor_label_distribution[neighbor_label].append(attention_score)
        
        plt.figure(figsize=(10, 6))
        
        # kde for each neighbor label
        for neighbor_label, scores in neighbor_label_distribution.items():
            sns.kdeplot(scores, label=f'Neighbor label: {neighbor_label}', shade=True, color=label_colors[neighbor_label])
        
        plt.title(f'Attention score distribution for {key}')
        plt.xlabel('Attention score')
        plt.ylabel('Density')
        plt.legend()
        plt.show()

plot_attention_distribution(filtered_profiles)

In [None]:
def plot_neighbor_label_distribution(filtered_profiles):
    label_colors = {0: '#636EFA', 1: "#E8E337"}  

    for key, profile in filtered_profiles.items():
        neighbor_labels = profile['neighbor_labels']
        
        neighbor_label_counts = {0: 0, 1: 0}
        
        for labels in neighbor_labels:
            for label in labels:
                neighbor_label_counts[label] += 1
        
        plt.figure(figsize=(10, 6))
        plt.bar(neighbor_label_counts.keys(), neighbor_label_counts.values(), color=[label_colors[label] for label in neighbor_label_counts.keys()], alpha=0.7)
        
        plt.xticks([0, 1], ['Label 0', 'Label 1'])
        plt.title(f'Neighbor label distribution for {key}')
        plt.xlabel('Neighbor label')
        plt.ylabel('Count')
        plt.show()
plot_neighbor_label_distribution(filtered_profiles)

In [None]:
detailed_report(attention_profiles_fulldataset, G, patient_ids=test_patient_ids)