In [1]:
import torch
import numpy as np
from pathlib import Path

from datasets import Dataset

from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import TrainingArguments, Trainer
from transformers import pipeline, enable_full_determinism


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
RANDOM_STATE = 42
NUM_UNFROZEN_TRANSFORMER_LAYERS = 8
NUM_EPOCHS = 3
BATCH_SIZE = 32
DATASET_DIR = Path('../datasets/link_classification/1ctx')


In [3]:
enable_full_determinism(RANDOM_STATE)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device "{device}"')


Using device "cuda"


In [4]:
train_dataset = Dataset.load_from_disk(DATASET_DIR / 'train')
eval_dataset  = Dataset.load_from_disk(DATASET_DIR / 'eval')
test_dataset  = Dataset.load_from_disk(DATASET_DIR / 'test')


In [5]:
tokenizer = AutoTokenizer.from_pretrained('ai-forever/ruBert-base')
label_encoder = LabelEncoder().fit(train_dataset['label'])

model = AutoModelForSequenceClassification.from_pretrained('ai-forever/ruBert-base', num_labels=label_encoder.classes_.size)
model.bert.requires_grad_(False)
for i in range(NUM_UNFROZEN_TRANSFORMER_LAYERS):
    model.bert.encoder.layer[-1 - i].requires_grad_(True)
model.to(device)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ai-forever/ruBert-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(120138, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1

In [6]:
encode_kwargs = {
    'max_length': 256,
    'padding': 'max_length',
    'truncation': True,
    'return_tensors': 'pt',
}

def tokenize_dataset(data):
    premises = [f'{row["fragment1_left"]}. {row["fragment1"]}. {row["fragment1_right"]}' for row in data]
    conclusions = [f'{row["fragment2_left"]}. {row["fragment2"]}. {row["fragment2_right"]}' for row in data]
    dataset = Dataset.from_dict({
        'text': list(zip(premises, conclusions)),
        'label': label_encoder.transform(data['label'])
    })
    return dataset.map(lambda x: tokenizer(x['text'], **encode_kwargs), batched=True)

train_dataset_tok = tokenize_dataset(train_dataset)
eval_dataset_tok = tokenize_dataset(eval_dataset)
test_dataset_tok = tokenize_dataset(test_dataset)


Map: 100%|██████████| 15467/15467 [00:02<00:00, 5860.53 examples/s]
Map: 100%|██████████| 1499/1499 [00:00<00:00, 5867.94 examples/s]
Map: 100%|██████████| 1428/1428 [00:00<00:00, 5757.99 examples/s]


In [None]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='macro', zero_division=0)
    return {'acc': accuracy, 'prec': precision, 'recall': recall, 'f1': f1}

training_args = TrainingArguments(
    output_dir='trainer_output',
    report_to='none',
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    eval_strategy='steps',
    eval_steps=128,
    logging_steps=128,
    save_steps=128,
    save_total_limit=3,
    learning_rate=5e-5,
    warmup_ratio=0.1,
    fp16=True,
    load_best_model_at_end=True,
    metric_for_best_model='f1',
    
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset_tok,
    eval_dataset=eval_dataset_tok,
    compute_metrics=compute_metrics,
)


In [8]:
trainer.train()


Step,Training Loss,Validation Loss,Acc,Prec,Recall,F1
128,0.7435,0.646779,0.724483,0.490896,0.489344,0.483977
256,0.6162,0.585886,0.760507,0.507559,0.517719,0.512144
384,0.5884,0.555833,0.752502,0.50207,0.512288,0.506726
512,0.5337,0.580711,0.75984,0.83941,0.528293,0.530276
640,0.451,0.589292,0.751167,0.585583,0.529966,0.537087
768,0.4522,0.603905,0.753836,0.58795,0.530992,0.538515
896,0.4373,0.598171,0.771181,0.681316,0.56217,0.58036
1024,0.3734,0.632183,0.756504,0.598424,0.549187,0.561793
1152,0.3042,0.683293,0.758506,0.601804,0.569434,0.581029
1280,0.2784,0.667854,0.762508,0.595593,0.571561,0.580581


TrainOutput(global_step=1452, training_loss=0.4541370717618748, metrics={'train_runtime': 1826.5785, 'train_samples_per_second': 25.403, 'train_steps_per_second': 0.795, 'total_flos': 6104362847998464.0, 'train_loss': 0.4541370717618748, 'epoch': 3.0})

In [9]:
model.eval()
eval_pipeline = pipeline(task='text-classification', model=model, tokenizer=tokenizer)
eval_pred = eval_pipeline([{'text': s1, 'text_pair': s2} for s1, s2 in eval_dataset_tok['text']])
labels = eval_dataset_tok['label']
predictions = [model.config.label2id[x['label']] for x in eval_pred]
accuracy = accuracy_score(labels, predictions)
precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='macro', zero_division=0)

print(f'Accuracy  = {accuracy:0.4f}')
print(f'Precision = {precision:0.4f}')
print(f'Recall    = {recall:0.4f}')
print(f'F1        = {f1:0.4f}')


Device set to use cuda:0


Accuracy  = 0.7585
Precision = 0.6018
Recall    = 0.5694
F1        = 0.5810
