In [None]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

from pyaptamer.datasets import load_csv_dataset, load_hf_dataset
from pyaptamer.datasets.dataclasses import MaskedDataset
from pyaptamer.utils.preprocessing import augment_reverse

# auto-reloading external modules
%load_ext autoreload
%autoreload 2

## Settings

In [None]:
BATCH_SIZE = 64
TEST_SIZE = 0.05
RAMDOM_STATE = 42 # for reproducibility

# RNA embeddings (pretraining)
RNA_MAX_LEN = 275

## Data

### Load RNA data for pretraining

In [None]:
# (1.) load the RNA dataset for pretraining
rna_dataset = load_hf_dataset(name="bpRNA-shin2023", store=True)

# (2.) Creaye training-test splits of (sequence, secondary structure (ss)) pairs
x_rna_train, x_rna_test, y_rna_train, y_rna_test = train_test_split(
    rna_dataset["SEQUENCE"].tolist(), 
    rna_dataset["SS"].tolist(), 
    test_size=TEST_SIZE, 
    random_state=RAMDOM_STATE,
)

# (3.) augment training data by adding reverse complements
# e.g., (seq="ACG", ss="SHM") -> (seq="GCA", ss="MHS")
x_rna_train, y_rna_train = augment_reverse(x_rna_train, y_rna_train)

# (4.) mask the dataset for pretraining embeddings
train_rna = MaskedDataset(
    x=x_rna_train,
    y=y_rna_train,
    mask_idx=0,
    max_len=RNA_MAX_LEN,
    is_rna=True,
)
test_rna = MaskedDataset(
    x=x_rna_test,
    y=y_rna_test,
    mask_idx=0,
    max_len=RNA_MAX_LEN,
    is_rna=True,
)

# (5.) create dataloaders
train_rna_dataloader = DataLoader(
    train_rna,
    batch_size=BATCH_SIZE,
    shuffle=True,
)
test_rna_dataloader = DataLoader(
    test_rna,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

### Load protein data for pretraining

In [None]:
# (1.) load the proteins' dataset for pretraining
prot_dataset = load_hf_dataset(name="proteins-shin2023", store=True)

# (2.) Creaye training-test splits of (sequence, secondary structure (ss)) pairs
x_prot_train, x_prot_test, y_prot_train, y_prot_test = train_test_split(
    prot_dataset["SEQUENCE"].tolist(), 
    prot_dataset["SS"].tolist(), 
    test_size=TEST_SIZE, 
    random_state=RAMDOM_STATE,
)

"""
# (3.) transform sequence to a numerical representation (vectors)


# (4.) create dataloaders
train_prot_dataloader = DataLoader(
    train_prot,
    batch_size=BATCH_SIZE,
    shuffle=True,
)
test_prot_dataloader = DataLoader(
    test_prot,
    batch_size=BATCH_SIZE,
    shuffle=True,
)
"""
