# Semantic Role Labelling with Fine-tuned BERT models

This notebook is about the process of fine-tuning pre-trained BERT models for the task of Semantic Role Labelling (SRL). First, make sure the libraries listed in the 'requirements' are available. If you run this notebook on Google Colab, you might need to install additional libraries and restart the kernel afterwards, for details see the requirements.txt. The functions are organised in util scripts, which we import in this notebook.

Make sure your version of Transformers is at least 4.11.0 since the functionality was introduced in that version.


In [3]:
from utils_preprocessing import *
from utils_data_processing import *
from utils_evaluation import *
from utils_model_training import *

print(transformers.__version__)

  from .autonotebook import tqdm as notebook_tqdm
  assert(len(actual_labels), len(actual_predictions))


4.38.2


Ensure that the trained models are placed into the 'models' subdirectory. If you want to start fine-tuning from scratch, use 'distrilbert-base-uncased'. Below we also define the output paths where model predictions will be written.

In [1]:
# model1,2,3 for loading fine-tuned model, "distilbert-base-uncased" for fine-tuning again
model = "distilbert-base-uncased"
model1 = "models/BERT1_new" 
model2 = "models/BERT2_new"
model3 = "models/BERT3_new" 

output_path1 = "prediction_output/BERT1_predictions.tsv"
output_path2 = "prediction_output/BERT2_predictions.tsv"
output_path3 = "prediction_output/BERT3_predictions.tsv"

### Step 1: Preprocessing

Next we define file paths to the original dataset, which should be located in the 'data' subdirectory. The transform_raw_data function has three different versions, since we train three different models. 

Commonalities between versions: remove sentences that contain no predicates, duplicate the sentences depending on the number of predicates they contain. This means if the input sentence has five predicates, the sentence will be copied five times. Only the following information is kept after processing: sentence ids, token ids, tokens, gold labels, some form of predicate information.

Differences between versions: predicate information is processed differently in the sentences for BERT input.
- V1 (baseline): Obama went to Paris last week. [SEP] went
- V2: Obama went to Paris last week. [SEP] Obama went to
- V3: Obama [PRED] went to Paris last week.

In Version 1, our aim is to mark which token is the predicate for the model after the SEP token (Shi and Lin, 2019). In Version 2, we provide the context window of the predicate, which can help the model disambiguate which token is the predicate in case the sentence contains the same predicate token twice (Zhou and Xu, 2015). In Version 3, we directly mark the predicate token by inserting a 'PRED' token immediately before the predicate itself (Khandelwal and Sawant, 2020).

In [None]:
# Transforming raw data into BERT format suitable for BERT input
input_train = 'data/en_ewt-up-train.conllu'  
input_dev = 'data/en_ewt-up-dev.conllu' 
input_test = 'data/en_ewt-up-test.conllu' 

# Call the function to perform the preprocessing transformation
transform_raw_data_v1(input_train, 'data/bert-train1-new.conllu')
transform_raw_data_v1(input_dev, 'data/bert-dev1-new.conllu')
transform_raw_data_v1(input_test, 'data/bert-test1-new.conllu')

transform_raw_data_v2(input_train, 'data/bert-train2-new.conllu')
transform_raw_data_v2(input_dev, 'data/bert-dev2-new.conllu')
transform_raw_data_v2(input_test, 'data/bert-test2-new.conllu')

transform_raw_data_v3(input_train, 'data/bert-train3-new.conllu')
transform_raw_data_v3(input_dev, 'data/bert-dev3-new.conllu')
transform_raw_data_v3(input_test, 'data/bert-test3-new.conllu')

### Step 2: Dataset Creation
In the next step we are using the functions from 'utils_data_processing'. 

To prepare our dataset for the SRL task with DistilBERT we are first loading the file path into a Pandas DataFrame with the columns 'sentence_id', 'token_id', 'token', and 'argument'. 

Then we transform the DataFrame  into a format that is suitable for training by grouping tokens by their sentence_id and converting argument labels into numerical indices. 

