In [1]:
import torch
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer
from small_concept_model.model import SmallConceptModel
from small_concept_model.pipeline import Pipeline
from small_concept_model.inverter import PreNet, Inverter, get_encoder, get_gpt2_decoder
from small_concept_model.train import train_scm, train_inverter
from small_concept_model.data import get_bookcorpus_scm, get_bookcorpus_inverter
from small_concept_model.auto import build_scm, build_inverter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [194]:
scm_configs = {
    "d_model": 384,
    "d_embed": 384,
    "d_ff": 4 * 384,
    "n_heads": 4,
    "n_layers": 3,
    "dropout": 0.1,
    "max_seq_len": 16
}

prenet_configs = {
    "input_dim": 384,
    "output_dim": 768,
    "rank": 128,
    "prefix_len": 20,
}

train_configs = {
    "lr": 1e-4,
    "weight_decay": 0,
    "batch_size": 128,
    "num_epochs": 5
}

In [5]:
encoder = get_encoder("paraphrase-multilingual-MiniLM-L12-v2")
prenet = PreNet(**prenet_configs).to(device)
decoder, tokenizer = get_gpt2_decoder()

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


In [6]:
data = get_bookcorpus_inverter(
    encoder, tokenizer, max_target_len=64, embed_batch_size=256, sample=0.1, clean=True
)

Cleaning texts...
Tokenizing the texts...


Batches: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:09<00:00, 41.56it/s]


In [84]:
train_inverter(prenet, decoder, tokenizer, data, **train_configs)

Epoch 1 [Train]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1563/1563 [05:16<00:00,  4.94it/s]

*** Epoch 0 Complete.  Avg Loss = 2.259933 ***





In [113]:
torch.save(prenet.state_dict(), "saved_models/prenet/prenet_100k_good.pth")

In [114]:
prenet.load_state_dict(torch.load("saved_models/prenet/prenet_100k_good.pth", map_location=device))

<All keys matched successfully>

In [115]:
inverter = Inverter(prenet, decoder, tokenizer)

In [117]:
sample_text = "\"You were never the problem,\" he said, \"and you know it\"."

vec = encoder.encode(sample_text, convert_to_tensor=True)
inverter.invert(
    vec, max_len=50, temperature=0.4, repetition_penalty=1.2
)

'You know, you never knew the problem," he said.'

---

In [4]:
encoder = SentenceTransformer("paraphrase-multilingual-MiniLM-L12-v2", device="cuda")
dataset = get_bookcorpus_scm(encoder, 32)

Batches: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50000/50000 [10:46<00:00, 77.31it/s]


Train the model.

In [22]:
from datasets import load_dataset

data_x = load_dataset("francescoortame/bookcorpus-rand-1M", split="train")

In [24]:
texts = data_x["text"]

In [25]:
from tqdm import tqdm
from small_concept_model.utils import clean_text

clean_texts = []

for t in tqdm(texts, total=len(texts)):
    clean_texts.append(clean_text(t))

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000000/1000000 [00:05<00:00, 176557.61it/s]


In [30]:
clean_texts = [t + tokenizer.eos_token for t in clean_texts]

In [32]:
clean_texts[237920]

'"What shall we do to commemorate our first day of coupledom?"<|endoftext|>'

# SCM Training

In [2]:
encoder = get_encoder("paraphrase-multilingual-MiniLM-L12-v2")

dataset = get_bookcorpus_scm(
    encoder,
    embed_batch_size=128,
    clean=True
)

Batches: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12500/12500 [02:45<00:00, 75.42it/s]


In [126]:
model = SmallConceptModel(**scm_configs).to(device)

In [89]:
train_scm(model, dataset, **train_configs)

Epoch [1/1]  Batch [100/1563]  Loss: 0.053620
Epoch [1/1]  Batch [200/1563]  Loss: 0.049489
Epoch [1/1]  Batch [300/1563]  Loss: 0.047784
Epoch [1/1]  Batch [400/1563]  Loss: 0.046954
Epoch [1/1]  Batch [500/1563]  Loss: 0.047889
Epoch [1/1]  Batch [600/1563]  Loss: 0.046675
Epoch [1/1]  Batch [700/1563]  Loss: 0.047618
Epoch [1/1]  Batch [800/1563]  Loss: 0.046970
Epoch [1/1]  Batch [900/1563]  Loss: 0.046868
Epoch [1/1]  Batch [1000/1563]  Loss: 0.047154
Epoch [1/1]  Batch [1100/1563]  Loss: 0.046410
Epoch [1/1]  Batch [1200/1563]  Loss: 0.046313
Epoch [1/1]  Batch [1300/1563]  Loss: 0.046832
Epoch [1/1]  Batch [1400/1563]  Loss: 0.046077
Epoch [1/1]  Batch [1500/1563]  Loss: 0.046274
*** Epoch 1 Complete.  Avg Loss = 0.048678 ***


