In [99]:
import itertools
import numpy as np
import torch
from torch import nn
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F

In [101]:
t1 = torch.tensor([[1, 2,3]])
t2 = torch.tensor([[1], [2],[3]])
torch.matmul(t2, t1)

tensor([[1, 2, 3],
        [2, 4, 6],
        [3, 6, 9]])

In [100]:
t1 = torch.tensor([[1, 2],
                   [3, 4],
                  [0, 0]])
t2 = torch.tensor([[1, 2],
                   [3, 4],
                  [5, 6]])

r1 = t1[:2,:].repeat(3,1)
r2 = t2[:3,:].repeat(2,1)
a = -(r1-r2).abs().sum()/(2*3)
a

tensor(-3.3333)

In [53]:
t = torch.tensor([[[1, 2],
                   [3, 4]]])
print(t.shape)
print(torch.flatten(t, start_dim=0, end_dim=1).shape)

torch.Size([1, 2, 2])
torch.Size([2, 2])


In [2]:
t = torch.tensor([[1, 2], [30, 40],[5,6],[70,80],[9,10]])
t.gather(0, torch.tensor([[0, 0],[2,2],[4,4],[1,1],[3,3]]))

tensor([[ 1,  2],
        [ 5,  6],
        [ 9, 10],
        [30, 40],
        [70, 80]])

In [3]:
from torch.nn.utils.rnn import pack_sequence
a = torch.tensor([1,2,3])
b = torch.tensor([4,5])
c = torch.tensor([6])
pack_sequence([c, a, b], enforce_sorted=False)

PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=tensor([1, 2, 0]), unsorted_indices=tensor([2, 0, 1]))

In [4]:
x = torch.ones(5)  # input tensor
y = torch.zeros(3)  # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w)+b
print(z.grad_fn)

<AddBackward0 object at 0x10dbaeca0>


In [250]:
amino_acids = "arndcqeghilkmfpstwyv"
counter = itertools.count()
aa2index = {a: next(counter) for a in amino_acids}
dummy2index = {"a": 0, "b": 1}

def sequence_one_hot_encoder(indexer, sequence):
    dim = len(indexer)
    one_hot_encoded = np.zeros((len(sequence), dim+1))
    for i, aa in enumerate(sequence):
        index = indexer.get(aa, dim)
        one_hot_encoded[i, index] = 1.0
    return torch.tensor(one_hot_encoded, dtype=torch.float)

def pack_batch(batch):
    lengths = [x.size(0) for x in batch]   # get the length of each sequence in the batch\
    #print(f"sum of lengths {sum(lengths)}")
    #print(f"max of lengths {max(lengths)}")
    padded = nn.utils.rnn.pad_sequence(batch, batch_first=True)  # padd all sequences
    #b, s, n = padded.shape
    #print(f"padded shape {padded.shape}")
        
    # pack padded sequece
    packed = nn.utils.rnn.pack_padded_sequence(padded, lengths=lengths, batch_first=True, enforce_sorted=False)
    #packed = nn.utils.rnn.pack_sequence(batch, enforce_sorted=False)
        
    return packed, lengths

def batchify(batch):
    transposed_data = list(zip(*batch))
    batch1, batch2, labels = transposed_data
    
    return pack_batch(batch1), pack_batch(batch2), torch.stack(labels, dim=0)

class SequenceDataset(Dataset):
    def __init__(self, fpath, encoder):
        self.encoder = encoder
        with open(fpath, "r") as fin:
            self.lines = fin.readlines()

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, idx):
        line = self.lines[idx]
        sline = line.strip().split("\t")
        label_tensor = torch.tensor(np.ones((1,4)))
        label = int(sline[4])
        label_tensor[0,label:] = 0
        return self.encoder(sline[0]), self.encoder(sline[1]), label_tensor
    
