In [1]:
import pandas as pd
import torch
from tqdm import tqdm
import numpy as np
import json

  from .autonotebook import tqdm as notebook_tqdm


# initialize path and variables

In [2]:
CUDA_DEVICE_ID = 0
TRAIN_DATA_PATH = f"data/sample_train.txt"
VALID_DATA_PATH = f"data/sample_valid.txt"
TEST_DATA_PATH = f"data/sample_test.txt"

MODEL_LOG_DIR = f"model_weights"
TRAIN_BATCH_SIZE = 32
N_PROCS = 4
VALID_BATCH_SIZE = 32
lr = 0.005 # 0.001 for DREAM-Attn, 0.005 for DREAM-CNN and DREAM-RNN
BATCH_PER_EPOCH = len(pd.read_csv(TRAIN_DATA_PATH))//TRAIN_BATCH_SIZE
BATCH_PER_VALIDATION = len(pd.read_csv(VALID_DATA_PATH))//TRAIN_BATCH_SIZE
SEQ_SIZE = 230
NUM_EPOCHS = 5 #80
generator = torch.Generator()
generator.manual_seed(42)
device = torch.device(f"cuda:{CUDA_DEVICE_ID}")

# Model

### DREAM-RNN

In [11]:
from prixfixe.autosome import AutosomeFinalLayersBlock
from prixfixe.bhi import BHIFirstLayersBlock
from prixfixe.bhi import BHICoreBlock
from prixfixe.prixfixe import PrixFixeNet

first = BHIFirstLayersBlock(
            in_channels = 5,
            out_channels = 320,
            seqsize = 230,
            kernel_sizes = [9, 15],
            pool_size = 1,
            dropout = 0.2
        )

core = BHICoreBlock(
in_channels = first.out_channels,
out_channels = 320,
seqsize = first.infer_outseqsize(),
lstm_hidden_channels = 320,
kernel_sizes = [9, 15],
pool_size = 1,
dropout1 = 0.2,
dropout2 = 0.5
)

final = AutosomeFinalLayersBlock(in_channels=core.out_channels)

model = PrixFixeNet(
    first=first,
    core=core,
    final=final,
    generator=generator
)

from torchinfo import summary
print(summary(model, (1, 5, 230)))

Layer (type:depth-idx)                   Output Shape              Param #
PrixFixeNet                              [1, 1]                    --
├─BHIFirstLayersBlock: 1-1               --                        --
│    └─ModuleList: 2-1                   --                        --
│    │    └─ConvBlock: 3-1               [1, 160, 230]             7,360
│    │    └─ConvBlock: 3-2               [1, 160, 230]             12,160
├─BHICoreBlock: 1-2                      --                        --
│    └─LSTM: 2-2                         [1, 230, 640]             1,643,520
│    └─ModuleList: 2-3                   --                        --
│    │    └─ConvBlock: 3-3               [1, 160, 230]             921,760
│    │    └─ConvBlock: 3-4               [1, 160, 230]             1,536,160
│    └─Dropout: 2-4                      [1, 320, 230]             --
├─AutosomeFinalLayersBlock: 1-3          --                        --
│    └─Conv1d: 2-5                       [1, 256, 230]     

### DREAM-CNN

In [4]:
from prixfixe.autosome import (AutosomeCoreBlock,
                      AutosomeFinalLayersBlock)
from prixfixe.bhi import BHIFirstLayersBlock
from prixfixe.prixfixe import PrixFixeNet

first = BHIFirstLayersBlock(
            in_channels = 5,
            out_channels = 320,
            seqsize = 230,
            kernel_sizes = [9, 15],
            pool_size = 1,
            dropout = 0.2
        )

core = AutosomeCoreBlock(in_channels=first.out_channels,
                        out_channels =64,
                        seqsize=first.infer_outseqsize())

final = AutosomeFinalLayersBlock(in_channels=core.out_channels)

model = PrixFixeNet(
    first=first,
    core=core,
    final=final,
    generator=generator
)

from torchinfo import summary
print(summary(model, (1, 5, 230)))

