In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width: 98% !important; }</style>"))
import pandas as pd
pd.set_option('display.max_colwidth', None)

In [None]:
from typing import List, Optional

In [None]:
import sys
import os
import re
import numpy as np
import pandas as pd

from tqdm.notebook import tqdm
from sklearn.metrics import f1_score, confusion_matrix

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
sns.set(font_scale=1.2)

In [None]:
from lib import util, fitter, visualize

In [None]:
datapath = 'SemEval_2022_Task2-idiomaticity/SubTaskA/Data'
testpath = 'SemEval_2022_Task2-idiomaticity/SubTaskA/TestData'

In [None]:
frames = util.load_csv_dataframes(datapath)
tframes = util.load_csv_dataframes(testpath)

In [None]:
zdf = frames['train_zero_shot.csv']
odf = frames['train_one_shot.csv']
ddf = frames['dev.csv']
ddf_gold = frames['dev_gold.csv']
edf = frames['eval.csv']
tdf = tframes['test.csv']

In [None]:
zdf_bt3 = pd.read_pickle('data/zdf_bt3_20220104_1.pkl')
ddf_bt3 = pd.read_pickle('data/ddf_bt3_20220104_1.pkl')
edf_bt3 = pd.read_pickle('data/edf_bt3_20220105_1.pkl')
tdf_bt3 = pd.read_pickle('data/tdf_bt3_20220111_1.pkl')

In [None]:
zdf_t = fitter.get_trainable(zdf_bt3)
ddf_t = fitter.get_trainable(ddf_bt3)
zdf_t['Label'] = zdf_bt3['Label']
ddf_t['Label'] = ddf_gold['Label']
zdf_t['Set'] = 'Train'
ddf_t['Set'] = 'Dev'

In [None]:
# pairplot = sns.pairplot(zdf_t.drop(['Quotes', 'Hassub', 'Caps', 'Trans'], axis=1), hue='Label')

##### Let's get some statistics

In [None]:
zdf_bt3.groupby(['Language','Label'])['DataID'].count()

In [None]:
zdf_bt3.groupby(['Language','Label'])['DataID'].count()/len(zdf_bt3)

Visualizing class distribution per language for training and dev sets.

In [None]:
zdf_counts = visualize.df_heatmap(zdf, zdf, col='Language')
zdf_counts.set_xlabel('Label', fontsize=16)
zdf_counts.set_ylabel('Language', fontsize=16)

In [None]:
# util.save_picture(zdf_counts.figure, name='train_counts')

In [None]:
ddf_gold.groupby(['Language','Label'])['DataID'].count()

In [None]:
ddf_counts = visualize.df_heatmap(ddf, ddf_gold, col='Language')
ddf_counts.set_xlabel('Label', fontsize=16)
ddf_counts.set_ylabel('Language', fontsize=16)

In [None]:
# util.save_picture(ddf_counts.figure, name='dev_counts')

In [None]:
ddf_res = util.load_df('data/ddf_sub_20220121_1.csv')
ddf_res2 = util.load_df('data/ddf_sub_20220121_2.csv')

Confusion matrix for the sbert+feature model.

In [None]:
df_res_heatmap = visualize.df_heatmap(ddf_gold, ddf_res[ddf_res['Setting'] == 'zero_shot'])
df_res_heatmap.set_xlabel('Predicted label', fontsize=16)
df_res_heatmap.set_ylabel('Actual label', fontsize=16)

In [None]:
# util.save_picture(df_res_heatmap.figure, name='sbert_feat_confusion')

In [None]:
# Sanity check: Predicted = 0, Actual = 1 should produce the value in the lower left cell
ddf_res[(ddf_res['Setting'] == 'zero_shot') & (ddf_res['Label'] == '0') & (ddf_gold['Label'] == '1')]['ID'].count()

Confusion matrix for the bert+feature model.

In [None]:
df_res2_heatmap = visualize.df_heatmap(ddf_gold, ddf_res2[ddf_res2['Setting'] == 'zero_shot'])
df_res2_heatmap.set_xlabel('Predicted label', fontsize=16)
df_res2_heatmap.set_ylabel('Actual label', fontsize=16)

In [None]:
# util.save_picture(df_res2_heatmap.figure, name='bert_feat_confusion')

