# Generator of IDRNN

Import necessary libraries

In [1]:
import torch
from torch.utils.data import DataLoader
from model_VAE import VAECNN
from dataloader import AminoAcidDataset


Initialize net

In [2]:
net = VAECNN(input_size=21, output_size=21, latent_dim=512)
net.load_state_dict(torch.load('vae_model.pt',map_location=torch.device('cpu')))
# Set the model to evaluation mode
net.eval()

VAECNN(
  (encoder): EncoderCNN(
    (conv1): Conv1d(21, 32, kernel_size=(3,), stride=(1,), padding=(1,))
    (conv2): Conv1d(32, 64, kernel_size=(3,), stride=(1,), padding=(1,))
    (conv3): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
    (tanh): Tanh()
    (relu): LeakyReLU(negative_slope=0.01)
    (fc_mean): Linear(in_features=128, out_features=512, bias=True)
    (fc_logvar): Linear(in_features=128, out_features=512, bias=True)
  )
  (decoder): DecoderCNN(
    (fc): Linear(in_features=512, out_features=128, bias=True)
    (bn_fc): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (variational_dropout): Dropout(p=0.3, inplace=False)
    (deconv1): Conv1d(128, 64, kernel_size=(3,), stride=(1,), padding=(1,))
    (deconv2): Conv1d(64, 32, kernel_size=(3,), stride=(1,), padding=(1,))
    (deconv3): Conv1d(32, 21, kernel_size=(3,), stride=(1,), padding=(1,))
    (bn_final): BatchNorm1d(21, eps=1e-05, momentum=0.1, affine=True, track_ru

Necessary function to decode the embedded output

In [3]:
# Function to convert one-hot encoding to amino acid sequences
def one_hot_to_sequence(one_hot):
    _, max_indices = torch.max(one_hot, dim=1)
    amino_acids = ["A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y", ""]
    sequences = []
    for indices in max_indices:
        sequence = "".join([amino_acids[idx] if idx < len(amino_acids) else '-' for idx in indices if idx != 0])
        sequences.append(sequence)
    return sequences

# Sequence Generation from Latent Space

In [4]:
num_seq = 12
max_sequence_length = 1000

samp = net.generate_sequence(num_seq, max_sequence_length)
for i, sequence in enumerate(one_hot_to_sequence(samp)):
    print(f"Generated {i+1}: {sequence}")

Generated 1: TNPRPHLMGEKHRVFVQEHQEVDQQRINQGWLMDCQPHQGVCVEGHCGGQTGIGDHVNHGHHIDVHFVHFRGTQGGHQVQGCPFPEPGCGSHDVHPGMCYMNQPPLYLLNVMGSCGYSDVYLQYLDIGPGDQFYVVTPTVFGHVQEEHVVDVCHVCWDGDCEEWWVGDHPQGCVDEPPGDQHGNFVPPPVNWHNMHQGGQYHCGHDICPCINGVPPVDGVDDGGPRRRTGYPNCQMPPVGGHDKYRHPGQGPPVPEEPHCGPHCHDVLGTQVGTHVVCHGWPPPYPGMVPEGGCDQNMHYNDIVHMDLGGSVEQPGYHYLGHQPHCDGHLTGQEHFDPCKDQQQDVDVIQNQVLPSDCDHRGPQVNDHPHIDEPFVQGQQDVEFDGPIQPPHHQMWIQTWVFWVHFFQHDGVDPFQQDEYFRHTGWVFHGVPDGHQDYCSEGVTVHCEVPDKLTWQKGNCDKVTFPGCGPPPNQDLCEQPPTPEPPQVGSSNQCFQHEPDDGVPEHPHGFVPGVIDVGGCEQVEVDFQGFHVGKQVPPPVGVRREWQSGYDIGHIWVNHVCGKGSYPPVPVRLFNKVRWNHTHMFPVPWDGVFQMGWGMQHDVGDQGQSHGFGPQYYDDQVHTGSCYGWQDYCHPPPEKGHYWDWYHIHDGPGHVHHDWLMVCVPQVTGMMPEDEPYGRWHPEGMQFHCLIVYMVPPVCGHHGQGGQELGHWGVGQREHGCKGCHNCGNGVHSFVHGGQPDHIMVHHPWTHCNFVGHVFHCKKGKHVHGHPPPHEFTQHDHHHVYMFPPEP
Generated 2: PQHGWHNCCYGKVQTWNVYYPVVHGQWFIDFHQDYDVCPPRWIDQEHVFWGHHHYRTYGWGDHVWDVTDRQYGYPYRTHHPQCFGPHHGWNPGCNVRHFWGVSFQDFNQKGDQCDNQVHHQGWWSWNFHQDVDIQGQGHGLYGHQVDKDIHPPGMHGHGPVEPPGPCNVHVDILKGHDVPHN

# Generation from original dataset

In [5]:

input_csv_file = "data.csv"

max_sequence_length = 1000
num_amino_acids = 20

dataset = AminoAcidDataset(input_csv_file, max_sequence_length, num_amino_acids)
# Test the trained model on an example from the training dataset

test_batch_size = 1  # Set batch size for testing
test_dataloader = DataLoader(dataset, batch_size=test_batch_size, shuffle=True)

# Retrieve an example from the test dataset
example_seq, example_len, example_target = next(iter(test_dataloader))

# Perform inference on the example
with torch.no_grad():
    recon_output, _, _ = net(example_seq)
    recon_output = torch.softmax(recon_output.permute(0,2,1), dim=1)

# Convert the output to amino acid sequences
amino_acid_sequences = one_hot_to_sequence(recon_output)

# Print the reconstructed sequences
for i, sequence in enumerate(amino_acid_sequences):
    print(f"Example {i+1}: {sequence}")

original_input = one_hot_to_sequence(example_seq.transpose(1,2))
for i, sequence in enumerate(original_input):
    print(f"Original {i+1}: {sequence}")

Example 1: METEIDGYITCDNELSPEREHSNMIDLTSSTPNGQHSPSHMTSTNSVKLEMQSDEECDRKPLSREDEIRGHDEGSSLEEPLIESSEVDNRKVQELQGEGGIRLPNG
Original 1: METEIDGYITCDNELSPEREHSNMIDLTSSTPNGQHSPSHMTSTNSVKLEMQSDEECDRKPLSREDEIRGHDEGSSLEEPLIESSEVDNRKVQELQGEGGIRLPNG