Layer (type:depth-idx)                        Output Shape              Param #
PrixFixeNet                                   [1, 1]                    --
├─BHIFirstLayersBlock: 1-1                    --                        --
│    └─ModuleList: 2-1                        --                        --
│    │    └─ConvBlock: 3-1                    [1, 160, 230]             7,360
│    │    └─ConvBlock: 3-2                    [1, 160, 230]             12,160
├─AutosomeCoreBlock: 1-2                      --                        --
│    └─ModuleDict: 2-2                        --                        --
│    │    └─Sequential: 3-3                   [1, 320, 230]             420,048
│    │    └─Sequential: 3-4                   [1, 128, 230]             573,696
│    │    └─Sequential: 3-5                   [1, 128, 230]             173,856
│    │    └─Sequential: 3-6                   [1, 128, 230]             229,632
│    │    └─Sequential: 3-7                   [1, 128, 230]         

### DREAM-Attn

In [5]:
from prixfixe.autosome import (
                      AutosomeFirstLayersBlock,
                      AutosomeFinalLayersBlock)
from prixfixe.unlockdna import UnlockDNACoreBlock
from prixfixe.prixfixe import PrixFixeNet

first = AutosomeFirstLayersBlock(in_channels=5,
                                out_channels=256, 
                                seqsize=230)
core = UnlockDNACoreBlock(
    in_channels = first.out_channels, out_channels= first.out_channels, seqsize = 230, 
    n_blocks = 4,kernel_size = 15, rate = 0.1, num_heads = 8)

final = AutosomeFinalLayersBlock(in_channels=core.out_channels)

model = PrixFixeNet(
    first=first,
    core=core,
    final=final,
    generator=generator
)

from torchinfo import summary
print(summary(model, (1, 5, 230)))

Layer (type:depth-idx)                        Output Shape              Param #
PrixFixeNet                                   [1, 1]                    --
├─AutosomeFirstLayersBlock: 1-1               --                        --
│    └─Sequential: 2-1                        [1, 256, 230]             --
│    │    └─Conv1d: 3-1                       [1, 256, 230]             8,960
│    │    └─BatchNorm1d: 3-2                  [1, 256, 230]             512
│    │    └─SiLU: 3-3                         [1, 256, 230]             --
├─UnlockDNACoreBlock: 1-2                     --                        --
│    └─Embedding: 2-2                         [1, 230, 256]             58,880
│    └─ModuleList: 2-3                        --                        --
│    │    └─ConformerSASwiGLULayer: 3-4       [1, 256, 230]             1,121,280
│    │    └─ConformerSASwiGLULayer: 3-5       [1, 256, 230]             1,121,280
│    │    └─ConformerSASwiGLULayer: 3-6       [1, 256, 230]             1

# DataProcessor

In [6]:
from prixfixe.autosome import AutosomeDataProcessor

dataprocessor = AutosomeDataProcessor(
    path_to_training_data=TRAIN_DATA_PATH,
    path_to_validation_data=VALID_DATA_PATH,
    train_batch_size=TRAIN_BATCH_SIZE, 
    batch_per_epoch=BATCH_PER_EPOCH,
    train_workers=N_PROCS,
    valid_batch_size=VALID_BATCH_SIZE,
    valid_workers=N_PROCS,
    shuffle_train=True,
    shuffle_val=False,
    seqsize=SEQ_SIZE,
    generator=generator
)

In [7]:
next(dataprocessor.prepare_train_dataloader())

