In [99]:
import numpy as np
import pandas as pd
import textdistance
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import accuracy_score
import fuzzywuzzy
from fuzzywuzzy import fuzz
from collections import defaultdict

In [140]:
handcoded = pd.read_csv('handcoded.csv')
handcoded = handcoded.drop(['Unnamed: 0'], axis=1)
handcoded['amicus'] = handcoded['amicus'].apply(lambda x: x.lower())
handcoded['bonica'] = handcoded['bonica'].apply(lambda x: x.lower())

In [141]:
pd.set_option('display.max_rows', 1000)

In [164]:
def match_name(list1, list2):
    out = []
    for i in list1:
        for j in list2:
            score = fuzz.ratio(i,j)
            if score > 50:
                out.append([i,j,score])
    return out

In [134]:
# %%
# Vector of amicus strings concatenated with corresponding bonica strings
handcoded_vec = handcoded['amicus'].map(str) + '_' + handcoded['bonica']

In [144]:
# %%
# Create a set of incorrect matches
# First, copy the correct matches
tmp = handcoded.copy()
# Shuffle the amicus column - makes most of them mismatched
tmp['amicus'] = np.random.permutation(tmp['amicus'].values)
# For any that might still be correct matches, filter them out
# by making sure the concatenated string isn't in the vector of
# correct concatenated strings (handcoded_vec)
tmp_vec = tmp['amicus'].map(str) + '_' + tmp['bonica']
tmp = tmp[~tmp_vec.map(lambda x: handcoded_vec.str.contains(x).any())]
tmp['match'] = 0

In [145]:
# %%
# Get one more batch of incorrect matches
tmp2 = handcoded.copy()
tmp2['amicus'] = np.random.permutation(tmp2['amicus'].values)
tmp2_vec = tmp2['amicus'].map(str) + '_' + tmp2['bonica']
tmp2 = tmp2[~tmp2_vec.map(lambda x: handcoded_vec.str.contains(x).any())]
tmp2['match'] = 0

In [146]:
# %%
print(tmp.shape)
print(tmp2.shape)
print(handcoded.shape)

(228, 3)
(229, 3)
(231, 3)


In [147]:
# %%
# Concatenate the incorrect ones, drop duplicates, and concatenate with the correct ones
tmp_full = pd.concat([tmp, tmp2])
tmp_full.drop_duplicates(inplace=True)
train = pd.concat([handcoded, tmp_full])
train['amicus'] = train['amicus'].str.lower()
train['bonica'] = train['bonica'].str.lower()

In [165]:
train_amicus = sorted(list(set(train['amicus'])))
train_bonica = sorted(list(set(train['bonica'])))

In [167]:
testout = match_name(train_amicus, train_bonica)

In [168]:
testout

[['american association of retired persons',
  'american association of retired persons',
  100],
 ['american association of retired persons',
  'american association of university professors',
  71],
 ['american association of retired persons',
  'american federation of gov emp',
  58],
 ['american association of retired persons',
  'american federation of labor & c',
  56],
 ['american association of retired persons',
  'american federation of tv & radio artists',
  57],
 ['american association of retired persons',
  'american institute of certified public accountants',
  54],
 ['american association of retired persons',
  'american psychiatric association',
  56],
 ['american association of retired persons',
  'american psychiatric association political action committee',
  55],
 ['american association of retired persons',
  'american public power association public ownership',
  56],
 ['american association of retired persons',
  'association of american publishers',
  60],
 ['amer

In [12]:
# %%
# Add more distance metrics?
methods = [textdistance.cosine, textdistance.jaccard]
def stringdist_wrap(row):
    a, b = row[['amicus', 'bonica']]
    out = pd.Series([m.distance(a, b) for m in methods])
    return out

In [13]:
# %%
df = train.apply(stringdist_wrap, axis=1)
df.columns = ['cosine', 'jaccard']

In [15]:
# %%
#the actual model
parameters = {'max_depth': [2,3,4]}

GSCV = GridSearchCV(cv = 5,
                   estimator = RandomForestClassifier(),
                   param_grid = parameters)

model = GSCV.fit(df, train['match'])



In [16]:
preds = model.predict(df)
labels = train['match']

In [17]:
accuracy_score(preds, labels)

0.8330434782608696