## Import Libraries

In [1]:
import requests  # for HTTPS requests to URLS
import zipfile  # to un-zip any data files
import pypdf  # to extract text from PDFs (is NOT that good)
import sys, pathlib, fitz  # to extract text from PDFs
import bs4  # HTLM parser to extract data from HTML pages
import math
from sympy import divisors
import pandas  # to create and save CSV files
import time  # sleep function to manage rate limits for Congress API
import progress  # progress barfor loops (not used)
from tqdm import tqdm  # progress bar for loops
import pdf2docx  # good method to extract text (in correct order) from PDFs
import docx2txt  # Method for converting a Docx file to a Txt file
from pdfminer.high_level import extract_text  # YES, use this for text extraction from PDFs
import torch  # PyTorch for creating machine learning models
import datetime  # For creating different Supreme Court groups
import re  # For regular expressions
import string  # For string cleaning
import random
import csv

from collections import Counter, namedtuple
from itertools import chain
import json
import math
import os
from pathlib import Path
from tqdm.notebook import tqdm, trange
from typing import List, Tuple, Dict, Set, Union
import matplotlib as plt
import numpy

# import sklearn
# from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
import torch.nn.utils
import torch.nn.functional as F
# import lzma
# import torchtext  # For tokenizer
# from torchtext.data import tokenizer

## Data Collection

### U.S. Supreme Court opinions

Collect a list of all Justices and their service dates

In [124]:
# Create date ranges for Supreme Courts with a unique set of justices
url = "https://www.supremecourt.gov/about/members_text.aspx"
r = requests.get(url=url)

soup = bs4.BeautifulSoup(markup=r.content, features="html.parser")
justice_tables = soup.find_all(attrs={"class":"table table-striped justicetable"})

soup = bs4.BeautifulSoup(markup=str(justice_tables[0]), features="html.parser")
chief_justices_data = list(map(lambda x: x.text.split("\n"), soup.find_all("tr")[1:]))
soup = bs4.BeautifulSoup(markup=str(justice_tables[1]), features="html.parser")
associate_justices_data = list(map(lambda x: x.text.split("\n"), soup.find_all("tr")[1:]))
all_justices_data = chief_justices_data + associate_justices_data

def cleanNotes(s):
    return s.replace("(a) ", "").replace("(b) ", "").replace("(c) ", "").replace("*", "").strip()

def toDate(s):
    if s == "":
        return False
    # Two formatting mistakes on the Supreme Court's website makes this necessary (notification has been submitted to the Webmaster)
    try:
        return datetime.datetime.strptime(s, "%B %d, %Y").date()
    except:
        try:
            return datetime.datetime.strptime(s, "%B %d %Y").date()
        except:
            return datetime.datetime.strptime(s, "%B %d,%Y").date()

Create a list of Chief Justice Supreme Courts

In [125]:
chief_justices = list(map(lambda x: [x[1], toDate(cleanNotes(x[4])), toDate(cleanNotes(x[5]))], chief_justices_data))

chief_courts = list(map(lambda x: (x[0][:x[0].find(",")] + " Court", x[1], x[2]), chief_justices))

Create a list of unique member-set Supreme Courts

In [267]:
all_justices = list(map(lambda x: [x[1], toDate(cleanNotes(x[4])), toDate(cleanNotes(x[5]))], all_justices_data))

start_dates = list(map(lambda x: x[1], all_justices))
end_dates = list(filter(lambda x: x, map(lambda x: x[2], all_justices)))
start_dates.sort()
beginning = start_dates[0]
end_dates.sort(reverse=True)
end = end_dates[0]

time_period = []
one_day = datetime.timedelta(days=1)
curr = beginning

for i in range((end - beginning).days + 1):
    time_period.append((curr, []))
    curr = curr + one_day

for justice in all_justices:

    # Finds the index in the time_period list where their service starts
    start_index = (justice[1] - beginning).days
    
    # Finds the index in the time_period list where their service ends
    if justice[2] == False: # if there is no end date
        end_index = len(time_period) # make end index the end of the list
    else: # if there is an end date
        end_index = (justice[2] - beginning).days # find the date index in time_period list

    for day_index in range(start_index, end_index):

        time_period[day_index][1].append(justice[0])

# Initialize pointer / trackers
court_sets = []
prev_court = set(time_period[0][1])
prev_date = time_period[0][0]
start_date = time_period[0][0]

for day in time_period:
    
    curr_court = set(day[1])

    if prev_court != curr_court:
        
        court_sets.append((prev_court, start_date, prev_date))  # add old court to list
        start_date = day[0]  # set start date as current date
        prev_date = day[0]  # set previous date as current date
        prev_court = curr_court  # set previous court as current court

    else:

        prev_date = day[0]  # set previous date as current date

# Add last court
court_sets.append((prev_court, start_date, prev_date))

