In [13]:
import logging
import math
import os
import re
from datetime import datetime

import arabicstopwords.arabicstopwords as ar_stp
import numpy as np
import pandas as pd
import pyterrier as pt
import torch
from sentence_transformers import LoggingHandler, SentenceTransformer, CrossEncoder, util, InputExample
from sentence_transformers import models, losses
from sentence_transformers import evaluation

from snowballstemmer import stemmer
from torch import nn
# from simcse import SimSCE
from torch.utils.data import DataLoader

In [14]:
print(torch.__version__)
print(torch.cuda.is_available())
# print(torch.cuda.device_count())
# print(torch.cuda.current_device())
print(torch.cuda.get_device_name(0))

2.0.1+cu118
True
NVIDIA GeForce RTX 3090


# Get Data

In [15]:
data_path = "../data"
index_path = os.path.join(data_path, "QPC_Index/data.properties")

query_train_path = os.path.join(data_path, "QQA23_TaskA_train.tsv")
query_dev_path = os.path.join(data_path, "QQA23_TaskA_dev.tsv")

passage_path = os.path.join(data_path, "Thematic_QPC/QQA23_TaskA_QPC_v1.1.tsv")

qp_pair_train_path = os.path.join(data_path, "qrels\QQA23_TaskA_qrels_train.gold")
qp_pair_dev_path = os.path.join(data_path, "qrels\QQA23_TaskA_qrels_dev.gold")

## Read file

In [16]:
# read file based on its extension (tsv or xlsx)
def read_file(input_file, sep="\t", names = ""):
    if input_file.endswith(".xlsx"):
        df = pd.read_excel(input_file)
    else:
        if names != "":
            df = pd.read_csv(input_file, sep=sep, names=names,encoding="utf-8")
        else:
            df = pd.read_csv(input_file, sep=sep,encoding="utf-8")
    return df

In [17]:
qrels_columns = ["qid", "Q0", "docid", "relevance"]

def read_qrels_file(qrels_file):
    # split_token = '\t' if format_checker.is_tab_sparated(qrels_file) else  "\s+"
    df_qrels = pd.read_csv(qrels_file, sep='\t', names=qrels_columns)
    df_qrels["qid"] = df_qrels["qid"].astype(str)
    df_qrels["docid"] = df_qrels["docid"].astype(str)
    return df_qrels

In [18]:
def load_index(index_path):
    if not pt.started():
        pt.init(helper_version="0.0.6")

    try:
        index = pt.IndexFactory.of(index_path)
        print("Index was loaded successfully from this path: ", index_path)
        return index
    except Exception as e:
        print('Cannot load the index, check exception details {}'.format(e))
        return []

## Cleaning & Preprocessing
Clean text from urls, handles, special characters, tabs, line jumps, extra white space, and puntuations.
Preprocess the arabic input text by performing normalization, stemming, and removing stop words.

In [19]:
# Clean text from urls, handles, special characters, tabs, line jumps, and extra white space.
def clean(text):
    text = re.sub(r"http\S+", " ", text)  # remove urls
    text = re.sub(r"@[\w]*", " ", text)  # remove handles
    text = re.sub(r"[\.\,\#_\|\:\?\?\/\=]", " ", text) # remove special characters
    text = re.sub(r"\t", " ", text)  # remove tabs
    text = re.sub(r"\n", " ", text)  # remove line jump
    text = re.sub(r"\s+", " ", text)  # remove extra white space
    text = re.sub(r'[^\w\s]', '', text) # Removing punctuations in string using regex
    text = text.strip()
    return text

In [20]:
# arabic stemmer
ar_stemmer = stemmer("arabic")

# remove arabic stop words
def ar_remove_stop_words(sentence):
    terms=[]
    stopWords= set(ar_stp.stopwords_list())
    for term in sentence.split() : 
        if term not in stopWords :
            terms.append(term)
    return " ".join(terms)


