# Embedding + first classifier

In [1]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_validate, cross_val_score, StratifiedKFold,train_test_split
from sklearn.metrics import make_scorer, accuracy_score, precision_score, recall_score, f1_score,make_scorer
from imblearn.under_sampling import RandomUnderSampler

In [2]:
data = pd.read_csv("dca_for_classifier.csv")

In [3]:
data

Unnamed: 0,גוף המסמך,מספר תיק,Outcome of case,binary_outcome
0,"['פסק-דין בתיק רע""פ 7861/03 בבית המשפט העליון ...",7861/03,G R,1.0
1,"['פסק-דין בתיק רע""פ 8337/04 בבית המשפט העליון ...",8337/04,G R,1.0
2,"['החלטה בתיק רע""פ 7896/04 בבית המשפט העליון רע...",7896/04,G A,1.0
3,"['פסק-דין בתיק רע""פ 2038/04 בבית המשפט העליון ...",2038/04,G R,1.0
4,"['פסק-דין בתיק רע""פ 5978/04 בבית המשפט העליון ...",5978/04,G A,1.0
...,...,...,...,...
726,"['החלטה בתיק רע""פ 3076/07 בבית המשפט העליון רע...",3076/07,D,0.0
727,"['החלטה בתיק רע""פ 6415/07 בבית המשפט העליון רע...",6415/07,D A,0.0
728,"['החלטה בתיק רע""פ 825/07 בבית המשפט העליון רע""...",825/07,D A,0.0
729,"['החלטה בתיק רע""פ 4180/07 בבית המשפט העליון רע...",4180/07,D A,0.0


In [4]:

tokenizer = AutoTokenizer.from_pretrained("dean-ai/sentence_transformer_Legal-heBERT")
model = AutoModel.from_pretrained("dean-ai/sentence_transformer_Legal-heBERT")

# פונקציה שמקבלת טקסט ומחזירה את האימבדינג כוקטור
def get_embeddings(text):
    # המרת הטקסט לפורמט מתאים למודל
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    # חישוב האימבדינג של המודל
    outputs = model(**inputs)
    # חישוב ממוצע הוקטורים עבור כל טקסט
    embeddings = outputs.last_hidden_state.mean(dim=1).detach().numpy()
    return embeddings[0]

# הסרת שורות עם ערכים חסרים בעמודת הטקסט או בעמודת המטרה
data = data.dropna(subset=["גוף המסמך", "binary_outcome"])

# עדכון עמודת המטרה ואימבדינג לאחר הסרת הערכים החסרים
y = data["binary_outcome"].values
embeddings = np.array([get_embeddings(text) for text in data["גוף המסמך"]])


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.



## Training

In [5]:

classifier = LogisticRegression(random_state=42, max_iter=1000)

# הגדרת Stratified K-Fold עם 5 פולדים כדי לשמר את יחס הכיתות
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# הגדרת מדדי הביצועים הרצויים
scoring = {
    'accuracy': make_scorer(accuracy_score),
    'precision': make_scorer(precision_score, pos_label=1),
    'recall': make_scorer(recall_score, pos_label=1),
    'f1': make_scorer(f1_score, pos_label=1)
}

# ביצוע Cross Validation עם המדדים הנוספים
cv_results = cross_validate(classifier, embeddings, y, cv=cv, scoring=scoring)

# הצגת התוצאות
print("Cross-Validation Accuracy:", cv_results['test_accuracy'])
print("Mean Accuracy:", cv_results['test_accuracy'].mean())
print("Cross-Validation Precision for Class 1:", cv_results['test_precision'])
print("Mean Precision for Class 1:", cv_results['test_precision'].mean())
print("Cross-Validation Recall for Class 1:", cv_results['test_recall'])
print("Mean Recall for Class 1:", cv_results['test_recall'].mean())
print("Cross-Validation F1 Score for Class 1:", cv_results['test_f1'])
print("Mean F1 Score for Class 1:", cv_results['test_f1'].mean())


Cross-Validation Accuracy: [0.97260274 0.97241379 0.96551724 0.96551724 0.9862069 ]
Mean Accuracy: 0.9724515824279643
Cross-Validation Precision for Class 1: [0.875      1.         0.83333333 1.         1.        ]
Mean Precision for Class 1: 0.9416666666666668
Cross-Validation Recall for Class 1: [0.7        0.55555556 0.55555556 0.44444444 0.77777778]
Mean Recall for Class 1: 0.6066666666666667
Cross-Validation F1 Score for Class 1: [0.77777778 0.71428571 0.66666667 0.61538462 0.875     ]
Mean F1 Score for Class 1: 0.7298229548229548


### Loss function - class weight tunning

