In [None]:
!pip install bertopic

In [2]:
import pandas as pd
import numpy as np
import string
import json
import re

from bertopic import BERTopic
import gensim.corpora as corpora 
from gensim.corpora.dictionary import Dictionary
from sentence_transformers import SentenceTransformer
from gensim.models.coherencemodel import CoherenceModel
import spacy
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
nltk.download('wordnet')
nltk.download('stopwords')

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Unzipping corpora/wordnet.zip.
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


True

# Import Dataset

In [48]:
data = pd.read_json('/content/us_test_data_final_OFFICIAL.jsonl', lines = True)

# Pre-processing

In [4]:
def replace_semicolon(text, threshold=10):
    '''
    Get rid of semicolons.
    First split text into fragments between the semicolons. If the fragment 
    is longer than the threshold, turn the semicolon into a period. O.w treat
    it as a comma.
    Returns new text
    '''
    new_text = ""
    for subset in re.split(';', text):
        subset = subset.strip() # Clear off spaces
        # Check word count
        if len(subset.split()) > threshold:
            # Turn first char into uppercase
            new_text += ". " + subset[0].upper() + subset[1:]
        else:
            # Just append with a comma 
            new_text += ", " + subset

    return new_text

In [5]:
USC_re = re.compile('[Uu]\.*[Ss]\.*[Cc]\.]+')
PAREN_re = re.compile('\([^(]+\ [^\(]+\)')
BAD_PUNCT_RE = re.compile(r'([%s])' % re.escape('"#%&\*\+/<=>@[\]^{|}~_'), re.UNICODE)
BULLET_RE = re.compile('\n[\ \t]*`*\([a-zA-Z0-9]*\)')
DASH_RE = re.compile('--+')
WHITESPACE_RE = re.compile('\s+')
EMPTY_SENT_RE = re.compile('[,\.]\ *[\.,]')
FIX_START_RE = re.compile('^[^A-Za-z]*')
FIX_PERIOD = re.compile('\.([A-Za-z])')
SECTION_HEADER_RE = re.compile('SECTION [0-9]{1,2}\.|\nSEC\.* [0-9]{1,2}\.|Sec\.* [0-9]{1,2}\.')

FIX_PERIOD = re.compile('\.([A-Za-z])')

SECTION_HEADER_RE = re.compile('SECTION [0-9]{1,2}\.|\nSEC\.* [0-9]{1,2}\.|Sec\.* [0-9]{1,2}\.')

In [6]:
def clean_text(text):
    """
    Borrowed from the FNDS text processing with additional logic added in.
    Note: we do not take care of token breaking - assume SPACY's tokenizer
    will handle this for us.
    """

    # Indicate section headers, we need them for features
    text = SECTION_HEADER_RE.sub('SECTION-HEADER', text)
    # For simplicity later, remove '.' from most common acronym
    text = text.replace("U.S.", "US")
    text = text.replace('SEC.', 'Section')
    text = text.replace('Sec.', 'Section')
    text = USC_re.sub('USC', text)

    # Remove parantheticals because they are almost always references to laws 
    # We could add a special tag, but we just remove for now
    # Note we dont get rid of nested parens because that is a complex re
    #text = PAREN_re.sub('LAWREF', text)
    text = PAREN_re.sub('', text)
    
    # Get rid of enums as bullets or ` as bullets
    text = BULLET_RE.sub(' ',text)
    
    # Clean html 
    text = text.replace('&lt;all&gt;', '')

    # Remove annoying punctuation, that's not relevant
    text = BAD_PUNCT_RE.sub('', text)

    # Get rid of long sequences of dashes - these are formating
    text = DASH_RE.sub( ' ', text)

    # removing newlines, tabs, and extra spaces.
    text = WHITESPACE_RE.sub(' ', text)
    
    # If we ended up with "empty" sentences - get rid of them.
    text = EMPTY_SENT_RE.sub('.', text)
    
    # Attempt to create sentences from bullets 
    text = replace_semicolon(text)
    
    # Fix weird period issues + start of text weirdness
    #text = re.sub('\.(?=[A-Z])', '  . ', text)
    # Get rid of anything thats not a word from the start of the text
    text = FIX_START_RE.sub( '', text)
    # Sometimes periods get formatted weird, make sure there is a space between periods and start of sent   
    text = FIX_PERIOD.sub(". \g<1>", text)

    # Fix quotes
    text = text.replace('``', '"')
    text = text.replace('\'\'', '"')

    # Add special punct back in
    text = text.replace('SECTION-HEADER', '')

    # Remove punctuations
    text = re.sub('[,\.!?()]', '', text)

    # Return lowercased sentences
    text = text.lower()

    return text

