### Notes

First, we need to decide what library can be used for handling data. I will encode sequences numerically, and want to have the option to handle them that way, scale to add different information (BPP, physical properties, distance, etc), and to handle them as graphs. The best one for this seems to be PyTorch (datatype - tensor), it also has the option of PyTorch Geometric. Another option would be TensorFlow/Keras, though it seems a bit harder to handle graphs. 

TO DO:
- set up first NN with X as input and y (coordinates) as output
- incorporate MSA

### Prepare data (X & y)
For now, these are prepared as tensors of one-hot-encoded sequence (padded to make sure they are of same length), and tensors of coordinates. MSA are not yet considered.
Update: since embedding is used, the sequences are instead converted to tensors. The one-hot-encode code is kept below for now.

In [11]:
import pandas as pd
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from torch.nn import Module, MSELoss
from torch.optim import Adam

train_seq = pd.read_csv("../toy_data/train_sequences.csv")
train_lbl = pd.read_csv("../toy_data/train_labels.csv")

Index mapping:

In [29]:
nts = ['G', 'U', 'C', 'A', 'X', '-']
mapping = {nt: idx+1 for idx, nt in enumerate(nts)}


def tokenise_seq(seq, mapping=mapping):
    seq_idx = [mapping[nt] for nt in seq]
    seq_idx = torch.tensor(seq_idx)
    return seq_idx

X_list = [tokenise_seq(seq) for seq in train_seq['sequence']]
X_tensor = pad_sequence(X_list, batch_first=True)

X_tensor[0] # QC


tensor([1, 1, 1, 2, 1, 3, 2, 3, 4, 1, 2, 4, 3, 1, 4, 1, 4, 1, 1, 4, 4, 3, 3, 1,
        3, 4, 3, 3, 3, 0, 0, 0, 0, 0, 0])

One-hot encoding:

In [19]:
# # X: One-hot encode sequence and convert to tenshor

# #seq = train_seq['sequence'][0] # to test

# nts = ['G', 'U', 'C', 'A', 'X', '-']
# mapping = {nt: idx for idx, nt in enumerate(nts)}

# def one_hot_encode_seq(seq):

#     ohe_seq = []

#     for nt in seq:
#         binary_l = [0] * len(nts)
#         binary_l[mapping[nt]] = 1
#         ohe_seq.append(binary_l)
    
#     ohe_torch = torch.tensor(ohe_seq, dtype=torch.float32)
#     return ohe_torch

# X_list = [one_hot_encode_seq(seq) for seq in train_seq['sequence']]
# X_tensor = pad_sequence(X_list, batch_first=True) # pad sequences to same length


In [30]:
# y: Convert coordinates to tensor

train_lbl['base_ID'] = train_lbl['ID'].str.rsplit('_', n=1).str[0]

y_list = []
for idx in list(train_lbl['base_ID'].unique()):

    coords = []
    for res in range(len(train_lbl[train_lbl['ID'].str.startswith(idx)])):
        coord = list(train_lbl.iloc[res, 3:6])
        coords.append(coord)
    
    y_list.append(torch.tensor(coords, dtype=torch.float32))
    
y_tensor = pad_sequence(y_list, batch_first=True)

y_tensor.size()[0:2] == X_tensor.size()[0:2] # check that it's formatted correctly 

True

In [31]:
# Make a padding mask

attn_mask = []
for seq in X_list:
    mask = [False if i < len(seq) else True for i in range(X_tensor.size()[1])]
    attn_mask.append(mask)

attn_mask = torch.tensor(attn_mask)
padding_mask = ~attn_mask

In [15]:
# Create Dataset & Dataloader

from torch.utils.data import random_split

class Rnadataset(Dataset):
    def __init__(self, X_tensor, y_tensor):
        super().__init__()
        self.X_tensor = X_tensor
        self.y_tensor = y_tensor
    
    def __len__(self):
        return len(self.X_tensor)
    
    def __getitem__(self, index) :
        return self.X_tensor[index], self.y_tensor[index]
    
dataset = Rnadataset(X_tensor, y_tensor)

train_size = int(len(dataset)*0.8)
test_size = int(len(dataset)-train_size)

train_data, test_data = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
test_loader = DataLoader(test_data, batch_size=16, shuffle=True)

In [6]:
# Embedding test

embedding = nn.Embedding(35, 4)

embedded_seq = embedding(X_tensor[0])
embedded_seq

