In [None]:
# |default_exp training
# |default_cls_lvl 3

In [None]:
# |hide
%reload_ext autoreload
%autoreload 2

# training

Classes and methods for uniformly training, saving, and tuning our topic segmentation and summarization models. 

In [None]:
# |export
from __future__ import annotations

import abc, inspect, os
from pathlib import Path
import random

from dotenv import load_dotenv
import optuna
import wandb

from course_copilot import utils

In [None]:
# | hide
import pdb

from fastcore.test import *
import nbdev

from blurr.utils import print_versions

In [None]:
# | echo: false
os.environ["TOKENIZERS_PARALLELISM"] = "false"
print("What we're running with at the time this documentation was generated:")
print_versions("torch fastai transformers blurr")

What we're running with at the time this documentation was generated:
torch: 1.12.1+cu102
fastai: 2.7.9
transformers: 4.22.0
blurr: 1.0.5


## Configuration

In [None]:
# | export
load_dotenv()


class TrainConfig:
    training_subset = 1
    val_pct = 0.25
    random_seed = 2022
    only_seed_splits = True
    preprocess_strategy = None

In [None]:
# | export
def get_train_config_props(cfg: TrainConfig):
    log_params = {k: v if not callable(v) else v.__name__ for k, v in inspect.getmembers(cfg) if not k.startswith("__")}
    return log_params

In [None]:
get_train_config_props(TrainConfig)

{'only_seed_splits': True,
 'preprocess_strategy': None,
 'random_seed': 2022,
 'training_subset': 1,
 'val_pct': 0.25}

## Model Trainer

In [None]:
# | export
class ModelTrainer(abc.ABC):
    def __init__(
        self,
        experiment_name,
        train_config: TrainConfig,
        data_path="data",
        model_output_path="models",
        log_output_path="logs",
        log_preds=False,
        log_n_preds=None,
        use_wandb=False,
        verbose=False,
        **kwargs,
    ):
        self.experiment_name = experiment_name
        self.train_config = train_config

        self.data_path = Path(data_path)

        self.model_output_path = Path(model_output_path)
        self.model_output_path.mkdir(parents=True, exist_ok=True)

        self.log_output_path = Path(log_output_path)
        self.log_output_path.mkdir(parents=True, exist_ok=True)
        self.log_preds = log_preds
        self.log_n_preds = log_n_preds
        self.use_wandb = use_wandb

        self.verbose = verbose

        if self.use_wandb:
            wandb.login()
            wandb.init(
                project=os.environ["WANDB_PROJECT_NAME"],
                entity=os.environ["WANDB_TEAM"],
                group=self.experiment_name,
                config=get_train_config_props(self.train_config),
                dir=self.log_output_path,
                reinit=True,
            )

    @abc.abstractmethod
    def get_training_data(self, on_the_fly=False, split_type="cross_validation"):
        pass

    def train(self, trial: optuna.Trial = None):
        if self.use_wandb:
            wandb.finish(quiet=not self.verbose)

    def get_preds(self, model_or_learner, data, **kwargs):
        raise NotImplementedError()

    def tune(self):
        raise NotImplementedError()

    def load_learner_or_model(self, model_learner_fpath: str | Path = None, device="cpu", mode="eval"):
        raise NotImplementedError()

    def get_train_config_props(self):
        return get_train_config_props(self.train_config)

## Export -

In [None]:
# | hide
nbdev.nbdev_export()