Collect decisions between 1937 and 1975

In [2]:
url = "https://www.govinfo.gov/bulkdata/SCD/1937/SCD-1937.zip"
r = requests.get(url=url)

open("./data/raw-data/SCD/SCD-1937.zip", "xb").write(r.content)
with zipfile.ZipFile(file="./data/raw-data/SCD/SCD-1937.zip", mode="r") as zip_ref:
    zip_ref.extractall("./data/raw-data/SCD")

Collect decisions between 1991 and 2015

In [3]:
# Loop for 1991 to 2009
for volume in range(502,561):
    url = f"https://www.supremecourt.gov/opinions/boundvolumes/{volume}bv.pdf"
    r = requests.get(url=url)
    open(f"./data/raw-data/SCD/SCD-{volume}.pdf", "xb").write(r.content)

# Loop for 2009 to 2015
for volume in range(561,578):
    url = f"https://www.supremecourt.gov/opinions/boundvolumes/{volume}BV.pdf"
    r = requests.get(url=url)
    open(f"./data/raw-data/SCD/SCD-{volume}.pdf", "xb").write(r.content)

In [4]:
data = []

for volume in tqdm(range(502, 578)):

    text = extract_text(f"./data/raw-data/SCD/SCD-{volume}.pdf")
    data.append([volume, text])

pandas.DataFrame(data=data, columns=["Volume", "Text"]).to_csv("./data/raw-data/SCD.csv")

  0%|          | 0/76 [00:00<?, ?it/s]

### U.S. Presidential Documents

Collect Presidential memorandas, determinations, executive orders, and proclamations

In [292]:
# Dictionary for all executive document-types and their associated url-bases
executive = {
    "memorandas": "app-categories/presidential/memoranda",
    "determinations": "app-attributes/determinations",
    "executive-orders": "app-attributes/executive-orders",
    "proclamations": "app-categories/written-presidential-orders/presidential/proclamations"
}

for document_type, base in executive.items():

    data = []
    if document_type in ["memorandas", "determinations",  "executive-orders"]:
        continue

    # Get counts for page requests
    url = f"https://www.presidency.ucsb.edu/documents/{base}?items_per_page=5"
    r = requests.get(url=url)
    soup = bs4.BeautifulSoup(markup=r.content, features="html.parser")
    count = soup.find(attrs={"class":"tax-count"}).text
    total = int(count[count.index("of")+3: count.index(".")])
    MAX_DISPLAY = 500
    display = max(map(lambda x: x if x<MAX_DISPLAY else 0,divisors(total)))
    pageNum = int(total / display)

    # Loop to gather all documents
    for page in tqdm(range(21,pageNum)):
        url = f"https://www.presidency.ucsb.edu/documents/{base}?items_per_page={display}&page={page}"
        r = requests.get(url=url)
        soup = bs4.BeautifulSoup(markup=r.content, features="html.parser")
        documentLinks = [child.a["href"] for child in soup.find_all(attrs={"class":"field-title"})]

        for href in documentLinks:
            # get content
            r = requests.get(url=f"https://www.presidency.ucsb.edu{href}")
            soup = bs4.BeautifulSoup(markup=r.content, features="html.parser")
            content = ""
            for child in soup.find(attrs={"class":"field-docs-content"}).children:
                content += child.text
            # Get person
            soup = bs4.BeautifulSoup(markup=r.content, features="html.parser")
            person = ""
            for child in soup.find(attrs={"class":"diet-title"}).children:
                person += child.text

            data.append([person, content])

    pandas.DataFrame(data=data, columns=["person", document_type]).to_csv(f"./data/raw-data/{document_type}.csv")

  0%|          | 0/6 [00:00<?, ?it/s]

### Congressional Bill Summaries

