# PreNet Training for Embedding Inversion
Here, we train our `PreNet` model for embedding to text inversion. We will use a sample of the `bookcorpus` dataset to train it.

In [None]:
import torch
from modules.inverter import PreNet, Inverter, get_encoder, get_gpt2_decoder
from modules.data import get_bookcorpus_for_inversion
from modules.train import train_inversion_model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Configs
We define the model and training hyperparameters.

In [None]:
models_config = {
    "encoder_id": "all-MiniLM-L6-v2",
}

prenet_configs = {
    "input_dim": 384,
    "output_dim": 768,
    "rank": 128,
    "prefix_len": 20,
}

train_configs = {
    "load_weights": "saved_models/prenet_prefix_tuning_bookcorpus.pth",
    "save_weights": "saved_models/prenet_prefix_tuning_bookcorpus.pth",
    "lr": 1e-3,
    "max_target_len": 64,
    "embed_batch_size": 32,
    "train_batch_size": 32,
    "sample_data": 0.03,
    "num_epochs": 1,
}

---

## Models
We initialize and load the encoder, PreNet, and decoder mdoels.

In [None]:
encoder = get_encoder(models_config["encoder_id"])
decoder, tokenizer = get_gpt2_decoder()

prenet = PreNet(**prenet_configs).to(device)

if train_configs["load_weights"]:
    prenet.load_state_dict(
        torch.load(train_configs["load_weights"], map_location=device)
    )

## Dataset
We load the dataset using the `get_bookcorpus_for_inversion` function.

In [None]:
train_loader, val_loader = get_bookcorpus_for_inversion(
    encoder,
    tokenizer,
    train_configs["max_target_len"],
    train_configs["train_batch_size"],
    train_configs["embed_batch_size"],
    train_configs["sample_data"],
)

## Training
Finally, we can train the model using the `train_inversion_model` function.

In [None]:
train_inversion_model(
    prenet,
    decoder,
    tokenizer,
    train_loader,
    val_loader,
    train_configs["lr"],
    train_configs["num_epochs"],
)

Save the updated weights of the model.

In [None]:
if train_configs["save_weights"]:
    torch.save(prenet.state_dict(), train_configs["save_weights"])

## Inference
Test the model at inference time using the `Inverter` class.

In [None]:
inverter = Inverter(prenet, decoder, tokenizer)

sample_text = "he came back from the beach , quietly ."
embedding = encoder.encode(sample_text, convert_to_tensor=True)

print(inverter.invert(embedding, max_len=30))