In [1]:
from datasets import Dataset, DatasetDict
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments
from sklearn.metrics import precision_recall_fscore_support
import evaluate
import numpy as np

import warnings
warnings.filterwarnings("ignore")

from eval_tool import sample_from_big_string, sample_texts_from_dir


HUGGINGFACE_MODEL = "distilbert/distilbert-base-multilingual-cased"
OUTPUT_DIR = "distilmbert_lc_model_80_b"
num_train_per_lang = 80
train_len = 256
num_test_per_lang = 20
test_len = 256


TOKENIZER = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL)
def preprocess_function(examples):
    return TOKENIZER(examples["text"], max_length=512, truncation=True)


ACCURACY = evaluate.load("accuracy")
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return ACCURACY.compute(predictions=predictions, references=labels)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    precision, recall, fscore, _ = precision_recall_fscore_support(labels, predictions, average='weighted')
    return {
        'precision': precision,
        'recall': recall,
        'fscore': fscore
    }

data_collator = DataCollatorWithPadding(tokenizer=TOKENIZER)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import logging
import os
from pathlib import Path

# Small Wikipedia corpus from https://lukelindemann.com/wiki_corpus.html, preprocessed using lplangid training script.
wiki_root = Path.home() / "Data" / "WikipediaLindemann"
language_codes = os.listdir(wiki_root / "train")
label2id = {lang: idx for idx, lang in enumerate(language_codes)}
id2label = {idx: lang for lang, idx in label2id.items()}

train_texts, train_labels, test_texts, test_labels = [], [], [], []

for lang in language_codes:
    train_fh = open(wiki_root / "train" / lang, encoding='utf-8')
    test_fh = open(wiki_root / "test" / lang, encoding='utf-8')
    train_contents = train_fh.read()
    test_contents = test_fh.read()

    # Check there is enough data, otherwise skip.
    if len(train_contents) < train_len * num_train_per_lang or len(test_contents) < test_len * num_test_per_lang:
        logging.warning(f"{len(train_contents)} training characters for language '{lang}'. We need {train_len * num_train_per_lang}.")
        logging.warning(f"{len(test_contents)} training characters for language '{lang}'. We need {test_len * num_test_per_lang}.")
        logging.warning("Skipping")
        continue

    train_texts.extend(sample_from_big_string(train_contents, train_len, num_train_per_lang))
    train_labels.extend([label2id[lang]] * num_train_per_lang)
    test_texts.extend(sample_from_big_string(train_contents, test_len, num_test_per_lang))
    test_labels.extend([label2id[lang]] * num_test_per_lang)

train_dataset = Dataset.from_dict({"text": train_texts, "label": train_labels})
test_dataset = Dataset.from_dict({"text": test_texts, "label": test_labels})

# Bundle into a DatasetDict
dataset_dict = DatasetDict({
    "train": train_dataset,
    "test": test_dataset
})
tokenized_data = dataset_dict.map(preprocess_function, batched=True)




Map: 100%|██████████| 22560/22560 [00:03<00:00, 7501.15 examples/s]
Map: 100%|██████████| 5640/5640 [00:00<00:00, 8118.87 examples/s]


In [11]:
model = AutoModelForSequenceClassification.from_pretrained(
        HUGGINGFACE_MODEL,
        num_labels=len(label2id), id2label=id2label, label2id=label2id
    )

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_data["train"],
    eval_dataset=tokenized_data["test"],
    tokenizer=TOKENIZER,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
retrain = True
if retrain:
    print(f"Starting training based on {HUGGINGFACE_MODEL} ... outputting to {trainer.args.output_dir}")
    print(f"Labels: {len(set(test_labels))} Num train: {num_train_per_lang} (len {train_len}). Num test: {num_test_per_lang} (len {test_len}).")
    trainer.train()
    print("Done training.")

Starting training based on distilbert/distilbert-base-multilingual-cased ... outputting to distilbert_lc_model_80_b
Labels: 282 Num train: 80 (len 256). Num test: 20 (len 256).


