In [1]:
import json
import os
from datasets import load_dataset
import argparse

import numpy as np
from datasets import load_dataset, load_metric
from transformers import (AutoModelForTokenClassification, AutoTokenizer,
                          DataCollatorForTokenClassification,
                          EarlyStoppingCallback, Trainer, TrainingArguments,
                          set_seed)
from sklearn.metrics import f1_score, recall_score, precision_score

import warnings
warnings.filterwarnings('ignore')
import glob

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = load_dataset('json', data_files={'train': "fial_jsons/train_all.json",
                                              'validation': 'fial_jsons/val_all.json'})
class_names = ['ARGM-PRP',
 'ARGM-NEG',
 'NAA',
 'ARGA',
 'ARGM-EXT',
 'ARGM-DIS',
 'ARG0',
 'ARGM-ADV',
 'ARG2-ATR',
 'ARG1',
 'ARG3',
 'ARGM-DIR',
 'ARGM-LOC',
 'ARG2-LOC',
 'ARG0-GOL',
 'Predicate',
 'NAH',
 'ARG2-SOU',
 'ARGM-MNS',
 'ARGM-CAU',
 'ARGM-PRX',
 'ARGM-TMP',
 'ARGM-MNR',
 'ARG2',
 'ARG2-GOL']

# Finetuning bert based models
Instruction: Run the following cells seperately for different models i.e "ai4bharat/indic-bert", "google-bert/bert-base-multilingual-cased"

In [7]:
model_name = "google-bert/bert-base-multilingual-cased"
# model_name = "ai4bharat/indic-bert"
output_dir = "output"
do_train = False
do_predict = False
label_all_tokens = True

In [8]:
set_seed(42)

def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["words"], truncation='longest_first', is_split_into_words=True, max_length=512)
    
    labels = []
    for i, label in enumerate(examples["srl"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                label_ids.append(-100)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                label_ids.append(label[word_idx] if label_all_tokens else -100)
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    true_predictions = []
    true_labels = []
    for prediction, label in zip(predictions, labels):
        for (p, l) in zip(prediction, label):
            if l != -100:
                true_predictions.append(p)
                true_labels.append(l)
                
    return {
        "precision": precision_score(true_labels,true_predictions, average='macro') ,
        "recall": recall_score(true_labels,true_predictions, average='macro'),
        "f1": f1_score(true_labels,true_predictions, average='macro'),
    }

label_list = class_names

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name, num_labels=len(label_list))
data_collator = DataCollatorForTokenClassification(tokenizer)

if do_train:

    train_dataset = dataset["train"].map(
        tokenize_and_align_labels,
        batched=True,
        remove_columns=dataset["train"].column_names,
    )
    print(f"Length of Training Dataset: {len(train_dataset)}")

    validation_dataset = dataset["validation"].map(
        tokenize_and_align_labels,
        batched=True,
        remove_columns=dataset["validation"].column_names,
    )
    print(f"Lenght of Validation Dataset: {len(validation_dataset)}")

    training_args = TrainingArguments(
        output_dir=f"{output_dir}/{model_name}",
        save_total_limit=1,
        save_strategy="epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=20,
        do_eval=True,
        evaluation_strategy="epoch",
        weight_decay=0.0,
        fp16=True,
        warmup_ratio=0.1,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False
    )

    earlystoppingcallback = EarlyStoppingCallback(early_stopping_patience=5)
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=validation_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        callbacks=[earlystoppingcallback]
    )
    trainer.train()

    trainer.evaluate()


Some weights of BertForTokenClassification were not initialized from the model checkpoint at google-bert/bert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11892/11892 [00:03<00:00, 3751.64 examples/s]


Length of Training Dataset: 11892


Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2548/2548 [00:00<00:00, 3911.62 examples/s]


Lenght of Validation Dataset: 2548


Detected kernel version 4.15.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss,Precision,Recall,F1
1,No log,1.17664,0.171081,0.135584,0.138356
2,1.566000,1.04658,0.372284,0.2105,0.225434
3,1.041100,1.01627,0.370802,0.251212,0.272658
4,1.041100,1.009658,0.328788,0.295133,0.300102
5,0.928900,1.016793,0.405713,0.304734,0.31319
6,0.826900,1.03739,0.386748,0.313435,0.317281
7,0.750200,1.088175,0.388243,0.323283,0.321584
8,0.750200,1.134245,0.391525,0.358146,0.347632
9,0.670900,1.179428,0.398527,0.333578,0.340491


