In [None]:
# load the dataset (in full)
import pandas as pd
import numpy as np

headlines_df = pd.read_csv("data/dataset.csv")
lst = list(headlines_df.itertuples(index=False, name=None))

sents, labels = zip(*lst)

In [None]:
# create lookups

sent_to_index = {sent: index for index, sent in enumerate(sents)}

def is_sarcastic(sent_or_index):
    if isinstance(sent_or_index, str):
        index = sent_to_index[sent_or_index]
        return labels[index]
      
    # assume that it is an int or numpy.int64
    return labels[int(sent_or_index)]

In [None]:
from utils.misc import load

# load recorded metrics
record_path = "./records/"
eval_id = "0001"

baseline_metrics = load(f"{record_path}{eval_id}_baseline_metrics.json")
model_records = load(f"{record_path}{eval_id}_model_records.json")

### Graph performance metrics from train/eval iterations

In [None]:

import matplotlib.pyplot as plt

save_to_disk = False 
file_name = ""

data = model_records
baseline = baseline_metrics

# Creating subplots for each metric
include_overfitting = False
metrics = ['polarity_score', 'similarity_score', 'overfitting_indicator']
cutoff = 3 if include_overfitting else 2
fig, axs = plt.subplots(1, cutoff, figsize=(15, 5))

metric_labels = ["Polarity Score", "Semantic Similarity Score", "Overfitting Indicator"]
keys = sorted(list(data.keys()))
n_time_steps = len(data[keys[0]])

lookup = {
    "Contrastive(lambda=0.2)": ["black", "-"],
    "Contrastive(lambda=0.5)": ["black", "--"],
    "OnlineContrastive(lambda=0.5)": ["blue", "--"],
    "MultipleNegatives(scale=1)": ["m", "-"],
    "MultipleNegatives(scale=20)": ["m", "--"],
    "MultipleNegatives": ["m", "-"],
    "MultipleNegatives(scale=100)": ["m", "-."],
    "Triplet(lambda=5)": ["r", "-"], 
    "Triplet(lambda=2)": ["r", "--"], 
    "Triplet(lambda=0.5)": ["orange", "-"], 
    "Triplet(lambda=0.1)": ["orange", "--"], 
    "Triplet(lambda=0.05)": ["green", "-."], 
    "Triplet(lambda=0.01)": ["green", ":"], 
}

for i, metric in enumerate(metrics[:cutoff]):
    for j in range(n_time_steps):
        axs[i].axvline(x=j, color='lightgray', linestyle='-', linewidth=0.5)

    if metric in ['polarity_score', 'similarity_score']:
        axs[i].axhline(y=baseline[metric], linestyle='-.', color='gray', label='Reference Embedding')
        
    for model_name in keys:
        color, style = lookup[model_name]
        model_data = data[model_name]
        epochs = [d['epoch'] for d in model_data]
        values = [d[metric] for d in model_data]
        axs[i].plot(epochs, values, label=model_name, color=color, linestyle=style)

    axs[i].set_xlabel('Iteration / Epoch')
    axs[i].set_ylabel(metric_labels[i])
    axs[i].set_title(metric_labels[i])

legend = axs[1].legend(loc='upper center', bbox_to_anchor=(-0.1, -0.15), ncol=3, fontsize=12) 

if save_to_disk:
    plt.savefig(f"assets/{file_name}", bbox_inches='tight', dpi=300)
else:
    plt

### PCA visualizations of test sentences

In [None]:
model_names = sorted(list(model_records.keys()))
for i, model_name in enumerate(model_names):
    print(f"{i}) {model_name}")

In [None]:
from utils.visualizations import scatter_pair

# select model as well as test sentence id (0, 1, ..., n_test_sentences-1)
# iter_1 and iter_2 represent the iteration numbers at which we take the PCA transformations of the retrieved sentences 
# select a model by choosing the corresponding index from the above list
model_idx = 0 
model_name = model_names[model_idx]
sent_idx = 4
iter_1 = 1
iter_2 = -1

base_1 = model_records[model_name][iter_1]["raw_data"][sent_idx]
base_2 = model_records[model_name][iter_2]["raw_data"][sent_idx]

sentence = base_1["sentence"]
ys = [is_sarcastic(s) for s in base_1["suggestions"]] + [-1]
embs_1 = np.append(base_1["suggestion_embeddings_2d"], [base_1["embedding_2d"]], 0)
embs_2 = np.append(base_2["suggestion_embeddings_2d"], [base_2["embedding_2d"]], 0)

print(f"Model name: {model_name}")
print(f"Sentence: {sentence}")
print(f"Polarity: {'sarcastic' if is_sarcastic(sentence) else 'non-sarcastic'}")
scatter_pair(embs_1, iter_1, embs_2, iter_2, ys)