In [49]:
data['clean_text'] = data.text.map(clean_text)
# data['clean_summary'] = data.summary.map(clean_text)
# data['clean_title'] = data.title.map(clean_text)

selected = pd.DataFrame()
# selected['text'] = data.clean_summary
selected['text'] = data.clean_text

### Stopwords & Lemmatization

In [50]:
def lemmatize_df(corpus):
    tokens = corpus.split(' ')
    stop_words = stopwords.words('english')
    stop_words.extend(['section','shall', 'act', 'secretary', 'subsection', 'year','may', 'state', 'paragraph', 'program', 'short', 'title', 'cite'])
    filtered = [word for word in tokens if word not in stop_words]

    return " ".join(filtered)

In [53]:
import spacy
nlp = spacy.load("en_core_web_sm", disable=['parser', 'ner'])

selected['text'] = selected['text'].apply(lambda x: re.sub("[0-9]", "", x))
selected['clean_text'] = selected['text'].map(lemmatize_df)
selected['clean_text'] = selected['clean_text'].apply(lambda x: " ".join([w.lemma_ for w in nlp(x)]))
selected['clean_text'] = selected['clean_text'].apply(lambda x: " ".join([ele for ele in x.split() if len(ele) > 1]))

In [54]:
selected.head() 

Unnamed: 0,text,clean_text
0,"short title this act may be cited as the ""nat...",cite national science education tax incentive ...
1,"short title this act may be cited as the ""sma...",cite small business expansion hire business cr...
2,release of documents captured in iraq and afg...,release document capture iraq afghanistan gene...
3,"short title this act may be cited as the ""nat...",cite national cancer finding congress make fol...
4,"short title this act may be cited as the ""mil...",cite military call up relief act waiver early ...


# Baseline BERTopic

In [11]:
model = BERTopic(verbose=True)
topics, probabilities = model.fit_transform(selected['clean_text'])

Downloading:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/10.2k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/612 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/39.3k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/349 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/350 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/13.2k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Batches:   0%|          | 0/103 [00:00<?, ?it/s]

2022-04-15 05:01:19,525 - BERTopic - Transformed documents to Embeddings
2022-04-15 05:01:47,259 - BERTopic - Reduced dimensionality with UMAP
2022-04-15 05:01:47,468 - BERTopic - Clustered UMAP embeddings with HDBSCAN


In [12]:
len(model.get_topic_info())

72

In [None]:
len(data)

In [13]:
model.get_topic_info()

Unnamed: 0,Topic,Count,Name
0,-1,1192,-1_service_insert_states_amend
1,0,81,0_hospital_medicare_health_service
2,1,78,1_taxable_tax_corporation_income
3,2,78,2_consumer_institution_depository_transaction
4,3,76,3_transportation_motor_vehicle_carrier
...,...,...,...
67,66,12,66_woman_abortion_pregnancy_pregnant
68,67,12,67_unemployment_compensation_short_week
69,68,12,68_district_columbia_puerto_rico
70,69,11,69_energy_earth_rare_technology


In [14]:
def display_topics(model,top_n):
    for i in range(0,top_n):
        print("Topic %d:" % (i), "(%d documents)" %(model.get_topic_freq(i)))
        words = []
        for tup in model.get_topic(i):
            words.append(tup[0])
        print(words)
        print("="*150)
        print()

display_topics(model, 11)

Topic 0: (81 documents)
['hospital', 'medicare', 'health', 'service', 'care', 'physician', 'payment', 'medical', 'home', 'nurse']

Topic 1: (78 documents)
['taxable', 'tax', 'corporation', 'income', 'amount', 'property', 'code', 'gain', 'apply', 'dividend']

Topic 2: (78 documents)
['consumer', 'institution', 'depository', 'transaction', 'account', 'financial', 'bank', 'fee', 'reserve', 'banking']

Topic 3: (76 documents)
['transportation', 'motor', 'vehicle', 'carrier', 'rail', 'safety', 'equipment', 'passenger', 'railroad', 'highway']

Topic 4: (63 documents)
['water', 'project', 'administrator', 'pollution', 'river', 'environmental', 'lake', 'federal', 'system', 'basin']

Topic 5: (62 documents)
['disaster', 'emergency', 'fire', 'hurricane', 'building', 'response', 'firefighter', 'federal', 'agency', 'volunteer']

Topic 6: (60 documents)
['resolution', 'budget', 'house', 'bill', 'joint', 'motion', 'fiscal', 'appropriation', 'senate', 'congress']

