In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
DATA_PATH = '/content/drive/MyDrive/domain-helper/'
LOG_PATH = '/content/drive/MyDrive/domain-helper/logs/'

!pip install -q transformers datasets nlp

from nlp import load_dataset, DatasetDict



dataset = load_dataset('csv', data_files=[str(DATA_PATH) + '/gar_level_labeled_data.csv'], split='train' )

# 90% train, 10% test + validation
train_test_valid = dataset.train_test_split(test_size=0.1)
# Split the 10% test + valid in half test, half valid
test_valid = train_test_valid['test'].train_test_split(test_size=0.5)
# gather everyone if you want to have a single DatasetDict
dataset = DatasetDict({
    'train': train_test_valid['train'],
    'test': test_valid['test'],
    'validation': test_valid['train']})

labels = [label for label in dataset['train'].features.keys() if label not in ['id', 'text']]
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}




[K     |████████████████████████████████| 4.2 MB 4.9 MB/s 
[K     |████████████████████████████████| 346 kB 92.6 MB/s 
[K     |████████████████████████████████| 1.7 MB 86.0 MB/s 
[K     |████████████████████████████████| 86 kB 6.0 MB/s 
[K     |████████████████████████████████| 6.6 MB 77.4 MB/s 
[K     |████████████████████████████████| 596 kB 84.5 MB/s 
[K     |████████████████████████████████| 212 kB 90.8 MB/s 
[K     |████████████████████████████████| 86 kB 6.5 MB/s 
[K     |████████████████████████████████| 140 kB 29.0 MB/s 
[K     |████████████████████████████████| 1.1 MB 80.7 MB/s 
[K     |████████████████████████████████| 127 kB 94.8 MB/s 
[K     |████████████████████████████████| 94 kB 1.3 MB/s 
[K     |████████████████████████████████| 271 kB 99.7 MB/s 
[K     |████████████████████████████████| 144 kB 92.1 MB/s 
[K     |████████████████████████████████| 112 kB 103.6 MB/s 
[31mERROR: pip's dependency resolver does not currently take into account all the packages

Downloading:   0%|          | 0.00/2.75k [00:00<?, ?B/s]

Using custom data configuration default


Downloading and preparing dataset csv/default-c2001a0f530a8856 (download: Unknown size, generated: Unknown size, post-processed: Unknown sizetotal: Unknown size) to /root/.cache/huggingface/datasets/csv/default-c2001a0f530a8856/0.0.0/ede98314803c971fef04bcee45d660c62f3332e8a74491e0b876106f3d99bd9b...


0 tables [00:00, ? tables/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-c2001a0f530a8856/0.0.0/ede98314803c971fef04bcee45d660c62f3332e8a74491e0b876106f3d99bd9b. Subsequent calls will reuse this data.


  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
from transformers import CamembertTokenizerFast
import numpy as np

tokenizer = CamembertTokenizerFast.from_pretrained("camembert-base")

def preprocess_data(examples):
  titles = examples["text"]
  #auto max length is set to 62, but most examples dont exceed 30. Force max length parameter ?
  encoding = tokenizer(titles, padding=True, truncation=True, max_length=36)
  labels_batch = {k: examples[k] for k in examples.keys() if k in labels}
  labels_matrix = np.zeros((len(titles), len(labels)))
  for idx, label in enumerate(labels):
    labels_matrix[:, idx] = labels_batch[label]
  encoding["labels"] = labels_matrix.tolist()
  return encoding

In [None]:
encoded_dataset = dataset.map(preprocess_data, batched=True, remove_columns=dataset['train'].column_names)
encoded_dataset.set_format("torch")

In [None]:
from transformers import CamembertForSequenceClassification

model = CamembertForSequenceClassification.from_pretrained("camembert-base", 
                                                           problem_type="multi_label_classification", 
                                                           num_labels=len(labels),
                                                           id2label=id2label,
                                                           label2id=label2id)

In [None]:
metric_name = "f1"
batch_size = 20
epochs = 100
weight_decay = 0.01

from transformers import TrainingArguments, Trainer

args = TrainingArguments(
    output_dir=f'{DATA_PATH}/level-results-2',
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=epochs,
    weight_decay=weight_decay,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name
)

In [7]:
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from transformers import EvalPrediction
import torch
    
# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/
def multi_label_metrics(predictions, labels, threshold=0.5):
    # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    # next, use threshold to turn them into integer predictions
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1
    # finally, compute metrics
    y_true = labels
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')
    accuracy = accuracy_score(y_true, y_pred)
    # return as dictionary
    metrics = {'f1': f1_micro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy}
    return metrics

def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, 
            tuple) else p.predictions
    result = multi_label_metrics(
        predictions=preds, 
        labels=p.label_ids)
    return result

In [8]:
trainer = Trainer(
    model,
    args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()

***** Running training *****
  Num examples = 8552
  Num Epochs = 100
  Instantaneous batch size per device = 20
  Total train batch size (w. parallel, distributed & accumulation) = 20
  Gradient Accumulation steps = 1
  Total optimization steps = 42800


Epoch,Training Loss,Validation Loss,F1,Roc Auc,Accuracy
1,No log,0.191343,0.0,0.5,0.0


***** Running Evaluation *****
  Num examples = 475
  Batch size = 20
Saving model checkpoint to /content/drive/MyDrive/domain-helper//level-results-2/checkpoint-428
Configuration saved in /content/drive/MyDrive/domain-helper//level-results-2/checkpoint-428/config.json
Model weights saved in /content/drive/MyDrive/domain-helper//level-results-2/checkpoint-428/pytorch_model.bin
tokenizer config file saved in /content/drive/MyDrive/domain-helper//level-results-2/checkpoint-428/tokenizer_config.json
Special tokens file saved in /content/drive/MyDrive/domain-helper//level-results-2/checkpoint-428/special_tokens_map.json


Epoch,Training Loss,Validation Loss,F1,Roc Auc,Accuracy
1,No log,0.191343,0.0,0.5,0.0
2,0.367400,0.153121,0.0,0.5,0.0


***** Running Evaluation *****
  Num examples = 475
  Batch size = 20
Saving model checkpoint to /content/drive/MyDrive/domain-helper//level-results-2/checkpoint-856
Configuration saved in /content/drive/MyDrive/domain-helper//level-results-2/checkpoint-856/config.json
Model weights saved in /content/drive/MyDrive/domain-helper//level-results-2/checkpoint-856/pytorch_model.bin
tokenizer config file saved in /content/drive/MyDrive/domain-helper//level-results-2/checkpoint-856/tokenizer_config.json
Special tokens file saved in /content/drive/MyDrive/domain-helper//level-results-2/checkpoint-856/special_tokens_map.json


In [None]:
trainer.evaluate()

In [None]:
french_labels = {"http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-002": "école maternelle",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-008": "école élémentaire",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-016": "collège",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-065": "voie professionnelle",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-066": "voie générale et technologique",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-140": "adaptation scolaire et scolarisation des élèves handicapés",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-141": "collège professionnel",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-142": "lycée général et technologique",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-163": "formation des personnels de l'éducation nationale",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-164": "formation professionnelle continue des adultes",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-165": "lycée professionnel",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-185": "licence professionnelle",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-195": "section BTS",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-198": "section DUT",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-226": "voie CPGE",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-239": "11-15 ans",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-240": "15-18 ans",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-241": "3-6 ans",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-242": "6-11 ans",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-615": "cycles de l'enseignement scolaire (2016)",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-063": "formation continue des personnels de l'éducation nationale",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-136": "2de générale et technologique",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-146": "1re générale",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-201": "terminale générale",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-211": "terminale technologique",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-134": "1re ST2S",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-123": "1re générale et technologique",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-083": "terminale STMG",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-047": "terminale STI2D",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-043": "terminale ST2S",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-044": "terminale STD2A",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-048": "terminale STL",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-641": "terminale STHR",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-650": "2de STHR",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-132": "1re STL",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-237": "1re technologique",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-620": "cycle 3 (2016)",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-621": "cycle 4 (2016)",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-619": "cycle 2 (2016)",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-027": "2de professionnelle",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-050": "terminale professionnelle",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-127": "1re professionnelle",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-223": "voie CAP",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-133": "1re STMG",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-129": "1re STD2A",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-131": "1re STI2D",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-088": "terminale générale et technologique",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-617": "cycle 1 (2016)",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-640": "1re STHR",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-162": "enseignement supérieur en lycée",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-214": "voie BMA",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-221": "voie BT",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-222": "voie BTM",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-187": "diplôme de comptabilité et gestion",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-212": "terminale TMD",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-238": "1re TMD",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-023": "3e",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-618": "cycles de la scolarité obligatoire",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-004": "toute petite section",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-005": "petite section",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-006": "moyenne section",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-007": "grande section",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-003": "cycle 1",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-009": "cycle 2",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-010": "CP",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-011": "CE1",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-013": "CE2",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-014": "CM1",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-015": "CM2",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-012": "cycle 3",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-018": "6e",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-020": "5e",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-021": "4e",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-017": "cycle d'adaptation",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-019": "cycle central",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-022": "cycle d'orientation",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-608": "3e d'insertion",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-084": "classe de détermination",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-040": "terminale ES",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-041": "terminale L",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-042": "terminale S",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-085": "cycle terminal",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-233": "voie technologique",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-049": "terminale techno hôtellerie",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-089": "terminale pro domaine de la production",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-090": "terminale pro domaine des services",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-095": "BMA - 1re année",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-096": "BMA - 2e année",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-097": "BP - 1re année",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-098": "BP - 2e année",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-099": "BTM - 1re année",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-100": "BTM - 2e année",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-103": "CAP - 1re année",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-104": "CAP - 2e année",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-125": "1re pro domaine de la production",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-126": "1re pro domaine des services",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-138": "3e professionnelle",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-139": "4e professionnelle",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-213": "voie BEP",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-215": "voie BP",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-227": "voie FC",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-231": "voie MC",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-288": "FC post niveau IV",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-298": "FC post niveau V",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-230": "voie générale",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-135": "1re techno hôtellerie",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-094": "cycles de l'enseignement scolaire",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-101": "BTS - 1re année",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-102": "BTS - 2e année",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-109": "CPGE - 1re année",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-110": "CPGE - 2e année",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-111": "DUT - 1re année",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-112": "DUT - 2e année",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-121": "tranches d'âge de l'enseignement scolaire",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-626": "Niveau éducatif détaillé (2015-)",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-122": "1re ES",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-124": "1re L",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-128": "1re S",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-150": "ASSEH - 1er degré",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-151": "ASSEH - 2d degré",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-153": "BT - première",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-154": "BT - terminale",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-234": "voies générale, technologique, professionnelle",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-284": "formation professionnelle",
                 "http://data.education.fr/voc/scolomfr/concept/scolomfr-voc-022-num-616": "cycle unique pour l'école maternelle"}

In [None]:
texts = list(map(lambda entry: entry['text'], dataset["test"]))

for text  in texts:
  encoding = tokenizer(text, return_tensors="pt")
  encoding = {k: v.to(trainer.model.device) for k,v in encoding.items()}

  outputs = trainer.model(**encoding)
  logits = outputs.logits
  # apply sigmoid + threshold
  sigmoid = torch.nn.Sigmoid()
  probs = sigmoid(logits.squeeze().cpu())
  predictions = np.zeros(probs.shape)
  predictions[np.where(probs >= 0.5)] = 1
  # turn predicted id's into actual label names
  predicted_labels = [id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]
  predicted_labels_with_scores = [french_labels[f'http://data.education.fr/voc/scolomfr/concept/{id2label[idx]}'] + '->' + str(round(probs[idx].item()*100,2))+'%' for idx, label in enumerate(predictions) if label == 1.0]
  print("***************")
  print (text)
  print(predicted_labels_with_scores)
  max = torch.argmax(probs).item()
  print('Maximum : ' + french_labels[f'http://data.education.fr/voc/scolomfr/concept/{id2label[max]}'])