# normalize the arabic text
def normalize_arabic(text):
    text = re.sub("[إأٱآا]", "ا", text)
    text = re.sub("ى", "ي", text)
    text = re.sub("ؤ", "ء", text)
    text = re.sub("ئ", "ء", text)
    text = re.sub("ة", "ه", text)
    return(text)

# stem the arabic text
def ar_stem(sentence):
    return " ".join([ar_stemmer.stemWord(i) for i in sentence.split()])


# apply all preprocessing steps needed for Arabic text
def preprocess_arabic(text): 
    text = normalize_arabic(text)
    text = ar_remove_stop_words(text)
    text = ar_stem(text)
    return text

In [21]:
def prepare_data(path, column, id_type, id_column='docno'):
        df = read_file(path, names=['docno', 'text'])

        print("Cleaning passages")
        # apply the cleaning functions on the queries/questions
        df[column] = df['text'].apply(clean)

        # apply normalization, stemming and stop word removal
        print("Preprocessing - Applying normalization, stemming and stop word removal")
        df[column] = df[column].apply(preprocess_arabic)

        df[id_type] = df[id_column].astype(str) # convert the id column to string
        df = df[[id_type, 'text', column]] # keep the columns needed for search

        print("Done with preparation!")
        return df


## Loading

In [None]:
index = load_index(index_path=index_path)

# print(index.getCollectionStatistics().toString())
# print(index.getMetaIndex().getKeys())

# for kv in index.getLexicon():
#     print((kv.getKey())+"\t"+ kv.getValue().toString())
# index.getLexicon()["فاعل"].toString()

In [22]:
df_passage = prepare_data(passage_path, 'passage', 'pid')

df_query_train = prepare_data(query_train_path, 'query', 'qid')
df_query_dev = prepare_data(query_dev_path, 'query', 'qid')

df_qppair_train = read_qrels_file(qp_pair_train_path)

df_qppair_dev = read_qrels_file(qp_pair_dev_path)


Cleaning passages
Preprocessing - Applying normalization, stemming and stop word removal
Done with preparation!
Cleaning passages
Preprocessing - Applying normalization, stemming and stop word removal
Done with preparation!
Cleaning passages
Preprocessing - Applying normalization, stemming and stop word removal
Done with preparation!


In [30]:
df_passage.passage[5]

'ان كفر سواء اانذر ام تنذر يءمن ختم الله قلوب سمع ابصار غشا عذاب عظيم'

# Model - Sentence Embedding

#### Simple passage-passage pair

In [32]:
train_samples_passage = []
for _, row in df_passage.iterrows():
    train_samples_passage.append(InputExample(texts=[row['passage'], row['passage']]))

print("len(train_samples_passage) =", len(train_samples_passage))

len(train_samples_passage) = 1266


#### query-passage double pair with relevance label = 1 (positive)

In [33]:
train_samples_qp = []
for _, row in df_qppair_train.iterrows():
    query_id = row['qid']
    query = df_query_train[df_query_train['qid'] == query_id]['query'].tolist()[0]
    passage_id = row['docid']

    if passage_id == '-1':
        continue
    else:
        passage = df_passage[df_passage['pid'] == passage_id]['passage'].tolist()[0]
        label = row['relevance']
        #positive sample
        train_samples_qp.append(InputExample(texts=[query, passage], label=label))
        train_samples_qp.append(InputExample(texts=[passage, query], label=label))

print("len(train_samples_qp) =", len(train_samples_qp))

len(train_samples_qp) = 1892


#### contrastive: query-passage double pair with relevance label = 1 (positive) and not found BM25 top-k passages with relevance label = 0 (negative)

In [34]:
train_samples_qp_contrastive = []

top_k = 75
print("top_k =", top_k)
BM25_model = pt.BatchRetrieve(index, controls = {"wmodel": "BM25"}, num_results=top_k)

