In [1]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from sklearn.model_selection import train_test_split

from datamodule import PromptDataset, get_length_reg
from model_baseline import DistilBertRegressor



  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import pandas as pd

# === Tokenizer ===
tokenizer_BERT = AutoTokenizer.from_pretrained("distilbert-base-uncased")
tokenizer_LLaDa = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)

# I would have done with also 2048 and 4096 but in the training data there are no examples with that length

# === Load data ===
df_train = pd.read_csv(r"..\data\train.csv")
train_data = list(zip(df_train["user_prompt"], get_length_reg(df_train["model_response"], tokenizer_LLaDa)))
del df_train

df_test = pd.read_csv(r"..\data\test.csv")
df_test = df_test.dropna(subset=["model_response"])
data_test = list(zip(df_test["user_prompt"], get_length_reg(df_test["model_response"], tokenizer_LLaDa)))
del df_test

val_data, test_data = train_test_split(data_test, test_size=0.3, random_state=42)

# All the training prompt except one have length < 64
train_ds = PromptDataset(train_data, tokenizer_BERT, max_len=64)
val_ds = PromptDataset(val_data, tokenizer_BERT, max_len=128)
test_ds = PromptDataset(test_data, tokenizer_BERT, max_len=128)

train_dl = DataLoader(train_ds, batch_size=16, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=16)
test_dl = DataLoader(test_ds, batch_size=16)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



In [5]:
model = DistilBertRegressor()
# model = torch.compile(model).to(device)
model = model.to(device)

In [6]:
@torch.no_grad()
def estimate_loss(eval_iters = 10):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        if split == 'train':
            dataloader = train_dl
        else:
            dataloader = val_dl
        k = 0
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            logits, loss = model(input_ids, attention_mask, labels)
            assert loss is not None, "Loss should not be None"
            losses[k] = loss.item()
            k += 1
            if k >= eval_iters:
                break
        out[split] = losses.mean()
    model.train()
    return out

### Freezing the DistilBERT parameters

In [7]:
for param in model.encoder.parameters():
    param.requires_grad = False

optimizer = torch.optim.AdamW(model.regressor.parameters(), lr=2e-5)

In [8]:
eval_interval = 200
max_iters = len(train_dl)

# === Training loop ===
for epoch in range(10):
    model.train()
    total_loss = 0
    for i, batch in enumerate(train_dl):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        logits, loss = model(input_ids, attention_mask, labels)
        assert loss is not None, "Loss should not be None"
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if i % eval_interval == 0 or i == max_iters - 1:
                losses = estimate_loss()
                print(f"step {i}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    print(f"[Epoch {epoch+1}] Loss: {total_loss:.4f}")

step 0: train loss 53821.6680, val loss 56883.5703
step 200: train loss 54077.6875, val loss 56035.7305
step 400: train loss 51715.7578, val loss 54982.7812
step 600: train loss 51472.6094, val loss 53731.6680
step 800: train loss 50682.9062, val loss 52255.3711
step 890: train loss 54525.1992, val loss 51507.0586
[Epoch 1] Loss: 47106147.9385
step 0: train loss 57997.3828, val loss 51499.2109
step 200: train loss 50459.4648, val loss 49719.6914
step 400: train loss 35980.2891, val loss 47815.9922
step 600: train loss 50268.2773, val loss 45851.4609
step 800: train loss 43603.1641, val loss 43753.4688
step 890: train loss 36718.3203, val loss 42808.6953
[Epoch 2] Loss: 40830067.4756
step 0: train loss 44452.1250, val loss 42797.9805
step 200: train loss 35735.5391, val loss 40642.3359
step 400: train loss 30312.0117, val loss 38554.7578
step 600: train loss 35604.3047, val loss 36463.1055
step 800: train loss 27563.8008, val loss 34393.5781
step 890: train loss 28638.4414, val loss 335

### Fine tuning: changing also the parameters of DistilBERT

In [9]:
# Congela tutti i parametri di BERT
for param in model.encoder.parameters():
    param.requires_grad = True

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-6)

In [None]:
eval_interval = 200
max_iters = len(train_dl)