In [6]:

# הגדרת מסווג לוגיסטי עם משקולות כיתות
classifier = LogisticRegression(random_state=42, max_iter=1000, class_weight={0: 1, 1: 5})

# הגדרת Cross Validation ומדדים כפי שהגדרנו קודם
cv_results = cross_validate(classifier, embeddings, y, cv=cv, scoring=scoring)

# הצגת תוצאות הקרוס ולידיישן עם משקולות כיתות
print("Cross-Validation Accuracy:", cv_results['test_accuracy'])
print("Mean Accuracy:", cv_results['test_accuracy'].mean())
print("Cross-Validation Precision for Class 1:", cv_results['test_precision'])
print("Mean Precision for Class 1:", cv_results['test_precision'].mean())
print("Cross-Validation Recall for Class 1:", cv_results['test_recall'])
print("Mean Recall for Class 1:", cv_results['test_recall'].mean())
print("Cross-Validation F1 Score for Class 1:", cv_results['test_f1'])
print("Mean F1 Score for Class 1:", cv_results['test_f1'].mean())


Cross-Validation Accuracy: [0.97945205 0.95862069 0.95172414 0.97241379 0.99310345]
Mean Accuracy: 0.9710628247520076
Cross-Validation Precision for Class 1: [0.88888889 0.71428571 0.6        1.         1.        ]
Mean Precision for Class 1: 0.8406349206349206
Cross-Validation Recall for Class 1: [0.8        0.55555556 0.66666667 0.55555556 0.88888889]
Mean Recall for Class 1: 0.6933333333333332
Cross-Validation F1 Score for Class 1: [0.84210526 0.625      0.63157895 0.71428571 0.94117647]
Mean F1 Score for Class 1: 0.7508292790800531


## UNDERSAMPLING


In [7]:


# הגדרת undersampler עם יחס של 60-40 לטובת קלאס 0
undersampler = RandomUnderSampler(sampling_strategy=0.6, random_state=42)

# ביצוע undersampling
X_resampled, y_resampled = undersampler.fit_resample(embeddings, y)

# הגדרת מסווג וביצוע Cross Validation על הנתונים המדוללים
cv_results_resampled = cross_validate(classifier, X_resampled, y_resampled, cv=cv, scoring=scoring)

# הצגת תוצאות ה-Cross Validation על הנתונים החדשים
print("Cross-Validation Accuracy (after undersampling):", cv_results_resampled['test_accuracy'])
print("Mean Accuracy (after undersampling):", cv_results_resampled['test_accuracy'].mean())
print("Cross-Validation Precision for Class 1 (after undersampling):", cv_results_resampled['test_precision'])
print("Mean Precision for Class 1 (after undersampling):", cv_results_resampled['test_precision'].mean())
print("Cross-Validation Recall for Class 1 (after undersampling):", cv_results_resampled['test_recall'])
print("Mean Recall for Class 1 (after undersampling):", cv_results_resampled['test_recall'].mean())
print("Cross-Validation F1 Score for Class 1 (after undersampling):", cv_results_resampled['test_f1'])
print("Mean F1 Score for Class 1 (after undersampling):", cv_results_resampled['test_f1'].mean())


Cross-Validation Accuracy (after undersampling): [0.84       0.8        0.91666667 0.79166667 0.91666667]
Mean Accuracy (after undersampling): 0.853
Cross-Validation Precision for Class 1 (after undersampling): [0.72727273 0.77777778 0.88888889 0.75       0.81818182]
Mean Precision for Class 1 (after undersampling): 0.7924242424242425
Cross-Validation Recall for Class 1 (after undersampling): [0.88888889 0.7        0.88888889 0.66666667 1.        ]
Mean Recall for Class 1 (after undersampling): 0.8288888888888888
Cross-Validation F1 Score for Class 1 (after undersampling): [0.8        0.73684211 0.88888889 0.70588235 0.9       ]
Mean F1 Score for Class 1 (after undersampling): 0.8063226694186447


## combine undersampling with class weights

In [10]:

# רשימת משקלים שנבדוק עבור קלאס 1
class_weights = [2, 3, 4, 5]
results = []

# הגדרת undersampler עם יחס של 60-40 לטובת קלאס 0
undersampler = RandomUnderSampler(sampling_strategy=0.6, random_state=42)
X_resampled, y_resampled = undersampler.fit_resample(embeddings, y)

# הגדרת Cross Validation ומדדים
scoring = {
    'accuracy': make_scorer(accuracy_score),
    'precision': make_scorer(precision_score, pos_label=1),
    'recall': make_scorer(recall_score, pos_label=1),
    'f1': make_scorer(f1_score, pos_label=1)
}

