In [25]:
import itertools
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

In [137]:
amino_acids = "arndcqeghilkmfpstwyv"
counter = itertools.count()
aa2index = {a: next(counter) for a in amino_acids}

def sequence_one_hot_encoder(indexer, sequence):
    dim = len(indexer)
    one_hot_encoded = np.zeros((len(sequence), dim+1))
    for i, aa in enumerate(sequence):
        index = indexer.get(aa, dim)
        one_hot_encoded[i, index] = 1.0
    return torch.tensor(one_hot_encoded, dtype=torch.float)

def pack_batch(batch):
    #lengths = [x.size(0) for x in batch]   # get the length of each sequence in the batch\
    #print(f"sum of lengths {sum(lengths)}")
    #print(f"max of lengths {max(lengths)}")
    #padded = nn.utils.rnn.pad_sequence(batch, batch_first=True)  # padd all sequences
    #b, s, n = padded.shape
    #print(f"padded shape {padded.shape}")
        
    # pack padded sequece
    #padded = nn.utils.rnn.pack_padded_sequence(padded, lengths=lengths, batch_first=True, enforce_sorted=False)
    packed = nn.utils.rnn.pack_sequence(batch, enforce_sorted=False)
        
    return packed

def batchify(batch):
    transposed_data = list(zip(*batch))
    batch1, batch2, labels = transposed_data
    
    return pack_batch(batch1), pack_batch(batch2), torch.tensor(labels)

class SequenceDataset(Dataset):
    def __init__(self, fpath, encoder):
        self.encoder = encoder
        with open(fpath, "r") as fin:
            self.lines = fin.readlines()

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, idx):
        line = self.lines[idx]
        sline = line.strip().split("\t")
        return self.encoder(sline[0]), self.encoder(sline[1]), int(sline[4])

In [138]:
train_data = SequenceDataset("../../data/train_set_0.tsv", lambda x: sequence_one_hot_encoder(aa2index, x))
train_dataloader = DataLoader(train_data, batch_size=None, batch_sampler=None)
#test_dataloader = DataLoader(test_data, batch_size=64)

In [139]:
for batch in train_dataloader:
    print(batch[0].shape)
    print(batch[2])
    break

torch.Size([89, 21])
1


In [153]:
class SequenceEmbedder(nn.Module):
    def __init__(self, input_dim, hidden_lstm_units=512, n_lstm_layers=1):
        super(SequenceEmbedder, self).__init__()
        
        self.lstm = nn.LSTM(
            input_size=hidden_lstm_units,
            hidden_size=hidden_lstm_units,
            num_layers=n_lstm_layers,
            batch_first=True,
            bidirectional=True,
        ) 
        
        self.input_stack = nn.Sequential(
            nn.Linear(input_dim, hidden_lstm_units, dtype=float),
            nn.ReLU()
        )
        
        #self.output_stack = nn.Sequential(
        #    nn.Linear(),
        #    nn.ReLU()
        #)

    def forward(self, batch):
        lstm_inputs = []
        for sequence1, sequence2, label in batch:
            lstm_inputs.append(self.input_stack(sequence1))
            
        print(len(lstm_inputs))
        print(lstm_inputs[0].shape)
        lstm_inputs = pack_batch(lstm_inputs)
        print(lstm_inputs.data.shape)
        
        return self.lstm(lstm_inputs)
        

model = SequenceEmbedder(21)
model.float()

SequenceEmbedder(
  (lstm): LSTM(512, 512, batch_first=True, bidirectional=True)
  (input_stack): Sequential(
    (0): Linear(in_features=21, out_features=512, bias=True)
    (1): ReLU()
  )
)

In [158]:
batch = []
for i, x in enumerate(train_dataloader):
    batch.append(x)
    if i>=63:
        break
out, (final_hidden_state, final_cell_state) = model.forward(batch)
print(out.data.shape)
print(final_hidden_state.shape)

64
torch.Size([89, 512])
torch.Size([11525, 512])
torch.Size([11525, 1024])
torch.Size([2, 64, 512])


In [None]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)
print("Done!")