In [1]:
import numpy as np
import torch
import os
import pandas as pd
import torch_geometric
import pickle
import tqdm
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [3]:
rfam_dir = "../rfam/data/raw/processed/release-14.8"
rfams = ["RF00001","RF00174","RF00169","RF00050"]

In [4]:
from RNARepLearn.datasets import CombinedRfamDataset, SingleRfamDataset
#dataset = CombinedRfamDataset(rfam_dir, rfams, "Under300", 15, 300)
dataset = SingleRfamDataset(rfam_dir, "RF00001", 15)

Processing...
Done!


In [5]:
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

In [6]:
from torch_geometric.loader import DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [8]:
##Model
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, Linear

class Encoder(torch.nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()
        self.conv1 = GCNConv(input_channels, output_channels)
        self.conv2 = GCNConv(output_channels,output_channels)

        
    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        #build representation, encode
        # input -> V_N,D
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x,edge_index)
        x = F.relu(x)
        
        return x
    
class AttentionDecoder(torch.nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()
        self.key_projection = Linear(input_channels, input_channels)
        self.query_projection = Linear(input_channels, input_channels)
        self.nuc_projection = Linear(input_channels, output_channels)
        
    def forward(self, x):
        keys = self.key_projection(x)
        queries = self.query_projection(x)
        
        nucleotides = self.nuc_projection(x)
        
        dotprod = torch.matmul(queries,keys.T)
        
        return F.softmax(nucleotides, dim=1), F.softmax(dotprod, dim=1)

In [9]:
layers = []
layers.append(Encoder(4,64))
layers.append(AttentionDecoder(64,4))
model = torch.nn.Sequential(*layers)

In [10]:
##Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model = model.double()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

n_epochs = 10

In [12]:
from RNARepLearn.utils import mask_batch, reconstruct_bpp
##Training
train_hist = {}
train_hist["loss"]=[]
train_hist["nucleotide_loss"]=[]
train_hist["edge_loss"]=[]
model.train()
cel_loss = torch.nn.CrossEntropyLoss()
kl_loss = torch.nn.KLDivLoss(reduction='batchmean')

for epoch in range(n_epochs):
    for idx, batch in enumerate(train_loader):
        
        true_x = torch.clone(batch.x)
        true_edges = torch.clone(batch.edge_attr)

        nuc_mask, edge_mask = mask_batch(batch,15)
        batch.to(device)
        optimizer.zero_grad()
        
        nucs, bpp = model(batch)

        node_loss = cel_loss(nucs.cpu()[nuc_mask],true_x[nuc_mask])
        edge_loss = kl_loss(bpp.cpu()[nuc_mask].log() , torch.tensor(reconstruct_bpp(batch.edge_index.cpu(), true_edges, (len(bpp),len(bpp)))[nuc_mask]))
        
        loss = node_loss + edge_loss
        
        writer.add_scalar("Loss/train", loss, epoch)
        writer.add_scalar("Loss_nodes/train", node_loss, epoch)
        writer.add_scalar("Loss_edges/train", edge_loss, epoch)

        loss.backward()
        optimizer.step()
        train_hist["loss"].append(loss.item())
        train_hist["nucleotide_loss"].append(node_loss.item())
        train_hist["edge_loss"].append(edge_loss.item())
        
        
        if idx % 10 == 0:
            print('\r[Epoch %4d/%4d] [Batch %4d/%4d] Loss: % 2.2e Nucleotide-Loss: % 2.2e Edge-Loss: % 2.2e' % (epoch + 1, n_epochs, 
                                                                idx + 1, len(train_loader), 
                                                                loss.item(),node_loss.item(),edge_loss.item()))
writer.flush()

[Epoch    1/  10] [Batch    1/ 125] Loss:  9.12e+00 Nucleotide-Loss:  1.39e+00 Edge-Loss:  7.74e+00
[Epoch    1/  10] [Batch   11/ 125] Loss:  8.88e+00 Nucleotide-Loss:  1.22e+00 Edge-Loss:  7.65e+00
[Epoch    1/  10] [Batch   21/ 125] Loss:  8.79e+00 Nucleotide-Loss:  1.13e+00 Edge-Loss:  7.66e+00
[Epoch    1/  10] [Batch   31/ 125] Loss:  8.66e+00 Nucleotide-Loss:  1.07e+00 Edge-Loss:  7.60e+00
[Epoch    1/  10] [Batch   41/ 125] Loss:  8.38e+00 Nucleotide-Loss:  9.21e-01 Edge-Loss:  7.46e+00
[Epoch    1/  10] [Batch   51/ 125] Loss:  8.16e+00 Nucleotide-Loss:  8.49e-01 Edge-Loss:  7.31e+00
[Epoch    1/  10] [Batch   61/ 125] Loss:  7.92e+00 Nucleotide-Loss:  8.04e-01 Edge-Loss:  7.12e+00
[Epoch    1/  10] [Batch   71/ 125] Loss:  7.88e+00 Nucleotide-Loss:  7.84e-01 Edge-Loss:  7.09e+00
[Epoch    1/  10] [Batch   81/ 125] Loss:  7.86e+00 Nucleotide-Loss:  7.80e-01 Edge-Loss:  7.08e+00
[Epoch    1/  10] [Batch   91/ 125] Loss:  7.80e+00 Nucleotide-Loss:  7.77e-01 Edge-Loss:  7.02e+00


In [14]:
model.eval()
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)
testbatch = next(iter(test_loader))

true_x = torch.clone(testbatch.x)
true_edges = torch.clone(testbatch.edge_attr)


nuc_mask, edge_mask = mask_batch(testbatch,15)

testbatch.to(device)
nucs, bpp = model(testbatch)

TypeError: clone(): argument 'input' (position 1) must be Tensor, not DataBatch

In [26]:
correct = (torch.argmax(true_x,dim=1)[nuc_mask] == torch.argmax(nucs.cpu(),dim=1)[nuc_mask]).sum()
acc = int(correct) / len(nucs[nuc_mask])
acc

1.0

In [None]:
import RNARepLearn.visualize as vis

for sample in testbatch:
    real_bpp = torch.tensor(reconstruct_bpp(sample.edge_index, true_edges, (len(bpp),len(bpp))))
    vis.compare_bpps(real_bpp, bpp.cpu())
    break

AttributeError: 'tuple' object has no attribute 'edge_index'