In [66]:
import torch
import torch.nn as nn
import torch.optim as optim
from dataloader import get_dataloaders, MAX_SEQ_LENGTH, vocab_size
from torch.autograd import Variable
import time
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = 'cpu'

class RNAPairLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=1):
        super(RNAPairLSTM, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_layers = num_layers

        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(2*hidden_dim, output_dim)

    def forward(self, input): 
        h_0 = Variable(torch.zeros(2*self.num_layers, input.size(0), self.hidden_dim, requires_grad=False).to(device))
        c_0 = Variable(torch.zeros(2*self.num_layers, input.size(0), self.hidden_dim).to(device))

        output, (h_out, _) = self.lstm(input, (h_0, c_0))
        output = self.fc(output)
        
        return output

In [67]:
import torch
import numpy as np
import random
import os
from dataloader import get_dataloaders, MAX_SEQ_LENGTH, vocabulary, vocab_size
# from lstm import RNAPairLSTM

model_file = "./model2.pth"
input_dim = vocab_size  # One-hot encoded input size
hidden_dim = 128
output_dim = vocab_size  # One-hot encoded output size
num_layers = 2
num_epochs = 100
model = RNAPairLSTM(input_dim, hidden_dim, output_dim, num_layers).to(device)
# model.load_state_dict(torch.load(model_file, map_location=torch.device('cpu')))
model.load_state_dict(torch.load(model_file))
model.eval()

RNAPairLSTM(
  (lstm): LSTM(7, 128, num_layers=2, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=256, out_features=7, bias=True)
)

In [68]:
vocab = list(vocabulary.keys())
def outputs_to_seq(outputs):
    idx = outputs.argmax(dim=-1)
    # 取vocab中的token

    rna = [vocab[int(i)] for i in idx.squeeze()]
    # 去掉padding
    rna = rna[:rna.index('P')]
    return rna


In [72]:
from dataloader import get_dataloaders, MAX_SEQ_LENGTH, vocab_size
batch_size = 1
train_loader, dev_loader, test_loader = get_dataloaders(batch_size=batch_size)
# random select 5 training samples
# random.seed(0)
train_samples = random.sample(list(train_loader), 10)
# random select 5 dev samples
dev_samples = random.sample(list(dev_loader), 10)
# random select 5 test samples
test_samples = random.sample(list(test_loader), 10)

# 输出原来的seq1和seq2，还有预测的seq2
for i in range(10):
    seq1, seq2 = train_samples[i]
    # seq1 = seq1.reshape(MAX_SEQ_LENGTH, -1)
    # seq2 = seq2.reshape(MAX_SEQ_LENGTH, -1)
    outputs = model(seq1)
    # compute CrossEntropyLoss with seq2 and outputs
    criterion = nn.CrossEntropyLoss()
    loss = criterion(outputs.reshape(MAX_SEQ_LENGTH, -1), seq2.reshape(MAX_SEQ_LENGTH, -1))
    print("train loss: ", loss.item())
    print("seq1: ", outputs_to_seq(seq1))
    print("seq2: ", outputs_to_seq(seq2))
    seq1, _ = train_samples[i]
    outputs = model(seq1)
    # 自己softmax
    # outputs = torch.softmax(outputs, dim=-1)
    # print(outputs)
    # print(outputs.shape)
    outputs = outputs_to_seq(outputs.reshape(MAX_SEQ_LENGTH, -1))
    # outputs = "".join(outputs)
    print("predict seq2: ", outputs)
    print()

train loss:  0.029913874343037605
seq1:  ['G', 'G', 'G', 'A', 'E']
seq2:  ['T', 'C', 'T', 'C', 'E']
predict seq2:  ['T', 'C', 'C', 'C', 'E']

train loss:  0.07538177818059921
seq1:  ['G', 'G', 'G', 'G', 'A', 'E']
seq2:  ['T', 'C', 'C', 'C', 'C', 'E']
predict seq2:  ['T', 'C', 'C', 'C', 'C', 'E', 'E', 'E']

train loss:  0.06187400221824646
seq1:  ['C', 'T', 'C', 'G', 'G', 'T', 'C', 'E']
seq2:  ['G', 'A', 'C', 'C', 'G', 'G', 'G', 'E']
predict seq2:  ['G', 'A', 'C', 'G', 'G', 'G', 'G', 'E', 'E']

train loss:  0.8472760915756226
seq1:  ['G', 'A', 'C', 'A', 'G', 'A', 'G', 'T', 'G', 'A', 'G', 'G', 'C', 'T', 'C', 'C', 'A', 'T', 'C', 'T', 'T', 'G', 'G', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'T', 'E']
seq2:  ['G', 'T', 'T', 'T', 'T', 'T', 'A', 'A', 'A', 'A', 'A', 'A', 'C', 'A', 'G', 'C', 'G', 'A', 'C', 'A', 'G', 'G', 'G', 'T', 'C', 'T', 'C', 'T', 'C', 'T', 'C', 'T', 'G', 'T', 'C', 'E']
predict seq2:  ['G', 'T', 'T', 'T', 'T', 'T', 'T', 'T', 'T', 'T', 'T', 'T', 'T', 'T', 'T', 'T', 'T', 'T