In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from small_concept_model.utils import clean_text
from small_concept_model.auto import build_inverter
from small_concept_model.baseline import LSTMConceptModel, LSTMDataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
encoder = SentenceTransformer("paraphrase-multilingual-MiniLM-L12-v2")

data = load_dataset("francescoortame/bookcorpus-sorted-100k16x", split="train")
flat_texts = [t for sublist in data["slice"] for t in sublist]

flat_texts = [clean_text(t) for t in flat_texts]

n_seqs = len(data)
seq_len = len(data["slice"][0])
d_embed = encoder.get_sentence_embedding_dimension()

embeddings = encoder.encode(
    flat_texts,
    batch_size=128,
    show_progress_bar=True,
    convert_to_tensor=True,
)

all_embeds = embeddings.contiguous().view(n_seqs, seq_len, d_embed)

In [None]:
model = LSTMConceptModel(
    input_dim=384,
    hidden_dim=512,
    num_layers=2,
    dropout=0.1
).to(device)

model.load_state_dict(torch.load("saved_models/lstm/lstm_100k.pth", map_location=device))

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

scm_dataset = LSTMDataset(all_embeds)
dataloader = DataLoader(scm_dataset, batch_size=32, shuffle=True)

Train the LSTM.

In [None]:
NUM_EPOCHS = 5
COSINE_LOSS = False
PRINT_COSINE = False

target = torch.ones([32]).to(device)

model.train()
for epoch in range(1, NUM_EPOCHS + 1):
    epoch_loss = 0.0
    for i, (batch_seq, batch_target) in enumerate(dataloader):
        optimizer.zero_grad()

        batch_seq = batch_seq.to(device)
        batch_target = batch_target.to(device)

        pred = model(batch_seq)
        loss = criterion(pred, batch_target)

        cos_sim = F.cosine_similarity(pred, batch_target).mean()

        if COSINE_LOSS:
            loss = 1 - cos_sim

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        epoch_loss += loss.item() * batch_seq.size(0)

        if PRINT_COSINE:
            if (i + 1) % 100 == 0:
                print(f"Cos Sim: {cos_sim}")

    avg_loss = epoch_loss / len(all_embeds)
    print(f"Epoch [{epoch}/{NUM_EPOCHS}] - Loss: {avg_loss:.6f}")

In [None]:
torch.save(model.state_dict(), "saved_models/lstm/lstm_100k.pth")

### Inference

In [None]:
inverter = build_inverter("paraphrase_multilingual")

In [None]:
GEN_STEPS: int = 20
COSINE_THRESHOLD: float = 0.9

texts = [
    "He was a researcher at the time, working on machine learning.",
    "However, he was not satisfied with his job.",
]


for t in texts:
    x = encoder.encode(t, convert_to_tensor=True)
    print(inverter.invert(x, 30, 0.0, 1.1))

encoded_texts = encoder.encode(texts, convert_to_tensor=True).to(device)

for _ in range(GEN_STEPS):
    print("---")
    res = model(encoded_texts)
    cos_sim = F.cosine_similarity(res, encoded_texts[-1], dim=0)
    print(cos_sim.item())
    if cos_sim.mean().item() > COSINE_THRESHOLD:
        break
    encoded_texts = torch.cat([encoded_texts, res.unsqueeze(0)], dim=0)
    print(inverter.invert(res, 30, 0.0, 1.1))