Congress.gov now offers a beta [API](https://www.congress.gov/help/using-data-offsite) to request data available through their website.

#### Initialize Collection

In [None]:
# url = "https://api.congress.gov/v3/summaries/"
url = "https://api.congress.gov/v3/summaries?fromDateTime=1973-01-01T00:00:00Z&toDateTime=2023-01-02T00:00:00Z&sort=updateDate+asc&offset=0&limit=250&format=json"
start = "1973-01-01T00:00:00Z"
end = "2023-01-02T00:00:00Z"
params = {
    "api_key": "VdU8MBQu6ztVqNZ6WF6escCDRTXApyktwBdWqZli",
    "limit": 250,
    "fromDateTime": start,
    "toDateTime": end,
    "sort": "updateDate+asc",
}
key_param = {"api_key": "VdU8MBQu6ztVqNZ6WF6escCDRTXApyktwBdWqZli",}
offset = 0
total = 398_395
requestsNeeded = math.ceil(total/250)

summaryData = []
t = tqdm(total=total)

#### Collection Loop

In [27]:
while (url != False):
    
    # params.update({"offset": offset}) # set offset to request new data
    r = requests.get(url=url, params=key_param) # sent request to API
    loopSummaries = r.json()["summaries"]

    # For the summary data from the request, map to collect the Chamber + Congress and the text of the summary
    summaryData += map(
        lambda data: [
            data["bill"]["originChamber"] + " " + str(data["bill"]["congress"]),
            data["text"],
        ], loopSummaries
    )

    url = r.json()["pagination"].get("next", False)
    time.sleep(2)
    t.update(n=len(loopSummaries))

t.close()

pandas.DataFrame(data=summaryData, columns=["Institution", "Text"]).to_csv("data/congress3.csv")  # save data to CSV file

100%|██████████| 398395/398395 [3:59:52<00:00, 27.68it/s]  


## Data Processing

### Cleaning

String Cleaning functions

In [40]:
printable = set(string.printable)  # Get set of ASCII characters
other = "!\"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c"

def cleanString(s):

    s = "".join(filter(lambda x: x in printable, s))  # Remove all non-ASCII character from the string
    return re.sub("\s+", " ", s.replace("\n", " ").replace("\t", " ").replace("\r", " ")).strip()

def bleachString(s):

    onlyLetters = re.sub(r'[^a-z]', " ", s.lower())  # remove numbers, and other non-letter characters
    return re.sub("\s+", " ", onlyLetters).strip()

def totalClean(s):
    return bleachString(cleanString(s))

Clean Congressional bill summaries

In [142]:
# Remove HTML artifacts, totally clean strings

congress = pandas.read_csv("./data/raw-data/congress.csv")

congress["Text"] = congress["Text"].apply(
    lambda x: totalClean(bs4.BeautifulSoup(markup=x, features="html.parser").text)
    )

congress["Institution"] = congress["Institution"].apply(lambda x: x + "th")
congress.drop(columns=["Unnamed: 0"], inplace=True)

  lambda x: totalClean(bs4.BeautifulSoup(markup=x, features="html.parser").text)


Clean Supreme Court decisions

In [139]:
# For 1937 through 1975 Supreme Court cases, there are 19 new lines between each case
decision19 = 19 * "\n"
file = open("./data/raw-data/SCD/SCD-1937.txt")
decisions1937 = map(lambda x: cleanString(x), file.read().split(decision19))

def getDate(s):

    try:
        match = re.search(r'(DECIDED|DECODED|DISMISSED|CONTINUED|ANNOUNCED|GRANTED)\s[A-Z]+\s[0-9]+,\s[0-9]+', s)
        date = match[0].split(" ")
        date = " ".join(date[1:])
        date = date.split(" ")
        date[0] = date[0][0] + date[0][1:].lower()
        return " ".join(date)
    except:
        print(s)

def date2Cheif(date):

    date = datetime.datetime.strptime(date, "%B %d, %Y").date()
    
    for chief in chief_courts:
        
        if date >= chief[1]:
            val = chief[0]
            if date <= chief[2]:
                return chief[0]
        else: # return previous Chief Justice if the start date of the current is before the decision
            return val
    


decisions1937 = list(map(lambda x: [date2Cheif(getDate(x)), totalClean(x)], decisions1937))
decisions1937 = pandas.DataFrame(data=decisions1937, columns=["Institution", "Text"])

In [113]:
decisions1991 = pandas.read_csv("./data/raw-data/SCD.csv")

decisions1991["Text"] = decisions1991["Text"].apply(lambda x: cleanString(x))

def getChief(s):

    try:
        match = re.search(r'([A-Z]+|[A-Z]+,\sJr\.),\sChief\sJustice\.', s)
        return match[0].split(",")[0]
    except:
        print("Failed match")

decisions1991["Holder"] = decisions1991["Text"].apply(lambda x: getChief(x))
decisions1991["Text"] = decisions1991["Text"].apply(lambda x: totalClean(x))
decisions1991["Volume"] = decisions1991["Holder"].apply(lambda x: x[0] + x[1:].lower() + " Court")
decisions1991.drop(columns=["Unnamed: 0", "Holder"], inplace=True)
decisions1991.rename(columns={"Volume": "Institution"}, inplace=True)

Clean Presidential documents

In [111]:
determinations = pandas.read_csv("./data/raw-data/determinations.csv")
executive_orders = pandas.read_csv("./data/raw-data/executive-orders.csv")
memorandas = pandas.read_csv("./data/raw-data/memorandas.csv")
proclamations = pandas.read_csv("./data/raw-data/proclamations.csv")

determinations.rename(columns={"person": "Institution", "determinations": "Text"}, inplace=True)
executive_orders.rename(columns={"person": "Institution", "executive-orders": "Text"}, inplace=True)
memorandas.rename(columns={"person": "Institution", "memorandas": "Text"}, inplace=True)
proclamations.rename(columns={"person": "Institution", "proclamations": "Text"}, inplace=True)

president = pandas.concat([determinations, executive_orders, memorandas, proclamations])
president.drop(columns=["Unnamed: 0"], inplace=True)

president["Text"] = president["Text"].apply(lambda x: totalClean(x))

In [144]:
institution_data = pandas.concat([congress, decisions1937, decisions1991, president])
institution_data.to_csv("./data/clean-data/institution.csv")

Remove stop words

In [None]:
stop_words = ['i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', "you're", 
              "you've", "you'll", "you'd", 'your', 'yours', 'yourself', 'yourselves', 'he', 
              'him', 'his', 'himself', 'she', "she's", 'her', 'hers', 'herself', 'it', "it's", 
              'its', 'itself', 'they', 'them', 'their', 'theirs', 'themselves', 'what', 
              'which', 'who', 'whom', 'this', 'that', "that'll", 'these', 'those', 'am', 'is', 
              'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'having', 
              'do', 'does', 'did', 'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or', 
              'because', 'as', 'until', 'while', 'of', 'at', 'by', 'for', 'with', 'about', 
              'against', 'between', 'into', 'through', 'during', 'before', 'after', 'above', 
              'below', 'to', 'from', 'up', 'down', 'in', 'out', 'on', 'off', 'over', 'under', 
              'again', 'further', 'then', 'once', 'here', 'there', 'when', 'where', 'why', 
              'how', 'all', 'any', 'both', 'each', 'few', 'more', 'most', 'other', 'some', 
              'such', 'no', 'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 
              's', 't', 'can', 'will', 'just', 'don', "don't", 'should', "should've", 'now', 
              'd', 'll', 'm', 'o', 're', 've', 'y', 'ain', 'aren', "aren't", 'couldn', "couldn't", 
              'didn', "didn't", 'doesn', "doesn't", 'hadn', "hadn't", 'hasn', "hasn't", 'haven',
                "haven't", 'isn', "isn't", 'ma', 'mightn', "mightn't", 'mustn', "mustn't", 
                'needn', "needn't", 'shan', "shan't", 'shouldn', "shouldn't", 'wasn', "wasn't", 
                'weren', "weren't", 'won', "won't", 'wouldn', "wouldn't"]



Convert to lower-case

### Processing

Define Vocab class to tokenize a set of words

In [3]:
UNK = "<UNK>"
PAD = "<PAD>"

class Vocab(object):
    """ Vocabulary, i.e. structure containing language terms.
        This vocabulary can, and should, be abstracted to other
        sets of object. For vocabulary, it is words to tokens. 
        For government institutions, it is institutions to tokens.

        Instance attributes:
            word2id: dictionary mapping words to indices
            unk_id: index for UNK
            id2words: dictionary mapping indices to words
    """
    def __init__(self, word2id=None):
        """ Init Vocab Instance.

        :param word2id: dictionary mapping words to indices
        :type word2id: dict[str, int]
        """
        if word2id:
            self.word2id = word2id
        else:
            self.word2id = dict()
            self.word2id[PAD] = 0  # Pad Token
            self.word2id[UNK] = 1  # Unknown Token
        self.unk_id = self.word2id[UNK]
        self.id2word = {v: k for k, v in self.word2id.items()}

    def __getitem__(self, word):
        """ Retrieve word's index. Return the index for unk
        token if the word is out of vocabulary

        :param word: word to look up
        :type word: str
        :returns: index of word
        :rtype: int
        """
        return self.word2id.get(word, self.unk_id)
    
    def __contains__(self, word):
        """ Check if word is captured by Vocab.
        
        :param word: word to look up
        :type word: str
        :returns whether word is in vocab
        :rtype: bool
        """
        return word in self.word2id
    
    def __setitem__(self, key, value):
        """ Raise error, if one tries to edit Vocab directly.
        """
        raise ValueError("Vocab is readonly")
    
    def __len__(self):
        """ Compute number of words in Vocab.
        
        :returns: number of words in Vocab
        :rtype: int
        """
        return len(self.word2id)
    
    def __repr__(self):
        """ Representation of Vocab to be used
        when printing the object.
        """
        return "Vocabulary[size=%d]" % len(self)
    
    def word_from_id(self, wid):
        """ Return mapping of index to word.
        
        :param wid: word index
        :type: int
        :returns: word corresponding to index
        :rtype: str
        """
        return self.id2word[wid]
    
    def add(self, word):
        """ Add word to Vocab, if it is previously unseen.
        
        :param word: to add to Vocab
        :type word: str
        :returns: index that the word has been assigned
        :rtype: int
        """
        if word not in self:
            wid = self.word2id[word] = len(self)
            self.id2word[wid] = word
            return wid
        else:
            return self[word]
        
    def save(self, path):
        """ Save Vocab to CSV file, indicated by `path`
        
        :param path: the relative path to the saving file
        :type path: str
        """
        with open(path, "w") as f:
            write = csv.writer(f)
            write.writerows([[w, id] for w, id in self.word2id.items()])

    @staticmethod
    def load(path):
        """ Load Vocab from CSV file, indicated by `path`
        
        :param path: the relative path to the loading file
        :type path: str
        :returns: Vocab instance produced from CSV file
        :rtype: Vocab
        """
        with open(path) as f:
            reader = csv.reader(f)
            return Vocab({row[0]: int(row[1]) for row in reader})

        
    @staticmethod
    def from_corpus(corpus, remove_frac=None, freq_cutoff=None):
        """ Given a corpus, construct a Vocab.
        
        :param corpus: corpus of text produced by read_corpus function
        :type corpus: List[str]
        :param remove_frac: remove len * remove_frac number of words
        :type remove_frac: float
        :param freq_cutoff: if word occurs n < frew_cutoff times, drop the word
        :type freq_cutoff: int
        :returns: Vocab instance produced from provided corpus
        :rtype: Vocab
        """
        vocab_entry = Vocab()
        word_freq = Counter(chain(corpus))
        if freq_cutoff is None:
            freq_cutoff = 1
        valid_words = [w for w, v in word_freq.items() if v >= freq_cutoff]
        print("number of word types: {}, number of word types w/ frequency >= {}: {}"
              .format(len(word_freq), freq_cutoff, len(valid_words)))
        top_words = sorted(valid_words, key=lambda word: word_freq[word], reverse=True)
        if remove_frac is not None:
            size = len(top_words) - int(remove_frac * len(top_words))
            top_words = top_words[:size]
            print(f"number of unqiue words retained with remove_frac={remove_frac}: {len(top_words)}")
        for word in top_words:
            vocab_entry.add(word)
        return vocab_entry


Create Vocabs for both words and institutions

In [5]:
institution_data = pandas.read_csv("./data/clean-data/institution.csv")
institution_data.drop(columns=["Unnamed: 0"], inplace=True)

corpus = " ".join(institution_data["Text"].to_list()).split(" ")
institutions = institution_data["Institution"].to_list()

word_vocab = Vocab.from_corpus(corpus=corpus)
gov_vocab = Vocab.from_corpus(corpus=institutions)

word_vocab.save("./data/govtext/words.csv")
gov_vocab.save("./data/govtext/govs.csv")

number of word types: 244187, number of word types w/ frequency >= 1: 244187
number of word types: 102, number of word types w/ frequency >= 1: 102


Load Vocabs for both words and institutions

In [3]:
word_vocab = Vocab.load("./data/vocab/words.csv")
gov_vocab = Vocab.load("./data/vocab/govs.csv")

Define tokenization functions

Should look into PyTorch's [tokenizer](https://pytorch.org/text/stable/data_utils.html) functions. Has some pre-build NLP functionality that is most likely implemented with a faster language. Separation of your tokenizer and 

In [6]:
def tokenize_document(document, vocab):
    return [vocab[w] for w in document.split(" ")]

def tokenize_documents(documents, vocab):
    output = []
    for document in documents:
        output.append(tokenize_document(document, vocab))
    return output

Tokenize data

In [7]:
institution_data["Institution"] = institution_data["Institution"].apply(lambda x: gov_vocab[x])
institution_data["Text"] = institution_data["Text"].apply(lambda x: tokenize_document(x, word_vocab))

Consolidate an institution's corpus into a single string

In [9]:
institution_corpus = dict()

for index, row in institution_data.iterrows():

    if row["Institution"] in institution_corpus.keys():
        institution_corpus[row["Institution"]] += row["Text"]
    else:
        institution_corpus[row["Institution"]] = row["Text"]

Save tokenized, consolidated institution data

In [20]:
with open("./data/govtext/institution.csv", 'w') as f:
    write = csv.writer(f)
    write.writerows([[gov, tokens] for gov, tokens in institution_corpus.items()])

Defined data loader class and function

In [4]:
class LanguageDataset(Dataset):
    """ LanguageDataset is a torch dataset to interact with the Language data.

        Dataset (List[ Tuples[ List[ torch.Tensor ], int ] ]): The vectorized dataset with input and expected output values
        Dataset is an abstract class representing a dataset:
    """
    def __init__(self, context, gov, target):
        """ Loads in the context, gov, and target as tensors.

        :param context: context tokens on both sides
        :type context: List[List[int]]
        :param gov: government token
        :type gov: List[int]
        :param target: middle target token
        :type target: List[int]
        """
        self.context = torch.tensor(context)
        self.gov = torch.tensor(gov)
        self.target = torch.tensor(target)
        self.len = len(context)
    
    def __len__(self):
        """ Number of samples in dataset

        :returns: number of samples in dataset
        :rtype: int
        """
        return self.len
    
    def __getitem__(self, index):
        """ The tensor, output for a given index

        :param index: index within dataset
        :type index: int
        :returns: A tuple (x, y, z) where x is the context, y is the govnernment, z is the target
        :rtype: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
        """
        return self.context[index], self.gov[index], self.target[index]


def get_data_loaders(preprocessed_data, batch_size=1, shuffle=False):
    """
    """
    dataset = LanguageDataset(preprocessed_data["context_tokens"], preprocessed_data["gov_tokens"], preprocessed_data["target_tokens"])
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    return loader

Collect 1 millions samples for each government institution

In [5]:
WINDOW_SIZE = 15

In [None]:
preprocessed_data = dict()
preprocessed_data["context_tokens"] = []
preprocessed_data["gov_tokens"] = []
preprocessed_data["target_tokens"] = []


for gov, text in tqdm(institution_corpus.items()):

    count = 0
    length = len(text)

    # Collect 500,000 samples for each Institution
    while(count < 500_000):

        # Sample randomly from the total corpus
        position = math.ceil(random.random() * length)
        if position < (2*WINDOW_SIZE + 1):
            position = 2*WINDOW_SIZE + 1

        context = text[position - (2*WINDOW_SIZE + 1): position - (WINDOW_SIZE + 1)] + text[position - WINDOW_SIZE: position]
        if len(context) != 2 * WINDOW_SIZE:
            continue
        target = text[position - (WINDOW_SIZE + 1)]

        preprocessed_data["context_tokens"].append(context)
        preprocessed_data["gov_tokens"].append(gov)
        preprocessed_data["target_tokens"].append(target)

        count += 1


pandas.DataFrame(data=preprocessed_data).to_csv(f"./data/input/inputRandom_window{WINDOW_SIZE}.csv")

Load all samples (~100 million)

In [None]:
preprocessed_data = pandas.read_csv(filepath_or_buffer=f"./data/input/inputRandom_window{WINDOW_SIZE}.csv", 
                                    converters={"context_tokens": lambda x: [int(y) for y in x.strip('][').split(', ')]}
                                    )

loader = get_data_loaders(preprocessed_data=preprocessed_data, batch_size=32, shuffle=True)

Load 100,000 random samples

In [19]:
preprocessed_data = pandas.read_csv(filepath_or_buffer=f"./data/input/inputRandom_window{WINDOW_SIZE}.csv")

indices = set()
for i in range(100_000):
    index = random.randrange(0, len(preprocessed_data))
    indices.add(index)

preprocessed_data = preprocessed_data.iloc[list(indices)]

preprocessed_data["context_tokens"] = preprocessed_data["context_tokens"].apply(lambda x: [int(y) for y in x.strip('][').split(', ')])
loader = get_data_loaders(preprocessed_data=preprocessed_data.to_dict("list"), batch_size=32, shuffle=True)

## Model

In [26]:
# https://github.com/inejc/paragraph-vectors/blob/master/paragraphvec/models.py
# Implemention of object embedding with target word predictions

# https://github.com/inejc/paragraph-vectors/tree/master

# https://github.com/OlgaChernytska/word2vec-pytorch/tree/main

# https://github.com/jeffchy/pytorch-word-embedding/blob/master/CBOW.py

Check and Set PyTorch backend device

In [7]:
# Verify mps support (Apple Silicon)

if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

device = ""
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = "cpu"

tensor([1.], device='mps:0')


Define Gov2Vec Model

In [8]:
class Gov2Vec_Model(nn.Module):
    """ Gov2Vec Model
    """
    def __init__(self, vocab_size, gov_size, word_embedding_dim, gov_embedding_dim):

        super(Gov2Vec_Model, self).__init__()
        # Weights should be given to CrossEntropyLoss that incorporate the frequency of words
        # in the dataset. Weights for the minority classes (words) should be higher.
        self.criterion = nn.CrossEntropyLoss()

        self.vocab_size = vocab_size
        self.gov_size = gov_size
        self.word_embedding_dim = word_embedding_dim
        self.gov_embedding_dim = gov_embedding_dim

        self.word_embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=word_embedding_dim
        )
        self.gov_embedding = nn.Embedding(
            num_embeddings=gov_size,
            embedding_dim=gov_embedding_dim
        )
        self.linear = nn.Linear(
            in_features=word_embedding_dim+gov_embedding_dim,
            out_features=vocab_size
        )

    def forward(self, context, gov):
        """
        """
        context_embedding = self.word_embedding(context)
        gov_embedding = self.gov_embedding(gov)

        # Can add, or mean over context tokens
        # Can concat or add or mean context tokens with gov token
        context_embedding = context_embedding.mean(axis=1)
        combined = torch.cat((gov_embedding, context_embedding), 1)
        out = self.linear(combined)

        return out

    def load_model(self, save_path, is_state_dict=False):
        if not is_state_dict:
            saved_model = torch.load(save_path)
            self.load_state_dict(saved_model.state_dict())
        else:
            self.load_state_dict(torch.load(save_path))

    def save_model(self, save_path, is_state_dict=False):
        if is_state_dict:
            torch.save(self.state_dict(), save_path)
        else:
            torch.save(self, save_path)

