In [None]:
# ! python3 -m spacy download en_core_web_lg

In [1]:
import re
from typing import List, Set, Tuple

import numpy as np
import pandas as pd
import spacy

### Pre-processing

In [2]:
def replace_annotated_party_names(s: str) -> str:
    pattern = r"{party([A-Z])_?}(.*?){/?party\1_?}"
    replacement = r"party\1"
    return re.sub(pattern, replacement, s)


def replace_annotated_firm_names(s: str) -> str:
    pattern = r"{firm([1-9])_?}(.*?){/?firm\1_?}"  # todo: support arbitrary integers
    replacement = r"firm\1"
    return re.sub(pattern, replacement, s)

In [3]:
nlp = spacy.load("en_core_web_lg", disable=["ner", "pos"]) # todo: optimize

In [4]:
def get_pre_processed_sentences(data: str) -> List[str]:
    paragraphs = (
        replace_annotated_firm_names(replace_annotated_party_names(data))
        .splitlines()
    )
    return [str(sent).lower() for doc in nlp.pipe(paragraphs) for sent in doc.sents]

### Modelling

In [5]:
"""
Assumptions: 
1. Parties cannot assume mutually exclusive roles. That is, a single party cannot simultaneously be a plaintiff 
and a defendant/defendant-intervenor (but could, of course, be a counter-defendant). Similarly, a single party 
cannot simultaneously be a defendant and a plaintiff, or be a petitioner and a respondant, etc.
2. When our system delivers mutually exclusive roles (e.g. returns both defendant and plaintiff) we give preference
to the first discovered role. This is motivated by the fact that the search for party roles starts with the " v. "
approach which seems more reliable than the second phase which simply looks for a party and role mention
in the same sentence. However, once in phase 2, it seems fair to assume that the sentence that reveals a given 
party's role will appear before other sentences which might mention the party in question and a role that does not
belong to them.
"""


class LegalPartyRolesDetector:
    _PARTY_ROLES = {
        "defendant", 
        "plaintiff",
        "counter-defendant",
        "counter-plaintiff",
        "counter-claimant",
        "appellee",
        "appellant",
        "defendant-intervenor",
        "petitioner",
        "respondent",
    }
    _MUTUALLY_EXCLUSIVE_PARTY_ROLES_MAPPING = {
        "defendant": ["plaintiff", "defendant-intervenor"],
        "plaintiff": ["defendant", "defendant-intervenor"],
        "defendant-intervenor": ["plaintiff", "defendant"],
        "petitioner": ["respondent"],
        "respondent": ["petitioner"],
        "appellee": ["appellant"],
        "appellant": ["appellee"],
        "counter-defendant": [],
        "counter-claimant": [],
        "counter-plaintiff": [],
    }
    
    def __init__(self, opinion: int, party_letter: str):
        self._opinion = opinion
        self._party_letter = party_letter
    
        self._sorted_party_roles: List[str] = []
        self._identified_party_roles: Set[str] = set()
            
        self._set_sorted_party_roles_by_length()
        
    def predict(self) -> Set[str]:
        with open(f"Opinions/Opinion{self._opinion}.txt", "r") as file:
            sentences = get_pre_processed_sentences(file.read())
        
        # phase 1
        for sentence in sentences:
            if (
                " v. " in sentence and 
                f"party{self._party_letter}" in sentence and
                any(role in sentence for role in self._sorted_party_roles)
            ):
                for s in sentence.split(" v. "):
                    if f"party{self._party_letter}" in s:
                        for role in self._sorted_party_roles: 
                            regex_pattern = self._get_regex_pattern(role)
                            match = re.search(regex_pattern, s)
                            if match:
                                self._add(role)
                                s = re.sub(regex_pattern, "", s)
                                
        # phase 2
        for sentence in sentences:
            if (
                f"party{self._party_letter}" in sentence and 
                self._get_role_count(sentence) == 1
            ):
                for role in self._sorted_party_roles:
                    regex_pattern = self._get_regex_pattern(role)
                    match = re.search(regex_pattern, sentence)
                    if match:
                        self._add(role)
                        break
                        
        return self._identified_party_roles
        
    def _add(self, role: str) -> None:
        if self._is_conflicting_role(role):
            return
        self._identified_party_roles.add(role)
    
    def _set_sorted_party_roles_by_length(self) -> None:
        self._sorted_party_roles = sorted(list(self._PARTY_ROLES), key=len, reverse=True)
    
    def _is_conflicting_role(self, role: str) -> bool:
        for identified_party_role in self._identified_party_roles:
            for m_e_role in self._MUTUALLY_EXCLUSIVE_PARTY_ROLES_MAPPING[identified_party_role]:
                if role == m_e_role:
                    return True
        return False
    
    def _get_role_count(self, s: str) -> int:
        "Count the number of party roles present in a given sentence"
        count = 0
        for role in self._sorted_party_roles:
            regex_pattern = self._get_regex_pattern(role)
            match = re.search(regex_pattern, s)
            if match:
                count += 1
                s = re.sub(regex_pattern, "", s)
        return count
    
    @staticmethod
    def _get_regex_pattern(role: str) -> str:
        return r"\b" + role + r"s?\b"