for _, row in df_qppair_train.groupby('qid'):
    query_id = row['qid'].tolist()[0]
    query = df_query_train[df_query_train['qid'] == query_id]['query'].tolist()[0]
    bm25_related_passage = BM25_model.search(query)['docno'].tolist()
    positive_passage = row['docid'].tolist()
    negative_passage = list(set(bm25_related_passage) - set(positive_passage))

    for pos_passage in positive_passage:
        if pos_passage == '-1':
            continue
        else:
            passage = df_passage[df_passage['pid'] == pos_passage]['passage'].tolist()[0]
            label = 1
            #positive sample
            train_samples_qp_contrastive.append(InputExample(texts=[query, passage], label=label))
            train_samples_qp_contrastive.append(InputExample(texts=[passage, query], label=label))

    for neg_passage in negative_passage:
        if neg_passage == '-1':
            continue
        else:
            passage = df_passage[df_passage['pid'] == neg_passage]['passage'].tolist()[0]
            label = 0
            #positive sample
            train_samples_qp_contrastive.append(InputExample(texts=[query, passage], label=label))
            train_samples_qp_contrastive.append(InputExample(texts=[passage, query], label=label))

print("len(train_samples_qp_contrastive) =", len(train_samples_qp_contrastive))

top_k = 75
len(train_samples_qp_contrastive) = 19224


#### multiple negative ranking: query-passage double pair with relevance label = 1 (positive)

In [35]:
train_samples_qp_multiple_negative_ranking = []
for _, row in df_qppair_train.groupby('qid'):
    query_id = row['qid'].tolist()[0]
    query = df_query_train[df_query_train['qid'] == query_id]['query'].tolist()[0]
    positive_passage = row['docid'].tolist()
    for pos_passage in positive_passage:
        if pos_passage == '-1':
            continue
        else:
            passage = df_passage[df_passage['pid'] == pos_passage]['passage'].tolist()[0]
            label = 1
            #positive sample
            train_samples_qp_multiple_negative_ranking.append(InputExample(texts=[query, passage], label=label))
            train_samples_qp_multiple_negative_ranking.append(InputExample(texts=[passage, query], label=label))

print("len(train_samples_qp_multiple_negative_ranking) =", len(train_samples_qp_multiple_negative_ranking))

len(train_samples_qp_multiple_negative_ranking) = 1892


#### triple: query-positive passage-negative passage; negative passage is not found in BM25 top-k passages

In [36]:
top_k = 20
print("top_k =", top_k)
BM25_model = pt.BatchRetrieve(index, controls = {"wmodel": "BM25"}, num_results=top_k)

train_samples_qp_triple = []

for _, row in df_qppair_train.groupby('qid'):
    # print(row)
    query_id = row['qid'].tolist()[0]
    query = df_query_train[df_query_train['qid'] == query_id]['query'].tolist()[0]
    bm25_related_passage = BM25_model.search(query)['docno'].tolist()
    positive_passage = row['docid'].tolist()
    negative_passage = list(set(bm25_related_passage) - set(positive_passage))
    # print(bm25_related_passage)
    # print(possitive_passage)
    # print(negative_passage)

    for pos_passage_id in positive_passage:
        for neg_passage_id in negative_passage:
            if pos_passage_id == '-1':
                continue
            else:
                pos_passage = df_passage[df_passage['pid'] == pos_passage_id]['passage'].tolist()[0]
                neg_passage = df_passage[df_passage['pid'] == neg_passage_id]['passage'].tolist()[0]
                train_samples_qp_triple.append(InputExample(texts=[query, pos_passage, neg_passage]))

print("len train_samples_qp_triple =", len(train_samples_qp_triple))

top_k = 20
len train_samples_qp_triple = 14047


In [83]:
df_query_train['query'].tolist()

