### Basic Settings

In [None]:
import os

import torch
import yaml

from sssd.core.model_specs import MODEL_PATH_FORMAT, setup_model
from sssd.data.utils import get_dataloader
from sssd.training.trainer import DiffusionTrainer
from sssd.utils.logger import setup_logger
from sssd.utils.utils import calc_diffusion_hyperparams

if "results" not in os.listdir(os.getcwd()):
    os.chdir("../../")

### Setup Device

In [None]:
if torch.cuda.device_count() > 0:
    print(f"Using {torch.cuda.device_count()} GPUs!")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Import Configs

In [None]:
with open("configs/model.yaml", "rt") as f:
    model_config = yaml.safe_load(f.read())
with open("configs/training.yaml", "rt") as f:
    training_config = yaml.safe_load(f.read())

### Setup Result Directory

In [None]:
local_path = MODEL_PATH_FORMAT.format(
    T=model_config["diffusion"]["T"],
    beta_0=model_config["diffusion"]["beta_0"],
    beta_T=model_config["diffusion"]["beta_T"],
)
output_directory = os.path.join(training_config["output_directory"], local_path)

if not os.path.isdir(output_directory):
    os.makedirs(output_directory)
    os.chmod(output_directory, 0o775)
print(f"Output directory %s", output_directory)

### Define Model Essentials

In [None]:
dataloader = get_dataloader(
    training_config["data"]["train_path"],
    batch_size=training_config.get("batch_size"),
    device=device,
)

diffusion_hyperparams = calc_diffusion_hyperparams(
    **model_config["diffusion"], device=device
)
net = setup_model(training_config["use_model"], model_config, device)

### Model Training

In [None]:
trainer = DiffusionTrainer(
    dataloader=dataloader,
    diffusion_hyperparams=diffusion_hyperparams,
    net=net,
    device=device,
    output_directory=output_directory,
    ckpt_iter=training_config.get("ckpt_iter"),
    n_iters=training_config.get("n_iters"),
    iters_per_ckpt=training_config.get("iters_per_ckpt"),
    iters_per_logging=training_config.get("iters_per_logging"),
    learning_rate=training_config.get("learning_rate"),
    only_generate_missing=training_config.get("only_generate_missing"),
    masking=training_config.get("masking"),
    missing_k=training_config.get("missing_k"),
    batch_size=training_config.get("batch_size"),
    logger=setup_logger(),
)
trainer.train()