## Text similarity probings

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from tqdm import tqdm
import numpy, pandas

import pathlib
import sys

WD = str(pathlib.Path().absolute()) + '/'
PROJECT_FOLDER = WD + '../'
PROBINGS_FOLDER = PROJECT_FOLDER + 'probings/'
DATA_FOLDER = PROJECT_FOLDER + 'data/'

sys.path.append(PROBINGS_FOLDER)

In [3]:
DATA_FOLDER = WD + '../data/'

input_texts = list([
    DATA_FOLDER + 'rte/val.jsonl',
    DATA_FOLDER + 'axb/val.jsonl',
    DATA_FOLDER + 'axg/val.jsonl',
    DATA_FOLDER + 'mnli/val.jsonl',
])

In [4]:
inputs, similarities = list(), list()
for dataset in tqdm(input_texts):
    data = pandas.read_json(dataset, lines=True)
    data = data.drop('idx', axis='columns')
    data = data['premise'].values.tolist()

    inputs.append(data)
    dataset_name = dataset.split('/data/')[1].split('/')[0]
    similarities.append(numpy.load(DATA_FOLDER + 'probings/' + dataset_name + '_pairwise_similarities.dat', allow_pickle=True))

100%|██████████| 4/4 [00:01<00:00,  2.82it/s]


## Visualization

In [6]:
import pandas

from bokeh.io import output_file, output_notebook, show, export_png
from bokeh.models import BasicTicker, ColorBar, ColumnDataSource, LinearColorMapper, PrintfTickFormatter
from bokeh.plotting import figure
from bokeh.sampledata.unemployment1948 import data
from bokeh.transform import transform

import geckodriver_autoinstaller

# palettes
from bokeh.palettes import RdBu11


def similarity_heatmap(premises, data, dataset_name, colors=RdBu11, out_file=None):
    mapper = LinearColorMapper(palette=colors, low=-1, high=+1)
    
    vals = list()
    for i in range(len(inp)):
        for j in range(i + 1, len(inp)):
            vals.append((str(i), str(j), sim[i, j], i, j))
    data = pandas.DataFrame(vals, columns=['x', 'y', 'val', 'x_int', 'y_int']).sort_values(by=['x_int', 'y_int'])
    data = data.pivot(index='x', columns='y', values='val')
    data.columns.name = 'y'
    df = pandas.DataFrame(data.stack(), columns=['val']).reset_index()
    source = ColumnDataSource(df)
    
    x_index = [str(el) for el in sorted([int(x) for x in list(data.index)])]
    y_index = [str(el) for el in sorted([int(y) for y in list(reversed(data.columns))])]
    
    p = figure(plot_width=1080, plot_height=1080, title='Pairwise premise similarity on ' + dataset_name,
               x_range=x_index, y_range=y_index)
    p.rect('x', 'y', width=1, height=1, source=source, line_color=None, fill_color={'field': 'val', 'transform':mapper})
    
    color_bar = ColorBar(color_mapper=mapper, location=(0, 0))
    p.add_layout(color_bar, 'right')
    p.axis.major_label_text_font_size = "10px"

    export_png(p, filename=out_file)
    #output_notebook()
    #show(p)

In [7]:
for inp, sim, dataset in zip(inputs[:-1], similarities[:-1], ['rte', 'axb', 'axg', 'mnli'][:-1]):
    similarity_heatmap(inp, sim, dataset, out_file=dataset + '_pairwise_similarities.png')    