In [1]:
import numpy as np
import matplotlib.pyplot as plt
import glob
import os
from collections import OrderedDict

from src.Dataset import HiC_Dataset
from src.layers.WEGATConv import Deep_WEGATv2_Conv
from src.layers.utils import PositionalEncoding

import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss as CEL
from torch.nn import Parameter, Linear, Sequential, Dropout, BatchNorm1d
from torch.utils.data import random_split
from torch.optim.lr_scheduler import OneCycleLR

import torch_geometric as tgm
from torch_geometric.data import DataLoader
from torch_geometric.nn import TopKPooling as TKP
from torch_geometric.nn import global_max_pool

import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers

DATASET = "Data/test_dset_18features_custom_norm.pt"
NUMEPOCHS = 10
NUMCLASSES = 3
BATCHSIZE = 500
LEARNING_RATE = 0.00005
WEIGHT_DECAY = 5e-4
MANUAL_SEED = 40
TRAIN_FRACTION = 0.7

# Modules

In [2]:
'''
UTILITY FUNCTIONS
'''
def get_middle_features(x,
                        numnodes = 51
                       ):
    mid = int((numnodes-1)/2)
    idxs = torch.arange(mid, x.shape[0], numnodes)
    return x[idxs,:]
    
'''
WEIGHTED EDGE GRAPH ATTENTION MODULE
'''
class WEGATModule(torch.nn.Module):
    def __init__(self,
                 hidden_channels=20,
                 numchip = 18,
                 numedge = 2,
                 heads = 4,
                 num_graph_convs = 6,
                 embedding_layers = 5,
                 num_fc = 8,
                 fc_channels = [15,15,15,10,10,10,5,2],
                 num_prom_fc = 10,
                 prom_fc_channels = [15,15,15,15,10,10,10,10,5,2],
                 positional_encoding = True,
                 pos_embedding_dropout = 0.1,
                 fc_dropout = 0.5,
                 conv_dropout = 0.1,
                 numclasses = NUMCLASSES
                ):
        if isinstance(fc_channels,int):
            fc_channels = [fc_channels]*num_fc
        elif len(fc_channels) != num_fc:
            print("number of fully connected channels must match the number of fully connected layers")
            raise

        if num_graph_convs < 1:
            print("need at least one graph convolution")
            raise
        num_graph_convs = int(num_graph_convs)

        if isinstance(prom_fc_channels,int):
            prom_fc_channels = [prom_fc_channels]*num_prom_fc
        elif len(prom_fc_channels) != num_prom_fc:
            raise

        super().__init__()
        torch.manual_seed(12345)
        #dropout layer
        self.dropout = Dropout(p=fc_dropout)

        #number of input chip features
        self.numchip = numchip

        #Whether to apply positional encoding to nodes
        self.positional_encoding = positional_encoding
        if positional_encoding:
            self.posencoder = PositionalEncoding(hidden_channels,
                                                 dropout=pos_embedding_dropout,
                                                 identical_sizes = True
                                                )

        #initial embeddding layer
        embedding = []
        embedding.append(Linear(numchip,
                                hidden_channels)
                        )
        embedding.append(torch.nn.Dropout(p=fc_dropout))
        embedding.append(torch.nn.ReLU())
        for idx in torch.arange(embedding_layers - 1):
            embedding.append(Linear(hidden_channels,
                                    hidden_channels)
                            )
            embedding.append(torch.nn.Dropout(p=fc_dropout))
            embedding.append(BatchNorm1d(hidden_channels))
            embedding.append(torch.nn.ReLU())
        self.embedding = Sequential(*embedding)

        #graph convolution layers
        enc = Deep_WEGATv2_Conv(node_in_channels = hidden_channels,
                                node_out_channels = hidden_channels,
                                edge_in_channels = numedge,
                                edge_out_channels = numedge,
                                heads = heads,
                                node_dropout = conv_dropout,
                                edge_dropout = conv_dropout
                               )
        gconv = [enc]
        encdec = Deep_WEGATv2_Conv(node_in_channels = hidden_channels,
                                   node_out_channels = hidden_channels,
                                   edge_in_channels = numedge,
                                   edge_out_channels = numedge,
                                   heads = heads,
                                   node_dropout = conv_dropout,
                                   edge_dropout = conv_dropout
                                    )
        for idx in np.arange(num_graph_convs-1):
            gconv.append(encdec)

        self.gconv = Sequential(*gconv)

        #fully connected channels
        fc_channels = [hidden_channels]+fc_channels
        lin = []
        for idx in torch.arange(num_fc):
            lin.append(Linear(fc_channels[idx],fc_channels[idx+1]))
            lin.append(torch.nn.Dropout(p=fc_dropout))
            lin.append(BatchNorm1d(fc_channels[idx+1]))
            lin.append(torch.nn.ReLU())
        self.lin = Sequential(*lin)
        self.num_fc = num_fc

        #fully connected promoter channels
        prom_fc_channels = [numchip]+prom_fc_channels
        linprom = []
        for idx in torch.arange(num_prom_fc):
            linprom.append(Linear(prom_fc_channels[idx],prom_fc_channels[idx+1]))
            linprom.append(torch.nn.Dropout(p=fc_dropout))
            lin.append(BatchNorm1d(prom_fc_channels[idx+1]))
            linprom.append(torch.nn.ReLU())
        self.linprom = Sequential(*linprom)
        self.num_prom_fc = num_prom_fc

        #final readout function
        self.readout = Linear(prom_fc_channels[-1]+fc_channels[-1], numclasses)

    def forward(self,
                batch):
        batch.prom_x = batch.prom_x.view(-1,self.numchip).float()
        batch.edge_attr[torch.isnan(batch.edge_attr)] = 0
        batch.x[torch.isnan(batch.x)] = 0
        batch.prom_x[torch.isnan(batch.prom_x)] = 0
        
        #initial dropout and embedding
        batch.x = self.dropout(batch.x)
        batch.x = self.embedding(batch.x.float())

        #positional encoding
        if self.positional_encoding:
            batch.x = self.posencoder(batch.x,
                                      batch.batch)

        #graph convolutions
        batch = self.gconv(batch)

        #extracting node of interest from graph
        x = get_middle_features(batch.x)

        # 3. Apply fully connected linear layers to graph
        x = self.lin(x)

        # 3. Apply fully connected linear layers to promoter
        prom_x = self.linprom(batch.prom_x)

        r_x = torch.cat([x,prom_x],
                        dim = 1)
        
        # 4. Apply readout layers
        x = self.readout(r_x)

        return x

