In [None]:
"""
This notebook performs a brief performance analysis on the validation metrics generated by
the topic models using the baseline preprocessing routine.
"""

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
%matplotlib inline

results_header = '/Users/jhamer90811/Documents/Insight/legal_topic_modeling/validation_output/baseline_12k'


datasets = ['random_cases2', 'cases_after1950_12k', 'cases_IL_12k', 'cases_IL_after1950_12k']



In [None]:
data = pd.DataFrame()
for dataset in datasets:
    new_data = pd.read_csv(os.path.join(results_header, dataset + '.csv'))
    new_data['name'] = dataset
    data = data.append(new_data, ignore_index=True)

In [None]:
data

In [None]:
# Train perplexity
ax = sns.lineplot('num_topics', 'train_perplexity', hue='name', data=data)
ax.set_title('Train Perplexity vs. Number of Topics; by Input Dataset')
plt.show()

In [None]:
# Test perplexity
ax = sns.lineplot('num_topics', 'test_perplexity', hue='name', data=data)
ax.set_title('Test Perplexity vs. Number of Topics; by Input Dataset')
plt.show()

In [None]:
# Train coherence
ax = sns.lineplot('num_topics', 'train_coherence', hue='name', data=data)
ax.set_title('Train Coherence vs. Number of Topics; by Input Dataset')
plt.show()

In [None]:
# Test coherence
ax = sns.lineplot('num_topics', 'test_coherence', hue='name', data=data)
ax.set_title('Test Coherence vs. Number of Topics; by Input Dataset')
plt.show()

In [None]:
# Min citation distance
ax = sns.lineplot('num_topics', 'min_cite_dist_mean', hue='name', data=data)
ax.set_title('Min Citation Distance vs. Number of Topics; by Input Dataset')
plt.show()

In [None]:
# Average citation distance
ax = sns.lineplot('num_topics', 'avg_cite_dist_mean', hue='name', data=data)
ax.set_title('Avg Citation Distance vs. Number of Topics; by Input Dataset')
plt.show()

In [None]:
# Maximum citation distance
ax = sns.lineplot('num_topics', 'max_cite_dist_mean', hue='name', data=data)
ax.set_title('Max Citation Distance vs. Number of Topics; by Input Dataset')
plt.show()

In [None]:
# Use if stdev bars are desired

def plot_with_error_bars(df, x, y, y_err, group_by):
    groups = df.groupby(group_by)[[x, y, y_err]]

    fig, ax = plt.subplots()

    for k, v in groups:
        v.plot(label=k, x=x, y=y, ax=ax)
        plt.fill_between(v[x], v[y]-v[y_err], v[y]+v[y_err], alpha=0.1)

    ax.legend()
    ax.set_title(f'{y} vs. {x}; by {group_by}')

    plt.show()

In [None]:
# Min citation distance
plot_with_error_bars(data, 'num_topics', 'min_cite_dist_mean', 'min_cite_dist_sd', 'name')

In [None]:
# Average citation distance
plot_with_error_bars(data, 'num_topics', 'avg_cite_dist_mean', 'avg_cite_dist_sd', 'name')

In [None]:
# Maximum citation distance
plot_with_error_bars(data, 'num_topics', 'max_cite_dist_mean', 'max_cite_dist_sd', 'name')