In [17]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import DatasetDict, Sequence, Value, Features
import torch
import os
import sys
sys.path.append(os.getcwd()+"/../..")
from src import paths
from src.utils import model_output

In [2]:
# Load dataset
dataset = DatasetDict.load_from_disk(paths.DATA_PATH_PREPROCESSED/'line_labelling/line_labelling_clean_dataset')

# Num Labels
num_labels = len(set(dataset['train']['class_agg']))

In [None]:
# Run this cell if you want to download and fine-tune the model

# # Checkpoint
# checkpoint = "bert-base-multilingual-cased"

# # Load tokenizer
# tokenizer = AutoTokenizer.from_pretrained(checkpoint)

# # Save tokenizer
# tokenizer.save_pretrained(paths.MODEL_PATH/'bert-base-multilingual-cased')

# # Load model for embedding
# model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=num_labels, problem_type="multi_label_classification")

# # Save model
# model.save_pretrained(paths.MODEL_PATH/'bert-base-multilingual-cased')

In [13]:
# Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(paths.MODEL_PATH/'bert-base-multilingual-cased')

# Load model
model = AutoModelForSequenceClassification.from_pretrained(paths.MODEL_PATH/'bert-base-multilingual-cased', num_labels=num_labels, problem_type="multi_label_classification").to(device)

In [4]:
# Tokenize
def tokenize(examples):
    return tokenizer(examples['text'], padding="max_length", truncation=True, max_length=256, return_tensors='pt')

# # Set format of labels to FloatTensor
features = Features({'labels': Sequence(Value(dtype='float32')),
                     'input_ids': Sequence(Value(dtype='int32')),
                     'attention_mask': Sequence(Value(dtype='int32')),
                     'token_type_ids': Sequence(Value(dtype='int32')),
                     'class_agg': Value(dtype='string'),
                     'rid': Value(dtype='string'),
                     'text': Value(dtype='string'),
                     'class': Value(dtype='string')
                     })

# Tokenize dataset
dataset = dataset.map(tokenize, batched=True, features=features)


Map:   0%|          | 0/249 [00:00<?, ? examples/s]

In [5]:
# Train/Val/Test 
train_dataset = dataset['train']
val_dataset = dataset['val']
test_dataset = dataset['test']

In [15]:
# Training Arguments
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=12,              # total number of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=32,   # batch size for evaluation
    warmup_steps=200,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=10,
    load_best_model_at_end=True,
    save_strategy='epoch',
    evaluation_strategy='epoch',
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    fp16=True,
)

# Trainer
trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset            # evaluation dataset
)

In [8]:
#trainer.train()

  0%|          | 0/180 [00:00<?, ?it/s]



{'loss': 0.6859, 'learning_rate': 2.5e-06, 'epoch': 0.67}


  0%|          | 0/8 [00:00<?, ?it/s]

{'eval_loss': 0.6506651639938354, 'eval_runtime': 1.8381, 'eval_samples_per_second': 135.468, 'eval_steps_per_second': 4.352, 'epoch': 1.0}




{'loss': 0.6528, 'learning_rate': 5e-06, 'epoch': 1.33}
{'loss': 0.5922, 'learning_rate': 7.5e-06, 'epoch': 2.0}


  0%|          | 0/8 [00:00<?, ?it/s]

{'eval_loss': 0.5280624628067017, 'eval_runtime': 1.812, 'eval_samples_per_second': 137.416, 'eval_steps_per_second': 4.415, 'epoch': 2.0}




{'loss': 0.4874, 'learning_rate': 1e-05, 'epoch': 2.67}


  0%|          | 0/8 [00:00<?, ?it/s]

{'eval_loss': 0.40505319833755493, 'eval_runtime': 1.8853, 'eval_samples_per_second': 132.072, 'eval_steps_per_second': 4.243, 'epoch': 3.0}




{'loss': 0.4169, 'learning_rate': 1.25e-05, 'epoch': 3.33}
{'loss': 0.3741, 'learning_rate': 1.5e-05, 'epoch': 4.0}


  0%|          | 0/8 [00:00<?, ?it/s]

{'eval_loss': 0.34950608015060425, 'eval_runtime': 1.8019, 'eval_samples_per_second': 138.189, 'eval_steps_per_second': 4.44, 'epoch': 4.0}




{'loss': 0.3385, 'learning_rate': 1.75e-05, 'epoch': 4.67}


  0%|          | 0/8 [00:00<?, ?it/s]

{'eval_loss': 0.30657652020454407, 'eval_runtime': 1.8819, 'eval_samples_per_second': 132.31, 'eval_steps_per_second': 4.251, 'epoch': 5.0}




