### This file loads the proxy model trained with IIT, and evaluate it for concept-based explanations with the CEBaB dataset.
This file tends to follow the one in our OpenTable repo.

In [1]:
import json
import os 
import pandas as pd
from collections import defaultdict
import numpy as np
import random
from tqdm import tqdm

import torch
from transformers import (
    AutoModelForSequenceClassification, 
    AutoTokenizer, 
)
from math import ceil

import sklearn
from sklearn.pipeline import Pipeline
from sklearn.metrics import classification_report

from models.modelings_roberta import *
from models.modelings_bert import *

### 1) Get data

In [2]:
def load_split_new_scheme(splitname):
    filename = os.path.join("../OpenTable/mturk/dataset", f"dataset-2022-05-02-{splitname}.json")
    with open(filename) as f:
        data = json.load(f)
    return data  

In [3]:
train = pd.DataFrame(load_split_new_scheme("train"))
dev = pd.DataFrame(load_split_new_scheme("dev"))
test = pd.DataFrame(load_split_new_scheme("test"))

In [4]:
# dropping all -1 review_majority in the train for simplicity
# NOTE: this is a large part of the dataset, we should deal with this!
print(sum(train['review_majority'] == 'no majority') / len(train))
train = train[train['review_majority'] != 'no majority']

0.16030013642564803


### 2) Create all intervention pairs in the dev/test set

In [5]:
def get_pairs_per_original(df):
    """
    For a df containing all examples related to one original,
    create and return all the possible intervention pairs.
    """
    assert len(df.original_id.unique()) == 1

    df_edit = df[~df['is_original']].reset_index(drop=True)
    df_original = pd.concat([df[df['is_original']]] * len(df_edit)).reset_index(drop=True)
    
    assert len(df_edit) == len(df_original)

    # (edit, original) pairs
    df_edit_base = df_edit.rename(columns=lambda x: x + '_base')
    df_original_counterfactual = df_original.rename(columns=lambda x: x + '_counterfactual')

    edit_original_pairs = pd.concat([df_edit_base, df_original_counterfactual], axis=1)

    # (original, edit) pairs
    df_edit_counterfactual = df_edit.rename(columns=lambda x: x + '_counterfactual')
    df_original_edit = df_original.rename(columns=lambda x: x + '_base')

    original_edit_pairs = pd.concat([df_original_edit, df_edit_counterfactual], axis=1)
    
    # (edit, edit) pairs

    # The edits are joined based on their edit type. 
    # Actually, the 'edit_type' can also differ from the edit performed, but there is no clean way of resolving this.
    edit_edit_pairs = df_edit.merge(df_edit, on='edit_type', how='inner', suffixes=('_base', '_counterfactual'))
    edit_edit_pairs = edit_edit_pairs[edit_edit_pairs['id_base'] != edit_edit_pairs['id_counterfactual']]
    edit_edit_pairs = edit_edit_pairs.rename(columns= {'edit_type':'edit_type_base'})
    edit_edit_pairs['edit_type_counterfactual'] = edit_edit_pairs['edit_type_base']
    
    # get all pairs
    pairs = pd.concat([edit_original_pairs, original_edit_pairs, edit_edit_pairs]).reset_index(drop=True)

    # annotate pairs with the intervention type and the direction (calculated from the validated labels)
    pairs = _get_intervention_type_and_direction(pairs)

    return pairs

def _drop_unsuccesful_edits(pairs):
    """
    Drop edits that produce no measured aspect change.
    """
    # Make sure the validated labels of the edited aspects are different.
    # We can not do this comparison based on 'edit_goal_*' because the final label might differ from the goal.
    meaningless_edits = pairs['intervention_aspect_base'] == pairs['intervention_aspect_counterfactual'] 
    print(f'Dropeed {sum(meaningless_edits)} pairs that produced no validated label change. This is due to faulty edits by the workers or edits with the same edit_goal.')
    pairs = pairs[~meaningless_edits]

    return pairs

