# Predicting distances

The model 1 learnt to predict the actual 3D coordinates of each nucleotide. However, this means that it was also learning the rotation in space, which for the purpose of this task is irrelevant - we want rotation & translation invariance. In this next iteration, the model will predict distances, which will be deterministically mapped back to coordinates. 

### Prepare data (X & y)
For now, these are prepared as tensors of one-hot-encoded sequence (padded to make sure they are of same length), and tensors of coordinates. MSA are not yet considered.
Update: since embedding is used, the sequences are instead converted to tensors. The one-hot-encode code is kept below for now.

In [3]:
import pandas as pd
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from torch.nn import Module, MSELoss
from torch.optim import Adam
from torch.utils.data import random_split

train_seq = pd.read_csv("../toy_data/train_sequences.csv")
train_lbl = pd.read_csv("../toy_data/train_labels.csv")
train_lbl = train_lbl.interpolate() # For now, interpolate - are there better imputation techniques?

train_lbl["ID_num"] = [n+1 for n in range(len(train_lbl))] # map ID to numeric ID to store in tensor
id_mapping = {idx+1: og_id for idx, og_id in enumerate(train_lbl['ID'])} # create mapping to re-map back to original ID
id_mapping[0] = "padded_row"

#train_lbl[train_lbl.iloc[:,3:6].isna().any(axis=1)] # Check which rows  rows have NaN


  train_lbl = train_lbl.interpolate() # For now, interpolate - are there better imputation techniques?


In [4]:
# Create Dataset & Dataloader

nts = ['G', 'U', 'C', 'A', 'X', '-']
mapping = {nt: idx+1 for idx, nt in enumerate(nts)}
reverse_mapping = {v: k for k, v in mapping.items()}


def tokenise_seq(seq, mapping=mapping):
    seq_idx = [mapping[nt] for nt in seq]
    seq_idx = torch.tensor(seq_idx)
    return seq_idx

def make_coord_tensor(train_lbl):
    train_lbl['base_ID'] = train_lbl['ID'].str.rsplit('_', n=1).str[0] # sequence ID for each nt
    main_id_list = train_lbl['ID']
    y_list = []
    og_id_list_temp = [] # not extended list
    for idx in list(train_lbl['base_ID'].unique()):
        subset = train_lbl[train_lbl['base_ID'] == idx]
        coords = []
        for res in range(len(subset['ID'])):
            coord = list(subset.iloc[res, 3:6])
            coords.append(coord)
        
        og_id_list_temp.append(torch.tensor(list(subset['ID_num'])))
        
        y_list.append(torch.tensor(coords, dtype=torch.float32))
        
    y_tensor = pad_sequence(y_list, batch_first=True)
    og_id_list = pad_sequence(og_id_list_temp, batch_first=True)

    return y_list, y_tensor, og_id_list

# Create Dataset & Dataloader

def collate(batch):
    xs, ys, ids = zip(*batch)
    len_x = [x.size(0) for x in xs]

    x_padded = pad_sequence(xs, batch_first=True)
    y_padded = pad_sequence(ys, batch_first=True)
    id_padded = pad_sequence(ids, batch_first=True)

    return x_padded, y_padded, id_padded, torch.tensor(len_x)


nts = ['G', 'U', 'C', 'A', 'X', '-']
mapping = {nt: idx+1 for idx, nt in enumerate(nts)}
reverse_mapping = {v: k for k, v in mapping.items()}


def tokenise_seq(seq, mapping=mapping):
    seq_idx = [mapping[nt] for nt in seq]
    seq_idx = torch.tensor(seq_idx)
    return seq_idx

def make_coord_tensor(train_lbl):
    train_lbl['base_ID'] = train_lbl['ID'].str.rsplit('_', n=1).str[0] # sequence ID for each nt
    main_id_list = train_lbl['ID']
    y_list = []
    og_id_list_temp = [] # not extended list
    for idx in list(train_lbl['base_ID'].unique()):
        subset = train_lbl[train_lbl['base_ID'] == idx]
        coords = []
        for res in range(len(subset['ID'])):
            coord = list(subset.iloc[res, 3:6])
            coords.append(coord)
        
        og_id_list_temp.append(torch.tensor(list(subset['ID_num'])))
        
        y_list.append(torch.tensor(coords, dtype=torch.float32))
        
    #y_tensor = pad_sequence(y_list, batch_first=True)
    og_id_list = pad_sequence(og_id_list_temp, batch_first=True)

    return y_list, og_id_list

