# SCM Training
Here, we train a small concept model (SCM) for next-concept prediction. We will use the pre-trained PreNet to invert the embeddings into sentences.

In [1]:
import torch
from modules.inverter import build_inverter, get_encoder
from modules.data import SCMTrainingDataset, get_bookcorpus_for_scm
from modules.scm import SmallConceptModel, GenerativeSCM
from modules.train import train_scm
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


## Configs
We define configs for model and training hyperparameters.

In [2]:
models_config = {
    "encoder_id": "BAAI/bge-m3",
}

scm_configs = {
    "d_model": 768,
    "embed_dim": 1024,
    "nhead": 8,
    "num_layers": 6,
    "dim_feedforward": 4 * 512,
    "dropout": 0.1,
    "max_seq_len": 16,
}

train_configs = {
    "load_weights": None,
    "save_weights": "saved_models/scm_v01_BGEM3.pth",
    "lr": 1e-3,
    "weight_decay": 1e-2,
    "max_target_len": 64,
    "embed_batch_size": 32,
    "train_batch_size": 32,
    "sample_data": 0.03,
    "num_epochs": 1,
}

data_configs = {
    "load_cached": None,
}

---

## Models
We initialize and load the encoder, inverter, and SCM models.

In [3]:
encoder = get_encoder(models_config["encoder_id"])
inverter = build_inverter()
scm = SmallConceptModel(**scm_configs).to(device)

if train_configs["load_weights"]:
    scm.load_state_dict(torch.load(train_configs["load_weights"], map_location=device))

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


## Dataset
We load cached embeddings if available, otherwise we rebuild the dataloader via the `get_bookcorpus_for_scm` function.

In [None]:
if data_configs["load_cached"]:
    embeddings = torch.load(data_configs["load_cached"])
    dataset = SCMTrainingDataset(embeddings)
    dataloader = DataLoader(
        dataset,
        batch_size=train_configs["train_batch_size"],
        shuffle=True,
        drop_last=True
    )
else:
    dataloader = get_bookcorpus_for_scm(
        encoder,
        train_configs["train_batch_size"],
        train_configs["embed_batch_size"],
    )

## Training
Finally, we can train the model using the `train_sm` function.

In [None]:
train_scm(
    scm,
    dataloader,
    scm_configs["max_seq_len"],
    train_configs["lr"],
    train_configs["weight_decay"],
    train_configs["num_epochs"]
)

Save the updated weights of the model.

In [None]:
torch.save(scm.state_dict(), "saved_models/scm_v01.pth")

## Inference
We can test the model at inference time using the `GenerativeSCM` class.

In [None]:
gen_scm = GenerativeSCM(scm, encoder, inverter)

sentences = [
    "this is a phd research proposal .",
    "it is about neural networks ."
]

gen_scm.generate(sentences)