Define training functions

In [9]:
def train_epoch(model, train_loader, optimizer):
    model.train()
    total = 0
    batch = 0
    total_loss = 0
    correct = 0
    for (context_batch, gov_batch, target_batch) in tqdm(train_loader, leave=False, desc="Training batches"):
        optimizer.zero_grad()
        batch += 1
        output = model(context_batch.to(device), gov_batch.to(device)).to(device)
        loss = model.criterion(output, target_batch.to(device))
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
        
    print("Loss: " + str(total_loss / batch))
    return total_loss / batch

def evaluation(model, val_loader, optimizer):
    model.eval()
    loss = 0
    correct = 0
    total = 0
    for (context_batch, gov_batch, target_batch) in tqdm(val_loader, leave=False, desc="Validation Batches"):
        outputs = model(context_batch.to(device), gov_batch.to(device)).to(device)
        loss += model.criterion(outputs, target_batch.to(device))
    loss /= len(val_loader)
    print("Validation Loss: " + str(loss.item()))
    # print("Validation Accuracy: " + str(correct))
    print()
    return loss.item()


def train_and_evaluate(number_of_epochs, model, train_loader, val_loader=None, min_loss=0, learning_rate=0.001):
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    loss_values = [[],[]]

    for epoch in trange(number_of_epochs, desc="Epochs"):
        curr_loss = train_epoch(model, train_loader, optimizer)
        loss_values[0].append(curr_loss)
        # if val_loader is not None:
        #    curr_loss_val = evaluation(model, val_loader, optimizer)
        #    loss_values[1].append(curr_loss_val)
        if curr_loss <= min_loss: return loss_values

    return loss_values

