# Analysing predictions of XNLI

In [1]:
import pandas as pd
from sklearn.metrics import confusion_matrix
from collections import Counter
PATH="C:/Users/bdolicki/Documents/Git/multilingual-analysis/code/analysing_predictions/xnli"


In [2]:
pd.set_option('display.max_colwidth', -1)
pd.options.display.max_rows = None

  """Entry point for launching an IPython kernel.


In [9]:
all_langs = ["ar", "bg","de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr", "ur", "vi", "zh"]
train_langs = ['de']
test_langs = ['de']

## Helper functions

In [29]:
# read XNLI predictions
# it seems that the predictions are shifted and the first prediction is replaced by header, thus we read as if there was no header 
# and then drop the first example
def read_preds(train_lang, test_lang):
    preds = pd.read_csv(f"{PATH}/results/dev_{train_lang}_{test_lang}_predictions.tsv",sep="\t",header=None)
    df = pd.read_csv(f"{PATH}/dev_sets/dev-{test_lang}.tsv",sep="\t",header=None)
    df.columns = ["sentence1","sentence2","actual"]
    df["pred"] = preds[0]
    df = df[df.pred!="Prediction"]
    return df

def mistaken(x):
    return x[x.actual!=x.pred].shape[0]

# show how many examples were misclassified pair class
def mistaken_class_count(df):
    print("# misclassified examples per class")
    display(df.groupby(by="actual").apply(mistaken).sort_values(ascending=False))

# show how many times each class was predicted
def pred_class_count(df):
    print("Predicted class distribution")
    display(df.pred.value_counts())

# show number of predictions for each class pair
def class_pairs(df):
    print("# predictions for each class pair")
    classes = df.actual.unique()
    pairs = {"actual":[],"pred":[],"count":[]}
    for actual in classes:
        for pred in classes:
            pairs["actual"].append(actual)
            pairs["pred"].append(pred)
            pairs["count"].append(df[(df.actual==actual)&(df.pred==pred)].shape[0])
    class_pairs = pd.DataFrame.from_dict(pairs).sort_values(by="count",ascending=False)
    display(class_pairs)

In [30]:
read_preds('de','de').head()

Unnamed: 0,sentence1,sentence2,actual,pred
1,"und er hat gesagt, Mama ich bin daheim.",Er sagte kein Wort.,contradiction,contradiction
2,"und er hat gesagt, Mama ich bin daheim.","Er sagte seiner Mutter, er sei nach Hause gekommen.",entailment,entailment
3,"Ich wusste nicht was ich vorhatte oder so, ich musste mich an einen bestimmten Ort in Washington melden.","Ich war noch nie in Washington, deshalb habe ich mich auf der Suche nach dem Ort verirrt, als ich dahin entsandt wurde.",neutral,contradiction
4,"Ich wusste nicht was ich vorhatte oder so, ich musste mich an einen bestimmten Ort in Washington melden.","Ich wusste genau, was ich tun musste, als ich nach Washington marschierte.",contradiction,contradiction
5,"Ich wusste nicht was ich vorhatte oder so, ich musste mich an einen bestimmten Ort in Washington melden.",Ich war mir nicht ganz sicher was ich tun soll und deswegen bin ich nach Washington gereist wo ich zugewiesen wurde zu berichten.,entailment,neutral


In [31]:
for test_lang in test_langs:
    print("test lang:",test_lang)
    most_mistaken_class = Counter()  
    for train_lang in train_langs:
        df = read_preds(train_lang, test_lang)
        mistaken_class_count(df)
        pred_class_count(df)
        class_pairs(df)

        most_mistaken = df.groupby(by="actual").apply(mistaken).sort_values(ascending=False).index[0]
        most_mistaken_class[most_mistaken] += 1
    print("Most mistaken classes per test lang")
    print(most_mistaken_class)

test lang: de
# misclassified examples per class


actual
entailment       249
neutral          195
contradiction    147
dtype: int64

Predicted class distribution


neutral          907
contradiction    905
entailment       677
Name: pred, dtype: int64

# predictions for each class pair


Unnamed: 0,actual,pred,count
0,contradiction,contradiction,683
8,neutral,neutral,634
4,entailment,entailment,581
5,entailment,neutral,154
6,neutral,contradiction,127
2,contradiction,neutral,119
3,entailment,contradiction,95
7,neutral,entailment,68
1,contradiction,entailment,28


Most mistaken classes per test lang
Counter({'entailment': 1})
