In [21]:
import math
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

## Configs

In [7]:
DATA_PATH     = "saved_models/train_seq_embeddings.pt"       # path to your (100k, 16, 384) NumPy file
BATCH_SIZE    = 32
NUM_EPOCHS    = 5
LEARNING_RATE = 1e-4
SEQ_LEN       = 16               # total length of each sequence 
EMBED_DIM     = 384              # dimension of each sentence embedding
D_MODEL       = 512              # model dimension (we keep it = EMBED_DIM)
NUM_LAYERS    = 3                # number of Transformer layers
NUM_HEADS     = 4                # number of attention heads
FFN_DIM       = 4 * D_MODEL      # feed‐forward “intermediate” dimension
DROPOUT       = 0.1
DEVICE        = torch.device("cuda" if torch.cuda.is_available() else "cpu")

---

## Data
First, load the cached embeddings.

In [8]:
embeddings = torch.load(DATA_PATH)

Then, we define a custom sentence vector dataset.

In [18]:
class SentenceEmbeddingDataset(Dataset):
    def __init__(self, data: torch.Tensor):
        assert embeddings.ndim == 3
        assert embeddings.shape[1] == SEQ_LEN
        assert embeddings.shape[2] == EMBED_DIM
        
        self.data = data.float()
    
    def __len__(self):
        return self.data.size(0)

    def __getitem__(self, idx):
        seq = self.data[idx]
        input_seq  = seq[: SEQ_LEN - 1, :]  # (15, 384)
        target_seq = seq[1: SEQ_LEN, :]     # (15, 384)
        return input_seq, target_seq

Now create the dataset and dataloader.

In [None]:
dataset = SentenceEmbeddingDataset(embeddings)
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

Also, we need to generate the causal mask we will use to mask future vectors during training.

In [26]:
def generate_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
    return torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()

causal_mask = generate_causal_mask(SEQ_LEN - 1, device=DEVICE)

---

## Model

In [31]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]


class TransformerNextEmbedding(nn.Module):
    def __init__(self,
                 d_model: int,
                 nhead: int,
                 num_layers: int,
                 dim_feedforward: int,
                 dropout: float = 0.1,
                 max_seq_len: int = 5000):
        super().__init__()
        self.input_projection = nn.Linear(EMBED_DIM, d_model)
        self.pos_encoder = PositionalEncoding(
            d_model=d_model,
            max_len=max_seq_len
        )
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )
        self.output_projection = nn.Linear(d_model, EMBED_DIM)
        self.d_model = d_model

    def forward(self, src: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor:
        x = self.input_projection(src)
        x_proj = self.pos_encoder(x) * math.sqrt(self.d_model)
        float_mask = torch.zeros_like(src_mask, dtype=torch.float32)
        float_mask = float_mask.masked_fill(src_mask, float("-1e9"))
        enc_output = self.transformer_encoder(x_proj, mask=float_mask)
        return self.output_projection(enc_output)
    

Initialize the model and define the optimizer and loss.

In [32]:
model = TransformerNextEmbedding(
        d_model=D_MODEL,
        nhead=NUM_HEADS,
        num_layers=NUM_LAYERS,
        dim_feedforward=FFN_DIM,
        dropout=DROPOUT,
        max_seq_len=SEQ_LEN
    ).to(DEVICE)

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-2)
criterion = nn.MSELoss(reduction="none") # we’ll mask the last position manually

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {trainable_params}")

Trainable parameters: 9851264


---

## Training

In [33]:
model.train()

for epoch in range(1, NUM_EPOCHS + 1):
    epoch_loss = 0.0
    n_batches = 0
    for batch_idx, (input_seq, target_seq) in enumerate(train_loader):
        # input_seq: (batch, 15, 384), target_seq: (batch, 15, 384)
        input_seq = input_seq.to(DEVICE)    # (BATCH_SIZE, SEQ_LEN−1,  EMBED_DIM)
        target_seq = target_seq.to(DEVICE)  # (BATCH_SIZE, SEQ_LEN−1,  EMBED_DIM)

        optimizer.zero_grad()
        # 6.6) Forward pass
        output = model(input_seq, causal_mask)
        # output shape: (batch, SEQ_LEN−1, EMBED_DIM)

        # 6.7) Compute loss: MSE over all positions
        # We do not want to predict beyond the provided target. Both output and target have shape (B, 15, 384).
        # So we can do a straightforward MSE.
        loss_tensor = criterion(output, target_seq)  # (B, 15, 384)
        loss = loss_tensor.mean()                    # scalar
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        n_batches += 1

        if (batch_idx + 1) % 100 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}]  "
                f"Batch [{batch_idx+1}/{len(train_loader)}]  "
                f"Loss: {loss.item():.6f}"
            )

    avg_epoch_loss = epoch_loss / n_batches
    print(f"*** Epoch {epoch} Complete.  Avg Loss = {avg_epoch_loss:.6f} ***")

