# PreNet Inversion Model Traning
This notebook trains a lightweight `PreNet` network to invert embeddings from any pre-trained sentence-level emebdding model using GPT-2 as a token-level decoder.

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from small_concept_model.inverter import PreNet, Inverter, get_encoder, get_gpt2_decoder
from small_concept_model.data import InverterDataset, get_bookcorpus_inverter

## Configs
Here, we can specify some configuration parameters, such as the encoder model we want to use and the size of the `PreNet`.

In [None]:
ENCODER_ID  : str = "paraphrase-multilingual-MiniLM-L12-v2"

SAMPLE_RATIO: float = 1.0
EMBED_BS    : int   = 256

INPUT_DIM   : int   = 384
OUTPUT_DIM  : int   = 768
RANK        : int   = 128
PREFIX_LEN  : int   = 20
LOAD_WEIGHTS: str   = None

NUM_EPOCHS  : int   = 1
TRAIN_BS    : int   = 64
LEARN_RATE  : float = 1e-3
WEIGHT_DECAY: float = 1e-5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

---

## Dataset
To create the training dataset, we first need to get an encoder and a decoder.

In [None]:
encoder = get_encoder(ENCODER_ID, trainable=False)
decoder, tokenizer = get_gpt2_decoder(trainable=False)

Now we can load and pre-process the dataset using the `get_bookcorpus_inverter` function and wrap it into the `InverterDataset` class.

In [None]:
embeddings, input_ids = get_bookcorpus_inverter(
    encoder=encoder,
    tokenizer=tokenizer,
    embed_batch_size=EMBED_BS,
    sample=SAMPLE_RATIO,
    clean=True
)

dataset = InverterDataset(embeddings, input_ids, tokenizer.eos_token_id)
dataloader = DataLoader(dataset, batch_size=TRAIN_BS, shuffle=True)

## Inverter Model Training
First, we initialize our _PreNet_ using the `PreNet` class and define the loss function and the optimizer.

In [None]:
prenet = PreNet(
    input_dim=INPUT_DIM,
    output_dim=OUTPUT_DIM,
    rank=RANK,
    prefix_len=PREFIX_LEN
).to(device)

if LOAD_WEIGHTS:
    prenet.load_state_dict(torch.load(LOAD_WEIGHTS, map_location=device))

optimizer = torch.optim.Adam(prenet.parameters(), lr=LEARN_RATE, weight_decay=WEIGHT_DECAY)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

Finally, we perform the training loop.

In [None]:
for p in decoder.parameters():
    p.requires_grad = False

for epoch in range(NUM_EPOCHS):
        prenet.train()
        decoder.eval()
        total_train_loss = 0.0

        for idx, (embeddings, input_ids) in tqdm(
            enumerate(dataloader), desc=f"Epoch {epoch+1} [Train]", total=len(dataloader)
        ):
            embeddings = embeddings.to(device)
            input_ids = input_ids.to(device)

            prefix_embeds = prenet(embeddings)
            token_embeds = decoder.transformer.wte(input_ids)
            inputs_embeds = torch.cat([prefix_embeds, token_embeds[:, :-1, :]], dim=1)

            outputs = decoder(inputs_embeds=inputs_embeds)
            logits = outputs.logits[:, prenet.prefix_len :, :]  # shape [B, L, V]
            labels = input_ids[:, 1:]  # shifted targets

            B, L, V = logits.size()
            loss = criterion(logits.reshape(-1, V), labels.reshape(-1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

            if idx % 100 == 0:
                print(f"Epoch [{epoch + 1}/{NUM_EPOCHS}] - Batch [{idx + 1}/{len(dataloader)}] - Loss: {loss.item():.6f}")
                
        avg_epoch_loss = total_train_loss / len(dataloader)


        print(f"*** Epoch [{epoch+1}/{NUM_EPOCHS}] - Loss: {avg_epoch_loss:.6f} ***")

Save the model's weights' checkpoint.

In [None]:
SAVE_WEIGHTS: str = "saved_models/test_prenet_checkpoint.pth"

if SAVE_WEIGHTS:
    torch.save(prenet.state_dict(), SAVE_WEIGHTS)