# ביצוע בדיקה עבור כל משקל
for weight in class_weights:
    # הגדרת מסווג לוגיסטי עם המשקל הנוכחי לקלאס 1
    classifier = LogisticRegression(random_state=42, max_iter=1000, class_weight={0: 1, 1: weight})

    # ביצוע Cross Validation על הנתונים המדוללים עם המשקל הנוכחי
    cv_results_combined = cross_validate(classifier, X_resampled, y_resampled, cv=cv, scoring=scoring)

    # שמירת התוצאות
    results.append({
        "weight": weight,
        "mean_accuracy": cv_results_combined['test_accuracy'].mean(),
        "mean_precision_class_1": cv_results_combined['test_precision'].mean(),
        "mean_recall_class_1": cv_results_combined['test_recall'].mean(),
        "mean_f1_class_1": cv_results_combined['test_f1'].mean()
    })

# הצגת התוצאות
for result in results:
    print(f"Class Weight for 1: {result['weight']}")
    print(f"Mean Accuracy: {result['mean_accuracy']}")
    print(f"Mean Precision for Class 1: {result['mean_precision_class_1']}")
    print(f"Mean Recall for Class 1: {result['mean_recall_class_1']}")
    print(f"Mean F1 Score for Class 1: {result['mean_f1_class_1']}")
    print("--------------------------------------------------")

Class Weight for 1: 2
Mean Accuracy: 0.869
Mean Precision for Class 1: 0.845
Mean Recall for Class 1: 0.8066666666666666
Mean F1 Score for Class 1: 0.8218094255245958
--------------------------------------------------
Class Weight for 1: 3
Mean Accuracy: 0.8613333333333333
Mean Precision for Class 1: 0.8146464646464647
Mean Recall for Class 1: 0.8288888888888888
Mean F1 Score for Class 1: 0.816780185758514
--------------------------------------------------
Class Weight for 1: 4
Mean Accuracy: 0.853
Mean Precision for Class 1: 0.7924242424242425
Mean Recall for Class 1: 0.8288888888888888
Mean F1 Score for Class 1: 0.8063226694186447
--------------------------------------------------
Class Weight for 1: 5
Mean Accuracy: 0.853
Mean Precision for Class 1: 0.7924242424242425
Mean Recall for Class 1: 0.8288888888888888
Mean F1 Score for Class 1: 0.8063226694186447
--------------------------------------------------


### Best model

In [13]:
# הגדרת undersampler עם יחס של 60-40 לטובת קלאס 0
undersampler = RandomUnderSampler(sampling_strategy=0.6, random_state=42)
X_resampled, y_resampled = undersampler.fit_resample(embeddings, y)

# הגדרת מסווג לוגיסטי עם משקל של 2 עבור קלאס 1
classifier = LogisticRegression(random_state=42, max_iter=1000, class_weight={0: 1, 1: 3})

# הגדרת Cross Validation ומדדים
scoring = {
    'accuracy': make_scorer(accuracy_score),
    'precision': make_scorer(precision_score, pos_label=1),
    'recall': make_scorer(recall_score, pos_label=1),
    'f1': make_scorer(f1_score, pos_label=1)
}

# ביצוע Cross Validation על הנתונים המדוללים עם משקלות כיתות
cv_results = cross_validate(classifier, X_resampled, y_resampled, cv=5, scoring=scoring)

# הצגת תוצאות ה-Cross Validation
print("Cross-Validation Accuracy:", cv_results['test_accuracy'])
print("Mean Accuracy:", cv_results['test_accuracy'].mean())
print("Cross-Validation Precision for Class 1:", cv_results['test_precision'])
print("Mean Precision for Class 1:", cv_results['test_precision'].mean())
print("Cross-Validation Recall for Class 1:", cv_results['test_recall'])
print("Mean Recall for Class 1:", cv_results['test_recall'].mean())
print("Cross-Validation F1 Score for Class 1:", cv_results['test_f1'])
print("Mean F1 Score for Class 1:", cv_results['test_f1'].mean())


Cross-Validation Accuracy: [0.84       0.88       0.83333333 0.91666667 0.875     ]
Mean Accuracy: 0.869
Cross-Validation Precision for Class 1: [0.72727273 0.81818182 0.85714286 0.88888889 0.8       ]
Mean Precision for Class 1: 0.8182972582972583
Cross-Validation Recall for Class 1: [0.88888889 0.9        0.66666667 0.88888889 0.88888889]
Mean Recall for Class 1: 0.8466666666666665
Cross-Validation F1 Score for Class 1: [0.8        0.85714286 0.75       0.88888889 0.84210526]
Mean F1 Score for Class 1: 0.8276274018379282
