# Setup and imports

In [1]:
# imports
import numpy as np
from tueplots import bundles, figsizes
import wandb
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import pandas as pd


import sys

%load_ext autoreload
%autoreload 2

sys.path.insert(0, '.')

In [2]:
from analysis import sweep2df, plot_typography, stats2string, RED, BLUE

In [3]:
USETEX = True

In [4]:
plt.rcParams.update(bundles.icml2022(usetex=USETEX))
# plt.rcParams.update({
#     'text.latex.preamble': [r'\usepackage{amsfonts}', # mathbb
#                             r'\usepackage{amsmath}'] # boldsymbol
# })

In [5]:
plot_typography(usetex=USETEX, small=12, medium=16, big=20)


In [8]:
# Constants
ENTITY = "causal-representation-learning"
PROJECT = "rule_extrapolation"

# W&B API
api = wandb.Api(timeout=200)
runs = api.runs(ENTITY + "/" + PROJECT)

# Data loading

## baN

In [15]:
SWEEP_ID = "ntepxfn4"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"ban_{SWEEP_ID}"
ban_df,ban_train_loss,ban_val_loss,ban_val_kl,ban_val_accuracy,ban_finised,ban_ood_finised,ban_sos_finised,ban_r1,ban_r2,ban_grammatical,ban_ood_r1,ban_ood_r1_completion,ban_ood_r2,ban_ood_grammatical,ban_sos_r1,ban_sos_r2,ban_sos_grammatical= sweep2df(sweep.runs, filename, save=True, load=True)

## aNbN

In [16]:
SWEEP_ID = "iv6wtito"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"anbn_{SWEEP_ID}"
anbn_df,anbn_train_loss,anbn_val_loss,anbn_val_kl,anbn_val_accuracy,anbn_finised,anbn_ood_finised,anbn_sos_finised,anbn_r1,anbn_r2,anbn_grammatical,anbn_ood_r1,anbn_ood_r1_completion,anbn_ood_r2,anbn_ood_grammatical,anbn_sos_r1,anbn_sos_r2,anbn_sos_grammatical= sweep2df(sweep.runs, filename, save=True, load=True)

## aNbNcN

In [17]:
SWEEP_ID = "eaitw1nw"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"anbncn_{SWEEP_ID}"
anbncn_df,anbncn_train_loss,anbncn_val_loss,anbncn_val_kl,anbncn_val_accuracy,anbncn_finised,anbncn_ood_finised,anbncn_sos_finised,anbncn_r1,anbncn_r2,anbncn_grammatical,anbncn_ood_r1,anbncn_ood_r1_completion,anbncn_ood_r2,anbncn_ood_grammatical,anbncn_sos_r1,anbncn_sos_r2,anbncn_sos_grammatical= sweep2df(sweep.runs, filename, save=True, load=True)

## Matched brackets and parentheses

In [18]:
SWEEP_ID = "wr96o2v4"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"brackets_{SWEEP_ID}"
brackets_df,brackets_train_loss,brackets_val_loss,brackets_val_kl,brackets_val_accuracy,brackets_finised,brackets_ood_finised,brackets_sos_finised,brackets_r1,brackets_r2,brackets_grammatical,brackets_ood_r1,brackets_ood_r1_completion,brackets_ood_r2,brackets_ood_grammatical,brackets_sos_r1,brackets_sos_r2,brackets_sos_grammatical= sweep2df(sweep.runs, filename, save=True, load=True)

# Plots

## Rule extrapolation (normal and adversarial)

In [None]:
df.min_val_loss.mean(), df.min_val_loss.std()

In [None]:
df_adversarial.min_val_loss.mean(), df_adversarial.min_val_loss.std()

In [None]:
df_extrapolation.min_val_loss.mean(), df_extrapolation.min_val_loss.std()

In [None]:
df.ood_r2_accuracy4min_val_loss.mean(), df.ood_r2_accuracy4min_val_loss.std(), df.ood_r2_accuracy4min_val_loss.min(), df.ood_r2_accuracy4min_val_loss.max()

In [None]:
df_adversarial.ood_r2_accuracy4min_val_loss.mean(), df_adversarial.ood_r2_accuracy4min_val_loss.std(), df_adversarial.ood_r2_accuracy4min_val_loss.min(), df_adversarial.ood_r2_accuracy4min_val_loss.max()

In [None]:
df_extrapolation.ood_r2_accuracy4min_val_loss.mean(), df_extrapolation.ood_r2_accuracy4min_val_loss.std(), df_extrapolation.ood_r2_accuracy4min_val_loss.min(), df_extrapolation.ood_r2_accuracy4min_val_loss.max()

In [None]:
TICK_PADDING = 2
LABELPAD = 1
cmap = "coolwarm"

fig = plt.figure(figsize=figsizes.icml2022_full(nrows=1, ncols=2)['figure.figsize'])

