In [8]:
#!pip install dgl
#!DGLBACKEND=pytorch
#!export $DGLBACKEND
#import os
#os.environ["DGLBACKEND"] = "pytorch"
#print(os.environ["DGLBACKEND"])
import dgl
import numpy as np
#import tensorflow as tf
import torch as th

#!pip install forgi
import forgi
import forgi.graph.bulge_graph as fgb
import forgi.threedee as ft
import forgi.threedee.model.coarse_grain as ftmc


import matplotlib.pyplot as plt
import networkx as nx

In [6]:
th.__version__

'1.10.0+cpu'


Ideas: 
*   load coarse grain representation with forgi,
*   use forgi to form a graph with nodes labeled as s/i/o/.. and twist, length, angle,...
*   use that graph to feed into model
*   use dgl.save_graph() to store a graph, so the structure can be used for several steps?
*   use forgi.threedee.model.coarse_grain.CoarseGrainRNA.rotate() to rotate cg RNAs and see if the classification changes

TODO:
*  build dataloader
*      build model
*      simulate batches to train (while testing)
*    future --> find where ernwin writes/stores output of structure for each n steps



In [89]:
#Graph Building

#load coarse grain file
def load_cg_file(file): 
    cg = ftmc.CoarseGrainRNA.from_bg_file(file) 
    coord_dict = dict(cg.coords)
    twist_dict = dict(cg.twists)

    # Get elements and neighbours:
    connections = {}
    for elem in cg.sorted_element_iterator():
        neighbours = cg.connections(elem)
        if elem not in connections:
            connections[elem] = cg.connections(elem)
    return coord_dict, twist_dict, connections

def build_dgl_graph(coord_dict, twist_dict, connections):
    #dictionary to convert type
    type_transl = {
        "h": [1, 0, 0, 0, 0, 0],
        "i": [0, 1, 0, 0, 0, 0],
        "m": [0, 0, 1, 0, 0, 0],
        "s": [0, 0, 0, 1, 0, 0],
        "f": [0, 0, 0, 0, 1, 0],
        "t": [0, 0, 0, 0, 0, 1]
    } 

    #encode nodes numerically for dgl graph
    num_graph = {}
    elem_count = {}
    for num, n in enumerate(sorted(connections)):
        num_graph[n] = num
        if n[0] not in elem_count:
            elem_count[n[0]] = 1
        else:
            elem_count[n[0]] += 1

    #build graph and edges
    u = []
    v = []
    for node in connections:
        for c in connections[node]:
            u.append(num_graph[node])
            v.append(num_graph[c])

    graph = dgl.graph((th.tensor(u), th.tensor(v)))

    #initialise node attributes
    graph.ndata["type"] = th.zeros(graph.num_nodes(), 6, dtype=th.float64)
    graph.ndata["coord"] = th.zeros(graph.num_nodes(), 6, dtype=th.float64) #seperate coords into 2 sets of 3, so that the information of start and end is added?
    graph.ndata["twist"] = th.zeros(graph.num_nodes(), 6, dtype=th.float64)

    for elem in connections:
        graph.ndata["type"][num_graph[elem]] = th.tensor(type_transl[elem[0]], dtype=th.float64) 
        graph.ndata["coord"][num_graph[elem]] = th.tensor(np.concatenate(coord_dict[elem]), dtype=th.float64)
        if elem in twist_dict:
            graph.ndata["twist"][num_graph[elem]] = th.tensor(np.concatenate(twist_dict[elem]), dtype=th.float64)
  
    return graph



In [90]:
#Graph Dataset Class
#TODO: adapt, so it can stand alone

from dgl.data import DGLDataset
class CGDataset(DGLDataset):
    def __init__(self):
        super().__init__(name="cgRNA")
  
    def process(self):
        self.graphs = []
        self.labels = [12.722, 4.891, 22.918]

        for struc in ["6cu1.cg", "2mis.cg", "1p5p.cg"]:
            coord_dict, twist_dict, connections = load_cg_file(struc)
            self.graphs.append(build_dgl_graph(coord_dict, twist_dict, connections))

        self.labels = th.tensor(self.labels)
  
    def __getitem__(self, i):
        return self.graphs[i], self.labels[i]

    def __len__(self):
        return len(self.graphs)


In [91]:
g_list = []
glabels ={"rmsd": th.tensor([12.722, 4.891, 22.918])} #done with: compare_RNA.py data/6CU1.pdb /home/mescalin/mgeyer/3d_classifier/6cu1.cg

for struc in ["6cu1.cg", "2mis.cg", "1p5p.cg"]:
    coord_dict, twist_dict, connections = load_cg_file(struc)
    graph = build_dgl_graph(coord_dict, twist_dict, connections)
    g_list.append(graph)

