In [None]:
import lightning as L
import torch
from torch.utils.data import DataLoader

from pyaptamer.aptatrans import (
    AptaTrans,
    AptaTransLightning,
    AptaTransPipeline,
    EncoderPredictorConfig,
)
from pyaptamer.datasets import (
    load_csv_dataset,
)
from pyaptamer.datasets.dataclasses import APIDataset
from pyaptamer.utils._base import (
    filter_words,
)

# auto-reloading external modules
%load_ext autoreload
%autoreload 2

## Settings

In [None]:
BATCH_SIZE = 64  # mini-batch size
RAMDOM_STATE = 42  # for reproducibility

# embedding configurations for pretraining
# aptamers
N_APTA_VOCABS = 127
N_APTA_TARGET_VOCABS = 344
APTA_MAX_LEN = 275
# proteins
N_PROT_VOCABS = 715
N_PROT_TARGET_VOCABS = 585
PROT_MAX_LEN = 867

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Data

### Load frequencies of protein words

In [None]:
# (1.) load the proteins' dataset for pretraining
prot_words = load_csv_dataset(name="protein_word_freq")  # words and their frequencies
prot_words = prot_words.set_index("seq")["freq"].to_dict()

# (2.) transform sequences to a numerical representation (vectors) and filter the one
# with below-average frequency
filtered_prot_words = filter_words(prot_words)

### Load aptamer-protein interaction (API) dataset
Here, we load the API dataset used in the AptaTrans paper, from Li et a., 2014.

In [None]:
# (1.) load the api dataset for fine-tuning
train_dataset = load_csv_dataset(name="train_li2014")
test_dataset = load_csv_dataset(name="test_li2014")

# (2.) create the API dataset
train_dataset = APIDataset(
    x_apta=train_dataset["aptamer"].to_numpy(),
    x_prot=train_dataset["protein"].to_numpy(),
    y=train_dataset["label"].to_numpy(),
    apta_max_len=APTA_MAX_LEN,
    prot_max_len=PROT_MAX_LEN,
    prot_words=filtered_prot_words,
)
test_dataset = APIDataset(
    x_apta=test_dataset["aptamer"].to_numpy(),
    x_prot=test_dataset["protein"].to_numpy(),
    y=test_dataset["label"].to_numpy(),
    apta_max_len=APTA_MAX_LEN,
    prot_max_len=PROT_MAX_LEN,
    prot_words=filtered_prot_words,
    split="test",
)

# (3.) create dataloaders
train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

## Model
Before initializing an AptaTrans instance, we need to define the embedding configuration for both aptamers and proteins.

In [None]:
apta_embedding = EncoderPredictorConfig(
    num_embeddings=N_APTA_VOCABS,
    target_dim=N_APTA_TARGET_VOCABS,
    max_len=APTA_MAX_LEN,
)
prot_embedding = EncoderPredictorConfig(
    num_embeddings=N_PROT_VOCABS,
    target_dim=N_PROT_TARGET_VOCABS,
    max_len=PROT_MAX_LEN,
)

Then, we may initialize an AptaTrans instance, load its pretrained weights, and wrap it in a PyTorch Lightning module for training.

In [None]:
model = AptaTrans(
    apta_embedding=apta_embedding,
    prot_embedding=prot_embedding,
    in_dim=128,
    n_encoder_layers=6,
    n_heads=8,
    conv_layers=[3, 3, 3],
    dropout=0.1,
)
model.load_pretrained_weights()  # load pretrained weights
model_lightning = AptaTransLightning(model).to(device)  # wrap the model with Lightning

## Training

In [None]:
trainer = L.Trainer(max_epochs=100)
trainer.fit(model_lightning, train_dataloader)

## Recommend

In [None]:
pipeline = AptaTransPipeline(
    device=device,
    model=model,
    prot_words=prot_words,
    depth=20,  # i.e., how long the candidates will be
    n_iterations=10,  # number of iterations for the MCTS search
)


# specify the target protein sequence here
target_protein = (
    "STEYKLVVVGADGVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVVIDGETCLLDILDTAGQEEYSAM"
    "RDQYMRTGEGFLCVFAINNTKSFEDIHHYREQIKRVKDSEDVPMVLVGNKCDLPSRTVDTKQAQDLARSYGIPFIETSAKTR"
    "QGVDDAFYTLVREIRKHKEKMSK"
)
candidates = pipeline.recommend(
    target=target_protein,
    n_candidates=10,
    verbose=True,
)

In [None]:
for idx, candidate in enumerate(candidates):
    print(f"[Candidate {idx}] {candidate[0]} - Score: {float(candidate[2])}")