Epoch,Training Loss,Validation Loss,Precision,Recall,Fscore
1,2.8356,1.051048,0.856825,0.860284,0.842108
2,0.5238,0.2933,0.915068,0.920035,0.908701
3,0.2583,0.167055,0.939126,0.945745,0.938646
4,0.1728,0.144305,0.953232,0.945745,0.939434
5,0.137,0.106313,0.958431,0.957979,0.953299
6,0.1195,0.096519,0.96232,0.962234,0.9584
7,0.0893,0.080763,0.967046,0.96844,0.964432
8,0.0892,0.072717,0.97426,0.970745,0.967565
9,0.0823,0.069296,0.974026,0.971986,0.969491
10,0.0802,0.067643,0.969185,0.97305,0.969544


Done training.


In [4]:
lc_model = AutoModelForSequenceClassification.from_pretrained(
    OUTPUT_DIR + "/checkpoint-14100",
    num_labels=len(label2id), id2label=id2label, label2id=label2id)

evaluator = Trainer(
    model=lc_model,
    eval_dataset=tokenized_data["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

test_results = evaluator.evaluate(tokenized_data["test"])
print(", ".join([f"{k}: {v:0.4f}" for k, v in test_results.items()]))


eval_loss: 0.0676, eval_precision: 0.9692, eval_recall: 0.9730, eval_fscore: 0.9695, eval_runtime: 20.5156, eval_samples_per_second: 274.9120, eval_steps_per_second: 34.3640


In [7]:
all_results = {}
for eval_strlen in [16, 32, 64, 128, 256]:
    eval_texts, eval_labels = sample_texts_from_dir(Path(wiki_root) / "test", eval_strlen, num_test_per_lang)

    eval_dataset = Dataset.from_dict({"text": eval_texts, "label": [label2id[l] for l in eval_labels]})
    tokenized_eval_data = eval_dataset.map(preprocess_function, batched=True)

    evaluator = Trainer(
        model=lc_model,
        eval_dataset=tokenized_eval_data,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )

    eval_results = evaluator.evaluate(tokenized_eval_data)
    all_results[eval_strlen] = eval_results


for eval_strlen, eval_results in all_results.items():
    print(f"Results for length {eval_strlen}")
    print(", ".join([f"{k}: {v:0.4f}" for k, v in eval_results.items()]))


Map: 100%|██████████| 5800/5800 [00:00<00:00, 48306.12 examples/s]


Map: 100%|██████████| 5800/5800 [00:00<00:00, 16108.81 examples/s]


Map: 100%|██████████| 5800/5800 [00:00<00:00, 23825.13 examples/s]


Map: 100%|██████████| 5800/5800 [00:00<00:00, 14059.37 examples/s]


Map: 100%|██████████| 5800/5800 [00:00<00:00, 8162.97 examples/s]


Results for length 16
eval_loss: 2.9565, eval_precision: 0.5560, eval_recall: 0.4667, eval_fscore: 0.4619, eval_runtime: 4.8304, eval_samples_per_second: 1200.7290, eval_steps_per_second: 150.0910
Results for length 32
eval_loss: 2.0615, eval_precision: 0.6743, eval_recall: 0.6347, eval_fscore: 0.6260, eval_runtime: 5.0815, eval_samples_per_second: 1141.3880, eval_steps_per_second: 142.6730
Results for length 64
eval_loss: 1.4544, eval_precision: 0.7739, eval_recall: 0.7717, eval_fscore: 0.7599, eval_runtime: 6.7701, eval_samples_per_second: 856.7070, eval_steps_per_second: 107.0880
Results for length 128
eval_loss: 1.1765, eval_precision: 0.8394, eval_recall: 0.8419, eval_fscore: 0.8306, eval_runtime: 11.0415, eval_samples_per_second: 525.2910, eval_steps_per_second: 65.6610
Results for length 256
eval_loss: 0.8697, eval_precision: 0.8738, eval_recall: 0.8803, eval_fscore: 0.8669, eval_runtime: 19.7455, eval_samples_per_second: 293.7380, eval_steps_per_second: 36.7170