# Lightning Net

In [8]:
'''
LIGHTNING NET
'''
class LitWEGATNet(pl.LightningModule):
    def __init__(self,
                 module,
                 train_loader,
                 val_loader,
                 learning_rate,
                 numsteps,
                 criterion
                ):
        super().__init__()
        self.module = module
        self.learning_rate = learning_rate
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.numsteps = numsteps
        self.criterion = criterion

    def train_dataloader(self):
        return self.train_loader
    
    def validation_dataloader(self):
        return self.test_loader

    def shared_step(self, batch):
        pred = self.module(batch).squeeze()
        loss = self.criterion(pred, batch.y)
        return loss, pred

    def customlog(self, name, loss, pred):
        self.log(f'{name}_loss', loss)
        
    def training_step(self, batch, batch_idx):
        loss, pred = self.shared_step(batch)
        self.customlog('train',loss, pred)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, pred = self.shared_step(batch)
        self.customlog('val',loss, pred)
        return loss

    def test_step(self, batch, batch_idx):
        loss, pred = self.shared_step(batch)
        self.customlog('test',loss, pred)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(),
                                     lr=self.learning_rate)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': OneCycleLR(optimizer,
                                        max_lr = 10*self.learning_rate,
                                        total_steps = self.numsteps
                                       )
            }
        }



# Data Loaders

In [4]:
'''
CONSTRUCTING THE DATALOADERS
'''
print("Loading in memory datasets")
dset = torch.load(DATASET,map_location=torch.device('cpu'))