def _get_intervention_type_and_direction(pairs):
    """
    Annotate a dataframe of pairs with their invention type 
    and the validated label of that type for base and counterfactual.
    """
    # get intervention type
    pairs['intervention_type'] = np.maximum(pairs['edit_type_base'].astype(str), pairs['edit_type_counterfactual'].astype(str))

    # get base/counterfactual value of the intervention aspect
    pairs['intervention_aspect_base'] = \
        ((pairs['intervention_type'] == 'ambiance') * pairs['ambiance_aspect_majority_base']) +\
        ((pairs['intervention_type'] == 'noise') * pairs['noise_aspect_majority_base']) +\
        ((pairs['intervention_type'] == 'service') * pairs['service_aspect_majority_base']) +\
        ((pairs['intervention_type'] == 'food') * pairs['food_aspect_majority_base'])

    pairs['intervention_aspect_counterfactual'] = \
        ((pairs['intervention_type'] == 'ambiance') * pairs['ambiance_aspect_majority_counterfactual']) +\
        ((pairs['intervention_type'] == 'noise') * pairs['noise_aspect_majority_counterfactual']) +\
        ((pairs['intervention_type'] == 'service') * pairs['service_aspect_majority_counterfactual']) +\
        ((pairs['intervention_type'] == 'food') * pairs['food_aspect_majority_counterfactual'])
    
    return pairs


def _int_to_onehot(series, range):
    """
    Encode a series of ints as a series of onehot vectors.
    Assumes the series of ints is contained within the range.
    """
    offset = range[0]
    range = max(range) - min(range) + 1
    
    def _get_onehot(x):
        zeros = np.zeros(range)
        zeros[int(x) - offset] = 1.0
        return zeros

    return series.apply(_get_onehot)

def _pairs_to_onehot(pairs):
    """
    Cast the review majority columns to onehot vectors.
    """
    pairs['review_majority_counterfactual'] = _int_to_onehot(pairs['review_majority_counterfactual'], range(1,6))
    pairs['review_majority_base'] = _int_to_onehot(pairs['review_majority_base'], range(1,6))

    return pairs


def get_pairs(df):
    """
    Given a dataframe in the new data scheme, return all intervention pairs.
    """
    # Drop label distribution and worker information.
    columns_to_keep = ['id', 'original_id', 'edit_id', 'is_original', 'edit_goal', 'edit_type', 'description', 'review_majority','food_aspect_majority', 'ambiance_aspect_majority', 'service_aspect_majority', 'noise_aspect_majority']
    columns_to_keep += [col for col in df.columns if 'prediction' in col]
    df = df[columns_to_keep]
    
    # get all the intervention pairs
    unique_originals = df.original_id.unique()
    to_merge = []
    for unique_id in unique_originals:
        df_slice = df[df['original_id'] == unique_id]
        pairs_slice = get_pairs_per_original(df_slice)
        to_merge.append(pairs_slice)
    pairs = pd.concat(to_merge)

    # drop unsuccesful edits
    pairs = _drop_unsuccesful_edits(pairs)

    # remove all examples where the intervention aspect is 'no majority' for base or counterfactual, since we cant measure a change here with certainty
    intervention_aspect_nomaj = (pairs['intervention_aspect_counterfactual'] == 'no majority') | (pairs['intervention_aspect_base'] == 'no majority')
    review_nomaj = (pairs['review_majority_counterfactual'] == 'no majority') | (pairs['review_majority_base'] == 'no majority')
    nomaj = intervention_aspect_nomaj | review_nomaj
    pairs = pairs[~nomaj]
    print(f'Dropped {sum(nomaj)} examples with no_majority.')

    # onehot encode
    pairs = _pairs_to_onehot(pairs)

    return pairs

### 3) model loader

In [6]:
class Transformer_CEBaB():
    def __init__(self, model_path, device='cpu', batch_size=64):
        self.device = device
        self.model_path = model_path
        self.batch_size = batch_size

        self.model = AutoModelForSequenceClassification.from_pretrained(self.model_path)
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        except:
            self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        self.model.to(device)

    def preprocess(self, df):
        x = self.tokenizer(df['description'].to_list(), padding=True, truncation=True, return_tensors='pt')
        y = df['review_majority'].astype(int) - 1

        return x, y

    def fit(self, dataset):
        # assume model was already trained
        pass 

    def predict_proba(self, dataset):
        self.model.eval()

        x, y = self.preprocess(dataset)

        # get the predictions batch per batch
        probas = []
        for i in range(ceil(len(dataset)/self.batch_size)):
            x_batch = {k:v[i*self.batch_size:(i+1)*self.batch_size].to(self.device) for k,v in x.items()}    
            probas.append(torch.nn.functional.softmax(self.model(**x_batch).logits.cpu()).detach())

        probas = torch.concat(probas)
        probas = np.round(probas.numpy(), decimals=4)
        
        predictions =np.argmax(probas, axis=1)
        clf_report = classification_report(y.to_numpy(), predictions, output_dict=True)

        return probas, clf_report

