In [3]:
import argparse
import os
import sys
import numpy as np

from textacy.datasets.supreme_court import SupremeCourt
# from pytorch_pretrained_bert import BertModel, BertTokenizer
import torch
import torch.nn as nn

In [4]:
sc = SupremeCourt()
# sc.download()
print('sc.info: ', sc.info)

sc.info:  {'name': 'supreme_court', 'site_url': 'http://caselaw.findlaw.com/court/us-supreme-court', 'description': 'Collection of ~8.4k decisions issued by the U.S. Supreme Court between November 1946 and June 2016.'}


In [5]:
sc.issue_area_codes.keys()

dict_keys([-1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14])

In [31]:
# 15 labels
issue_codes_15 = list(sc.issue_area_codes.keys()) 
issue_codes_15.sort()
issue_codes_15 = [str(ic) for ic in issue_codes_15]

# 279 labels
issue_codes_279 = list(sc.issue_codes.keys()) 
issue_codes_279.append('-1')
issue_codes_279.sort()

# dictionary mapping label name to numeric id
labels_index_15 = dict(zip(issue_codes_15, np.arange(len(issue_codes_15))))
labels_index_279 = dict(zip(issue_codes_279, np.arange(len(issue_codes_279))))
labels_index_15

{'-1': 0,
 '1': 1,
 '2': 2,
 '3': 3,
 '4': 4,
 '5': 5,
 '6': 6,
 '7': 7,
 '8': 8,
 '9': 9,
 '10': 10,
 '11': 11,
 '12': 12,
 '13': 13,
 '14': 14}

In [7]:
# Take a look at the format of the data.

tempRecord = next(sc.records())
type(tempRecord)

print('--------- The format of one record ---------')
print('length: ', len(tempRecord))
print('tempRecord[0] is the text: ', type(tempRecord[0]))
print('tempRecord[1] is the dict: ', tempRecord[1])

--------- The format of one record ---------
length:  2
tempRecord[0] is the text:  <class 'str'>
tempRecord[1] is the dict:  {'issue': '80180', 'issue_area': 8, 'n_min_votes': 1, 'case_name': 'HALLIBURTON OIL WELL CEMENTING CO. v. WALKER et al., DOING BUSINESS AS DEPTHOGRAPH CO.', 'maj_opinion_author': 78, 'decision_date': '1946-11-18', 'decision_direction': 'liberal', 'n_maj_votes': 8, 'us_cite_id': '329 U.S. 1', 'argument_date': '1946-01-09'}


In [35]:
texts = []  # list of text samples
labels_15 = []  # list of label ids
labels_279 = []

case_name_plus_citeId_list = []

for record in sc.records():
    text_record = record[0]
    feature_record = record[1]

    # process issue number
    issue_record = feature_record['issue']
    
    # 15 labels
    if issue_record == None: # some cases have None as an issue
        issue_record_label = labels_index_15['-1']
    else:
        issue_record_label = labels_index_15[feature_record['issue'][:-4]]
    labels_15.append(issue_record_label)
    
    # 279 labels
    if issue_record == None: # some cases have None as an issue
        labels_279.append(labels_index_279['-1'])
    else:
        labels_279.append(labels_index_279[feature_record['issue']])
    
    # process case name
    case_name_record = feature_record['case_name']
    if case_name_record == None:
        print('We do find a None in case_name')
    if '@' in case_name_record:
        print('We do find a @')
        
    # process cite id
    cite_id_record = feature_record['us_cite_id']
    if cite_id_record == None:
        cite_id_record = 'None'
    if '@' in cite_id_record:
        print('We do find a @')
    
    # prepare for the dictionary key
    case_name_plus_citeId_list.append(case_name_record+'@'+cite_id_record)
    
    # add texts
    texts.append(text_record)

print('Found %s texts.' % len(texts))
print('Found %s labels in labels_15.' % len(set(labels_15)))
print('Found %s labels in labels_279.' % len(set(labels_279)))

print()
print('length of case_name_plus_citeId_list: ', len(case_name_plus_citeId_list))
case_name_plus_citeId_set = set(case_name_plus_citeId_list)
print('length of case_name_plus_citeId_set: ', len(case_name_plus_citeId_set))

Found 8419 texts.
Found 15 labels in labels_15.
Found 264 labels in labels_279.

length of case_name_plus_citeId_list:  8419
length of case_name_plus_citeId_set:  8419


In [11]:
for record in sc.records():
    print(type(record))
    print(len(record))
    print(record[1])
#     print(record['issue'])
    break
#         if record['issue'] == None

