In [1]:
!camel_data -i all

The following packages will be installed: 'disambig-bert-unfactored-glf', 'disambig-bert-unfactored-lev', 'disambig-mle-calima-egy-r13', 'disambig-bert-unfactored-msa', 'sentiment-analysis-mbert', 'morphology-db-msa-r13', 'morphology-db-lev-01', 'ner-arabert', 'disambig-mle-calima-msa-r13', 'disambig-ranking-cache-calima-glf-01', 'disambig-bert-unfactored-egy', 'dialectid-model26', 'dialectid-model6', 'morphology-db-msa-s31', 'sentiment-analysis-arabert', 'disambig-ranking-cache-calima-msa-r13', 'disambig-ranking-cache-calima-egy-r13', 'morphology-db-egy-r13', 'morphology-db-glf-01', 'disambig-ranking-cache-calima-lev-01'
Downloading package 'disambig-bert-unfactored-glf': 100%|[32m█[0m| 442M/442M [00:01<00:[0m
Extracting package 'disambig-bert-unfactored-glf': 100%|[32m█[0m| 442M/442M [00:00<00:0[0m
Downloading package 'disambig-bert-unfactored-lev': 100%|[32m█[0m| 441M/441M [00:01<00:[0m
Extracting package 'disambig-bert-unfactored-lev': 100%|[32m█[0m| 441M/441M [00:00<00:

In [1]:
from data_reshaping import DataReshaping

# Example usage:
data_path = '/data/WSD_Arabic_Dataset.xlsx'
pos_path = '/data/part-of-speech.xlsx'
freq_path = '/data/Term_frequency.xlsx'
index_path = '/data/Index_term.xlsx'

disambiguator = DataReshaping(data_path, pos_path, freq_path, index_path)
token_pos_freq, ws_pos_freq = disambiguator.get_result()

# Output the results
token_pos_freq.head()

Unnamed: 0,Target_ID,Target_Word,Label,Sentence,Gloss,POS,Freq,start_index,end_index
0,s0001.t0001,اب,1,كشفت دراسة جديدة أن عناق الأب مهم جدا بالنسبة ...,والد الشخص.,Noun,82,5,6
1,s0001.t0001,اب,0,كشفت دراسة جديدة أن عناق الأب مهم جدا بالنسبة ...,اسم الشهر الثامن من السنة السريانيّة,Noun,82,5,6
2,s0001.t0001,اب,0,كشفت دراسة جديدة أن عناق الأب مهم جدا بالنسبة ...,لقب كنسي لرجل الدين المسيحي.,Noun,82,5,6
3,s0002.t0001,اب,1,الأم التي تقوم بدورين في آن واحد بسبب غياب الأ...,والد الشخص.,Noun,82,9,10
4,s0002.t0001,اب,0,الأم التي تقوم بدورين في آن واحد بسبب غياب الأ...,اسم الشهر الثامن من السنة السريانيّة,Noun,82,9,10


In [2]:
# Create word-to-index dictionary
word_to_index = {word: index for index, word in enumerate(set(token_pos_freq['POS'].values))}

# Convert words to indices
word_indices = [word_to_index[word] for word in token_pos_freq['POS']]

token_pos_freq['POS'] = word_indices
ws_pos_freq['POS'] = word_indices

ws_pos_freq.head()

Unnamed: 0,Target_ID,Target_Word,Label,Sentence,POS,Freq,start_index,end_index,Gloss_Pair
0,s0001.t0001,اب,1,"كشفت دراسة جديدة أن عناق ""الأب"" مهم جدا بالنسب...",5,82,5,6,اب: والد الشخص.
1,s0001.t0001,اب,0,"كشفت دراسة جديدة أن عناق ""الأب"" مهم جدا بالنسب...",5,82,5,6,اب: اسم الشهر الثامن من السنة السريانيّة
2,s0001.t0001,اب,0,"كشفت دراسة جديدة أن عناق ""الأب"" مهم جدا بالنسب...",5,82,5,6,اب: لقب كنسي لرجل الدين المسيحي.
3,s0002.t0001,اب,1,"الأم التي تقوم بدورين في آن واحد بسبب غياب ""ال...",5,82,9,10,اب: والد الشخص.
4,s0002.t0001,اب,0,"الأم التي تقوم بدورين في آن واحد بسبب غياب ""ال...",5,82,9,10,اب: اسم الشهر الثامن من السنة السريانيّة


In [3]:
import torch
from transformers import AutoModelForSequenceClassification
import os
import torch.nn as nn
from ensemble import EnsembleBERT
from preprocess import (preprocess_ws_freq,
                        preprocess_ws_pos,
                        preprocess_ws)

In [4]:
model1 = AutoModelForSequenceClassification.from_pretrained('/data/pos_model')
model2 = AutoModelForSequenceClassification.from_pretrained('/data/freq_model')
model3 = AutoModelForSequenceClassification.from_pretrained('/data/ws_pos_model')
model4 = AutoModelForSequenceClassification.from_pretrained('/data/ws_model')
model5 = AutoModelForSequenceClassification.from_pretrained('/data/ws_freq_model')

In [5]:
ensemble_model = EnsembleBERT(model1, model2, model3, model4, model5)

In [6]:
sample_pos = preprocess_ws_pos(token_pos_freq['Sentence'], token_pos_freq['Gloss'],
                       token_pos_freq['POS'], token_pos_freq['start_index'],
                       token_pos_freq['end_index'], 'enhancedBERTmodel')
sample_freq = preprocess_ws_freq(token_pos_freq['Sentence'], token_pos_freq['Gloss'],
                        token_pos_freq['Freq'], 'enhancedBERTmodel')
sample_ws_pos = preprocess_ws_pos(ws_pos_freq['Sentence'], ws_pos_freq['Gloss_Pair'],
                       ws_pos_freq['POS'], ws_pos_freq['start_index'],
                       ws_pos_freq['end_index'], 'enhancedBERTmodel')
sample_ws = preprocess_ws(ws_pos_freq['Sentence'], ws_pos_freq['Gloss_Pair'], 'enhancedBERTmodel')
sample_ws_freq = preprocess_ws_freq(ws_pos_freq['Sentence'], ws_pos_freq['Gloss_Pair'], ws_pos_freq['Freq'], 'enhancedBERTmodel')

In [13]:
#Test on first 500 samples
from time import sleep

predictions = []
k = 0
while k<500:
    output = ensemble_model(sample_pos[k], sample_freq[k], sample_ws_pos[k], sample_ws[k], sample_ws_freq[k])
    prediction = torch.argmax(output, dim=1)
    predictions.append(prediction.item())
    sleep(3)
    k = k+1

print("Predicted label:", predictions)

Predicted label: [1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0

In [14]:
true_labels = ws_pos_freq.Label.values

In [15]:
from sklearn.metrics import f1_score
f1_score(true_labels[:500], predictions, average = 'weighted')

0.9979971299600283