In [None]:
ddf_res2[(ddf_res2['Setting'] == 'zero_shot') & (ddf_res2['Label'] == '0') & (ddf_gold['Label'] == '1')]['ID'].count()

In [None]:
def feat_show(df, feats: List[str], x: str, hue: str, savename: Optional[str] = None, numcols: int = 4):
    if not savename:
        rowcount = len(feats) // numcols + 1
        fig, axes = plt.subplots(rowcount, numcols, sharex=True, figsize=(5*numcols, 5*rowcount))
    row=0
    col=0
    for f in feats:
        if savename:
            fig, axes = plt.subplots(1, 1, sharex=True, figsize=(5, 5))
            ax = sns.boxplot(ax=axes, data=df, y=f, hue=hue, x=x)
            util.save_picture(ax.figure, name=savename + '_' + f)
        else:
            if len(feats) < numcols:
                sns.boxplot(ax=axes[row], data=df, y=f, hue=hue, x=x)
            else:
                sns.boxplot(ax=axes[row, col], data=df, y=f, hue=hue, x=x)

            col += 1
            if col >= numcols:
                row += 1
                col = 0


In [None]:
feat_show(zdf_bt3, ['Sentiment', 'Nextdiff', 'Prevdiff', 'MWEdiff', 'Top score', 'Top score 1', 'Top score 2', 'FoundScore'], 'Language', 'Label')

In [None]:
# feat_show(zdf_bt3, ['Sentiment', 'Nextdiff', 'Prevdiff', 'MWEdiff', 'Top score', 'Top score 1', 'Top score 2', 'FoundScore'], 'Language', 'Label', 'feat')

In [None]:
comb_t = pd.concat([zdf_t, ddf_t], ignore_index=True)
comb_bt3 = pd.concat([zdf_bt3, ddf_bt3], ignore_index=True)

In [None]:
fig5, axes5 = plt.subplots(3, 4, figsize=(20,15))
row=0
col=0
for column in comb_t.drop(['Label', 'Hassub', 'Caps', 'Quotes', 'Trans', 'Set'], axis=1):
    # print(row,col)
    sns.boxplot(ax=axes5[row, col], y=comb_t[column], hue=comb_t['Set'], x=comb_t['Label'])
    col += 1
    if col >= axes5.shape[1]:
        row += 1
        col = 0

In [None]:
fig6, axes6 = plt.subplots(1, 4, figsize=(20,4))
sns.countplot(ax=axes6[0], data=zdf_bt3, x='Hassub', hue='Label')
sns.countplot(ax=axes6[1], data=zdf_bt3, x='Quotes', hue='Label')
sns.countplot(ax=axes6[2], data=zdf_bt3, x='Caps', hue='Label')
sns.countplot(ax=axes6[3], data=zdf_bt3, x='Trans', hue='Label')
fig7, axes7 = plt.subplots(1, 4, figsize=(20,4))
sns.countplot(ax=axes7[0], data=zdf_bt3[zdf_bt3['Language'] == 'EN'], x='Hassub', hue='Label')
sns.countplot(ax=axes7[1], data=zdf_bt3[zdf_bt3['Language'] == 'EN'], x='Quotes', hue='Label')
sns.countplot(ax=axes7[2], data=zdf_bt3[zdf_bt3['Language'] == 'EN'], x='Caps', hue='Label')
sns.countplot(ax=axes7[3], data=zdf_bt3[zdf_bt3['Language'] == 'EN'], x='Trans', hue='Label')
fig8, axes8 = plt.subplots(1, 4, figsize=(20,4))
sns.countplot(ax=axes8[0], data=zdf_bt3[zdf_bt3['Language'] == 'PT'], x='Hassub', hue='Label', hue_order=['0','1'])
sns.countplot(ax=axes8[1], data=zdf_bt3[zdf_bt3['Language'] == 'PT'], x='Quotes', hue='Label', hue_order=['0','1'])
sns.countplot(ax=axes8[2], data=zdf_bt3[zdf_bt3['Language'] == 'PT'], x='Caps', hue='Label', hue_order=['0','1'])
sns.countplot(ax=axes8[3], data=zdf_bt3[zdf_bt3['Language'] == 'PT'], x='Trans', hue='Label', hue_order=['0','1'])

