## Embedding Space Exploration

In [43]:
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from small_concept_model.inverter import get_encoder
from small_concept_model.data import clean_text, SCMDataset
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
data = load_dataset("francescoortame/bookcorpus-sorted-100k16x", split="train")
flat_texts = [t for sublist in data["slice"] for t in sublist]
flat_texts = [clean_text(t) for t in flat_texts]

In [None]:
encoder = get_encoder("paraphrase-multilingual-MiniLM-L12-v2")
embeddings = encoder.encode(flat_texts, convert_to_tensor=True, show_progress_bar=True)

In [None]:
stds = embeddings.std(dim=0, unbiased=True)

sorted_dims = torch.argsort(stds, descending=True)

print("Top 10 dims by standard deviation:")
for rank, dim_idx in enumerate(sorted_dims[:10]):
    print(f"  rank {rank+1:>2}: dim {dim_idx.item():>3} (std = {stds[dim_idx].item():.4f})")

In [None]:
reshaped_embeddings = embeddings.contiguous().view(100000, 16, 384)
dataset = SCMDataset(reshaped_embeddings)

## SCM Definition

In [217]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        
        pe = pe.unsqueeze(1)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        T, B, D = x.size()
        x = x + self.pe[:T]
        return x

def generate_causal_mask(sz: int, device: torch.device) -> torch.Tensor:
    mask = torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)
    return mask.to(device)


class SmallConceptModel(nn.Module):
    def __init__(
        self,
        d_model: int = 384,
        embed_dim: int = 384,
        nhead: int = 8,
        num_layers: int = 6,
        dim_feedforward: int = 384 * 4,
        dropout: float = 0.1,
        max_seq_len: int = 64
    ):
        super().__init__()
        self.d_model = d_model

        self.input_proj = nn.Linear(embed_dim, d_model, bias=True)

        self.pos_encoder = PositionalEncoding(d_model, max_len=max_seq_len)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation='gelu',
            batch_first=False
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )

        self.output_proj = nn.Linear(d_model, embed_dim, bias=True)
        self.register_buffer('dummy_mask', torch.zeros(1))

    def forward(self, input_seq: torch.Tensor) -> torch.Tensor:
        B, T, D = input_seq.shape
        device = input_seq.device
        x = input_seq.permute(1, 0, 2)
        x = self.input_proj(x)
        x = self.pos_encoder(x)
        causal_mask = generate_causal_mask(T, device=device)
        encoded = self.transformer_encoder(x, mask=causal_mask)
        output = self.output_proj(encoded)
        output = output.permute(1, 0, 2)
        return output

#### Loss Functions

In [223]:
stds = embeddings.std(dim=0, unbiased=False)
sigma_min = stds.min().item()
sigma_max = stds.max().item()

if sigma_max - sigma_min < 1e-12:
    weights = torch.ones_like(stds) * 1e-3
else:
    weights = (stds - sigma_min) / (sigma_max - sigma_min)
    epsilon = 1e-6
    weights = torch.clamp(weights, min=epsilon)


class WeightedMSELoss(nn.Module):
    def __init__(self, weight_vector: torch.Tensor):
        super().__init__()
        self.register_buffer('w', weight_vector.view(1, -1))

    def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        if predictions.shape != targets.shape:
            raise ValueError(f"predictions and targets must have same shape. "
                             f"Got {predictions.shape} vs {targets.shape}.")
        se = (predictions - targets) ** 2
        weighted_se = se * self.w 
        return weighted_se.mean()
    

