### Imports

In [1]:
import os
import pandas as pd
import numpy as np
import torch
import genomepy
import lightning as L
import lightning.pytorch as pl
import torch.nn.functional as F

# Set up path to import parent modules
from pathlib import Path
import sys  

# Add to sys.path
sys.path.insert(0, str(Path().resolve().parents[1]))

### Load sequence data from cryptic seq experiment

In [2]:
# Preprosses cryptic seq data to normalize by on target
from cryptic import utils

# Flag to trigger data export from excel files
export_excel_data = False

cs_data_file = '../data/TB000208a.outputs.xlsx'
data_path = '../data/TB000208a'
train_sheets = ['GT-Rep1-N7_S1','GT-Rep2-N7_S2','GT-Rep3-N7_S3']
test_sheets = ['Pool-Rep1-N7_S4','Pool-Rep2-N7_S5','Pool-Rep3-N7_S6']

if export_excel_data:
    utils.cs_excel_data.extract_excel_cs_data(cs_data_file, train_sheets, data_path, 'fit.csv')
    sites = utils.cs_excel_data.extract_excel_cs_data(cs_data_file, test_sheets, data_path, 'test.csv', dn_exclusion=['GT','AC'])

In [3]:
from lightning.pytorch.utilities.types import EVAL_DATALOADERS
from cryptic.models import mlp
from cryptic.datasets import one_hot, data

genomic_reference_file = '../data/references/hg38.fa'
n_classes = 2
seq_length = 22
vocab_size = 4
input_size = seq_length*vocab_size
hidden_size = 1024
n_hidden = 2
train_test_split = 0.8

class CSDataModule(L.LightningDataModule):
    def __init__(self, data_path, add_decoys=False, batch_size=8):
        super().__init__()
        self.data_path = data_path
        self.add_decoys = add_decoys
        self.batch_size = batch_size

    def setup(self, stage: str):
        # Select test/train dataset
        fname = stage + '.csv'

        # Load the cryptic seq data
        sites = pd.read_csv(os.path.join(self.data_path, fname))

        # Threshold the data to assign a label. This code should live somewhere else...
        sites['label'] = (sites['norm_count'] > 1e-2).astype(int)

        # Compute class frequencies for weighting
        class_sample_count = np.array([len(np.where(sites['label'] == c)[0]) for c in np.unique(sites['label'])])

        # Cryptic sites data for training
        sequences = sites['seq'].values
        labels = sites['label'].values

        # Sample weights based on label and class frequency
        weight = 1. / class_sample_count
        samples_weight = np.array([weight[t] for t in labels])
        self.samples_weight = torch.from_numpy(samples_weight)

        # Convert labels to one hot and build dataset
        one_hot_labels = F.one_hot(torch.tensor(labels), num_classes=n_classes)
        self.seq_length = len(sequences[0])
        #self.dataset = one_hot.Dataset(sequences, one_hot_labels, vocab_size=vocab_size, output_size=n_classes)
        self.dataset = data.SequenceDataset(sequences, one_hot_labels)

        if self.add_decoys:
            # Generate random decoy sequences
            decoy_count = len(sites)
            genome = genomepy.genome.Genome(genomic_reference_file)
            samples = genome.get_random_sequences(n=decoy_count, length=seq_length-1, max_n=0)
            decoys = pd.Series(list(map(lambda row: genome.get_seq(*row).seq.upper(), samples)))
            decoys_labels = np.zeros(len(decoys))

        if stage == 'fit':
            # Test and train data split
            train_size = int(train_test_split*len(self.dataset))
            test_size = len(self.dataset) - train_size
            self.train_dataset, self.val_dataset = torch.utils.data.random_split(self.dataset, [train_size, test_size])

            # Weighted random sampler for upsampling minority class for training
            train_sample_weights = samples_weight[self.train_dataset.indices]
            self.train_sampler = torch.utils.data.WeightedRandomSampler(train_sample_weights, len(train_sample_weights), replacement=True)

        elif stage == 'test':
            self.test_dataset = self.dataset

        elif stage == 'predict':
            self.pred_dataset = self.dataset

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, sampler=self.train_sampler)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size)
    
    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size)
    
    def predict_dataloader(self):
        return torch.utils.data.DataLoader(self.pred_dataset, batch_size=self.batch_size)

data_module = CSDataModule(data_path, batch_size=32)

# Build model
model = mlp.Model(input_size=input_size, hidden_size=hidden_size, output_size=n_classes, n_hidden=n_hidden, dropout=0.5)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