Set Gov2Vec Model

In [10]:
model = Gov2Vec_Model(
    vocab_size=len(word_vocab), 
    gov_size=len(gov_vocab),
    word_embedding_dim=150,
    gov_embedding_dim=150
)
model.to(device)

Gov2Vec_Model(
  (criterion): CrossEntropyLoss()
  (word_embedding): Embedding(244189, 150)
  (gov_embedding): Embedding(104, 150)
  (linear): Linear(in_features=300, out_features=244189, bias=True)
)

Train Gov2Vec Model

In [20]:
train_and_evaluate(
    number_of_epochs=5,
    model=model,
    train_loader=loader,
    min_loss=0.2,
    learning_rate=0.015
)

Epochs:   0%|          | 0/5 [00:00<?, ?it/s]

Training batches:   0%|          | 0/3091 [00:00<?, ?it/s]

Loss: 6.477017898923639


Training batches:   0%|          | 0/3091 [00:00<?, ?it/s]

Loss: 4.80784002718745


Training batches:   0%|          | 0/3091 [00:00<?, ?it/s]

Loss: 3.8650567180629842


Training batches:   0%|          | 0/3091 [00:00<?, ?it/s]

Loss: 3.3638254211620775


Training batches:   0%|          | 0/3091 [00:00<?, ?it/s]