In [6]:
class LegalRepresentationDetector:
    def __init__(self, opinion: int, party_letter: str, identified_party_roles: Set[str]):
        self._opinion = opinion
        self._party_letter = party_letter
        self._identified_party_roles = identified_party_roles
        
        self._identified_firms: Set[int] = set()
        
    def predict(self) -> Set[int]:
        with open(f"Opinions/Opinion{self._opinion}.txt", "r") as file:
            sentences = get_pre_processed_sentences(file.read())
        
        for sentence in sentences:
            if not re.search(r"firm([1-9])", sentence):
                continue
            if re.search(r"party([a-z])", sentence):
                # if sentence contains party mention then we don't resort to party role usage
                if f"party{self._party_letter}" in sentence:
                    firm_numbers = self._get_firm_numbers(sentence)
                    self._identified_firms.update(firm_numbers)
            elif any([re.search(r"\b" + role + r"s?\b", sentence) for role in self._identified_party_roles]):
                firm_numbers = self._get_firm_numbers(sentence)
                self._identified_firms.update(firm_numbers)
                
        return self._identified_firms
    
    @staticmethod
    def _get_firm_numbers(sentence: str) -> Set[int]:
        matches = re.findall("firm\d+", sentence)
        return {int(firm.replace("firm", "")) for firm in matches}

### Results

In [7]:
df = pd.read_csv("data.csv")

In [8]:
input_df = (
    df[["Opinion", "Party Letter"]]
    .rename(columns={"Opinion": "opinion", "Party Letter": "party_letter"})
)
input_df["party_letter"] = input_df["party_letter"].str.lower()

In [9]:
predictions = []

for record in input_df.to_dict(orient="records"):
    # predict legal party roles
    l_p_r_d = LegalPartyRolesDetector(**record)
    record["identified_party_roles"] = l_p_r_d.predict()
    # predict law firm(s)
    l_r_d = LegalRepresentationDetector(**record)
    record["identified_law_firms"] = l_r_d.predict()
    
    # transform data representation
    if record["identified_party_roles"] == set():
        record["identified_party_roles"] = np.nan
    else: 
        record["identified_party_roles"] = ", ".join(record["identified_party_roles"])
    
    if record["identified_law_firms"] == set():
        record["identified_law_firms"] = np.nan
    else:
        record["identified_law_firms"] = ",".join([str(firm) for firm in record["identified_law_firms"]])
    
    # append to predictions
    predictions.append(record)

In [10]:
prediction_df = (
    pd.DataFrame(predictions)
    .rename(columns={
        "opinion": "Opinion", 
        "party_letter": "Party Letter",
        "identified_party_roles": "Party type(s) - Modeled",
        "identified_law_firms": "Law firm(s) - Model"
    })
)
prediction_df["Party Letter"] = prediction_df["Party Letter"].str.upper()

In [11]:
results_df = pd.merge(df, prediction_df, on=["Opinion", "Party Letter"], suffixes=("_", ""))[[
    "Opinion", 
    "Party Letter", 
    "Party type(s) - Annotated", 
    "Party type(s) - Modeled", 
    "Law firm(s) - Annotated", 
    "Law firm(s) - Model",
]]

In [12]:
results_df.head(n=2)

Unnamed: 0,Opinion,Party Letter,Party type(s) - Annotated,Party type(s) - Modeled,Law firm(s) - Annotated,Law firm(s) - Model
0,1,B,defendant,defendant,2,2
1,1,G,defendant,,2,2


In [13]:
def get_rates(results_df: pd.DataFrame, actual_col: str, pred_col: str, separator: str) -> Tuple[int]:
    """
    Parameters
    ----------
    results_df : pandas.DataFrame
        Columns:
            Name: Party type(s) - Annotated, dtype: object
            Name: Party type(s) - Modeled, dtype: object
            Name: Law firm(s) - Annotated, dtype: object
            Name: Law firm(s) - Model, dtype: object
    """
    tp, fp, fn = 0, 0, 0
    
    for record in results_df.to_dict(orient="records"):
        if type(record[actual_col]) == float:
            actuals = []
        else:
            actuals = record[actual_col].split(separator)

        if type(record[pred_col]) == float:
            preds = []
        else:
            preds = record[pred_col].split(separator)

        for actual in actuals:
            if actual in preds:
                tp += 1
            else:
                fn += 1
        for pred in preds:
            if pred not in actuals:
                fp += 1
    
    return tp, fp, fn

In [14]:
def get_precision(tp: int, fp: int) -> float:
    return tp / (tp + fp)


def get_recall(tp: int, fn: int) -> float:
    return tp / (tp + fn)


def get_f1_score(tp: int, fp: int, fn: int) -> float:
    return 2 * get_precision(tp, fp) * get_recall(tp, fn) / (get_precision(tp, fp) + get_recall(tp, fn))

### Party types

In [15]:
actual_col, pred_col = "Party type(s) - Annotated", "Party type(s) - Modeled"
separator = ", "

tp, fp, fn = get_rates(
    results_df, 
    actual_col=actual_col,
    pred_col=pred_col,
    separator=separator,
)

In [16]:
precision = get_precision(tp, fp)
recall = get_recall(tp, fn)
f1 = get_f1_score(tp, fp, fn)

print(f"Precision: {round(precision, 2)}\nRecall: {round(recall, 2)}\nF1 score: {round(f1, 2)}")

Precision: 0.94
Recall: 0.92
F1 score: 0.93


### Law firms

In [17]:
actual_col, pred_col = "Law firm(s) - Annotated", "Law firm(s) - Model"
separator = ","

tp, fp, fn = get_rates(
    results_df, 
    actual_col=actual_col,
    pred_col=pred_col,
    separator=separator,
)

In [18]:
precision = get_precision(tp, fp)
recall = get_recall(tp, fn)
f1 = get_f1_score(tp, fp, fn)

print(f"Precision: {round(precision, 2)}\nRecall: {round(recall, 2)}\nF1 score: {round(f1, 2)}")

Precision: 1.0
Recall: 0.82
F1 score: 0.9
