In [None]:
import torch

from pyaptamer.aptatrans import (
    AptaTrans,
    AptaTransPipeline,
    EncoderPredictorConfig,
)
from pyaptamer.datasets import (
    load_csv_dataset,
)
from pyaptamer.utils._base import (
    filter_words,
)

# auto-reloading external modules
%load_ext autoreload
%autoreload 2

## Settings

In [None]:
BATCH_SIZE = 64
TEST_SIZE = 0.05
RAMDOM_STATE = 42  # for reproducibility

# embedding configurations for pretraining
# aptamers
N_APTA_VOCABS = 127
N_APTA_TARGET_VOCABS = 344
APTA_MAX_LEN = 275
# proteins
N_PROT_VOCABS = 715
N_PROT_TARGET_VOCABS = 585
PROT_MAX_LEN = 867

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Data

### Load RNA data for pretraining

In [None]:
"""# (1.) load the RNA dataset for pretraining
rna_dataset = load_hf_dataset(name="bpRNA-shin2023", store=True)

# (2.) Creaye training-test splits of (sequence, secondary structure (ss)) pairs
x_rna_train, x_rna_test, y_rna_train, y_rna_test = train_test_split(
    rna_dataset["SEQUENCE"].tolist(),
    rna_dataset["SS"].tolist(),
    test_size=TEST_SIZE,
    random_state=RAMDOM_STATE,
)

# (3.) augment training data by adding reverse complements
# e.g., (seq="ACG", ss="SHM") -> (seq="GCA", ss="MHS")
x_rna_train, y_rna_train = augment_reverse(x_rna_train, y_rna_train)

# (4.) mask the dataset for pretraining embeddings
train_rna = MaskedDataset(
    x=x_rna_train,
    y=y_rna_train,
    max_len=APTA_MAX_LEN,
    mask_idx=N_APTA_VOCABS - 1,
    is_rna=True,
)
test_rna = MaskedDataset(
    x=x_rna_test,
    y=y_rna_test,
    max_len=APTA_MAX_LEN,
    mask_idx=N_APTA_VOCABS - 1,
    is_rna=True,
)

# (5.) create dataloaders
train_rna_dataloader = DataLoader(
    train_rna,
    batch_size=BATCH_SIZE,
    shuffle=True,
)
test_rna_dataloader = DataLoader(
    test_rna,
    batch_size=BATCH_SIZE,
    shuffle=True,
)"""

### Load protein data for pretraining

In [None]:
# (1.) load the proteins' dataset for pretraining
# prot_dataset = load_hf_dataset(name="proteins-shin2023", store=True)
prot_words = load_csv_dataset(name="protein_word_freq")  # words and their frequencies
prot_words = prot_words.set_index("seq")["freq"].to_dict()

# (2.) Creaye training-test splits of (sequence, secondary structure (ss)) pairs
"""x_prot_train, x_prot_test, y_prot_train, y_prot_test = train_test_split(
    prot_dataset["SEQUENCE"].tolist(),
    prot_dataset["SS"].tolist(),
    test_size=TEST_SIZE,
    random_state=RAMDOM_STATE,
)"""

# (3.) transform sequences to a numerical representation (vectors)
filtered_prot_words = filter_words(prot_words)
"""x_prot_train, y_prot_train = seq2vec(
    sequence_list=(x_prot_train, y_prot_train),
    words=filtered_prot_words,
    seq_max_len=PROT_MAX_LEN,
)
x_prot_test, y_prot_test = seq2vec(
    sequence_list=(x_prot_test, y_prot_test),
    words=filtered_prot_words,
    seq_max_len=PROT_MAX_LEN,
)

# (4.) mask the dataset for pretraining embeddings
train_prot = MaskedDataset(
    x=x_prot_train,
    y=y_prot_train,
    max_len=PROT_MAX_LEN,
    mask_idx=N_PROT_VOCABS - 1,
    is_rna=True,
)
test_prot = MaskedDataset(
    x=x_prot_test,
    y=y_prot_test,
    max_len=PROT_MAX_LEN,
    mask_idx=N_PROT_VOCABS - 1,
    is_rna=False,
)

# (5.) create dataloaders
train_prot_dataloader = DataLoader(
    train_prot,
    batch_size=BATCH_SIZE,
    shuffle=True,
)
test_prot_dataloader = DataLoader(
    test_prot,
    batch_size=BATCH_SIZE,
    shuffle=True,
)"""