In [90]:
prenet = PreNet(**prenet_configs).to(device)
prenet.load_state_dict(torch.load("saved_models/prenet/prenet_100k_good.pth", map_location=device))

decoder, tokenizer = get_gpt2_decoder()
inverter = Inverter(prenet, decoder, tokenizer)

In [103]:
pipe = Pipeline(encoder, model, inverter)

In [129]:
texts = [
    'Lexi stretched her arms.',
    'She heard the door open, and soft voices echoed down the hall toward her.',
]

pipe.generate(
    texts,
    n_future_steps = 5,
    sigma_noise = 0.0,
    temperature = 0.0,
    max_len = 30
)

['elle stretched, tugging her arms.',
 ' she heard the quiet of the room, the soft footsteps of the door.',
 'and then she glanced at the floor, her eyes wide with confusion.',
 'and then she glanced at the floor, her eyes wide with confusion.',
 'and then she glanced at the floor, and then at the floor, and then at the floor, and then at the floor, and then at the',
 'and then she glanced at the floor, and then at the man she had just met.',
 'and then she glanced at her husband, who was staring at her with a confused expression.']

In [57]:
inverter.invert(x)

'mike, who had been standing in the doorway, glanced at her, then turned to the other side of the room, where she was standing.'

## Pipeline

In [10]:
model = model.to("cuda")
inverter = build_inverter("paraphrase_multilingual")
pipe = Pipeline(encoder, model, inverter)

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


In [11]:
texts = [
    'he asked her if she was hungry.',
    'she never heard that before.',
]

pipe.generate(
    texts,
    n_future_steps = 5,
    sigma_noise = 0.0,
    temperature = 0.0,
    max_len = 30
)

[" she asked him if he was hungry . '' he replied . '' she asked him if he was hungry . '' she asked him if he was hungry .",
 " she never heard that before . '\n\nBut she never heard that before . '\n\nBut she never heard that before . '\n\nBut",
 " she said , and then she said , '' yes , she should be able to talk about it . '' . '' she said , and then she said",
 " she said , and then she gave him a small smile . '' she said , and then she ate the rest of the food . '' she said ,",
 " she was glad to have him as her companion , but she knew that she would have to eat some of the food she had been given . '' she",
 " she was glad to have been able to eat the food , but she was not sure how to describe it . '' she said , her voice soft and",
 ' she was glad to have been able to eat the food , but she was also pleased to know that she was not the only one who had been pleased']

---

In [3]:
import math
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import Optional


class PositionalEncoding(nn.Module):
    """Standard sinusoidal positional encoding."""

    def __init__(self, d_model: int, max_len: Optional[int] = 128):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        denominator = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float32)
            * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * denominator)
        pe[:, 1::2] = torch.cos(position * denominator)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]