ax = fig.add_subplot(121)
# ax.grid(True, which="both", ls="-.")
ax.set_axisbelow(True)
# df.min_val_loss
im = ax.bar(x=[0.5, 1.5, 2.5], height=[df.ood_r2_accuracy4min_val_loss.mean(), df_adversarial.ood_r2_accuracy4min_val_loss.mean(), df_extrapolation.ood_r2_accuracy4min_val_loss.mean()], yerr=[df.ood_r2_accuracy4min_val_loss.std(), df_adversarial.ood_r2_accuracy4min_val_loss.std(),df_extrapolation.ood_r2_accuracy4min_val_loss.std()],  label="accuracy", width=0.5, color=BLUE)
# im = ax.scatter(df_adversarial.min_val_loss, df_adversarial.ood_r2_accuracy4min_val_loss, cmap=cmap, label="adversarial")
ax.set_ylabel("\#a=\#b %", labelpad=LABELPAD)

# set xtick names
ax.set_xticks([0.5, 1.5, 2.5])
ax.set_xticklabels(["normal", "adversarial", "extrapolation"])

# ax.set_xlabel("Minimum validation loss", labelpad=LABELPAD)
plt.legend()
ax.tick_params(axis='both', which='major', pad=TICK_PADDING)

# ax = fig.add_subplot(122)
# ax.grid(True, which="both", ls="-.")
# ax.set_axisbelow(True)
#
# im = ax.scatter(df.min_val_loss, df.ood_r1_completion_accuracy4min_val_loss, cmap=cmap, label="normal")
# im = ax.scatter(df_adversarial.min_val_loss, df_adversarial.ood_r1_completion_accuracy4min_val_loss, cmap=cmap, label="adversarial")
# ax.set_ylabel("a's before b's %", labelpad=LABELPAD)
# ax.set_xlabel("Minimum validation loss", labelpad=LABELPAD)
# plt.legend()
# ax.tick_params(axis='both', which='major', pad=TICK_PADDING)
#
#
plt.savefig("adversarial_rule_extrapolation.svg")

In [None]:
TICK_PADDING = 2
LABELPAD = 1
cmap = "coolwarm"

fig = plt.figure(figsize=figsizes.icml2022_full(nrows=1, ncols=2)['figure.figsize'])

ax = fig.add_subplot(121)
ax.grid(True, which="both", ls="-.")
ax.set_axisbelow(True)

im = ax.scatter(df.min_val_loss, df.ood_r2_accuracy4min_val_loss, cmap=cmap, label="normal")
im = ax.scatter(df_adversarial.min_val_loss, df_adversarial.ood_r2_accuracy4min_val_loss, cmap=cmap, label="adversarial")
ax.set_ylabel("\#a=\#b %", labelpad=LABELPAD)
ax.set_xlabel("Minimum test loss", labelpad=LABELPAD)
plt.legend()
ax.tick_params(axis='both', which='major', pad=TICK_PADDING)

ax = fig.add_subplot(122)
ax.grid(True, which="both", ls="-.")
ax.set_axisbelow(True)

im = ax.scatter(df.min_val_loss, df.ood_r1_completion_accuracy4min_val_loss, cmap=cmap, label="normal")
im = ax.scatter(df_adversarial.min_val_loss, df_adversarial.ood_r1_completion_accuracy4min_val_loss, cmap=cmap, label="adversarial")
ax.set_ylabel("a's before b's %", labelpad=LABELPAD)
ax.set_xlabel("Minimum validation loss", labelpad=LABELPAD)
plt.legend()
ax.tick_params(axis='both', which='major', pad=TICK_PADDING)


plt.savefig("adversarial_rule_extrapolation.svg")

In [None]:
def _prune_histories(histories):
    min_len = np.array([len(v) for v in histories]).min()
    return np.array([v[:min_len] for v in histories])

In [None]:
val_loss_pruned = _prune_histories(val_loss)
val_loss_adversarial_pruned = _prune_histories(val_loss_adversarial)
val_loss_extrapolation_pruned = _prune_histories(val_loss_extrapolation)

ood_r2_pruned = _prune_histories(ood_r2)
ood_r2_adversarial_pruned = _prune_histories(ood_r2_adversarial)
ood_r2_extrapolation_pruned = _prune_histories(ood_r2_extrapolation)

In [None]:
from scipy.interpolate import make_interp_spline
def spline_interpolation(x, y, num=100):
    xnew = np.linspace(x.min(), x.max(), num=num)
    ynew = make_interp_spline(x, y)(xnew)
    return xnew, ynew

In [None]:
TICK_PADDING = 2
LABELPAD = 2
cmap = "coolwarm"
EPS=1e-8

fig = plt.figure(figsize=figsizes.icml2022_full(nrows=2, ncols=3)['figure.figsize'])


