In [None]:
!pip install dgl
!export DGLBACKEND=pytorch
import dgl
import numpy as np
#import tensorflow as tf
import torch as th

!pip install forgi
import forgi
import forgi.graph.bulge_graph as fgb
import forgi.threedee as ft
import forgi.threedee.model.coarse_grain as ftmc


import matplotlib.pyplot as plt
import networkx as nx

In [None]:
th.__version__

'1.10.0+cu111'


Ideas: 
*   load coarse grain representation with forgi,
*   use forgi to form a graph with nodes labeled as s/i/o/.. and twist, length, angle,...
*   use that graph to feed into model
*   use dgl.save_graph() to store a graph, so the structure can be used for several steps?
*   use forgi.threedee.model.coarse_grain.CoarseGrainRNA.rotate() to rotate cg RNAs and see if the classification changes

TODO:
*  build dataloader
*      build model
*      simulate batches to train (while testing)
*    future --> find where ernwin writes/stores output of structure for each n steps



In [None]:
#Graph Building

#load coarse grain file
def load_cg_file(file): 
  cg = ftmc.CoarseGrainRNA.from_bg_file(file) 
  coord_dict = dict(cg.coords)
  twist_dict = dict(cg.twists)

  # Get elements and neighbours:
  connections = {}
  for elem in cg.sorted_element_iterator():
    neighbours = cg.connections(elem)
    if elem not in connections:
        connections[elem] = cg.connections(elem)
  return coord_dict, twist_dict, connections

def build_dgl_graph(coord_dict, twist_dict, connections):
  #dictionary to convert type
  type_transl = {"h": 0, "i": 1, "m": 2, "s": 3, "f": 4, "t": 5}

  #encode nodes numerically for dgl graph
  num_graph = {}
  elem_count = {}
  for num, n in enumerate(sorted(connections)):
    num_graph[n] = num
    if n[0] not in elem_count:
      elem_count[n[0]] = 1
    else:
      elem_count[n[0]] += 1

  #build graph and edges
  u = []
  v = []
  for node in connections:
    for c in connections[node]:
      u.append(num_graph[node])
      v.append(num_graph[c])

  graph = dgl.graph((th.tensor(u), th.tensor(v)))

  #initialise node attributes
  graph.ndata["type"] = th.zeros(graph.num_nodes(), 1, dtype=int)
  graph.ndata["coord"] = th.zeros(graph.num_nodes(), 6, dtype=th.float64)
  graph.ndata["twist"] = th.zeros(graph.num_nodes(), 6, dtype=th.float64)

  for elem in connections:
    graph.ndata["type"][num_graph[elem]] = type_transl[elem[0]] 
    graph.ndata["coord"][num_graph[elem]] = th.tensor(np.concatenate(coord_dict[elem]), dtype=th.float64)
    if elem in twist_dict:
      graph.ndata["twist"][num_graph[elem]] = th.tensor(np.concatenate(twist_dict[elem]), dtype=th.float64)
  
  return graph



In [None]:
#Graph Dataset Class
#TODO: adapt, so it can stand alone

from dgl.data import DGLDataset
class CGDataset(DGLDataset):
  def __init__(self):
    super().__init__(name="cgRNA")
  
  def process(self):
    self.graphs = []
    self.labels = [12.722, 4.891, 22.918]

    for struc in ["6cu1.cg", "2mis.cg", "1p5p.cg"]:
      coord_dict, twist_dict, connections = load_cg_file(struc)
      self.graphs.append(build_dgl_graph(coord_dict, twist_dict, connections))

    self.labels = th.tensor(self.labels)
  
  def __getitem__(self, i):
    return self.graphs[i], self.labels[i]

  def __len__(self):
    return len(self.graphs)


In [None]:
g_list = []
glabels ={"glabel": th.tensor([12.722, 4.891, 22.918])} #done with: compare_RNA.py data/6CU1.pdb /home/mescalin/mgeyer/3d_classifier/6cu1.cg
for struc in ["6cu1.cg", "2mis.cg", "1p5p.cg"]:
  coord_dict, twist_dict, connections = load_cg_file(struc)
  graph = build_dgl_graph(coord_dict, twist_dict, connections)
  g_list.append(graph)

print("graphs")
print(g_list)

