# GIN < FC

In [67]:
import numpy as np
import pickle
import pandas as pd 
import matplotlib.pyplot as plt
from enum import Enum, auto
import typing
from typing import List, Union
import copy

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter

from torch_geometric.nn import GINConv, global_add_pool
from torch_geometric.data import Data, DataLoader

from torchinfo import summary
from sklearn.model_selection import StratifiedKFold

## Load data

In [2]:
DATA_FOLDER = '../data'
PICKLE_FOLDER = '../pickles'

In [3]:
df_metadata = pd.read_csv(f'{DATA_FOLDER}/patients-cleaned.csv', index_col=0)

In [4]:
df_metadata.head(2)

Unnamed: 0,age,sex,target
0,24.75,1,0
1,27.667,1,0


### Select connectivity dataset

In [5]:
THRESHOLD = 0.1                                          # 0.01, 0.05, 0.1, 0.15
N = 20                                                   # 3, 5, 7, 10, 15, 20, 40
CORR_TYPE = 'pearson'                                    # 'pearson', 'spearman', 'partial-pearson'
THRESHOLD_METHOD = 'abs-group-avg-diff'                  # 'abs-sample-diff', 'abs-group-avg-diff'
THRESHOLD_TYPE = 'max'                                   # 'min', 'max' or for kNN 'small', 'large'
KNN = False                                              # Whether all or only top N neigbors are taken

In [6]:
fc_folder = f'{PICKLE_FOLDER}/fc-{CORR_TYPE}{"-knn" if KNN else ""}-{THRESHOLD_METHOD}'

# Try Gini or SGD.
# fc_folder = f'{PICKLE_FOLDER}/fc-{CORR_TYPE}-gini'
# fc_folder = f'{PICKLE_FOLDER}/fc-{CORR_TYPE}-sgd'

In [7]:
fc_file_binary = f'{fc_folder}/{THRESHOLD_TYPE}-{f"knn-{N}" if KNN else f"th-{THRESHOLD}"}-binary.pickle'
fc_file_real = f'{fc_folder}/{THRESHOLD_TYPE}-{f"knn-{N}" if KNN else f"th-{THRESHOLD}"}-real.pickle'

# fc_file_binary = f'{fc_folder}/binary.pickle'
# fc_file_real = f'{fc_folder}/real.pickle'

In [8]:
with open(fc_file_binary, 'rb') as f:
    edge_index_matrix = pickle.load(f)

In [9]:
with open(fc_file_real, 'rb') as f:
    fc_matrix = pickle.load(f)

In [10]:
total_samples, total_brain_regions, _ = edge_index_matrix.shape
edge_index_matrix.shape

(190, 90, 90)

In [11]:
fc_matrix.shape

(190, 90, 90)

## Split data

In [12]:
with open(f'{PICKLE_FOLDER}/test-indices.pickle', 'rb') as f:
    test_indices = pickle.load(f)
    
train_indices = list(set(range(total_samples)) - set(test_indices))

In [13]:
train_targets = df_metadata.iloc[train_indices]["target"].reset_index(drop=True)

In [14]:
print(f'Train set size: {len(train_indices)}')
print(f'Test set size: {len(test_indices)}')

Train set size: 140
Test set size: 50


## Prepare data

In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

### `Data` object fields

- `data.x`: Node feature matrix with shape `[num_nodes, num_node_features]`

- `data.edge_index`: Graph connectivity in COO format with shape `[2, num_edges]` and type `torch.long`

- `data.edge_attr`: Edge feature matrix with shape `[num_edges, num_edge_features]`

- `data.y`: Target to train against (may have arbitrary shape), e.g., node-level targets of shape `[num_nodes, *]` or graph-level targets of shape `[1, *]`

- `data.pos`: Node position matrix with shape `[num_nodes, num_dimensions]`

### Select node features

- onehot
- correlations

In [16]:
# Each nodes contains its row from FC matrix.
def correlations_in_nodes(i):
    return torch.from_numpy(fc_matrix[i]).to(torch.float32)

In [17]:
# Each brain region is onehot encoded. See GIN for phenotype paper.
def onehot_in_nodes(i):
    return torch.diag(torch.ones(total_brain_regions))

In [18]:
features_in_nodes = correlations_in_nodes
num_features_in_nodes = total_brain_regions   

### Create dataset

In [19]:
dataset = [Data(
    x=features_in_nodes(i),  
    edge_index=torch.from_numpy(np.asarray(np.nonzero(edge_index_matrix[i]))).to(torch.int64),
    # y=torch.tensor([[1, 0]]  if target == 0 else [[0, 1]], dtype=torch.int64)
    y=torch.tensor([target], dtype=torch.int64)
).to(device) for target, i in zip(train_targets, train_indices)]

In [20]:
print(f'True train data: {len(dataset)}')