tensor([[-0.2162,  0.4166, -0.4622, -1.1417],
        [-0.2162,  0.4166, -0.4622, -1.1417],
        [-0.2162,  0.4166, -0.4622, -1.1417],
        [-0.9356,  0.4105, -0.3889,  0.0239],
        [-0.2162,  0.4166, -0.4622, -1.1417],
        [ 0.1681, -1.9141, -0.7487,  1.1676],
        [-0.9356,  0.4105, -0.3889,  0.0239],
        [ 0.1681, -1.9141, -0.7487,  1.1676],
        [-0.6137,  1.4808, -2.4116, -0.5547],
        [-0.2162,  0.4166, -0.4622, -1.1417],
        [-0.9356,  0.4105, -0.3889,  0.0239],
        [-0.6137,  1.4808, -2.4116, -0.5547],
        [ 0.1681, -1.9141, -0.7487,  1.1676],
        [-0.2162,  0.4166, -0.4622, -1.1417],
        [-0.6137,  1.4808, -2.4116, -0.5547],
        [-0.2162,  0.4166, -0.4622, -1.1417],
        [-0.6137,  1.4808, -2.4116, -0.5547],
        [-0.2162,  0.4166, -0.4622, -1.1417],
        [-0.2162,  0.4166, -0.4622, -1.1417],
        [-0.6137,  1.4808, -2.4116, -0.5547],
        [-0.6137,  1.4808, -2.4116, -0.5547],
        [ 0.1681, -1.9141, -0.7487

### Note on loss function

The competition uses TM-Score to evaluate predictions, which among other things is based on distance rather than absolute differences. As such, for my task, I will be converting both ground truth and predicted coordinates to distance matrices, and minimising loss between the two. Since it leverages  squared difference in distances, we'll use MSE (for now).

### Build initial simple model
The architecture will consist of:
- embedding: mapping integers corresponding to nucleotides in sequence to vectors representing semantic meanings
- sequence encoder:  inspired by RibonanzaNet: 9 layers of 1D conv + residual, multi-head self-attention, and a feed-forward network

In [16]:
# Define blocks of the model

class SeqEncoder(nn.Module): # Define single encoder block
    def __init__(self, hidden_size=256, kernel_size=3):
        super().__init__()
        self.hidden_size = hidden_size
        self.kernel_size = kernel_size
        self.conv = nn.Conv1d(hidden_size, hidden_size, kernel_size=kernel_size, padding = kernel_size // 2)
        self.attn = nn.MultiheadAttention(hidden_size, 8)
        self.norm1 = nn.LayerNorm(hidden_size)
        self.norm2 = nn.LayerNorm(hidden_size)
        self.norm3 = nn.LayerNorm(hidden_size)
        self.ff = nn.Sequential(
            nn.Linear(hidden_size, 4*hidden_size),
            nn.GELU(),
            nn.Linear(4*hidden_size, hidden_size)
        )

    def forward(self, X):
        X = X + self.conv(X.transpose(1,2)).transpose(1,2) # 1D conv with residual connection + Layer Norm; transpose to expected input
        X = self.norm1(X)
        res = X
        attn_out, _ = self.attn(X.transpose(0,1), X.transpose(0,1), X.transpose(0,1))
        attn_out = attn_out.transpose(0,1) + res
        X = self.norm2(attn_out)
        res = X
        X = self.norm3(res + self.ff(X))
        return X
        
class ConvEncoder(nn.Module): # define a whole transformer pipeline
    def __init__(self, n_blocks = 9, **kwargs):
        super().__init__()
        self.layers = nn.ModuleList([SeqEncoder(**kwargs) for _ in range(n_blocks)])
    
    def forward(self, X):
        for layer in self.layers:
            X = layer(X)
        return X


In [38]:
# Define model 

class InitModel(Module): # define rest of model
    def __init__(self, seq_length=35, vocab=6, n_blocks=9, hidden_size=256):
        super().__init__()
        #self.b, self.l = X.size()
        self.l = seq_length
        self.b = vocab
        self.embedding = nn.Embedding(self.l , hidden_size, padding_idx=0) # map each base to a vector representation of size 256
        self.pos_embedding = nn.Embedding(self.l , hidden_size)
        self.convencoder = ConvEncoder(n_blocks=n_blocks, hidden_size=hidden_size)
        self.output = nn.Linear(hidden_size, 3)

    def forward(self, X):

        # Make embeddings (+ positional embeddings)

        X = self.embedding(X)
        positions = torch.arange(self.l).unsqueeze(0).expand(X.size(0), self.l)
        pos_embd = self.pos_embedding(positions)
        X = X + pos_embd

        # Pass through convolutional transformer

        X = self.convencoder(X)

        out = self.output(X)
        return(out)

        ## TO DO: add padding masks, add layers which map the encoded representations to coords, add distance calculation, minimise loss btwn og dist & encoded dist 


initmodel = InitModel()       

In [42]:
# Make pairwise distance

truth_seq = y_tensor[0]

d_matrix = torch.zeros(truth_seq.size()[0], truth_seq.size()[0])

for i in range(truth_seq.size()[0]):
    i_coord = truth_seq[i]
    for j in range(truth_seq.size()[0]):
        j_coord = truth_seq[j]
        d_matrix[i,j] = torch.norm(i_coord - j_coord)


#def distance_mse_loss(pred, truth):

d_matrix

tensor([[ 0.0000,  6.3101, 10.2222,  ..., 29.3938, 29.3938, 29.3938],
        [ 6.3101,  0.0000,  5.2843,  ..., 31.1803, 31.1803, 31.1803],
        [10.2222,  5.2843,  0.0000,  ..., 28.9600, 28.9600, 28.9600],
        ...,
        [29.3938, 31.1803, 28.9600,  ...,  0.0000,  0.0000,  0.0000],
        [29.3938, 31.1803, 28.9600,  ...,  0.0000,  0.0000,  0.0000],
        [29.3938, 31.1803, 28.9600,  ...,  0.0000,  0.0000,  0.0000]])

In [18]:
# Define custom loss function on distance matrices rather than coords

def pairwise_distance_matrix(X):
    diff = X.unsqueeze(2) - X.unsqueeze(1)  # shape: (batch, 35, 35, 5)
    return torch.norm(diff, dim=-1)

class DistanceMatrixLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss = MSELoss()
    
    def forward(self, y_true, y_pred):
        y_true_m = pairwise_distance_matrix(y_true)
        y_pred_m = pairwise_distance_matrix(y_pred)
        loss = self.loss(y_true_m, y_pred_m)
        return loss 


In [45]:
from func import score

n_epochs = 10

initmodel = InitModel()
criterion = DistanceMatrixLoss()
optimiser = Adam(initmodel.parameters())

cols = ["Epoch", "Train_Loss", "Test_Loss", "Train_TMScore", "Test_TMScore"]
perf = pd.DataFrame(index=range(n_epochs), columns=cols)

for epoch in range(n_epochs):
    loss_train = []
    epoch_pred_train = []
    epoch_true_train = []
    initmodel.train()
    for seq, coords in train_loader:
        optimiser.zero_grad()
        pred_coords = initmodel(seq)
        loss = criterion(coords,pred_coords)
        loss_train.append(loss.item())
        epoch_pred_train.append(pred_coords.detach())
        epoch_true_train.append(coords.detach())
        loss.backward()
        optimiser.step()
    
    #TMTrain = score(pd.DataFrame(epoch_true_train), pd.DataFrame(epoch_pred_train))
    loss_train = sum(loss_train)/len(loss_train)

    initmodel.eval()
    with torch.no_grad():
        loss_test = []
        epoch_pred_test = []
        epoch_true_test = []
        for seq, coords in test_loader:
            pred_coords_test = initmodel(seq)
            epoch_true_test.extend(coords)
            epoch_pred_test.extend(pred_coords_test)
            loss_test.append(criterion(coords, pred_coords_test).item())
        
        #TMTest = score(pd.DataFrame(epoch_true_test), pd.DataFrame(epoch_pred_test))
        loss_test = sum(loss_test)/len(loss_test)
    
    #perf.iloc[epoch, :] = [epoch+1, loss_train, loss_test, TMTrain, TMTest]
    print(f"Epoch {epoch+1}: Loss train {round(loss_train, 2)}, Loss Test {round(loss_test, 2)}")


    

# TO DO: figure out TM Score (expects 2D dataframe of values), add padding masks, refine whole model, and train on full data

Epoch 1: Loss train 285.67, Loss Test 206.67
Epoch 2: Loss train 191.1, Loss Test 147.29
Epoch 3: Loss train 137.03, Loss Test 111.4
Epoch 4: Loss train 111.21, Loss Test 98.1
Epoch 5: Loss train 101.18, Loss Test 91.27
Epoch 6: Loss train 91.03, Loss Test 62.54
Epoch 7: Loss train 71.73, Loss Test 61.36
Epoch 8: Loss train 60.41, Loss Test 55.12
Epoch 9: Loss train 52.11, Loss Test 53.65
Epoch 10: Loss train 51.01, Loss Test 45.95


In [43]:
pred_coords

tensor([[[-5.1221, -1.5074, -3.5171],
         [-6.6697, -0.9203, -1.8650],
         [ 1.3617, -3.4590, -2.5066],
         [-2.3103, -1.5089, -3.9034],
         [ 4.7777, -4.9347, -0.9478],
         [-6.5557,  0.7278, -1.0197],
         [-6.3039, -0.7329, -3.0912],
         [ 3.4688, -3.2465,  4.8137],
         [ 0.5209,  4.0770, -1.2205],
         [ 6.8408, -0.7116, -0.1893],
         [ 1.5479,  3.5921, -4.1622],
         [ 5.6581,  0.4541,  2.6360],
         [-1.1962,  2.9436,  5.5144],
         [-1.9997,  4.2526,  3.1709],
         [-4.2152,  1.4335,  4.3036],
         [ 6.2778, -1.4252,  4.2280],
         [-3.7223,  4.9912, -2.2801],
         [-2.1177, -5.9105, -0.7925],
         [ 5.0614, -2.5719,  4.3367],
         [-2.9110,  0.8523, -5.8743],
         [ 4.9273, -2.6767,  4.7122],
         [-6.2562,  3.0101,  0.8944],
         [ 0.9202, -2.4327, -4.7375],
         [-6.1439,  2.7268,  2.2095],
         [-6.3950,  2.1013, -2.5844],
         [ 3.2267,  0.4759,  3.1048],
         [-0