class MSELossWithAvgPenalty(nn.Module):
    def __init__(self, avg_vector: torch.Tensor):
        super().__init__()
        self.register_buffer('avg', avg_vector)
    
    def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        if predictions.shape != targets.shape:
            raise ValueError(f"predictions and targets must have same shape. "
                             f"Got {predictions.shape} vs {targets.shape}.")
        
        B, T, D = predictions.shape

        pred_flat = predictions.reshape(-1, D)                                 # [B*T, 384]
        pred_flat_norm = F.normalize(pred_flat, p=2, dim=1)           # [B*T, 384]
        avg_expanded = self.avg.unsqueeze(0).expand(pred_flat_norm.size(0), -1)
        cos_sims = (pred_flat_norm * avg_expanded).sum(dim=1)         # [B*T]
        batch_avg_cos_sim = cos_sims.mean().item()
        
        se = (predictions - targets) ** 2

        loss = se.mean() + batch_avg_cos_sim

        return loss

class AntiAverageLoss(torch.nn.Module):
    """
    Loss function that INCREASES when predictions get closer to the average embedding.
    This will force the model to predict anything BUT the average.
    If your model still collapses with this loss, there's definitely a bug.
    """
    def __init__(self, avg_embedding, penalty_weight=1.0):
        super().__init__()
        self.register_buffer('avg_embedding', avg_embedding)
        self.penalty_weight = penalty_weight
    
    def forward(self, predictions, targets):
        # predictions: [B, T, D] or [B*T, D]
        # targets: [B, T, D] or [B*T, D] 
        
        # Flatten if needed
        if predictions.dim() == 3:
            pred_flat = predictions.reshape(-1, predictions.size(-1))
            tgt_flat = targets.reshape(-1, targets.size(-1))
        else:
            pred_flat = predictions
            tgt_flat = targets
        
        # Normalize vectors
        pred_norm = F.normalize(pred_flat, p=2, dim=1)
        avg_norm = F.normalize(self.avg_embedding, p=2, dim=0)
        
        # Cosine similarity with average (higher = more similar to average)
        cos_sim_with_avg = torch.mm(pred_norm, avg_norm.unsqueeze(1)).squeeze()
        
        # Convert to penalty: higher similarity = higher loss
        # Use sigmoid to bound the penalty and make it smooth
        avg_penalty = torch.sigmoid(cos_sim_with_avg * 5)  # Scale factor makes it more sensitive
        
        # Main loss: encourage predictions to match targets
        main_loss = 1 - F.cosine_similarity(pred_flat, tgt_flat, dim=1)
        
        # Combined loss: minimize target error + maximize distance from average
        total_loss = main_loss.mean() + self.penalty_weight * avg_penalty.mean()
        
        return total_loss, main_loss.mean()
        

class CosineSimilarityLoss(nn.Module):
    def __init__(self, eps: float = 1e-8):
        super().__init__()
        self.eps = eps

    def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        if predictions.shape != targets.shape:
            raise ValueError(
                f"predictions and targets must have same shape. "
                f"Got {predictions.shape} vs {targets.shape}."
            )

        cos_sim = F.cosine_similarity(predictions, targets, dim=-1, eps=self.eps)
        loss = 1.0 - cos_sim
        return loss.mean()
    
def combined_loss(pred, target, alpha=0.7):
    mse = F.mse_loss(pred, target)
    cosine = 1 - F.cosine_similarity(pred, target, dim=-1).mean()
    return alpha * mse + (1 - alpha) * cosine

I suspect the model is always predicting the mean vector, let's check if that's the case.

In [224]:
avg_embedding = embeddings.mean(dim=0)
avg_embedding = avg_embedding / avg_embedding.norm(p=2, dim=0)
avg_embedding = avg_embedding.to(device)

In [276]:
batch_size = 256
dataloader = DataLoader(dataset, batch_size, shuffle=True, drop_last=True)

