In [1]:
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments
    )

from peft import (
    LoraConfig,
    TaskType,
    get_peft_model
)

from datasets import Dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
checkpoint = 'FacebookAI/roberta-base'

## Preprocessing data

In [3]:
from events import EventPair

In [4]:
from typing import Dict, Iterable
from pathlib import Path

In [5]:
p = Path('../data/event_pairs.test')
p.suffixes

['.test']

In [6]:
tokenizer = RobertaTokenizer.from_pretrained(checkpoint)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [7]:
def load_data(fpath):
    """Load and yield dataset from its path"""
    is_test_set = Path(fpath).suffix == '.test'
    with open(fpath, 'r', encoding='utf-8') as f:
        for line in f:
            yield EventPair(line, is_test_set)


def tokenize_func(sample):
    """Define tokenization function for a single sample"""
    return tokenizer(
        sample['event_1'],
        sample['event_2'],
        truncation=True
    )


def build_dataset_from(fpath) -> Dataset:
    """Build HF's `Dataset` for training"""
    event_pairs = load_data(fpath)
    events_1, events_2, labels = [], [], []

    for pair in event_pairs:
        event_1, event_2 = pair.events
        label = pair.label

        events_1.append(event_1)
        events_2.append(event_2)
        labels.append(label)

    data_dict = {
        'event_1': events_1,
        'event_2': events_2,
        'label': labels
    }

    dataset = Dataset.from_dict(data_dict)

    return dataset.map(tokenize_func, batched=True)

In [8]:
train_data = build_dataset_from('../data/event_pairs.train')
dev_data = build_dataset_from('../data/event_pairs.dev')
test_data = build_dataset_from('../data/event_pairs.test')

Map: 100%|██████████| 227328/227328 [00:07<00:00, 29864.33 examples/s]
Map: 100%|██████████| 36438/36438 [00:01<00:00, 33924.56 examples/s]
Map: 100%|██████████| 42953/42953 [00:01<00:00, 33869.70 examples/s]


In [9]:
type(train_data)

datasets.arrow_dataset.Dataset

## Model configuration

### Config LoRA for PEFT

In [10]:
peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    r=3,
    lora_alpha=32,
    lora_dropout=0.1
)

In [11]:
model = RobertaForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 702,722 || all params: 125,349,892 || trainable%: 0.5606


### Config training arguments

In [12]:
# cache_dir = Path('../.cache/ft_models/')
# if not cache_dir.exists():
#     cache_dir.mkdir()

# training_args = TrainingArguments(
#     output_dir=cache_dir / 'lora_roberta',
#     learning_rate=1e-3,
#     per_device_train_batch_size=32,
#     per_device_eval_batch_size=32,
#     num_train_epochs=2,
#     eval_strategy='epoch',
#     save_strategy='epoch',
#     load_best_model_at_end=True
# )
training_args = TrainingArguments('test-trainer', eval_strategy='epoch')

In [13]:
import evaluate
import numpy as np

def compute_metrics(eval_preds):
    metric = evaluate.load('accuracy')
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(references=labels, predictions=predictions)

In [14]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=dev_data,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

  trainer = Trainer(


In [15]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.1725,0.293269,0.935699


KeyboardInterrupt: 