In [2]:
from events import EventPair

from typing import Dict, Iterable
from pathlib import Path

import numpy as np

from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments
    )

from peft import (
    LoraConfig,
    TaskType,
    get_peft_model
)

from datasets import Dataset
import evaluate

## Preprocessing data

In [4]:
checkpoint = 'FacebookAI/roberta-base'
tokenizer = RobertaTokenizer.from_pretrained(checkpoint)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [5]:
def load_data(fpath):
    """Load and yield dataset from its path"""
    is_test_set = Path(fpath).suffix == '.test'
    with open(fpath, 'r', encoding='utf-8') as f:
        for line in f:
            yield EventPair(line, is_test_set)


def tokenize_func(sample):
    """Define tokenization function for a single sample"""
    return tokenizer(
        sample['event_1'],
        sample['event_2'],
        truncation=True
    )


def build_dataset_from(fpath) -> Dataset:
    """Build HF's `Dataset` for training"""
    event_pairs = load_data(fpath)
    events_1, events_2, labels = [], [], []

    for pair in event_pairs:
        event_1, event_2 = pair.events
        label = pair.label

        events_1.append(event_1)
        events_2.append(event_2)
        labels.append(label)

    data_dict = {
        'event_1': events_1,
        'event_2': events_2,
        'label': labels
    }

    dataset = Dataset.from_dict(data_dict)

    return dataset.map(tokenize_func, batched=True)

In [6]:
train_data = build_dataset_from('../data/event_pairs.train')
dev_data = build_dataset_from('../data/event_pairs.dev')
test_data = build_dataset_from('../data/event_pairs.test')

Map: 100%|██████████| 227328/227328 [00:07<00:00, 28608.33 examples/s]
Map: 100%|██████████| 36438/36438 [00:01<00:00, 28294.53 examples/s]
Map: 100%|██████████| 42953/42953 [00:01<00:00, 28225.15 examples/s]


In [7]:
type(train_data)

datasets.arrow_dataset.Dataset

In [12]:
for i in range(5):
   print(train_data[i]['label'])

1
1
1
1
1


## Model configuration

### Config LoRA for PEFT

In [8]:
peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    r=3,
    lora_alpha=32,
    lora_dropout=0.1
)

In [9]:
model = RobertaForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 702,722 || all params: 125,349,892 || trainable%: 0.5606


### Config training arguments

In [10]:
# cache_dir = Path('../.cache/ft_models/')
# if not cache_dir.exists():
#     cache_dir.mkdir()

# training_args = TrainingArguments(
#     output_dir=cache_dir / 'lora_roberta',
#     learning_rate=1e-3,
#     per_device_train_batch_size=32,
#     per_device_eval_batch_size=32,
#     num_train_epochs=2,
#     eval_strategy='epoch',
#     save_strategy='epoch',
#     load_best_model_at_end=True
# )
training_args = TrainingArguments('test-trainer', eval_strategy='epoch')

In [49]:
import torch
from torch.nn import CrossEntropyLoss
from collections import Counter

def weighted_loss_fn(training_data: Dataset):
    label_counts = Counter(sample['label'] for sample in training_data)

    total_samples = sum(label_counts.values())
    class_counts = np.array([label_counts.get(0, 1), label_counts.get(1, 1)])

    weights = total_samples / (2.0 * class_counts)
    weights_tensor = torch.tensor(weights, dtype=torch.float32).to('cuda')

    return weights_tensor, CrossEntropyLoss(weight=weights_tensor)


class WeightedTrainer(Trainer):
    def __init__(self, *args, loss_fn=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_fn = loss_fn if loss_fn is not None else CrossEntropyLoss()

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")  # Extract labels
        outputs = model(**inputs)
        logits = outputs.logits  # Extract model logits

        loss = self.loss_fn(logits, labels)  # Compute loss

        return (loss, outputs) if return_outputs else loss

In [50]:
w, _ = weighted_loss_fn(train_data)
w


tensor([0.5473, 5.7847], device='cuda:0')

In [46]:
def compute_metrics(eval_preds):
    acc_metric = evaluate.load('accuracy')
    f1_metric = evaluate.load('f1')
    precision_metric = evaluate.load("precision")

    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)

    f1 = f1_metric.compute(references=labels, predictions=predictions)['f1']
    acc = acc_metric.compute(references=labels, predictions=predictions)
    precision = precision_metric.compute(references=labels, predictions=predictions)

    # return metric.compute(references=labels, predictions=predictions)
    return {'acc': acc, 'precision': precision, 'F1': f1}

In [47]:
loss_fn = weighted_loss_fn(train_data)

trainer = WeightedTrainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=dev_data,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    loss_fn=loss_fn
)

  super().__init__(*args, **kwargs)


In [48]:
trainer.train()

Epoch,Training Loss,Validation Loss,Acc,Precision,F1
1,0.405,0.331965,{'accuracy': 0.9300455568362699},{'precision': 0.6361216730038023},0.5676
2,0.3333,0.343859,{'accuracy': 0.9374828475767056},{'precision': 0.7503805175038052},0.564935
3,0.3903,0.352099,{'accuracy': 0.9363576486085954},{'precision': 0.7428131416837782},0.555151


TrainOutput(global_step=85248, training_loss=0.4028636823903333, metrics={'train_runtime': 895.5361, 'train_samples_per_second': 761.537, 'train_steps_per_second': 95.192, 'total_flos': 3098421928873152.0, 'train_loss': 0.4028636823903333, 'epoch': 3.0})