# === Training loop ===
for epoch in range(3):
    model.train()
    total_loss = 0
    for i, batch in enumerate(train_dl):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        logits, loss = model(input_ids, attention_mask, labels)
        assert loss is not None, "Loss should not be None"
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if i % eval_interval == 0 or i == max_iters - 1:
                losses = estimate_loss()
                print(f"step {i}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    print(f"[Epoch {epoch+1}] Loss: {total_loss:.4f}")

step 0: train loss 11786.5625, val loss 15483.3203
step 200: train loss 12569.9678, val loss 13120.4922
step 400: train loss 8993.9922, val loss 11818.5000
step 600: train loss 11440.7998, val loss 11661.8643
step 800: train loss 10326.2275, val loss 11394.1084
step 890: train loss 7899.3687, val loss 11435.4277
[Epoch 1] Loss: 9954936.4500
step 0: train loss 6805.6108, val loss 11389.7490
step 200: train loss 8292.1143, val loss 11462.5674
step 400: train loss 10395.0547, val loss 10864.5762
step 600: train loss 8008.9819, val loss 10445.2129
step 800: train loss 10152.5605, val loss 10665.1387
step 890: train loss 9229.6250, val loss 10639.5869
[Epoch 2] Loss: 8611750.4399
step 0: train loss 9773.0293, val loss 10644.4980
step 200: train loss 9010.0645, val loss 10362.3408
step 400: train loss 9648.1982, val loss 10408.4629
step 600: train loss 9706.8457, val loss 10232.0713
step 800: train loss 6735.9639, val loss 10523.7080
step 890: train loss 7574.8921, val loss 10656.4375
[Epoch

KeyboardInterrupt: 

In [11]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-7)

In [12]:
eval_interval = 200
max_iters = len(train_dl)

# === Training loop ===
for epoch in range(1):
    model.train()
    total_loss = 0
    for i, batch in enumerate(train_dl):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        logits, loss = model(input_ids, attention_mask, labels)
        assert loss is not None, "Loss should not be None"
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if i % eval_interval == 0 or i == max_iters - 1:
                losses = estimate_loss()
                print(f"step {i}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    print(f"[Epoch {epoch+1}] Loss: {total_loss:.4f}")

step 0: train loss 7409.3125, val loss 10114.3750
step 200: train loss 8863.0352, val loss 10173.2363
step 400: train loss 7386.7437, val loss 10268.1855
step 600: train loss 7325.1265, val loss 10288.1348
step 800: train loss 6753.1538, val loss 10148.7363
step 890: train loss 5931.0869, val loss 10062.7529
[Epoch 1] Loss: 7347738.3793


In [14]:
torch.save(model.state_dict(), "checkpoints/DistilBERT_LLaDa_reg.pth")


### Evaluation

In [15]:
model = DistilBertRegressor()
model.load_state_dict(torch.load("checkpoints/DistilBERT_LLaDa_reg.pth"))
model = model.to(device)
model.eval()
0

0

In [16]:
input_text = "Can you explain the theory of relativity?"
input_enc = tokenizer_BERT(input_text, return_tensors="pt", padding='max_length', truncation=True, max_length=512).to(device)
output = model(input_enc['input_ids'], input_enc['attention_mask'])
print(output)

(tensor([217.7619], device='cuda:0', grad_fn=<SqueezeBackward1>), None)


In [17]:
input_text = "What's your name?"
input_enc = tokenizer_BERT(input_text, return_tensors="pt", padding='max_length', truncation=True, max_length=512).to(device)
output = model(input_enc['input_ids'], input_enc['attention_mask'])
print(output)

(tensor([46.9342], device='cuda:0', grad_fn=<SqueezeBackward1>), None)


In [18]:
input_text = "What is 3 + 3?"
input_enc = tokenizer_BERT(input_text, return_tensors="pt", padding='max_length', truncation=True, max_length=512).to(device)
output = model(input_enc['input_ids'], input_enc['attention_mask'])
print(output)

(tensor([108.0012], device='cuda:0', grad_fn=<SqueezeBackward1>), None)


In [23]:
@torch.no_grad()
def see_prediction(kappa = 1):
    out = {}
    model.eval()
    for split in ['val']:
        losses = torch.zeros(kappa)
        if split == 'train':
            dataloader = train_dl
        else:
            dataloader = val_dl
        k = 0
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            logits, loss = model(input_ids, attention_mask, labels)
            pred = logits
            print(f"pred: {[(pred[i].item(), labels[i].item()) for i in range(16)]}")
            assert loss is not None, "Loss should not be None"
            losses[k] = loss.item()
            k += 1
            if k >= kappa:
                break
        out[split] = losses
    model.train()
    print("mean cross entropy loss: ", out["val"].mean())
    return 

In [24]:
see_prediction(10)

pred: [(67.27046203613281, 269), (126.904052734375, 47), (84.86556243896484, 75), (156.15689086914062, 158), (76.28671264648438, 53), (295.21954345703125, 348), (231.06982421875, 102), (130.68165588378906, 240), (329.534423828125, 234), (213.53468322753906, 216), (312.6249694824219, 260), (319.0155029296875, 345), (293.59295654296875, 454), (59.38877487182617, 106), (163.07162475585938, 210), (267.2560119628906, 303)]
pred: [(247.94784545898438, 255), (198.24847412109375, 260), (247.11895751953125, 75), (209.23394775390625, 267), (181.27635192871094, 189), (200.6521453857422, 310), (129.39602661132812, 30), (179.17127990722656, 242), (364.22772216796875, 362), (156.38255310058594, 219), (101.56836700439453, 28), (280.0093994140625, 221), (194.967041015625, 215), (152.21661376953125, 101), (159.84963989257812, 175), (318.2129211425781, 2)]
pred: [(122.97423553466797, 64), (163.20803833007812, 125), (73.89338684082031, 26), (117.00796508789062, 167), (92.34988403320312, 153), (271.680480

In [27]:
from sklearn.metrics import accuracy_score
from sklearn.metrics import mean_squared_error


def evaluate_accuracy(model, dataloader, device='cpu'):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            logits, loss = model(input_ids, attention_mask, labels)
            pred = logits

            all_preds.extend(pred.cpu().tolist())
            all_labels.extend(labels.cpu().tolist())

    mse = mean_squared_error(all_labels, all_preds)
    return mse



In [28]:
mse = evaluate_accuracy(model, val_dl, device=device)
print(f"Mean Squared Error: {mse:.4f}")


Mean Squared Error: 9267.0477


### Final evaluation: obtaining the list of predictions on the test


In [30]:
df_test = pd.read_csv(r"..\data\test.csv")
print(df_test.shape)
df_test = df_test.dropna(subset=["model_response"])
print(df_test.shape)
data_test = list(zip(df_test["user_prompt"], get_length_reg(df_test["model_response"], tokenizer_LLaDa)))
del df_test


test_ds = PromptDataset(data_test, tokenizer_BERT, max_len=64)
test_dl = DataLoader(test_ds, batch_size=16)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

(5000, 3)
(4998, 3)


In [33]:
def get_predictions(model, dataloader, device='cpu'):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            pred, loss = model(input_ids, attention_mask, labels)

            all_preds.extend(pred.cpu().tolist())
            all_labels.extend(labels.cpu().tolist())

    return all_preds, all_labels

In [34]:
all_preds, all_labels = get_predictions(model, test_dl, device=device)

In [39]:
import numpy as np
np.ceil(np.array(all_preds)), np.array(all_labels)


(array([150., 300.,  98., ..., 133., 316., 155.]),
 array([ 28,  65, 154, ..., 111, 359,  28]))

In [41]:
from sklearn.metrics import mean_squared_error
all_preds = np.ceil(np.array(all_preds))
all_labels = np.array(all_labels)
mse = mean_squared_error(all_labels, all_preds)
print(f"Mean Squared Error: {mse:.4f}")

Mean Squared Error: 9013.2199


In [42]:
import os
import numpy as np

# Ensure the directory exists
os.makedirs("prediction_test", exist_ok=True)

# Save all_preds as a numpy array
np.save("prediction_test/DistilBERT_LLaDa_reg.npy", all_preds)