In [None]:
import pandas as pd
import numpy as np
from torch_geometric.data import Data, DataLoader
from torch.optim import Adam
from pathlib import Path
from tqdm.auto import tqdm
import torch.nn as nn
import torch_geometric.nn as gnn
import torch.nn.functional as F
import pytorch_lightning as pl
import torch
from scipy.linalg import block_diag
import json

In [None]:
def get_couples(structure):
    """
    For each closing parenthesis, I find the matching opening one and store their index in the couples list.
    The assigned list is used to keep track of the assigned opening parenthesis
    """
    opened = [idx for idx, i in enumerate(structure) if i == "("]
    closed = [idx for idx, i in enumerate(structure) if i == ")"]

    assert len(opened) == len(closed)

    assigned = []
    couples = []

    for close_idx in closed:
        for open_idx in opened:
            if open_idx < close_idx:
                if open_idx not in assigned:
                    candidate = open_idx
            else:
                break
        assigned.append(candidate)
        couples.append((candidate, close_idx))
        assigned.append(close_idx)
        couples.append((close_idx, candidate))

    assert len(couples) == 2 * len(opened)

    return couples


def build_matrix(couples, size):
    mat = np.zeros((size, size))

    for i in range(size):  # neigbouring bases are linked as well
        if i < size - 1:
            mat[i, i + 1] = 1
        if i > 0:
            mat[i, i - 1] = 1

    for i, j in couples:
        mat[i, j] = 2
        mat[j, i] = 2

    return mat


In [None]:
def seq2nodes(sequence, loop_type):
    type_dict={'A':0,'G':1,'U':2,'C':3}
    type_loop = {'E': 0, 'H': 1, 'M': 2, 'I': 3, 'X': 4, 'S': 5, 'B': 6}
    nodes=np.zeros((len(sequence),4))
    loops = np.zeros((len(sequence), len(type_loop)))
    for i,(s,lt) in enumerate(zip(sequence, loop_type)):
        nodes[i,type_dict[s]]=1
        loops[i,type_loop[lt]] = 1
    nodes = np.concatenate([nodes, loops],axis=-1)
    return nodes

def seq2edge_index(structure):
    couples = sorted(set(get_couples(structure)))
    couples = np.array(couples).T
    neig = np.array([np.arange(0,len(structure) -1), np.arange(1,len(structure))])
    neig2 = neig[::-1,::]
    edge_index = np.concatenate([couples, neig, neig2], axis=1)
    edges_type = np.array([1]*couples.shape[1] + [2]*neig.shape[1]*2)

    return edge_index, edges_type

def edge_index2features(edge_index, edges_type, node_features):
    edge_type_f = np.zeros((edge_index.shape[1],2))
    for ty in [1,2]:
        edge_type_f[:,ty-1] = (edges_type == ty).astype(int)
    edge_direction = np.stack([(edge_index[1,] - edge_index[0,] == 1).astype(int),
                          (edge_index[0,] - edge_index[1,] == 1).astype(int)]).T
    edge_features = np.concatenate([edge_type_f,edge_direction],axis=-1)
    return edge_features

def seq2edges(structure, node_features):
    edge_index, edges_type = seq2edge_index(structure)
    edge_features = edge_index2features(edge_index, edges_type, node_features)
    return edge_index, edge_features

def cg2edges(cg_graph, node2idx):
    features = []
    indexes = []
    for node_name, segments in cg_graph['nodes'].items():
        node_idx = node2idx[node_name]
        for seg in segments:
            for idx in range(*seg):
                indexes.append((node_idx,idx))
                features.append([1,0,0])
                indexes.append((idx, node_idx))
                features.append([0,1,0])
    for node_1,node_2 in cg_graph['edges']:
        indexes.append((node2idx[node_1],node2idx[node_2]))
        features.append([0,0,1])
    indexes = np.array(indexes).T
    features = np.array(features)
    return indexes, features

def create_edges(structure, node_features, cg_graph,node2idx):
    edge_index_nuc, edge_features_nuc = seq2edges(structure, node_features)
    edge_index_bungle, edge_features_bungle = cg2edges(cg_graph, node2idx)
    edge_index = np.concatenate([edge_index_nuc,edge_index_bungle],axis=1)
    edge_from = node_features[edge_index[0,]]
    edge_to = node_features[edge_index[1,]]
    edge_features = block_diag(edge_features_nuc, edge_features_bungle)
    edge_features = np.concatenate([edge_features, edge_from,edge_to],axis=1)
    return edge_index, edge_features

