In [1]:
import pandas as pd
import os
import torch
from transformers import AutoModel, AutoTokenizer

df = pd.read_csv(os.path.abspath("C:\\Users\\mokrota\\Documents\\GitHub\\math_problem_recommender\\math_problem_recommender\\andreescu-andrica-problems-on-number-theory\\benchmark_v1.csv"))
df

Unnamed: 0,TopicMetadata,Problem&Solution
0,(Arithmetic Functions)->(Euler’s totient funct...,['Problem. Prove that there are infinitely man...
1,(Arithmetic Functions)->(Exponent of a prime a...,['Problem. Let $p$ be a prime. Find the expone...
2,(Arithmetic Functions)->(Multiplicative functi...,['Problem. 1) Prove that the convolution produ...
3,(Arithmetic Functions)->(Number of divisors),['Problem. For any $n\\geq2$ \n\n$$\n\\tau(n)...
4,(Arithmetic Functions)->(Sum of divisors),"['Problem. For any $n\\geq2$ , \n\n$$\n\\sigm..."
5,(Basic Principles in Number Theory)->(Inclusio...,"['Problem. Let $S=\\{1,2,3,\\ldots,280\\}$ . F..."
6,(Basic Principles in Number Theory)->(Infinite...,"['Problem. Find all triples $(x,y,z)$ of nonne..."
7,(Basic Principles in Number Theory)->(Mathemat...,"['Problem. Prove that, for any integer $n\\geq..."
8,(Basic Principles in Number Theory)->(Two simp...,['Problem. Let $n_{1}<n_{2}<\\cdots<n_{2000}<1...
9,(Basic Principles in Number Theory)->(Two simp...,['Problem. Show that there exist infinitely ma...


In [2]:
import ast
df['Problem&Solution'] = df['Problem&Solution'].apply(ast.literal_eval)

In [3]:
df = df.explode("Problem&Solution").rename({"TopicMetadata": "label", "Problem&Solution": "text"}, axis=1).reset_index().drop("index", axis=1)

In [4]:
import pandas as pd
import random

def make_ds(df):
    df = df.copy()  # Work on a copy to avoid modifying original df
    triplets = []
    debug = []

    while True:
        if len(df) < 3:
            # Not enough samples left to form triplets
            break

        # Get all labels that have at least 2 samples (for anchor + positive)
        label_counts = df['label'].value_counts()
        valid_labels = label_counts[label_counts > 1].index.tolist()
        if len(valid_labels) < 1:
            # No label with at least 2 samples - can't pick anchor + positive
            break

        # Also need at least 1 different label to pick negative from
        if len(label_counts) < 2:
            break

        # Pick a random anchor label with at least 2 samples
        anchor_label = random.choice(valid_labels)
        # Filter df by anchor_label
        anchor_pool = df[df['label'] == anchor_label]

        # Randomly pick anchor
        anchor_idx = random.choice(anchor_pool.index.tolist())

        # Remove anchor from anchor_pool to pick positive
        positive_pool = anchor_pool.drop(anchor_idx)
        if positive_pool.empty:
            # No positive available for this anchor, try another iteration
            continue
        positive_idx = random.choice(positive_pool.index.tolist())

        # Negative candidates = all samples with different labels than anchor_label
        negative_pool = df[df['label'] != anchor_label]
        if negative_pool.empty:
            # No negative available, try another iteration
            continue
        negative_idx = random.choice(negative_pool.index.tolist())

        # Append triplet (text, text, text)
        anchor_text = df.at[anchor_idx, 'text']
        positive_text = df.at[positive_idx, 'text']
        negative_text = df.at[negative_idx, 'text']
        triplets.append((anchor_text, positive_text, negative_text))

        # Remove used samples from df
        df = df.drop([anchor_idx, positive_idx, negative_idx])

    return triplets

triplets = make_ds(df)

In [5]:
class TrainableBERTCLSMeanPooler(torch.nn.Module):
    def __init__(self, model_name='allenai/longformer-base-4096', stride=10, max_length=1024):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.tokenizer_args = {
            "max_length": max_length,
            "stride": stride,
            "return_overflowing_tokens": True,
            "truncation": True,
            "padding": True,
            "return_tensors": "pt"
        }

    def forward(self, texts: list[str]):
        if isinstance(texts, str):
            texts = [texts]

        inputs = self.tokenizer(texts, **self.tokenizer_args)
        input_ids = inputs['input_ids'].to(self.model.device)
        attention_mask = inputs['attention_mask'].to(self.model.device)
        token_type_ids = inputs.get('token_type_ids')
        if token_type_ids is not None:
            token_type_ids = token_type_ids.to(self.model.device)
        if "overflow_to_sample_mapping" not in inputs:
            inputs['overflow_to_sample_mapping'] = [0]
            
        mapping = inputs['overflow_to_sample_mapping']

        model_inputs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask
        }
        if token_type_ids is not None:
            model_inputs["token_type_ids"] = token_type_ids

        outputs = self.model(**model_inputs)  # (batch * chunks, seq_len, hidden_size)
        cls_embs = outputs.last_hidden_state[:, 0]  # (batch * chunks, hidden_size)

        # Now average CLS embeddings for each original sample using mapping
        grouped_cls = []
        for i in range(len(texts)):
            indices = [j for j, v in enumerate(mapping) if v == i]
            cls_group = cls_embs[indices]
            mean_cls = cls_group.mean(dim=0)
            grouped_cls.append(mean_cls)

        return torch.stack(grouped_cls, dim=0)  # (batch, hidden_size)


In [6]:
from torch.optim import AdamW
from tqdm import tqdm

device = torch.device("cuda")

model = TrainableBERTCLSMeanPooler().to(device)
triplet_loss = torch.nn.TripletMarginLoss(margin=1.0)
optimizer = AdamW(model.parameters(), lr=2e-5)

batch_size = 1
# Suppose triplets is a list of (a, p, n) tuples
# No need to pre-collect anchors, positives, negatives separately
all_losses = []

model.train()
for epoch in range(3):
    total_loss = 0.0
    for i in tqdm(range(0, len(triplets), batch_size)):
        batch = triplets[i:i+batch_size]

        anchors = [a for a, _, _ in batch]
        positives = [p for _, p, _ in batch]
        negatives = [n for _, _, n in batch]

        optimizer.zero_grad()

        anchor_emb = model(anchors)
        pos_emb = model(positives)
        neg_emb = model(negatives)

        loss = triplet_loss(anchor_emb, pos_emb, neg_emb)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        all_losses.append(loss.item())

    avg_loss = total_loss / (len(triplets) / batch_size)
    print(f"Epoch {epoch+1} - Avg Loss: {avg_loss:.4f}")

  0%|          | 0/136 [00:00<?, ?it/s]Input ids are automatically padded to be a multiple of `config.attention_window`: 512
100%|██████████| 136/136 [10:32<00:00,  4.65s/it]


Epoch 1 - Avg Loss: 1.5176


100%|██████████| 136/136 [11:30<00:00,  5.08s/it]


Epoch 2 - Avg Loss: 0.9063


100%|██████████| 136/136 [11:28<00:00,  5.06s/it]

Epoch 3 - Avg Loss: 0.6331





In [7]:
import plotly.graph_objects as go

# Example: Replace this with your actual data
# Suppose all_losses was collected during training

fig = go.Figure()

fig.add_trace(go.Scatter(
    y=all_losses,
    x=list(range(len(all_losses))),
    mode='lines+markers',
    name='Triplet Loss',
    line=dict(color='royalblue', width=2),
    marker=dict(size=4)
))

fig.update_layout(
    title="Triplet Loss per Batch",
    xaxis_title="Batch Number",
    yaxis_title="Loss",
    template="plotly_white",
    font=dict(size=14),
    showlegend=False,
    margin=dict(l=40, r=40, t=40, b=40)
)

fig.show()


In [8]:
torch.save(model.state_dict(), "model_weights.pt")