# Synth Data

In [None]:
from measureLM import helpers, measuring, synth_data

## Load Model and Data

In [None]:
df = synth_data.load_synth_data(n=None)
model_name = "gpt2-medium"
model = measuring.load_model(model_name=model_name, device="cpu")

scales = [["good", "bad"]]

## Compute Biases

In [None]:
prompt = "The relationship between {ent1} and {ent2} is"
df = synth_data.measure_scale(df, prompt, model, scales, prefix="bias ")
df

## Compute Context Influence

In [None]:
pos_prompt = "{ent1} loves {ent2}. The relationship between {ent1} and {ent2} is"
df = synth_data.measure_scale(df, pos_prompt, model, scales, prefix="pos ")

In [None]:
neg_prompt = "{ent1} hates {ent2}. The relationship between {ent1} and {ent2} is"
df = synth_data.measure_scale(df, neg_prompt, model, scales, prefix="neg ")

In [None]:
df

In [None]:
import numpy as np
import matplotlib
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from measureLM import helpers
from matplotlib.lines import Line2D
import matplotlib.patches as mpatches

matplotlib.rcParams['axes.spines.right'] = False
matplotlib.rcParams['axes.spines.top'] = False


## preprocessing__________________
df = df.sort_values("bias good-bad")
df = df.reset_index(drop=True)

ent1, ent2, labels = df["ent1"].to_list(), df["ent2"].to_list(), df["label"].to_list()
ent1_2 = [f"{e1}–{e2}" for e1, e2 in list(zip(ent1, ent2))]
x_vals = np.arange(0, len(ent1_2))

# draw plot_____________________
titlefont = 20
labelfont = 14

fig, ax = plt.subplots(1, 1, figsize=(20, 3), gridspec_kw={'hspace': 0.4})

vals = df["bias good-bad"].to_numpy()
prior_scatter = ax.scatter(x_vals, vals, s=200, alpha=1.0, c=vals, cmap=cm.coolwarm_r)

pos_vals = df["pos good-bad"].to_numpy()
pos_scatter = ax.scatter(x_vals, pos_vals, s=100, alpha=0.8, marker="v", color="blue")

neg_vals = df["neg good-bad"].to_numpy()
neg_scatter = ax.scatter(x_vals, neg_vals, s=100, alpha=0.8, marker="^", color="red")

ax.hlines(y=vals.mean(), xmin=x_vals.min() - 1, xmax=x_vals.max() + 1, linewidth=2, linestyle='--', color='grey')
# for x, y in zip(x_vals, vals):
# t = ax.text(x, y, round(y, 1), horizontalalignment='center',
# verticalalignment='center', fontdict={'color':'white'})

ax.xaxis.set_ticks(x_vals)
ax.tick_params(axis='both', which='major', labelsize=labelfont)
ax.set_xticklabels(ent1_2, fontsize=labelfont, rotation=90)
ax.set_ylim(vals.min() - 0.05, vals.max() + 0.05)
ax.set_xlim(-0.5)
# ax.set_title(scale_name, fontsize=titlefont, color="black", loc='center')
ax.set_ylabel("good-bad scale", fontsize=labelfont)

for i, x_tick_label in enumerate(ax.get_xticklabels()):
    label = labels[i]
    if label == "enemy":
        label_name = "E"
        color = "red"
    elif label == "friend":
        color = "blue"
        label_name = "F"
    position = x_tick_label.get_position()
    ax.text(position[0]-0.33, 0.435, label_name, fontsize=labelfont, color=color, verticalalignment='top')
    #x_tick_label.set_color(color)
    x_tick_label.set_y(-.1)
    
prior_scatter = Line2D([0], [0], label='The relationship between A and B is', marker='.',markersize=22, color='grey', linestyle='')
pos_scatter = Line2D([0], [0], label='prepended context: A loves B.', marker='v', markersize=10, color='blue', linestyle='')
neg_scatter = Line2D([0], [0], label='prepended context: A hates B.', marker='^',markersize=10, color='red', linestyle='')

# add manual symbols to auto legend
handles, labels = plt.gca().get_legend_handles_labels()
handles.extend([prior_scatter, pos_scatter, neg_scatter])

plt.legend(handles=handles, ncol=len(handles), prop={'size': labelfont}, facecolor='white', framealpha=0, loc='upper left', bbox_to_anchor=(0.0, 1.1))
plt.show()

fig.savefig(helpers.ROOT_DIR / "results" / "plots" / f"{model_name}.pdf", bbox_inches='tight', dpi=200,transparent=True)