<class 'tuple'>
2
{'issue': '80180', 'issue_area': 8, 'n_min_votes': 1, 'case_name': 'HALLIBURTON OIL WELL CEMENTING CO. v. WALKER et al., DOING BUSINESS AS DEPTHOGRAPH CO.', 'maj_opinion_author': 78, 'decision_date': '1946-11-18', 'decision_direction': 'liberal', 'n_maj_votes': 8, 'us_cite_id': '329 U.S. 1', 'argument_date': '1946-01-09'}


### save the text, and build a corresponding dictionary in disk

In [16]:
def getStringNumber(input_num):
    input_num = str(input_num)
    output_num = ''
    for i in range(4-len(input_num)):
        output_num += '0'
    output_num += input_num
    return output_num

getStringNumber(0)
# getStringNumber(1)
# getStringNumber(123)
# getStringNumber(2345)

'0000'

In [15]:
def getDictKey(feature_record):
    # process case name
    case_name_record = feature_record['case_name']
    if case_name_record == None:
        print('We do find a None in case_name')
    if '@' in case_name_record:
        print('We do find a @')
        
    # process cite id
    cite_id_record = feature_record['us_cite_id']
    if cite_id_record == None:
        cite_id_record = 'None'
    if '@' in cite_id_record:
        print('We do find a @')
    
    # prepare for the dictionary key
    output_dict_key = case_name_record+'@'+cite_id_record
    return output_dict_key

In [68]:
saved_mapping_path = '/misc/grice1/yijun/SCOTUS-Embedding/data/'
saved_files_path = '/misc/grice1/yijun/SCOTUS-Embedding/data/supreme_court_8K/'
count = 0
caseName_and_citeID_to_savedID_list = []

for record in sc.records():
    text_record = record[0]
    feature_record = record[1]
    
    # save to local
    saved_string_num = getStringNumber(count)
    file_name = saved_string_num + '.txt'
    with open(saved_files_path+file_name, 'w') as f:
        f.write(text_record)
        
    # save to list
    dict_key = getDictKey(feature_record)
    caseName_and_citeID_to_savedID_list.append(dict_key+'@'+saved_string_num)
    
    count += 1
    # break
    
print('count: ', count)
print('length of caseName_and_citeID_to_savedID_list: ', len(caseName_and_citeID_to_savedID_list))

# save the mapping list
saved_mapping_file_name = 'caseName_and_citeID_to_savedID.txt'
saved_mapping_str = ''
for i in range(len(caseName_and_citeID_to_savedID_list)):
    saved_mapping_str += caseName_and_citeID_to_savedID_list[i]+'\n'
with open(saved_mapping_path + saved_mapping_file_name, 'w') as f:
    f.write(saved_mapping_str)
print('The mapping info has been saved as: ', saved_mapping_path + saved_mapping_file_name)

count:  8419
length of caseName_and_citeID_to_savedID_list:  8419
The mapping info has been saved as:  /misc/grice1/yijun/SCOTUS-Embedding/data/caseName_and_citeID_to_savedID.txt


### Read from the saved mapping file

In [17]:
saved_mapping_path = '/misc/grice1/yijun/SCOTUS-Embedding/data/'
saved_files_path = '/misc/grice1/yijun/SCOTUS-Embedding/data/supreme_court_8K/'
saved_mapping_file_name = 'caseName_and_citeID_to_savedID.txt'

# read the saved mapping file
with open(saved_mapping_path + saved_mapping_file_name, 'r') as f:
    read_list = f.read().split('\n')[:-1]
    
print(type(read_list))
print('length of read_list: ', len(read_list))

# build a dict to map case from 8K dateset to the file name
read_dict = {}
for line in read_list:
    case_name, cite_id, file_name = line.split('@')
    key = case_name + '@' + cite_id
    read_dict[key] = file_name

# go through 8K dataset to find the file name
for record in sc.records():
    text_record = record[0]
    feature_record = record[1]
    
    dict_key = getDictKey(feature_record)
    if dict_key not in read_dict:
        print('we find a case from 8K dataset is not in our dict')
print('\nwe find every case in our dict')


<class 'list'>
length of read_list:  8419

we find every case in our dict


### save the 15 labels

In [40]:
saved_labels_path = '/misc/grice1/yijun/SCOTUS-Embedding/data/'
saves_labels_file_name = 'UWash_labels.txt'
with open(saved_labels_path+saves_labels_file_name, 'w') as f:
    output_string = ''
    for index in range(len(case_name_plus_citeId_list)):
        case_name_plus_citeId = case_name_plus_citeId_list[index]
        case_label_15 = str(labels_15[index])
        case_label_279 = str(labels_279[index])
        file_name = read_dict[case_name_plus_citeId]
        output_string = output_string+file_name+'@'+case_label_15+'@'+case_label_279+'\n'
    f.write(output_string)