In [None]:
"""
In this notebook we train, validate, and visualize different topic models for the basic_stopwords
preprocessing routine. The chosen dataset consists of all cases appearing in the citation graph
whose jurisdiction is Illinois and whose decision dates occured after 1950.
"""

import pandas as pd

%matplotlib inline
from wordcloud import WordCloud
import matplotlib.pyplot as plt
import os
import time

from gensim.corpora.dictionary import Dictionary
from gensim.models.ldamodel import LdaModel
from gensim.models import CoherenceModel

import networkx as nx

from sklearn.neighbors import NearestNeighbors

import logging
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.ERROR)

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

processed_data_header = '/Users/jhamer90811/Documents/Insight/legal_topic_modeling/data_uncompressed/basic_stopwords_42k'

datasets = ['cases_IL_after1950_42k']

graph_path = '/Users/jhamer90811/Documents/Insight/legal_topic_modeling/citation_graph.gpickle'

output_header = '/Users/jhamer90811/Documents/Insight/legal_topic_modeling/validation_output/basic_stopwords_42k'

wordcloud_header_top = '/Users/jhamer90811/Documents/Insight/legal_topic_modeling/wordclouds/basic_stopwords_42k'




In [None]:
def parse_list_col(df, col_to_parse):
    df.loc[:, col_to_parse] = df[col_to_parse].apply(lambda x: x.strip('[]').split(','))
    df.loc[:, col_to_parse] = df[col_to_parse].apply(lambda x: [t.strip().strip("'") for t in x])
    
# HELPER FUNCTIONS FOR PERPLEXITY, COHERENCE, AND WORDCLOUDS
def get_perplexity(model, corpus):
    return 2**(-model.log_perplexity(corpus))

def get_coherence(model, texts, dictionary):
    # texts should be lists of terms, not the BoW representation
    coherence_model = CoherenceModel(model=model, texts=texts, 
                                 dictionary=dictionary, coherence='c_v')
    return coherence_model.get_coherence()

def get_wordclouds(model, num_words=250, save_file=None, num_topics=10):
    for i, topic in  enumerate(model.show_topics(num_topics=num_topics, num_words=num_words, formatted=False)):
        topic_dict = {w:v for (w,v) in topic[1]}

        wordcloud = WordCloud(width = 800, height = 800, 
                        background_color ='white',
                        min_font_size = 10).generate_from_frequencies(topic_dict) 

        # plot the WordCloud image                        
        plt.figure(figsize = (8, 8), facecolor = None) 
        plt.imshow(wordcloud) 
        plt.axis("off") 
        plt.tight_layout(pad = 0) 
        if save_file:
            path = os.path.join(save_file, f'topic_{i+1}.png')
            plt.savefig(path)
            plt.close()
        else:
            plt.show()

# HELPER FUNCTIONS FOR CITATION-GRAPH KNN VALIDATION

def unpack_topics(df, num_topics):
    new_df = pd.DataFrame(columns=['case_id']+[f'topic_{i}' for i in range(num_topics)])
    for i, row in df.iterrows():
        new_row = {}
        new_row['case_id'] = row.case_id
        topics = row.topic_vector
        for t in topics:
            topic_num = t[0]
            topic_val = t[1]
            new_row[f'topic_{topic_num}'] = topic_val
        new_df = new_df.append(new_row, ignore_index=True)
    new_df = new_df.fillna(0)
    new_df['case_id'] = new_df['case_id'].apply(lambda x: int(x))
    return new_df

def get_nearest_neighbors(df, n_neighbors, nn_model):
    knearest = nn_model.kneighbors(n_neighbors=n_neighbors, return_distance=False)
    for k in range(n_neighbors):
        df[f'nn_{k}'] = [df.case_id[x[k]] for x in knearest]
        
def edge_length(row, k, graph):
    return nx.shortest_path_length(graph, row['case_id'], row[f'nn_{k}'])

def get_edge_lengths(df, n_neighbors, graph):
    for k in range(n_neighbors):
        df[f'cite_distance_{k}'] = df.apply(edge_length, k=k, graph=graph, axis=1)
        
def get_min_cite_dist(row, n_neighbors):
    return int(row[[f'cite_distance_{k}' for k in range(n_neighbors)]].min())

def get_mean_cite_dist(row, n_neighbors):
    return row[[f'cite_distance_{k}' for k in range(n_neighbors)]].mean()

def get_max_cite_dist(row, n_neighbors):
    return int(row[[f'cite_distance_{k}' for k in range(n_neighbors)]].max())

def knn_citation_validation(test_ids, lda_model, test_corpus, graph, n_neighbors):
    test_data = pd.DataFrame(test_ids, columns=['case_id'])
    test_data['topic_vector'] = [lda_model[op] for op in test_corpus]
    nodes = list(graph.nodes)
    test_data = test_data.loc[test_data.case_id.isin(nodes),:]
    nodes = None
    test_data.reset_index(drop=True, inplace=True)
    
    test_data = unpack_topics(test_data, num_topics=15)
    
    nn = NearestNeighbors()

    X = test_data.drop(columns='case_id').values

    nn.fit(X)
    
    get_nearest_neighbors(test_data, n_neighbors, nn)
    
    get_edge_lengths(test_data, n_neighbors, graph)
    
    test_data['min_cite_dist'] = test_data.apply(get_min_cite_dist, axis=1, n_neighbors=n_neighbors)
    test_data['mean_cite_dist'] = test_data.apply(get_mean_cite_dist, axis=1, n_neighbors=n_neighbors)
    test_data['max_cite_dist'] = test_data.apply(get_max_cite_dist, axis=1, n_neighbors=n_neighbors)
    
    return test_data