In [None]:
# util.save_picture(fig6, name='bool_all')
# util.save_picture(fig7, name='bool_en')
# util.save_picture(fig8, name='bool_pt')

In [None]:
fig9, axes9 = plt.subplots(3, 4, figsize=(20,15))
row=0
col=0
for column in zdf_t.drop(['Label', 'Hassub', 'Caps', 'Quotes', 'Trans', 'Set'], axis=1):
    # print(row,col)
    sns.violinplot(ax=axes9[row, col], y=zdf_t[column], hue=zdf_t['Label'], x=zdf_bt3['Language'], split=True)
    col += 1
    if col >= axes9.shape[1]:
        row += 1
        col = 0

In [None]:
# util.save_picture(fig9, name='violin')

In [None]:
fig10, axes10 = plt.subplots(1, 4, figsize=(20,4))
sns.countplot(ax=axes10[0], data=zdf_bt3, hue='Hassub', x='Label')
sns.countplot(ax=axes10[1], data=zdf_bt3, hue='Quotes', x='Label')
sns.countplot(ax=axes10[2], data=zdf_bt3, hue='Caps', x='Label')
sns.countplot(ax=axes10[3], data=zdf_bt3, hue='Trans', x='Label')
fig11, axes11 = plt.subplots(1, 4, figsize=(20,4))
sns.countplot(ax=axes11[0], data=zdf_bt3[zdf_bt3['Language'] == 'EN'], hue='Hassub', x='Label')
sns.countplot(ax=axes11[1], data=zdf_bt3[zdf_bt3['Language'] == 'EN'], hue='Quotes', x='Label')
sns.countplot(ax=axes11[2], data=zdf_bt3[zdf_bt3['Language'] == 'EN'], hue='Caps', x='Label')
sns.countplot(ax=axes11[3], data=zdf_bt3[zdf_bt3['Language'] == 'EN'], hue='Trans', x='Label')
fig12, axes12 = plt.subplots(1, 4, figsize=(20,4))
sns.countplot(ax=axes12[0], data=zdf_bt3[zdf_bt3['Language'] == 'PT'], hue='Hassub', x='Label', order=['0','1'])
sns.countplot(ax=axes12[1], data=zdf_bt3[zdf_bt3['Language'] == 'PT'], hue='Quotes', x='Label', order=['0','1'])
sns.countplot(ax=axes12[2], data=zdf_bt3[zdf_bt3['Language'] == 'PT'], hue='Caps', x='Label', order=['0', '1'])
sns.countplot(ax=axes12[3], data=zdf_bt3[zdf_bt3['Language'] == 'PT'], hue='Trans', x='Label', order=['0', '1'])

In [None]:
# util.save_picture(fig10, name='label_bool_all')
# util.save_picture(fig11, name='label_bool_en')
# util.save_picture(fig12, name='label_bool_pt')

In [None]:
heatmap_z = sns.clustermap(zdf_t.corr(), cbar_pos=(.1, .5, .03, .2), cmap="Blues")
heatmap_z.ax_row_dendrogram.remove()

In [None]:
# util.save_picture(heatmap_z, name='heatmap')

In [None]:
# cf = pd.crosstab(zdf_bt3['Trans'], zdf_bt3['Label'])
# sns.heatmap(cf, annot=True, cmap='Blues', fmt='d')
# sns.jointplot(data=zdf_bt3, x='Sentiment', y='Top score', hue='Label')
# fig11, axes11 = plt.subplots(1, 2, figsize=(8,4))
# sns.kdeplot(ax=axes11[0],data=zdf_bt3[zdf_bt3['Language'] == 'EN'], hue='Label', x='Sentiment', fill=True)
# sns.kdeplot(ax=axes11[1],data=zdf_bt3[zdf_bt3['Language'] == 'PT'], hue='Label', x='Sentiment', fill=True)

In [None]:
fig13, axes13 = plt.subplots(1, 4, figsize=(20,4))
sns.countplot(ax=axes13[0], data=zdf_bt3, hue='Hassub', x='Language').set_title('Training')
sns.countplot(ax=axes13[1], data=ddf_bt3, hue='Hassub', x='Language').set_title('Development')
sns.countplot(ax=axes13[2], data=edf_bt3, hue='Hassub', x='Language').set_title('Evaluation')
sns.countplot(ax=axes13[3], data=tdf_bt3, hue='Hassub', x='Language').set_title('Test')