Loss: 3.1140751017987323


[[6.477017898923639,
  4.80784002718745,
  3.8650567180629842,
  3.3638254211620775,
  3.1140751017987323],
 []]

In [51]:
model.save_model("./data/model/e5b16lr025s10000")

## Test

In [14]:
gov_token1 = int(gov_vocab["Robert Court"])
gov_token2 = int(gov_vocab["Rehnquist Court"])
gov_token3 = int(gov_vocab["Joseph R. Biden"])

word_token1 = int(word_vocab["truth"])

gov_embed1 = model.gov_embedding(torch.tensor(gov_token1, device=device))
gov_embed2 = model.gov_embedding(torch.tensor(gov_token2, device=device))
gov_embed3 = model.gov_embedding(torch.tensor(gov_token3, device=device))

word_embed1 = model.word_embedding(torch.tensor(word_token1, device=device))

cos = nn.CosineSimilarity(dim=0)

print("Court v Court", cos(gov_embed1, gov_embed2))
print("President v Court", cos(gov_embed3, gov_embed2))

print("Court v 'truth'", cos(gov_embed1, word_embed1))
print("President v 'truth'", cos(gov_embed3, word_embed1))

Court v Court tensor(0.0835, device='mps:0', grad_fn=<SumBackward1>)
President v Court tensor(0.0675, device='mps:0', grad_fn=<SumBackward1>)
Court v 'truth' tensor(0.0633, device='mps:0', grad_fn=<SumBackward1>)
President v 'truth' tensor(-0.0306, device='mps:0', grad_fn=<SumBackward1>)


