# SCM Training
Here, we train a small concept model (SCM) for next-concept prediction. We will use the pre-trained PreNet to invert the embeddings into sentences.

In [1]:
import torch
from modules.inverter import build_inverter, get_encoder
from modules.data import SCMTrainingDataset, get_bookcorpus_for_scm
from modules.scm import SmallConceptModel, GenerativeSCM
from modules.train import train_scm
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


## Configs
We define configs for model and training hyperparameters.

In [41]:
models_config = {
    "encoder_id": "paraphrase-multilingual-MiniLM-L12-v2",
}

scm_configs = {
    "d_model": 512,
    "embed_dim": 384,
    "nhead": 4,
    "num_layers": 3,
    "dim_feedforward": 4 * 512,
    "dropout": 0.1,
    "max_seq_len": 16,
}

train_configs = {
    "load_weights": "saved_models/scm_v01_multilingual.pth",
    "save_weights": "saved_models/scm_v01_multilingual.pth",
    "lr": 1e-2,
    "weight_decay": 1e-2,
    "max_target_len": 64,
    "embed_batch_size": 128,
    "train_batch_size": 64,
    "sample_data": 0.03,
    "num_epochs": 5,
}

data_configs = {
    "load_cached": None,
}

---

## Models
We initialize and load the encoder, inverter, and SCM models.

In [42]:
encoder = get_encoder(models_config["encoder_id"])
inverter = build_inverter()
scm = SmallConceptModel(**scm_configs).to(device)

if train_configs["load_weights"]:
    scm.load_state_dict(torch.load(train_configs["load_weights"], map_location=device))

## Dataset
We load cached embeddings if available, otherwise we rebuild the dataloader via the `get_bookcorpus_for_scm` function.

In [36]:
if data_configs["load_cached"]:
    embeddings = torch.load(data_configs["load_cached"])
    dataset = SCMTrainingDataset(embeddings)
    dataloader = DataLoader(
        dataset,
        batch_size=train_configs["train_batch_size"],
        shuffle=True,
        drop_last=True
    )
else:
    dataloader = get_bookcorpus_for_scm(
        encoder,
        embed_dim=scm_configs["embed_dim"],
        train_batch_size=train_configs["train_batch_size"],
        embed_batch_size=train_configs["embed_batch_size"],
    )

Batches: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12500/12500 [02:51<00:00, 72.98it/s]


## Training
Finally, we can train the model using the `train_sm` function.

In [38]:
train_scm(
    scm,
    dataloader,
    scm_configs["max_seq_len"],
    train_configs["lr"],
    train_configs["weight_decay"],
    train_configs["num_epochs"]
)

Epoch [1/5]  Batch [100/3125]  Loss: 0.065444
Epoch [1/5]  Batch [200/3125]  Loss: 0.058237
Epoch [1/5]  Batch [300/3125]  Loss: 0.054097
Epoch [1/5]  Batch [400/3125]  Loss: 0.049704
Epoch [1/5]  Batch [500/3125]  Loss: 0.049940
Epoch [1/5]  Batch [600/3125]  Loss: 0.048512
Epoch [1/5]  Batch [700/3125]  Loss: 0.050204
Epoch [1/5]  Batch [800/3125]  Loss: 0.050097
Epoch [1/5]  Batch [900/3125]  Loss: 0.049509
Epoch [1/5]  Batch [1000/3125]  Loss: 0.048793
Epoch [1/5]  Batch [1100/3125]  Loss: 0.048871
Epoch [1/5]  Batch [1200/3125]  Loss: 0.049683
Epoch [1/5]  Batch [1300/3125]  Loss: 0.049092
Epoch [1/5]  Batch [1400/3125]  Loss: 0.048284
Epoch [1/5]  Batch [1500/3125]  Loss: 0.047964
Epoch [1/5]  Batch [1600/3125]  Loss: 0.047858
Epoch [1/5]  Batch [1700/3125]  Loss: 0.048639
Epoch [1/5]  Batch [1800/3125]  Loss: 0.048534
Epoch [1/5]  Batch [1900/3125]  Loss: 0.047878
Epoch [1/5]  Batch [2000/3125]  Loss: 0.048913
Epoch [1/5]  Batch [2100/3125]  Loss: 0.048436
Epoch [1/5]  Batch [22

Save the updated weights of the model.

In [27]:
torch.save(scm.state_dict(), train_configs["save_weights"])

## Inference
We can test the model at inference time using the `GenerativeSCM` class.

In [43]:
gen_scm = GenerativeSCM(scm, encoder, inverter)

sentences = [
    "questa è una proposta di ricerca universitaria .",
    "parla di reti neurali ."
]

gen_scm.generate(sentences, sigma_noise=0.0)

[" this is a research project the university has initiated . '' said this prospective study , this is an evaluation that involves more than one theory . '' that is",
 " it 's neural networks . '' zebra 's rodents are similar to zebra 's neural networks . '' does gerald 's lab scientist",
 " they 've been studying these interactions since the beginning . ''\n\n'' and eka 've been talking about the concept of divine modality ,",
 " who is the programmer ? '' delaney replied , hoping that the information she gathered will help him defeat the accursed algorithm . '' that is , until",
 ' the reason why it seems like every single other organism in the parallel universe has been studying the process , despite having something to hide , is that it means',
 " '' it is possible to explain this entire process by comparing the living universe to an electronic system , or to a black hole . '' the professor says ,",
 " we , as a nation , are aware of how the geometric systems of science are making t