### 5) explainer loader

In [17]:
class ProxyIIT():
    def __init__(self, model_path, device, batch_size):
        self.batch_size = batch_size
        self.device = device
        if "roberta" in model_path.lower():
            print("loading roberta model")
            model = IITRobertaForSequenceClassification.from_pretrained(
                model_path,
                cache_dir="../huggingface_cache/"
            )
            self.tokenizer = AutoTokenizer.from_pretrained("roberta-base")
        elif "bert" in model_path.lower():
            print("loading bert model")
            model = IITBERTForSequenceClassification.from_pretrained(
                model_path,
                cache_dir="../huggingface_cache/"
            )
            self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        model.to(device)
        self.model = InterventionableIITTransformerForSequenceClassification(
            model=model
        )
    def fit(self, pairs):
        # we don't need to train IIT here.
        pass
    
    def preprocess(self, pairs_dataset, dev_dataset):

        min_iit_pair_examples = 10
        query_dataset = get_iit_examples(dev_dataset)
        iit_pairs_dataset = []
        iit_id = 0
        for index, row in pairs_dataset.iterrows():
            query_description_base = row['description_base']
            query_int_type = row['intervention_type']
            query_int_aspect_base = row["intervention_aspect_base"]
            query_int_aspect_assignment = row['intervention_aspect_counterfactual']
            query_original_id = row["original_id_base"]
            matched_iit_examples = query_dataset[
                (query_dataset[f"{query_int_type}_aspect_majority"]==query_int_aspect_assignment)&
                (query_dataset["original_id"]!=query_original_id)
            ]
            if len(set(matched_iit_examples["id"])) < min_iit_pair_examples:
                assert False # we need to check the number!
            sampled_iit_example_ids = random.sample(
                set(matched_iit_examples["id"]), min_iit_pair_examples
            )
            for _id in sampled_iit_example_ids:
                description_iit = query_dataset[query_dataset["id"]==_id]["description"].iloc[0]
                iit_pairs_dataset += [[
                    iit_id,
                    query_int_type, query_int_aspect_base, 
                    query_int_aspect_assignment, 
                    query_description_base, description_iit
                ]]
            iit_id += 1
        iit_pairs_dataset = pd.DataFrame(
            columns=[
                'iit_id',
                'intervention_type', 
                'intervention_aspect_base', 
                'intervention_aspect_counterfactual',
                'description_base', 
                'description_iit'], 
            data=iit_pairs_dataset
        )
        
        base_x = self.tokenizer(
            iit_pairs_dataset['description_base'].to_list(), 
            padding=True, truncation=True, return_tensors='pt'
        )
        source_x = self.tokenizer(
            iit_pairs_dataset['description_iit'].to_list(), 
            padding=True, truncation=True, return_tensors='pt'
        )
        intervention_corr = []
        for _type in iit_pairs_dataset["intervention_type"].tolist():
            if _type == "ambiance":
                intervention_corr += [0]
            if _type == "food":
                intervention_corr += [1]
            if _type == "noise":
                intervention_corr += [2]
            if _type == "service":
                intervention_corr += [3]
        intervention_corr = torch.tensor(intervention_corr).long()
        return base_x, source_x, intervention_corr, iit_pairs_dataset
    
    def predict(self, pairs, df):
        ProxyIIT_iTEs = []
        self.model.model.eval()
        base_x, source_x, intervention_corr, iit_pairs_dataset = self.preprocess(
            pairs, df
        )
        with torch.no_grad():
            for i in tqdm(range(ceil(len(iit_pairs_dataset)/self.batch_size))):
                base_x_batch = {k:v[i*self.batch_size:(i+1)*self.batch_size].to(self.device) for k,v in base_x.items()} 
                source_x_batch = {k:v[i*self.batch_size:(i+1)*self.batch_size].to(self.device) for k,v in source_x.items()} 
                intervention_corr_batch = intervention_corr[i*self.batch_size:(i+1)*self.batch_size].to(self.device)
                
                base_outputs, _, counterfactual_outputs = self.model.forward(
                    base=(base_x_batch['input_ids'], base_x_batch['attention_mask']),
                    source=(source_x_batch['input_ids'], source_x_batch['attention_mask']),
                    base_intervention_corr=intervention_corr_batch,
                    source_intervention_corr=intervention_corr_batch,
                )
                base_outputs = torch.nn.functional.softmax(base_outputs["logits"][0].cpu(), dim=-1).detach()
                counterfactual_outputs = torch.nn.functional.softmax(counterfactual_outputs["logits"][0].cpu(), dim=-1).detach()
                ProxyIIT_iTE = counterfactual_outputs-base_outputs
                ProxyIIT_iTEs.append(ProxyIIT_iTE)
        ProxyIIT_iTEs = torch.concat(ProxyIIT_iTEs)
        ProxyIIT_iTEs = np.round(ProxyIIT_iTEs.numpy(), decimals=4)

        # only for iit explainer!
        iit_pairs_dataset["EiCaCE"] = list(ProxyIIT_iTEs)
        ProxyIIT_iTEs = list(iit_pairs_dataset.groupby(["iit_id"])["EiCaCE"].mean())
        
        return ProxyIIT_iTEs
    