Epoch [1/5]  Batch [100/3125]  Loss: 0.010688
Epoch [1/5]  Batch [200/3125]  Loss: 0.008928
Epoch [1/5]  Batch [300/3125]  Loss: 0.007678
Epoch [1/5]  Batch [400/3125]  Loss: 0.006708
Epoch [1/5]  Batch [500/3125]  Loss: 0.005631
Epoch [1/5]  Batch [600/3125]  Loss: 0.004518
Epoch [1/5]  Batch [700/3125]  Loss: 0.003671
Epoch [1/5]  Batch [800/3125]  Loss: 0.003250
Epoch [1/5]  Batch [900/3125]  Loss: 0.003031
Epoch [1/5]  Batch [1000/3125]  Loss: 0.002799
Epoch [1/5]  Batch [1100/3125]  Loss: 0.002704
Epoch [1/5]  Batch [1200/3125]  Loss: 0.002606
Epoch [1/5]  Batch [1300/3125]  Loss: 0.002565
Epoch [1/5]  Batch [1400/3125]  Loss: 0.002524
Epoch [1/5]  Batch [1500/3125]  Loss: 0.002459
Epoch [1/5]  Batch [1600/3125]  Loss: 0.002450
Epoch [1/5]  Batch [1700/3125]  Loss: 0.002464
Epoch [1/5]  Batch [1800/3125]  Loss: 0.002382
Epoch [1/5]  Batch [1900/3125]  Loss: 0.002398
Epoch [1/5]  Batch [2000/3125]  Loss: 0.002381
Epoch [1/5]  Batch [2100/3125]  Loss: 0.002385
Epoch [1/5]  Batch [22

Optionally, we can save the model weights.

In [35]:
SAVE_MODEL: bool = True

if SAVE_MODEL:
    torch.save(model.state_dict(), "saved_models/scm_v01.pth")

---

## Inference

In [38]:
from sentence_transformers import SentenceTransformer
from modules.prenet import PreNet
from modules.encdec import get_gpt2_decoder

encoder = SentenceTransformer("all-MiniLM-L6-v2")
decoder, tokenizer = get_gpt2_decoder()

prenet = PreNet(
    input_dim=384,
    output_dim=768,
    bottleneck_dim=128,
    prefix_len=20
).to(DEVICE)

prenet.load_state_dict(torch.load("saved_models/prenet_prefix_tuning_bookcorpus.pth", map_location=DEVICE))


def generative_inference(model, initial_sequence, n_future_steps):
    model.eval()

    prefix = initial_sequence.clone().unsqueeze(0).to(DEVICE)  # (1, k, input_dim)
    generated = prefix.clone()

    with torch.no_grad():
        for step in range(n_future_steps):
            current_len = generated.size(1)
            mask = generate_causal_mask(current_len, device=DEVICE)
            out = model(generated, mask)
            next_embed = out[:, -1, :].unsqueeze(1)
            generated = torch.cat([generated, next_embed], dim=1)
    return generated

def vec_to_text(embedding, decoder, tokenizer, prenet, gen_len=50):
    """
    Given input text, encode it, generate prefix via PreNet, and autoregressively decode output text.
    """
    decoder.eval()
    prenet.eval()
    with torch.no_grad():
        prefix = prenet(embedding.unsqueeze(0))  # (1, prefix_len, model_dim)

        generated = prefix  # initial embeddings
        generated_ids = []
        for _ in range(gen_len):
            outputs = decoder(inputs_embeds=generated)
            next_logits = outputs.logits[:, -1, :]
            next_id = torch.argmax(next_logits, dim=-1).unsqueeze(-1)  # greedy
            generated_ids.append(next_id)
            next_embed = decoder.transformer.wte(next_id)
            generated = torch.cat([generated, next_embed], dim=1)

    gen_ids = torch.cat(generated_ids, dim=1)
    return tokenizer.decode(gen_ids[0].cpu().numpy(), skip_special_tokens=True)


In [61]:
future_steps = 4

sentences = [
    "``hello, my name is francesco, and this is my phd project''",
    "he said, entering the univerisity class.",
]
encoded_sentences = encoder.encode(sentences, convert_to_tensor=True)
generated_seq = generative_inference(model, encoded_sentences, future_steps)

In [63]:
for vec in generated_seq.squeeze():
    generated_text = vec_to_text(vec, decoder, tokenizer, prenet, 30)
    print(generated_text)
    print("---")

 , my name is diana . '' i say . '' mia , my name is diana . '' i say . '' mia , my
---
 he said , entering the class . . . . . . . . . . . . . . . . . . . . . . . .
---
 , i 'm going to introduce you to the organization . '' '' '' i 'm going to introduce you to the organization . '' `` i '
---
 , i 'm going to introduce you to the organization . '' '' '' i replied . '' i 'm going to introduce you to the organization .
---
 , i 'm going to introduce you to the chairman of the institute . '' '' '' i replied . '' i 'm going to introduce you to
---
 , i asked him to introduce you to the organization . '' '' i replied . '' i 'm a professor of mathematics . '' '' '' he replied
---
