# Tutorial

> API details.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
import pytorch_lightning as pl

from retrofit.data import RetroDataset
from retrofit.model import RetroFitModel
from pytorch_lightning.loggers import WandbLogger

In [None]:
def train_model(model, data_module, num_epochs, output_dir, name=None):
    """
    Train a model with a given training data loader, validation data loader,
    optimizer, scheduler, loss function, metrics, and callbacks.

    Args:
        model (pl.LightningModule): The model to train.
        data_module (pl.LightningDataModule): The data module to use for training.
        num_epochs (int): The number of epochs to train for.
        output_dir (pathlib.Path): The directory to save the model to.
        name (str): The name of the model.
    Returns:
        best_model_path (str): The path to the best model's checkpoint.
    """
    pl.seed_everything(115, workers=True)
    wandb_logger = WandbLogger(project="Athena", name=name)
    # saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        dirpath=str(output_dir / "checkpoints"),
        filename="athena-{epoch:02d}-{val_loss:.2f}",
        save_top_k=5,
        mode="min",
    )
    trainer = pl.Trainer(
        logger=wandb_logger,
        default_root_dir=str(output_dir / "checkpoints"),
        gpus=torch.cuda.device_count(),
        max_epochs=num_epochs,
        precision=16,
        callbacks=[
            checkpoint_callback,
            EarlyStopping(monitor="val_loss"),
            TQDMProgressBar(refresh_rate=1),
        ],
    )
    tracker = EmissionsTracker(output_dir=output_dir.parent.parent, project_name=name)

    # train the model and track emissions
    tracker.start()
    trainer.fit(model, data_module)
    tracker.stop()

    # save the best model to wandb
    best_model_path = checkpoint_callback.best_model_path
    if best_model_path is not None:
        wandb.save(best_model_path)

    # save the emissions csv file
    wandb.save(str(output_dir.parent.parent / "emissions.csv"))

    return checkpoint_callback.best_model_path

In [None]:
# export
class RetroFitModel(pl.LightningModule):
    def __init__(
        self,
        encoder_name,
        decoder_name,
        column,
        weight_decay=0.1
        lr=5e-4,
        freeze_decoder=True
    ):
        self.encoder_tokenizer = AutoTokenizer.from_pretrained(encoder_name)
        self.decoder_tokenizer = AutoTokenizer.from_pretrained(decoder_name)
        self.column = column
        self.weight_decay = weight_decay
        self.lr = lr
        self.lr_scheduler_type = lr_scheduler_type
        self.freeze_decoder = freeze_decoder
        self.model = EncoderDecoderModel.from_encoder_decoder_pretrained(encoder_name, decoder_name)

    def training_step(self, batch, batch_idx):
        retrieved = batch["retrieved_examples"]
        retrieved = [encoder_tokenizer.eos_token.join(r) for r in retrieved]
        input_ids = tokenizer(retrieved, return_tensors="pt").input_ids
        labels = tokenizer(batch[self.column], return_tensors="pt").input_ids
        loss = model(input_ids=input_ids, labels=input_ids).loss

        self.log("trn_loss", loss, on_step=True, on_epoch=True, logger=True)
        return loss
    
    def get_grouped_params(self, model, no_decay=["bias", "LayerNorm.weight"]):
        params_with_wd, params_without_wd = [], []
        for n, p in model.named_parameters():
            if any(nd in n for nd in no_decay):
                params_without_wd.append(p)
            else:
                params_with_wd.append(p)
        return [
            {"params": params_with_wd, "weight_decay": self.weight_decay},
            {"params": params_without_wd, "weight_decay": 0.0},
        ]

    def configure_optimizers(self):
        # Prepare the optimizer and learning rate scheduler
        param_model = self.model if not freeze_decoder else self.model.encoder
        optimizer = AdamW(self.get_grouped_params(param_model), lr=self.lr)
        return optimizer

In [None]:
# hide
from nbdev.export import notebook2script

notebook2script()