In [None]:
# Obtain largest connected component of citation graph and other static variables.

G = nx.read_gpickle(graph_path)
big_subgraph = nx.subgraph(G, list(nx.connected_components(G))[0])
G = None
seed = 9
num_topics = [5, 8, 10, 12, 15]

In [None]:
for dataset in datasets:
    print(f'BEGINNING VALIDATION OF {dataset}...')
    data = pd.read_csv(os.path.join(processed_data_header, dataset + '_processed.csv'))
    # Accidentally appended header several times to data set. Correct for this.
    data = data.loc[data.case_id!='case_id',:]
    data.loc[:,'case_id'] = data.case_id.apply(lambda x: int(x))
    parse_list_col(data, 'opinion')
    
    # Shuffle data to ensure jurisdictions are mixed properly.

    data = data.sample(frac=1, random_state=seed).reset_index(drop=True)

    # Split into train/test sets

    split = int(0.8*data.shape[0])
    train_ops = data.loc[:split, 'opinion']
    test_ops = data.loc[split:, 'opinion']

    # Build gensim dictionary

    op_dictionary = Dictionary(train_ops.to_list())
    train_op_corpus = [op_dictionary.doc2bow(op) for op in train_ops.to_list()]
    test_op_corpus = [op_dictionary.doc2bow(op) for op in test_ops.to_list()]
    
    # BEGIN VALIDATION. THIS WILL TAKE SOME TIME.

    train_perplexity = []
    test_perplexity = []
    train_coherence = []
    test_coherence = []
    min_cite_dist_mean = []
    min_cite_dist_sd = []
    avg_cite_dist_mean = []
    avg_cite_dist_sd = []
    max_cite_dist_mean = []
    max_cite_dist_sd = []

    test_ids = data.loc[split:, 'case_id'].to_list()
    data = None

    wordcloud_header = os.path.join(wordcloud_header_top, dataset)
    os.mkdir(wordcloud_header)

    start = time.time()

    for nt in num_topics:
        iter_start=time.time()
        print(f'Processing model with {nt} topics...')
        temp_time= time.time()
        lda = LdaModel(train_op_corpus, id2word=op_dictionary, num_topics=nt)
        print(f'Model training done. Time: {round(time.time()-temp_time)}')
        print('Computing perplexity on train set.')
        temp_time= time.time()
        train_perplexity.append(get_perplexity(lda, train_op_corpus))
        print(f'Done. Time: {round(time.time()-temp_time)}')
        print('Computing perplexity on test set.')
        temp_time= time.time()
        test_perplexity.append(get_perplexity(lda, test_op_corpus))
        print(f'Done. Time: {round(time.time()-temp_time)}')
        print('Computing coherence on train set.')
        temp_time= time.time()
        train_coherence.append(get_coherence(lda, train_ops.to_list(), op_dictionary))
        print(f'Done. Time: {round(time.time()-temp_time)}')
        print('Computing coherence on test set.')
        temp_time= time.time()
        test_coherence.append(get_coherence(lda, test_ops.to_list(), op_dictionary))
        print(f'Done. Time: {round(time.time()-temp_time)}')
        print('Computing citation graph validation metrics.')
        temp_time= time.time()
        metric_cols = ['min_cite_dist', 'mean_cite_dist', 'max_cite_dist']
        citation_dist_results = knn_citation_validation(test_ids, lda, test_op_corpus, big_subgraph, 5)[metric_cols]
        min_cite_dist_mean.append(citation_dist_results.min_cite_dist.mean())
        min_cite_dist_sd.append(citation_dist_results.min_cite_dist.std())
        avg_cite_dist_mean.append(citation_dist_results.mean_cite_dist.mean())
        avg_cite_dist_sd.append(citation_dist_results.mean_cite_dist.std())
        max_cite_dist_mean.append(citation_dist_results.max_cite_dist.mean())
        max_cite_dist_sd.append(citation_dist_results.max_cite_dist.std())
        print(f'Done. Time: {round(time.time()-temp_time)}')
        print('Saving wordclouds...')
        temp_time= time.time()
        os.mkdir(os.path.join(wordcloud_header, f'num_topics_{nt}'))
        get_wordclouds(lda, save_file=os.path.join(wordcloud_header, f'num_topics_{nt}'), num_topics=nt)
        print(f'Done. Time: {round(time.time()-temp_time)}')
        print(f'Done with full iteration. TOTAL TIME: {round(time.time()-iter_start)}')
        print('######################################')
    print(f'FINISHED. TOTAL TIME ELAPSED: {time.time()-start}')

    results_df = pd.DataFrame({'num_topics': num_topics,
                              'train_perplexity': train_perplexity,
                              'test_perplexity': test_perplexity,
                              'train_coherence': train_coherence,
                              'test_coherence': test_coherence,
                              'min_cite_dist_mean': min_cite_dist_mean,
                              'min_cite_dist_sd': min_cite_dist_sd,
                              'avg_cite_dist_mean': avg_cite_dist_mean,
                              'avg_cite_dist_sd': avg_cite_dist_sd,
                              'max_cite_dist_mean': max_cite_dist_mean,
                              'max_cite_dist_sd': max_cite_dist_sd})

    results_df.to_csv(os.path.join(output_header, dataset + '.csv'), index=False)

    train_ops = None
    train_op_corpus = None
    test_ops = None
    test_op_corpus = None
    op_dictionary = None
    test_ids = None