# save_graphs, label is rmsd
dgl.save_graphs("cg_graphs.dgl", g_list, labels=glabels)
gs, ls = dgl.load_graphs("cg_graphs.dgl")
print(gs)
print(ls)

graphs
[Graph(num_nodes=15, num_edges=30,
      ndata_schemes={'type': Scheme(shape=(1,), dtype=torch.int64), 'coord': Scheme(shape=(6,), dtype=torch.float64), 'twist': Scheme(shape=(6,), dtype=torch.float64)}
      edata_schemes={}), Graph(num_nodes=4, num_edges=6,
      ndata_schemes={'type': Scheme(shape=(1,), dtype=torch.int64), 'coord': Scheme(shape=(6,), dtype=torch.float64), 'twist': Scheme(shape=(6,), dtype=torch.float64)}
      edata_schemes={}), Graph(num_nodes=14, num_edges=26,
      ndata_schemes={'type': Scheme(shape=(1,), dtype=torch.int64), 'coord': Scheme(shape=(6,), dtype=torch.float64), 'twist': Scheme(shape=(6,), dtype=torch.float64)}
      edata_schemes={})]
[Graph(num_nodes=15, num_edges=30,
      ndata_schemes={'twist': Scheme(shape=(6,), dtype=torch.float64), 'type': Scheme(shape=(1,), dtype=torch.int64), 'coord': Scheme(shape=(6,), dtype=torch.float64)}
      edata_schemes={}), Graph(num_nodes=4, num_edges=6,
      ndata_schemes={'twist': Scheme(shape=(6,), dtyp

In [None]:
#Dataloading
import dgl.dataloading as dtl

dataset = CGDataset()
#graph, label = dataset[0]
#print(graph, label)


dataloader = dtl.pytorch.GraphDataLoader(dataset, batch_size=1)



In [94]:
#Model
'''
#tensorflow mock up model
def td_model():
    model = tf.keras.Sequential([
        dgl.nn.tensorflow.conv.GraphConv(10, 8, weight=True, bias=True),
        dgl.nn.tensorflow.conv.GraphConv(8, 6, weight=True, bias=True),
        dgl.nn.tensorflow.conv.GraphConv(6, 4, weight=True, bias=True),
        tf.layers.Dense(1, activation=None)
    ])
'''
# from https://docs.dgl.ai/tutorials/blitz/5_graph_classification.html
# TODO adapt

from dgl.nn import GraphConv

class GCN(th.nn.Module): # feed the 3 different node attributes one after the other though the first layer. https://discuss.dgl.ai/t/getting-started-with-multiple-node-features-in-homogenous-graph/919/2
  def __init__(self, in_feats, h_feats): # condense the 3 node attributes down to 1? see point above
    super(GCN, self).__init__()
    self.conv1 = GraphConv(in_feats, h_feats)
    self.conv2 = GraphConv(h_feats, 1)

    def forward(self, g, in_feat):
      h = self.conv1(g, in_feat)
      h = F.relu(h)
      h = self.conv2(g, h)
      g.ndata['h'] = h
      return dgl.mean_nodes(g, 'h')

#Coarse Grain RNA Classifier Model

class CG_Classifier(th.nn.Module):
  def __init__(self, in_dim, hidden_dim, out_dim):
    super(CG_Classifier, self).__init__()

    self.conv1 = GraphConv(in_dim, hidden_dim)
    self.conv2 = GraphConv(hidden_dim, hidden_dim)
    
    self.classify = th.nn.Linear(hidden_dim, 1)

  def forward(self, g, n_types, n_coord, n_twist):
    pass

In [None]:
#Training
# from https://docs.dgl.ai/tutorials/blitz/5_graph_classification.html
# TODO adapt

model = GCN(3 , 16)
optimizer = th.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(2):
    for batched_graph, labels in dataloader:
        pred = model(batched_graph, batched_graph.ndata['attr'].float())
        loss = F.cross_entropy(pred, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

num_correct = 0
num_tests = 0
for batched_graph, labels in dataloader:
    pred = model(batched_graph, batched_graph.ndata['attr'].float())
    num_correct += (pred.argmax(1) == labels).sum().item()
    num_tests += len(labels)

print('Test accuracy:', num_correct / num_tests)


KeyError: ignored