## Import Libraries

In [1]:
import os, random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import grelu.sequence.format as seqfmt

In [2]:
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader, Dataset
from Seqformer import *
from lion_pytorch import Lion

## Config

In [3]:
# Training settings
minibatch_size = 4096
batch_size = 4096
epochs = 150
lr = 0.0003
criterion = nn.MSELoss()

In [4]:
seed = 42
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

torch.set_default_dtype(torch.float)
torch.set_num_threads(8)
torch.set_float32_matmul_precision('high')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [5]:
# prefer flash attention
torch.backends.cuda.enable_flash_sdp(enabled=True)
torch.backends.cuda.enable_mem_efficient_sdp(enabled=True)
torch.backends.cuda.enable_math_sdp(enabled=True)
torch.backends.cuda.enable_cudnn_sdp(enabled=True)
torch.backends.cudnn.deterministic = True

## Load Data

In [6]:
# Training dataset generated by sim_gene_seqs.ipynb
traindata = pd.read_csv("./data/train.csv", delimiter=",")
trainX = traindata["seq"][0    :20000].apply(str.upper)
trainY = traindata["score"][0    :20000]
## use half for validation
validX = traindata["seq"][20001:24000].apply(str.upper)
validY = traindata["score"][20001:24000]
print(traindata.shape, len(traindata["seq"][0]))

(40000, 2) 132


In [7]:
# feed input with offsets 0-3
trainX = pd.concat([
    trainX.str.slice(0, 128),
    trainX.str.slice(1, 129),
    trainX.str.slice(2, 130),
    trainX.str.slice(3, 131),
], ignore_index=True)
validX = validX.str.slice(0,128)

trainY = pd.concat([ trainY, trainY, trainY, trainY ], ignore_index=True)
print(len(trainX[0]))

128


In [8]:
trainX = seqfmt.strings_to_one_hot(trainX.to_list()).transpose(1, 2)
validX = seqfmt.strings_to_one_hot(validX.to_list()).transpose(1, 2)
trainY = torch.tensor(trainY.to_list()).to(torch.float)
validY = torch.tensor(validY.to_list()).to(torch.float)

In [9]:
train_loader = DataLoader(
    SeqDataset(trainX, trainY),
    batch_size=minibatch_size,
    shuffle=True
)

valid_loader = DataLoader(
    SeqDataset(validX, validY),
    batch_size=minibatch_size,
    shuffle=True
)

## Model

In [10]:
## Transformer model - my_model.py
## Original code adapted from
## https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_1d.py
model = Seqformer(
    dim=64,
    hidden_dim=256,
    head_dim=64,
    heads=4,
    depth=4,
    dropout=0.2,
    emb_dropout=0.2,
    word_len=4,
    num_classes=1,
    channels=4,
    seq_len=len(trainX[0])
)
print(model)
model = torch.compile(model)
model = model.to(device)

Seqformer(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b (n w) c -> b n (w c)', w=4)
    (1): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
    (2): Linear(in_features=16, out_features=64, bias=False)
    (3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  )
  (dropout): Dropout(p=0.2, inplace=False)
  (transformer): Transformer(
    (layers): ModuleList(
      (0-3): 4 x ModuleList(
        (0): Attention(
          (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (to_q): Linear(in_features=64, out_features=256, bias=False)
          (to_kv): Linear(in_features=64, out_features=512, bias=False)
          (to_out): Sequential(
            (0): Linear(in_features=256, out_features=64, bias=False)
            (1): Dropout(p=0.0, inplace=False)
          )
        )
        (1): FeedForward(
          (net): Sequential(
            (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
            (1): Linear(in_features=64, out_features

In [11]:
#optimizer = optim.AdamW(model.parameters(), lr=lr)
optimizer = Lion(model.parameters(), lr=lr)

# normalization factor for gradient accumulation
normf = minibatch_size / batch_size

## Testing

In [12]:
for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0
    count = 0

    for data, label in train_loader:
        data = data.to(device)
        label = label.to(device)

        output = model(data).squeeze()
        loss = criterion(output, label) * normf
        loss.backward()

        count += minibatch_size
        if count >= batch_size:
            optimizer.step()
            optimizer.zero_grad()
            count=0

        acc = np.corrcoef(x=output.detach().cpu(), y=label.detach().cpu())[0,1]
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    optimizer.step()
    optimizer.zero_grad()
    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data).squeeze()
            val_loss = criterion(val_output, label)

            acc = np.corrcoef(x=val_output.cpu(), y=label.cpu())[0,1]
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    if epoch_val_accuracy >= 0.93:
        torch.save( model, f"./models/simtest_v0_e{epoch}.pkl" )

    print(
        f"[{epoch+1:02d}] Loss: {epoch_loss:.2f}\tAcc: {epoch_accuracy:.3f}\tVloss: {epoch_val_loss:.2f}\tVacc: {epoch_val_accuracy:.3f}"
    )

[01] Loss: 52.40	Acc: 0.002	Vloss: 47.22	Vacc: 0.053
[02] Loss: 50.10	Acc: 0.069	Vloss: 46.80	Vacc: 0.102
[03] Loss: 49.40	Acc: 0.125	Vloss: 46.49	Vacc: 0.173
[04] Loss: 48.76	Acc: 0.181	Vloss: 45.39	Vacc: 0.187
[05] Loss: 47.45	Acc: 0.216	Vloss: 44.55	Vacc: 0.250
[06] Loss: 46.88	Acc: 0.265	Vloss: 43.64	Vacc: 0.287
[07] Loss: 46.08	Acc: 0.289	Vloss: 43.11	Vacc: 0.304
[08] Loss: 45.71	Acc: 0.310	Vloss: 42.74	Vacc: 0.305
[09] Loss: 44.79	Acc: 0.326	Vloss: 41.97	Vacc: 0.328
[10] Loss: 44.24	Acc: 0.345	Vloss: 41.34	Vacc: 0.349
[11] Loss: 43.83	Acc: 0.361	Vloss: 41.17	Vacc: 0.351
[12] Loss: 43.27	Acc: 0.373	Vloss: 40.53	Vacc: 0.386
[13] Loss: 43.10	Acc: 0.386	Vloss: 40.37	Vacc: 0.384
[14] Loss: 42.05	Acc: 0.403	Vloss: 40.50	Vacc: 0.387
[15] Loss: 41.70	Acc: 0.415	Vloss: 39.85	Vacc: 0.399
[16] Loss: 41.56	Acc: 0.424	Vloss: 39.70	Vacc: 0.405
[17] Loss: 41.23	Acc: 0.432	Vloss: 39.33	Vacc: 0.405
[18] Loss: 40.70	Acc: 0.442	Vloss: 39.39	Vacc: 0.401
[19] Loss: 40.29	Acc: 0.450	Vloss: 38.28	Vacc: