### Imports

In [1]:
import os
import pandas as pd
import numpy as np
import torch
import lightning.pytorch as pl

from ray.train.lightning import (
    RayDDPStrategy,
    RayLightningEnvironment,
    RayTrainReportCallback,
    prepare_trainer,
)

# Set up path to import parent modules
from pathlib import Path
import sys  

# Add to sys.path
sys.path.insert(0, str(Path().resolve().parents[1]))

### Training

In [4]:
import models as models
from lit_modules import data_modules, modules

data_path = 'data/TB000208a'
genomic_reference_file = '../data/reference/hg38.fa'
n_classes = 3
seq_length = 46
vocab_size = 4
input_size = seq_length*vocab_size
hidden_size = 512
n_hidden = 2
train_test_split = 0.8

# Build the data module
data_module = data_modules.MulticlassDataModule(
    data_path, 
    threshold=0.01, 
    genomic_reference_file=genomic_reference_file,
    n_classes=n_classes, 
    train_test_split=train_test_split, 
    batch_size=32
)

# Build model
lit_model = modules.Classifier(input_size, hidden_size, n_classes, n_hidden, dropout=0.5)

# train the model
tb_logger = pl.loggers.TensorBoardLogger(save_dir="lightning_logs/")
trainer = pl.Trainer(max_epochs=5, logger=tb_logger, default_root_dir='.')
trainer.fit(lit_model, data_module)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


3
Number of classes 2


  weight = 1. / class_sample_count

  | Name     | Type               | Params
------------------------------------------------
0 | model    | MLPModel           | 358 K 
1 | sigmoid  | Sigmoid            | 0     
2 | loss_fn  | BCEWithLogitsLoss  | 0     
3 | accuracy | MulticlassAccuracy | 0     
------------------------------------------------
358 K     Trainable params
0         Non-trainable params
358 K     Total params
1.434     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/matthewbakalar/anaconda3/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


ValueError: Target size (torch.Size([32, 3])) must be the same as input size (torch.Size([32, 2]))

In [3]:
lit_model

Classifier(
  (model): MLPModel(
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (linear_relu_stack): Sequential(
      (0): Linear(in_features=184, out_features=512, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.5, inplace=False)
      (3): Linear(in_features=512, out_features=512, bias=True)
      (4): ReLU()
      (5): Dropout(p=0.5, inplace=False)
      (6): Linear(in_features=512, out_features=2, bias=True)
    )
  )
  (sigmoid): Sigmoid()
  (loss_fn): BCEWithLogitsLoss()
  (accuracy): MulticlassAccuracy()
)

In [2]:
import ray
from ray import tune
from ray.tune.schedulers import ASHAScheduler

# Init the cluster for training
ray.shutdown()
ray.init(num_cpus=8, num_gpus=1)
ray.init()

# Model tuning
def train_func():
    # Build the lightning modules
    data_module = data_modules.MulticlassDataModule(data_path, threshold=0.01, n_classes=n_classes, train_test_split=train_test_split, batch_size=32)
    lit_model = modules.Classifier(input_size=input_size, hidden_size=hidden_size, output_size=n_classes, n_hidden=n_hidden, dropout=0.5)

    trainer = pl.Trainer(
        devices="auto",
        accelerator="auto",
        strategy=RayDDPStrategy(),
        callbacks=[RayTrainReportCallback()],
        plugins=[RayLightningEnvironment()],
        enable_progress_bar=False,
    )
    trainer = prepare_trainer(trainer)
    trainer.fit(lit_model, datamodule=data_module)

search_space = {
    "hidden_size": tune.choice([512, 1024, 2048]),
    "lr": tune.loguniform(1e-4, 1e-1),
    "batch_size": tune.choice([32, 64]),
}

# The maximum training epochs
num_epochs = 5

# Number of sampls from parameter space
num_samples = 10

scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)

from ray.train import RunConfig, ScalingConfig, CheckpointConfig

scaling_config = ScalingConfig(
    num_workers=1, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 1}
)

run_config = RunConfig(
    checkpoint_config=CheckpointConfig(
        num_to_keep=2,
        checkpoint_score_attribute="val_acc",
        checkpoint_score_order="max",
    ),
)

from ray.train.torch import TorchTrainer