Lastly, we use the main function which extracts all unique argument labels from the training, validation, and test sets to create a label mapping and transforms the structured data into Hugging Face 'Dataset' objects. 



In [6]:
#dataset_dict1, argument_label_mapping1 = create_datasets('data/bert-train1-new.conllu', 'data/bert-test1-new.conllu', 'data/bert-dev1-new.conllu')
dataset_dict2, argument_label_mapping2 = create_datasets('data/bert-train2-new.conllu', 'data/bert-test2-new.conllu', 'data/bert-dev2-new.conllu')
#dataset_dict3, argument_label_mapping3 = create_datasets('data/bert-train3-new.conllu', 'data/bert-test3-new.conllu', 'data/bert-dev3-new.conllu')

### Step 3: Training and Predicting
For the next step we are using the functions from 'utils_model_training.py'. This section is about the model training process with specialized functions to prepare and fine-tune the DistilBERT model for the SRL task. 

Firstly, we begin by adjusting the tokens that appear after the special [SEP] token to ensure that they are not included in the loss calculations. This is done because we want to focus on the meaningful elements in the sentences during training. 

Secondly, we tokenize the data into subtokens and align them with their word_id's to make sure that the subtokens match correctly to their semantic roles. 

Thirdly, we standardize the input lengths across the dataset, by determining the maximum sequence length. This is necessary to make sure the input sentences are padded correctly. 

Lastly, the model and tokenizer are initialized with pre-trained settings, a trainer is used to manage the fine-tuning. 

 

In [7]:
# If you want to train the model from scratch, replace the first argument of the 'set_trainer' function to 'model'

#tokenizer1, trainer1, tokenized_dataset1 = set_trainer(model1, dataset_dict1)
tokenizer2, trainer2, tokenized_dataset2 = set_trainer(model2, dataset_dict2)
# tokenizer3, trainer3, tokenized_dataset3 = set_trainer(model3, dataset_dict3)

Map:   0%|          | 0/40482 [00:00<?, ? examples/s]

Map:   0%|          | 0/4977 [00:00<?, ? examples/s]

Map:   0%|          | 0/4799 [00:00<?, ? examples/s]

In [None]:
# Uncomment desired model to start the fine-tuning process of "distilbert-base-uncased". You can skip this code block if you only want to evaluate an existing model.

# trainer1.train()
# trainer2.train()
# trainer3.train()

In [None]:
# If you fine-tuned a model, you can use this code to save the model for evaluation later.

#trainer.save_model("enter path")

In [8]:
#prediction_output(trainer1, tokenizer1, tokenized_dataset1, argument_label_mapping1, output_path1)
prediction_output(trainer2, tokenizer2, tokenized_dataset2, argument_label_mapping2, output_path2)
# prediction_output(trainer3, tokenizer3, tokenized_dataset3, argument_label_mapping3, output_path3)

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


### Step 4: Evaluation

In the next step we are using the functions from 'utils_evaluation.py'. This script contains a function to evaluate the performances of the DistillBERT models by reading in the predicted labels and gold labels from the data. The evaluation metrics include the precision, recall, and F1-scores on token level for each argument class. Furthermore, aggregated scores are measured by compiling all tokens' true and predicted labels across the dataset. 


In [6]:
# evaluate_model_from_file(output_path1)
# evaluate_model_from_file(output_path2)
evaluate_model_from_file(output_path3)

Classification Report:
              precision    recall  f1-score   support

        ARG0      0.867     0.883     0.875      1733
        ARG1      0.848     0.897     0.872      3241
    ARG1-DSP      0.000     0.000     0.000         4
        ARG2      0.778     0.773     0.776      1129
        ARG3      0.000     0.000     0.000        74
        ARG4      0.688     0.589     0.635        56
        ARG5      0.000     0.000     0.000         1
        ARGA      0.000     0.000     0.000         2
    ARGM-ADJ      0.718     0.772     0.744       228
    ARGM-ADV      0.763     0.599     0.671       496
    ARGM-CAU      0.500     0.652     0.566        46
    ARGM-COM      0.500     0.077     0.133        13
    ARGM-CXN      0.600     0.250     0.353        12
    ARGM-DIR      0.500     0.383     0.434        47
    ARGM-DIS      0.770     0.698     0.732       182
    ARGM-EXT      0.769     0.762     0.766       105
    ARGM-GOL      1.000     0.083     0.154        24
    