print('Data object')
print(f'Edge index: {dataset[0].edge_index.shape}')
print(f'Node features: {dataset[0].x.shape}')
print(f'Target: {dataset[0].y.shape}')

True train data: 140
Data object
Edge index: torch.Size([2, 7716])
Node features: torch.Size([90, 90])
Target: torch.Size([1])


## Define GIN < FC architectures

In [50]:
class ConnectivityEmbedding(nn.Module):
    """
    Learns connectivity between nodes. For each node a weighted combination of all the nodes is learned.

    Input: [batch_size, num_nodes, num_features]
    Output: [batch_size, num_nodes, num_features]
    """
    def __init__(self, size, dropout: 0.0):
        super(ConnectivityEmbedding, self).__init__()
        # Initialize with fully connected graph.
        self.fc_matrix = nn.Parameter(torch.ones(size, size), requires_grad=True)
        self.dropout = nn.Dropout(p=dropout)

    def toggle_gradients(self, requires_grad):
        self.fc_matrix.requires_grad = requires_grad


    def forward(self, x):
        # There is no non-linearity since we are just combining nodes.
        return self.dropout(torch.matmul(self.fc_matrix, x))


In [52]:
class ConnectivityMLP(nn.Module):
    """
    Runs node features through MLP.

    Input: [batch_size, num_nodes, num_in_features]
    Output: [batch_size, num_nodes, num_out_features]
    """
    def __init__(self, size_in, size_out, dropout):
        super(ConnectivityMLP, self).__init__()
        self.fc = nn.Linear(size_in, size_out)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        return F.relu(self.dropout(self.fc(x)))

In [54]:
class ConnectivityMode(Enum):
    """
    Determines how is connectivity matrix obtained.

    FIXED: Use handmade FC matrix.
    START: Learn FC matrix only on raw input features.
    SINGLE: Learn FC matrix on raw input features as well as all subsequent feature mapping layers.
    MULTIPLE: Learn new FC matrix before every feature mapping layer.
    """
    FIXED = auto()
    START = auto()
    SINGLE = auto()
    MULTIPLE = auto()


In [53]:
class ConnectivitySublayer(nn.Module):
    """
    Combines neighborhood connectivity with MLP transformation.

    Input: [batch_size, num_nodes, num_in_features]
    Output: [batch_size, num_nodes, num_out_features]
    """
    def __init__(self, id: int, size_in: int, size_out: int, dropout: float, mode: ConnectivityMode, **mode_kwargs):
        super(ConnectivitySublayer, self).__init__()
        # Create new FC matrix for this sublayer.
        if mode == ConnectivityMode.MULTIPLE:
            self.fc_matrix = ConnectivityEmbedding(size_in, dropout=mode_kwargs['dropout'])
        # Used passed in FC matrix.
        else:
            self.fc_matrix = mode_kwargs['fc_matrix']

        # Feature mapping layer.
        self.mlp = ConnectivityMLP(size_in, size_out, dropout)

    def forward(self, x):
        # Aggregate feature vectors based on connectivity neighborhood.
        x = self.fc_matrix(x)
        # Map features.
        x = self.mlp(x)
        return x       


In [76]:
class ConnectivityDenseNet(nn.Module):
    """
    Emulates Graph isomorphism network using a fully connected alternative.
    """
    def __init__(
        self, 
        num_nodes: int, 
        mode: ConnectivityMode, 
        num_in_features: int, 
        num_hidden_features: Union[int, List[int]],
        dropout: float = 0.5,
        connectivity_dropout: float = 0.0, 
        num_sublayers: int = -1,
        readout: str = 'add', 
        **mode_kwargs
    ):
        super(ConnectivityDenseNet, self).__init__()

        self.fc_matrix = None
        # Set passed in FC matrix.
        if mode == ConnectivityMode.FIXED:
            self.fc_matrix = mode_kwargs['fc_matrix']
        # Create single FC matrix that will be learned only at the beggining.
        # or
        # Create single FC matrix that will be learned throughout.
        elif (mode == ConnectivityMode.START) or (mode == ConnectivityMode.SINGLE):
            self.fc_matrix = ConnectivityEmbedding(num_nodes, dropout=connectivity_dropout)
        # Else `ConnectivityMode.MULTIPLE`, let each sublayer create its own FC matrix.
        self.mode_kwargs = {
            'fc_matrix': self.fc_matrix,
            'dropout': connectivity_dropout
        }

        # Prepare feature mapping dimensions.
        if type(num_hidden_features) is int:
            num_out_features = np.repeat(num_hidden_features, num_sublayers)
        num_in_features = copy.copy(num_out_features)
        num_out_features[0] = num_nodes

        # Create model stacked from sublayers: connectivity + feature mapping.
        self.sublayers = nn.ModuleList([
            ConnectivitySublayer(
                size_in, size_out, dropout=dropout, mode=mode, mode_kwargs=self.mode_kwargs
            ) for size_in, size_out in zip(num_in_features, num_out_features)
        ])

        # Classification head.
        self.readout = readout
        self.fc = nn.Linear(num_out_features[-1], 2)
        

    def forward(self, x):
        # Run sample through model.
        for sublayer in self.sublayers:
            x = sublayer(x)

        # Binary classification head.
        # Readout across nodes.
        if self.readout == 'add':
            x = torch.sum(x, dim=1)
        elif self.readout == 'mean':
            x = torch.mean(x, dim=1)
        if self.readout == 'max':
            x = torch.max(x, dim=1)

        # Return binary logits.
        return self.fc(x)
        


