# Tutorial

> API details.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
import torch

import pytorch_lightning as pl

# from codecarbon import EmissionsTracker
from pathlib import Path
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.loggers import WandbLogger
from retrofit.data import RetroDataset
from retrofit.model import RetroFitModel

In [None]:
column = "whole_func_string"
retro_ds = RetroDataset(
    "code_search_net",
    "flax-sentence-embeddings/st-codesearch-distilroberta-base",
    "distilbert-base-uncased",
    "gpt2",
    dataset_config="python",
    column=column,
    batch_size=2,
    k=2,
    n_perc=1
)
model = RetroFitModel(
    "distilbert-base-uncased",
    "gpt2",
    column,
    weight_decay=0.1,
    lr=5e-4,
    freeze_decoder=True
)

In [None]:
def train_model(model, data_module, num_epochs, output_dir, name=None):
    """
    Train a model with a given training data loader, validation data loader,
    optimizer, scheduler, loss function, metrics, and callbacks.

    Args:
        model (pl.LightningModule): The model to train.
        data_module (pl.LightningDataModule): The data module to use for training.
        num_epochs (int): The number of epochs to train for.
        output_dir (pathlib.Path): The directory to save the model to.
        name (str): The name of the model.
    Returns:
        best_model_path (str): The path to the best model's checkpoint.
    """
    # pl.seed_everything(115, workers=True)
    wandb_logger = WandbLogger(project="Athena", name=name)
    # saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        dirpath=str(output_dir / "checkpoints"),
        filename="retrofit-{epoch:02d}-{val_loss:.2f}",
        save_top_k=5,
        mode="min",
    )
    trainer = pl.Trainer(
        logger=wandb_logger,
        default_root_dir=str(output_dir / "checkpoints"),
        gpus=torch.cuda.device_count(),
        max_epochs=num_epochs,
        # precision=16,
        callbacks=[
            checkpoint_callback,
            EarlyStopping(monitor="val_loss"),
            TQDMProgressBar(refresh_rate=1),
        ],
    )
    # tracker = EmissionsTracker(output_dir=output_dir.parent.parent, project_name=name)

    # train the model and track emissions
    # tracker.start()
    trainer.fit(model, data_module)
    # tracker.stop()

    # save the best model to wandb
    best_model_path = checkpoint_callback.best_model_path
    # if best_model_path is not None:
    #     wandb.save(best_model_path)

    # # save the emissions csv file
    # wandb.save(str(output_dir.parent.parent / "emissions.csv"))

    return checkpoint_callback.best_model_path

In [None]:
num_epochs = 2
out_dir = Path("/workspace/retrofit/data/output/")
best_model_path = train_model(
    model,
    retro_ds,
    num_epochs=num_epochs,
    output_dir=out_dir / "model",
    name="test",
)
model = RetroFitModel.load_from_checkpoint(best_model_path)