class DummyDataset(Dataset):
    def __init__(self, encoder):
        self.encoder = encoder
        self.lines = [
            ("ba", "bba","a.1.1.1","a.2.1.1",2),
            ("aba","aab","a.1.1.1","a.1.1.1",4),
        ]

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, idx):
        sline = self.lines[idx]
        label_tensor = torch.tensor(np.ones((1,4)))
        label = int(sline[4])
        label_tensor[0,label:] = 0        
        return self.encoder(sline[0]), self.encoder(sline[1]), label_tensor

In [251]:
dummy_data = DummyDataset(lambda x: sequence_one_hot_encoder(dummy2index, x))
#dummy_dataloader = DataLoader(dummy_data, batch_size=None, batch_sampler=None)
dummy_dataloader = DataLoader(dummy_data, batch_size=2, collate_fn=batchify)

In [252]:
def uniform_alignment(sequence_embedding1, length1, sequence_embedding2, length2):
    r1 = sequence_embedding1[:length1,:].repeat(length2,1)
    r2 = sequence_embedding2[:length2,:].repeat(length1,1)
    #print(r1)
    #print(r2)
    return -(r1-r2).abs().sum()/(length1*length2)

class OrdinalRegression(nn.Module):
    def __init__(self, n_classes):
        self.n_classes = n_classes
        self.coefficients = torch.tensor(np.ones((1, n_classes))/n_classes, requires_grad=True, dtype=torch.float)    
        self.bias = torch.tensor(np.zeros((1, n_classes)), requires_grad=True, dtype=torch.float)
        
    def forward(self, batch):
        if batch.size(1) != 1:
            raise Exception("second dimension of input should be 1")
        expanded_bias = self.bias.expand(batch.size(0),-1)
        return torch.sigmoid(torch.matmul(batch,F.relu(self.coefficients))+expanded_bias)
    
def structural_similarity_loss(predictions, labels):
    loss_matrix = torch.mul(torch.log(predictions), labels) + torch.mul(torch.log(1.0-predictions), 1-labels)
    return -torch.mean(torch.sum(loss_matrix, dim=1))

