In [1]:
from seqinfer.seq.datasets import SeqFromFileDataset
from seqinfer.seq.transforms import Compose, KmerTokenizer, OneHotEncoder, ToTensor
from seqinfer.seq.vocabularies import unambiguous_dna_vocabulary_dict, SpecialToken

import torch
from torch.utils.data import ConcatDataset, Subset, DataLoader
from sklearn.model_selection import train_test_split
import logging

logging.disable(logging.INFO)

In [2]:
MAXLEN = 120
pos_seq_dataset = SeqFromFileDataset(
    seq_file="pos.fasta",
    seq_file_fmt="fasta",
    transform_sequences=Compose(
        [
            KmerTokenizer(
                k=1,
                stride=1,
                vocab_dict=unambiguous_dna_vocabulary_dict,
                num_output_tokens=MAXLEN,
                special_tokens=SpecialToken,
            ),
            OneHotEncoder(vocab_size=len(unambiguous_dna_vocabulary_dict) + len(SpecialToken)),
            ToTensor(dtype=torch.float32),
        ]
    ),
    targets=1,
    transform_targets=ToTensor(torch.float32),
)

neg_seq_dataset = SeqFromFileDataset(
    seq_file="neg.fasta",
    seq_file_fmt="fasta",
    transform_sequences=Compose(
        [
            KmerTokenizer(
                k=1,
                stride=1,
                vocab_dict=unambiguous_dna_vocabulary_dict,
                num_output_tokens=MAXLEN,
                special_tokens=SpecialToken,
            ),
            OneHotEncoder(vocab_size=len(unambiguous_dna_vocabulary_dict) + len(SpecialToken)),
            ToTensor(dtype=torch.float32),
        ]
    ),
    targets=0,
    transform_targets=ToTensor(torch.float32),
)

all_seq = ConcatDataset([pos_seq_dataset, neg_seq_dataset])

In [3]:
len(all_seq), all_seq[1][0].shape, all_seq[2][0].shape

(200, torch.Size([120, 14]), torch.Size([120, 14]))

In [4]:
BATCHSIZE = 64
train_val_ids, test_ids = train_test_split(range(len(all_seq)), test_size=0.2, random_state=0)
train_ids, val_ids = train_test_split(train_val_ids, test_size=0.2, random_state=0)
train_loader = DataLoader(Subset(all_seq, train_ids), batch_size=BATCHSIZE)
val_loader = DataLoader(Subset(all_seq, val_ids), batch_size=BATCHSIZE)
test_loader = DataLoader(Subset(all_seq, test_ids), batch_size=BATCHSIZE)

In [5]:
from seqinfer.infer.classifiers import LitBinaryClassifier, LitClassifier
import lightning as L

L.seed_everything(123, workers=True)

from torch import nn
import torchmetrics
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint

In [6]:
class SwapAxes(nn.Module):
    def __init__(self, dim0: int, dim1: int) -> None:
        super().__init__()
        self.dim0 = dim0
        self.dim1 = dim1

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.swapaxes(x, self.dim0, self.dim1)


class Squeeze(nn.Module):
    def __init__(self, dim: int | None = None) -> None:
        super().__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.squeeze(x)


convnet_binary = nn.Sequential(
    SwapAxes(1, 2),
    nn.Conv1d(in_channels=14, out_channels=8, kernel_size=8),
    nn.AvgPool1d(kernel_size=3),
    nn.Conv1d(in_channels=8, out_channels=8, kernel_size=6),
    nn.AvgPool1d(kernel_size=3),
    nn.Conv1d(in_channels=8, out_channels=8, kernel_size=3),
    nn.AvgPool1d(kernel_size=3),
    nn.Flatten(),
    nn.Linear(16, 1),
    Squeeze(),
)

In [7]:
lit_model = LitBinaryClassifier(
    model=convnet_binary, is_output_logits=True, loss=nn.BCEWithLogitsLoss()
)

early_stop_callback = EarlyStopping(monitor="val_loss", mode="min", patience=15, verbose=False)
checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="val_loss", mode="min")
logger = TensorBoardLogger("tb_logs", name="my_model")

trainer = L.Trainer(
    max_epochs=200,
    # precision="bf16-mixed",
    deterministic=True,
    # enable_checkpointing=False,
    log_every_n_steps=5,
    logger=logger,
    callbacks=[early_stop_callback, checkpoint_callback],
)

In [8]:
trainer.fit(model=lit_model, train_dataloaders=train_loader, val_dataloaders=val_loader)
best_lit_model = LitBinaryClassifier.load_from_checkpoint(
    checkpoint_callback.best_model_path, model=convnet_binary, loss=nn.BCEWithLogitsLoss()
)

                                                                           

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Epoch 30: 100%|██████████| 2/2 [00:00<00:00, 11.27it/s, v_num=0]


In [9]:
trainer.test(best_lit_model, train_loader)

  rank_zero_warn(


Testing DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 31.16it/s] 


[{'test_loss': 0.6922537684440613,
  'test_BinaryAccuracy': 0.4921875,
  'test_BinaryAUROC': 0.579487144947052}]

In [10]:
trainer.test(best_lit_model, val_loader)

Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 95.88it/s] 


[{'test_loss': 0.6922627091407776,
  'test_BinaryAccuracy': 0.46875,
  'test_BinaryAUROC': 0.4745097756385803}]

In [11]:
trainer.test(best_lit_model, test_loader)

Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 17.99it/s] 


[{'test_loss': 0.6961395740509033,
  'test_BinaryAccuracy': 0.550000011920929,
  'test_BinaryAUROC': 0.558080792427063}]

In [13]:
x, y = next(iter(train_loader))
output = best_lit_model.model(x)
torchmetrics.classification.BinaryAUROC()(output, y)

tensor(0.6089)