{'x': tensor([[[1., 0., 0.,  ..., 0., 0., 1.],
          [0., 1., 1.,  ..., 0., 1., 0.],
          [0., 0., 0.,  ..., 1., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[1., 0., 0.,  ..., 0., 0., 1.],
          [0., 1., 1.,  ..., 0., 1., 0.],
          [0., 0., 0.,  ..., 1., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[1., 0., 0.,  ..., 0., 0., 1.],
          [0., 1., 1.,  ..., 0., 1., 0.],
          [0., 0., 0.,  ..., 1., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         ...,
 
         [[1., 0., 0.,  ..., 0., 0., 1.],
          [0., 1., 1.,  ..., 0., 1., 0.],
          [0., 0., 0.,  ..., 1., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[1., 0., 0.,  ..., 0., 0., 1.],
          [0., 1., 1.,  ..., 0., 1., 0.],
          [0., 0., 0.,  ..., 1., 0., 0.],
 

# Trainer

In [8]:
from prixfixe.autosome import AutosomeTrainer
trainer = AutosomeTrainer(
    model,
    device=torch.device(f"cuda:{CUDA_DEVICE_ID}"), 
    model_dir=MODEL_LOG_DIR,
    dataprocessor=dataprocessor,
    num_epochs=NUM_EPOCHS,
    lr = lr)

trainer.fit()

  0%|                                                                                                                                                                                                                                                       | 0/5 [00:00<?, ?it/s]
Train epoch:   0%|                                                                                                                                                                                                                                         | 0/31 [00:00<?, ?it/s][A
Train epoch:   3%|███████▎                                                                                                                                                                                                                         | 1/31 [00:00<00:04,  6.05it/s][A
Train epoch:  13%|█████████████████████████████                                                                                                                          

{'MSE': 1.4224580526351929, 'pearsonr': 0.13069323380273878, 'spearmanr': 0.08399409829541563}



Train epoch:   0%|                                                                                                                                                                                                                                         | 0/31 [00:00<?, ?it/s][A
Train epoch:   6%|██████████████▌                                                                                                                                                                                                                  | 2/31 [00:00<00:03,  7.58it/s][A
Train epoch:  16%|████████████████████████████████████▎                                                                                                                                                                                            | 5/31 [00:00<00:01, 15.12it/s][A
Train epoch:  29%|█████████████████████████████████████████████████████████████████▎                                                                                 

{'MSE': 10.506120681762695, 'pearsonr': 0.033149745138612116, 'spearmanr': 0.0692970644590461}



Train epoch:   0%|                                                                                                                                                                                                                                         | 0/31 [00:00<?, ?it/s][A
Train epoch:  13%|█████████████████████████████                                                                                                                                                                                                    | 4/31 [00:00<00:02, 12.00it/s][A
Train epoch:  23%|██████████████████████████████████████████████████▊                                                                                                                                                                              | 7/31 [00:00<00:01, 17.22it/s][A
Train epoch:  35%|███████████████████████████████████████████████████████████████████████████████▍                                                                   

{'MSE': 59.10154724121094, 'pearsonr': 0.027876909221375257, 'spearmanr': -0.029677688582079646}



Train epoch:   0%|                                                                                                                                                                                                                                         | 0/31 [00:00<?, ?it/s][A
Train epoch:  13%|█████████████████████████████                                                                                                                                                                                                    | 4/31 [00:00<00:00, 30.50it/s][A
Train epoch:  26%|██████████████████████████████████████████████████████████                                                                                                                                                                       | 8/31 [00:00<00:01, 16.16it/s][A
Train epoch:  39%|██████████████████████████████████████████████████████████████████████████████████████▋                                                            

{'MSE': 2.3883261680603027, 'pearsonr': 0.02507159148452033, 'spearmanr': -0.028842758862356783}



Train epoch:   0%|                                                                                                                                                                                                                                         | 0/31 [00:00<?, ?it/s][A
Train epoch:  13%|█████████████████████████████                                                                                                                                                                                                    | 4/31 [00:00<00:00, 30.67it/s][A
Train epoch:  26%|██████████████████████████████████████████████████████████                                                                                                                                                                       | 8/31 [00:00<00:01, 16.07it/s][A
Train epoch:  35%|███████████████████████████████████████████████████████████████████████████████▍                                                                   

{'MSE': 2.171882390975952, 'pearsonr': -0.02672071844840184, 'spearmanr': -0.040107870384680985}





# Predict

In [9]:
model.load_state_dict(torch.load(f"{MODEL_LOG_DIR}/model_best.pth"))
model.eval()

test_df = pd.read_csv(TEST_DATA_PATH, sep='\t')
test_df['rev'] = test_df['seq_id'].str.contains('_Reversed:').astype(int)

def one_hot_encode(seq):
    mapping = {'A': [1, 0, 0, 0],
            'G': [0, 1, 0, 0],
            'C': [0, 0, 1, 0],
            'T': [0, 0, 0, 1], 
            'N': [0, 0, 0, 0]}
    return [mapping[base] for base in seq]

In [10]:
# One-hot encode sequences and concatenate 'rev' column
encoded_seqs = []
for i, row in tqdm(test_df.iterrows()):
    encoded_seq = one_hot_encode(row['seq'])
    rev_value = [row['rev']] * len(encoded_seq)
    encoded_seq_with_rev = [list(encoded_base) + [rev] for encoded_base, rev in zip(encoded_seq, rev_value)]
    encoded_seqs.append(encoded_seq_with_rev)

from tqdm import tqdm
pred_expr = []
for seq in tqdm(encoded_seqs):
    pred = model(torch.tensor(np.array(seq).reshape(1,230,5).transpose(0,2,1), device = device, dtype = torch.float32)) #can also predict on batches to speed up prediction
    pred_expr.append(pred.detach().cpu().flatten().tolist())

1000it [00:00, 4944.59it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:03<00:00, 323.45it/s]