{'loss': 0.3046, 'learning_rate': 2e-05, 'epoch': 5.33}
{'loss': 0.2704, 'learning_rate': 2.25e-05, 'epoch': 6.0}


  0%|          | 0/8 [00:00<?, ?it/s]

{'eval_loss': 0.26361310482025146, 'eval_runtime': 1.7876, 'eval_samples_per_second': 139.293, 'eval_steps_per_second': 4.475, 'epoch': 6.0}




{'loss': 0.2333, 'learning_rate': 2.5e-05, 'epoch': 6.67}


  0%|          | 0/8 [00:00<?, ?it/s]

{'eval_loss': 0.23214852809906006, 'eval_runtime': 1.8316, 'eval_samples_per_second': 135.947, 'eval_steps_per_second': 4.368, 'epoch': 7.0}




{'loss': 0.2112, 'learning_rate': 2.7500000000000004e-05, 'epoch': 7.33}
{'loss': 0.1784, 'learning_rate': 3e-05, 'epoch': 8.0}


  0%|          | 0/8 [00:00<?, ?it/s]

{'eval_loss': 0.2063852995634079, 'eval_runtime': 1.7619, 'eval_samples_per_second': 141.323, 'eval_steps_per_second': 4.541, 'epoch': 8.0}




{'loss': 0.1529, 'learning_rate': 3.2500000000000004e-05, 'epoch': 8.67}


  0%|          | 0/8 [00:00<?, ?it/s]

{'eval_loss': 0.1923319548368454, 'eval_runtime': 1.8576, 'eval_samples_per_second': 134.046, 'eval_steps_per_second': 4.307, 'epoch': 9.0}




{'loss': 0.1398, 'learning_rate': 3.5e-05, 'epoch': 9.33}
{'loss': 0.1161, 'learning_rate': 3.7500000000000003e-05, 'epoch': 10.0}


  0%|          | 0/8 [00:00<?, ?it/s]

{'eval_loss': 0.18598468601703644, 'eval_runtime': 1.7736, 'eval_samples_per_second': 140.394, 'eval_steps_per_second': 4.511, 'epoch': 10.0}




{'loss': 0.1012, 'learning_rate': 4e-05, 'epoch': 10.67}


  0%|          | 0/8 [00:00<?, ?it/s]

{'eval_loss': 0.18416623771190643, 'eval_runtime': 1.8311, 'eval_samples_per_second': 135.986, 'eval_steps_per_second': 4.369, 'epoch': 11.0}




{'loss': 0.0853, 'learning_rate': 4.25e-05, 'epoch': 11.33}
{'loss': 0.0786, 'learning_rate': 4.5e-05, 'epoch': 12.0}


  0%|          | 0/8 [00:00<?, ?it/s]

{'eval_loss': 0.18029162287712097, 'eval_runtime': 1.7677, 'eval_samples_per_second': 140.862, 'eval_steps_per_second': 4.526, 'epoch': 12.0}
{'train_runtime': 405.6052, 'train_samples_per_second': 28.047, 'train_steps_per_second': 0.444, 'train_loss': 0.30109807418452367, 'epoch': 12.0}


TrainOutput(global_step=180, training_loss=0.30109807418452367, metrics={'train_runtime': 405.6052, 'train_samples_per_second': 28.047, 'train_steps_per_second': 0.444, 'train_loss': 0.30109807418452367, 'epoch': 12.0})

In [9]:
# Save model
#trainer.save_model(paths.MODEL_PATH/'bert-base-multilingual-cased_finetuned')

In [16]:
# Load model
trainer.model = AutoModelForSequenceClassification.from_pretrained(paths.MODEL_PATH/'bert-base-multilingual-cased_finetuned', num_labels=num_labels, problem_type="multi_label_classification").to(device)

In [18]:
# Model output
train_out = model_output(data=train_dataset, model=trainer.model, device=device)
val_out = model_output(data=val_dataset, model=trainer.model, device=device)
test_out = model_output(data=test_dataset, model=trainer.model, device=device)

# Save model output
torch.save(train_out, paths.RESULTS_PATH/'line_labelling/BERT-multilingual-finetuned-train_output.pt')
torch.save(val_out, paths.RESULTS_PATH/'line_labelling/BERT-multilingual-finetuned-val_output.pt')
torch.save(test_out, paths.RESULTS_PATH/'line_labelling/BERT-multilingual-finetuned-test_output.pt')

100%|██████████| 30/30 [00:16<00:00,  1.86it/s]
100%|██████████| 8/8 [00:04<00:00,  1.96it/s]
100%|██████████| 11/11 [00:05<00:00,  1.94it/s]
