In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import sys
sys.path.append(os.path.join(os.getcwd(), '..'))
from src.data.dataset import EssayDataset
from src.data.longDataset import LongEssayDataset
from src.models.hierarchicalBertPeft import HierarchicalBertPeft
import pandas as pd
import torch
from torch.utils.data import DataLoader

In [2]:
df = pd.read_csv("../data/aes_dataset_5k.csv")
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5363 entries, 0 to 5362
Data columns (total 13 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
 0   question           4859 non-null   object 
 1   reference_answer   5363 non-null   object 
 2   answer             5363 non-null   object 
 3   score              5363 non-null   float64
 4   dataset            5363 non-null   object 
 5   normalized_score   5363 non-null   float64
 6   normalized_score2  5363 non-null   int64  
 7   bert_length        5363 non-null   int64  
 8   indobert_length    5363 non-null   int64  
 9   albert_length      5363 non-null   int64  
 10  longformer_length  5363 non-null   int64  
 11  multibert_length   5363 non-null   int64  
 12  indoalbert_length  5363 non-null   int64  
dtypes: float64(2), int64(7), object(4)
memory usage: 544.8+ KB


In [3]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = HierarchicalBertPeft("bert-base-uncased").to("cuda")
dataset = EssayDataset(df, tokenizer, 512)

In [4]:
dataset = LongEssayDataset(df, tokenizer, 512, 0, "bert_length")

max len 512


In [5]:
dataset = LongEssayDataset(df, tokenizer, 128, 64, "bert_length")

max len 128


In [6]:
from torch.nn.utils.rnn import pad_sequence

def custom_collate_fn(batch):
    # Separate features and labels
    features = [item[0] for item in batch]
    labels = torch.stack([item[1] for item in batch])

    # Pad the input_ids, attention_mask, and token_type_ids
    padded_input_ids = pad_sequence([f["input_ids"] for f in features], batch_first=True, padding_value=0)
    padded_attention_mask = pad_sequence([f["attention_mask"] for f in features], batch_first=True, padding_value=0)
    padded_token_type_ids = pad_sequence([f["token_type_ids"] for f in features], batch_first=True, padding_value=0)

    # Return a dictionary of padded features and labels
    return {
        "input_ids": padded_input_ids,
        "attention_mask": padded_attention_mask,
        "token_type_ids": padded_token_type_ids,
    }, labels

dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=custom_collate_fn)

In [7]:
from torch.optim import AdamW
optimizer = AdamW(model.parameters(), lr=1e-5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = torch.nn.MSELoss()

In [8]:
for epoch in range(1):
    print(f"====== Training Epoch {epoch + 1}/{1} ======")
    model.train()
    train_mse_loss = 0
    for batchs, targets in dataloader:
        try:
            optimizer.zero_grad()
            input_ids = batchs['input_ids'].to(device)
            attention_mask = batchs['attention_mask'].to(device)
            token_type_ids = batchs['token_type_ids'].to(device)
            targets = targets.to(device)
            predictions = model(input_ids, attention_mask, token_type_ids).squeeze(1)
            loss = criterion(predictions, targets)
            if torch.isnan(loss):
                print("⚠️ Warning: NaN detected in loss!")
                print(f"Predictions: {predictions}")
                print(f"Targets: {targets}")
                continue
            loss.backward()
            optimizer.step()
            train_mse_loss += loss.item()
        except Exception as e:
            print(f"Error during training: {str(e)}")
            torch.cuda.empty_cache()



  attn_output = torch.nn.functional.scaled_dot_product_attention(
Token indices sequence length is longer than the specified maximum sequence length for this model (725 > 512). Running this sequence through the model will result in indexing errors
