# PreNet Training for Embedding Inversion
Here, we train our `PreNet` model for embedding to text inversion. We will use a sample of the `bookcorpus` dataset to train it.

In [None]:
import torch
from small_concept_model.inverter import PreNet, get_encoder, get_gpt2_decoder, Inverter
from small_concept_model.data import get_bookcorpus_inverter
from small_concept_model.train import train_inverter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Configs
We define the model and training hyperparameters.

In [None]:
models_config = {
    "encoder_id": "paraphrase-multilingual-MiniLM-L12-v2",
}

prenet_configs = {
    "input_dim": 384,
    "output_dim": 768,
    "rank": 128,
    "prefix_len": 20,
}

data_configs = {
    "max_target_len": 64,
    "embed_batch_size": 256,
    "sample": 0.01,
    "clean": True
}

train_configs = {
    "lr": 1e-4,
    "weight_decay": 0,
    "batch_size": 32,
    "num_epochs": 5
}

load_weights: str = "saved_models/prenet/prenet_100k_good.pth"
save_weights: str = None

---

## Models
We initialize and load the encoder, PreNet, and decoder mdoels.

In [19]:
encoder = get_encoder(models_config["encoder_id"])
decoder, tokenizer = get_gpt2_decoder()

prenet = PreNet(**prenet_configs).to(device)
if load_weights:
    prenet.load_state_dict(torch.load(load_weights, map_location=device))

## Data
Automatically load and preprocess the data for training.

In [None]:
data = get_bookcorpus_inverter(encoder, tokenizer, **data_configs)

## Training
Finally, we can train the model using the `train_inverter` function.

In [None]:
train_inverter(prenet, decoder, tokenizer, data, **train_configs)

if save_weights:
    torch.save(prenet.state_dict(), save_weights)

---

## Inference
We can empirically test how good the model is at inference time for embedding inversion. First, we wrap the prenet, decoder, and tokenizer into a unique `Inverter` class, which contains handy methods.

In [21]:
inverter = Inverter(prenet, decoder, tokenizer)

Now, we can write a sample text, encoding it into an embedding vector and inverting the embedding.

In [63]:
sample_text = "I never thought wthis would be possible."

vec = encoder.encode(sample_text, convert_to_tensor=True)
inverter.invert(
    vec, max_len=50, temperature=0.6, repetition_penalty=1.1
)

' I dont think it would ever happen.'