# Graph Neural Network
---
> Graph neural network model for vertex dynamics and tension prediction

**To-Do**👷🚧

- *Training loop*:
    - [ ] Training loop w/ validation set error monitoring, and best model saving
    - [ ] Combine `Message` and `AggregateUpdate` into a graph layer `GraphBlock` (a more general block/model that can be composed into a deep residual network). "AddGN" block, w/ `AddGN(x) = f(x)+x` form (in fact, where it's possible make all blocks with this form).
- [ ] *Prediction stage*: read \{test, val, train\} data and predict w/ saving.
- [ ] Ablation dataset (*real*).
- [ ] Larger simul-n dataset.

**DOING**🛠

1-val data `dataset` obj, and dataloder.

**Node-to-Cell Encoding/Pooling Layer**:
1. Initiate node-to-cell edge attr-s as (source) node attr-s `x[node2cell_index[0]]`.
1. Compute node-to-cell edge attr-s using MLP: `e_n2c = MLP( x[node2cell_index[0]] )`
1. Aggregate node-to-cell edge attr-s as cell attr-s : `x_cell = Aggregate(e_n2c)`
1. Compute new cell attr-s using (encodes `x_cell` into cell attr-s) : `h_cell = MLP_Cell_encoder( x_cell )`

```python
n2c_model = mlp(...) # "message", just node-wise MLP
cell_aggr = Aggregate()
cell_enc = mlp(...)

e_n2c = n2c_model(data.x)[data.node2cell_index[0]]
x_cell = cell_aggr(data.cell_pressures.size(0), data.node2cell_index, e_n2c)
h_cell = cell_enc(x_cell)
```

**Examples**:
- General "Message Passing" schemes: a nice example for composite graph layer –"meta layer" consisting of "edge", "node" and "global" layers [link](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.meta.MetaLayer)

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch
from os import path

from torch_geometric.data import Data, DataLoader
import torch_geometric.transforms as T

from torch_geometric.utils import to_networkx, from_networkx
import networkx as nx
from simgnn.datautils import load_array, load_graph

import matplotlib.pyplot as plt
import matplotlib
plt.style.use('ggplot')
%matplotlib inline
matplotlib.rcParams['figure.figsize'] = (10,10) # use larger for presentation
matplotlib.rcParams['font.size']= 14 # use 14 for presentation

In [3]:
from simgnn.datasets import VertexDynamics, CellData
from simgnn.nn import mlp, Message, AggregateUpdate, Aggregate
from simgnn.transforms import Pos2Vec, ScaleVelocity
# from torch_geometric.utils import to_undirected as T_undir

In [4]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
dtype = torch.float32
print(f'Defaults:\n |-device: {device}\n |-dtype : {dtype}')

Defaults:
 |-device: cpu
 |-dtype : torch.float32


## Hara Movies

In [5]:
!ls 'simgnn_data/hara_movies/raw/Seg_001/'

edge_Length.npy     edges_index.npy     node2cell_index.npy vtx_pos.npy


In [6]:
l_ij = load_array('simgnn_data/hara_movies/raw/Seg_001/edge_Length.npy')
seg001_edge_index = load_array('simgnn_data/hara_movies/raw/Seg_001/edges_index.npy')
seg001_vtx = load_array('simgnn_data/hara_movies/raw/Seg_001/vtx_pos.npy')
seg001_n2c = load_array('simgnn_data/hara_movies/raw/Seg_001/node2cell_index.npy')

print(f'l_ij : {l_ij.shape}')
print(f'seg001_edge_index : {seg001_edge_index.shape}')
print(f'seg001_vtx : {seg001_vtx.shape}')
print(f'seg001_n2c : {seg001_n2c.shape}')

l_ij : (61, 225, 1)
seg001_edge_index : (2, 225)
seg001_vtx : (61, 164, 2)
seg001_n2c : (2, 492)


In [14]:
n2c_ = torch.from_numpy( seg001_n2c ).type(torch.int64).contiguous()
n2c_.is_contiguous()

True

In [41]:
# for s,t in zip(seg001_vtx[60][seg001_edge_index[0]],seg001_vtx[60][seg001_edge_index[1]]):
#     plt.plot([s[0],t[0]],[s[1],t[1]],'gray')

# for s,t in zip(seg001_vtx[0][seg001_edge_index[0]],seg001_vtx[0][seg001_edge_index[1]]):
#     plt.plot([s[0],t[0]],[s[1],t[1]],'r')
# plt.axis([0,511,511,0])

In [42]:
class HaraMovie(VertexDynamics):
    '''
    Docs
    '''
    def __init__(self, root, window_size=5, transform=None, pre_transform=None):
        '''
        Assumes `root` dir contains folder named `raw` with all vertex dynamics simulation results
        for tracing vertex trajectories, building graphs, and variables for computing edge tensions 
        and cell pressures.
        - Velocities are approximated as 1st order differences of positions `x` in subsequent frames:
          `velocity(T+0) = x(T+1) - x(T+0)`.
        - Use `pre_transform` for normalising and pre-processing dataset(s).
        
        Arg-s:
        - root : path to a root directory that contains folder with raw dataset(s) in a folder named "raw".
        Raw datasets should be placed into separate folders each containing outputs from a single simulation.
        E.g. root contains ["raw", "processed", ...], and in folder "raw/" we should have ["simul1", "simul2", ...]
        - window_size : number of past velocities to be used as node features 
        `[x(T+0)-x(T-1), x(T-1)-x(T-2),..., x(T-window_size+1)-x(T-window_size)]`, where `x(T)` is node position at time `T`.
        - transform :  transform(s) for graph datasets (e.g. from torch_geometric.transforms ), used in parent class' loading method.
        - pre_transform : transform(s) for data pre-processing (resulting graphs are saved in "preprocessed" folder)
        and used as this dataset's sample graphs.
        '''
        super(HaraMovie, self).__init__(root, window_size, transform, pre_transform)

    @property
    def processed_file_names(self):
        '''
        Return list of pytorch-geometric data files in `root/processed` folder (`self.processed_dir`).
        '''
        # "last_idx" : last index of window in "vertex velocity" (for features)
        # last_idx=T-(2+window_size) --> num of processed frames: num_of_frames=last_idx+1 
        nums_of_frames = [ (path.basename(raw_path),
                            load_array(path.join(raw_path,'vtx_pos.npy')).shape[0]-(2+self.window_size)+1 
                           ) for raw_path in self.raw_paths]
        file_names = ['data_{}_{}.pt'.format(raw_path, t)
                      for raw_path, tmax in nums_of_frames for t in range(tmax)]
        return file_names

    def process(self):
        '''
        Assumptions:
        - the parent class init runs _process() and initialises all the required dir-s.
        - cell graph topology and number of nodes is constant w.r.t. to frames.
        '''
        for raw_path in self.raw_paths:
            # a movie in "raw_path"
            # monolayer graph (topology)
            # mg_dict = load_graph(path.join(raw_path,'graph_dict.pkl'))
            
            # Load node positions from raw_path and convert to (windowed) node attrib-s and targets.
            node_pos, X_node, Y_node = self.pos2nodeXY(pos_path = path.join(raw_path,'vtx_pos.npy') )
            
            # edge indices
            edge_index = torch.from_numpy( load_array( path.join( raw_path, 'edges_index.npy'))
                                         ).contiguous()
            
            # cell-to-node and node-to-cell "edge indices"
            node2cell_index = torch.from_numpy( load_array( path.join( raw_path, 'node2cell_index.npy'))
                                              ).contiguous() # node_id-cell_id pairs
            cell2node_index = node2cell_index[[1,0]].contiguous() # cell_id-node_id pairs
            
            mov_name = path.basename(raw_path) # folder name for the files
            N_nodes = node_pos.size(1) # assume constant w.r.t. "t"
            N_cells = max(mg_dict['cells'].keys())+1 # num_of_cells assume constant w.r.t. "t"
            
#             for t in range(node_pos.size(0)):
#                 data = CellData(num_nodes = N_nodes,
#                                 num_cells = N_cells,
#                                 edge_index = edge_index,
#                                 node2cell_index = node2cell_index, cell2node_index = cell2node_index,
#                                 pos = node_pos[t], x = X_node[t], y = Y_node[t],
#                                 cell_pressures = cell_presrs[t],
#                                 edge_tensions = edge_tensns[t]
#                                )
#                 if self.pre_filter is not None and not self.pre_filter(data):
#                     continue
#                 if self.pre_transform is not None:
#                     data = self.pre_transform(data)
#                 torch.save(data, path.join(self.processed_dir, 'data_{}_{}.pt'.format(sim_name, t)))



In [44]:
hara_movs = HaraMovie('simgnn_data/hara_movies/')

Processing...
Done!


In [66]:
hara_movs.

5

In [64]:
# hara_movs.processed_file_names

## Training w/ Synthetic Data

`CellData` prop-s (graph data objects):
- `x` : `(#nodes, WindowWidth, 2)` *node features*
- `y` : `(#nodes, 2)` *node targets (velocities)*.
- `pos` : `(#nodes, 2)` *node positions*.
- `edge_attr` : `(#edges, 2)` or `(#edges, #edge_features)` *edge features  (relative Cartesian positions of connected nodes)*.
- `edge_index` : `(2, #edges)` *edge indices*.
- `edge_tensions` : `(#edges,)` *edge targets (line tensions)*.
- `node2cell_index` : `(2, #cell2node_edges)`, `node2cell`-> *first row is node indices and second row is cell indices;
- `cell2node_index` : `(2, #cell2node_edges)`, `cell2node`-> *first row is cell indices and second row is node indices*.
- `cell_pressures` : `(#cells,)` *cell targets (cell pressures)*.

In [5]:
# test = VertexDynamics('../../../dataDIR/simgnn_data/test/')

In [7]:
# Normalisation: for simulated data
Tnorm = T.Compose([Pos2Vec(scale=20*0.857) ,
                           ScaleVelocity(0.857)])
# training dataset
vtxdata = VertexDynamics('../../../dataDIR/simgnn_data/train/', transform=Tnorm)

In [12]:
print('',vtxdata,'\n',vtxdata[0])

 VertexDynamics(95) 
 CellData(cell2node_index=[2, 600], cell_pressures=[100], edge_attr=[339, 2], edge_index=[2, 339], edge_tensions=[339], node2cell_index=[2, 600], pos=[240, 2], x=[240, 5, 2], y=[240, 2])


In [7]:
data = vtxdata[0]
data

CellData(cell2node_index=[2, 600], cell_pressures=[100], edge_index=[2, 339], edge_tensions=[339], node2cell_index=[2, 600], pos=[240, 2], x=[240, 5, 2], y=[240, 2])

In [2]:
# data.is_undirected()

In [1]:
# nx.draw(to_networkx(data),pos=dict(enumerate(data.pos.numpy())), node_size=60)

In [44]:
# in order to track the batch id for var-s add its key to "follow_batch":
loader = DataLoader(vtxdata, batch_size=2,follow_batch=['cell_pressures','edge_index'])
# this tracks batch id for "cell_pressures_batch" and "edge_index_batch" in addition to node batch ids
batch = next(iter(loader))
# nx.draw( to_networkx(
#     CellData(num_nodes = torch.sum(batch.batch==0).item(),
#              edge_index = batch.edge_index[:,batch.edge_index_batch==0])),
#     pos=dict(enumerate(batch.pos[batch.batch==0].numpy())),
#     node_size=60, node_color='r',edge_color='r')
# nx.draw( to_networkx(
#     CellData(num_nodes = torch.sum(batch.batch==1).item(),
#              edge_index = batch.edge_index[:,batch.edge_index_batch==1]-240)),
#     pos=dict(enumerate(batch.pos[batch.batch==1].numpy() +1.5)),
#     node_size=60)
batch

Batch(batch=[480], cell2node_index=[2, 1200], cell_pressures=[200], cell_pressures_batch=[200], edge_attr=[678, 2], edge_index=[2, 678], edge_index_batch=[678], edge_tensions=[678], node2cell_index=[2, 1200], pos=[480, 2], x=[480, 5, 2], y=[480, 2])

## Training

**Training Loop**

In [None]:
# useful functions for model training and saving, etc.
import time
import copy
import torch

def train_model(model,
                optimizer,
                data_loaders,
                num_epochs = 5,
                loss_func = torch.nn.CrossEntropyLoss(),
                device = torch.device('cpu'),
                scheduler = None,
                return_best = False,
               classifier=None):
    '''
    docs
    '''
    # model states/modes
    model_states = ['train', 'val']
    training_model=model
    if classifier!=None:
        training_model=classifier # transfer learning
    
    curve_data = {'trainLosses':[],
                 'trainAccs':[],
                 'valLosses':[],
                 'valAccs':[],
                 'total_epochs':num_epochs}
    
    time_start = time.time()
    if return_best:
        best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1} ---', end=' ')
        
        # set model state depending on training/eval stage
        for state in model_states:
            if state == 'train':
                training_model.train()  # Set model to training mode
            else:
                training_model.eval()   # Set model to evaluation mode
            
            running_loss = 0.0
            running_corrects = 0
            
            for samples in data_loaders[state]:
                # input HxW depend on transform function(s), 3 Channels
                inputs = samples['image'].to(device)
                # labels \in [0, 1, 2]
                labels = samples['label'].to(device)          
                
                # set grad accumulator to zero
                optimizer.zero_grad()

                with torch.set_grad_enabled(state == 'train'):
                    # grad tracking is disabled in "eval" mode
                    outputs = model(inputs) # output:(batch, #classes)
                    _, preds = torch.max(outputs, 1) # labels:(batch,)
                    loss = loss_func(outputs, labels) #<-torch.nn.CrossEntropyLoss
                    
                    if state == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0) # weighted loss
                running_corrects += torch.sum(preds == labels.detach() )
            
                # apply LR schedule
                if state == 'train' and scheduler!=None:
                    scheduler.step()

            epoch_loss = running_loss / len(data_loaders[state].dataset)
            epoch_acc = running_corrects.double() / len(data_loaders[state].dataset)
            
            curve_data[f'{state}Losses'].append(epoch_loss)
            curve_data[f'{state}Accs'].append(epoch_acc)
            
            print(f'{state} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}',end=' || ')
            
            # deep copy the model
            if state == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                if return_best:
                    # keep best weights to return
                    best_model_wts = copy.deepcopy(model.state_dict())
        print(f'{time.time() - time_start:.0f}s')
    time_elapsed = time.time() - time_start
    print(f'Training complete in {time_elapsed//60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f} (return best:{return_best})')
    
    if return_best:
        # load best model weights
        model.load_state_dict(best_model_wts)
    
    return model, curve_data


