In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import DatasetDict, Sequence, Value, Features
import torch
from torch.utils.data import DataLoader
import os
import sys
sys.path.append(os.getcwd()+"/../..")
from src import paths
import tqdm

In [None]:
# Load dataset
dataset = DatasetDict.load_from_disk(paths.DATA_PATH_PREPROCESSED/'line_labelling/line_labelling_clean_dataset')

# Num Labels
num_labels = len(set(dataset['train']['class_agg']))

In [None]:
# Run this cell if you want to download and fine-tune the model

# # Checkpoint
# checkpoint = "bert-base-multilingual-cased"

# # Load tokenizer
# tokenizer = AutoTokenizer.from_pretrained(checkpoint)

# # Save tokenizer
# tokenizer.save_pretrained(paths.MODEL_PATH/'bert-base-multilingual-cased')

# # Load model for embedding
# model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=num_labels, problem_type="multi_label_classification")

# # Save model
# model.save_pretrained(paths.MODEL_PATH/'bert-base-multilingual-cased')

In [None]:
# Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(paths.MODEL_PATH/'bert-base-multilingual-cased')

# Load model
model = AutoModelForSequenceClassification.from_pretrained(paths.MODEL_PATH/'bert-base-multilingual-cased', num_labels=num_labels, problem_type="multi_label_classification").to(device)

In [None]:
# Tokenize
def tokenize(examples):
    return tokenizer(examples['text'], padding="max_length", truncation=True, max_length=256, return_tensors='pt')

# # Set format of labels to FloatTensor
features = Features({'labels': Sequence(Value(dtype='float32')),
                     'input_ids': Sequence(Value(dtype='int32')),
                     'attention_mask': Sequence(Value(dtype='int32')),
                     'token_type_ids': Sequence(Value(dtype='int32')),
                     'class_agg': Value(dtype='string'),
                     'rid': Value(dtype='string'),
                     'text': Value(dtype='string'),
                     'class': Value(dtype='string')
                     })

# Tokenize dataset
dataset = dataset.map(tokenize, batched=True, features=features)


In [None]:
# Train/Val/Test 
train_dataset = dataset['train']
val_dataset = dataset['val']
test_dataset = dataset['test']

In [None]:
# Training Arguments
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=12,              # total number of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=32,   # batch size for evaluation
    warmup_steps=200,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=10,
    load_best_model_at_end=True,
    save_strategy='epoch',
    evaluation_strategy='epoch',
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    fp16=True,
)

# Trainer
trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset            # evaluation dataset
)

In [None]:
trainer.train()

In [None]:
# Save model
trainer.save_model(paths.MODEL_PATH/'bert-base-multilingual-cased_finetuned')

In [None]:
# Load model
trainer.model = AutoModelForSequenceClassification.from_pretrained(paths.MODEL_PATH/'bert-base-multilingual-cased_finetuned', num_labels=num_labels, problem_type="multi_label_classification").to(device)

In [None]:
# Embedd test set
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)
embeddings = []
for batch in tqdm.tqdm(test_dataloader):
    input_ids = torch.stack(batch['input_ids'], dim=1).to(device)
    attention_mask = torch.stack(batch['attention_mask'], dim=1).to(device)
    token_type_ids = torch.stack(batch['token_type_ids'], dim=1).to(device)
    with torch.no_grad():
        output = trainer.model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True)
        embeddings.append(output.hidden_states[-1].cpu())

embeddings = torch.cat(embeddings, dim=0)

# Save embeddings
torch.save(embeddings, paths.RESULTS_PATH/'line_labelling'/'multilingual-bert-fine-tuned-embeddings.pt')

In [None]:
# Predict
predictions = trainer.predict(test_dataset)

# Save predictions
torch.save(predictions, paths.RESULTS_PATH/'line_labelling'/'multilingual-bert-fine-tuned-predictions.pt')