ax = fig.add_subplot(121)
ax.grid(True, which="both", ls="-.")
ax.set_axisbelow(True)


# max likelihood
# for as_bs, val in zip(ood_r2, val_loss):
#     if len(as_bs) == 0:
#         continue
#     ax.plot(as_bs, val, alpha=0.2, c=BLUE)

im = ax.plot(100*ood_r2_pruned.mean(0), val_loss_pruned.mean(0),
                 label="max likelihood", c=BLUE)




im = ax.scatter(100*ood_r2_pruned.mean(0)[0], val_loss_pruned.mean(0)[0], c=BLUE, marker="o", s=35)
im = ax.scatter(100*ood_r2_pruned.mean(0)[-1], val_loss_pruned.mean(0)[-1], c=BLUE, marker="*", s=100)

# adversarial
# for as_bs, val in zip(ood_r2_adversarial, val_loss_adversarial):
#     if len(as_bs) == 0:
#         continue
#     ax.plot(as_bs, val,  alpha=0.2, c=RED)

im = ax.plot(100* ood_r2_adversarial_pruned.mean(0), val_loss_adversarial_pruned.mean(0),
                 label="adversarial", c=RED)



im = ax.scatter(100*ood_r2_adversarial_pruned.mean(0)[0], val_loss_adversarial_pruned.mean(0)[0], c=RED, marker="o", s=35)
im = ax.scatter(100*ood_r2_adversarial_pruned.mean(0)[-1], val_loss_adversarial_pruned.mean(0)[-1], c=RED, marker="*", s=100)


# oracle
# for as_bs, val in zip(ood_r2_extrapolation, val_loss_extrapolation):
#     if len(as_bs) == 0:
#         continue
#     ax.plot(as_bs, val,  alpha=0.2, c="green")

im = ax.plot(100* ood_r2_extrapolation_pruned.mean(0), val_loss_extrapolation_pruned.mean(0),
                 label="oracle", c="green")

im = ax.scatter(100*ood_r2_extrapolation_pruned.mean(0)[0], val_loss_extrapolation_pruned.mean(0)[0], c="green", marker="o", s=35)
im = ax.scatter(100*ood_r2_extrapolation_pruned.mean(0)[-1], val_loss_extrapolation_pruned.mean(0)[-1], c="green", marker="*", s=100)


handles, labels = ax.get_legend_handles_labels()




# Create the linestyle-based legend and set its location
start_marker = mlines.Line2D([], [], color='black', marker='o', linestyle='None', markersize=5, label='start')
end_marker = mlines.Line2D([], [], color='black', marker='*', linestyle='None', markersize=10, label='end')
handles2 = [start_marker, end_marker]
labels2 = ['start', 'end']
legend = ax.legend([*handles, *handles2], [*labels, *labels2], loc='upper right')

ax.add_artist(legend)

# set xtick names
ax.set_xticks([0, 20, 40, 60, 80,])
ax.set_xticklabels(["0\%", "20\%", "40\%", "60\%", "80\%"])

ax.set_xlabel("OOD Accuracy of Rule 1 (\#a=\#b)", labelpad=LABELPAD)
ax.set_ylabel("Test loss", labelpad=LABELPAD)


plt.savefig("rule_extrapolation.svg")

## Emergence of grammatical correctness

In [None]:
v = np.array(val_kl)

In [None]:
grammatical = np.array(grammatical)

In [None]:
LABELPAD = 1
TICK_PADDING = 2
#
fig = plt.figure(figsize=figsizes.neurips2022(nrows=1, ncols=2, rel_width=1)['figure.figsize'])
#
#
ax = fig.add_subplot(121)
ax.grid(True, which="both", ls="-.")

#
# r2 = np.array(r2)
# val_loss = np.array(val_loss)

for as_bs, val in zip(ood_r2, val_loss):
    if len(as_bs) == 0:
        continue
    ax.plot(val, as_bs, alpha=0.3)
# ax.scatter(val_loss, r2, label="normal", c=BLUE)

for as_bs, val in zip(ood_r2_adversarial, val_loss_adversarial):
    if len(as_bs) == 0:
        continue
    ax.plot(val, as_bs, alpha=0.3)
#
#
# # Remove ticks and labels and set which side to label
# ticksoff = dict(labelleft=False, labelright=False, left=False, right=False)
# ax.tick_params(axis="y", **ticksoff)
# ax.tick_params(axis="y", labelleft=True, labelright=False, left=True, right=False)
# ax.tick_params(axis="y", labelleft=False, labelright=True, left=False, right=True)
#
#
# ax.errorbar(range(v.shape[1]), (grammatical-np.exp(-v)).mean(0), yerr=(grammatical-np.exp(-v)).std(0), label='val_kl', c=BLUE)
#


In [None]:
ax.plot(r2[0])