### Imports

In [1]:
import os
import pandas as pd
import numpy as np
import torch
import lightning as L
import lightning.pytorch as pl

### Load sequence data from cryptic seq experiment

In [4]:
# Preprosses cryptic seq data to normalize by on target
import utils.cs_excel_data

# Flag to trigger data export from excel files
export_excel_data = False

cs_data_file = 'data/TB000208a.outputs.xlsx'
data_path = 'data/Tb000208a'
train_sheets = ['GT-Rep1-N7_S1','GT-Rep2-N7_S2','GT-Rep3-N7_S3']
test_sheets = ['Pool-Rep1-N7_S4','Pool-Rep2-N7_S5','Pool-Rep3-N7_S6']

if export_excel_data:
    utils.cs_excel_data.extract_excel_cs_data(cs_data_file, train_sheets, data_path, 'fit.csv')
    utils.cs_excel_data.extract_excel_cs_data(cs_data_file, test_sheets, data_path, 'test.csv', dn_exclusion=['GT','AC'])

In [9]:
from lightning.pytorch.utilities.types import EVAL_DATALOADERS
from models import mlp
from datasets import one_hot

n_classes = 2
seq_length = 22
vocab_size = 5
input_size = seq_length*vocab_size
hidden_size = 1024
n_hidden = 2
train_test_split = 0.8

class CSDataModule(L.LightningDataModule):
    def __init__(self, data_path, batch_size):
        super().__init__()
        self.data_path = data_path
        self.batch_size = batch_size

    def setup(self, stage: str):
        # Select test/train dataset
        fname = stage + '.csv'

        # Load the cryptic seq data
        sites = pd.read_csv(os.path.join(self.data_path, fname))

        # Threshold the data to assign a label
        sites['label'] = (sites['norm_count'] > 1e-2).astype(int)

        # Compute class frequencies for weighting
        class_sample_count = np.array([len(np.where(sites['label'] == c)[0]) for c in np.unique(sites['label'])])

        # Cryptic sites data for training
        sequences = sites['seq'].values
        labels = sites['label'].values
        self.seq_length = len(sequences[0])
        self.dataset = one_hot.Dataset(sequences, labels, vocab_size=vocab_size, output_size=n_classes)

        # Sample weights based on label and class frequency
        weight = 1. / class_sample_count
        samples_weight = np.array([weight[t] for t in labels])
        self.samples_weight = torch.from_numpy(samples_weight)

        if stage == 'fit':
            # Test and train data split
            train_size = int(train_test_split*len(self.dataset))
            test_size = len(self.dataset) - train_size
            self.train_dataset, self.val_dataset = torch.utils.data.random_split(self.dataset, [train_size, test_size])

            # Weighted random sampler for upsampling minority class for training
            train_sample_weights = samples_weight[self.train_dataset.indices]
            self.train_sampler = torch.utils.data.WeightedRandomSampler(train_sample_weights, len(train_sample_weights), replacement=True)

        elif stage == 'test':
            self.test_dataset = self.dataset

        elif stage == 'predict':
            self.pred_dataset = self.dataset

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, sampler=self.train_sampler)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size)
    
    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size)
    
    def predict_dataloader(self):
        return torch.utils.data.DataLoader(self.pred_dataset, batch_size=self.batch_size)

data_module = CSDataModule(data_path, batch_size=32)

# Build model
model = mlp.Model(input_size=input_size, hidden_size=hidden_size, output_size=n_classes, n_hidden=n_hidden, dropout=0.5)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

### Training

In [11]:
from typing import Any
import torchmetrics

# define the LightningModule
class LitClassifier(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.sigmoid = torch.nn.Sigmoid()
        self.loss_fn = torch.nn.BCEWithLogitsLoss()
        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=n_classes)

    def forward(self, x):
        return self.model(x)
    
    def logging(self, logits, target, loss, stage):
        # Logging to TensorBoard (if installed) by default
        pred = self.sigmoid(logits)
        self.accuracy(torch.argmax(pred,1), torch.argmax(target,1))
        self.log(f'{stage}_acc_step', self.accuracy)
        self.log(f'{stage}_loss', loss)

    def training_step(self, batch, batch_idx):
        # Model pass
        data, target = batch
        logits = self(data)
        loss = self.loss_fn(logits, target.float())
        self.logging(logits, target, loss, 'train')
        return loss
    
    def validation_step(self, batch, batch_idx):
        # Model pass
        data, target = batch
        logits = self(data)
        loss = self.loss_fn(logits, target.float())
        self.logging(logits, target, loss, 'val')
        return loss
    
    def test_step(self, batch, batch_idx):
        # Model pass
        data, target = batch
        logits = self(data)
        loss = self.loss_fn(logits, target.float())
        self.logging(logits, target, loss, 'test')
        return loss
    
    def predict_step(self, batch, batch_idx):
        # Model pass
        data, target = batch
        output = self.sigmoid(self(data))
        preds = torch.argmax(output, 1)
        return preds

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
# init the autoencoder
lit_model = LitClassifier(model)