model = SmallConceptModel(
    d_model=512,
    embed_dim=384,
    nhead=4,
    num_layers=3,
    dim_feedforward=512*4,
    dropout=0.0,
    max_seq_len=dataset.seq_len  # so positional encoding covers full length
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

#loss_fn = WeightedMSELoss(weights.to(device))
#loss_fn = nn.MSELoss()
#loss_fn = CosineSimilarityLoss()
#loss_fn = MSELossWithAvgPenalty(avg_embedding)
#loss_fn = MSELossWithAvgPenalty(avg_embedding)
loss_fn = AntiAverageLoss(avg_embedding, 0.35)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

In [None]:
num_epochs = 2

model.train()
for epoch in range(num_epochs):
    for idx, (input_batch, target_batch) in enumerate(dataloader):
        model.train()
        # input_batch, target_batch: [B, T, D]
        input_batch = input_batch.to(device)      # [B, T, 384]
        target_batch = target_batch.to(device)    # [B, T, 384]

        noise_scale = 0.01

        noise = torch.randn_like(input_batch) * noise_scale
        input_batch = input_batch + noise
        
        optimizer.zero_grad()
        preds = model(input_batch)                # preds.shape = [B, T, 384]
        loss, main_loss = loss_fn(preds, target_batch)       # compare preds[:, t, :] to target[:, t, :]
        loss.backward()
        optimizer.step()


        with torch.no_grad():
            B, T, D = preds.shape

            pred_flat = preds.reshape(-1, D)                                 # [B*T, 384]
            pred_flat_norm = F.normalize(pred_flat, p=2, dim=1)           # [B*T, 384]
            avg_expanded = avg_embedding.unsqueeze(0).expand(pred_flat_norm.size(0), -1)
            cos_sims = (pred_flat_norm * avg_expanded).sum(dim=1)         # [B*T]
            batch_avg_cos_sim = cos_sims.mean().item()

            tgt_flat = target_batch.reshape(-1, D)                                 # [B*T, 384]
            tgt_flat_norm = F.normalize(tgt_flat, p=2, dim=1)
            cos_sims = (tgt_flat_norm * avg_expanded).sum(dim=1)         # [B*T]
            tgt_batch_avg_cos_sim = cos_sims.mean().item()

        if (idx + 1) % 20 == 0:
            print(f"(Epoch {epoch+1}) {idx+1} | Loss = {loss.item():.6f} | Main Loss = {main_loss.item():.6f}| Pred Sim = {batch_avg_cos_sim:.6f} | True Sim = {tgt_batch_avg_cos_sim:.6f}")
    print(f"*** Epoch {epoch+1}/{num_epochs} — Loss = {loss.item():.6f} ***")

(Epoch 1) 20 | Loss = 0.953035 | Main Loss = 0.605389| Pred Sim = 0.998983 | True Sim = 0.394653
(Epoch 1) 40 | Loss = 0.961489 | Main Loss = 0.613842| Pred Sim = 0.999051 | True Sim = 0.386534
(Epoch 1) 60 | Loss = 0.955100 | Main Loss = 0.607454| Pred Sim = 0.999026 | True Sim = 0.392910
(Epoch 1) 80 | Loss = 0.959364 | Main Loss = 0.611714| Pred Sim = 0.999357 | True Sim = 0.388550


KeyboardInterrupt: 

---

In [None]:
import torch
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer
from small_concept_model.model import SmallConceptModel
from small_concept_model.pipeline import Pipeline
from small_concept_model.inverter import PreNet, Inverter, get_encoder, get_gpt2_decoder
from small_concept_model.train import train_scm, train_inverter
from small_concept_model.data import get_bookcorpus_scm, get_bookcorpus_inverter
from small_concept_model.auto import build_scm, build_inverter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
scm_configs = {
    "d_model": 384,
    "d_embed": 384,
    "d_ff": 4 * 384,
    "n_heads": 4,
    "n_layers": 3,
    "dropout": 0.1,
    "max_seq_len": 16
}

prenet_configs = {
    "input_dim": 384,
    "output_dim": 768,
    "rank": 128,
    "prefix_len": 20,
}

train_configs = {
    "lr": 1e-4,
    "weight_decay": 0,
    "batch_size": 128,
    "num_epochs": 5
}

In [None]:
encoder = get_encoder("paraphrase-multilingual-MiniLM-L12-v2")
prenet = PreNet(**prenet_configs).to(device)
decoder, tokenizer = get_gpt2_decoder()

In [None]:
data = get_bookcorpus_inverter(
    encoder, tokenizer, max_target_len=64, embed_batch_size=256, sample=0.1, clean=True
)

In [None]:
train_inverter(prenet, decoder, tokenizer, data, **train_configs)

In [None]:
torch.save(prenet.state_dict(), "saved_models/prenet/prenet_100k_good.pth")

In [None]:
prenet.load_state_dict(torch.load("saved_models/prenet/prenet_100k_good.pth", map_location=device))

In [None]:
inverter = Inverter(prenet, decoder, tokenizer)

In [None]:
sample_text = "\"You were never the problem,\" he said, \"and you know it\"."

vec = encoder.encode(sample_text, convert_to_tensor=True)
inverter.invert(
    vec, max_len=50, temperature=0.4, repetition_penalty=1.2
)

---

In [None]:
encoder = SentenceTransformer("paraphrase-multilingual-MiniLM-L12-v2", device="cuda")
dataset = get_bookcorpus_scm(encoder, 32)

Train the model.

In [None]:
from datasets import load_dataset

data_x = load_dataset("francescoortame/bookcorpus-rand-1M", split="train")

In [None]:
texts = data_x["text"]

In [None]:
from tqdm import tqdm
from small_concept_model.utils import clean_text

clean_texts = []

for t in tqdm(texts, total=len(texts)):
    clean_texts.append(clean_text(t))

In [None]:
clean_texts = [t + tokenizer.eos_token for t in clean_texts]

In [None]:
clean_texts[237920]

# SCM Training

In [None]:
encoder = get_encoder("paraphrase-multilingual-MiniLM-L12-v2")

dataset = get_bookcorpus_scm(
    encoder,
    embed_batch_size=128,
    clean=True
)

In [None]:
model = SmallConceptModel(**scm_configs).to(device)

In [None]:
train_scm(model, dataset, **train_configs)

In [None]:
prenet = PreNet(**prenet_configs).to(device)
prenet.load_state_dict(torch.load("saved_models/prenet/prenet_100k_good.pth", map_location=device))

decoder, tokenizer = get_gpt2_decoder()
inverter = Inverter(prenet, decoder, tokenizer)

In [None]:
pipe = Pipeline(encoder, model, inverter)

In [None]:
texts = [
    'Lexi stretched her arms.',
    'She heard the door open, and soft voices echoed down the hall toward her.',
]

pipe.generate(
    texts,
    n_future_steps = 5,
    sigma_noise = 0.0,
    temperature = 0.0,
    max_len = 30
)

In [None]:
inverter.invert(x)

## Pipeline

In [None]:
model = model.to("cuda")
inverter = build_inverter("paraphrase_multilingual")
pipe = Pipeline(encoder, model, inverter)

In [None]:
texts = [
    'he asked her if she was hungry.',
    'she never heard that before.',
]

pipe.generate(
    texts,
    n_future_steps = 5,
    sigma_noise = 0.0,
    temperature = 0.0,
    max_len = 30
)

---

In [None]:
import math
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import Optional


class PositionalEncoding(nn.Module):
    """Standard sinusoidal positional encoding."""

    def __init__(self, d_model: int, max_len: Optional[int] = 128):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        denominator = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float32)
            * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * denominator)
        pe[:, 1::2] = torch.cos(position * denominator)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]


