<a href="https://colab.research.google.com/github/menicacci/CE_for_ER/blob/main/Ditto_with_CE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Ditto set up

# clone the repo
!git clone https://github.com/megagonlabs/ditto
%cd ditto/
!git pull

# install requirements
!pip install -r requirements.txt

!pip install transformers
!pip install tensorboardX

%cd ..

# support for the model

import nltk
nltk.download('stopwords')

!git clone https://github.com/NVIDIA/apex
%cd apex
!python setup.py install
%cd ..

!pip install --upgrade "urllib3==1.25.4" awscli
!pip install jsonlines

%cd ditto

In [2]:
import subprocess
import sys
import re

def run(command):
  try:
    return subprocess.run(command, capture_output=True, text=True, check=True)
  except subprocess.CalledProcessError as e:
    print(f"ERROR\n\n{e.stderr}")
    sys.exit(file=None)

def get_f1_score(output):
  f1 = re.search(r'real_f1 = (\d+\.\d+)', output)
  return float(f1.group(1)) if f1 else False

In [3]:
lm = 'distilbert'
dataset = 'Structured/Fodors-Zagats'

In [4]:
# command for training Ditto
training = [
    'python', 'train_ditto.py',
    '--task', dataset,
    '--batch_size', '32',
    '--max_len', '64',
    '--lr', '3e-5',
    '--n_epochs', '20',
    '--finetuning',
    '--lm', lm,
    '--fp16',
    '--save_model',
    '--logdir', f'model/{lm}/'

]

# command for testing Ditto
testing = [
    'python', 'matcher.py',
    '--task', dataset,
    '--input_path', f"data/er_magellan/{dataset}/test.txt",
    '--output_path', "output/output_small.jsonl",
    '--lm', lm,
    '--max_len', '64',
    '--use_gpu',
    '--fp16',
    '--checkpoint_path', f'model/{lm}'
]

In [5]:
# training without CE
standard_training = run(training)

# testing
testing_1 = run(testing)
results_1 = get_f1_score(testing_1.stdout)

In [None]:
%cd ..

In [None]:
# Certa set-up

!git clone https://github.com/tteofili/certa.git
%cd certa
!git pull
!pip install .

import pandas as pd
from certa.explain import CertaExplainer
from certa.utils import merge_sources
from certa.models.ditto.ditto import DittoModel

%cd ..

In [None]:
# load dataset

datadir = 'https://pages.cs.wisc.edu/~anhai/data1/deepmatcher_data/' + dataset + '/exp_data'

lsource = pd.read_csv(datadir + '/tableA.csv')
rsource = pd.read_csv(datadir + '/tableB.csv')
gt = pd.read_csv(datadir + '/train.csv')
valid = pd.read_csv(datadir + '/valid.csv')
test = pd.read_csv(datadir + '/test.csv')

valid_df = merge_sources(valid, 'ltable_', 'rtable_', lsource, rsource, ['label'], [])
train_df = merge_sources(gt, 'ltable_', 'rtable_', lsource, rsource, ['label'],[])

In [9]:
# load model
import torch
pt_model_dict = torch.load('ditto/model/' + lm + '/' + dataset + '/model.pt')

In [10]:
ditto_model = DittoModel(lm=lm, device='cpu')
ditto_model.load_state_dict(pt_model_dict['model'])

<All keys matched successfully>

In [11]:
from certa.models.ditto.summarize import Summarizer
summarizer = Summarizer(lsource, rsource, lm)

certa_explainer = CertaExplainer(lsource, rsource, data_augmentation=None)

from certa.models.bert import EMTERModel
model = EMTERModel(ditto=True, summarizer=summarizer, dk='')

model.model = ditto_model

In [12]:
def predict_fn(x):
    return model.predict(x, len=64)

predictions = predict_fn(train_df)

In [13]:
# convert .csv into Ditto string format

def to_string(r1, r2, l):
    content = ''
    for ent in [r1, r2]:
        if isinstance(ent, str):
            content += ent
        else:
            for attr in ent.keys():
                content += 'COL %s VAL %s ' % (attr, ent[attr])
        content += '\t'

    content += str(l) + '\n'
    return content

def clean_cols(r):
    return r.replace('ltable_','').replace('rtable_','')

def get_record(r1, r2, l):
    return clean_cols(to_string(r1,r2, l))

In [14]:
# returns a .csv containing all wrong predictions from the predictions parameter
def get_wrong_preds(predictions):
    false_positives = predictions.loc[(predictions['label'] == 0) & (predictions['match_score'] > 0.5)]
    false_negatives = predictions.loc[(predictions['label'] == 1) & (predictions['match_score'] < 0.5)]

    print('# of false positives: ' + str(len(false_positives)))
    print('# of false negatives: ' + str(len(false_negatives)))

    return pd.concat([false_positives, false_negatives])

In [15]:
# filters a row into two tuples
def get_tuples(rand_row):
    l_id = int(rand_row['ltable_id'])
    l_tuple = lsource.iloc[l_id]
    r_id = int(rand_row['rtable_id'])
    r_tuple = rsource.iloc[r_id]
    return l_tuple, r_tuple

# returns a .csv containing counterfactual explanations for wrong_pred
def get_counterfactual_explanations(certa_exp, predict_fn, wrong_pred):
    l_tuple, r_tuple = get_tuples(wrong_pred)
    _, _, ce, _, _ = certa_exp.explain(l_tuple, r_tuple, predict_fn, num_triangles=10)

    return ce