# Define a TorchTrainer without hyper-parameters for Tuner
ray_trainer = TorchTrainer(
    train_func,
    scaling_config=scaling_config,
    run_config=run_config,
)

def tune_mnist_asha(num_samples=10):
    scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)

    tuner = tune.Tuner(
        ray_trainer,
        param_space={"train_loop_config": search_space},
        tune_config=tune.TuneConfig(
            metric="val_acc",
            mode="max",
            num_samples=num_samples,
            scheduler=scheduler,
        ),
    )
    return tuner.fit()


results = tune_mnist_asha(num_samples=num_samples)


0,1
Current time:,2023-11-23 08:09:41
Running for:,00:00:38.24
Memory:,12.8/16.0 GiB

Trial name,# failures,error file
TorchTrainer_79508_00000,1,"/Users/matthewbakalar/ray_results/TorchTrainer_2023-11-23_08-09-02/TorchTrainer_79508_00000_0_batch_size=32,hidden_size=512,lr=0.0022_2023-11-23_08-09-03/error.txt"
TorchTrainer_79508_00001,1,"/Users/matthewbakalar/ray_results/TorchTrainer_2023-11-23_08-09-02/TorchTrainer_79508_00001_1_batch_size=64,hidden_size=512,lr=0.0843_2023-11-23_08-09-03/error.txt"
TorchTrainer_79508_00002,1,"/Users/matthewbakalar/ray_results/TorchTrainer_2023-11-23_08-09-02/TorchTrainer_79508_00002_2_batch_size=32,hidden_size=512,lr=0.0600_2023-11-23_08-09-03/error.txt"

Trial name,status,loc,train_loop_config/ba tch_size,train_loop_config/hi dden_size,train_loop_config/lr
TorchTrainer_79508_00003,PENDING,,64,2048,0.0025049
TorchTrainer_79508_00004,PENDING,,32,2048,0.0022036
TorchTrainer_79508_00005,PENDING,,64,2048,0.0547193
TorchTrainer_79508_00006,PENDING,,32,512,0.00266909
TorchTrainer_79508_00007,PENDING,,32,512,0.0472886
TorchTrainer_79508_00008,PENDING,,32,1024,0.0389879
TorchTrainer_79508_00009,PENDING,,64,1024,0.000128937
TorchTrainer_79508_00000,ERROR,127.0.0.1:27970,32,512,0.00215776
TorchTrainer_79508_00001,ERROR,127.0.0.1:27977,64,512,0.0842684
TorchTrainer_79508_00002,ERROR,127.0.0.1:27985,32,512,0.0600417


