# Setup and imports

In [None]:
# imports
import numpy as np
from tueplots import bundles, figsizes
import wandb
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import pandas as pd


import sys

%load_ext autoreload
%autoreload 2

sys.path.insert(0, '.')

In [None]:
from analysis import sweep2df, plot_typography, stats2string, RED, BLUE

In [None]:
USETEX = True

In [None]:
plt.rcParams.update(bundles.icml2022(usetex=USETEX))
# plt.rcParams.update({
#     'text.latex.preamble': [r'\usepackage{amsfonts}', # mathbb
#                             r'\usepackage{amsmath}'] # boldsymbol
# })

In [None]:
plot_typography(usetex=USETEX, small=12, medium=16, big=20)


In [None]:
# Constants
ENTITY = "causal-representation-learning"
PROJECT = "llm-non-identifiability"

# W&B API
api = wandb.Api(timeout=200)
runs = api.runs(ENTITY + "/" + PROJECT)

# Data loading

## Normal sweep

In [None]:
SWEEP_ID = "ndbfr3qd"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"normal_{SWEEP_ID}"
df, train_loss, val_loss, val_kl, val_accuracy,finised, ood_finised, sos_finised, as_before_bs, same_as_bs, grammatical, ood_as_before_bs, ood_as_before_bs_completion, ood_same_as_bs, ood_grammatical, sos_as_before_bs, sos_same_as_bs, sos_grammatical, = sweep2df(sweep.runs, filename, save=True, load=False)

## Adversarial sweep

In [None]:
SWEEP_ID = "nohk20ol"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"adversarial_{SWEEP_ID}"
df_adversarial, train_loss_adversarial, val_loss_adversarial, val_kl_adversarial, val_accuracy_adversarial,finised_adversarial, ood_finised_adversarial, sos_finised_adversarial, as_before_bs_adversarial, same_as_bs_adversarial, grammatical_adversarial, ood_as_before_bs_adversarial, ood_as_before_bs_completion_adversarial, ood_same_as_bs_adversarial, ood_grammatical_adversarial, sos_as_before_bs_adversarial, sos_same_as_bs_adversarial, sos_grammatical_adversarial, = sweep2df(sweep.runs, filename, save=True, load=False)

## Extrapolation sweep

In [None]:
SWEEP_ID = "gnagvai4"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"extrapolation_{SWEEP_ID}"
df_extrapolation, train_loss_extrapolation, val_loss_extrapolation, val_kl_extrapolation, val_accuracy_extrapolation,finised_extrapolation, ood_finised_extrapolation, sos_finised_extrapolation, as_before_bs_extrapolation, same_as_bs_extrapolation, grammatical_extrapolation, ood_as_before_bs_extrapolation, ood_as_before_bs_completion_extrapolation, ood_same_as_bs_extrapolation, ood_grammatical_extrapolation, sos_as_before_bs_extrapolation, sos_same_as_bs_extrapolation, sos_grammatical_extrapolation, = sweep2df(sweep.runs, filename, save=True, load=False)

# Plots

## Rule extrapolation (normal and adversarial)

In [49]:
df.min_val_loss.mean(), df.min_val_loss.std()

(0.021478888902651226, 0.0011338101335613084)

In [None]:
df_adversarial.min_val_loss.mean(), df_adversarial.min_val_loss.std()

In [None]:
df_extrapolation.min_val_loss.mean(), df_extrapolation.min_val_loss.std()

In [50]:
df.ood_same_as_bs_accuracy4min_val_loss.mean(), df.ood_same_as_bs_accuracy4min_val_loss.std()

(0.43700092771778937, 0.04467848035422576)

In [51]:
df_adversarial.ood_same_as_bs_accuracy4min_val_loss.mean(), df_adversarial.ood_same_as_bs_accuracy4min_val_loss.std()

(0.0, 0.0)

In [52]:
df_extrapolation.ood_same_as_bs_accuracy4min_val_loss.mean(), df_extrapolation.ood_same_as_bs_accuracy4min_val_loss.std()

(0.8296164155006409, 0.12182615111142998)

In [None]:
TICK_PADDING = 2
LABELPAD = 1
cmap = "coolwarm"

fig = plt.figure(figsize=figsizes.icml2022_full(nrows=1, ncols=2)['figure.figsize'])