### 4) pipeline loader

In [18]:
def get_iit_examples(df):
    """
    Given a dataframe in the new data scheme, return all intervention pairs.
    """
    # Drop label distribution and worker information.
    columns_to_keep = ['id', 'original_id', 'edit_id', 'is_original', 'edit_goal', 'edit_type', 'description', 'review_majority','food_aspect_majority', 'ambiance_aspect_majority', 'service_aspect_majority', 'noise_aspect_majority']
    columns_to_keep += [col for col in df.columns if 'prediction' in col]
    df = df[columns_to_keep]
    return df

def CEBaB_pipeline(model, explainer, train_dataset, dev_dataset):
    # get predictions on train and dev
    train_predictions, _ = model.predict_proba(train_dataset)
    dev_predictions, dev_report = model.predict_proba(dev_dataset)

    # append predictions to datasets
    train_dataset['prediction'] = list(train_predictions)
    dev_dataset['prediction'] = list(dev_predictions)
    
    # fit explainer
    explainer.fit(train_dataset)

    # get explanations on the pairs
    pairs_dataset = get_pairs(dev_dataset)
    explanations = explainer.predict(
        pairs_dataset, dev_dataset
    )
    
    # append explanations to the pairs
    pairs_dataset['EiCaCE'] = explanations

    # calculate metrics
    aggregated_metrics = get_aggregated_metrics(pairs_dataset)
    CEBaB_metrics = get_CEBaB_metrics(aggregated_metrics)
    
    # add model performance to CEBaB metrics
    accuracy = round(dev_report['accuracy'],4)
    macro_f1 = round(dev_report['macro avg']['f1-score'], 4)
    CEBaB_metrics['accuracy'] = accuracy
    CEBaB_metrics['macro_f1'] = macro_f1

    return aggregated_metrics, CEBaB_metrics

### 6) Putting everything together

In [19]:
def calculate_iTE(pairs):
    pairs['iTE'] = pairs['review_majority_counterfactual'] - pairs['review_majority_base']

    return pairs

def calculate_iCaCE(pairs):
    pairs['iCaCE'] = (pairs['prediction_counterfactual'] - pairs['prediction_base']).apply(lambda x: np.round(x, decimals=4))

    return pairs


def calculate_clf_loss(pairs):
    pairs['(iTE,iCaCE) loss'] = pairs['iCaCE'] - pairs['iTE']
    pairs['(iTE,iCaCE) loss'] = pairs['(iTE,iCaCE) loss'].apply(lambda x: np.linalg.norm(x, ord=2))

    return pairs