[36m(TorchTrainer pid=27970)[0m Starting distributed worker processes: ['27973 (127.0.0.1)']
[36m(RayTrainWorker pid=27973)[0m Setting up process group for: env:// [rank=0, world_size=1]
2023-11-23 08:09:14,065	ERROR tune_controller.py:1383 -- Trial task failed for trial TorchTrainer_79508_00000
Traceback (most recent call last):
  File "/Users/matthewbakalar/anaconda3/lib/python3.11/site-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
             ^^^^^^^^^^^^^^^
  File "/Users/matthewbakalar/anaconda3/lib/python3.11/site-packages/ray/_private/auto_init_hook.py", line 24, in auto_init_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/matthewbakalar/anaconda3/lib/python3.11/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/matthewbakalar/anaconda3/lib/python3.11/site-packages/ray/_pri

In [8]:
trainer.test(lit_model, data_module)

/Users/matthewbakalar/anaconda3/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_acc_step         0.7878776788711548
        test_loss            0.526176929473877
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_acc_step': 0.7878776788711548, 'test_loss': 0.526176929473877}]

### Training

## Inference Code

In [11]:
# Fast prediction code. Currently runs on one chromosome only
genomic_reference_file = '../../data/reference/hg38.fa'

pred_data_module = data_modules.GenomeDataModule(genomic_reference_file, batch_size=256)
preds = trainer.predict(lit_model, pred_data_module)
preds = torch.hstack(preds[:-1])

/Users/matthewbakalar/anaconda3/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Predicting: |          | 0/? [00:00<?, ?it/s]

/Users/matthewbakalar/anaconda3/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


TypeError: 'NoneType' object is not subscriptable

### Analysis

In [37]:
# Load checkpoint

lit_model.eval()

from Bio import SeqIO  #
import torch.nn.functional as F

genomic_reference_file = '../../data/reference/hg38.fa'


def reverse_complement(dna_sequence):
    complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
    reversed_sequence = dna_sequence[::-1]
    reverse_complement_sequence = ''.join(complement[nucleotide] for nucleotide in reversed_sequence)
    return reverse_complement_sequence

def encode_sequence(seq, seq_length=46, vocab_size=5):
    translation_dict = {'A':0,'T':1,'C':2,'G':3,'N':4}
    encoding = torch.tensor([translation_dict[c] for c in seq])
    x = F.one_hot(encoding, num_classes=vocab_size).to(torch.float32)
    return x

# Adjust the sliding window function to use batches
def sliding_window_inference(genome_sequence, seq_length, batch_size):
    predictions = []
    encoded_seqs_front = []
    encoded_seqs_back = []
    
    for i in range(0, len(genome_sequence) - seq_length + 1):
        if i % 10000 == 0:
            print(i)
        # Check for 'N' early
        full_sequence = genome_sequence[i:i+seq_length]
        if 'N' in full_sequence:
            continue
        
        # Process in batches
        front_half_sequence = full_sequence[:22]
        back_half_sequence = reverse_complement(full_sequence[24:])
        
        encoded_seqs_front.append(encode_sequence(front_half_sequence, seq_length))
        encoded_seqs_back.append(encode_sequence(back_half_sequence, seq_length))
        
        if len(encoded_seqs_front) == batch_size:
            # Make predictions on batch
            batch_preds = predict_on_batch(encoded_seqs_front, encoded_seqs_back)
            predictions.extend(batch_preds)
            
            # Clear lists for next batch
            encoded_seqs_front = []
            encoded_seqs_back = []

    # Process the final batch if there are any sequences left
    if encoded_seqs_front:
        batch_preds = predict_on_batch(encoded_seqs_front, encoded_seqs_back)
        predictions.extend(batch_preds)
    
    return predictions

# Define a function to make predictions on batches
def predict_on_batch(front_seqs, back_seqs):
    front_seqs_tensor = torch.stack(front_seqs)
    back_seqs_tensor = torch.stack(back_seqs)
    
    with torch.no_grad():
        front_preds = lit_model.predict_step((front_seqs_tensor, None), 0)
        back_preds = lit_model.predict_step((back_seqs_tensor, None), 0)
        average_logits = (front_preds + back_preds) / 2
        sigmoid = torch.nn.Sigmoid()
        final_preds = sigmoid(average_logits)
    
    # print(final_preds)
        
    return final_preds.tolist()


# Process each sequence in the FASTA file
seq_length = 46
batch_size = 10000  # or any size that fits in your GPU memory
for record in SeqIO.parse(genomic_reference_file, "fasta"):
    chromosome_sequence = record.seq.upper()
    chromosome_id = record.id
    print(f"Processing {chromosome_id}...")
    
    predictions = sliding_window_inference(str(chromosome_sequence), seq_length, batch_size)
    print(predictions)


Processing chr1...
0
10000
20000
30000
40000
50000
60000
70000
80000
90000
100000
110000
120000
130000
140000
150000
160000
170000
180000
190000
200000
210000
220000
230000
240000
250000
260000
270000
280000
290000
300000
310000
320000
330000
340000
350000
360000
370000
380000
390000
400000
410000
420000
430000
440000
450000
460000
470000
480000
490000
500000
510000
520000
530000
540000
550000
560000
570000
580000
590000
600000
610000
620000
630000
640000
650000
660000
670000
680000
690000
700000
710000
720000
730000
740000
750000
760000
770000
780000
790000
800000
810000
820000
830000
840000
850000
860000
870000
880000
890000
900000
910000
920000
930000
940000
950000
960000
970000
980000
990000
1000000
1010000
1020000
1030000
1040000
1050000
1060000
1070000
1080000
1090000
1100000
1110000
1120000
1130000
1140000
1150000
1160000
1170000
1180000
1190000
1200000
1210000
1220000
1230000
1240000
1250000
1260000
1270000
1280000
1290000
1300000
1310000
1320000
1330000
1340000
1350000
1360000

KeyboardInterrupt: 

In [13]:
# Fix this to unzip a tuple
data = list(data_module.predict_dataloader())
inputs, labels = map(list, zip(*data))
inputs = torch.vstack(inputs)