In [292]:
class SequenceEmbedder(nn.Module):
    def __init__(self, n_classes, input_dim, hidden_lstm_units=512, n_lstm_layers=1, output_dim=100, bidirectional=True):
        super(SequenceEmbedder, self).__init__()
        
        self.n_classes = n_classes
        self.input_dim = input_dim
        self.hidden_lstm_units = hidden_lstm_units
        self.n_lstm_layers = n_lstm_layers
        self.bidirectional = bidirectional
        self.output_dim = output_dim
        
        self.rnn = nn.LSTM(#TODO: adjust between LSTM/GRU/RNN
            input_size=hidden_lstm_units,
            hidden_size=hidden_lstm_units,
            num_layers=n_lstm_layers,
            batch_first=True,
            bidirectional=bidirectional,
            bias=False,
            #nonlinearity="relu" #TODO: adjust between LSTM/GRU/RNN
        ) 
        self.fix_rnn_input_parameters()
        
        self.input_stack = nn.Sequential(
            nn.Linear(input_dim, hidden_lstm_units, dtype=torch.float),
            nn.ReLU()
        )
        
        self.D = 2 if self.bidirectional else 1
        self.output_stack = nn.Sequential(
            nn.Linear(self.D*hidden_lstm_units, output_dim, dtype=torch.float),
            nn.ReLU()
        )
        
    def _initialize_parameters(self):
        """
        Just to initialize a simple RNN with fixed values for debugging and understanding PyTorch forward flow
        Important note: will not work with LSTM or GRU, will not work with bias in RNN (or bidirectionality)
        """
        ones = np.ones((self.hidden_lstm_units, self.input_dim))
        for rowi in range(ones.shape[0]):
            for coli in range(ones.shape[1]):
                ones[rowi,coli] = rowi+coli
        self.input_stack[0].weight = Parameter(torch.tensor(
            ones,
            dtype=torch.float
        ))
        self.input_stack[0].bias = Parameter(torch.tensor(0, dtype=torch.float))
        
        for k in range(self.n_lstm_layers):
            ones = np.ones(getattr(self.rnn,f"weight_hh_l{k}").shape)
            setattr(self.rnn,f"weight_hh_l{k}",Parameter(torch.tensor(ones, dtype=torch.float)))

        ones = np.ones(self.output_stack[0].weight.shape)   
        self.output_stack[0].weight = Parameter(torch.tensor(
            ones,
            dtype=torch.float
        ))
        self.output_stack[0].bias = Parameter(torch.tensor(0, dtype=torch.float)) 
        
    def fix_rnn_input_parameters(self):
        #TODO: adjust between LSTM/GRU/RNN
        eye = np.repeat(np.eye(self.hidden_lstm_units), 4, axis=0)
        setattr(self.rnn, 
                "weight_ih_l0",
                Parameter(torch.tensor(eye, dtype=torch.float), requires_grad=False))  
        
    def embed_sequence(self, batch, lengths):
        #unpack packed sequences in batch by padding: shape L*input_dim -> shape B*T*input_dim
        #where L is the sum of all sequence lenghts in the batch
        #where B is the batch size
        #where T is the length of the longest sequence in the batch
        batch_padded, _ = nn.utils.rnn.pad_packed_sequence(batch, batch_first=True)
        
        #relu transformation of input and packing sequences in batch for recurrent layer
        #shape B*T*input_dim -> L*hidden_lstm_units
        rnn_inputs = self.input_stack(batch_padded)
        #print(f"{rnn_inputs}")
        rnn_inputs = nn.utils.rnn.pack_padded_sequence(rnn_inputs, lengths=lengths, batch_first=True, enforce_sorted=False)
        #print(f"{rnn_inputs}")
        
        #apply recurrent layers and unpack by padding the output of the last recurrent layer
        #lstm_h_0 = torch.tensor(np.zeros((self.D*self.n_lstm_layers, batch.batch_sizes[0], self.hidden_lstm_units)), dtype=torch.float)
        #lstm_c_0 = torch.tensor(np.zeros((self.D*self.n_lstm_layers, batch.batch_sizes[0], self.hidden_lstm_units)), dtype=torch.float)
        rnn_output = self.rnn(rnn_inputs)
        rnn_out_unpacked, _ = nn.utils.rnn.pad_packed_sequence(rnn_output[0], batch_first=True)
        
        #feed output of recurrent layer to Relu unit
        return self.output_stack(rnn_out_unpacked)        
            
    def forward(self, batch):
        (batch1, lengths1), (batch2, lengths2), labels = batch
        
        sequence_embeddings1 = self.embed_sequence(batch1, lengths1)
        sequence_embeddings2 = self.embed_sequence(batch2, lengths2)
        
        batch_size = sequence_embeddings1.size(0)
        chunks1 = sequence_embeddings1.chunk(batch_size, 0)
        chunks2 = sequence_embeddings2.chunk(batch_size, 0)
        
        alignment_scores = []
        for embedding1, length1, embedding2, length2 in zip(chunks1,
                                          lengths1,
                                         chunks2,
                                         lengths2):
            flattened1 = torch.flatten(embedding1, start_dim=0, end_dim=1)
            flattened2 = torch.flatten(embedding2, start_dim=0, end_dim=1)
            alignment_score = uniform_alignment(flattened1, length1, flattened2, length2)
            alignment_scores.append(alignment_score)
        
        stacked_scores = torch.tensor(alignment_scores).reshape((len(alignment_scores),1))
        
        stacked_predictions = OrdinalRegression(self.n_classes).forward(stacked_scores)
        
        #TODO: labels should be processed somewhere else
        return stacked_predictions, labels
        

In [303]:
model = SequenceEmbedder(4, 3, hidden_lstm_units=2, output_dim=2, bidirectional=True)
#model._initialize_parameters()
model.float()