In [28]:
class MLP(torch.nn.Module):
    def __init__(self, inchan, outchan):
        super(MLP, self).__init__()
        
        self.fc = Linear(inchan, outchan)
        self.activation = ReLU()
        
    def forward(self, x):
        x = self.activation(self.fc(x))
        
        return x

In [44]:
class GIN(torch.nn.Module):
    
    def __init__(self, depth, hidchan=total_brain_regions, eps=0.):
        super(GIN, self).__init__()

        self.convs = torch.nn.ModuleList([GINConv(MLP(hidchan, hidchan), eps=eps) for _ in range(depth)])
        self.final_fc = Linear(hidchan, 2)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        # Message passing.
        for conv in self.convs:
            x = conv(x, edge_index)

        # Readout.
        x = global_add_pool(x, batch)

        # Final FC.
        x = self.final_fc(x)
        
        return x

In [45]:
# Architecture FC.
# summary(FC(depth=3))

In [46]:
# Architecture GIN.
summary(GIN(depth=3))

Layer (type:depth-idx)                   Param #
GIN                                      --
├─ModuleList: 1-1                        --
│    └─GINConv: 2-1                      --
│    │    └─MLP: 3-1                     8,190
│    └─GINConv: 2-2                      --
│    │    └─MLP: 3-2                     8,190
│    └─GINConv: 2-3                      --
│    │    └─MLP: 3-3                     8,190
├─Linear: 1-2                            182
Total params: 24,752
Trainable params: 24,752
Non-trainable params: 0

## Evaluation

In [47]:
def evaluation_metrics(predicted, labels):
    pred_positives = predicted == 1
    label_positives = labels == 1

    tp = (pred_positives & label_positives).sum().item()
    tn = (~pred_positives & ~label_positives).sum().item()
    fp = (pred_positives & ~label_positives).sum().item()
    fn = (~pred_positives & label_positives).sum().item()

    return tp, tn, fp, fn

## Train model

In [48]:
NUM_FOLDS = 3

In [49]:
skf = StratifiedKFold(n_splits=NUM_FOLDS, random_state=42, shuffle=True)

In [50]:
# Settings.
EPOCHS = 200
LR = 0.001
MOMENTUM = 0.5
OPTIMIZER = 'adam'
LOSS = 'ce'
BATCH_SIZE = 2

VALIDATE_FREQ = 10

DEPTH = 3
EPS = 0.2

STEP_SIZE = 50
GAMMA = 0.5

WEIGHT_DECAY = 0.0001

settings_str = f'bs={BATCH_SIZE},e={EPOCHS},lr={LR},mom={MOMENTUM},opt={OPTIMIZER},loss={LOSS},step={STEP_SIZE},gamma={GAMMA},wd={WEIGHT_DECAY},eps={EPS}'

In [51]:
# Experiment folder.
EXP_FOLDER = 'runs/fc-vs-gin'

In [52]:
# Experiment.
EXP_ID = 1

