In [1]:
import torch
import torch_geometric

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
rfam_dir = "../rfam/data/raw/processed/release-14.8"
rfams = ["RF00001","RF00174","RF00169","RF00050"]

from RNARepLearn.datasets import CombinedRfamDataset, SingleRfamDataset
#dataset = CombinedRfamDataset(rfam_dir, rfams, "Under300", 15, 300)
dataset = SingleRfamDataset(rfam_dir, "RF00001", 15)

train_size = int(0.8 * len(dataset))
test_size = int(0.5 * (len(dataset)-train_size))
val_size = len(dataset)-train_size-test_size
train_dataset, test_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, test_size, val_size])

from torch_geometric.loader import DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

Processing...
Done!


In [9]:
from RNARepLearn.modules import LinearEmbedding, RPINetEncoder, AttentionDecoder
layers = []
layers.append(LinearEmbedding(4, 32))
layers.append(RPINetEncoder(32, 32, 7))
layers.append(AttentionDecoder(32, 4))
model = torch.nn.Sequential(*layers)

In [10]:
##Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model = model.double()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

n_epochs = 10

In [11]:
from RNARepLearn.train import MaskedTraining
training = MaskedTraining(model, 10, 15, writer)

In [13]:
training.run(train_loader)

NameError: name 'writer' is not defined

In [6]:
from RNARepLearn.utils import mask_batch, reconstruct_bpp

##Training

train_hist = {}
train_hist["loss"]=[]
train_hist["nucleotide_loss"]=[]
train_hist["edge_loss"]=[]
model.train()
cel_loss = torch.nn.CrossEntropyLoss()
kl_loss = torch.nn.KLDivLoss(reduction='batchmean')

for epoch in range(n_epochs):
    for idx, batch in enumerate(train_loader):
        true_x = torch.clone(batch.x)
        true_edges = torch.clone(batch.edge_weight)

        nuc_mask, edge_mask = mask_batch(batch,15)
        batch.to(device)
        optimizer.zero_grad()
        
        nucs, bpp = model(batch)

        node_loss = cel_loss(nucs.cpu()[nuc_mask],true_x[nuc_mask])
        edge_loss = kl_loss(bpp.cpu()[nuc_mask].log() , torch.tensor(reconstruct_bpp(batch.edge_index.cpu(), true_edges, (len(bpp),len(bpp)))[nuc_mask]))
        
        loss = node_loss + edge_loss
        
        writer.add_scalar("Loss/train", loss, epoch)
        writer.add_scalar("Loss_nodes/train", node_loss, epoch)
        writer.add_scalar("Loss_edges/train", edge_loss, epoch)

        loss.backward()
        optimizer.step()
        
        node_accuracy = int((nucs.cpu()[nuc_mask].argmax(dim=1)==true_x[nuc_mask].argmax(dim=1)).sum()) / len(nuc_mask)
        
        writer.add_scalar("Node_Accuracy/train", node_accuracy, epoch)
        
        if val_loader is not None:
            model.eval()
            node_accuracy_val = []
            for batch in val_loader:
                true_x = torch.clone(batch.x)
                nuc_mask, edge_mask = mask_batch(batch,15)
                batch.to(device)
                optimizer.zero_grad()
        
                nucs, bpp = model(batch)
                node_accuracy_val.append(int((nucs.cpu()[nuc_mask].argmax(dim=1)==true_x[nuc_mask].argmax(dim=1)).sum()) / len(nuc_mask))
            node_accuracy_val = sum(node_accuracy_val)/len(node_accuracy_val)
            writer.add_scalar("Node_Accuracy/val", node_accuracy, epoch)
            
            model.train()
        
        
        
        
        if idx % 10 == 0:
            print('\r[Epoch %4d/%4d] [Batch %4d/%4d] Loss: % 2.2e Nucleotide-Loss: % 2.2e Edge-Loss: % 2.2e Node_accuracy % 2.2e' % (epoch + 1, n_epochs, 
                                                                idx + 1, len(train_loader), 
                                                                loss.item(),node_loss.item(),edge_loss.item(),node_accuracy))
writer.flush()
        

0.036535288038247946
[Epoch    1/  10] [Batch    1/ 125] Loss:  8.97e+00 Nucleotide-Loss:  1.39e+00 Edge-Loss:  7.58e+00 Node_accuracy  3.82e-02
0.03759276537977665
0.036995275016855074
0.03703332897827453
0.03775433171650231
0.04065079386125073
0.04041120630054664
0.041047311539076925
0.042415546119421076
0.0413196451662169
0.04058328658888613
[Epoch    1/  10] [Batch   11/ 125] Loss:  8.05e+00 Nucleotide-Loss:  1.38e+00 Edge-Loss:  6.67e+00 Node_accuracy  4.12e-02
0.041015826707139505
0.04063228458886548
0.04120530592338296


KeyboardInterrupt: 