# train the model
tb_logger = pl.loggers.TensorBoardLogger(save_dir="lightning_logs/")
trainer = pl.Trainer(max_epochs=5, logger=tb_logger, accelerator="gpu")
trainer.fit(lit_model, data_module)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name     | Type               | Params
------------------------------------------------
0 | model    | Model              | 1.2 M 
1 | sigmoid  | Sigmoid            | 0     
2 | loss_fn  | BCEWithLogitsLoss  | 0     
3 | accuracy | MulticlassAccuracy | 0     
------------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.661     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/matthewbakalar/anaconda3/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Users/matthewbakalar/anaconda3/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


### Analysis

In [54]:
preds = trainer.predict(lit_model, data_module)
preds = torch.hstack(preds)

/Users/matthewbakalar/anaconda3/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Predicting: |          | 0/? [00:00<?, ?it/s]

In [96]:
# Fix this to unzip a tuple
data = list(data_module.predict_dataloader())
inputs, labels = map(list, zip(*data))
inputs = torch.vstack(inputs)

In [102]:
# Translate inputs
trans_dict = {0:'A',1:'T',2:'C',3:'G',4:'N'}
translate_func = lambda x: ''.join([trans_dict[y] for y in x])

sequences = [translate_func(x) for x in torch.argmax(inputs,2).numpy()]

labels = torch.argmax(torch.vstack(labels),1)
predictions = pd.DataFrame({'seq':sequences, 'labels':labels, 'preds':preds})

In [117]:
for seq in predictions[predictions['labels'] == 1]['seq'].unique():
    print(seq)

AAAAAAAAAATTGACCTTAGAT
AAAAAAATTGGAAACATCTGAA
AAAAAAGGCTTTTACAATGAAG
AAAAAATGACTTATGTGGTGAG
AAAAACAGCCTTCGCAACAGTC
AAAAACAGCTTAATCAGCTGAG
AAAAACAGCTTCTACCGTTTAG
AAAAACCATCCTCACAACAGTC
AAAAACTATACCCACTGCAGAG
AAAAACTCTTTCTCCAGCAGAG
AAAAACTGCTTTCAGTGCTGAG
AAAAAGGAGCTTTCACTCTGAC
AAAAAGTTTAGGGACCACTGCT
AAAAAGTTTGGCAGTGTTTGAA
AAAAAGTTTGTCAACTCCTGTT
AAAAATCTTTCAAACCTTGGAG
AAAAATGCATCCAACACTGGAG
AAAAATGGCCTCAGCTTCAGAG
AAAAATGGCTTCTCCACCTGCA
AAAAATGGTTTAAACAATTCTC
AAAAATGTGAGTGACTCCTGCC
AAAAATTGTTTCTCCCACTGTG
AAAACAATTTTTAATAATAGAG
AAAACAGATCTCTACCTCTGAG
AAAACAGGCTTCCACCATTCCA
AAAACCTTTGTCCTCACTGGAG
AAAACGGAATATCACCATAGAC
AAAACGGACTTTCATATTAGGG
AAAACGGGTTCAGAAGACAGCT
AAAACTGAGCTCTACAACAGTC
AAAACTGGTTCAAACCATGCCA
AAAACTGTATTCAGAGGGTGAG
AAAAGAAAATAAAATTATTGAG
AAAAGAAGCTTAGAGTTCTGAT
AAAAGACTTTATTACTGTGGCA
AAAAGAGATTTTCACTACTGCT
AAAAGAGGATTTGCCTCTTGCA
AAAAGAGGGTTTCACATCTGCT
AAAAGAGGTTTACACGTCAGTG
AAAAGAGTGTTTCACAACTGCA
AAAAGATGCTTTCTCATCAGAT
AAAAGCAGATCTGACCACTGAA
AAAAGCAGCTCAGACAATTGCA
AAAAGCAGTTA

In [10]:
# Translate inputs
trans_dict = {0:'A',1:'T',2:'C',3:'G',4:'N'}
translate_func = lambda x: ''.join([trans_dict[y] for y in x])
sequences = [translate_func(x) for x in inputs.numpy()]
predictions = pd.DataFrame({'seq':sequences, 'label':labels, 'pred':preds})

sites = pd.read_csv('data/Tb000208a/predict.csv', index_col='index')
predictions = predictions.merge(sites, on='seq')

In [11]:
predictions.head()

Unnamed: 0,seq,label,pred,norm_count
0,AAAAAAAAAAAAAAAAAAGTCA,0,0,0.000215
1,AAAAAAAAAAAAAAAAATAGAG,0,0,0.000215
2,AAAAAAAAAAAAACAAAAAGAA,0,0,0.000215
3,AAAAAAAAAAAAAGCCACAGGA,0,0,0.000429
4,AAAAAAAAAAACAACAACAGCA,0,0,0.000859


In [13]:
sorted = predictions.sort_values('norm_count', ascending=False).reset_index()
pred_true = sorted[sorted['pred'] == 1]
label_true = sorted[sorted['label'] == 1]

#sorted['norm_count'].min()
label_true['norm_count'].min()

#plt.semilogy(label_true['norm_count'])

NameError: name 'plt' is not defined