### Training

In [4]:
from typing import Any
import torchmetrics

# define the LightningModule
class LitClassifier(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.sigmoid = torch.nn.Sigmoid()
        self.loss_fn = torch.nn.BCEWithLogitsLoss()
        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=n_classes)

    def forward(self, x):
        return self.model(x)
    
    def logging(self, logits, target, loss, stage):
        # Logging to TensorBoard (if installed) by default
        pred = self.sigmoid(logits)
        self.accuracy(torch.argmax(pred,1), torch.argmax(target,1))
        self.log(f'{stage}_acc_step', self.accuracy)
        self.log(f'{stage}_loss', loss)

    def training_step(self, batch, batch_idx):
        # Model pass
        data, target = batch
        logits = self(data)
        loss = self.loss_fn(logits, target.float())
        self.logging(logits, target, loss, 'train')
        return loss
    
    def validation_step(self, batch, batch_idx):
        # Model pass
        data, target = batch
        logits = self(data)
        loss = self.loss_fn(logits, target.float())
        self.logging(logits, target, loss, 'val')
        return loss
    
    def test_step(self, batch, batch_idx):
        # Model pass
        data, target = batch
        logits = self(data)
        loss = self.loss_fn(logits, target.float())
        self.logging(logits, target, loss, 'test')
        return loss
    
    def predict_step(self, batch, batch_idx):
        # Model pass
        data, target = batch
        logits = self(data)
        # output = self.sigmoid(self(data))
        # preds = torch.argmax(output, 1)
        return logits

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

  
# init the autoencoder
lit_model = LitClassifier(model)

# train the model
tb_logger = pl.loggers.TensorBoardLogger(save_dir="lightning_logs/")
trainer = pl.Trainer(max_epochs=5, logger=tb_logger, default_root_dir='.')
trainer.fit(lit_model, data_module)

INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /var/folders/x8/067mrh2x5_s1x14x3kq4qkd80000gr/T/tmpxspmnlmk
INFO:torch.distributed.nn.jit.instantiator:Writing /var/folders/x8/067mrh2x5_s1x14x3kq4qkd80000gr/T/tmpxspmnlmk/_remote_module_non_scriptable.py
INFO: GPU available: True (mps), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (mps), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: 
  | Name     | Type               | Params
------------------------------------------------
0 | model    | Model              | 1.1 M 
1 | sigmoid  | Sigmoid            | 0     
2 | loss_fn  | BCEWithLogitsLoss  |

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/matthewbakalar/anaconda3/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
  tp = tp.sum(dim=0 if multidim_average == "global" else 1)
