In [None]:
!apt install python3.10-venv
!python3 -m venv NER
!source NER/bin/activate; pip install -r requirements.txt

**INFERENCE CODE**

In [None]:
!source NER/bin/activate; \
  python inference.py \
    --ds_test_set '../datasets/english/en_val.json' \
    --label_list 'COURT' 'PETITIONER' 'RESPONDENT' 'JUDGE' 'DATE' 'ORG' 'GPE' 'STATUTE' 'PROVISION' 'PRECEDENT' 'CASE_NUMBER' 'WITNESS' 'OTHER_PERSON' 'LAWYER' \
    --checkpoint_path_list 'dslim/bert-large-NER/checkpoint-54975'

**TRAINING OF THE MODELS**

In [None]:
!source NER/bin/activate; \
cd /content/gdrive/MyDrive/DNLP2/model_expl/legal_ner; \
python3 main.py \
  --ds_train_path '../datasets/english/en_train.json' \
  --ds_valid_path '../datasets/english/en_val.json' \
  --output_folder results/ \
  --batch 32 \
  --num_epochs 5 \
  --lr 1e-4 \
  --weight_decay 0.01 \
  --warmup_ratio 0.06 \
  --label_list 'COURT' 'PETITIONER' 'RESPONDENT' 'JUDGE' 'DATE' 'ORG' 'GPE' 'STATUTE' 'PROVISION' 'PRECEDENT' 'CASE_NUMBER' 'WITNESS' 'OTHER_PERSON' 'LAWYER' \
  --model_list 'studio-ousia/luke-large' \
  # --resume_checkpoint # Remove comment to recover from checkpoint in folder "results"

## t-SNE Analysis

In [None]:
import os
import json
import numpy as np
from tqdm import tqdm
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

In [None]:
def tsne_analysis(base_dir, save_dir, checkpoint_path_list):
    all_model_path = [(f'{base_dir}/{checkpoint}', f'{checkpoint.split("/")[0]}') for checkpoint in checkpoint_path_list]

    plt.figure(figsize=(20, 16))

    common_labels = set()

    label_to_color = {}

    for idx, model_path in enumerate(sorted(all_model_path), 1):
        model_name_for_filename = model_path[1].replace('/', '_')

        # Specifying the full path to embeddings and labels
        embeddings_file_path = f'{save_dir}/{model_name_for_filename}_embeddings.npy'
        labels_file_path = f'{save_dir}/{model_name_for_filename}_labels.npy'

        all_embeddings = np.load(embeddings_file_path)
        all_labels = np.load(labels_file_path)

        common_labels.update(all_labels)  # Collecting all unique labels

    # Assign consistent colors to labels
    for idx, label in enumerate(sorted(common_labels)):
        label_to_color[label] = plt.cm.tab20(idx / len(common_labels))

    for idx, model_path in enumerate(sorted(all_model_path), 1):
        # Modify the model name to replace slashes with underscores
        model_name_for_filename = model_path[1].replace('/', '_')

        # Specify the full path to embeddings and labels
        embeddings_file_path = f'{save_dir}/{model_name_for_filename}_embeddings.npy'
        labels_file_path = f'{save_dir}/{model_name_for_filename}_labels.npy'

        all_embeddings = np.load(embeddings_file_path)
        all_labels = np.load(labels_file_path)

        # Apply t-SNE
        tsne_embeddings = TSNE(n_components=2, random_state=42).fit_transform(all_embeddings)

        plt.subplot(2, 2, idx)
        for label in common_labels:
            indices = [i for i, l in enumerate(all_labels) if l == label]
            plt.scatter(tsne_embeddings[indices, 0], tsne_embeddings[indices, 1],
                        c=[label_to_color[label]], label=label, alpha=0.7)

        plt.title(f'{model_name_for_filename.upper()}')

    # Create a common legend below the subplots
    legend_elements = [Line2D([0], [0], marker='o', color='w', label=label,
                            markerfacecolor=color, markersize=10) for label, color in label_to_color.items()]

    plt.legend(handles=legend_elements, loc='upper center', bbox_to_anchor=(0.5, +0.08), ncol=len(common_labels)//2, bbox_transform=plt.gcf().transFigure)
    # Save the plot as an image
    plt.savefig(f'{save_dir}/combined_tsne_visualization.svg', bbox_inches='tight')

    plt.savefig(f'{save_dir}/combined_tsne_visualization.jpg', bbox_inches='tight')
    plt.savefig(f'{save_dir}/combined_tsne_visualization.png', bbox_inches='tight')

    plt.show()

In [None]:
tsne_analysis('results/all', 'saved_results', ['bert-large-NER/checkpoint-54975', 'legal-bert-base-uncased/checkpoint-54975', 'bert-base-uncased-echr/checkpoint-54975', 'luke-large/checkpoint-1720'])