### Load aptamer-protein interaction (API) dataset for fine-tuning

In [None]:
"""# (1.) load the api dataset for fine-tuning
train_dataset = load_csv_dataset(name="train_li2014")
test_dataset = load_csv_dataset(name="test_li2014")

# (2.) create the API dataset
train_dataset = APIDataset(
    x_apta=train_dataset["aptamer"].to_numpy(),
    x_prot=train_dataset["protein"].to_numpy(),
    y=train_dataset["label"].to_numpy(),
    apta_max_len=APTA_MAX_LEN,
    prot_max_len=PROT_MAX_LEN,
    prot_words=filtered_prot_words,
)
test_dataset = APIDataset(
    x_apta=test_dataset["aptamer"].to_numpy(),
    x_prot=test_dataset["protein"].to_numpy(),
    y=test_dataset["label"].to_numpy(),
    apta_max_len=APTA_MAX_LEN,
    prot_max_len=PROT_MAX_LEN,
    prot_words=filtered_prot_words,
    split="test",
)

# (3.) create dataloaders
train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)"""

## Model

In [None]:
apta_embedding = EncoderPredictorConfig(
    num_embeddings=N_APTA_VOCABS,
    target_dim=N_APTA_TARGET_VOCABS,
    max_len=APTA_MAX_LEN,
)
prot_embedding = EncoderPredictorConfig(
    num_embeddings=N_PROT_VOCABS,
    target_dim=N_PROT_TARGET_VOCABS,
    max_len=PROT_MAX_LEN,
)

In [None]:
model = AptaTrans(
    apta_embedding=apta_embedding,
    prot_embedding=prot_embedding,
    in_dim=128,
    n_encoder_layers=6,
    n_heads=8,
    conv_layers=[3, 3, 3],
    dropout=0.1,
)
model.load_pretrained_weights()

## Pretraining

In [None]:
# TODO

## Train

In [None]:
"""solver = AptaTransSolver(
    model=model,
    train_dataloader=train_dataloader,
    test_dataloader=test_dataloader,
    device=device,
)
_ = solver.train(epochs=10, show_progress=True)"""

## Recommend

In [None]:
pipeline = AptaTransPipeline(
    device=device,
    model=model,
    prot_words=prot_words,
    depth=40,  # i.e., how long the candidates will be
    n_iterations=1000,
)


# specify the target protein sequence here
target_protein = (
    "STEYKLVVVGADGVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVVIDGETCLLDILDTAGQEEYSAM"
    "RDQYMRTGEGFLCVFAINNTKSFEDIHHYREQIKRVKDSEDVPMVLVGNKCDLPSRTVDTKQAQDLARSYGIPFIETSAKTRQG"
    "VDDAFYTLVREIRKHKEKMSK"
)
candidates = pipeline.recommend(
    target=target_protein,
    n_candidates=10,
    verbose=True,
)

In [None]:
for idx, candidate in enumerate(candidates):
    print(f"[Candidate {idx}] {candidate[0]} - Score: {float(candidate[2])}")

In [None]:
target_protein = (
    "STEYKLVVVGADGVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVVIDGETCLLDILDTAGQEEYSAM"
    "RDQYMRTGEGFLCVFAINNTKSFEDIHHYREQIKRVKDSEDVPMVLVGNKCDLPSRTVDTKQAQDLARSYGIPFIETSAKTRQG"
    "VDDAFYTLVREIRKHKEKMSK"
)