@torch.no_grad()
def predict_check(data_loaders, model,device=torch.device('cpu')):
    '''
    Run prediction on datasets using dataloader
    - data_loaders: data loaders (dict of torch.utils.data.DataLoader objects) with
                    keys 'train' and 'val', for training and validation data loaders respectively.
                    ! DISABLE SHUFFLING in both datasets in order to preserve order of IDs
    - model: model used for prediction
    - device: device, e.g. "torch.device('cuda')"
    '''
    loss_func = torch.nn.CrossEntropyLoss()
    losses = {'train':0, 'val': 0}
    accuracies = {'train':0, 'val':0}
    pred_labels = {'train':[],'val':[]}
    
    model.eval()
    default_device = next(model.parameters()).device
    model.to(device)
    for loader_type in data_loaders:
        print(f'Loading: {loader_type}')
        for samples in data_loaders[loader_type]:
            inputs = samples['image'].to(device) # input images
            labels = samples['label'].to(device) # labels \in [0, 1, 2]       
            # predict
            outputs = model(inputs) # output:(batch, #classes)
            _, preds = torch.max(outputs, 1) # labels:(batch,)
            preds = preds.cpu()
            loss = loss_func(outputs, labels) #<-torch.nn.CrossEntropyLoss
            losses[loader_type] += loss.item() * inputs.size(0) # weighted loss
            accuracies[loader_type] += torch.sum(preds == labels.cpu() )
            pred_labels[loader_type].extend(preds.tolist())
            
        losses[loader_type] = losses[loader_type] / len(data_loaders[loader_type].dataset)
        accuracies[loader_type] = accuracies[loader_type] / len(data_loaders[loader_type].dataset)
    print('Losses:',losses)
    print('Accuracies:', accuracies)
    
    model.to(default_device)
    return losses, accuracies, pred_labels