# filters a record
def get_explanation(ex):
    l_tuple = ex.filter(regex='^ltable_').to_dict()
    r_tuple = ex.filter(regex='^rtable_').to_dict()

    l = int( (float(ex.filter(regex='^match_'))) > (float(ex.filter(regex='^nomatch_'))) )
    return l_tuple, r_tuple, l

In [16]:
# returns a string containing new records (counterfactual explanations).
# n is the number of explanations per wrong prediction record
def get_records(c_exp, predict_fn, wrong_preds, n):
    output = ''

    for i in range(len(wrong_preds)):
        # counterfactual explanations in .csv format
        explanations = get_counterfactual_explanations(c_exp, predict_fn, wrong_preds.iloc[i])

        n_expl = min(n, len(explanations))
        for j in range(n_expl):
            # left and right tuple in .csv format + label (int)
            l_tuple, r_tuple, label = get_explanation(explanations[:n_expl].iloc[j])
            # convers parameters to string
            record = get_record(l_tuple, r_tuple, label)

            print('Record: ' + str(j))
            print(record)
            # updates output
            output += record

    return output

In [17]:
wrong_preds = get_wrong_preds(predictions)
wrong_preds.head()

# of false positives: 2
# of false negatives: 0


Unnamed: 0,label,ltable_name,ltable_addr,ltable_city,ltable_phone,ltable_type,ltable_class,rtable_name,rtable_addr,rtable_city,rtable_phone,rtable_type,rtable_class,match_score,nomatch_score,ltable_id,rtable_id
231,0,` restaurant ritz-carlton atlanta ',' 181 peachtree st. ',atlanta,404/659 -0400,continental,91.0,` ritz-carlton cafe ( atlanta ) ',' 181 peachtree st. ',atlanta,404-659-0400,` american ( new ) ',711.0,0.849613,0.150387,91.0,330.0
550,0,` dining room ritz-carlton buckhead ',' 3434 peachtree rd. ',atlanta,404/237 -2700,international,90.0,` ritz-carlton cafe ( buckhead ) ',' 3434 peachtree rd. ne ',atlanta,404-237-2700,` american ( new ) ',89.0,0.634264,0.365736,90.0,307.0


In [18]:
new_records = get_records(certa_explainer, predict_fn, wrong_preds, 2)

  records = records.append([record] * len(source), ignore_index=True)
  records = records.append([record] * len(source), ignore_index=True)


Record: 0
COL name VAL ` il fornaio cucina italiana ' COL addr VAL ' 181 peachtree st. ' COL city VAL atlanta COL phone VAL 404/659 -0400 COL type VAL continental COL class VAL 91 	COL name VAL ` ritz-carlton cafe ( atlanta ) ' COL addr VAL ' 181 peachtree st. ' COL city VAL atlanta COL phone VAL 404-659-0400 COL type VAL ` american ( new ) ' COL class VAL 711 	0

Record: 1
COL name VAL ` restaurant ritz-carlton atlanta ' COL addr VAL ' 11705 national blvd. ' COL city VAL atlanta COL phone VAL 404/659 -0400 COL type VAL continental COL class VAL 91 	COL name VAL ` ritz-carlton cafe ( atlanta ) ' COL addr VAL ' 181 peachtree st. ' COL city VAL atlanta COL phone VAL 404-659-0400 COL type VAL ` american ( new ) ' COL class VAL 711 	0



  records = records.append([record] * len(source), ignore_index=True)
  records = records.append([record] * len(source), ignore_index=True)


Record: 0
COL name VAL ` alain rondelli ' COL addr VAL ' 3434 peachtree rd. ' COL city VAL atlanta COL phone VAL 404/237 -2700 COL type VAL international COL class VAL 90 	COL name VAL ` ritz-carlton cafe ( buckhead ) ' COL addr VAL ' 3434 peachtree rd. ne ' COL city VAL atlanta COL phone VAL 404-237-2700 COL type VAL ` american ( new ) ' COL class VAL 89 	0

Record: 1
COL name VAL ` dining room ritz-carlton buckhead ' COL addr VAL ' 3434 peachtree rd. ' COL city VAL atlanta COL phone VAL 213/658 -6340 COL type VAL international COL class VAL 90 	COL name VAL ` ritz-carlton cafe ( buckhead ) ' COL addr VAL ' 3434 peachtree rd. ne ' COL city VAL atlanta COL phone VAL 404-237-2700 COL type VAL ` american ( new ) ' COL class VAL 89 	0



In [19]:
# adding the CE to the training set
training_dir = 'ditto/data/er_magellan/' + dataset + '/train.txt'

training_set = open(training_dir, "a", encoding="utf8")
training_set.write(new_records)
training_set.seek(0,0)
training_set.close()

In [None]:
%cd ditto

In [21]:
# training with CE
ce_training = run(training)

# testing
testing_2 = run(testing)
results_2 = get_f1_score(testing_2.stdout)

In [22]:
print(f"F1 score with standard training: \t\t{round(results_1, 3)}")
print(f"F1 score with counterfactual explanations: \t{round(results_2, 3)}")

F1 score with standard training: 		0.955
F1 score with counterfactual explanations: 	0.978