def bungle_features(cg_nodes):
    type_dict = {'f': 0, 't': 1, 's': 2, 'i': 3, 'm': 4, 'h': 5}
    features = np.zeros((len(cg_nodes),len(type_dict)+1))
    for index, (node_name, segments) in enumerate(cg_nodes.items()):
        features[index][type_dict[node_name[0]]] = 1
        num_b = sum(seg[1]-seg[0] for seg in segments)
        features[index][-1] = num_b
    return features

def add_bungle_nodes(x, cg_graph):
    x = np.concatenate([x,np.ones((x.shape[0],1))], axis=1)
    cg_nodes = cg_graph['nodes']
    x_len,x_dim = x.shape
    node2idx = {node:index for index,node in enumerate(cg_nodes,start=x_len)}
    cg_x = bungle_features(cg_nodes)
    cg_x = np.concatenate([cg_x,np.ones((cg_x.shape[0],1))], axis=1)
    features = block_diag(x, cg_x)
    return features, node2idx
    

def build_data(df, cg_graphs, target_cols=None, error_cols=None):
    target_cols = target_cols or []
    error_cols = error_cols or []
    assert len(error_cols) == len(target_cols)
    data_list = []
    for (id_, sequence, structure, seq_scored, loop_type), targets, errors in zip(tqdm(df[['id','sequence','structure','seq_scored', 'predicted_loop_type']].values),
                                                                       df[target_cols].values, df[error_cols].values):
        cg_graph = cg_graphs[id_]
        x = seq2nodes(sequence, loop_type)
        x, node2idx = add_bungle_nodes(x, cg_graph)
        edge_index, edge_features = create_edges(structure, x, cg_graph, node2idx)
        if targets is not None:
            targets = np.stack(targets).T
            errors = np.stack(errors).T
            targets = np.pad(targets, ((0,x.shape[0]-targets.shape[0]),(0,0)),constant_values=np.nan)
            errors = np.pad(errors, ((0,x.shape[0]-errors.shape[0]),(0,0)),constant_values=np.nan)
            targets = torch.FloatTensor(targets)
            errors = torch.FloatTensor(errors)
            
            targets = torch.stack([targets, errors], dim=2)
        else:
            targets = None

        edge_index = torch.LongTensor(edge_index)
        edge_features = torch.FloatTensor(edge_features)
        x = torch.FloatTensor(x)
        data = Data(x, edge_index, edge_features, targets)
        data.seq_scored = seq_scored
        is_nuc = np.zeros(x.shape[0])
        is_nuc[:len(sequence)] = 1
        data.is_nuc = torch.BoolTensor(is_nuc)
        data_list.append(data)
    assert len({data.seq_scored for data in data_list}) == 1
    return data_list

def simple_graph(bpp, sequence, structure, loop_type, targets, errors):
    x = seq2nodes(sequence, loop_type)
    matrix = build_matrix(get_couples(structure), len(structure))


In [None]:
MAP2D_FOLDER = '../data/nsp_distances_angles2/'
TARGET_COLS = ['reactivity', 'deg_Mg_pH10', 'deg_Mg_50C']
ERROR_COLS = ['reactivity_error', 'deg_error_Mg_pH10','deg_error_Mg_50C']
NEPTUNE_API_TOKEN='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5haSIsImFwaV91cmwiOiJodHRwczovL3VpLm5lcHR1bmUuYWkiLCJhcGlfa2V5IjoiODQwOTM5MjItYWQ2Mi00ODRhLTgxOTUtMzA4NzNhMzI3OGIwIn0='

In [None]:
class MyDataModule(pl.LightningDataModule):

    def __init__(self, df_train, df_val, cg_graphs, batch_size):
        super().__init__()
        self.batch_size = batch_size
        self.df_train = df_train
        self.df_val = df_val
        self.cg_graphs = cg_graphs
    
    def setup(self, stage=None):
        self.train_data_list = build_data(self.df_train, self.cg_graphs, target_cols=TARGET_COLS, error_cols=ERROR_COLS)
        self.val_data_list = build_data(self.df_val, self.cg_graphs, target_cols=TARGET_COLS, error_cols=ERROR_COLS)
        
    def train_dataloader(self):
        return DataLoader(self.train_data_list, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_data_list, batch_size=self.batch_size)