In [9]:
# Saving the best model checkpoints
trainer.save_model("bert-multilingual-cased-lat-finetuned-srl")

# Testing language wise for different cases
1. Argument identification
2. Argument classification
3. Overall

use `compute_metrics_argument_identification` as compute-metric argument for argument identification 

use `compute_metrics_argument_classification` as compute-metric argument for argument classification 

use `compute_metrics_all` as compute-metric argument for overall performance

Instuction: Run the following cells multiple times for the above 3 performances by changing `compute-metric` argument when calling trainer api

In [3]:
def compute_metrics_argument_identification(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    true_predictions = []
    true_labels = []
    for prediction, label in zip(predictions, labels):
        for (p, l) in zip(prediction, label):
            if l != -100:
                if l not in [15,16]:
                        true_labels.append(0 if l == 2 else 1)
                        true_predictions.append(0 if p == 2 else 1)
                
    return {
        "precision-macro": precision_score(true_labels,true_predictions, average='macro') ,
        "recall-macro": recall_score(true_labels,true_predictions, average='macro'),
        "f1-macro": f1_score(true_labels,true_predictions, average='macro'),
        "precision-weighted": precision_score(true_labels,true_predictions, average='weighted') ,
        "recall-weighted": recall_score(true_labels,true_predictions, average='weighted'),
        "f1-weighted": f1_score(true_labels,true_predictions, average='weighted'),
                            
    }


def compute_metrics_argument_classification(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    true_predictions = []
    true_labels = []
    for prediction, label in zip(predictions, labels):
        for (p, l) in zip(prediction, label):
            if l != -100:
                if l not in [15,16]:
                    true_predictions.append(p)
                    true_labels.append(l)
                
    return {
        "precision-macro": precision_score(true_labels,true_predictions, average='macro') ,
        "recall-macro": recall_score(true_labels,true_predictions, average='macro'),
        "f1-macro": f1_score(true_labels,true_predictions, average='macro'),
        "precision-weighted": precision_score(true_labels,true_predictions, average='weighted') ,
        "recall-weighted": recall_score(true_labels,true_predictions, average='weighted'),
        "f1-weighted": f1_score(true_labels,true_predictions, average='weighted'),
    }


def compute_metrics_all(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    true_predictions = []
    true_labels = []
    for prediction, label in zip(predictions, labels):
        for (p, l) in zip(prediction, label):
            if l != -100:
                true_predictions.append(p)
                true_labels.append(l)
                
    return {
        "precision-macro": precision_score(true_labels,true_predictions, average='macro') ,
        "recall-macro": recall_score(true_labels,true_predictions, average='macro'),
        "f1-macro": f1_score(true_labels,true_predictions, average='macro'),
        "precision-weighted": precision_score(true_labels,true_predictions, average='weighted') ,
        "recall-weighted": recall_score(true_labels,true_predictions, average='weighted'),
        "f1-weighted": f1_score(true_labels,true_predictions, average='weighted'),
    }

def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["words"], truncation='longest_first', is_split_into_words=True, max_length=512)
    print(f"Label all tokens from tokenize_and_align_labels function {label_all_tokens}")
    labels = []
    for i, label in enumerate(examples["srl"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                label_ids.append(-100)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                label_ids.append(label[word_idx] if label_all_tokens else -100)
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [5]:
models = ['indic-bert-finetuned-srl','indic-bert-lat-finetuned-srl','bert-multilingual-cased-finetuned-srl','bert-multilingual-cased-lat-finetuned-srl']
for model_name in models:
    if 'lat' in model_name:
        label_all_tokens = True
    else:
        label_all_tokens = False
    
    print(f"---------------- {model_name} -------------------------")
    print(f"Label all tokens from main function {label_all_tokens}")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForTokenClassification.from_pretrained(model_name)
    data_collator = DataCollatorForTokenClassification(tokenizer)

    training_args = TrainingArguments(
                    output_dir = model_name,
                    per_device_eval_batch_size=16,
                    fp16=True,
                    )
    trainer = Trainer(
        model=model,
        args=training_args,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics_all,
    )
    for file in glob.glob("fial_jsons/test_*.json"):
        lang = file.split("/")[-1].split(".")[0].split("_")[-1]
        # print(lang)
        dataset = load_dataset('json', data_files={'test': file})
    
        test_dataset = dataset['test'].map(
            tokenize_and_align_labels,
            batched=True,
            remove_columns=dataset['test'].column_names,
        )
        print(f"Length of Test Dataset: {len(test_dataset)}")
        
        results = trainer.predict(test_dataset).metrics
        print(f"Results for {lang}: {results}")
    
   



---------------- indic-bert-finetuned-srl -------------------------
Label all tokens from main function False


Detected kernel version 4.15.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Length of Test Dataset: 638


Results for hindi: {'test_loss': 0.8649789690971375, 'test_precision-macro': 0.3970191912584015, 'test_recall-macro': 0.24843825911307707, 'test_f1-macro': 0.2807057564142384, 'test_precision-weighted': 0.7033877970987126, 'test_recall-weighted': 0.7153618336337884, 'test_f1-weighted': 0.7043757453869288, 'test_runtime': 1.7414, 'test_samples_per_second': 366.371, 'test_steps_per_second': 11.485}
Length of Test Dataset: 638


Results for urdu: {'test_loss': 0.8606521487236023, 'test_precision-macro': 0.3820534250223506, 'test_recall-macro': 0.24611217770609306, 'test_f1-macro': 0.2772327284878324, 'test_precision-weighted': 0.6921091341047291, 'test_recall-weighted': 0.7138384086444007, 'test_f1-weighted': 0.6915894495428145, 'test_runtime': 4.0043, 'test_samples_per_second': 159.328, 'test_steps_per_second': 4.995}
Length of Test Dataset: 638


Results for tamil: {'test_loss': 1.1957036256790161, 'test_precision-macro': 0.308755211025052, 'test_recall-macro': 0.16948987970187757, 'test_f1-macro': 0.1899522338500692, 'test_precision-weighted': 0.5855275957332987, 'test_recall-weighted': 0.6037459760023413, 'test_f1-weighted': 0.5827483566091488, 'test_runtime': 1.941, 'test_samples_per_second': 328.697, 'test_steps_per_second': 10.304}
Length of Test Dataset: 638


Results for telugu: {'test_loss': 1.2218786478042603, 'test_precision-macro': 0.2804714022594595, 'test_recall-macro': 0.16461760862065297, 'test_f1-macro': 0.18487954240028429, 'test_precision-weighted': 0.5842227447156424, 'test_recall-weighted': 0.6014800197335964, 'test_f1-weighted': 0.5793397162779664, 'test_runtime': 1.7363, 'test_samples_per_second': 367.443, 'test_steps_per_second': 11.519}
---------------- indic-bert-lat-finetuned-srl -------------------------
Label all tokens from main function True


Detected kernel version 4.15.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Length of Test Dataset: 638


Results for hindi: {'test_loss': 0.9695111513137817, 'test_precision-macro': 0.3796522093417347, 'test_recall-macro': 0.2005207752537422, 'test_f1-macro': 0.22977918137163858, 'test_precision-weighted': 0.6776578708219299, 'test_recall-weighted': 0.6758569299552906, 'test_f1-weighted': 0.6618471008361207, 'test_runtime': 1.7795, 'test_samples_per_second': 358.52, 'test_steps_per_second': 11.239}
Length of Test Dataset: 638


Results for urdu: {'test_loss': 0.9307342767715454, 'test_precision-macro': 0.3643717107941405, 'test_recall-macro': 0.24859044649404569, 'test_f1-macro': 0.28514811790653755, 'test_precision-weighted': 0.6806664924433051, 'test_recall-weighted': 0.6872126947982637, 'test_f1-weighted': 0.6736706333911643, 'test_runtime': 4.2778, 'test_samples_per_second': 149.143, 'test_steps_per_second': 4.675}
Length of Test Dataset: 638


Results for tamil: {'test_loss': 1.3113352060317993, 'test_precision-macro': 0.3089140695961398, 'test_recall-macro': 0.15757126708982672, 'test_f1-macro': 0.18493101223007535, 'test_precision-weighted': 0.5622518493283274, 'test_recall-weighted': 0.5698211462756889, 'test_f1-weighted': 0.5441155519019689, 'test_runtime': 2.0694, 'test_samples_per_second': 308.306, 'test_steps_per_second': 9.665}
Length of Test Dataset: 638


Results for telugu: {'test_loss': 1.3257864713668823, 'test_precision-macro': 0.2763183040881881, 'test_recall-macro': 0.1494943525791851, 'test_f1-macro': 0.1735637500549916, 'test_precision-weighted': 0.5574667397856046, 'test_recall-weighted': 0.5686422957675391, 'test_f1-weighted': 0.5407162523320217, 'test_runtime': 1.8396, 'test_samples_per_second': 346.823, 'test_steps_per_second': 10.872}
---------------- bert-multilingual-cased-finetuned-srl -------------------------
Label all tokens from main function False


Detected kernel version 4.15.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Length of Test Dataset: 638


Results for hindi: {'test_loss': 0.7975223064422607, 'test_precision-macro': 0.36781928753473286, 'test_recall-macro': 0.3212379740234714, 'test_f1-macro': 0.3243661283792209, 'test_precision-weighted': 0.7278269692373353, 'test_recall-weighted': 0.7325521503991759, 'test_f1-weighted': 0.7292267642965696, 'test_runtime': 2.9225, 'test_samples_per_second': 218.304, 'test_steps_per_second': 6.843}
Length of Test Dataset: 638


Results for urdu: {'test_loss': 0.582085132598877, 'test_precision-macro': 0.4552410899012174, 'test_recall-macro': 0.39895659935227235, 'test_f1-macro': 0.4086628383386534, 'test_precision-weighted': 0.8015562478369707, 'test_recall-weighted': 0.806544695481336, 'test_f1-weighted': 0.8031085443057607, 'test_runtime': 2.8038, 'test_samples_per_second': 227.546, 'test_steps_per_second': 7.133}
Length of Test Dataset: 638


Results for tamil: {'test_loss': 1.1070899963378906, 'test_precision-macro': 0.295489053923806, 'test_recall-macro': 0.2346720815334721, 'test_f1-macro': 0.2483453706439215, 'test_precision-weighted': 0.6151124724635257, 'test_recall-weighted': 0.6328163106038436, 'test_f1-weighted': 0.6215929572627091, 'test_runtime': 3.1929, 'test_samples_per_second': 199.815, 'test_steps_per_second': 6.264}
Length of Test Dataset: 638


Results for telugu: {'test_loss': 1.0996220111846924, 'test_precision-macro': 0.28312059328028377, 'test_recall-macro': 0.23602395775867402, 'test_f1-macro': 0.24810021163951732, 'test_precision-weighted': 0.6179542663812712, 'test_recall-weighted': 0.635027133695116, 'test_f1-weighted': 0.6241226039048188, 'test_runtime': 3.433, 'test_samples_per_second': 185.844, 'test_steps_per_second': 5.826}
---------------- bert-multilingual-cased-lat-finetuned-srl -------------------------
Label all tokens from main function True


Detected kernel version 4.15.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Length of Test Dataset: 638


Results for hindi: {'test_loss': 0.9272770881652832, 'test_precision-macro': 0.3792084843164745, 'test_recall-macro': 0.2982529148743553, 'test_f1-macro': 0.3094828878870961, 'test_precision-weighted': 0.6841318934661724, 'test_recall-weighted': 0.690176322418136, 'test_f1-weighted': 0.6849535174389756, 'test_runtime': 2.994, 'test_samples_per_second': 213.094, 'test_steps_per_second': 6.68}
Length of Test Dataset: 638


Results for urdu: {'test_loss': 0.6847628355026245, 'test_precision-macro': 0.43125761214858943, 'test_recall-macro': 0.360243525323777, 'test_f1-macro': 0.3791191403921131, 'test_precision-weighted': 0.7695020069248988, 'test_recall-weighted': 0.7772697583278902, 'test_f1-weighted': 0.7712362823533837, 'test_runtime': 2.8528, 'test_samples_per_second': 223.643, 'test_steps_per_second': 7.011}
Length of Test Dataset: 638


Results for tamil: {'test_loss': 1.2331571578979492, 'test_precision-macro': 0.2744327903318415, 'test_recall-macro': 0.21872827393823877, 'test_f1-macro': 0.23311105086313155, 'test_precision-weighted': 0.5735115182913081, 'test_recall-weighted': 0.5912012003510461, 'test_f1-weighted': 0.5795851141454745, 'test_runtime': 3.2959, 'test_samples_per_second': 193.572, 'test_steps_per_second': 6.068}
Length of Test Dataset: 638


Results for telugu: {'test_loss': 1.2283591032028198, 'test_precision-macro': 0.28302176382653993, 'test_recall-macro': 0.2270439855315812, 'test_f1-macro': 0.24344856830234443, 'test_precision-weighted': 0.5762767836634966, 'test_recall-weighted': 0.5942475741167641, 'test_f1-weighted': 0.5820895840885354, 'test_runtime': 3.5579, 'test_samples_per_second': 179.321, 'test_steps_per_second': 5.621}
