In [75]:
%matplotlib inline

import sys 
import os 

nb_dir = os.getcwd()
if nb_dir not in sys.path:
    sys.path.append(nb_dir)

from plotlib.loaders import *
from plotlib.plotters import *

from phdconf import stop
from phdconf import config 

from typing import List, Dict
import math 
import copy
from collections import OrderedDict

from sklearn.metrics.pairwise import linear_kernel, cosine_similarity

In [2]:
queries = load_queries(config.AUS_TOPIC_PATH)

In [81]:
def load_vectors(path: str):
    embeddings = {}
    with open(path, 'r', encoding='utf-8') as f:
        f.readline()
        for line in f:
            values = line.rstrip().rsplit(' ')
            word = values[0]
            coefs = np.asarray(values[1:], dtype='float32')
            embeddings[word] = coefs

    return embeddings

embeddings = load_vectors('/home/danlocke/fastText/filtered-100d.vec')

class EmbMethod:
    SUM = 0
    MEAN = 1 
    CF_SUM = 2 
    CF_MEAN = 3

def embed(tokens: List[str], embeddings, dim: int=100, method: EmbMethod = EmbMethod.SUM, coll_stats = None) -> np.array:
    e = []
    if method > EmbMethod.MEAN: 
        for x in tokens: 
            if x in embeddings:
                mul = coll_stats[x] if x in coll_stats else 1.0
                e.append(mul * embeddings[x])
    else: 
        e = [embeddings[x] for x in tokens if x in embeddings]
    if len(e) == 0:
        return np.zeros((dim,), dtype='float32')
    
    e = np.sum(e, axis=0)
    if method == EmbMethod.MEAN or method == EmbMethod.CF_MEAN:
        e /= len(tokens)
            
    return e

In [4]:
embeddings = load_vectors('/home/danlocke/fastText/filtered-100d.vec')

In [21]:
phrase_embeddings = load_vectors('/home/danlocke/fastText/para-phrase-100d.vec')

In [44]:
find = {
    'first', 
    'second', 
    'third',
    'fourth',
    'fifth',
    'sixth',
    'seventh',
    'eighth',
    'ninth',
    'tenth',
    'eleventh',
    'twelfth',
    'thirteenth',
    'fourteenth',
    'fifteenth',
    'sixteenth',
    'seventeenth',
    'eighteenth',
    'nineteenth',
    'twenty',
    'thirty',
    'fourty',
}

num_exclusion = {
    'amendment',
    'degree',
    'refusal'
}

find2 = {
    'north',
    'south',
    'east',
    'west'
}

find3 = {
    'set',
    'hold',
    'exemplary',
    'intending',
    'include',
    'including',
    'relies',
    'relied',
    'pursuant',
    'thereto',
    'behalf',
    'giving',
    'give',
    'apparently',
    'subsequently',
    'jurisdictional',
    'included',
    'includes',
    'appellant',
    'respondent',
    'resulting',
    'such',
    's',
    'x',
    't',
    'th',
    'thing',
    're',
    'subject'
}

find4 = {
    'nonetheless',
    'follows',
    'referred',
    'thereto',
    'behalf',
    'hastily',
    'instance',
    'instances',
    'such',
    's',
    'applicant',
    'applicants',
    'respondent',
    'respondents',
    'th',
    'x',
    'such',
    'll',
    'only',
    'way',
    're',
    'much',
    'st',
    'consideration',
    'organisation'
}

#     consider 'made' and 'make' as exclusions

def filter_phrases(phrases, splitter=' '):
    d = isinstance(phrases, dict)
    if d:
        ret = {}
    else:
        ret = []
    for phrase in phrases:
        toks = phrase.split(splitter)
        if len(toks) < 2:
            continue 
        
        cont = False 
        cnt = 0

        if toks[-1] in find3 or len(toks[-1]) == 1:
            continue
        
        if len(toks) == 2 and toks[0].startswith('claim'):
            continue
            
        if len(toks) == 3 and toks[0] == 'intention' and toks[1] != 'to':
            continue
            
        if toks[0] == 'intention' and len(toks) == 2:
            continue
            
        if toks[0] in find4 or len(toks[0]) == 1:
            continue
            
        unique = set()
            
        num = False 
        num_ex = False
        for t, tok in enumerate(toks):
            unique.add(tok)
            
            if tok in find: 
                num = True 
            if tok in num_exclusion:
                num_ex = True
            if tok in find2:
                cnt += 1 
            if tok == 'nunc' and t != 0:
                cont = True 
                break
                
        if num and not num_ex: 
            continue
        
        if len(toks) > len(unique): 
            continue
    
        if cont or cnt > 1:
            continue 
            
        if d: 
            ret[phrase] = phrases[phrase]
        else:
            ret.append(phrase)
        
    return ret 

