In [236]:
import pandas as pd
import numpy as np

from nltk.corpus import stopwords
from nltk.stem import SnowballStemmer
from nltk.corpus import wordnet
from nltk.stem import WordNetLemmatizer
from nltk.tokenize import word_tokenize
from sklearn.preprocessing import LabelEncoder
import string
import re
import matplotlib.pyplot as plt
import itertools
from transformers import AutoTokenizer, AutoModel, AutoModelForTokenClassification, pipeline,BertTokenizer, BertModel
import spacy
spcy = spacy.load("en_core_web_sm")
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
from sklearn.model_selection import StratifiedKFold
from bertopic import BERTopic
import torch
import torch.nn as nn
import torch.functional as f
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from bertopic.vectorizers import ClassTfidfTransformer
from bertopic.dimensionality import BaseDimensionalityReduction
from sklearn.linear_model import LogisticRegression
%matplotlib inline

In [10]:
dataset.to_parquet("legal_cases.parquet")

In [223]:
dataset = pd.read_parquet("legal_cases.parquet")

In [76]:
dataset.head()

Unnamed: 0,id,text,domain,clean_text
0,r-e4EYcBD5gMZwcz41zP,UNITED STATES DISTRICT COURT \nEASTERN DISTRIC...,consumer fraud,united states district court eastern district ...
1,i9H5DocBD5gMZwcztj0y,IN THE UNITED STATES DISTRICT COURT \nFOR THE ...,privacy,in the united states district court for the di...
2,SMn3DYcBD5gMZwcz-hwH,IN THE UNITED STATES DISTRICT COURT\n FOR THE ...,privacy,in the united states district court for the we...
3,GMIWDYcBD5gMZwczDQBb,Case No. _______________ \n \n \nCLASS ACTION ...,criminal & enforcement,case no class action complaint for violations ...
4,lELw_IgBF5pVm5zYONwC,UNITED STATES DISTRICT COURT \n SOUTHERN DISTR...,consumer fraud,united states district court southern district...


Clean Text

In [77]:
def run_preprocessing(text):
    text = preprocess(text)
    text = remove_stopwords(text)
    # text = lemmatizing(text)
    # text = remove_double_words(text)
    return text

def preprocess(text):
    text = re.sub(r"(@\[A-Za-z0-9]+)|([^0-9A-Za-z \t])|(\w+:\/\/\S+)|^rt|http.+?", "", text)
    text = text.lower()
    text = text.strip()
    text = re.compile('<.*?>').sub('', text)
    text = re.compile('[%s]' % re.escape(string.punctuation)).sub(' ', text)
    text = re.sub('\s+', ' ', text)
    text = re.sub(r'\[[0-9]*\]', ' ', text)  # [0-9]
    text = re.sub(r'[^\w\s]', '', str(text).lower().strip())
    text = re.sub(r'\d', ' ', text)
    text = re.sub(r'\s+', ' ', text)
    return text

def remove_stopwords(text):
    chars = [i for i in text.split() if i not in stopwords.words('english') and len(i) > 2]
    return ' '.join(chars)

def remove_double_words(text):
    return ' '.join(list(set([i for i in text.split(' ')])))

def lemmatizing(text):
    lemmatizer = WordNetLemmatizer()
    words = nltk.word_tokenize(text)
    chars = [lemmatizer.lemmatize(word) for word in words]

    return " ".join(chars)

In [78]:
dataset['clean_text'] = dataset['text'].apply(run_preprocessing)

In [87]:
def get_most_common_words(df, threshold=0.8):
    row_thresh_count = threshold * len(df)
    row_count = Counter()
    for i, row in df.iterrows():
        words = set(re.findall(r"\w+", row['clean_text']))
        row_count.update(words)

    return [word for word, count in row_count.items() if count >= row_thresh_count]

In [88]:
common_words = get_most_common_words(dataset)

In [92]:
def remove_common_words(common_words, text):
    words = text.split()
    filtered_words = [word for word in words if word not in common_words]
    return " ".join(filtered_words)

In [94]:
dataset['clean_text'] = dataset['clean_text'].apply(lambda x: remove_common_words(common_words, x))

In [159]:
def remove_entities(text, entity_types=['GPE', 'DATE', 'ORDINAL', 'PERSON']):
    doc = spcy(text)
    unique_ents = []
    for ent in doc.ents:
        if str(ent) in unique_ents:
            pass
        else:
            if ent.label_ in entity_types:
                unique_ents.append(str(ent))    
    
    words = text.split()
    filtered_words = [word for word in words if word not in unique_ents]
    return " ".join(filtered_words)

