In [None]:
from training.src.aasp.data_handler.data_handler import AASPConfig, AASPDataHandler
from training.src.aasp.data_handler.data_handler import AASPDataset

config_path = "training/src/aasp/data_handler/config.yaml"
cfg = AASPConfig(config_path)
handler = AASPDataHandler(cfg)

In [None]:
selected_features = [
    k for k, v in getattr(cfg, "features", {}).items()
    if v and k in {"ref_embedding", "alt_embedding", "biotype", "consequence", "ref_long", "alt_long", "scoreset"}
]
cat_config = getattr(cfg, "categorical_config", {})


In [None]:
records = handler.load_pickle(cfg.file_path)
val_frac = cfg.parameters.get("val_frac", 0.15)
val_size = int(len(records) * val_frac)
train_records = records[:len(records) - val_size]
val_records = records[len(records) - val_size:]


In [None]:
vocabs = {k: handler.fit_vocab(train_records, k) for k in cat_config}


In [None]:
train_dataset = AASPDataset(
    config_path=config_path,
    fields=selected_features,
    fuse_mode=cfg.fuse_mode,
    embed_metric=cfg.embed_metric,
    categorical_config=cat_config
)
val_dataset = AASPDataset(
    config_path=config_path,
    fields=selected_features,
    fuse_mode=cfg.fuse_mode,
    embed_metric=cfg.embed_metric,
    categorical_config=cat_config
)
train_dataset.records = train_records
val_dataset.records = val_records

from torch.utils.data import DataLoader
train_loader = DataLoader(train_dataset, batch_size=cfg.hyperparameters["train_batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=cfg.hyperparameters["val_batch_size"], shuffle=False)


Actual Model Training Example
================================

In [None]:
from training.src.aasp.model.models import BaselineModel
from training.src.aasp.model.trainer import Trainer

cat_dims = {k: (len(vocabs[k]), 4) for k, typ in cat_config.items() if typ == "embedding"}
multi_hot_dims = {k: len(vocabs[k]) for k, typ in cat_config.items() if typ == "multi_hot"}

model = BaselineModel(
    input_dim=1,  # If "distance"
    cat_dims=cat_dims,
    multi_hot_dims=multi_hot_dims,
    hidden_dims=tuple(cfg.hyperparameters["hidden_dims"]),
    dropout_rates=tuple(cfg.hyperparameters["dropout_rates"])
)

import torch
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.hyperparameters["learning_rate"])
loss_fn = torch.nn.MSELoss()

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    loss_fn=loss_fn,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=cfg.hyperparameters["num_epochs"],
    save_path="output/baseline_model_best.pth",
    device= "cuda" if torch.cuda.is_available() else "cpu" # Change to "cuda" if GPU is available
)


Dumb Model Training Example
================================

In [None]:
from training.src.aasp.model.models import DumbModel
from training.src.aasp.model.trainer import Trainer

cat_dims = {k: (len(vocabs[k]), 4) for k, typ in cat_config.items() if typ == "embedding"}
multi_hot_dims = {k: len(vocabs[k]) for k, typ in cat_config.items() if typ == "multi_hot"}

model = DumbModel(
    input_dim=1,  # If "distance"
    cat_dims=cat_dims,
    multi_hot_dims=multi_hot_dims,
    hidden_dims=tuple(cfg.hyperparameters["hidden_dims"]),
    dropout_rates=tuple(cfg.hyperparameters["dropout_rates"])
)

import torch
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.hyperparameters["learning_rate"])
loss_fn = torch.nn.MSELoss()

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    loss_fn=loss_fn,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=cfg.hyperparameters["num_epochs"],
    save_path="output/baseline_model_best.pth",
    device= "cuda" if torch.cuda.is_available() else "cpu" # Change to "cuda" if GPU is available
)


In [None]:
trainer.run()