In [None]:
fig14, axes14 = plt.subplots(1, 4, figsize=(20,4))
sns.countplot(ax=axes14[0], data=zdf_bt3, hue='Trans', x='Language').set_title('Training')
sns.countplot(ax=axes14[1], data=ddf_bt3, hue='Trans', x='Language').set_title('Development')
sns.countplot(ax=axes14[2], data=edf_bt3, hue='Trans', x='Language').set_title('Evaluation')
sns.countplot(ax=axes14[3], data=tdf_bt3, hue='Trans', x='Language').set_title('Test')

In [None]:
# util.save_picture(fig13, name='hassub_byset')
# util.save_picture(fig14, name='trans_byset')

In [None]:
fig15, axes15 = plt.subplots(1, 4, figsize=(20,4))
sns.boxplot(ax=axes15[0], data=zdf_bt3, x='Language', y='Sentiment').set_title('Training')
sns.boxplot(ax=axes15[1], data=ddf_bt3, x='Language', y='Sentiment').set_title('Development')
sns.boxplot(ax=axes15[2], data=edf_bt3, x='Language', y='Sentiment').set_title('Evaluation')
sns.boxplot(ax=axes15[3], data=tdf_bt3, x='Language', y='Sentiment').set_title('Test')

In [None]:
fig16, axes16 = plt.subplots(1, 4, figsize=(20,4))
sns.violinplot(ax=axes16[0], data=zdf_bt3, x='Language', y='Sentiment').set_title('Training')
sns.violinplot(ax=axes16[1], data=ddf_bt3, x='Language', y='Sentiment').set_title('Development')
sns.violinplot(ax=axes16[2], data=edf_bt3, x='Language', y='Sentiment').set_title('Evaluation')
sns.violinplot(ax=axes16[3], data=tdf_bt3, x='Language', y='Sentiment').set_title('Test')

In [None]:
# util.save_picture(fig15, name='sentiment_box_byset')
# util.save_picture(fig16, name='sentiment_violin_byset')

In [None]:
fig17, axes17 = plt.subplots(1, 4, figsize=(20,4))
sns.countplot(ax=axes17[0], data=zdf_bt3, hue='Caps', x='Language').set_title('Training')
sns.countplot(ax=axes17[1], data=ddf_bt3, hue='Caps', x='Language').set_title('Development')
sns.countplot(ax=axes17[2], data=edf_bt3, hue='Caps', x='Language').set_title('Evaluation')
sns.countplot(ax=axes17[3], data=tdf_bt3, hue='Caps', x='Language').set_title('Test')

In [None]:
fig18, axes18 = plt.subplots(1, 4, figsize=(20,4))
sns.countplot(ax=axes18[0], data=zdf_bt3, hue='Quotes', x='Language').set_title('Training')
sns.countplot(ax=axes18[1], data=ddf_bt3, hue='Quotes', x='Language').set_title('Development')
sns.countplot(ax=axes18[2], data=edf_bt3, hue='Quotes', x='Language').set_title('Evaluation')
sns.countplot(ax=axes18[3], data=tdf_bt3, hue='Quotes', x='Language').set_title('Test')

In [None]:
meanvals = pd.DataFrame(columns=['Set', 'Language', 'Feature', 'Score'])
for setname, setdf in zip(['Training', 'Development', 'Evaluation', 'Test'], [zdf_bt3, ddf_bt3, edf_bt3, tdf_bt3]):
    for language in setdf.Language.unique():
        v = setdf[setdf['Language'] == language].mean()
        for k in v.keys():
            if k in ['FoundIdx', 'Label' ,'ID']:
                continue
            val = v[k]
            meanvals.loc[len(meanvals)] = [setname, language, k, val]

In [None]:
meanvals

In [None]:
fig19, axes19 = plt.subplots(2, 1, figsize=(20,10))
sns.barplot(data=meanvals, ax=axes19[0], hue='Set', x='Feature', y='Score')
sns.barplot(data=meanvals, ax=axes19[1], hue='Language', x='Feature', y='Score')