In [1]:
from plotlib.loaders import *
from plotlib.plotters import *

from phdconf import stop
from phdconf import config 

import os
import nltk

from sentence_transformers import SentenceTransformer 
from sklearn.metrics.pairwise import cosine_similarity

In [2]:
queries = load_queries(config.AUS_TOPIC_PATH)

In [3]:
def load_qrel(path: str):
    qrels = {}
    with open(path) as f:
        for line in f:
            parts = line.strip().split()
            if parts[0] not in qrels:
                qrels[parts[0]] = [set(), set()]
            if parts[3] == '0': 
                qrels[parts[0]][1].add(parts[2])
            else:
                qrels[parts[0]][0].add(parts[2])
            
    return qrels

In [4]:
qrels = load_qrel(config.AUS_QREL_PATH)

In [5]:
prefix_path = os.path.join(os.environ["HOME"], 'JSON')

def get_case_path_from_id(path: str, prefix:str=prefix_path):
    prev = False 
    parts = []
    part = ''
    for t in path: 
        curr = t.isalpha()
        if curr == prev: 
            part += t
        else:
            parts.append(part)
            part = t
            
        prev = curr
        
    if part != '':
        parts.append(part)
            
    f_name = ''
    _dir = ''
    if 'FCA' in parts[1]:
        if int(parts[0]) < 2010:
            _dir = 'FCAP10/' + parts[0]
        else: 
            _dir = 'FCA/' + parts[0] 
        parts[2] = '{0:04}'.format(int(parts[2]))
        f_name = ''.join(parts) + '.json'
    else:
        parts[2] = '{0:03}'.format(int(parts[2]))
        f_name = '-'.join(parts[:3])
        if len(parts) == 4:
             parts += parts[3]
        f_name += '.json'
        _dir = 'QLD/{0}/{1}'.format(parts[1], parts[0])
        
    return os.path.join(prefix, _dir, f_name)
    
def load_json_case(_id: str):
    path = get_case_path_from_id(_id)
    with open(path) as f:
        data = json.load(f)
        return [x for x in data['body'] if x['type'] == 'paragraph' or x['type'] == 'quote']

In [31]:
m = 0
m_id = ''
large = []
for topic in qrels:
    for _id in qrels[topic][0].union(qrels[topic][1]):
        all_sent = []
        for para in load_json_case(_id):
            all_sent += nltk.sent_tokenize(para['text'])
        if len(all_sent) > m:
            m = len(all_sent)
            m_id = _id
        if len(all_sent) > 200: 
            large.append((_id, len(all_sent)))
        
print(m, m_id)

12003 2000FCA1084


4704


In [40]:
print(len([x for x in large if x[1] > 300]))
print(len([x for x in large if x[1] > 400]))
print(len([x for x in large if x[1] > 500]))
print(len([x for x in large if x[1] > 600]))
print(len([x for x in large if x[1] > 700]))
print(len([x for x in large if x[1] > 800]))
print(len([x for x in large if x[1] > 900]))
print(len([x for x in large if x[1] > 1000]))
print(len([x for x in large if x[1] > 2000]))
print(len([x for x in large if x[1] > 4000]))
print(len([x for x in large if x[1] > 5000]))

3505
2706
2184
1843
1588
1371
1233
1085
501
205
174


In [67]:
def read_crim_file(path: str):
    lookup = set()
    with open(path) as f:
        for line in f:
            parts = line.split()
            if parts[1] == '1':
                lookup.add(parts[0])
    return lookup

def count_crim_in_res_file(path: str, crim_lookup):
    qry = {}
    with open(path) as f:
        for line in f: 
            parts = line.split()
            v = qry.get(parts[0], 0)
            if parts[2] in crim_lookup: 
                v += 1
            qry[parts[0]] = v
                           
    return qry

In [68]:
crim_lookup = read_crim_file('/home/danlocke/go/src/crim-feature-file/crim-cases.txt')

In [69]:
count_crim_in_res_file('/home/danlocke/phd-generated/dirichlet_prior/case-topics-filtered-stop-unigram_dir_mu_2400.00.run', crim_lookup)