In [73]:
def load_coll_stats(path: str, total=None) -> Dict[str, float]:
    stats = {}
    
    idf = total is None
    if idf:
        total = 0.0
    with open(path) as f:
        for line in f:
            parts = line.split()
            if idf: 
                v = float(parts[-1])
                total += v
                stats[' '.join(parts[:-1])] = v
            else:
                stats[' '.join(parts[:-1])] = math.log(total / float(parts[-1])+1.0) + 1.0
    
    if idf: 
        for k, v in stats.items():
            stats[k] = math.log(total / v+1.0) + 1.0

    return stats

In [76]:
coll_stats = load_coll_stats('/home/danlocke/phd-generated/filtered-stop-top-tokens.txt')

In [45]:
phrase_embeddings = filter_phrases(phrase_embeddings, '_')

In [None]:
[x for x in phrase_embeddings if x.endswith('_date')]

In [47]:
def n_gram(tokens: List[str], n: int): 
    ret = []
    for i in range(0, len(tokens)-(n-1)):
           ret.append('_'.join(tokens[i:i+n]))
           
    return ret 

In [82]:
cnt = 0


for topic in queries.values(): 
    print(topic['topic'])
    tokens = topic['topic'].lower().replace(',', '').replace('?', '').replace('\'', '').replace('/', '').replace('’', '').split()
    tokens = [x for x in tokens if x not in stop.stop]
    diff_tokens = copy.copy(tokens)
    keep = {}
    for grams in [n_gram(tokens, 2), n_gram(tokens, 3)]:
        for i, gram in enumerate(grams):
            if gram in phrase_embeddings:
                keep[gram] = (i, len(gram.split('_')))
    for meth in [EmbMethod.SUM, EmbMethod.CF_SUM, EmbMethod.CF_MEAN]:
        qry_vec = embed(tokens, embeddings, method=meth, coll_stats=coll_stats).reshape(1, -1)
        tok_vecs = [embed([tok], embeddings, method=meth, coll_stats=coll_stats).reshape(1, -1) for tok in tokens]
        sims = [(tokens[i], cosine_similarity(tok_vecs[i], qry_vec)[0][0]) for i in range(len(tokens))]
        print(sorted(sims, key= lambda x: x[1], reverse=True))
    
    print(tokens)
    keep = OrderedDict(sorted(keep.items(), key=lambda t: (t[1][0], -t[1][1])))
    print(keep)
    print('-'*40)
    if cnt > 10:
        break
    cnt += 1

What is the effect of reinstating a company that was in liquidation as regards money that may be recovered?
[('company', 0.7679199), ('liquidation', 0.73408693), ('money', 0.69507045), ('recovered', 0.6855011), ('reinstating', 0.61489415), ('may', 0.55273366), ('what', 0.5356588), ('effect', 0.52630407), ('regards', 0.5105081)]
[('company', 0.76757026), ('liquidation', 0.74729365), ('recovered', 0.6891548), ('money', 0.6759849), ('reinstating', 0.65781003), ('may', 0.53240055), ('effect', 0.5117331), ('regards', 0.50749296), ('what', 0.49963984)]
here
here
here
here
here
here
here
here
here
here
[('company', 0.76757026), ('liquidation', 0.74729365), ('recovered', 0.6891548), ('money', 0.6759849), ('reinstating', 0.65781003), ('may', 0.53240055), ('effect', 0.5117331), ('regards', 0.50749296), ('what', 0.49963984)]
['what', 'effect', 'reinstating', 'company', 'liquidation', 'regards', 'money', 'may', 'recovered']
OrderedDict([('reinstating_company', (2, 2)), ('company_liquidation', (3, 