def calculate_estimate_loss(pairs):
    pairs['(iCaCE,EiCaCE) loss'] = pairs['iCaCE'] - pairs['EiCaCE']
    pairs['(iCaCE,EiCaCE) loss'] = pairs['(iCaCE,EiCaCE) loss'].apply(lambda x: np.linalg.norm(x, ord=2))

    return pairs


def aggregate_metrics(pairs, metrics):
    pairs_grouped = pairs.groupby(['intervention_type', 'intervention_aspect_base', 'intervention_aspect_counterfactual']).agg({metric: ['mean'] for metric in metrics})
    pairs_count = pairs.groupby(['intervention_type', 'intervention_aspect_base', 'intervention_aspect_counterfactual']).count()
    pairs_grouped['count'] = pairs_count[metrics[0]]

    # round
    for metric in metrics:
        pairs_grouped[f'aggregated {metric}'] = pairs_grouped[metric]['mean'].apply(lambda x: np.round(x, decimals=4))
        pairs_grouped = pairs_grouped.drop(columns=metric)
    
    return pairs_grouped    

def get_aggregated_metrics(pairs):
    pairs_with_metrics = pairs.copy()

    # calculate the metrics
    pairs_with_metrics = calculate_iTE(pairs_with_metrics)
    pairs_with_metrics = calculate_iCaCE(pairs_with_metrics)
    pairs_with_metrics = calculate_clf_loss(pairs_with_metrics)
    pairs_with_metrics = calculate_estimate_loss(pairs_with_metrics)

    # drop useless information
    metrics = ['iTE','iCaCE', 'EiCaCE', '(iTE,iCaCE) loss', '(iCaCE,EiCaCE) loss']
    pairs_with_metrics = pairs_with_metrics[['intervention_type', 'intervention_aspect_base', 'intervention_aspect_counterfactual', 'description_base',  'description_counterfactual'] + metrics]

    # aggregate metrics
    aggregated_metrics = aggregate_metrics(pairs_with_metrics, metrics = metrics)
    aggregated_metrics['loss of aggregated (iTE,iCaCE)'] = (aggregated_metrics['aggregated iTE'] - aggregated_metrics['aggregated iCaCE']).apply(lambda x: np.linalg.norm(x, ord=2)).apply(lambda x: np.round(x, decimals=4))
    aggregated_metrics['loss of aggregated (iCaCE,EiCaCE)'] = (aggregated_metrics['aggregated iCaCE'] - aggregated_metrics['aggregated EiCaCE']).apply(lambda x: np.linalg.norm(x, ord=2)).apply(lambda x: np.round(x, decimals=4))

    return aggregated_metrics

def get_CEBaB_metrics(aggregated_metrics):
    aggregated_metrics = aggregated_metrics[[col for col in aggregated_metrics.columns if 'loss' in col[0]]]
    return aggregated_metrics.mean()

In [20]:
model_path = './training_result/bert.aligned.5-class/'
device = 'cuda:8'
batch_size=64

train_dataset = train.copy()
dev_dataset = dev.copy()

model = Transformer_CEBaB(model_path, device=device, batch_size=batch_size)
explainer = ProxyIIT(
    model_path, 
    device=device, 
    batch_size=batch_size
)

tf_aggregated_metrics, tf_CEBaB_metrics = CEBaB_pipeline(
    model, explainer, train_dataset, dev_dataset
)

Some weights of the model checkpoint at ./training_result/bert.aligned.5-class/ were not used when initializing BertForSequenceClassification: ['multitask_classifier.out_proj.bias', 'multitask_classifier.out_proj.weight', 'multitask_classifier.dense.bias', 'multitask_classifier.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


loading bert model




Dropeed 262 pairs that produced no validated label change. This is due to faulty edits by the workers or edits with the same edit_goal.
Dropped 0 examples with no_majority.


100%|██████████| 610/610 [03:13<00:00,  3.16it/s]


In [22]:
tf_CEBaB_metrics

aggregated (iTE,iCaCE) loss            0.557525
aggregated (iCaCE,EiCaCE) loss         0.504625
loss of aggregated (iTE,iCaCE)         0.064267
loss of aggregated (iCaCE,EiCaCE)      0.141125
accuracy                               0.783600
macro_f1                               0.771500
dtype: float64