Topic 7: (55 documents)
['employee',

### Coherence Score

In [15]:
cv = model.vectorizer_model
X = cv.fit_transform(selected['clean_text'])
doc_tokens = [text.split(" ") for text in selected['clean_text']]

In [16]:
id2word = corpora.Dictionary(doc_tokens)
texts = doc_tokens
corpus = [id2word.doc2bow(text) for text in texts]

topic_words = []
for i in range(len(model.get_topic_freq())-1):
  interim = []
  interim = [t[0] for t in model.get_topic(i)]
  topic_words.append(interim)

coherence_model = CoherenceModel(topics=topic_words, 
                                 texts=texts, 
                                 corpus=corpus, 
                                 dictionary=id2word, 
                                 coherence='c_v')

baseline_coherence = coherence_model.get_coherence()
baseline_coherence

0.6672959126365035

Unfortunately, BERTopic does not have perplexity and hence will not be able to calculate and compare to the LDA in terms of perplexity score

In [17]:
# log_perplexity = model.log_perplexity(tdf)
# perplexity = 2**(-log_perplexity)
# print('Perplexity: ',perplexity)

### Visualizations

In [18]:
model.visualize_topics()

In [19]:
model.visualize_barchart(topics=[0,1,2,3,4,5,6,7,8,9,10,11])

In [20]:
model.visualize_heatmap()

In [21]:
model.visualize_hierarchy()

In [22]:
model.visualize_term_rank()

# Topic Reduced BERT 

In [23]:
new_topics, new_probs = model.reduce_topics(selected['clean_text'], topics, probabilities, nr_topics="auto")

2022-04-15 05:04:23,599 - BERTopic - Reduced number of topics from 72 to 48


In [24]:
# display_topics(model,20)

In [25]:
model.get_topic_info()

Unnamed: 0,Topic,Count,Name
0,-1,1192,-1_service_states_insert_united
1,0,156,0_united_states_country_iran
2,1,137,1_child_plan_benefit_health
3,2,124,2_vehicle_motor_transportation_fuel
4,3,110,3_drug_prescription_food_product
5,4,85,4_land_area_wilderness_park
6,5,84,5_employee_employer_employment_labor
7,6,81,6_hospital_health_service_medicare
8,7,80,7_energy_nuclear_technology_electric
9,8,78,8_taxable_tax_corporation_income


### Coherence Score of Auto Topic Reduction model

In [26]:
topic_words = []
for i in range(len(model.get_topic_freq())-1):
  interim = []
  interim = [t[0] for t in model.get_topic(i)]
  topic_words.append(interim)

coherence_model = CoherenceModel(topics=topic_words, 
                                 texts=texts, 
                                 corpus=corpus, 
                                 dictionary=id2word, 
                                 coherence='c_v')

reduced_coherence = coherence_model.get_coherence()
reduced_coherence

0.684372498253919

### Visualizations

In [27]:
model.visualize_topics()

In [28]:
model.visualize_barchart([0,1,2,3,4,5,6,7,8,9,10,11])

In [29]:
model.visualize_heatmap()

In [30]:
model.visualize_hierarchy()

In [31]:
model.visualize_term_rank()

# Embedding

In [32]:
sentence_model = SentenceTransformer("nlpaueb/legal-bert-small-uncased")
embeddings = sentence_model.encode(selected['clean_text'], show_progress_bar=True)

em_model = BERTopic(calculate_probabilities=True, nr_topics="auto")
topics, probabilities = em_model.fit_transform(selected['clean_text'], embeddings)

Downloading:   0%|          | 0.00/391 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/11.5k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/989 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/141M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/222k [00:00<?, ?B/s]

Some weights of the model checkpoint at /root/.cache/torch/sentence_transformers/nlpaueb_legal-bert-small-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Batches:   0%|          | 0/103 [00:00<?, ?it/s]

2022-04-15 05:32:17,228 - BERTopic - Reduced dimensionality with UMAP
2022-04-15 05:32:18,146 - BERTopic - Clustered UMAP embeddings with HDBSCAN
2022-04-15 05:32:24,837 - BERTopic - Reduced number of topics from 67 to 44


In [33]:
em_model.get_topic_info()

Unnamed: 0,Topic,Count,Name
0,-1,1096,-1_follow_amend_states_provide
1,0,1000,0_commission_individual_agency_education
2,1,89,1_united_states_country_president
3,2,78,2_health_disease_cancer_care
4,3,67,3_service_computer_commission_internet
5,4,58,4_land_conveyance_tribe_convey
6,5,55,5_election_political_candidate_voter
7,6,52,6_water_waste_administrator_environmental
8,7,50,7_wildlife_river_refuge_specie
9,8,47,8_alien_immigration_nationality_status


In [34]:
em_model.get_topic(9)

[('oil', 0.04259274175493743),
 ('gas', 0.040616692340830436),
 ('fuel', 0.038248139404786705),
 ('energy', 0.032313202735513843),
 ('reliability', 0.03050293051922194),
 ('natural', 0.028450870173541264),
 ('electric', 0.027968909804451762),
 ('gasoline', 0.025838492337621068),
 ('power', 0.022992397172504366),
 ('price', 0.02087260253769079)]

In [None]:
em_model.get_representative_docs(0)

### Coherence Score of Embedded Model

In [36]:
topic_words = []
for i in range(len(em_model.get_topic_freq())-1):
  interim = []
  interim = [t[0] for t in em_model.get_topic(i)]
  topic_words.append(interim)

coherence_model = CoherenceModel(topics=topic_words, 
                                 texts=texts, 
                                 corpus=corpus, 
                                 dictionary=id2word, 
                                 coherence='c_v')

embedded_coherence = coherence_model.get_coherence()
embedded_coherence

0.6607821234514816

## Visualizations

In [37]:
em_model.visualize_topics()

In [38]:
display_topics(em_model, 9)

Topic 0: (1000 documents)
['commission', 'individual', 'agency', 'education', 'amend', 'employee', 'code', 'amount', 'service', 'term']

Topic 1: (89 documents)
['united', 'states', 'country', 'president', 'foreign', 'government', 'iran', 'nations', 'sudan', 'sanction']

Topic 2: (78 documents)
['health', 'disease', 'cancer', 'care', 'treatment', 'patient', 'research', 'service', 'woman', 'center']

Topic 3: (67 documents)
['service', 'computer', 'commission', 'internet', 'carrier', 'broadband', 'video', 'communication', 'provider', 'television']

Topic 4: (58 documents)
['land', 'conveyance', 'tribe', 'convey', 'county', 'right', 'native', 'exchange', 'parcel', 'federal']

Topic 5: (55 documents)
['election', 'political', 'candidate', 'voter', 'ballot', 'vote', 'travel', 'official', 'communication', 'office']

Topic 6: (52 documents)
['water', 'waste', 'administrator', 'environmental', 'project', 'material', 'municipal', 'beverage', 'sediment', 'solid']

Topic 7: (50 documents)
['wild

In [39]:
em_model.visualize_hierarchy()

In [46]:
em_model.visualize_barchart(topics=[0,1,2,3,4,5,6,7,8,9,10,11])

In [41]:
em_model.visualize_heatmap()

In [42]:
em_model.visualize_term_rank()

#Comparison of models

In [43]:
def compute_coherence_values(dictionary, corpus, texts, limit, start=2, step=3):
    """
    Compute c_v coherence for various number of topics

    Parameters:
    ----------
    dictionary : Gensim dictionary
    corpus : Gensim corpus
    texts : List of input texts
    limit : Max num of topics

    Returns:
    -------
    model_list : List of LDA topic models
    coherence_values : Coherence values corresponding to the LDA model with respective number of topics
    """
    coherence_values = []
    model_list = []
    for num_topics in range(start, limit, step):
        model = BERTopic(verbose=True, nr_topics=num_topics)
        topics, probabilities = model.fit_transform(selected['clean_text'])
        model_list.append(model)

        topic_words = []
        for i in range(len(model.get_topic_freq())-1):
            interim = []
            interim = [t[0] for t in model.get_topic(i)]
            topic_words.append(interim)

        coherencemodel = CoherenceModel(topics=topic_words, texts=texts, corpus=corpus, dictionary=dictionary, coherence='c_v')
        coherence_values.append(coherencemodel.get_coherence())

    return model_list, coherence_values


In [44]:
model_list, coherence_values = compute_coherence_values(dictionary=id2word, corpus=corpus, texts=texts, start=2, limit=40, step=6)

# Show graph
import matplotlib.pyplot as plt
limit=40; start=5; step=5;
x = range(start, limit, step)
plt.plot(x, coherence_values)
plt.xlabel("Num Topics")
plt.ylabel("Coherence score")
plt.legend(("coherence_values"), loc='best')
plt.show()

Batches:   0%|          | 0/103 [00:00<?, ?it/s]

2022-04-15 05:41:11,092 - BERTopic - Transformed documents to Embeddings
2022-04-15 05:41:31,194 - BERTopic - Reduced dimensionality with UMAP
2022-04-15 05:41:31,346 - BERTopic - Clustered UMAP embeddings with HDBSCAN
2022-04-15 05:41:39,927 - BERTopic - Reduced number of topics from 68 to 3


Batches:   0%|          | 0/103 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

In [None]:
for i in range(len(model_list)):
    print('number of topics:',len(model_list[i].get_topic_info()))
    print('coherence score:',coherence_values[i])

In [None]:
print('Baseline BERTopic', baseline_coherence)
print('Automated Reduction BERTopic', reduced_coherence)
print('BERTopic w Legal Embedding', embedded_coherence)