# Small Concept Model (SCM) Training
Here, we train our `SmallConceptModel` for next-concept prediction.

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from small_concept_model.inverter import get_encoder
from small_concept_model.data import SCMDataset, get_bookcorpus_scm
from small_concept_model.model import SmallConceptModel

## Configs
Here, we can specify some configuration parameters, such as the number of attention heads model we want to use for the `SmallLanguageModel`.

In [4]:
ENCODER_ID  : str = "paraphrase-multilingual-MiniLM-L12-v2"

EMBED_BS    : int   = 128

D_EMBED     : int   = 384
D_MODEL     : int   = 512
D_FF        : int   = 4 * D_MODEL
N_LAYERS    : int   = 4
N_HEADS     : int   = 8
LOAD_WEIGHTS: str   = None

NUM_EPOCHS  : int   = 10
TRAIN_BS    : int   = 128
LEARN_RATE  : float = 1e-3
WEIGHT_DECAY: float = 1e-6

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

---

## Dataset
To create the training dataset, we first need to get an encoder.

In [3]:
encoder = get_encoder(ENCODER_ID, trainable=False)

Now we can load and pre-process the dataset using the `get_bookcorpus_scm` function and wrap it into the `SCMDataset` class.

In [None]:
embeddings = get_bookcorpus_scm(
    encoder=encoder,
    embed_batch_size=EMBED_BS,
    clean=True
)

dataset = SCMDataset(embeddings)
dataloader = DataLoader(dataset, batch_size=TRAIN_BS, shuffle=True)

Extract the average embedding vector for debugging during training, as predictions tend to collapse to it.

In [None]:
flat_embeds = embeddings.view(-1, embeddings.size(-1))
mean_tensor = flat_embeds.mean(dim=0)

## SCM Training
First, we initialize our _SCM_ using the `SmallConceptModel` class and define the loss function and the optimizer.

In [7]:
model = SmallConceptModel(
    d_embed=D_EMBED,
    d_model=D_MODEL,
    n_layers=N_LAYERS,
    n_heads=N_HEADS,
)
model.to(device)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARN_RATE, weight_decay=WEIGHT_DECAY)

Finally, we perform the training loop.

In [None]:
model.train()
for epoch in range(NUM_EPOCHS):
    epoch_loss = 0.0
    pos_sims = 0.0
    sims_with_avg = 0.0

    for idx, (batch_seq, batch_target) in tqdm(enumerate(transformer_dataloader), total=len(transformer_dataloader)):
        optimizer.zero_grad()

        batch_seq = batch_seq.to(device)
        batch_target = batch_target.to(device)

        preds = model(batch_seq)

        pos_sims += F.cosine_similarity(preds, batch_target, dim=-1).mean().item()
        sims_with_avg += F.cosine_similarity(preds, mean_tensor, dim=-1).mean().item()

        random_embed = get_random_embed(transformer_dataset).to(device)
        random_embed = random_embed.unsqueeze(0)

        loss = criterion(preds, batch_target)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item() * batch_seq.size(0)

        if idx % 500 == 0:
            print(f"Epoch [{epoch + 1}/{NUM_EPOCHS}] - Batch [{idx + 1}/{len(transformer_dataloader)}] - Loss: {loss.item():.6f}")

    avg_loss = epoch_loss / len(transformer_dataset)
    avg_pos_sim = pos_sims / len(transformer_dataloader)
    avg_sims_with_avg = sims_with_avg / len(transformer_dataloader)

    print(f"*** Epoch [{epoch + 1}/{NUM_EPOCHS}] - Loss: {avg_loss:.6f} - Pos Sim: {avg_pos_sim:.6f} - Sim with Avg: {avg_sims_with_avg:.6f} ***")

Save the model's weights' checkpoint.

In [None]:
SAVE_WEIGHTS: str = "saved_models/test_scm_checkpoint"

if SAVE_WEIGHTS:
    torch.save(prenet.state_dict(), SAVE_WEIGHTS)