class Rnadataset(Dataset):
    def __init__(self, train_seq, train_lbl):
        super().__init__()
        self.X_list = [tokenise_seq(seq) for seq in train_seq['sequence']]
        #self.X_tensor = pad_sequence(self.X_list, batch_first=True)
        
        self.y_list, self.ids = make_coord_tensor(train_lbl)
        if all(train_lbl["base_ID"].unique() == train_seq['target_id']): # Always good to check
            print("Order corresponds between sequences and coordinates")
        else:
            raise ValueError("Mismatch between base_IDs in train_lbl and target_ids in train_seq.")
            
        #self.ids = train_seq['target_id']

    def __len__(self):
        return len(self.X_list)
    
    def __getitem__(self, index) :
        return self.X_list[index], self.y_list[index], self.ids[index]
    
dataset = Rnadataset(train_seq, train_lbl)

train_size = int(len(dataset)*0.8)
test_size = int(len(dataset)-train_size)

train_data, test_data = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_data, batch_size=32, shuffle=False, collate_fn=collate, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False, collate_fn=collate, num_workers=8, pin_memory=True)


Order corresponds between sequences and coordinates


### Functions to map distances to coords & vica versa

Using classical multidimensional scaling cMDS

In [5]:
def distances_to_coords(D):
    L = D.shape[0] # D: (L, L) symmetric, zero diagonal
    I = torch.eye(L) # Centering matrix - LxL identity
    ones = torch.ones((L, L)) / L
    H = I - ones

    D2 = D**2 # Squared distances
    B = -0.5 * H @ D2 @ H # Double‐centered Gram matrix

    # Eigen‐decomposition
    eigvals, eigvecs = torch.linalg.eigh(B)
    # Sort descending
    idx = torch.argsort(eigvals, descending=True)
    vals = eigvals[idx][:3]
    vecs = eigvecs[:, idx][:, :3]

    # Coordinates = V * sqrt(Λ)
    return vecs * torch.sqrt(vals).unsqueeze(0)

# Define custom loss function on distance matrices rather than coords

def pairwise_distance_matrix(X):
    diff = X.unsqueeze(2) - X.unsqueeze(1)  # shape: (batch, 35, 35, 5)
    return torch.norm(diff, dim=-1)

### Build initial simple model
The main architecture will be the same as in the initial model.

In [6]:
# Define blocks of the model