Encode Labels

In [224]:
le = LabelEncoder()
dataset['label'] = le.fit_transform(dataset['domain'])
label_mapping = dict(zip(range(len(le.classes_)), le.classes_))

## Topic Modeling

In [100]:
empty_dimensionality_model = BaseDimensionalityReduction()
clf = LogisticRegression()
ctfidf_model = ClassTfidfTransformer(reduce_frequent_words=True)

In [101]:
topic_model= BERTopic(
        umap_model=empty_dimensionality_model,
        hdbscan_model=clf,
        ctfidf_model=ctfidf_model
)
topics, probs = topic_model.fit_transform(dataset['clean_text'].values, y=dataset['label'].values)

In [102]:
mappings = topic_model.topic_mapper_.get_mappings()
mapping = {value: label_mapping[mappings[value]] for value in mappings.keys()}

In [108]:
df = topic_model.get_topic_info()
df["Class"] = df.Topic.map(mapping)

In [110]:
df

Unnamed: 0,Topic,Count,Name,Representation,Representative_Docs,Class
0,-1,299,-1_plan_services_new_contact,"[plan, services, new, contact, telephone, heal...",[john p kristensen sbn david l weisberg sbn kr...,
1,0,144,0_company_companys_statements_financial,"[company, companys, statements, financial, sec...",[southern new york andre brodeur individually ...,employment & labor
2,1,87,1_website_blind_access_services,"[website, blind, access, services, visuallyimp...",[southern new york x shael cruz demand embrace...,discrimination
3,2,68,2_market_prices_apple_price,"[market, prices, apple, price, facebook, produ...",[maine portland huntalpine club llc demanded m...,antitrust
4,3,63,3_debt_collection_letter_fdcpa,"[debt, collection, letter, fdcpa, consumer, co...",[eastern new york ouriel ezra demand fh cann a...,intellectual property & communication
5,4,55,4_generic_drug_patent_fda,"[generic, drug, patent, fda, market, vascepa, ...",[eastern virginia alexandria division louisian...,healthcare
6,5,49,5_hours_pay_york_worked,"[hours, pay, york, worked, new, nyll, flsa, em...",[southern new york x jose taveras collective v...,criminal & enforcement
7,6,44,6_telephone_detail_garage_smart,"[telephone, detail, garage, smart, calls, cali...",[todd friedman sbn meghan george sbn adrian r ...,securities
8,7,41,7_flsa_texas_overtime_hours,"[flsa, texas, overtime, hours, worked, pay, em...",[southern texas corpus christi division cassan...,privacy
9,8,33,8_telephone_calls_tcpa_text,"[telephone, calls, tcpa, text, messages, cellu...",[offices ronald marron ronald marron sbn ronco...,products liability and mass tort


In [106]:
topic_model = BERTopic()
topics, _ = topic_model.fit_transform(dataset['clean_text'].values)
dataset['topics'] = topic_model.get_document_info(dataset['clean_text'].values)['Representation']

In [107]:
topic_model.get_document_info(dataset['clean_text'].values)

Unnamed: 0,Document,Topic,Name,Representation,Representative_Docs,Top_n_words,Probability,Representative_document
0,eastern new york chaya r denciger individually...,3,3_debt_collection_letter_fdcpa,"[debt, collection, letter, fdcpa, consumer, co...",[eastern new york ouriel ezra demand fh cann a...,debt - collection - letter - fdcpa - consumer ...,1.000000,False
1,kansas george jones ad astra recovery services...,8,8_telephone_calls_tcpa_text,"[telephone, calls, tcpa, text, messages, cellu...",[offices ronald marron ronald marron sbn ronco...,telephone - calls - tcpa - text - messages - c...,0.569314,False
2,western pennsylvania pittsburgh division mark ...,-1,-1_plan_services_new_contact,"[plan, services, new, contact, telephone, heal...",[john p kristensen sbn david l weisberg sbn kr...,plan - services - new - contact - telephone - ...,0.000000,False
3,racketeer influenced corrupt organizations ohi...,-1,-1_plan_services_new_contact,"[plan, services, new, contact, telephone, heal...",[john p kristensen sbn david l weisberg sbn kr...,plan - services - new - contact - telephone - ...,0.000000,False
4,southern new york cv bobby phillips individual...,-1,-1_plan_services_new_contact,"[plan, services, new, contact, telephone, heal...",[john p kristensen sbn david l weisberg sbn kr...,plan - services - new - contact - telephone - ...,0.000000,False
...,...,...,...,...,...,...,...,...
1199,lee cirsch ca bar lanier firm pc wilshire blvd...,2,2_market_prices_apple_price,"[market, prices, apple, price, facebook, produ...",[maine portland huntalpine club llc demanded m...,market - prices - apple - price - facebook - p...,0.803713,False
1200,southern florida shawn moore demanded alliance...,11,11_debt_collection_florida_mcdonalds,"[debt, collection, florida, mcdonalds, pounder...",[middle florida jaron finklea demanded credenc...,debt - collection - florida - mcdonalds - poun...,0.949810,True
1201,new mexico gerald davis jr individually collec...,7,7_flsa_texas_overtime_hours,"[flsa, texas, overtime, hours, worked, pay, em...",[southern texas corpus christi division cassan...,flsa - texas - overtime - hours - worked - pay...,0.409937,False
1202,colorado x david katt anb bank injunctive decl...,1,1_website_blind_access_services,"[website, blind, access, services, visuallyimp...",[southern new york x shael cruz demand embrace...,website - blind - access - services - visually...,0.497660,False


____________

In [187]:
word_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
vectorization_model = BertModel.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [176]:
tokens = dataset['clean_text'].apply(lambda x: word_tokenize(x)[-500:])

In [179]:
data = [" ".join(i) for i in tokens]

In [188]:
inputs = word_tokenizer(data,
                        padding=True,
                        truncation=True,
                        return_tensors='pt')

with torch.no_grad():
    output = vectorization_model(**inputs)
    
embedding = output.last_hidden_state

In [210]:
def mean_pooling(embeddings, attention_mask):
    token_embeddings = embeddings[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

In [216]:
embeddings = mean_pooling(output, inputs["attention_mask"])

In [217]:
embeddings.shape

torch.Size([1204, 768])

In [218]:
dataset['embedding'] = [i.numpy().tolist() for i in embeddings]

In [191]:
import pickle
file_path = 'dataset.pkl'
with open(file_path, 'wb') as f:
    pickle.dump(dataset, f)

In [229]:
labels = torch.tensor(dataset['label'].values)

____________

In [260]:
class CNNModel(nn.Module):
    def __init__(self, embedding_size, hidden_size, num_classes):
        super(CNNModel, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv1d(1, hidden_size, kernel_size=3, padding=1, stride=2),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(),
            nn.Conv1d(hidden_size, hidden_size, kernel_size=3, padding=1, stride=2),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(),
            nn.AdaptiveMaxPool1d(1),
            nn.Flatten()
            )
        
        self.classifier_head = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        x = self.cnn(x)
        x = self.classifier_head(x)
        return x

In [265]:
def train(model, train_loader, optimizer, criterion):
    batch_losses = []
    for idx, batch in enumerate(train_loader):
        optimizer.zero_grad()
        inpt = batch['embedding'].unsqueeze(1)
        output = model(inpt)
        loss = criterion(output, batch['label'])
        batch_losses.append(loss)
        loss.backward()
        optimizer.step()
        
    return batch_losses

In [204]:
class EmbeddingDataset(Dataset):
    def __init__(self, embedding, labels):
        self.embedding = embedding
        self.labels = labels
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        label = self.labels[idx]
        embed = self.embedding[idx]
        return {'label': label, 'embedding': embed}

In [230]:
dset = EmbeddingDataset(embedding=embeddings,
                          labels=labels)

In [231]:
loader = DataLoader(dataset=dset, batch_size=32, shuffle=True)

In [254]:
embeddings.shape

torch.Size([1204, 768])

In [266]:
model = CNNModel(embedding_size=768,
                hidden_size=50,
                num_classes=12)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.1)

In [268]:
current_loss = 0
current_acc = 0
all_losses = []
acc_list = []

for i in range(10):
    loss = train(model, loader, optimizer, criterion) 
    all_losses.append(loss)

In [278]:
def evaluate(dataloader):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, batch in enumerate(dataloader):
            inpt = batch['embedding'].unsqueeze(1)
            predicted_label = model(inpt)
            loss = criterion(predicted_label, batch['label'])
            total_acc += (predicted_label.argmax(1) == batch['label']).sum().item()
            total_count += batch['label'].size(0)
    return total_acc / total_count

In [279]:
evaluate(loader)

0.5132890365448505