ax = fig.add_subplot(121)
# ax.grid(True, which="both", ls="-.")
ax.set_axisbelow(True)
# df.min_val_loss
im = ax.bar(x=[0.5, 1.5, 2.5], height=[df.ood_same_as_bs_accuracy4min_val_loss.mean(), df_adversarial.ood_same_as_bs_accuracy4min_val_loss.mean(), df_extrapolation.ood_same_as_bs_accuracy4min_val_loss.mean()], yerr=[df.ood_same_as_bs_accuracy4min_val_loss.std(), df_adversarial.ood_same_as_bs_accuracy4min_val_loss.std(),df_extrapolation.ood_same_as_bs_accuracy4min_val_loss.std()],  label="accuracy", width=0.5, color=BLUE)
# im = ax.scatter(df_adversarial.min_val_loss, df_adversarial.ood_same_as_bs_accuracy4min_val_loss, cmap=cmap, label="adversarial")
ax.set_ylabel("\#a=\#b %", labelpad=LABELPAD)

# set xtick names
ax.set_xticks([0.5, 1.5, 2.5])
ax.set_xticklabels(["normal", "adversarial", "extrapolation"])

# ax.set_xlabel("Minimum validation loss", labelpad=LABELPAD)
plt.legend()
ax.tick_params(axis='both', which='major', pad=TICK_PADDING)

# ax = fig.add_subplot(122)
# ax.grid(True, which="both", ls="-.")
# ax.set_axisbelow(True)
#
# im = ax.scatter(df.min_val_loss, df.ood_as_before_bs_completion_accuracy4min_val_loss, cmap=cmap, label="normal")
# im = ax.scatter(df_adversarial.min_val_loss, df_adversarial.ood_as_before_bs_completion_accuracy4min_val_loss, cmap=cmap, label="adversarial")
# ax.set_ylabel("a's before b's %", labelpad=LABELPAD)
# ax.set_xlabel("Minimum validation loss", labelpad=LABELPAD)
# plt.legend()
# ax.tick_params(axis='both', which='major', pad=TICK_PADDING)
#
#
plt.savefig("adversarial_rule_extrapolation.svg")

In [None]:
TICK_PADDING = 2
LABELPAD = 1
cmap = "coolwarm"

fig = plt.figure(figsize=figsizes.icml2022_full(nrows=1, ncols=2)['figure.figsize'])

ax = fig.add_subplot(121)
ax.grid(True, which="both", ls="-.")
ax.set_axisbelow(True)

im = ax.scatter(df.min_val_loss, df.ood_same_as_bs_accuracy4min_val_loss, cmap=cmap, label="normal")
im = ax.scatter(df_adversarial.min_val_loss, df_adversarial.ood_same_as_bs_accuracy4min_val_loss, cmap=cmap, label="adversarial")
ax.set_ylabel("\#a=\#b %", labelpad=LABELPAD)
ax.set_xlabel("Minimum validation loss", labelpad=LABELPAD)
plt.legend()
ax.tick_params(axis='both', which='major', pad=TICK_PADDING)

ax = fig.add_subplot(122)
ax.grid(True, which="both", ls="-.")
ax.set_axisbelow(True)

im = ax.scatter(df.min_val_loss, df.ood_as_before_bs_completion_accuracy4min_val_loss, cmap=cmap, label="normal")
im = ax.scatter(df_adversarial.min_val_loss, df_adversarial.ood_as_before_bs_completion_accuracy4min_val_loss, cmap=cmap, label="adversarial")
ax.set_ylabel("a's before b's %", labelpad=LABELPAD)
ax.set_xlabel("Minimum validation loss", labelpad=LABELPAD)
plt.legend()
ax.tick_params(axis='both', which='major', pad=TICK_PADDING)


plt.savefig("adversarial_rule_extrapolation.svg")

## Emergence of grammatical correctness

In [None]:
v = np.array(val_kl)

In [None]:
grammatical = np.array(grammatical)

In [None]:
# LABELPAD = 1
# TICK_PADDING = 2
#
# fig = plt.figure(figsize=figsizes.neurips2022(nrows=1, ncols=2, rel_width=1)['figure.figsize'])
#
#
# ax = fig.add_subplot(121)
# ax.grid(True, which="both", ls="-.")
#
#
#
# # Remove ticks and labels and set which side to label
# ticksoff = dict(labelleft=False, labelright=False, left=False, right=False)
# ax.tick_params(axis="y", **ticksoff)
# ax.tick_params(axis="y", labelleft=True, labelright=False, left=True, right=False)
# ax.tick_params(axis="y", labelleft=False, labelright=True, left=False, right=True)
#
#
# ax.errorbar(range(v.shape[1]), (grammatical-np.exp(-v)).mean(0), yerr=(grammatical-np.exp(-v)).std(0), label='val_kl', c=BLUE)
#