### Results



| Class      | Precision Model 1 | Precision Model 2 | Precision Model 3 | Recall Model 1 | Recall Model 2 | Recall Model 3 | F1-Score Model 1 | F1-Score Model 2 | F1-Score Model 3 | Support |
|------------|-------------------|-------------------|-------------------|----------------|----------------|----------------|------------------|------------------|------------------|---------|
| ARG0       | 0.838             | 0.856             | 0.867             | 0.866          | 0.881          | 0.883          | 0.851            | 0.868            | 0.875            | 1733    |
| ARG1       | 0.817             | 0.84              | 0.848             | 0.864          | 0.885          | 0.897          | 0.84             | 0.862            | 0.872            | 3241    |
| ARG1-DSP   | 0                 | 0                 | 0                 | 0              | 0              | 0              | 0                | 0                | 0                | 4       |
| ARG2       | 0.761             | 0.77              | 0.778             | 0.724          | 0.757          | 0.773          | 0.742            | 0.763            | 0.776            | 1129    |
| ARG3       | 0.857             | 0                 | 0                 | 0.081          | 0              | 0              | 0.148            | 0                | 0                | 74      |
| ARG4       | 0.6               | 0.696             | 0.688             | 0.589          | 0.571          | 0.589          | 0.595            | 0.627            | 0.635            | 56      |
| ARG5       | 0                 | 0                 | 0                 | 0              | 0              | 0              | 0                | 0                | 0                | 1       |
| ARGA       | 0                 | 0                 | 0                 | 0              | 0              | 0              | 0                | 0                | 0                | 2       |
| ARGM-ADJ   | 0.711             | 0.702             | 0.718             | 0.746          | 0.754          | 0.772          | 0.728            | 0.727            | 0.744            | 228     |
| ARGM-ADV   | 0.734             | 0.723             | 0.763             | 0.516          | 0.583          | 0.599          | 0.606            | 0.645            | 0.671            | 496     |
| ARGM-CAU   | 0.531             | 0.538             | 0.5               | 0.565          | 0.609          | 0.652          | 0.547            | 0.571            | 0.566            | 46      |
| ARGM-COM   | 0.667             | 0.333             | 0.5               | 0.154          | 0.077          | 0.077          | 0.25             | 0.125            | 0.133            | 13      |
| ARGM-CXN   | 1                 | 0.5               | 0.6               | 0.25           | 0.167          | 0.25           | 0.4              | 0.25             | 0.353            | 12      |
| ARGM-DIR   | 0.433             | 0.391             | 0.5               | 0.277          | 0.383          | 0.383          | 0.338            | 0.387            | 0.434            | 47      |
| ARGM-DIS   | 0.76              | 0.792             | 0.77              | 0.764          | 0.775          | 0.698          | 0.762            | 0.783            | 0.732            | 182     |
| ARGM-EXT   | 0.792             | 0.794             | 0.769             | 0.762          | 0.733          | 0.762          | 0.777            | 0.762            | 0.766            | 105     |
| ARGM-GOL   | 0                 | 1                 | 1                 | 0              | 0.042          | 0.083          | 0                | 0.08             | 0.154            | 24      |
| ARGM-LOC   | 0.555             | 0.573             | 0.569             | 0.614          | 0.662          | 0.657          | 0.583            | 0.614            | 0.61             | 207     |
| ARGM-LVB   | 0.659             | 0.753             | 0.713             | 0.812          | 0.797          | 0.826          | 0.727            | 0.775            | 0.765            | 69      |
| ARGM-MNR   | 0.6               | 0.645             | 0.603             | 0.426          | 0.466          | 0.473          | 0.498            | 0.541            | 0.53             | 148     |
| ARGM-MOD   | 0.898             | 0.959             | 0.975             | 0.952          | 0.957          | 0.959          | 0.924            | 0.958            | 0.967            | 442     |
| ARGM-NEG   | 0.904             | 0.933             | 0.972             | 0.954          | 0.963          | 0.958          | 0.928            | 0.948            | 0.965            | 216     |
| ARGM-PRD   | 0.5               | 0.37              | 0.429             | 0.114          | 0.227          | 0.205          | 0.185            | 0.282            | 0.277            | 44      |
| ARGM-PRP   | 0.45              | 0.58              | 0.55              | 0.48           | 0.627          | 0.587          | 0.465            | 0.603            | 0.568            | 75      |
| ARGM-PRR   | 0.754             | 0.746             | 0.746             | 0.667          | 0.681          | 0.681          | 0.708            | 0.712            | 0.712            | 69      |
| ARGM-TMP   | 0.776             | 0.774             | 0.816             | 0.829          | 0.843          | 0.858          | 0.801            | 0.807            | 0.837            | 543     |
| C-ARG0     | 0                 | 0                 | 0                 | 0              | 0              | 0              | 0                | 0                | 0                | 3       |
| C-ARG1     | 0.58              | 0.652             | 0.634             | 0.558          | 0.577          | 0.5            | 0.569            | 0.612            | 0.559            | 52      |
| C-ARG1-DSP | 0                 | 0                 | 0                 | 0              | 0              | 0              | 0                | 0                | 0                | 1       |
| C-ARG2     | 0.5               | 1                 | 0.5               | 0.286          | 0.286          | 0.143          | 0.364            | 0.444            | 0.222            | 7       |
| C-ARG3     | 0                 | 0                 | 0                 | 0              | 0              | 0              | 0                | 0                | 0                | 2       |
| C-ARGM-CXN | 0                 | 0.333             | 0                 | 0              | 0.2            | 0              | 0                | 0.25             | 0                | 5       |
| C-ARGM-LOC | 0                 | 0                 | 0                 | 0              | 0              | 0              | 0                | 0                | 0                | 1       |
| R-ARG0     | 0.87              | 0.886             | 0.859             | 0.896          | 0.925          | 0.91           | 0.882            | 0.905            | 0.884            | 67      |
| R-ARG1     | 0.736             | 0.759             | 0.759             | 0.75           | 0.788          | 0.846          | 0.743            | 0.774            | 0.8              | 52      |
| R-ARG2     | 0                 | 0                 | 0                 | 0              | 0              | 0              | 0                | 0                | 0                | 1       |
| R-ARGM-ADJ | 0                 | 0                 | 0                 | 0              | 0              | 0              | 0                | 0                | 0                | 1       |
| R-ARGM-ADV | 0                 | 0                 | 0                 | 0              | 0              | 0              | 0                | 0                | 0                | 1       |
| R-ARGM-DIR | 0                 | 0                 | 0                 | 0              | 0              | 0              | 0                | 0                | 0                | 1       |
| R-ARGM-LOC | 0.316             | 0.273             | 0.389             | 0.667          | 0.667          | 0.778          | 0.429            | 0.387            | 0.519            | 9       |
| R-ARGM-MNR | 0                 | 0                 | 0                 | 0              | 0              | 0              | 0                | 0                | 0                | 8       |
| R-ARGM-TMP | 0                 | 0                 | 0                 | 0              | 0              | 0              | 0                | 0                | 0                | 2       |
| _          | 0.988             | 0.99              | 0.991             | 0.989          | 0.99           | 0.992          | 0.988            | 0.99             | 0.991            | 91725   |