['قوم شعيب',
 'قوم موس',
 'بن كعبه',
 'النب معروف صبر',
 'كفل سيده مريم',
 'معن حطمه',
 'اخو سيد موس',
 'معن قارعه',
 'معن جاثيه',
 'اسباط',
 'ملك سبا',
 'عقر ناقه',
 'عقوب ربا',
 'عقوب سارق',
 'شجره ملعونه',
 'النب علم الله لغه طير والحيو',
 'ميراث الام ولد ان يكن ولد',
 'امر الله زكري الا يكلم ناس',
 'مده عده مطلقه',
 'عدد اشهر حرم',
 'شجره ياكل كفار نار',
 'النب دخل سجن',
 'جبل استقر سفين نوح',
 'فتر رضاع مولود',
 'اين كان رحل اسراء والمعراج',
 'حكم تعدد زواج',
 'زكري',
 'لبث بطن حوت',
 'عقوب قتل خطا',
 'عقوب قتل عمد',
 'مطفف',
 'عدد حمل عرش',
 'وصف حور عين',
 'ابو سيد يوسف سلام',
 'قارون',
 'كفار ظهار',
 'كفار اليم',
 'جزاء يقول ان لله ولد',
 'اين تقع قبل مسلم اول',
 'يوم خلق الله كون',
 'فضل قدر',
 'كتاب انزل موس',
 'كتاب انزل عيس',
 'لغه القر',
 'مسيح',
 'صنع عجل الحل لبن اسراءيل',
 'مخلوق تسبح الله',
 'الا تتحدث موضوع وصيه',
 'الا تتحدث موضوع وصيه سور ماءده',
 'ابناء سيد ابراهيم سلام',
 'حدث لقابيل هابيل',
 'قرن',
 'احداث متعلقه قرن',
 'معجزا النب موس سلام',
 'حوار',
 'احداث متع

In [88]:
from statistics import mean
def stats(df):
    all = df.tolist()
    splitted = [p.split() for p in all]
    len_splitted = [len(p) for p in splitted]
    print(max(len_splitted), min(len_splitted), mean(len_splitted))


print("\ndf_passage")
stats(df_passage.passage)

print("\ndf_query_train")
stats(df_query_train['query'])

print("\ndf_query_dev")
stats(df_query_dev['query'])


df_passage
162 4 43.872037914691944

df_query_train
28 1 4.885057471264368

df_query_dev
18 2 4.84


# Train Bi-Encoder

In [37]:
# Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings
#TODO : change max_seq_length to 384 or 512
bi_model_name = "aubmindlab/bert-base-arabert"

max_seq_length = 256
word_embedding_model = models.Transformer(bi_model_name, max_seq_length=max_seq_length)
print("word_embedding_model Max Sequence Length:", word_embedding_model.max_seq_length)
print("word_embedding_model dimension", word_embedding_model.get_word_embedding_dimension())

# Apply mean pooling to get one fixed sized sentence vector
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
print("pooling_model sentence embedding dimension", pooling_model.get_sentence_embedding_dimension())

#TODO : change out_features to 512
dense_model = models.Dense(in_features=pooling_model.get_sentence_embedding_dimension(), out_features=max_seq_length, activation_function=nn.Tanh())

# bi_encoder = SentenceTransformer(modules=[word_embedding_model, pooling_model])
bi_encoder = SentenceTransformer(modules=[word_embedding_model, pooling_model, dense_model])

word_embedding_model Max Sequence Length: 256
word_embedding_model dimension 768
pooling_model sentence embedding dimension 768
2023-08-14 01:32:45 - Use pytorch device: cuda


In [90]:
 # define some global constants
TEXT = "text"
QUERY = "query"
LABEL = "label"
RANK = "rank"
TAG = "tag"
SCORE = "score"
QID = "qid"
DOC_NO = "docno"
DOCID = "docid"

def prepare_query_for_search(query_path, query_column=TEXT,
                        id_column=DOC_NO):

        names = [DOC_NO, TEXT]
        print("Cleaning queries and applying preprocessing steps")
        df_query = read_file(query_path, names=names)
        # apply the cleaning functions on the queries/questions
        df_query[QUERY] =df_query[query_column].apply(clean)

        # apply normalization, stemming and stop word removal
        print("Applying normalization, stemming and stop word removal")
        df_query[QUERY] =df_query[QUERY].apply(preprocess_arabic)

        df_query[QID] = df_query[id_column].astype(str) # convert the id column to string
        df_query = df_query[[QID, QUERY]] # keep the columns needed for search
        print("Done with preparation!")
        return df_query

In [177]:
BM25_model = pt.BatchRetrieve(index, controls = {"wmodel": "BM25"}, num_results=1000)

# 2. read the query file and prepare it for search to match pyterrier format
df_query = prepare_query_for_search(query_train_path)

# 3. search using BM25 model
df_run = BM25_model.transform(df_query)

# 4. save the run in trec format to a file
df_run["Q0"] = ["Q0"] * len(df_run)
df_run["tag"] = ["BM25"] * len(df_run)
df_run['question-id'] = df_run["qid"]
df_run['passage-id'] = df_run["docno"]
df_run = df_run[["question-id", "Q0", "passage-id", "rank", "score", "tag"]]
df_run.to_csv("../data/runs/GYM_BM25.tsv", sep="\t", index=False, header=False)
# df_run

Cleaning queries and applying preprocessing steps
Applying normalization, stemming and stop word removal
Done with preparation!


In [178]:
# ! python QQA23_TaskA_eval.py \
#     -r "../data/runs/GYM_BM25.tsv" \
#     -q "../data/qrels/QQA23_TaskA_qrels_dev.gold"

In [179]:
GOld_label_train = df_qppair_train.groupby('qid').apply(lambda x: x['docid'].tolist())
GOld_label_dev = df_qppair_dev.groupby('qid').apply(lambda x: x['docid'].tolist())
acuracy_list = []
for qid, predicted in df_run.groupby('question-id'):
    predicted = predicted['passage-id'].tolist()
    actual = GOld_label_train[qid]
    # print(qid, predicted, actual)
    acuracy_list.append(len(set(predicted) & set(actual)) / len(set(actual)))
    print(len(set(predicted) & set(actual)) , len(set(actual)))

# print("Accuracy =", sum(acuracy_list) / len(acuracy_list))
# print(acuracy_list)


4 4
14 17
0 1
0 2
1 1
1 1
5 7
1 1
1 1
0 1
5 5
2 2
0 3
2 2
1 1
2 2
1 3
1 1
0 3
1 2
1 1
1 3
0 2
0 1
3 3
1 1
1 1
2 3
1 1
2 4
2 2
1 3
0 1
1 1
1 1
0 1
8 8
1 1
0 2
4 4
3 7
5 5
1 2
13 26
3 7
1 1
4 5
0 1
2 2
1 1
9 10
3 3
8 9
1 1
6 7
4 4
4 4
2 2
0 1
8 74
1 21
1 1
1 2
4 5
1 1
7 7
1 2
0 1
0 1
3 3
3 3
0 1
0 1
1 1
0 1
3 3
1 1
1 1
2 4
0 1
1 3
0 10
15 16
1 8
4 4
1 1
3 3
1 21
3 8
8 24
2 30
0 1
0 1
0 1
3 3
3 8
3 5
3 5
5 6
0 1
0 1
0 1
1 5
3 3
0 1
2 6
6 7
1 4
1 1
6 6
1 1
2 2
0 12
9 9
0 1
3 3
7 9
1 1
7 9
0 2
0 6
0 3
1 1
4 8
2 2
111 144
2 2
0 1
1 1
1 1
20 28
0 1
1 1
0 1
0 1
1 1
0 1
0 1
4 5
0 1
0 1
1 8
0 1
1 1
0 2
2 2
2 2
1 20
5 5
2 2
1 3
0 1
7 8
3 5
0 4
1 3
0 4
1 1
3 3
1 25
0 26
6 10
0 7
0 1
0 9
2 5
3 5
0 1
2 2
1 1
