Using this as a resource: https://suzyahyah.github.io/pytorch/2019/07/01/DataLoader-Pad-Pack-Sequence.html

In [1]:
import numpy as np
import pandas as pd
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqFeature import SeqFeature, FeatureLocation

import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from torch.utils.data import Dataset, DataLoader

In [2]:
def get_list_of_codons(dna_seq):
    codons = []
    for i in range(0, len(dna_seq), 3):
        codons.append(dna_seq[i:i+3])
    return codons
assert get_list_of_codons('ATGCCCGGGAAATTTTAG') == ['ATG', 'CCC', 'GGG', 'AAA', 'TTT', 'TAG']

In [3]:
host_proteins = []
host_codons = []

u_aas = set()
u_codons = set()
all_codons = []
initial_states = []
emissions = {}
for record in SeqIO.parse('GCF_000009045.1_ASM904v1_genomic.gbff', "genbank"):
    for feature in record.features:
        if feature.type == 'CDS' and 'translation' in feature.qualifiers:
            protein = feature.qualifiers['translation'][0] + '*'
            host_proteins.append(protein)
            
            aas = set([aa for aa in protein])
            codon = get_list_of_codons(str(feature.extract(record.seq)))
            
            host_codons.append(codon)
            all_codons.append(codon)
            initial_states.append(codon[0])
            u_aas = u_aas.union(aas)
            u_codons = u_codons.union(set(codon))
            for i, cdn in enumerate(codon):
                emissions[cdn] = protein[i]
lu_aas = ['0'] + list(u_aas)
lu_codons = ['PAD'] + list(u_codons)

In [4]:
def encode_seq(seq_obj, seqtype='dna'):
    encdr = lu_codons
    symbols = get_list_of_codons(seq_obj)
    if seqtype != 'dna':
        encdr = lu_aas
        symbols = [c for c in seq_obj]
    outseq = np.array([encdr.index(s) for s in symbols])
    return outseq

test_aa = 'MENILD0'
test_nuc = 'AAAAAAATAAGATAGPAD'
assert encode_seq(test_aa, seqtype='prot')[0] == lu_aas.index(test_aa[0]) and \
       encode_seq(test_aa, seqtype='prot')[-1] == lu_aas.index(test_aa[-1])
assert encode_seq(test_nuc, seqtype='dna')[0] == lu_codons.index(test_nuc[0:3]) and \
       encode_seq(test_nuc, seqtype='dna')[-1] == lu_codons.index(test_nuc[-3:])

def decode_seq(num_array, seqtype='dna'):
    encdr = lu_codons
    if seqtype != 'dna':
        encdr = lu_aas
    outseq = [encdr[s] for s in num_array]
    return ''.join(outseq)

assert decode_seq(encode_seq(test_nuc)) == test_nuc
assert decode_seq(encode_seq(test_aa, seqtype='prot'), seqtype='prot') == test_aa

#### Pytorch

In [5]:
# Think of the sequences as one hot vectors stacked up, so that the dimensions are (seq_len, num_characters)
# Given a dictionary of characters of length 300, we'll pad and pack 3 sequences together of lengths 25, 22, and 15
a = torch.ones(25, 300)
b = torch.ones(22, 300)
c = torch.ones(15, 300)
d = pad_sequence([a, b, c])
d.size()

torch.Size([25, 3, 300])

In [6]:
# https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
class ProteinsSet(Dataset):
    """Protein data set"""
    
    def __init__(self, list_of_proteins, list_of_codons, codon_list, aa_list):
        self.prot_collection = list_of_proteins
        self.codon_collection = list_of_codons
        self.lu_codons = codon_list
        self.lu_aas = aa_list
        
    def __len__(self):
        return len(self.prot_collection)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        encd_prot = self.__encode__(self.prot_collection[idx], 'prot')
        encd_codn = self.__encode__(self.codon_collection[idx], 'dna')
        return (encd_prot, encd_codn)
    
    def __encode__(self, seq_obj, seqtype):
        encdr = self.lu_codons
        symbols = seq_obj
        if seqtype != 'dna':
            encdr = self.lu_aas
            symbols = [c for c in seq_obj]
        encoded_seq = []
        t = torch.zeros(len(symbols), len(encdr))
        for i, s in enumerate(symbols):
            t[i, encdr.index(s)] = 1
        return t

def pad_collate(batch):
    # turn list of tuples into two lists: https://stackoverflow.com/a/8081590 
    [xx, yy] = map(list, zip(*batch))
    print(xx[0].shape)
    x_lens = [len(x) for x in xx]
    y_lens = [len(y) for y in yy]
    xx_pad = pad_sequence(xx, batch_first=False, padding_value=0)
    yy_pad = pad_sequence(yy, batch_first=False, padding_value=0)
    print([xx_pad.size()])
    return xx_pad, yy_pad, x_lens, y_lens

ps = ProteinsSet(host_proteins, host_codons, lu_codons, lu_aas)
batch_size = 16
data_loader = DataLoader(dataset=ps, batch_size=batch_size, shuffle=True, collate_fn=pad_collate)

In [7]:
for batch in data_loader:
    break

torch.Size([210, 22])
[torch.Size([519, 16, 22])]


In [8]:
ex_dim = max(batch[2]) #expect xx_pad to be (max(x_lens), 16, 22)
assert batch[0].size()[0] == ex_dim

In [9]:
# Why pack? https://stackoverflow.com/a/56211056
embedding_dim = len(lu_aas)
h_dim = 100
n_layers = 2
x_embed = batch[0]
x_lens = batch[2]
rnn = nn.GRU(embedding_dim, h_dim, n_layers, batch_first=False)

x_packed = pack_padded_sequence(x_embed, x_lens, batch_first=False, enforce_sorted=False)

In [10]:
h0 = torch.zeros(n_layers, batch_size, h_dim)
output_packed, h1 = rnn(x_packed, h0)

In [11]:
output_padded, output_lengths = pad_packed_sequence(output_packed, batch_first=False)

In [12]:
output_padded.size()

torch.Size([519, 16, 100])