{'1': 0,
 '2': 0,
 '3': 3,
 '4': 2,
 '5': 1,
 '6': 0,
 '7': 0,
 '8': 0,
 '9': 0,
 '10': 0,
 '11': 2,
 '12': 0,
 '13': 2,
 '15': 10,
 '17': 3,
 '19': 0,
 '21': 0,
 '22': 0,
 '23': 0,
 '24': 0,
 '25': 0,
 '27': 1,
 '28': 0,
 '29': 1,
 '32': 2,
 '33': 3,
 '34': 2,
 '35': 0,
 '37': 1,
 '39': 0,
 '41': 0,
 '43': 2,
 '44': 32,
 '45': 13,
 '46': 57,
 '47': 1,
 '49': 0,
 '50': 1,
 '53': 2,
 '54': 0,
 '55': 3,
 '57': 0,
 '58': 0,
 '59': 11,
 '60': 1,
 '61': 6,
 '62': 1,
 '64': 0,
 '65': 0,
 '67': 0,
 '69': 0,
 '70': 0,
 '71': 0,
 '73': 0,
 '74': 7,
 '75': 8,
 '76': 8,
 '77': 0,
 '78': 0,
 '80': 1,
 '81': 0,
 '82': 0,
 '83': 1,
 '84': 0,
 '85': 0,
 '86': 1,
 '87': 0,
 '88': 2,
 '89': 1,
 '90': 0,
 '91': 18,
 '93': 1,
 '94': 1,
 '95': 7,
 '96': 1,
 '97': 1,
 '98': 0,
 '99': 0,
 '101': 0,
 '102': 1,
 '103': 1,
 '104': 0,
 '105': 0,
 '107': 7,
 '108': 1,
 '109': 12,
 '110': 1,
 '111': 0,
 '112': 1,
 '113': 0,
 '114': 36,
 '115': 3,
 '116': 0,
 '117': 13,
 '118': 1}

In [20]:
encoder = SentenceTransformer('distilbert-base-nli-stsb-mean-tokens')

In [21]:
print(type(encoder))

<class 'sentence_transformers.SentenceTransformer.SentenceTransformer'>


In [27]:
def rescore(path: str, out_file:str, queries, encoder:SentenceTransformer): 
    qry = {}
    with open(path) as f:
        for line in f:
            parts = line.split()
            vals = []
            q = int(parts[0])
            if q in qry: 
                vals = qry[q] 
                
            qry_score = encoder.encode(queries[1]['topic'])

            all_sent = []
            for para in load_json_case(parts[2]):
                all_sent += nltk.sent_tokenize(para['text'])

            scored = encoder.encode(all_sent)
            sims = cosine_similarity(qry_score, scored)[0]
            vals.append((parts[2], max(sims)))
            qry[q] = vals
            
    qry = {k: sorted(v, key=lambda x: x[1], reverse=True) for k, v in qry.items()}
    with open(out_file, 'w+') as f:
        for key, vals in m.items():
            for i, val in enumerate(vals):
                f.write('{0} Q0 {1} {2} {3:4f} b\n'.format(key, val[0], i, val[1]))


rescore('/home/danlocke/phd-generated/dirichlet_prior/case-topics-filtered-stop-unigram_dir_mu_2400.00.run', , queries, encoder)

TypeError: rescore() missing 2 required positional arguments: 'queries' and 'encoder'

In [25]:
m = {1: [('a', 100), ('b', 200), ('c', 120)], 2: [('a', 140), ('b', 200), ('c', 120)]}
m = {k: sorted(v, key=lambda x: x[1], reverse=True) for k, v in m.items()}

with open('test.txt', 'w+') as f:
    for key, vals in m.items():
        for i, val in enumerate(vals):
            f.write('{0} Q0 {1} {2} {3:4f} b\n'.format(key, val[0], i, val[1]))

In [19]:
print(m)

{1: [('b', 200), ('c', 120), ('a', 100)], 2: [('b', 200), ('a', 140), ('c', 120)]}
