In [None]:
import lightning as L
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

from pyaptamer.aptatrans import (
    AptaTrans,
    AptaTransEncoderLightning,
    AptaTransLightning,
    AptaTransPipeline,
    EncoderPredictorConfig,
)
from pyaptamer.datasets import (
    load_csv_dataset,
    load_hf_dataset,
)
from pyaptamer.datasets.dataclasses import APIDataset, MaskedDataset
from pyaptamer.experiments import Aptamer
from pyaptamer.utils._aptatrans_utils import seq2vec
from pyaptamer.utils._augment import augment_reverse
from pyaptamer.utils._base import filter_words
from pyaptamer.utils._rna import rna2vec

# setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# auto-reloading external modules
%load_ext autoreload
%autoreload 2

### Settings

In [None]:
BATCH_SIZE = 16
TEST_SIZE = 0.05  # size of the test set for pretraining
RAMDOM_STATE = 42  # for reproducibility

# embeddings for pretraining
# aptamers
N_APTA_VOCABS = 127
N_APTA_TARGET_VOCABS = 344
APTA_MAX_LEN = 275
# proteins
N_PROT_VOCABS = 715
N_PROT_TARGET_VOCABS = 585
PROT_MAX_LEN = 867

# AptaTrans
In this notebook, we present a tutorial on using AptaTrans [[1](#ref-1)] for the following tasks:
1. **Aptamer-protein interaction (API) prediction:** given aptamer and target protein sequences, predict whether they bind or not (binary classification);
2. **Recommend candidate aptamers:** given a target protein sequence, combine AptaTrans with Monte Carlo Tree Search (MCTS) to generate candidate aptamers that are likely to bind with the target.

AptaTrans is a deep neural network based on two attention-based encoders, one for aptamers and one for proteins, and multiple convolutional layers. The encoders are employed to capture relevant features for structural representation of aptamers and proteins respectively, by pretraining them on the task of predicting the secondary structure (ss) of given sequences. After pretraining, the deep neural network may be further fine-tuned by training it on a binary classification dataset of API interactions.

The tutorial is organized as follows. In [Step 1](#step-1-data-preparation), we load and pre-process the data needed for pretraining and fine-tuning. In [Step 2](#step-2-model) we initialize AptaTrans' neural network. In [Step 3](#step-3-training) we cover pretraining and fine-tuning, for the encoders and the overall neural network, respectively. 
Finally, [Step 4](#step-4-recommend) and [Step 5](#step-5-api-prediction) are dedicated to examples of candidate aptamer recommendation and API prediction, respectively.

##### References
<a id="ref-1"></a>
[1] Shin, Incheol, et al. "AptaTrans: a deep neural network for predicting aptamer-protein interaction using pretrained encoders." BMC bioinformatics 24.1 (2023): 447. <br>
<a id="ref-2"></a>
[2] Danaee, Padideh, et al. "bpRNA: large-scale automated annotation and analysis of RNA secondary structure." Nucleic acids research 46.11 (2018): 5381-5394. <br>
<a id="ref-3"></a>
[3] Berman, Helen M., et al. "The protein data bank." Nucleic acids research 28.1 (2000): 235-242. <br>
<a id="ref-4"></a>
[4] Li, Bi-Qing, et al. "Prediction of aptamer-target interacting pairs with pseudo-amino acid composition." PLoS One 9.1 (2014): e86729. <br>

## Step 1: Data preparation
Here, we prepare the data needed for pretraining and fine-tuning.

### Load (RNA) aptamer data for pretraining
For pretraining the aptamer encoder, we use $79,890$ RNA aptamer sequences from the *bpRNA-1m* dataset from *bpRNA* [[2](#ref-2)].

The sequences are augmented by adding their reverse complements. Then, they are masked to a numerical format suitable for the encoder and stored in PyTorch dataloaders.

In [None]:
# (1.) load the RNA dataset for pretraining
apta_dataset = load_hf_dataset(name="bpRNA-shin2023", store=True)

# (2.) Create training-test splits of (sequence, secondary structure (ss)) pairs
x_apta_train, x_apta_test, y_apta_train, y_apta_test = train_test_split(
    apta_dataset["SEQUENCE"].tolist(),
    apta_dataset["SS"].tolist(),
    test_size=TEST_SIZE,
    random_state=RAMDOM_STATE,
)

# (3.) augment training data by adding reverse complements
# e.g., (seq="ACG", ss="SHM") -> (seq="GCA", ss="MHS")
x_apta_train, y_apta_train = augment_reverse(x_apta_train, y_apta_train)

# (4.) Convert aptamer sequences and secondary structures to (integer) numerical vectors
x_apta_train = rna2vec(x_apta_train, sequence_type="rna")
y_apta_train = rna2vec(y_apta_train, sequence_type="ss")

# (5.) mask the dataset for pretraining embeddings
train_apta = MaskedDataset(
    x=x_apta_train,
    y=y_apta_train,
    max_len=APTA_MAX_LEN,
    mask_idx=N_APTA_VOCABS - 1,
    is_rna=True,
)
test_apta = MaskedDataset(
    x=x_apta_test,
    y=y_apta_test,
    max_len=APTA_MAX_LEN,
    mask_idx=N_APTA_VOCABS - 1,
    is_rna=True,
)

# (6.) create dataloaders
train_apta_dataloader = DataLoader(
    train_apta,
    batch_size=BATCH_SIZE,
    shuffle=True,
)
test_apta_dataloader = DataLoader(
    test_apta,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

### Load protein data for pretraining
For pretraining the protein encoder, we use $166,136$ protein sequences from the Protein Data Bank (PDB) [[3](#ref-3)].

In this case, the sequences are not augmented by adding the reverse complements. However, protein words with below average frequency are filtered out. Then, similarly to above, sequences are transformed to a numerical representation suitable for the encoder and stored in PyTorch dataloaders.

In [None]:
# (1.) load the proteins' dataset for pretraining
prot_dataset = load_hf_dataset(name="proteins-shin2023", store=True)
prot_words = load_csv_dataset(name="protein_word_freq")  # words and their frequencies
prot_words = prot_words.set_index("seq")["freq"].to_dict()

# (2.) Create training-test splits of (sequence, secondary structure (ss)) pairs
x_prot_train, x_prot_test, y_prot_train, y_prot_test = train_test_split(
    prot_dataset["SEQUENCE"].tolist(),
    prot_dataset["SS"].tolist(),
    test_size=TEST_SIZE,
    random_state=RAMDOM_STATE,
)

# (3.) transform sequences to a numerical representation (vectors)
filtered_prot_words = filter_words(prot_words)  # filter below average frequency words
x_prot_train, y_prot_train = seq2vec(
    sequence_list=(x_prot_train, y_prot_train),
    words=filtered_prot_words,
    seq_max_len=PROT_MAX_LEN,
)
x_prot_test, y_prot_test = seq2vec(
    sequence_list=(x_prot_test, y_prot_test),
    words=filtered_prot_words,
    seq_max_len=PROT_MAX_LEN,
)

# (4.) mask the dataset for pretraining embeddings
train_prot = MaskedDataset(
    x=x_prot_train,
    y=y_prot_train,
    max_len=PROT_MAX_LEN,
    mask_idx=N_PROT_VOCABS - 1,
    is_rna=True,
)
test_prot = MaskedDataset(
    x=x_prot_test,
    y=y_prot_test,
    max_len=PROT_MAX_LEN,
    mask_idx=N_PROT_VOCABS - 1,
    is_rna=False,
)

# (5.) create dataloaders
train_prot_dataloader = DataLoader(
    train_prot,
    batch_size=BATCH_SIZE,
    shuffle=True,
)
test_prot_dataloader = DataLoader(
    test_prot,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

### Load aptamer-protein interaction (API) dataset
For fine-tuning (i.e., train the neural network on the task of API prediction), we employ a selection of (aptamer, protein) pairs known to bind or not from [4].

In [None]:
# (1.) load the api dataset for fine-tuning
train_dataset = load_csv_dataset(name="train_li2014")
test_dataset = load_csv_dataset(name="test_li2014")

# (2.) create the API dataset
train_dataset = APIDataset(
    x_apta=train_dataset["aptamer"].to_numpy(),
    x_prot=train_dataset["protein"].to_numpy(),
    y=train_dataset["label"].to_numpy(),
    apta_max_len=APTA_MAX_LEN,
    prot_max_len=PROT_MAX_LEN,
    prot_words=filtered_prot_words,
)
test_dataset = APIDataset(
    x_apta=test_dataset["aptamer"].to_numpy(),
    x_prot=test_dataset["protein"].to_numpy(),
    y=test_dataset["label"].to_numpy(),
    apta_max_len=APTA_MAX_LEN,
    prot_max_len=PROT_MAX_LEN,
    prot_words=filtered_prot_words,
    split="test",
)

# (3.) create dataloaders
train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

## Step 2: Model
Here, we initialize AptaTrans model, employing the same hyperparameters from the paper.

In [None]:
# (1.) define embedding configurations for the encoders
apta_embedding = EncoderPredictorConfig(
    num_embeddings=N_APTA_VOCABS,
    target_dim=N_APTA_TARGET_VOCABS,
    max_len=APTA_MAX_LEN,
)
prot_embedding = EncoderPredictorConfig(
    num_embeddings=N_PROT_VOCABS,
    target_dim=N_PROT_TARGET_VOCABS,
    max_len=PROT_MAX_LEN,
)

# (2.) initialize the model
model = AptaTrans(
    apta_embedding=apta_embedding,
    prot_embedding=prot_embedding,
    in_dim=128,
    n_encoder_layers=6,
    n_heads=8,
    conv_layers=[3, 3, 3],
    dropout=0.1,
).to(device)

# (optional) (2.1) load pretrained weights
model.load_pretrained_weights()

## Step 3: Training
For training, we leverage the `Lightning` framework. Hyperparameters related to optimization are already set to the ones used in the paper.

If you decide to load pretrained weights, you may skip this step.

### Pretraining
Sequential pre-training of the aptamer and protein encoders.

In [None]:
encoder_lightning_apta = AptaTransEncoderLightning(
    model=model,
    lr=1e-4,
    weight_decay=1e-5,
    encoder_type="apta",
).to(device)
trainer = L.Trainer(max_epochs=10)
trainer.fit(encoder_lightning_apta, train_apta_dataloader)

In [None]:
encoder_prot_lightning = AptaTransEncoderLightning(
    model=model,
    lr=1e-4,
    weight_decay=1e-5,
    encoder_type="prot",
).to(device)
trainer = L.Trainer(max_epochs=1000)
trainer.fit(encoder_prot_lightning, train_prot_dataloader)

### Fine-tuning
Fine-tuning the deep neural network by training for the task of API prediction.

In [None]:
model_lightning = AptaTransLightning(
    model=model,
    lr=1e-5,
    weight_decay=1e-5,
).to(device)
trainer = L.Trainer(max_epochs=100)
trainer.fit(model_lightning, train_dataloader)

## Step 4: Recommend
Here, you may generate candidate aptamers for a given target protein sequence.

In `AptaTransPipeline` we combine Monte Carlo Tree Search (MCTS) and AptaTrans' deep neural network for generating candidates and ranking them, respectively.

In [None]:
# specify the target protein sequence here
target_protein = (
    "STEYKLVVVGADGVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVVIDGETCLLDILDTAGQEEYSAM"
    "RDQYMRTGEGFLCVFAINNTKSFEDIHHYREQIKRVKDSEDVPMVLVGNKCDLPSRTVDTKQAQDLARSYGIPFIETSAKTR"
    "QGVDDAFYTLVREIRKHKEKMSK"
)

pipeline = AptaTransPipeline(
    device=device,
    model=model,
    prot_words=prot_words,
    depth=20,  # depth of the search, length of generated candidates
    n_iterations=10,  # number of iterations (higher is better, but slower)
)
candidates = pipeline.recommend(
    target=target_protein,
    n_candidates=10,  # number of candidates to generate
    verbose=True,
)

In [None]:
# print generate candidates and their score (higher is better)
for idx, candidate in enumerate(candidates):
    print(f"[Candidate {idx + 1}] {candidate[0]} - Score: {float(candidate[2])}")

## Step 5: API prediction
Here, you may predict whether a given aptamer and protein sequence bind or not.

In [None]:
# specify the target protein sequence here
target_protein = (
    "STEYKLVVVGADGVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVVIDGETCLLDILDTAGQEEYSAM"
    "RDQYMRTGEGFLCVFAINNTKSFEDIHHYREQIKRVKDSEDVPMVLVGNKCDLPSRTVDTKQAQDLARSYGIPFIETSAKTR"
    "QGVDDAFYTLVREIRKHKEKMSK"
)
aptamer_candidate = "..."  # specify the aptamer candidate sequence here

experiment = Aptamer(
    target=target_protein, model=model, device=device, prot_words=filtered_prot_words
)
score = experiment.evaluate(aptamer_candidate=aptamer_candidate)
print(f"Score: {score.item():.4f}")