SequenceEmbedder(
  (rnn): LSTM(2, 2, bias=False, batch_first=True, bidirectional=True)
  (input_stack): Sequential(
    (0): Linear(in_features=3, out_features=2, bias=True)
    (1): ReLU()
  )
  (output_stack): Sequential(
    (0): Linear(in_features=4, out_features=2, bias=True)
    (1): ReLU()
  )
)

In [304]:
for x in dummy_dataloader:
    print(x)
    out = model(x)
    print(out)
print(structural_similarity_loss(out[0],out[1]))

((PackedSequence(data=tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.]]), batch_sizes=tensor([2, 2, 1]), sorted_indices=tensor([1, 0]), unsorted_indices=tensor([1, 0])), [2, 3]), (PackedSequence(data=tensor([[0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.]]), batch_sizes=tensor([2, 2, 2]), sorted_indices=tensor([0, 1]), unsorted_indices=tensor([0, 1])), [3, 3]), tensor([[[1., 1., 0., 0.]],

        [[1., 1., 1., 1.]]], dtype=torch.float64))
(tensor([[0.4976, 0.4976, 0.4976, 0.4976],
        [0.4971, 0.4971, 0.4971, 0.4971]], grad_fn=<SigmoidBackward0>), tensor([[[1., 1., 0., 0.]],

        [[1., 1., 1., 1.]]], dtype=torch.float64))
tensor(1.3916, dtype=torch.float64, grad_fn=<NegBackward0>)


In [295]:
train_data = SequenceDataset("../../data/train_set_0.tsv", lambda x: sequence_one_hot_encoder(aa2index, x))
train_dataloader = DataLoader(train_data, batch_size=None, batch_sampler=None)
#test_dataloader = DataLoader(test_data, batch_size=64)

In [305]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batchi, batch in enumerate(dataloader):
        # Compute prediction and loss
        out = model(batch)
        loss = loss_fn(out[0], out[1])

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batchi % 100 == 0:
            loss, current = loss.item(), batchi * batch[0][0].batch_sizes[0]
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [306]:
model = SequenceEmbedder(4, 21, hidden_lstm_units=512, output_dim=100, bidirectional=True)
#model._initialize_parameters()
model.float()

sequence_data = SequenceDataset("../../data/train_set_0.tsv", lambda x: sequence_one_hot_encoder(aa2index, x))
#dummy_dataloader = DataLoader(dummy_data, batch_size=None, batch_sampler=None)
sequence_dataloader = DataLoader(sequence_data, batch_size=64, collate_fn=batchify)

learning_rate = 0.001
loss_fn = structural_similarity_loss

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(sequence_dataloader, model, loss_fn, optimizer)
    #test_loop(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 41.056584  [    0/100000]
loss: 41.267984  [ 6400/100000]
loss: 41.311211  [12800/100000]
loss: 41.232817  [19200/100000]
loss: 41.083286  [25600/100000]
loss: 40.701659  [32000/100000]
loss: 41.186644  [38400/100000]
loss: 40.855392  [44800/100000]
loss: 41.383181  [51200/100000]
loss: 40.630810  [57600/100000]
loss: 40.769735  [64000/100000]
loss: 40.753051  [70400/100000]
loss: 41.190194  [76800/100000]
loss: 41.461689  [83200/100000]
loss: 41.282957  [89600/100000]
loss: 41.322677  [96000/100000]
Epoch 2
-------------------------------
loss: 41.056584  [    0/100000]
loss: 41.267984  [ 6400/100000]
loss: 41.311211  [12800/100000]
loss: 41.232817  [19200/100000]
loss: 41.083286  [25600/100000]
loss: 40.701659  [32000/100000]
loss: 41.186644  [38400/100000]
loss: 40.855392  [44800/100000]
loss: 41.383181  [51200/100000]
loss: 40.630810  [57600/100000]
loss: 40.769735  [64000/100000]
loss: 40.753051  [70400/100000]
loss: 41.190194  [76800/

KeyboardInterrupt: 