In [15]:
def similarity(cVocab, cGov, qVocab, query, model):
    """ Gives the token with the maximum cosine similarity
    score from the comparison vocab against the query
    
    :param cVocab: vocab to compare the query to
    :type cVocab: Vocab
    :param cGov: whether the comparison vocab is a government
    :type cGov: bool
    :param qVocab: vocab of the query
    :type qVocab: Vocab
    :param query: query of words to compare, each word in the query can be positive or negative
    :type query: List[Tuple(str, int)]
    """
    cos = nn.CosineSimilarity(dim=0)

    if cGov:
        cEmbedding = model.gov_embedding
    else:
        cEmbedding = model.word_embedding

    qEmbedding = model.word_embedding

    query_len = len(query)
    query_vec = sum(
        map(
            lambda x: qEmbedding(torch.tensor(qVocab[x[0]], device=device)) * x[1],
            query
        )
    ) / query_len

    max = ("", -10000)
    for word, id in cVocab.word2id.items():

        # Exclude comparison of words in the query
        if word in map(lambda x: x[0],query):
            continue

        score = cos(cEmbedding(torch.tensor(id, device=device)), query_vec)
        if score > max[1]:
            max = (word, score)

    return max

In [16]:
similarity(gov_vocab, True, word_vocab, [("validity", 1), ("truth", 1)], model)