In [53]:
for kfold, (train_index, val_index) in enumerate(skf.split(np.zeros(len(train_targets)), train_targets)):

    # Init TB writer.
    experiment_str = f'id={EXP_ID:03d},fold={kfold},{settings_str}'
    writer_FC = SummaryWriter(f"../{EXP_FOLDER}/FC/{experiment_str}")
    writer_GIN = SummaryWriter(f"../{EXP_FOLDER}/GIN/{experiment_str}")

    # Init models.
    net_GIN = GIN(depth=DEPTH, eps=EPS).to(device)
    optimizr_GIN = torch.optim.Adam(net_GIN.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    criterion_GIN = torch.nn.CrossEntropyLoss()

    net_FC = FC(depth=DEPTH).to(device)
    optimizr_FC = torch.optim.Adam(net_FC.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    criterion_FC = torch.nn.CrossEntropyLoss()

    # Save architecture.
    with open(f"../{EXP_FOLDER}/GIN/{experiment_str}/architecture", 'w', encoding="utf-8") as f:
        f.write(fc_folder + '\n')
        f.write(fc_file_binary + '\n')
        f.write(fc_file_real + '\n')
        f.write(features_in_nodes.__str__() + '\n')
        f.write('\n'.join(experiment_str.split(',')) + '\n\n')
        f.write(net_FC.__str__() + '\n\n')
        f.write(net_GIN.__str__() + '\n\n')
        f.write(str(summary(net_FC)))
        f.write(str(summary(net_GIN)))

    # Prepare data.
    X_train = [dataset[i] for i in train_index]
    X_val = [dataset[i] for i in val_index]
    
    trainloader = DataLoader(X_train, batch_size=BATCH_SIZE, shuffle=True)
    valloader = DataLoader(X_val, batch_size=BATCH_SIZE, shuffle=False)

    # Train.
    for epoch in range(EPOCHS):
        running_loss_GIN = 0.
        running_loss_FC = 0.
        
        net_FC.train()
        net_GIN.train()

        for data in trainloader:
            
            optimizr_GIN.zero_grad()
            optimizr_FC.zero_grad()

            outputs_GIN = net_GIN(data)
            outputs_FC = net_FC(data)

            loss_GIN = criterion_GIN(outputs_GIN, data.y)
            loss_FC = criterion_FC(outputs_FC, data.y)

            loss_GIN.backward()
            loss_FC.backward()

            optimizr_GIN.step()
            optimizr_FC.step()

            running_loss_GIN += loss_GIN.item()
            running_loss_FC += loss_FC.item()

        epoch_loss_GIN = running_loss_GIN / len(trainloader)
        epoch_loss_FC = running_loss_FC / len(trainloader)

        writer_GIN.add_scalar('training loss', epoch_loss_GIN, epoch)
        writer_FC.add_scalar('training loss', epoch_loss_FC, epoch)

        running_loss_GIN = 0.
        running_loss_FC = 0.

        # Evaluate epoch.
        tp_GIN, tn_GIN, fp_GIN, fn_GIN = 0, 0, 0, 0
        tp_FC, tn_FC, fp_FC, fn_FC = 0, 0, 0, 0
        total = 0

        net_GIN.eval()
        net_FC.eval()
        for data in valloader:
            optimizr_GIN.zero_grad()
            optimizr_FC.zero_grad()

            outputs_GIN = net_GIN(data)
            outputs_FC = net_GIN(data)

            loss_GIN = criterion_GIN(outputs_GIN, data.y)
            loss_FC = criterion_FC(outputs_FC, data.y)

            running_loss_GIN += loss_GIN.item()
            running_loss_FC += loss_FC.item()

            if (epoch+1) % VALIDATE_FREQ == 0:
                predicted_GIN = outputs_GIN.argmax(dim=1)
                predicted_FC = outputs_FC.argmax(dim=1)

                # labels = torch.nonzero(data.y, as_tuple=True)[1]
                labels = data.y.view(-1)

                # Update.
                _tp_GIN, _tn_GIN, _fp_GIN, _fn_GIN = evaluation_metrics(predicted_GIN, labels)
                _tp_FC, _tn_FC, _fp_FC, _fn_FC = evaluation_metrics(predicted_FC, labels)

                tp_GIN += _tp_GIN; tn_GIN += _tn_GIN; fp_GIN += _fp_GIN; fn_GIN += _fn_GIN
                tp_FC += _tp_FC; tn_FC += _tn_FC; fp_FC += _fp_FC; fn_FC += _fn_FC

                total += data.y.size(0)

        epoch_loss_GIN = running_loss_GIN / len(valloader)
        epoch_loss_FC = running_loss_FC / len(valloader)

        writer_GIN.add_scalar('validation loss', epoch_loss_GIN, epoch)
        writer_FC.add_scalar('validation loss', epoch_loss_FC, epoch)

        if (epoch+1) % VALIDATE_FREQ == 0:
            writer_GIN.add_scalar('validation accuracy', (tp_GIN + tn_GIN) / total, epoch)
            writer_GIN.add_scalar('validation precision', tp_GIN / (tp_GIN + fp_GIN) if (tp_GIN + fp_GIN) > 0 else 0, epoch)
            writer_GIN.add_scalar('validation recall', tp_GIN / (tp_GIN + fn_GIN), epoch)
    
            writer_FC.add_scalar('validation accuracy', (tp_FC + tn_FC) / total, epoch)
            writer_FC.add_scalar('validation precision', tp_FC / (tp_FC + fp_FC) if (tp_FC + fp_FC) > 0 else 0, epoch)
            writer_FC.add_scalar('validation recall', tp_FC / (tp_FC + fn_FC), epoch)

    # Single fold during exploration.
    #break

print('Finished training')

TypeError: linear(): argument 'input' (position 1) must be Tensor, not tuple