# SCM Training
Here, we train a small concept model (SCM) for next-concept prediction. We will use the pre-trained PreNet to invert the embeddings into sentences.

In [None]:
import torch
from modules.inverter import build_inverter, get_encoder
from modules.data import SCMTrainingDataset, get_bookcorpus_for_scm
from modules.scm import SmallConceptModel, GenerativeSCM
from modules.train import train_scm
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Configs
We define configs for model and training hyperparameters.

In [58]:
models_config = {
    "encoder_id": "paraphrase-multilingual-MiniLM-L12-v2",
}

scm_configs = {
    "d_model": 512,
    "embed_dim": 384,
    "nhead": 4,
    "num_layers": 3,
    "dim_feedforward": 4 * 512,
    "dropout": 0.1,
    "max_seq_len": 16,
}

train_configs = {
    "load_weights": "saved_models/scm_v01_multilingual.pth",
    "save_weights": "saved_models/scm_v01_multilingual.pth",
    "lr": 1e-3,
    "weight_decay": 1e-2,
    "max_target_len": 64,
    "embed_batch_size": 128,
    "train_batch_size": 32,
    "sample_data": 0.03,
    "num_epochs": 1,
}

data_configs = {
    "load_cached": None,
}

---

## Models
We initialize and load the encoder, inverter, and SCM models.

In [59]:
encoder = get_encoder(models_config["encoder_id"])
inverter = build_inverter()
scm = SmallConceptModel(**scm_configs).to(device)

if train_configs["load_weights"]:
    scm.load_state_dict(torch.load(train_configs["load_weights"], map_location=device))

## Dataset
We load cached embeddings if available, otherwise we rebuild the dataloader via the `get_bookcorpus_for_scm` function.

In [None]:
if data_configs["load_cached"]:
    embeddings = torch.load(data_configs["load_cached"])
    dataset = SCMTrainingDataset(embeddings)
    dataloader = DataLoader(
        dataset,
        batch_size=train_configs["train_batch_size"],
        shuffle=True,
        drop_last=True
    )
else:
    dataloader = get_bookcorpus_for_scm(
        encoder,
        embed_dim=scm_configs["embed_dim"],
        train_batch_size=train_configs["train_batch_size"],
        embed_batch_size=train_configs["embed_batch_size"],
    )

## Training
Finally, we can train the model using the `train_sm` function.

In [None]:
train_scm(
    scm,
    dataloader,
    scm_configs["max_seq_len"],
    train_configs["lr"],
    train_configs["weight_decay"],
    train_configs["num_epochs"]
)

Save the updated weights of the model.

In [None]:
torch.save(scm.state_dict(), train_configs["save_weights"])

## Inference
We can test the model at inference time using the `GenerativeSCM` class.

In [60]:
gen_scm = GenerativeSCM(scm, encoder, inverter)

sentences = [
    "how are you?",
    "it is about neural networks."
]

gen_scm.generate(sentences, sigma_noise=0.0)

[" how are you ? '' how are you ? '' how are you ? '' how are you ? '' how are you ? '' how are you ? ''",
 " it is a neural network . ''",
 " the way that you describe it , i am sure that you will be able to relate to us . ''",
 " i mean , how can you explain how you got into this ? '' jean said . '' i mean , how can you explain how you got into",
 " the way that the computer programed , i mean , how could you not know that . '' , jason said . '' i mean , how could",
 " we are talking about a computer , and that is what i mean . '' im not sure how to describe it , but i know that it is a",
 " the way that the computer programed , i mean , how many people were in the room ? '' jean said . '' it is a good thing"]