class InputProj(nn.Module):
    def __init__(self, d_embed, d_model, scaler_mean, scaler_std):
        super().__init__()
        # scaler_mean, scaler_std: each is a [d_embed]-shaped tensor
        self.register_buffer("mean", scaler_mean)   # shape: [d_embed]
        self.register_buffer("std",  scaler_std)    # shape: [d_embed]
        self.linear = nn.Linear(d_embed, d_model)

    def normalize(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, T, d_embed]
        # subtract/scale each dimension separately
        return (x - self.mean.unsqueeze(0).unsqueeze(0)) / self.std.unsqueeze(0).unsqueeze(0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.normalize(x)            # [B, T, d_embed] → zero‐centered (per‐dim)
        return self.linear(x)            # → [B, T, d_model]



class OutputProj(nn.Module):
    def __init__(self, d_model, d_embed, scaler_mean, scaler_std):
        super().__init__()
        self.register_buffer("mean", scaler_mean)   # [d_embed]
        self.register_buffer("std",  scaler_std)    # [d_embed]
        self.linear = nn.Linear(d_model, d_embed)

    def denormalize(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, T, d_embed] in “normalized space”
        return x * self.std.unsqueeze(0).unsqueeze(0) + self.mean.unsqueeze(0).unsqueeze(0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear(x)             # [B, T, d_embed]
        return self.denormalize(x)     # map back to “real” embedding distribution



class Transformer(nn.Module):
    """Transformer encoder with causal masking."""

    def __init__(
        self,
        d_model: int,
        d_ff: int,
        n_heads: Optional[int] = 4,
        n_layers: Optional[int] = 3,
        dropout: Optional[float] = 0.1,
    ):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_ff,
            dropout=dropout,
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer=encoder_layer, num_layers=n_layers
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        _, seq_len, _ = x.size()
        bool_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1)
        bool_mask = bool_mask.to(x.device)
        return self.transformer(x, bool_mask)


class SmallConceptModel(nn.Module):
    """Autoregressive transformer-based concept model."""

    def __init__(
        self,
        d_model: int,
        d_embed: int,
        d_ff: int,
        n_heads: Optional[int] = 4,
        n_layers: Optional[int] = 3,
        dropout: Optional[float] = 0.1,
        max_seq_len: Optional[int] = 128,
        scaler_mean: Optional[float] = 0.0,
        scaler_std: Optional[float] = 1.0,
    ):
        super().__init__()
        self.d_model = d_model
        self.input_projection = InputProj(d_embed, d_model, scaler_mean, scaler_std)
        self.pos_encoder = PositionalEncoding(d_model, max_seq_len)
        self.transformer = Transformer(d_model, d_ff, n_heads, n_layers, dropout)
        self.output_projection = OutputProj(d_model, d_embed, scaler_mean, scaler_std)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.input_projection(x) * math.sqrt(self.d_model)
        x = self.pos_encoder(x)
        x = self.transformer(x)
        return self.output_projection(x)


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from typing import Optional
from small_concept_model.data import InverterDataset, SCMDataset
from small_concept_model.inverter import PreNet
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from small_concept_model.model import SmallConceptModel
from tqdm import tqdm

def combined_mse_cosine_loss(
        preds: torch.Tensor,
        targets: torch.Tensor,
        lambda_cos: float = 0.5
    ) -> torch.Tensor:
    """
    preds:   (B, L, D_embed)  predicted embeddings
    targets: (B, L, D_embed)  ground-truth embeddings
    lambda_cos: weight on the (1 - cosine) term. 
                Total loss = lambda_cos * (1 - cos) + (1-lambda_cos) * MSE.

    Returns: mean loss over all B * L elements (a scalar).
    """
    # 1) Compute MSE term (per-coordinate)
    mse_per_coord = F.mse_loss(preds, targets, reduction="none")  # shape (B, L, D_embed)
    mse_per_vector = mse_per_coord.mean(dim=-1)                     # shape (B, L), avg over D_embed
    mse_term = mse_per_vector.mean()                                # scalar: avg over B * L

    # 2) Compute cosine term
    #    Flatten B*L so we can use F.cosine_similarity on shape ((B*L), D_embed)
    B, L, D = preds.shape
    preds_flat   = preds.view(B * L, D)    # shape (B*L, D_embed)
    targets_flat = targets.view(B * L, D)  # shape (B*L, D_embed)

    #    cosine_similarity returns shape (B*L,), values in [-1, +1]
    cos_sim = F.cosine_similarity(preds_flat, targets_flat, dim=-1, eps=1e-8)  # (B*L,)
    cos_dist = 1.0 - cos_sim                                                   # (B*L,)

    cosine_term = cos_dist.mean()  # scalar

    # 3) Combine
    loss = lambda_cos * cosine_term + (1.0 - lambda_cos) * mse_term
    return loss

def train_scm(
    model: SmallConceptModel,
    train_dataset: SCMDataset,
    lr: Optional[float] = 1e-3,
    weight_decay: Optional[float] = 1e-2,
    batch_size: Optional[int] = 32,
    num_epochs: Optional[int] = 1,
    schedule_length: int = 3,   # # of epochs to go from ε=0 → ε_max
    eps_max: float = 0.5,
):
    """Train the SCM for next-embedding prediction."""

    train_loader = DataLoader(train_dataset, batch_size=batch_size)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.train()

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    mse_loss = torch.nn.MSELoss(reduction="none")

    for epoch in range(1, num_epochs + 1):
        epoch_loss = 0.0
        n_batches = 0

        for batch_idx, (input_seq, target_seq) in enumerate(train_loader):
            input_seq = input_seq.to(device)
            target_seq = target_seq.to(device)
            
            output = model(input_seq)
            loss = combined_mse_cosine_loss(output, target_seq, lambda_cos=1e-9)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            n_batches += 1

            if (batch_idx + 1) % 100 == 0:
                print(
                    f"Epoch [{epoch}/{num_epochs}]  "
                    f"Batch [{batch_idx+1}/{len(train_loader)}]  "
                    f"Loss: {loss.item():.6f}"
                )

        avg_epoch_loss = epoch_loss / n_batches
        print(f"*** Epoch {epoch} Complete.  Avg Loss = {avg_epoch_loss:.6f} ***")


In [7]:
tensor_list = [dataset[i][0] for i in range(len(dataset))]
all_embeddings = torch.stack(tensor_list)

flat = all_embeddings.view(-1, 384)        # → [1_600_000, d_embed]
mean_vec = flat.mean(dim=0)                    # [d_embed]
std_vec  = flat.std(dim=0, unbiased=False)

scm_configs = {
    "d_model": 384,
    "d_embed": 384,
    "d_ff": 4 * 384,
    "n_heads": 4,
    "n_layers": 3,
    "dropout": 0.1,
    "max_seq_len": 16,
    "scaler_mean": mean_vec,
    "scaler_std": std_vec
}

model = SmallConceptModel(**scm_configs).to(device)

  self.register_buffer("mean", torch.tensor(scaler_mean))
  self.register_buffer("std", torch.tensor(scaler_std))
  self.register_buffer("mean", torch.tensor(scaler_mean))
  self.register_buffer("std", torch.tensor(scaler_std))


In [8]:
train_scm(
    model,
    dataset,
    lr=1e-3,
    weight_decay=0,
    batch_size=128,
    num_epochs=3,
    schedule_length=3,
    eps_max=0.6
)

Epoch [1/3]  Batch [100/782]  Loss: 0.046695
Epoch [1/3]  Batch [200/782]  Loss: 0.046569
Epoch [1/3]  Batch [300/782]  Loss: 0.046447
Epoch [1/3]  Batch [400/782]  Loss: 0.046278
Epoch [1/3]  Batch [500/782]  Loss: 0.046586
Epoch [1/3]  Batch [600/782]  Loss: 0.046114
Epoch [1/3]  Batch [700/782]  Loss: 0.045864
*** Epoch 1 Complete.  Avg Loss = 0.046646 ***
Epoch [2/3]  Batch [100/782]  Loss: 0.045720
Epoch [2/3]  Batch [200/782]  Loss: 0.046038
Epoch [2/3]  Batch [300/782]  Loss: 0.046060
Epoch [2/3]  Batch [400/782]  Loss: 0.046071
Epoch [2/3]  Batch [500/782]  Loss: 0.046395
Epoch [2/3]  Batch [600/782]  Loss: 0.045955
Epoch [2/3]  Batch [700/782]  Loss: 0.045638
*** Epoch 2 Complete.  Avg Loss = 0.046098 ***
Epoch [3/3]  Batch [100/782]  Loss: 0.045557
Epoch [3/3]  Batch [200/782]  Loss: 0.045925
Epoch [3/3]  Batch [300/782]  Loss: 0.045977
Epoch [3/3]  Batch [400/782]  Loss: 0.045961
Epoch [3/3]  Batch [500/782]  Loss: 0.046294
Epoch [3/3]  Batch [600/782]  Loss: 0.045829
Epoch 

In [9]:
texts = [
    'he asked her if she was hungry.',
    'she never heard that before.',
]

pipe.generate(
    texts,
    n_future_steps = 5,
    sigma_noise = 0.0,
    temperature = 0.0,
    max_len = 30
)

NameError: name 'pipe' is not defined

In [None]:
tensor_list = [dataset[i][0] for i in range(len(dataset))]

In [188]:
c = torch.stack(tensor_list)
d = c.view(-1, 384)

global_mean = d.mean(dim=0)

In [189]:
# shift targets:
X = c[:, :14, :]   # your input windows (ground truth)
Y = c[:, 1:, :]    # “true” next embeddings

n_total = float(Y.numel())
baseline_mse = ((Y - global_mean.unsqueeze(0).unsqueeze(0))**2).sum() / n_total
print("Mean-predictor MSE:", baseline_mse.item())


Mean-predictor MSE: 0.04862626641988754


In [177]:
c.view(-1, 384).size()

torch.Size([30, 384])