('Senate 116th', tensor(0.2217, device='mps:0', grad_fn=<SumBackward1>))

In [18]:
test_queries = [
    [("candidate",1),("elected",1),("campaign",1)],
    [("long",1),("term",1),("government",1),("carerr",1)],
    [("rule",1),("precedent",1),("interpret",1)],
    [("validity",1), ("truth",1)],
    [("statistics",1),("science",1),("data",1),("story",-1),("anecdote",-1)],
    [("order",1),("direct",1),("contemplate",-1),("consider",1)]
]

for query in test_queries:
    result = similarity(gov_vocab, True, word_vocab, query, model)
    print(result)

('Senate 117th', tensor(0.2138, device='mps:0', grad_fn=<SumBackward1>))
('Senate 104th', tensor(0.2097, device='mps:0', grad_fn=<SumBackward1>))
('House 103th', tensor(0.1977, device='mps:0', grad_fn=<SumBackward1>))
('Senate 116th', tensor(0.2217, device='mps:0', grad_fn=<SumBackward1>))
('Senate 111th', tensor(0.2222, device='mps:0', grad_fn=<SumBackward1>))
('Senate 96th', tensor(0.1679, device='mps:0', grad_fn=<SumBackward1>))


In [None]:
tests = [
    {"113th House Economic": [
        ("climate", 1, "word"),
        ("emissions", 1, "word"),
        ("House 113th", 1, "gov"),
        ("Barack Obama", -1, "gov"),
        ("economy", 1, "word"),
        ("environment", -1, "word"),
    ]},
    {"113th House Environmental": [
        ("climate", 1, "word"),
        ("emissions", 1, "word"),
        ("House 113th", 1, "gov"),
        ("Barack Obama", -1, "gov"),
        ("economy", -1, "word"),
        ("environment", 1, "word"),
    ]},
    {"Obama Economic": [
       ("climate", 1, "word"),
        ("emissions", 1, "word"),
        ("House 113th", -1, "gov"),
        ("Barack Obama", 1, "gov"),
        ("economy", 1, "word"),
        ("environment", -1, "word"),
    ]},
    {"Obama Environmental": [
        ("climate", 1, "word"),
        ("emissions", 1, "word"),
        ("House 113th", -1, "gov"),
        ("Barack Obama", 1, "gov"),
        ("economy", -1, "word"),
        ("environment", 1, "word"),
    ]},
    {"106th House Oil": [
        ("war", 1, "word"),
        ("House 106th", 1, "gov"),
        ("House 107th", -1, "gov"),
        ("oil", 1, "word"),
        ("terrorism", -1, "word"),
    ]},
    {"106th House Terrorism": [
        ("war", 1, "word"),
        ("House 106th", 1, "gov"),
        ("House 107th", -1, "gov"),
        ("oil", -1, "word"),
        ("terrorism", 1, "word"),
    ]},
    {"107th House Oil": [
        ("war", 1, "word"),
        ("House 106th", -1, "gov"),
        ("House 107th", 1, "gov"),
        ("oil", 1, "word"),
        ("terrorism", -1, "word"),
    ]},
    {"107th House Terrorism": [
        ("war", 1, "word"),
        ("House 106th", -1, "gov"),
        ("House 107th", 1, "gov"),
        ("oil", -1, "word"),
        ("terrorism", 1, "word"),
    ]},
]