In [None]:
import pickle

import numpy as np
import torch

from pyaptamer.aptatrans.model import AptaTrans
from pyaptamer.aptatrans.pipeline import AptaTransPipeline
from pyaptamer.aptatrans.layers.encoder import EmbeddingConfig

# auto-reloading external modules
%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Preliminaries

Let's initialize useful constants for both aptamers and proteins embeddings. These are taken directly from the authors repositories are they are dependent on the training data used to pretrain the transformer-based encoders.

In [None]:
apta_embedding = EmbeddingConfig(
    n_vocabs=127,
    n_target_vocabs=344,
    max_len=275,
)
prot_embedding = EmbeddingConfig(
    n_vocabs=715,
    n_target_vocabs=585,
    max_len=867,
)

Additionally, we need to initialize a dictionary that stores the most frequent (above-average) protein 3-mer subsequences w.r.t. training data. This will be needed to encode the target protein during aptamer generation, to make embeddings work. For simplicity, let's just use the protein frequencies stored in this .pickle file provided by the authors.

In [None]:
with open(f'../pyaptamer/data/protein_word_freq.pickle', 'rb') as inf:
    df = pickle.load(inf) # load
    words = df.set_index('seq')['freq'].to_dict() # to dictionary

# AptaTrans

Now, let's initialize an instance of AptaTrans (neural network) to be used within the pipeline that leverages Apta-MCTS for aptamer recommendations. For simplicity, we keep the default parameters (e.g., number of layers, layers' dimensions, etc.) and we do not train the neural network from scratch, as it is not supported yet.

In [None]:
aptatrans = AptaTrans(
    apta_embedding=apta_embedding,
    prot_embedding=prot_embedding,
)
pipeline = AptaTransPipeline(
    device=device,
    model=aptatrans,
    prot_words=words,
)

In [None]:
target = 'STEYKLVVVGADGVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVVIDGETCLLDILDTAGQEEYSAMRDQYMRTGEGFLCVFAINNTKSFEDIHHYREQIKRVKDSEDVPMVLVGNKCDLPSRTVDTKQAQDLARSYGIPFIETSAKTRQGVDDAFYTLVREIRKHKEKMSK'
candidates = pipeline.recommend(
    target=target,
    n_candidates=2,
    depth=5,
    n_iterations=1,
)