/Users/matthewbakalar/anaconda3/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=5` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


## Inference Code

In [6]:
# Blazing fast prediction code. Currently runs on one chromosome only

from cryptic.datasets.data import GenomeBoxcarDataset

genomic_reference_file = '../../data/reference/hg38.fa'

class GenomeDataModule(L.LightningDataModule):
    def __init__(self, data_file, batch_size=8):
        super().__init__()
        self.data_file = data_file
        self.batch_size = batch_size

    def setup(self, stage: str):
        if stage == 'predict':
            self.pred_dataset = GenomeBoxcarDataset(fasta_file=self.data_file)

    def predict_dataloader(self):
        return torch.utils.data.DataLoader(self.pred_dataset, batch_size=self.batch_size)

pred_data_module = GenomeDataModule(genomic_reference_file)
preds = trainer.predict(lit_model, data_module)
preds = torch.hstack(preds[:-1])
    

/Users/matthewbakalar/anaconda3/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Predicting: |          | 0/? [00:00<?, ?it/s]

### Analysis

In [37]:
# Load checkpoint

lit_model.eval()

from Bio import SeqIO  #
import torch.nn.functional as F

genomic_reference_file = '../../data/reference/hg38.fa'


def reverse_complement(dna_sequence):
    complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
    reversed_sequence = dna_sequence[::-1]
    reverse_complement_sequence = ''.join(complement[nucleotide] for nucleotide in reversed_sequence)
    return reverse_complement_sequence

def encode_sequence(seq, seq_length=46, vocab_size=5):
    translation_dict = {'A':0,'T':1,'C':2,'G':3,'N':4}
    encoding = torch.tensor([translation_dict[c] for c in seq])
    x = F.one_hot(encoding, num_classes=vocab_size).to(torch.float32)
    return x

# Adjust the sliding window function to use batches
def sliding_window_inference(genome_sequence, seq_length, batch_size):
    predictions = []
    encoded_seqs_front = []
    encoded_seqs_back = []
    
    for i in range(0, len(genome_sequence) - seq_length + 1):
        if i % 10000 == 0:
            print(i)
        # Check for 'N' early
        full_sequence = genome_sequence[i:i+seq_length]
        if 'N' in full_sequence:
            continue
        
        # Process in batches
        front_half_sequence = full_sequence[:22]
        back_half_sequence = reverse_complement(full_sequence[24:])
        
        encoded_seqs_front.append(encode_sequence(front_half_sequence, seq_length))
        encoded_seqs_back.append(encode_sequence(back_half_sequence, seq_length))
        
        if len(encoded_seqs_front) == batch_size:
            # Make predictions on batch
            batch_preds = predict_on_batch(encoded_seqs_front, encoded_seqs_back)
            predictions.extend(batch_preds)
            
            # Clear lists for next batch
            encoded_seqs_front = []
            encoded_seqs_back = []

    # Process the final batch if there are any sequences left
    if encoded_seqs_front:
        batch_preds = predict_on_batch(encoded_seqs_front, encoded_seqs_back)
        predictions.extend(batch_preds)
    
    return predictions

# Define a function to make predictions on batches
def predict_on_batch(front_seqs, back_seqs):
    front_seqs_tensor = torch.stack(front_seqs)
    back_seqs_tensor = torch.stack(back_seqs)
    
    with torch.no_grad():
        front_preds = lit_model.predict_step((front_seqs_tensor, None), 0)
        back_preds = lit_model.predict_step((back_seqs_tensor, None), 0)
        average_logits = (front_preds + back_preds) / 2
        sigmoid = torch.nn.Sigmoid()
        final_preds = sigmoid(average_logits)
    
    # print(final_preds)
        
    return final_preds.tolist()


# Process each sequence in the FASTA file
seq_length = 46
batch_size = 10000  # or any size that fits in your GPU memory
for record in SeqIO.parse(genomic_reference_file, "fasta"):
    chromosome_sequence = record.seq.upper()
    chromosome_id = record.id
    print(f"Processing {chromosome_id}...")
    
    predictions = sliding_window_inference(str(chromosome_sequence), seq_length, batch_size)
    print(predictions)


Processing chr1...
0
10000
20000
30000
40000
50000
60000
70000
80000
90000
100000
110000
120000
130000
140000
150000
160000
170000
180000
190000
200000
210000
220000
230000
240000
250000
260000
270000
280000
290000
300000
310000
320000
330000
340000
350000
360000
370000
380000
390000
400000
410000
420000
430000
440000
450000
460000
470000
480000
490000
500000
510000
520000
530000
540000
550000
560000
570000
580000
590000
600000
610000
620000
630000
640000
650000
660000
670000
680000
690000
700000
710000
720000
730000
740000
750000
760000
770000
780000
790000
800000
810000
820000
830000
840000
850000
860000
870000
880000
890000
900000
910000
920000
930000
940000
950000
960000
970000
980000
990000
1000000
1010000
1020000
1030000
1040000
1050000
1060000
1070000
1080000
1090000
1100000
1110000
1120000
1130000
1140000
1150000
1160000
1170000
1180000
1190000
1200000
1210000
1220000
1230000
1240000
1250000
1260000
1270000
1280000
1290000
1300000
1310000
1320000
1330000
1340000
1350000
1360000

KeyboardInterrupt: 

In [None]:
preds = trainer.predict(lit_model, data_module)
preds = torch.hstack(preds)

/home/ubuntu/mambaforge/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 2336/2336 [00:08<00:00, 284.71it/s]


In [None]:
# Fix this to unzip a tuple
data = list(data_module.predict_dataloader())
inputs, labels = map(list, zip(*data))
inputs = torch.vstack(inputs)

In [None]:
# Translate inputs
trans_dict = {0:'A',1:'T',2:'C',3:'G',4:'N'}
translate_func = lambda x: ''.join([trans_dict[y] for y in x])

sequences = [translate_func(x) for x in torch.argmax(inputs,2).numpy()]

labels = torch.argmax(torch.vstack(labels),1)
predictions = pd.DataFrame({'seq':sequences, 'labels':labels, 'preds':preds})

In [None]:
for seq in predictions[predictions['labels'] == 1]['seq'].unique():
    print(seq)

AAAAACAGCTTCTACCGTTTAG
AAAAACTATACCCACTGCAGAG
AAAAATCTTTCAAACCTTGGAG
AAAAATTGTTTCTCCCACTGTG
AAAACAATTTTTAATAATAGAG
AAAACAGATCTCTACCTCTGAG
AAAAGAGGTTTACACGTCAGTG
AAAAGGAATTTTGACATCAGAA
AAAAGGAGACTTAACAACTGAG
AAAAGGAGCTGTAACAACTGAG
AAAAGGAGGTATTACAACTGAT
AAAAGGCTTGATGACCTCAGGG
AAAAGGTTTGGACACAATAGCA
AAAAGTGAATTTTCCAACTGCT
AAAATAGTATTTGACAATAGAC
AAACAGGACTTGTTCAACAGAG
AAACAGGATTTGCACCCCGGGG
AAACAGTGCTTCATCTTCAGCT
AAACGCAGCTCCAACGCCAGCA
AAACGGTGAGTGAGGCACAGAG
AAAGAAACTTTTCACAACTGAG
AAAGAGATACTTGATAGCTGCT
AAAGGCGCCACTGACCACAGAG
AAAGGGCTGAATGGCCACTGAG
AAAGTCAGCTTCTGCTTCAGCC
AAAGTGCTGTTTCACCACTGCT
AAAGTGTGGTTTCACATCTGGC
AAAGTTTTCTTCCACTACAGCC
AAATAGAGTTTGGACTACAGAA
AAATATTATTTAGATATCAGAA
AAATCCTTTTTATTCCACTGAG
AAATCTTTCTTCTCCTACAGCA
AAATGAGGCTTGTCCTGCTGCA
AAATGGAGATATTACGACTGAC
AAATTCTGCTTTAACTACAGTT
AAATTGAATTTAATCAACTGAA
AAATTGCACCTTGAACACTGCA
AAATTGTTTTTATACTGCTGAA
AACAATAATCTCAACATCAGCA
AACAATAGATCTGACAACAGGG
AACAGAGGTTTAAGCCACAGGA
AACAGAGTTTATGACCTCAGAG
AACAGTATTGAATGCAATTGAT
AACATCAGCTT

In [None]:
predictions[(predictions['labels'] == 1) & (predictions['preds'] == 1) ]

Unnamed: 0,seq,labels,preds
45,AAAAACTATACCCACTGCAGAG,1,1
77,AAAAATCTTTCAAACCTTGGAG,1,1
95,AAAAATTGTTTCTCCCACTGTG,1,1
104,AAAACAATTTTTAATAATAGAG,1,1
177,AAAAGGAATTTTGACATCAGAA,1,1
...,...,...,...
74594,TGATTCAGCTTTCACTATTGCT,1,1
74652,TACTAAGTCTTCTACTACTGAG,1,1
74662,TGACAATACTTTGACTTTAGCC,1,1
74665,CTAAGAAGTTTTTACAATAGAG,1,1


In [None]:
# Translate inputs
trans_dict = {0:'A',1:'T',2:'C',3:'G',4:'N'}
translate_func = lambda x: ''.join([trans_dict[y] for y in x])
sequences = [translate_func(x) for x in inputs.numpy()]
predictions = pd.DataFrame({'seq':sequences, 'label':labels, 'pred':preds})

sites = pd.read_csv('data/TB000208a/predict.csv', index_col='index')
predictions = predictions.merge(sites, on='seq')

TypeError: unhashable type: 'numpy.ndarray'

In [None]:
predictions.head()

Unnamed: 0,seq,label,pred,norm_count
0,AAAAAAAAAAAAAAAAAAGTCA,0,0,0.000215
1,AAAAAAAAAAAAAAAAATAGAG,0,0,0.000215
2,AAAAAAAAAAAAACAAAAAGAA,0,0,0.000215
3,AAAAAAAAAAAAAGCCACAGGA,0,0,0.000429
4,AAAAAAAAAAACAACAACAGCA,0,0,0.000859


In [None]:
sorted = predictions.sort_values('norm_count', ascending=False).reset_index()
pred_true = sorted[sorted['pred'] == 1]
label_true = sorted[sorted['label'] == 1]

#sorted['norm_count'].min()
label_true['norm_count'].min()

#plt.semilogy(label_true['norm_count'])

NameError: name 'plt' is not defined

In [47]:
seq_length

22