class SeqEncoder(nn.Module): # Define single encoder block
    def __init__(self, hidden_size=256, kernel_size=3):
        super().__init__()
        self.hidden_size = hidden_size
        self.kernel_size = kernel_size
        self.conv = nn.Conv1d(hidden_size, hidden_size, kernel_size=kernel_size, padding = kernel_size // 2)
        self.attn = nn.MultiheadAttention(hidden_size, 8)
        self.norm1 = nn.LayerNorm(hidden_size)
        self.norm2 = nn.LayerNorm(hidden_size)
        self.norm3 = nn.LayerNorm(hidden_size)
        self.ff = nn.Sequential(
            nn.Linear(hidden_size, 4*hidden_size),
            nn.GELU(),
            nn.Linear(4*hidden_size, hidden_size)
        )

    def forward(self, X, padding_mask=None):
        X = X + self.conv(X.transpose(1,2)).transpose(1,2) # 1D conv with residual connection + Layer Norm; transpose to expected input
        X = self.norm1(X)
        res = X
        attn_out, _ = self.attn(X.transpose(0,1), X.transpose(0,1), X.transpose(0,1), key_padding_mask=padding_mask)
        attn_out = attn_out.transpose(0,1) + res
        X = self.norm2(attn_out)
        res = X
        X = self.norm3(res + self.ff(X))
        return X
        
class ConvEncoder(nn.Module): # define a whole transformer pipeline
    def __init__(self, n_blocks = 9, **kwargs):
        super().__init__()
        self.layers = nn.ModuleList([SeqEncoder(**kwargs) for _ in range(n_blocks)])
    
    def forward(self, X, padding_mask=None):
        for layer in self.layers:
            X = layer(X, padding_mask=padding_mask)
        return X
    
class DistancePredictor(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        f = 4 * hidden_size
        self.mlp = nn.Sequential(
            nn.Linear(f, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, 1))
        
    def forward(self, X):
        b, l, d = X.size()
        Xi = X.unsqueeze(2).expand(-1, -1, l, -1) # position i
        Xj = X.unsqueeze(1).expand(-1, l, -1, -1) # position j
        f = torch.cat([Xi, Xj, Xi-Xj, Xi*Xj], dim = -1) # stack i & j repr, their distance (-), and similarity (*)
        d = self.mlp(f).squeeze(-1)
        d = torch.relu(d)
        d = (d+d.transpose(1,2))*0.5 # symmetric
        d = d.masked_fill(torch.eye(l).bool(), 0.) # 0 across diagonal
        return d


In [7]:
# Define model 

class InitModel(Module): # define rest of model
    def __init__(self, vocab=6, max_len = 1024, n_blocks=9, hidden_size=256):
        super().__init__()
        self.b = vocab
        self.embedding = nn.Embedding(self.b, hidden_size, padding_idx=0) # map each base to a vector representation of size 256
        self.pos_embedding = nn.Embedding( max_len, hidden_size)
        self.convencoder = ConvEncoder(n_blocks=n_blocks, hidden_size=hidden_size)
        self.output=DistancePredictor(hidden_size=hidden_size)
        #self.output = nn.Linear(hidden_size, 3)

    def forward(self, X):

        # Make embeddings (+ positional embeddings)

        pad_mask = (X == 0)
        seq_length = X.size()[1]

        X = self.embedding(X)
        positions = torch.arange(seq_length).unsqueeze(0).expand(X.size(0), seq_length)
        pos_embd = self.pos_embedding(positions)
        X = X + pos_embd

        # Pass through convolutional transformer

        X = self.convencoder(X, padding_mask=pad_mask)

        out = self.output(X)
        return(out)

        ## TO DO: add padding masks, add layers which map the encoded representations to coords, add distance calculation, minimise loss btwn og dist & encoded dist 

### A brief interlude to test compilers for training speedup

In [23]:
# No compiler

initmodel = InitModel()
#initmodel = torch.compile(initmodel, backend="aot_eager")
print("begin training")
with torch.autograd.profiler.profile(record_shapes=True) as prof:
    initmodel.train()
    loss_train = []
    for i, (seq, coords, ids) in enumerate(train_loader):
        if i >= 3:
            break
        pad_mask = (seq == 0)
        optimiser.zero_grad()
        true = pairwise_distance_matrix(coords)
        pred_i = initmodel(seq)
        mask = (seq!=0).unsqueeze(1).expand_as(pred_i)
        pred = pred_i[mask]
        true = true[mask]
        loss = criterion(pred,true)
        loss_train.append(loss.item())
        num_ids.extend(ids.flatten(0,1).tolist())
        loss.backward()
        optimiser.step()

# Dump the profiling table
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

begin training
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                               aten::mm        44.53%     189.373ms        44.54%     189.444ms     681.454us           278  
    autograd::engine::evaluate_function: AddmmBackward0         0.32%       1.355ms        37.28%     158.548ms       2.734ms            58  
                                         AddmmBackward0         0.17%     715.001us        36.50%     155.216ms       2.676ms            58  
                                           aten::linear         0.15%     654.011us        22.56%      95.960ms     856.785us        

In [25]:
# Compiler

initmodel = InitModel()
initmodel = torch.compile(initmodel, backend="aot_eager")
print("begin training")
with torch.autograd.profiler.profile(record_shapes=True) as prof:
    initmodel.train()
    loss_train = []
    for i, (seq, coords, ids) in enumerate(train_loader):
        if i >= 3:
            break
        pad_mask = (seq == 0)
        optimiser.zero_grad()
        true = pairwise_distance_matrix(coords)
        pred_i = initmodel(seq)
        mask = (seq!=0).unsqueeze(1).expand_as(pred_i)
        pred = pred_i[mask]
        true = true[mask]
        loss = criterion(pred,true)
        loss_train.append(loss.item())
        num_ids.extend(ids.flatten(0,1).tolist())
        loss.backward()
        optimiser.step()

# Dump the profiling table
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

begin training
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
autograd::engine::evaluate_function: CompiledFunctio...         0.06%     268.758us        63.29%     304.842ms     152.421ms             2  
                               CompiledFunctionBackward         9.87%      47.532ms        63.24%     304.573ms     152.286ms             2  
                                               aten::mm        37.84%     182.258ms        37.86%     182.340ms     655.898us           278  
                             Torch-Compiled Region: 0/3         0.17%     795.004us        34.81%     167.649ms      83.824ms        

In [30]:
# Trace

initmodel = InitModel()
example_seq, example_coords, example_ids = next(iter(train_loader))

initmodel.eval()
initmodel = torch.jit.trace(initmodel, example_seq)
initmodel.train()
print("begin training")
with torch.autograd.profiler.profile(record_shapes=True) as prof:
    initmodel.train()
    loss_train = []
    for i, (seq, coords, ids) in enumerate(train_loader):
        if i >= 3:
            break
        pad_mask = (seq == 0)
        optimiser.zero_grad()
        true = pairwise_distance_matrix(coords)
        pred_i = initmodel(seq)
        mask = (seq!=0).unsqueeze(1).expand_as(pred_i)
        pred = pred_i[mask]
        true = true[mask]
        loss = criterion(pred,true)
        loss_train.append(loss.item())
        num_ids.extend(ids.flatten(0,1).tolist())
        loss.backward()
        optimiser.step()

# Dump the profiling table
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

begin training
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                               aten::mm        44.79%     198.175ms        44.81%     198.251ms     713.134us           278  
    autograd::engine::evaluate_function: AddmmBackward0         0.34%       1.515ms        37.48%     165.807ms       2.859ms            58  
                                         AddmmBackward0         0.14%     634.489us        36.70%     162.370ms       2.799ms            58  
                                                forward         0.39%       1.723ms        33.76%     149.362ms      74.681ms        

In [14]:
print(torch.__version__)

2.6.0+cu124


In [8]:
n_epochs = 10

import time

start = time.time()


initmodel = InitModel()
criterion = MSELoss()
optimiser = Adam(initmodel.parameters())

cols = ["Epoch", "Train_Loss", "Test_Loss"]
perf = pd.DataFrame(index=range(n_epochs), columns=cols)

for epoch in range(n_epochs):
    loss_train = []
    epoch_pred_train = []
    epoch_true_train = []
    num_ids = []
    seq_idx = []
    initmodel.train()
    for seq, coords, ids, _ in train_loader:
        pad_mask = (seq == 0)
        optimiser.zero_grad()
        true = pairwise_distance_matrix(coords)
        pred_i = initmodel(seq)
        mask = (seq!=0).unsqueeze(1).expand_as(pred_i)
        pred = pred_i[mask]
        true = true[mask]
        loss = criterion(pred,true)
        loss_train.append(loss.item())
        num_ids.extend(ids.flatten(0,1).tolist())
        loss.backward()
        optimiser.step()

    loss_train = sum(loss_train)/len(loss_train)

    initmodel.eval()
    with torch.no_grad():
        loss_test = []
        for seq, coords, ids, _ in test_loader:
            pred_test = initmodel(seq)
            true_test = pairwise_distance_matrix(coords)
            mask = (seq!=0).unsqueeze(1).expand_as(pred_test)
            pred_test = pred_test[mask]
            true_test = true_test[mask]
            loss = criterion(pred, true)
            loss_test.append(loss.item())
        
        loss_test_val = sum(loss_test)/len(loss_test)
    
    perf.iloc[epoch, :] = [epoch+1, loss_train, loss_test]
    print(f"Epoch {epoch+1}: Loss train {round(loss_train, 2)}, Loss Test {round(loss_test_val, 2)}")

end = time.time()

print(f"Time: {end-start}s")
# TO DO: figure out TM Score (expects 2D dataframe of values), add padding masks, refine whole model, and train on full data, figure out how to import from src/func

Epoch 1: Loss train 547.58, Loss Test 547.58
Epoch 2: Loss train 447.32, Loss Test 447.32
Epoch 3: Loss train 338.92, Loss Test 338.92
Epoch 4: Loss train 270.27, Loss Test 270.27
Epoch 5: Loss train 223.45, Loss Test 223.45
Epoch 6: Loss train 192.55, Loss Test 192.55
Epoch 7: Loss train 178.04, Loss Test 178.04
Epoch 8: Loss train 178.37, Loss Test 178.37
Epoch 9: Loss train 187.18, Loss Test 187.18
Epoch 10: Loss train 192.23, Loss Test 192.23
Time: 16.685160160064697s


In [513]:
perf.to_csv("../outputs/DistPred/distpred_perf.csv")

In [11]:
pred_i.size()


torch.Size([24, 35, 35])

In [12]:
# Get validation set

validation_seq = pd.read_csv("../data/validation_sequences.csv")
validation_lbl = pd.read_csv("../data/validation_labels.csv")
validation_lbl["ID_num"] = [n+1 for n in range(len(validation_lbl))] # map ID to numeric ID to store in tensor
id_mapping_val = {idx+1: og_id for idx, og_id in enumerate(validation_lbl['ID'])} # create mapping to re-map back to original ID
id_mapping_val[0] = "padded_row"
val_set = Rnadataset(validation_seq, validation_lbl)
val_loader = DataLoader(val_set,batch_size=32,shuffle=False,num_workers=4,pin_memory=False,collate_fn=collate)


# Make predictions on validation set
initmodel.eval()
all_preds = []
with torch.no_grad():
    for seq, coords, ids, lengths in val_loader:
        pred = initmodel(seq)         
        for b in range(pred.size(0)):
            single = pred[b]          
            mask   = seq[b] != 0
            coords = distances_to_coords(single)
            coords = coords[mask]
            all_preds.append(coords)

        

stacked = torch.cat(all_preds, dim=0)



Order corresponds between sequences and coordinates


torch.Size([2515, 3])

In [15]:


submission_cols = ['ID', 'resname', 'resid', 'x_1', 'y_1', 'z_1', 'x_2', 'y_2', 'z_2',
       'x_3', 'y_3', 'z_3', 'x_4', 'y_4', 'z_4', 'x_5', 'y_5', 'z_5']

submission_df = pd.DataFrame(0.0, index = range(stacked.size()[0]), columns = submission_cols)
submission_df[['ID', 'resname', 'resid']] = validation_lbl[['ID', 'resname', 'resid']]

submission_df[['x_1', 'y_1', 'z_1']] = stacked.detach().numpy()
submission_df.head()

#submission_df.to_csv('../outputs/DistPred/submission2.csv')

Unnamed: 0,ID,resname,resid,x_1,y_1,z_1,x_2,y_2,z_2,x_3,y_3,z_3,x_4,y_4,z_4,x_5,y_5,z_5
0,R1107_1,G,1,-12.300216,-10.44787,-0.029812,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,R1107_2,G,2,-0.03209,0.02042,0.2044,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,R1107_3,G,3,0.029442,-0.015258,-0.03518,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,R1107_4,G,4,0.018721,-0.01678,-0.196487,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,R1107_5,G,5,0.017855,-0.018463,-0.12334,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [15]:
submission_df.head()

Unnamed: 0,ID,resname,resid,x_1,y_1,z_1,x_2,y_2,z_2,x_3,y_3,z_3,x_4,y_4,z_4,x_5,y_5,z_5
0,R1107_1,G,1,-13.646433,0.406043,-0.092685,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,R1107_2,G,2,0.013011,0.012916,12.02226,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,R1107_3,G,3,0.029339,0.017286,-0.329253,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,R1107_4,G,4,0.032582,0.01553,-0.470709,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,R1107_5,G,5,0.032398,0.018278,-0.399121,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