In [None]:
def compute_MCRMSE(pred, y, columns,per_column=False,prefix=''):
    mask = ~y[:,0,0].isnan()
    pred = pred[mask]
    y = y[:,:,0][mask]
    losses = torch.sqrt(torch.mean((pred - y)**2,dim=0) + 1e-6)
    if not per_column:
        return losses.mean()
    metrics = {f"{prefix}_mcrmse_{col}" if prefix else f"mcrmse_{col}":loss for col,loss in zip(columns, losses)}
    metrics["mcrmse"] = losses.mean()
    return metrics
    

In [None]:
class MyGNNLighting(pl.LightningModule):
    
    def __init__(self,input_size, hidden_size, edge_dim, edge_hidden_dim, output_dim=3, seq_len=107, seq_scored=68):
        super().__init__()
        self.seq_len = seq_len
        self.edge_mlp = nn.Linear(edge_dim, edge_hidden_dim)
        self.node_mlp = nn.Linear(input_size, hidden_size)
        self.lstm = nn.LSTM(input_size,int(hidden_size/2),bidirectional=True, batch_first=True)
        self.gcn = gnn.CGConv(hidden_size, edge_hidden_dim)
        self.gcn2 = gnn.CGConv(hidden_size, edge_hidden_dim)
        self.gcn3 = gnn.CGConv(hidden_size, edge_hidden_dim)
        self.mlp = nn.Linear(hidden_size, output_dim)
        self.seq_scored = seq_scored
    
    def forward(self, data):
        x = data.x.float()
        x = self.node_mlp(x)
        edge_attr = data.edge_attr
        edge_attr = self.edge_mlp(edge_attr)
#         for i in range(batch.num_graphs):
            
#         x = self.lstm(x.view(-1, self.seq_len, x.shape[-1]))[0]
#         print(x.shape)
#         x = x.reshape(-1,x.shape[-1])
        x = self.gcn(x,data.edge_index, edge_attr)
        x = torch.sigmoid(x)
        x = self.gcn2(x,data.edge_index, edge_attr)
        x = self.gcn3(x,data.edge_index, edge_attr)
        x = self.mlp(x)
        return x
    
    def training_step(self, batch, batch_idx):
        pred = self(batch)
        y = batch.y
        loss = self.compute_loss(pred, y)
        result = pl.TrainResult()
        result.log("train_mcrmse", loss, on_step=True, logger=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        pred = self(batch)
        y = batch.y
        result = pl.TrainResult()
        return (pred, y)
    
    def validation_epoch_end(self, validation_step_outputs):
        pred,y = zip(*validation_step_outputs)
        pred = torch.cat(pred,dim=0)
        y = torch.cat(y, dim=0)
        metrics = self.compute_loss(pred, y, per_column=True, prefix='val')
        result = pl.EvalResult(checkpoint_on=metrics['mcrmse'])
        result.log_dict(metrics, on_step=False, on_epoch=True, logger=True, prog_bar=False)
        return result
    
    def compute_loss(self, pred, y, columns=TARGET_COLS, per_column=False, prefix=''):
        return compute_MCRMSE(pred, y, columns, per_column, prefix=prefix)
    
    def configure_optimizers(self):
        opt = Adam(self.parameters(), lr=1e-2)
        return opt
        

In [None]:
df =  pd.read_json('../data/train.json', lines=True)

In [None]:
df = df[df['SN_filter'] == 1]

In [None]:
df_train = df.iloc[:1000]
df_test = df.iloc[1000:]

In [None]:
data = MyDataModule(df_train,df_test,cg_graph, 16)

In [None]:
data.setup()

In [None]:
for batch in data.val_dataloader():
    break

In [None]:
module = MyGNNLighting(data.train_data_list[0].x.shape[1],100,data.train_data_list[0].edge_attr.shape[1],50)

In [None]:
logger = pl.loggers.neptune.NeptuneLogger(NEPTUNE_API_TOKEN,"gottalottarock/openVaccine",experiment_name="test")

In [None]:
trainer = pl.trainer.Trainer(logger=logger,max_epochs=40)

In [None]:
trainer.fit(module,data)