class InputProj(nn.Module):
    def __init__(self, d_embed, d_model, scaler_mean, scaler_std):
        super().__init__()
        # scaler_mean, scaler_std: each is a [d_embed]-shaped tensor
        self.register_buffer("mean", scaler_mean)   # shape: [d_embed]
        self.register_buffer("std",  scaler_std)    # shape: [d_embed]
        self.linear = nn.Linear(d_embed, d_model)

    def normalize(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, T, d_embed]
        # subtract/scale each dimension separately
        return (x - self.mean.unsqueeze(0).unsqueeze(0)) / self.std.unsqueeze(0).unsqueeze(0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.normalize(x)            # [B, T, d_embed] → zero‐centered (per‐dim)
        return self.linear(x)            # → [B, T, d_model]



class OutputProj(nn.Module):
    def __init__(self, d_model, d_embed, scaler_mean, scaler_std):
        super().__init__()
        self.register_buffer("mean", scaler_mean)   # [d_embed]
        self.register_buffer("std",  scaler_std)    # [d_embed]
        self.linear = nn.Linear(d_model, d_embed)

    def denormalize(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, T, d_embed] in “normalized space”
        return x * self.std.unsqueeze(0).unsqueeze(0) + self.mean.unsqueeze(0).unsqueeze(0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear(x)             # [B, T, d_embed]
        return self.denormalize(x)     # map back to “real” embedding distribution



class Transformer(nn.Module):
    """Transformer encoder with causal masking."""

    def __init__(
        self,
        d_model: int,
        d_ff: int,
        n_heads: Optional[int] = 4,
        n_layers: Optional[int] = 3,
        dropout: Optional[float] = 0.1,
    ):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_ff,
            dropout=dropout,
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer=encoder_layer, num_layers=n_layers
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        _, seq_len, _ = x.size()
        bool_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1)
        bool_mask = bool_mask.to(x.device)
        return self.transformer(x, bool_mask)


class SmallConceptModel(nn.Module):
    """Autoregressive transformer-based concept model."""

    def __init__(
        self,
        d_model: int,
        d_embed: int,
        d_ff: int,
        n_heads: Optional[int] = 4,
        n_layers: Optional[int] = 3,
        dropout: Optional[float] = 0.1,
        max_seq_len: Optional[int] = 128,
        scaler_mean: Optional[float] = 0.0,
        scaler_std: Optional[float] = 1.0,
    ):
        super().__init__()
        self.d_model = d_model
        self.input_projection = InputProj(d_embed, d_model, scaler_mean, scaler_std)
        self.pos_encoder = PositionalEncoding(d_model, max_seq_len)
        self.transformer = Transformer(d_model, d_ff, n_heads, n_layers, dropout)
        self.output_projection = OutputProj(d_model, d_embed, scaler_mean, scaler_std)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.input_projection(x) * math.sqrt(self.d_model)
        x = self.pos_encoder(x)
        x = self.transformer(x)
        return self.output_projection(x)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from typing import Optional
from small_concept_model.data import InverterDataset, SCMDataset
from small_concept_model.inverter import PreNet
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from small_concept_model.model import SmallConceptModel
from tqdm import tqdm

def combined_mse_cosine_loss(
        preds: torch.Tensor,
        targets: torch.Tensor,
        lambda_cos: float = 0.5
    ) -> torch.Tensor:
    """
    preds:   (B, L, D_embed)  predicted embeddings
    targets: (B, L, D_embed)  ground-truth embeddings
    lambda_cos: weight on the (1 - cosine) term. 
                Total loss = lambda_cos * (1 - cos) + (1-lambda_cos) * MSE.

    Returns: mean loss over all B * L elements (a scalar).
    """
    # 1) Compute MSE term (per-coordinate)
    mse_per_coord = F.mse_loss(preds, targets, reduction="none")  # shape (B, L, D_embed)
    mse_per_vector = mse_per_coord.mean(dim=-1)                     # shape (B, L), avg over D_embed
    mse_term = mse_per_vector.mean()                                # scalar: avg over B * L

    # 2) Compute cosine term
    #    Flatten B*L so we can use F.cosine_similarity on shape ((B*L), D_embed)
    B, L, D = preds.shape
    preds_flat   = preds.view(B * L, D)    # shape (B*L, D_embed)
    targets_flat = targets.view(B * L, D)  # shape (B*L, D_embed)

    #    cosine_similarity returns shape (B*L,), values in [-1, +1]
    cos_sim = F.cosine_similarity(preds_flat, targets_flat, dim=-1, eps=1e-8)  # (B*L,)
    cos_dist = 1.0 - cos_sim                                                   # (B*L,)

    cosine_term = cos_dist.mean()  # scalar

    # 3) Combine
    loss = lambda_cos * cosine_term + (1.0 - lambda_cos) * mse_term
    return loss

def train_scm(
    model: SmallConceptModel,
    train_dataset: SCMDataset,
    lr: Optional[float] = 1e-3,
    weight_decay: Optional[float] = 1e-2,
    batch_size: Optional[int] = 32,
    num_epochs: Optional[int] = 1,
    schedule_length: int = 3,   # # of epochs to go from ε=0 → ε_max
    eps_max: float = 0.5,
):
    """Train the SCM for next-embedding prediction."""

    train_loader = DataLoader(train_dataset, batch_size=batch_size)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.train()

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    mse_loss = torch.nn.MSELoss(reduction="none")

    for epoch in range(1, num_epochs + 1):
        epoch_loss = 0.0
        n_batches = 0

        for batch_idx, (input_seq, target_seq) in enumerate(train_loader):
            input_seq = input_seq.to(device)
            target_seq = target_seq.to(device)
            
            output = model(input_seq)
            loss = combined_mse_cosine_loss(output, target_seq, lambda_cos=1e-9)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            n_batches += 1

            if (batch_idx + 1) % 100 == 0:
                print(
                    f"Epoch [{epoch}/{num_epochs}]  "
                    f"Batch [{batch_idx+1}/{len(train_loader)}]  "
                    f"Loss: {loss.item():.6f}"
                )

        avg_epoch_loss = epoch_loss / n_batches
        print(f"*** Epoch {epoch} Complete.  Avg Loss = {avg_epoch_loss:.6f} ***")


In [None]:
tensor_list = [dataset[i][0] for i in range(len(dataset))]
all_embeddings = torch.stack(tensor_list)

flat = all_embeddings.view(-1, 384)        # → [1_600_000, d_embed]
mean_vec = flat.mean(dim=0)                    # [d_embed]
std_vec  = flat.std(dim=0, unbiased=False)

scm_configs = {
    "d_model": 384,
    "d_embed": 384,
    "d_ff": 4 * 384,
    "n_heads": 4,
    "n_layers": 3,
    "dropout": 0.1,
    "max_seq_len": 16,
    "scaler_mean": mean_vec,
    "scaler_std": std_vec
}

model = SmallConceptModel(**scm_configs).to(device)

In [None]:
train_scm(
    model,
    dataset,
    lr=1e-3,
    weight_decay=0,
    batch_size=128,
    num_epochs=3,
    schedule_length=3,
    eps_max=0.6
)

In [None]:
texts = [
    'he asked her if she was hungry.',
    'she never heard that before.',
]

pipe.generate(
    texts,
    n_future_steps = 5,
    sigma_noise = 0.0,
    temperature = 0.0,
    max_len = 30
)

In [None]:
tensor_list = [dataset[i][0] for i in range(len(dataset))]

In [None]:
c = torch.stack(tensor_list)
d = c.view(-1, 384)

global_mean = d.mean(dim=0)

In [None]:
# shift targets:
X = c[:, :14, :]   # your input windows (ground truth)
Y = c[:, 1:, :]    # “true” next embeddings

n_total = float(Y.numel())
baseline_mse = ((Y - global_mean.unsqueeze(0).unsqueeze(0))**2).sum() / n_total
print("Mean-predictor MSE:", baseline_mse.item())


In [None]:
c.view(-1, 384).size()