In [None]:
@torch.no_grad()
def predict_test(root_path, model, transform, batch_size=4, device=torch.device('cpu')):
    '''Run prediction on test images.
    - root_path: path to the folder with test images
    - model: model used for prediction 
    - batch_size: batch size for processing
    - device: device, e.g. "torch.device('cuda')"
    '''
    model.eval()
    model.to(device)
    # list of test image files
    test_image_names= [path.split(imgname)[-1] for imgname in glob.glob(path.join(root_path,'*.png'))]
    # sort according to image ID number
    test_image_names.sort(key=lambda x: int(x.split('.')[0]))
    ID = [int(imgname.split('.')[0]) for imgname in test_image_names]
    N_samples = len(test_image_names)
    print(f'Found {N_samples} images in test dataset folder: {path.join(*path.split(root_path)[:-1])}'+
          f'\n\"{test_image_names[0]}\"\n\"{test_image_names[1]}\"\n\"{test_image_names[2]}\"\n. . .\n'+
          f'\"{test_image_names[-3]}\"\n\"{test_image_names[-2]}\"\n\"{test_image_names[-1]}\"]\n')
    
    # iter over batches
    pred_labels = []
    N_batches = N_samples//batch_size + (1 if N_samples%batch_size else 0)
    print(f'Processing {N_batches} test batches in total (batch_size={batch_size}).')
    for b in range(N_batches):
        last_idx = min([b*batch_size+batch_size,N_samples])
        # read and transform images
        img_batch = torch.stack([transform( imread(path.join(root_path,imgname)) )
                     for imgname in test_image_names[b*batch_size:last_idx]],dim=0)
        img_batch = img_batch.to(device)
        # predict
        outputs = model(img_batch) # output:(batch, #classes)
        _, preds = torch.max(outputs, 1) # labels:(batch,)
        preds = preds.cpu()
        pred_labels.extend(preds.tolist())
    print('Done.')
    return {'ID': ID, 'Label': pred_labels}