print("graphs")
print(g_list)

# save_graphs, label is rmsd
dgl.save_graphs("cg_graphs.dgl", g_list, labels=glabels)
gs, ls = dgl.load_graphs("cg_graphs.dgl")
print(gs)
print(ls)

graphs
[Graph(num_nodes=15, num_edges=30,
      ndata_schemes={'type': Scheme(shape=(6,), dtype=torch.float64), 'coord': Scheme(shape=(6,), dtype=torch.float64), 'twist': Scheme(shape=(6,), dtype=torch.float64)}
      edata_schemes={}), Graph(num_nodes=4, num_edges=6,
      ndata_schemes={'type': Scheme(shape=(6,), dtype=torch.float64), 'coord': Scheme(shape=(6,), dtype=torch.float64), 'twist': Scheme(shape=(6,), dtype=torch.float64)}
      edata_schemes={}), Graph(num_nodes=14, num_edges=26,
      ndata_schemes={'type': Scheme(shape=(6,), dtype=torch.float64), 'coord': Scheme(shape=(6,), dtype=torch.float64), 'twist': Scheme(shape=(6,), dtype=torch.float64)}
      edata_schemes={})]
[Graph(num_nodes=15, num_edges=30,
      ndata_schemes={'twist': Scheme(shape=(6,), dtype=torch.float64), 'coord': Scheme(shape=(6,), dtype=torch.float64), 'type': Scheme(shape=(6,), dtype=torch.float64)}
      edata_schemes={}), Graph(num_nodes=4, num_edges=6,
      ndata_schemes={'twist': Scheme(shape=(6

In [92]:
#Dataloading
import dgl.dataloading as dtl

dataset = CGDataset()
#graph, label = dataset[0]
#print(graph, label)


dataloader = dtl.pytorch.GraphDataLoader(dataset, batch_size=1, shuffle=True) #add randomisation as in Defining Data Loader from https://docs.dgl.ai/tutorials/blitz/5_graph_classification.html



In [93]:
#Model
from dgl.utils import expand_as_pair
from dgl.nn import GraphConv
import torch.nn.functional as F

# feed the 3 different node attributes one after the other though the first layer? like in https://discuss.dgl.ai/t/getting-started-with-multiple-node-features-in-homogenous-graph/919/2
# condense the 3 node attributes down to 1? see point above


#Coarse Grain RNA Classifier Model
class CG_Classifier(th.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super(CG_Classifier, self).__init__()
        
        #self._in_src_feats, self._in_dst_feats = expand_as_pair(in_dim)
        
        self.conv1 = GraphConv(in_dim, hidden_dim, activation=F.relu)
        self.conv2 = GraphConv(hidden_dim, hidden_dim, activation=F.relu)
        
        #add pooling layer
        
        self.classify = th.nn.Linear(hidden_dim, 1)

    def forward(self, g, n_types, n_coord, n_twist): #WIP
        
        nt = n_types
        print(nt)
        nt = self.conv1(g, nt)
        print(nt)
        nt = self.conv2(g, nt)
        print(nt)
        g.ndata['nt'] = nt
        nt = dgl.mean_nodes(g, 'nt')
        
        nc = n_coord
        nc = self.conv1(g, nc)
        nc = self.conv2(g, nc)
        
        nw = n_twist
        nw = self.conv1(g, nw)
        nw = self.conv2(g, nw)
        
        tcw = th.cat((nt, nc, nw), 1)
        
        tcw_mean = tcw.mean(dim=0)
        return self.classify(tcw_mean)
    

In [82]:
#Training
#from tqdm import tqdm

model = CG_Classifier(
    in_dim=6,
    hidden_dim=3,
    out_dim=1
)

optimizer = th.optim.Adam(model.parameters(), lr=0.01)
model.train()

epochs = 3

for epoch in range(epochs):
    for batched_g, label in dataloader:#tqdm(dataloader):
        print(batched_g)
        n_types = batched_g.ndata["type"]
        n_coord = batched_g.ndata["coord"]
        n_twist = batched_g.ndata["twist"]

        prediction = model(batched_g, n_types, n_coord, n_twist)
        print(prediction)


Graph(num_nodes=15, num_edges=30,
      ndata_schemes={'type': Scheme(shape=(6,), dtype=torch.float64), 'coord': Scheme(shape=(6,), dtype=torch.float64), 'twist': Scheme(shape=(6,), dtype=torch.float64)}
      edata_schemes={})
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0., 0.]], dtype=torch.float64)


RuntimeError: expected scalar type Double but found Float