In [1]:
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

from pyaptamer.aptatrans import (
    AptaTrans,
    AptaTransPipeline,
    AptaTransSolver,
    EncoderPredictorConfig,
)
from pyaptamer.datasets import (
    load_csv_dataset,
    load_hf_dataset,
)
from pyaptamer.datasets.dataclasses import MaskedDataset
from pyaptamer.utils._augment import augment_reverse
from pyaptamer.utils._base import (
    filter_words,
    seq2vec,
)

# auto-reloading external modules
%load_ext autoreload
%autoreload 2

In [2]:
import lightning as pl

In [3]:
pl.__version__

'2.5.3'

## Settings

In [None]:
BATCH_SIZE = 64
TEST_SIZE = 0.05
RAMDOM_STATE = 42  # for reproducibility

# embedding configuras for pretraining
# aptamers
N_APTA_VOCABS = 127
N_APTA_TARGET_VOCABS = 344
APTA_MAX_LEN = 275
# proteins
N_PROT_VOCABS = 715
N_PROT_TARGET_VOCABS = 585
PROT_MAX_LEN = 867

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Data

### Load RNA data for pretraining

In [None]:
# (1.) load the RNA dataset for pretraining
rna_dataset = load_hf_dataset(name="bpRNA-shin2023", store=True)

# (2.) Creaye training-test splits of (sequence, secondary structure (ss)) pairs
x_rna_train, x_rna_test, y_rna_train, y_rna_test = train_test_split(
    rna_dataset["SEQUENCE"].tolist(),
    rna_dataset["SS"].tolist(),
    test_size=TEST_SIZE,
    random_state=RAMDOM_STATE,
)

# (3.) augment training data by adding reverse complements
# e.g., (seq="ACG", ss="SHM") -> (seq="GCA", ss="MHS")
x_rna_train, y_rna_train = augment_reverse(x_rna_train, y_rna_train)

# (4.) mask the dataset for pretraining embeddings
train_rna = MaskedDataset(
    x=x_rna_train,
    y=y_rna_train,
    max_len=APTA_MAX_LEN,
    mask_idx=N_APTA_VOCABS - 1,
    is_rna=True,
)
test_rna = MaskedDataset(
    x=x_rna_test,
    y=y_rna_test,
    max_len=APTA_MAX_LEN,
    mask_idx=N_APTA_VOCABS - 1,
    is_rna=True,
)

# (5.) create dataloaders
train_rna_dataloader = DataLoader(
    train_rna,
    batch_size=BATCH_SIZE,
    shuffle=True,
)
test_rna_dataloader = DataLoader(
    test_rna,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

### Load protein data for pretraining

In [None]:
# (1.) load the proteins' dataset for pretraining
prot_dataset = load_hf_dataset(name="proteins-shin2023", store=True)
prot_words = load_csv_dataset(name="protein_word_freq")  # words and their frequencies
prot_words = prot_words.set_index("seq")["freq"].to_dict()

# (2.) Creaye training-test splits of (sequence, secondary structure (ss)) pairs
x_prot_train, x_prot_test, y_prot_train, y_prot_test = train_test_split(
    prot_dataset["SEQUENCE"].tolist(),
    prot_dataset["SS"].tolist(),
    test_size=TEST_SIZE,
    random_state=RAMDOM_STATE,
)

# (3.) transform sequences to a numerical representation (vectors)
filtered_prot_words = filter_words(prot_words)
x_prot_train, y_prot_train = seq2vec(
    sequence_list=(x_prot_train, y_prot_train),
    words=filtered_prot_words,
    seq_max_len=PROT_MAX_LEN,
)
x_prot_test, y_prot_test = seq2vec(
    sequence_list=(x_prot_test, y_prot_test),
    words=filtered_prot_words,
    seq_max_len=PROT_MAX_LEN,
)

# (4.) mask the dataset for pretraining embeddings
train_prot = MaskedDataset(
    x=x_prot_train,
    y=y_prot_train,
    max_len=PROT_MAX_LEN,
    mask_idx=N_PROT_VOCABS - 1,
    is_rna=True,
)
test_prot = MaskedDataset(
    x=x_prot_test,
    y=y_prot_test,
    max_len=PROT_MAX_LEN,
    mask_idx=N_PROT_VOCABS - 1,
    is_rna=False,
)

# (5.) create dataloaders
train_prot_dataloader = DataLoader(
    train_prot,
    batch_size=BATCH_SIZE,
    shuffle=True,
)
test_prot_dataloader = DataLoader(
    test_prot,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

### Load aptamer-protein interaction (API) dataset for fine-tuning

In [None]:
train_dataset = load_csv_dataset(name="train_li2014")
test_dataset = load_csv_dataset(name="test_li2014")

In [None]:
train_dataset

## Model

In [None]:
apta_embedding = EncoderPredictorConfig(
    num_embeddings=N_APTA_VOCABS,
    target_dim=N_APTA_TARGET_VOCABS,
    max_len=APTA_MAX_LEN,
)
prot_embedding = EncoderPredictorConfig(
    num_embeddings=N_PROT_VOCABS,
    target_dim=N_PROT_TARGET_VOCABS,
    max_len=PROT_MAX_LEN,
)

In [None]:
model = AptaTrans(
    apta_embedding=apta_embedding,
    prot_embedding=prot_embedding,
    in_dim=128,
    n_encoder_layers=6,
    n_heads=8,
    conv_layers=[3, 3, 3],
    dropout=0.1,
)

## Pretraining

## Train

In [None]:
solver = AptaTransSolver()

## Recommend

In [None]:
pipeline = AptaTransPipeline(
    device=device,
    model=model,
    prot_words=prot_words,
    # depth=40, # i.e., how long the candidates will be
    # n_iterations=1000,
    depth=5,  # this determine the length of the candidates
    n_iterations=1,
)

# example: human insulin - A chain
target_protein = "GIVEQCCTSICSLYQLENYCN"  # specify the target protein sequence here
candidates = pipeline.recommend(
    target=target_protein,
    n_candidates=10,
    verbose=True,
)

In [None]:
candidates