vals = []
for d in dset:
    v = d.y.item()
    vals.append(v)
down, up = np.percentile(vals, 10),np.percentile(vals,90)

classes = ('downregulated','insignificant change','upregulated')
nums = [0,0,0]
for d in dset:
    v = d.y.item()
    if v <= down:
        idx = 0
    elif v >= up:
        idx = 2
    else:
        idx = 1
    d.y = idx
    nums[idx] += 1

criterion = CEL(weight = Tensor([np.max(nums)/item for item in nums]))    
numdatapoints = len(dset)
trainsize = int(numdatapoints*TRAIN_FRACTION)
train_dset, val_dset = random_split(dset,
                                    [trainsize, numdatapoints-trainsize],
                                    generator=torch.Generator().manual_seed(MANUAL_SEED)
                                   )

print("Loaded in memory datasets")
train_loader = DataLoader(train_dset, 
                              batch_size=BATCHSIZE,
                              num_workers=20
                             )
val_loader = DataLoader(val_dset, 
                             batch_size=BATCHSIZE,
                            num_workers=20
                           )

Loading in memory datasets
Loaded in memory datasets


In [10]:
NUMCHIP = dset[0].x.shape[1]
NUMEDGE = dset[0].edge_attr.shape[1]
module = WEGATModule(hidden_channels = 20,
                         numchip = NUMCHIP,
                         numedge = NUMEDGE,
                         embedding_layers = 4,
                         positional_encoding = True,
                         pos_embedding_dropout = 0,
                         fc_dropout = 0.01,
                         conv_dropout = 0.01
                        )
for dat in val_loader:
    print(module(dat))
    

tensor([[ 0.1053,  0.1748, -0.3319],
        [ 0.0829,  0.2445, -0.3212],
        [ 0.1354,  0.0813, -0.3463],
        ...,
        [ 0.1603,  0.0040, -0.3583],
        [ 0.0598,  0.3160, -0.3101],
        [ 0.0567,  0.3258, -0.3086]], grad_fn=<AddmmBackward>)
tensor([[ 0.0229,  0.4741,  0.1653],
        [ 0.1530,  0.0431, -0.1814],
        [ 0.0143,  0.5117,  0.2854],
        ...,
        [ 0.1481,  0.0418, -0.3524],
        [ 0.0567,  0.3258, -0.3086],
        [ 0.1367,  0.0867, -0.2475]], grad_fn=<AddmmBackward>)


KeyboardInterrupt: 

# Training

In [9]:
NUMCHIP = dset[0].x.shape[1]
NUMEDGE = dset[0].edge_attr.shape[1]

module = WEGATModule(hidden_channels = 20,
                     numchip = NUMCHIP,
                     numedge = NUMEDGE,
                     embedding_layers = 4,
                     positional_encoding = True,
                     pos_embedding_dropout = 0.05,
                     fc_dropout = 0.05,
                     conv_dropout = 0.05
                        )
Net = LitWEGATNet(module,
                  train_loader,
                  val_loader,
                  LEARNING_RATE,
                  50000,
                  criterion
                 )
    

tb_logger = pl_loggers.TensorBoardLogger('runs',
                                         name = 'printing_test',
                                         version = 0
                                        )
trainer = pl.Trainer(gpus=0,
                     max_epochs=100,
                     progress_bar_refresh_rate=1,
                     #logger=tb_logger,
                     auto_lr_find=False)

trainer.fit(Net, train_loader, val_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type             | Params
-----------------------------------------------
0 | module    | WEGATModule      | 18.5 K
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
18.5 K    Trainable params
0         Non-trainable params
18.5 K    Total params
0.074     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]



# Model Analysis

In [3]:
# Start tensorboard.
%reload_ext tensorboard
%tensorboard --logdir runs/ --port=8081

ERROR: Failed to launch TensorBoard (exited with 255).
Contents of stderr:
TensorFlow installation not found - running with reduced feature set.
E0519 13:54:06.474670 47233012180864 program.py:311] TensorBoard could not bind to port 8081, it was already in use
ERROR: TensorBoard could not bind to port 8081, it was already in use