| **Aggregated Scores**    | Model 1  | Model 2  | Model 3 |
|--------------------------|----------|----------|---------|
| Accuracy                 | 0.97     | 0.973    | 0.975   |
| Macro Avg (F1)           | 0.404    | 0.42     | 0.417   |
| Weighted Avg (F1)        | 0.969    | 0.972    | 0.974   |
| Macro Avg (Recall)       | 0.399    | 0.416    | 0.414   |
| Weighted Avg (Recall)    | 0.97     | 0.973    | 0.975   |
| Macro Avg (Precision)    | 0.455    | 0.469    | 0.461   |
| Weighted Avg (Precision) | 0.969    | 0.972    | 0.974   |


The performance scores displayed in the tables reflect that the advanced models perform only slighly better than our baseline model. Regarding the macro averages, which consider each class to be of equal weight, Model 2 performs with a small increase in the F1, precision and recall scores compared to Model 1 and Model 3.
If we consider individual classes, ARG1, ARG0 and ARG2 are the most frequent ones after the '_' class, which stands for non-argument tokens. The F1-scores for these specific argument classes indicate that Model 3 performs better than the other two models.

### Limitations of baseline model

In the baseline model, the predicate was simply indicated after the [SEP] token, and this causes one of the main limitations of the baseline model, which is the possible ambiguity of the predicate. In some sentences, the predicate can have multiple meanings based on the surrounding words, and this can change the semantic roles of tokens. Take the following sentence as an example:

A. The boy broke the window with his ball.

B. The comedy show broke his concentration. 

In the first sentence, ‘the boy’ is the agent, while in the second sentence, 'the comedy show' is the causer. The lack of context information may hinder the model's learning and ability to assign the correct roles. 

In contrast, the second model does include some context by providing the words before and after the predicate after the [SEP] token. This might help the model understand the meaning of the predicate based on its surroundings and be more accurate when predicting the roles. 

The second limitation of the baseline model is the vagueness of the predicate behind the [SEP] token. Some sentences contain multiple instances of the same word where only one of them is the predicate for the given data instance. It can be challenging for the model to identify which one is meant when the predicate is given after the [SEP] token without any additional positional information. Again, with the contextual window of the predicate, the second model might help disambiguate the identification of the correct predicate and learn the correct role assignments. 

We also included another advanced model that accentuates the predicate by inserting a special token '[PRED]' right before the predicate. This special token can also help with the identification of the right predicate when this word gets repeated in the sentence.

### Limitations of advanced models

As mentioned before, the second model includes the context window around the predicate after the [SEP] token. This process attempts to disambiguate the identification of the correct predicate and also to disambiguate the sense of the predicate. However, the effectiveness of these attempts relies heavily on the size of the context. A small window might not give enough information about the context of the predicate, such as sentences with punctuation marks immediately before or after the predicate. Alternatively, a too large window might give so much context that the model cannot depict what it has to learn anymore. This could create more confusion for the model rather than clarity. Moreover, adding a large context window might create more noise and make the model too complex to learn efficiently. 


With the third model, which used a special [PRED] token before the predicate, a limitation can be the overemphasis of the predicate. When we insert a special token, the model might rely too much on the marked tokens and may miss the general contextual cues. This can cause issues in performance when data is not annotated properly or when generalizing on new data that is not annotated.

### Conclusions and Future Directions

The advanced models have shown an insignificantly small improvement in the weighted and macro F1, precision, and recall scores compared to the baseline model. This indicates different input strategies to emphasize the predicate did not improve the generalization as expected in this experiment. Future work could include experimenting with including synthetic data for underperforming categories and patterns, alongside exploring different hyperparameters such as learning rates, batch sizes, and sequence lengths to potentially enhance model performance.

### References:

Shi, P. and Lin, J., 2019. Simple bert models for relation extraction and semantic role labeling. arXiv preprint arXiv:1904.05255.
https://arxiv.org/pdf/1904.05255.pdf

Zhou, J. and Xu, W., 2015, July. End-to-end learning of semantic role labeling using recurrent neural networks. In Proceedings of the 53rd Annual Meeting of the Association for Computational Linguistics and the 7th International Joint Conference on Natural Language Processing (Volume 1: Long Papers) (pp. 1127-1137).
https://aclanthology.org/P15-1109.pdf

Khandelwal, A. and Sawant, S., 2020, May. NegBERT: A Transfer Learning Approach for Negation Detection and Scope Resolution. In Proceedings of the Twelfth Language Resources and Evaluation Conference (pp. 5739-5748). http://www.lrec-conf.org/proceedings/lrec2020/pdf/2020.lrec-1.704.pdf



### End of Notebook
