In [None]:
import numpy as np
import json
from warnings import simplefilter
simplefilter(action="ignore", category=FutureWarning)
import pandas as pd
simplefilter(action="ignore", category=pd.errors.PerformanceWarning)
import scipy.stats
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import os
import sys
import collections
import torch
from toytask_utils import make_tasks, all_divisions

pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)

# Generate Initial Data

In [None]:
n_samples = 100 # number of times to resample each class
x_values = [-1,1] # valid causal x values
y_values = [i for i in range(5)] # valid causal y values
std = 0.1 # standard deviation of noise along causal dims
cov = 0.2 # covariance between causal dims

varbs = []
for x in x_values:
    for y in y_values:
        varbs.append([x,y])
varbs = np.asarray(varbs)

samples = []
og_varbs = []
xmean = 0
ymean = 0
for _ in range(n_samples):
    samp = varbs.copy().astype(float)
    og_varbs.append(varbs.copy())
    samp[:,0] += cov*samp[:,1]
    noise = std*np.random.randn(*samp.shape)
    samp = samp + noise
    samples.append(samp)
samples = np.vstack(samples)
samples = samples - samples.mean(0)
og_varbs = np.vstack(og_varbs)
print(samples.shape)

In [None]:
df = pd.DataFrame({
    "x": samples[:,0],
    "y": samples[:,1],
    "hue": samples[:,1],
})
df["x"] = (df["x"]-np.mean(df["x"]))
df["hue"] = df["hue"]-np.min(df["hue"])
df["hue"] = df["hue"]/np.max(df["hue"])


fontsize=25
legendsize = 25
alpha = 0.8
dark = 0.2
light = 0.85
rot = 0
thickness = 2
fig = plt.figure()
ax = plt.gca()

intrv_df = df.copy()
intrv_df["y"] = np.asarray(intrv_df["y"])[np.random.permutation(len(intrv_df)).astype(int)]
intrv_cmap = sns.cubehelix_palette(start=-.3, rot=rot, dark=dark, light=light, reverse=True, as_cmap=True)
#intrv_cmap = sns.dark_palette("blue", as_cmap=True)
sns.scatterplot(x="x", y="y", alpha=alpha, data=intrv_df, ax=ax, hue="hue", palette=intrv_cmap, edgecolor="none")

native_cmap = sns.cubehelix_palette(start=0.7, rot=rot, dark=dark, light=light, reverse=True, as_cmap=True)
#native_cmap = sns.dark_palette("red", as_cmap=True)
sns.scatterplot(x="x", y="y", alpha=alpha, data=df, ax=ax, hue="hue", palette=native_cmap, edgecolor="none")
                #hue="hue", palette="blue")
    
## y divider
ax.plot([0,0],[-3,3], "k--", alpha=0.5, linewidth=thickness)
# x dividers
for i in y_values[:-1]:
    y = i+0.5-2
    ax.plot([-2,2],[y,y], "k--", alpha=0.5, linewidth=thickness)
plt.xlim([-2,2])
plt.ylim([-2.75,2.75])

plt.xlabel("", fontsize=fontsize)
plt.ylabel("", fontsize=fontsize)

plt.xticks([], fontsize=fontsize)
plt.yticks([], fontsize=fontsize)

# # Manually create colorbars / legend patches
# native_cmap = sns.cubehelix_palette(start=0.7, rot=rot, dark=dark, light=light, reverse=True, as_cmap=True)
# intrv_cmap = sns.cubehelix_palette(start=-.3, rot=rot, dark=dark, light=light, reverse=True, as_cmap=True)

# Legend handles: colored rectangles with labels
#native_patch = mpatches.Patch(color=native_cmap(0.8), label="Native")
#intrv_patch = mpatches.Patch(color=intrv_cmap(0.8), label="Intervened")

#ax.legend(handles=[native_patch, intrv_patch], fontsize=legendsize, loc="upper right", bbox_to_anchor=(1.75,1))
plt.legend().set_visible(False)
#plt.savefig("figs/example_divergence.png", dpi=600, bbox_inches="tight")

plt.show()

In [None]:
import matplotlib.patches as mpatches
import matplotlib as mpl

df = pd.DataFrame({
    "x": samples[:,0],
    "y": samples[:,1],
    "hue": samples[:,1],
})
df["x"] = (df["x"]-np.mean(df["x"]))
df["hue"] = df["hue"]-np.min(df["hue"])
df["hue"] = df["hue"]/np.max(df["hue"])



fontsize=25
legendsize = 25
alpha = 0.8
dark = 0.2
light = 0.85
rot = 0
fig = plt.figure()
ax = plt.gca()

intrv_df = df.copy()
intrv_df["y"] = np.asarray(intrv_df["y"])[np.random.permutation(len(intrv_df)).astype(int)]
intrv_cmap = sns.cubehelix_palette(start=-.3, rot=rot, dark=dark, light=light, reverse=True, as_cmap=True)
#intrv_cmap = sns.dark_palette("blue", as_cmap=True)
sns.scatterplot(x="x", y="y", alpha=alpha, data=intrv_df, ax=ax, hue="hue", palette=intrv_cmap, edgecolor="none")

native_cmap = sns.cubehelix_palette(start=0.7, rot=rot, dark=dark, light=light, reverse=True, as_cmap=True)
#native_cmap = sns.dark_palette("red", as_cmap=True)
sns.scatterplot(x="x", y="y", alpha=alpha, data=df, ax=ax, hue="hue", palette=native_cmap, edgecolor="none")
                #hue="hue", palette="blue")
    
## y divider
#ax.plot([0,0],[-1,5], "k--", alpha=0.5)
## x dividers
#for i in y_values[:-1]:
#    y = i+0.5
#    ax.plot([-2,2],[y,y], "k--", alpha=0.5)
plt.xlim([-2,2])
plt.ylim([-2.75,2.75])

plt.xlabel("", fontsize=fontsize)
plt.ylabel("", fontsize=fontsize)

plt.xticks([], fontsize=fontsize)
plt.yticks([], fontsize=fontsize)

# # Manually create colorbars / legend patches
# native_cmap = sns.cubehelix_palette(start=0.7, rot=rot, dark=dark, light=light, reverse=True, as_cmap=True)
# intrv_cmap = sns.cubehelix_palette(start=-.3, rot=rot, dark=dark, light=light, reverse=True, as_cmap=True)

# Legend handles: colored rectangles with labels
native_patch = mpatches.Patch(color=native_cmap(0.8), label="Native")
intrv_patch = mpatches.Patch(color=intrv_cmap(0.8), label="Intervened")

ax.legend(handles=[native_patch, intrv_patch], fontsize=legendsize, loc="upper right", bbox_to_anchor=(1.75,1))
#plt.savefig("figs/legend.png", dpi=600, bbox_inches="tight")

plt.show()

In [None]:
from geomloss import SamplesLoss
kwargs = {
    "loss": "sinkhorn",
    "p": 2,
    "blur": 0.05,
}
loss_fn = SamplesLoss(**kwargs)

def compute_emd(X,Y):
    return loss_fn(X.float(),Y.float())


# CL Loss

In [None]:
print(og_varbs.shape)
print(samples.shape)

In [None]:
def quick_plot(
    natty,
    intrv,
    natty_classes=None,
    intrv_classes=None,
    save_name=None,
    incl_legend=False,
    incl_dividers=True,
    xlim=[-2,2],
    ylim=[-2.75,2.75],
    labels=["Native", "Intervened"],
    intrv_cmap=None,
    native_cmap=None,
    intrv_color=None,
    native_color=None,
    intrv_alpha = 0.95,
    native_alpha = 0.6,
    thickness=2,
    dash_alpha=0.5,
):
    fig = plt.figure()
    ax = plt.gca()
    
    plt.xticks([])
    plt.yticks([])
    
    natty = natty.cpu().detach().numpy()
    intrv = intrv.cpu().detach().numpy()
    
    if natty_classes is None:
        if intrv_cmap is None:
            intrv_cmap = sns.cubehelix_palette(start=-.3, rot=rot, dark=dark, light=light, reverse=True, as_cmap=True)
        if intrv_color is None:
            intrv_color = intrv_cmap(0.8)
        sns.scatterplot(x=intrv[:,0], y=intrv[:,1], alpha=intrv_alpha, ax=ax, color=intrv_color, edgecolor="none")
    elif len(set(intrv_classes))==1:
        if intrv_cmap is None:
            intrv_cmap = sns.cubehelix_palette(start=-.3, rot=rot, dark=dark, light=light, reverse=True, as_cmap=True)
        if intrv_color is None:
            intrv_color = intrv_cmap(0.8)
        sns.scatterplot(x=intrv[:,0], y=intrv[:,1], alpha=intrv_alpha, ax=ax, color=intrv_color, edgecolor="none")
    else:
        if intrv_cmap is None:
            intrv_cmap = sns.color_palette("pastel")
        sns.scatterplot(
            x=intrv[:,0], y=intrv[:,1],
            alpha=intrv_alpha, ax=ax,
            hue=intrv_classes, palette=intrv_cmap
        )

    if natty_classes is None:
        if native_cmap is None:
            native_cmap = sns.cubehelix_palette(start=0.7, rot=rot, dark=dark, light=light, reverse=True, as_cmap=True)
        if native_color is None:
            native_color = native_cmap(0.8)
        sns.scatterplot(x=natty[:,0], y=natty[:,1], alpha=native_alpha, ax=ax, color=native_color, edgecolor="none")
    elif len(set(natty_classes))==1:
        if native_cmap is None:
            native_cmap = sns.cubehelix_palette(start=0.7, rot=rot, dark=dark, light=light, reverse=True, as_cmap=True)
        if native_color is None:
            native_color = native_cmap(0.8)
        sns.scatterplot(x=natty[:,0], y=natty[:,1], alpha=native_alpha, ax=ax, color=native_color, edgecolor="none")
    else:
        if native_cmap is None:
            native_cmap = sns.color_palette("dark")
        sns.scatterplot(
            x=natty[:,0],
            y=natty[:,1],
            alpha=native_alpha,
            ax=ax,
            hue=natty_classes,
            palette=native_cmap,
            edgecolor="none"
        )
        
    if incl_legend and natty_classes is None:
        native_patch = mpatches.Patch(color=native_color, label=labels[0])
        intrv_patch = mpatches.Patch(color=intrv_color, label=labels[1])
        ax.legend(handles=[native_patch, intrv_patch], fontsize=legendsize, loc="upper right", bbox_to_anchor=(1.75,1))
    elif incl_legend:
        plt.legend()
    else:
        plt.legend().set_visible(False)
        
    ## y divider
    if incl_dividers:
        ax.plot([0,0],[-3,3], "k--", linewidth=thickness, alpha=dash_alpha)
        # x dividers
        for i in y_values[:-1]:
            y = i+0.5-2
            ax.plot([-2,2],[y,y], "k--", linewidth=thickness, alpha=dash_alpha)
    plt.xlim(xlim)
    plt.ylim(ylim)
    if save_name:
        plt.savefig(save_name, dpi=600, bbox_inches="tight")
    plt.show()    

In [None]:
def rot_fwd(vecs, mtx):
    return torch.matmul(vecs, mtx.weight )

def rot_bck(vecs, mtx):
    inv = torch.linalg.inv(mtx.weight)
    return torch.matmul(vecs, inv )

def interchange(trg,src,mtx,mask):
    """
    Patches the non-zero masked values from the src
    into the trg vectors in the rotated space.
    
    Args:
        trg: torch tensor (B,D)
        src: torch tensor (B,D)
        mtx: torch module (D,D)
            needs "weight" property
        mask: torch tensor (D,)
            ones denote dimensions that will be transferred
    """
    rot_trg = rot_fwd(trg, mtx)
    rot_src = rot_fwd(src, mtx )
    intrv = rot_trg*(1-mask) + rot_src*mask
    return rot_bck(intrv, mtx)

In [None]:
def normalize_fn(vecs, eps=1e-7):
    return (vecs-vecs.mean(0))/(vecs.std(0)+eps)

In [None]:
def get_classes_from_varbs(varbs, v2class=None):
    if v2class is None:
        v2class = dict()
    classes = []
    for v in varbs.detach().cpu().tolist():
        tup = tuple(v)
        if tup not in v2class:
            v2class[tup] = len(v2class)
        classes.append(v2class[tup])
    return np.asarray(classes), v2class
    

In [None]:
extra_dims = 64
cov = 0
mask_dims = 1

og_varbs = torch.tensor(og_varbs).long()
d = og_varbs.shape[-1] + extra_dims
mask = torch.zeros(d)
mask[:mask_dims] = 1

samples = torch.tensor(samples).float()
noise = torch.randn(len(og_varbs), extra_dims) + cov * torch.randn(len(og_varbs), extra_dims)*og_varbs[:,1:]
natty_varbs = og_varbs.clone()
natty_classes, v2class = get_classes_from_varbs(natty_varbs)
natty_samps = torch.cat([ samples, noise ], dim=-1)
perm = torch.randperm(len(og_varbs)).long()

intrv_varbs = torch.cat([ natty_varbs[:,0:1], natty_varbs[perm,1:] ], dim=-1)
intrv_classes, _ = get_classes_from_varbs(intrv_varbs, v2class=v2class)

trg_vecs = natty_samps[perm].clone()
src_vecs = natty_samps.clone()

eye = torch.nn.Linear(d,d)
eye.weight.data = torch.eye(d).float()
with torch.no_grad():
    intrv_samps = interchange(trg_vecs, src_vecs, eye, mask)

# Sanity check
quick_plot(
    natty_samps,
    intrv_samps,
    natty_classes=natty_classes,
    intrv_classes=intrv_classes,
    dash_alpha=0.2,
    thickness=3,
    save_name="figs/identity_patching.png",
)

rot_mtx = torch.nn.utils.parametrizations.orthogonal(torch.nn.Linear(d,d))

with torch.no_grad():
    intrv_samps = interchange(trg_vecs, src_vecs, rot_mtx, mask)

# Sanity check
quick_plot(natty_samps, intrv_samps, natty_classes=natty_classes, intrv_classes=intrv_classes)


In [None]:
for i,(iv,nv) in enumerate(zip(intrv_varbs,natty_varbs)):
    if tuple(iv.tolist())==tuple(nv.tolist()):
        assert natty_classes[i]==intrv_classes[i]

In [None]:
def train_classifier(
    vecs, classes,
    n_epochs=1000,
    lr=0.01,
    l2=0.01,
    drop_p=0.5,
    bsize=128,
    patience=500,
    print_every=50,
    n_layers=3,
    hidden_dim=256,
    pre_layernorm=False,
    pre_batchnorm=True,
    layernorm=False,
    batchnorm=True,
    model=None,
    ret_best=True,
    verbose=False,
):
    """
    vecs: torch tensor (B,D)
    classes: list-like (B,)
    """
    classes = torch.tensor(classes).long()
    d = vecs.shape[-1]
    n = len(set(classes.detach().cpu().tolist()))
    if model is None:
        modules = []
        if pre_layernorm:
            modules.append(torch.nn.LayerNorm(d))
        if pre_batchnorm:
            modules.append(torch.nn.BatchNorm1d(d))
        if n_layers>2:
            modules += [
                torch.nn.Linear(d,hidden_dim),
                torch.nn.ReLU(),
                torch.nn.Dropout(drop_p),
            ]
            if layernorm:
                modules.append( torch.nn.LayerNorm(hidden_dim) )
            if batchnorm:
                modules.append( torch.nn.BatchNorm1d(hidden_dim) )
            d = hidden_dim
        if n_layers>1:
            modules += [
                torch.nn.Linear(d,hidden_dim),
                torch.nn.ReLU(),
                torch.nn.Dropout(drop_p),
            ]
            if layernorm:
                modules.append( torch.nn.LayerNorm(hidden_dim) )
            if batchnorm:
                modules.append( torch.nn.BatchNorm1d(hidden_dim) )
            d = hidden_dim
        modules.append(torch.nn.Linear(d,n))
        model = torch.nn.Sequential(*modules)
    if verbose:
        print(model)
    model.train()
    optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2)
    optim.zero_grad()

    d = vecs.shape[-1]
    tlen = int(len(vecs)*0.8)
    train_vecs = vecs[:tlen]
    train_classes = classes[:tlen]
    valid_vecs = vecs[tlen:]
    valid_classes = classes[tlen:]
    prev_loss = np.inf
    best_loss = 0
    best_valid_loss = 0
    best_train_acc = 0
    best_valid_acc = 0
    n_pat = 0
    for epoch in range(n_epochs):
        model.train()
        perm = torch.randperm(len(train_vecs)).long()
        for b in range(0,len(perm)-bsize+1,bsize):
            idxs = perm[b:b+bsize]
            inputs = train_vecs[idxs]
            labels = train_classes[idxs]
            preds = model(inputs)
            loss = torch.nn.functional.cross_entropy(preds, labels)
            loss.backward()
            optim.step()
            optim.zero_grad()
        with torch.no_grad():
            preds = model(train_vecs)
            train_acc = (preds.argmax(-1)==train_classes).float().mean()
            model.eval()
            preds = model(valid_vecs)
            valid_acc = (preds.argmax(-1)==valid_classes).float().mean()
            valid_loss = torch.nn.functional.cross_entropy(preds, valid_classes)
        if epoch % print_every == 0 and verbose:
            print(epoch,
                  "TrnLoss:", loss.item(),
                  "ValLoss:", valid_loss.item(),
                  "TrnAcc:", train_acc.item(),
                  "ValAcc:", valid_acc.item()
                )
        if valid_acc>best_valid_acc or (valid_acc>=best_valid_acc and train_acc>best_train_acc):
            best_loss = loss
            best_valid_loss = valid_loss
            best_valid_acc = valid_acc
            best_train_acc = train_acc
            best_model = copy.deepcopy(model)
        if valid_loss>=prev_loss:
            n_pat += 1
            if n_pat>=patience:
                print("Converged at epoch", epoch)
                break
        prev_loss = valid_loss.item()
        
    model.eval()
    if ret_best:
        print(epoch,
              "TrnLoss:", best_loss.item(),
              "ValLoss:", best_valid_loss.item(),
              "TrnAcc:", best_train_acc.item(),
              "ValAcc:", best_valid_acc.item()
            )
        return best_model, best_train_acc.item(), best_valid_acc.item()
    print(epoch,
          "TrnLoss:", loss.item(),
          "ValLoss:", valid_loss.item(),
          "TrnAcc:", train_acc.item(),
          "ValAcc:", valid_acc.item()
        )
    return model, train_acc.item(), valid_acc.item()

In [None]:
def get_cl_vectors(natty_varbs, intrv_varbs, natty_vecs, method="sample"):
    """
    natty_varbs: tensor (B,2)
        the non-noisy variable values
    intrv_varbs: tensor (B,2)
        the non-noisy variable values
    natty_vecs: tensor (B,D)
        the noisy native vector representations
    method: str
        options: sample, mean
        determines whether the cl vectors should be averaged
        over all possible candidates or individual samples
    """
    cl_vectors = []
    all_idxs = torch.arange(len(natty_varbs)).long()
    for intrv in intrv_varbs:
        valid_bools = (natty_varbs[:,0]==intrv[0])&(natty_varbs[:,1]==intrv[1])
        valid_idxs = all_idxs[valid_bools]
        if method=="sample":
            idx = valid_idxs[int(np.random.randint(len(valid_idxs)))]
            cl_vectors.append(natty_vecs[idx])
        elif method in {"average", "mean"}:
            cl_vectors.append(natty_vecs[valid_idxs].mean(0))
    return torch.vstack(cl_vectors)

In [None]:
def calc_cl_loss(intrv, cl, cl_loss_type="both"):
    l2,cos = 0,0
    if cl_loss_type in {"mse", "both"}:
        l2 = ((intrv-cl)**2).mean()
    if cl_loss_type in {"cos", "both"}:
        cos = 1-torch.nn.functional.cosine_similarity(intrv,cl)
    return l2 + cos
    
def get_cl_loss(
    trg, src,
    mtx, mask,
    cl_vecs,
    empty_mask=None,
    incl_extra=True,
    n_varbs=2,
    calc_loss_in_aligned_basis=False,
    detach_cl_vecs=False,
    cl_loss_type="both",
):
    """
    Performs the interchange and computes the cl loss
    
    Args:
        trg: tensor (B,D)
            target vectors which will be patched into
        src: tensor (B,D)
            source vectors from which activity will be harvested
        mtx: torch module
            must have attribute "weight"
        cl_vecs: tensor (B,D)
        incl_extra: bool
            if true, the extraneous dimensions are included in
            the CL loss
        n_varbs: int
            the number of variables in the causal abstraction
        calc_loss_in_aligned_basis: bool
            if true, will compute cl loss in aligned basis
        detach_cl_vecs: bool
            if true, will detach the cl vectors from gradient
            calculations. otherwise, their rotation will be
            included in the loss.
        cl_loss_type: str
            options:
                "cos": cosine loss only
                "mse": mse loss only
                "both": add both cos and mse losses and divide by 2
    """
    intrv = interchange(trg, src, mtx, mask)
    if empty_mask is not None:
        perm = torch.randperm(len(src)).long()
        intrv = interchange(intrv, src[perm], mtx, empty_mask)
        
    raw_intrv = intrv.clone()
    if not incl_extra:
        extra_dim = mask.long().sum()*n_varbs
        intrv = rot_fwd(intrv, mtx)
        intrv[:,extra_dim:] = 0
        cl_vecs = rot_fwd(cl_vecs, mtx)
        cl_vecs[:,extra_dim:] = 0
        if not calc_loss_in_aligned_basis:
            intrv = rot_bck(intrv, mtx)
            cl_vecs = rot_bck(cl_vecs, mtx)
        if detach_cl_vecs: cl_vecs = cl_vecs.detach().data
    return calc_cl_loss(intrv, cl_vecs, cl_loss_type=cl_loss_type).mean(), raw_intrv
    


In [None]:
def get_actn_loss(preds, labels):
    device = preds.get_device()
    if device<0: device = "cpu"
    labels = torch.tensor(labels).long().to(device)
    loss = torch.nn.functional.cross_entropy(preds, labels)
    acc = (preds.argmax(-1)==labels).float().mean()
    return loss, acc

### Training Loop

In [None]:
def prep_data(
    og_varbs,
    samples,
    v2class=None,
    extra_dims = 128, # total number of additional noise dimensions
    dupl_rank = 0, # number of additional dimensions that are exact duplicates
    zero_rank = 0, # number of dimensions to zero out
    mask_dims=1, # number of dimensions in the DAS mask
    cov_strength = 1, # how much do the extraneous dimensions covary with the x and y values,
    n_samples_per_class=None,
    normalize=False,
):
    """
    Constructs the dataset from the samples
    """
    dupl_rank = min(extra_dims-1, dupl_rank)
    
    d = og_varbs.shape[-1] + extra_dims
    mask = torch.zeros(d)
    mask[:mask_dims] = 1

    if extra_dims>0:
        n_noise_dims = extra_dims-dupl_rank
        if extra_dims>1:
            noise = torch.randn(len(og_varbs), n_noise_dims) +\
                torch.cat([
                    cov * torch.randn(len(og_varbs), n_noise_dims//2)*og_varbs[:,1:],
                    cov * torch.randn(len(og_varbs), n_noise_dims//2)*og_varbs[:,0:1]
                ],dim=-1)
        else:
            noise = torch.randn(len(og_varbs), n_noise_dims) + cov * torch.randn(len(og_varbs), n_noise_dims)*og_varbs[:,1:]
            
        if dupl_rank>0:
            idxs = torch.randint(0,noise.shape[-1],(dupl_rank,)).long()
            dupls = noise.T[idxs].T
            noise = torch.cat([noise,dupls],dim=-1)
        if zero_rank>0:
            # Instead of just zeroing out dimensions, we reduce the rank of the noise
            # in a rotated basis.
            n = noise.shape[-1]
            orth = torch.nn.utils.parametrizations.orthogonal(torch.nn.Linear(n,n))
            with torch.no_grad():
                noise = torch.matmul(noise, orth.weight)
                noise[:,:zero_rank] = 0
                noise = torch.matmul(noise, orth.weight.T)
        natty_vecs = torch.cat([
            torch.tensor(samples).float(), noise
        ], dim=-1)
    else: natty_vecs = torch.tensor(samples).float()
    natty_varbs = torch.tensor(og_varbs).long()
    natty_classes, v2class = get_classes_from_varbs(natty_varbs, v2class=v2class)
    natty_classes = torch.tensor(natty_classes).long()
    if normalize:
        natty_vecs = (natty_vecs-natty_vecs.mean(0))/natty_vecs.std()

    if not n_samples_per_class or n_samples_per_class<0:
        perm = torch.randperm(len(og_varbs)).long()
        intrv_idxs = torch.stack([
            torch.arange(len(og_varbs)).long(), perm
        ],dim=1)
    
        intrv_varbs = torch.cat([ natty_varbs[:,0:1], natty_varbs[perm,1:] ], dim=-1)
        intrv_classes, intrv_v2class = get_classes_from_varbs(intrv_varbs, v2class={**v2class})
        intrv_classes = torch.tensor(intrv_classes).long()
        if len(intrv_v2class)!=len(v2class):
            valid_intrvs = torch.isin(intrv_classes, torch.tensor(list(v2class.values())).long())
        else:
            valid_intrvs = torch.ones(len(intrv_classes)).bool()

        trg_vecs = natty_vecs[perm].clone()
        src_vecs = natty_vecs.clone()
    else:
        intrv_v2class = {**v2class}
        c2varb = {v:k for k,v in v2class.items()}
        n_classes = len(v2class)
        intrv_classes = []
        intrv_idxs = []
        intrv_varbs = []
        arange = torch.arange(len(natty_varbs)).long()
        for c,varb_tup in c2varb.items():
            for samp in range(n_samples_per_class):
                idxs1 = arange[(natty_varbs[:,0]==varb_tup[0])]
                idx1 = idxs1[int(np.random.randint(len(idxs1)))]
                idxs2 = arange[(natty_varbs[:,1]==varb_tup[1])]
                idx2 = idxs2[int(np.random.randint(len(idxs2)))]
                intrv_idxs.append([int(idx1),int(idx2)])
                intrv_classes.append(c)
                intrv_varbs.append([varb_tup[0], varb_tup[1]])
                assert natty_varbs[idx1,0]==varb_tup[0] and natty_varbs[idx2,1]==varb_tup[1]
        intrv_idxs = torch.tensor(intrv_idxs).long()
        intrv_classes = torch.tensor(intrv_classes).long()
        intrv_varbs = torch.tensor(intrv_varbs).long()
        valid_intrvs = torch.ones(len(intrv_classes)).bool()
        
        trg_vecs = natty_vecs[intrv_idxs[:,1]].clone()
        src_vecs = natty_vecs[intrv_idxs[:,0]].clone()
        natty_varbs = natty_varbs[intrv_idxs[:,0]]
        natty_classes = natty_classes[intrv_idxs[:,0]]
        natty_vecs = src_vecs.clone()
                             
    return {
        "intrv_idxs": intrv_idxs,
        
        "mask": mask.clone(),
        "src_vecs": src_vecs.clone(),
        "trg_vecs": trg_vecs.clone(),
        "intrv_varbs": intrv_varbs.clone(),
        "intrv_classes": intrv_classes.clone(),
        "valid_intrvs": valid_intrvs.clone(),
        "src_varbs": natty_varbs.clone(),
        "intrv_v2class": intrv_v2class,
        "src_classes": natty_classes.clone(),
        "v2class": v2class,
    }

In [None]:
def get_plot_save_name(params, excl_keys={"calc_loss_in_aligned_basis", "lr", "n_epochs", "detach_cl_vecs"}):
    s = "toydiv"
    for k in sorted(list(params.keys())):
        if k in excl_keys: continue
        v = params[k]
        s+= f"_{k}{v}"
    return s+".png"

In [None]:
import copy

def train_rotation(
    src_varbs, intrv_varbs, classifier,
    trg_vecs, src_vecs,
    src_classes, intrv_classes,
    mask=None,
    n_epochs = 1000,
    lr = 0.01,
    cl_eps = 1,
    cl_loss_type="both",#"cos" "mse" "both"
    method = "sample",  #"mean" "sample"
    incl_extra = False, # Will include the extraneous subspaces in cl loss if True,
    calc_loss_in_aligned_basis = False,
    detach_cl_vecs = False,
    incl_actn_loss = False,
    incl_cl_loss = True,
    mtx_type = "orthog", # "orthog" "linear",
    print_every = 50,
    fig_every = 200,
    save_fig=False,
    shuffle_empty=False,
    incl_dividers=False,
    early_stopping=True,
    early_stop_thresh=1e-5,
    early_stop_patience=100,
    **kwargs,
):
    exp_params = {
        "n_epochs": n_epochs,
        "lr": lr,
        "method": method,
        "incl_extra": incl_extra,
        "calc_loss_in_aligned_basis": calc_loss_in_aligned_basis,
        "detach_cl_vecs": detach_cl_vecs,
        "incl_actn_loss": incl_actn_loss,
        "incl_cl_loss": incl_cl_loss,
        "mtx_type": mtx_type,
        "cl_eps": cl_eps,
    }

    d = src_vecs.shape[-1]
    assert incl_actn_loss or incl_cl_loss
    cl_vecs = get_cl_vectors(src_varbs, intrv_varbs, src_vecs, method=method)
    if mask is None:
        mask = torch.zeros(d)
        mask[:1] = 1
    empty_mask = None
    if shuffle_empty:
        empty_mask = torch.zeros_like(mask).cuda()
        empty_mask[int(mask.long().sum())*2:] = 1

    # eye = torch.nn.Linear(d,d)
    # eye.weight.data = torch.eye(d).float()
    # rot_mtx = eye
    if mtx_type=="linear":
        rot_mtx = SymmetricDefiniteMatrix( size=d, )
    elif "orthog" in mtx_type:
        rot_mtx = torch.nn.utils.parametrizations.orthogonal(torch.nn.Linear(d,d))
    optim = torch.optim.Adam(rot_mtx.parameters(), lr=lr)
    optim.zero_grad()

    with torch.no_grad():
        eye = torch.nn.Linear(d,d)
        eye.weight.data = torch.eye(d).float()
        intrv_vecs = interchange(trg_vecs, src_vecs, eye, mask)
    print("Identity")
    #quick_plot(intrv_vecs, src_vecs, intrv_classes, src_classes, )
    if fig_every<np.inf:
        quick_plot(
            src_vecs, intrv_vecs.detach(),
            src_classes, intrv_classes,
            save_name="figs/identity.png",
            incl_dividers=incl_dividers
        )

    with torch.no_grad():
        intrv_vecs = interchange(trg_vecs, src_vecs, rot_mtx, mask)
    print("Untrained")
    #quick_plot(intrv_vecs, src_vecs, intrv_classes, src_classes, )
    if fig_every<np.inf:
        quick_plot(src_vecs, intrv_vecs.detach(), src_classes, intrv_classes, incl_dividers=incl_dividers)

    print("Training Rotation")
    n_pat = 0
    best_acc = 0
    best_loss = np.inf
    best_emd = np.inf
    best_row_emd = np.inf
    for epoch in range(n_epochs):
        cl_loss, intrv_vecs = get_cl_loss(
            trg=trg_vecs.cuda(),
            src=src_vecs.cuda(),
            mtx=rot_mtx.cuda(),
            mask=mask.cuda(),
            empty_mask=empty_mask,
            cl_vecs=cl_vecs.cuda(),
            incl_extra=incl_extra,
            calc_loss_in_aligned_basis=calc_loss_in_aligned_basis,
            detach_cl_vecs=detach_cl_vecs,
            cl_loss_type=cl_loss_type,
        )
        with torch.no_grad():
            perm = torch.randperm(len(src_vecs)).long()
            emd = compute_emd(src_vecs[perm].cuda(), intrv_vecs.cuda()).item()
            extra_mask = torch.zeros_like(mask).cuda()
            extra_mask[:2] = 1 # Only using the causal dimensions
            row_emd = compute_emd(
                intrv_vecs.cuda()*extra_mask, src_vecs.cuda()*extra_mask
            ).item()
            
        actn_loss, acc = get_actn_loss(classifier(intrv_vecs), intrv_classes)
        loss = 0
        if incl_actn_loss:
            loss += actn_loss
        if incl_cl_loss:
            loss += cl_eps*cl_loss
        loss.backward()
        optim.step()
        optim.zero_grad()
        
        if epoch % print_every == 0:
            print(epoch, "Cl Loss:", cl_loss.item(), "ActLoss:", actn_loss.item(), "Actn:", acc.item(), "EMD:", emd, "RowEMD", row_emd)
        if epoch % fig_every == 0 and epoch > 0:
            quick_plot(src_vecs, intrv_vecs.detach(), src_classes, intrv_classes, incl_dividers=incl_dividers)
            
        if acc>best_acc:
            best_acc = acc
            best_mtx = copy.deepcopy(rot_mtx)
            best_cl_loss = cl_loss
            best_actn_loss = actn_loss
            best_intrv_vecs = intrv_vecs.detach().cpu().data.clone()
            best_emd = emd
            best_row_emd = row_emd
        if loss<best_loss:
            best_loss = loss.item()
            best_mtx = copy.deepcopy(rot_mtx)
            best_cl_loss = cl_loss
            best_actn_loss = actn_loss
            best_intrv_vecs = intrv_vecs.detach().cpu().data.clone()
            best_emd = emd
            best_row_emd = row_emd
        if loss>=(best_loss-early_stop_thresh) and early_stopping:
            n_pat += 1
            if n_pat>early_stop_patience:
                print("Converged at epoch", epoch)
                break
        else:
            n_pat = 0
        if acc.item()==1 and early_stopping:
            print("Converged at epoch", epoch)
            break


    for p in sorted(list(exp_params.keys())):
        print(p, exp_params[p])
    print()
    if fig_every<np.inf or save_fig:
        save_name = "figs/"+get_plot_save_name(exp_params)
        print("Best Plot by Train Loss")
        quick_plot(
            src_vecs, best_intrv_vecs,
            src_classes, intrv_classes,
            save_name=save_name if save_fig else None,
            incl_dividers=incl_dividers,
        )
        print("Last Plot")
        quick_plot(
            src_vecs, intrv_vecs,
            src_classes, intrv_classes,
            incl_dividers=incl_dividers,
        )
    print(epoch,
        "Cl Loss:", best_cl_loss.item(),
          "ActLoss:", best_actn_loss.item(),
          "Actn:", best_acc.item(),
          "EMD:", best_emd,
         "RowEMD:", best_row_emd)
    return best_mtx, exp_params, best_cl_loss, best_actn_loss, best_acc, best_emd, best_row_emd


In [None]:
extra_dims = 64
dupl_rank = 0 # duplicates uniformly sampled extra dimensions from the set of extra_dims-dupl_rank noisy dimensions
zero_rank = 0
mask_dims = 1
cov_strength = 0 # how much do the extraneous dimensions covary with the x and y values
normalize = False

np.random.seed(12345)
torch.manual_seed(12345)

data_dict = prep_data(
    og_varbs=og_varbs,
    samples=samples,
    v2class=v2class,
    extra_dims=extra_dims,
    dupl_rank=dupl_rank,
    zero_rank=zero_rank,
    mask_dims=mask_dims,
    cov_strength=cov_strength,
    n_samples_per_class=100,
    normalize=normalize
)

mask = data_dict["mask"]
valids = data_dict["valid_intrvs"]
src_vecs = data_dict["src_vecs"][valids]
trg_vecs = data_dict["trg_vecs"][valids]
intrv_idxs = data_dict["intrv_idxs"][valids]
intrv_varbs = data_dict["intrv_varbs"][valids]
intrv_classes = data_dict["intrv_classes"][valids]
src_classes = data_dict["src_classes"][valids]
src_varbs = data_dict["src_varbs"][valids]


d = src_vecs.shape[-1]
mask = torch.zeros(d)
mask[:mask_dims] = 1
empty_mask = torch.zeros_like(mask)
if len(mask)>mask_dims*2:
    empty_mask[mask_dims*2:] = 1

rot_mtx = torch.nn.utils.parametrizations.orthogonal(torch.nn.Linear(d,d))

with torch.no_grad():
    eye = torch.nn.Linear(d,d)
    eye.weight.data = torch.eye(d).float()
    intrv_vecs = interchange(trg_vecs, src_vecs, eye, mask)
print("Identity")
div = compute_emd(src_vecs, intrv_vecs)
print("Div:", div)
quick_plot(src_vecs, intrv_vecs.detach(), src_classes, intrv_classes, )

with torch.no_grad():
    intrv_vecs = interchange(trg_vecs, src_vecs, rot_mtx, mask)
print("Untrained")
div = compute_emd(src_vecs, intrv_vecs)
print("Div:", div)
quick_plot(src_vecs, intrv_vecs.detach(), src_classes, intrv_classes, )


In [None]:
lr = 0.01
bsize = 200
patience = 400
l2 = 0.005
n_epochs = 1000
drop_p = 0.5
hidden_dim = 128
n_layers = 2 # 1-3 layers 
pre_layernorm = False
pre_batchnorm = True
layernorm = False
batchnorm = True

np.random.seed(12345)
torch.manual_seed(12345)
print("Training Classifier")
classifier, max_acc, _ = train_classifier(
    #normalize_fn(src_vecs),
    src_vecs,
    src_classes,
    lr=lr,
    patience=patience,
    l2=l2,
    hidden_dim=hidden_dim,
    n_epochs=n_epochs,
    drop_p=drop_p,
    bsize=bsize,
    n_layers=n_layers,
    pre_layernorm=pre_layernorm,
    pre_batchnorm=pre_batchnorm,
    layernorm=layernorm,
    batchnorm=batchnorm,
    #model=model,
    ret_best=True,
    verbose=True,
)
classifier.cuda()
classifier.eval()
for p in classifier.parameters():
    p.requires_grad = False

In [None]:
lr = 0.05
cl_eps = 100
mask_dims = 1
n_epochs = 500
fig_every = 250

d = src_vecs.shape[-1]
mask = torch.zeros(d)
mask[:mask_dims] = 1

np.random.seed(12345)
torch.manual_seed(12345)
rot_mtx, exp_params, cl_loss, actn_loss, acc, emd, row_emd = train_rotation(
    src_varbs=src_varbs,
    intrv_varbs=intrv_varbs,
    classifier=classifier,
    trg_vecs=trg_vecs,
    src_vecs=src_vecs,
    src_classes=src_classes,
    intrv_classes=intrv_classes,
    mask=mask,
    cl_loss_type = "both", #"cos", "mse", "both
    method = "mean",  #"mean" "sample"
    calc_loss_in_aligned_basis = False,
    detach_cl_vecs = True,
    shuffle_empty = False,
    incl_extra = False, # Will include the extraneous subspaces in cl loss if True,
    lr = lr,
    cl_eps=cl_eps,
    early_stopping=False,
    incl_actn_loss = False,
    incl_cl_loss = True,
    mtx_type = "orthog", # "orthog" "linear",
    print_every = 50,
    n_epochs = n_epochs,
    fig_every = fig_every,
    save_fig=True,
    incl_dividers=False,
)

In [None]:
np.random.seed(12345)
torch.manual_seed(12345)
rot_mtx, exp_params, cl_loss, actn_loss, acc, emd, row_emd = train_rotation(
    src_varbs=src_varbs,
    intrv_varbs=intrv_varbs,
    classifier=classifier,
    trg_vecs=trg_vecs,
    src_vecs=src_vecs,
    src_classes=src_classes,
    intrv_classes=intrv_classes,
    mask=mask,
    cl_loss_type = "both", #"cos", "mse", "both
    method = "mean",  #"mean" "sample"
    calc_loss_in_aligned_basis = False,
    detach_cl_vecs = True,
    shuffle_empty = False,
    incl_extra = False, # Will include the extraneous subspaces in cl loss if True,
    lr = lr,
    cl_eps=cl_eps,
    early_stopping=False,
    incl_actn_loss = True,
    incl_cl_loss = False,
    mtx_type = "orthog", # "orthog" "linear",
    print_every = 50,
    n_epochs = n_epochs,
    fig_every = fig_every,
    save_fig=True,
    incl_dividers=False,
)


# Multi Tasking

In [None]:
def test_rotation(
        rot_mtx,
        data,
        classifier,
        mask=None,
        incl_extra=False,
        calc_loss_in_aligned_basis=False,
        detach_cl_vecs=False,
        method="mean",
        ylim=[-2.75,1.75],
        fig_save_name=None,
        **kwargs,
):
    cl_vecs = get_cl_vectors(
        data["src_varbs"], data["intrv_varbs"], data["src_vecs"], method=method
    )
    if mask is None:
        mask = torch.zeros(src_vecs.shape[-1]).cuda()
        mask[0] = 1

    cl_loss, intrv_vecs = get_cl_loss(
        trg=data["trg_vecs"].cuda(),
        src=data["src_vecs"].cuda(),
        mtx=rot_mtx.cuda(),
        mask=mask.cuda(),
        cl_vecs=cl_vecs.cuda(),
        incl_extra=incl_extra,
        calc_loss_in_aligned_basis=calc_loss_in_aligned_basis,
        detach_cl_vecs=detach_cl_vecs,
    )
    with torch.no_grad():
        src_vecs = data["src_vecs"]
        perm = torch.randperm(len(src_vecs)).long()
        emd = compute_emd(src_vecs[perm].cuda(), intrv_vecs.cuda()).item()
        extra_mask = torch.zeros_like(mask).cuda()
        extra_mask[:2] = 1 # Only using the causal dimensions
        row_emd = compute_emd(
            intrv_vecs.cuda()*extra_mask, src_vecs.cuda()*extra_mask
        ).item()
    actn_loss, acc = get_actn_loss(classifier(intrv_vecs.cuda()).cpu(), torch.tensor(data["intrv_classes"]).long().cpu())

    print("Cl Loss:", cl_loss.item(), "ActLoss:", actn_loss.item(), "Actn:", acc.item(), "EMD:", emd, "RowEMD:", row_emd)
    quick_plot(
        data["src_vecs"], intrv_vecs.detach(),
        data["src_classes"], data["intrv_classes"],
        ylim=ylim,
        incl_dividers=False,
        save_name=fig_save_name,
    )
    return cl_loss, actn_loss, acc, emd, row_emd

In [None]:
excl_divisions = {
    'mirror_L',
    'mirror_h',
    'random_overlap',
    'tetris_C',
    'tetris_F',
    'tetris_L',
    'tetris_T',
}

In [None]:
# import time
np.random.seed(12345)
torch.manual_seed(12345)

In [None]:
import time

data_params = {
    "extra_dims": 128,
    "dupl_rank": 0, # duplicates extra dims
    "zero_rank": 0, # zeros out extra dims in a rotated space
    "cov_strength": 0, # how much do the extraneous dimensions covary with the x and y values,
    "n_samples_per_class": 100,
    "mask_dims": 1,
}

mlp_params = {
    "lr": 0.01,
    "patience": 400,
    "l2": 0.005,
    "bsize": 200,
    "patience": 100,
    "n_epochs": 1000,
    "drop_p": 0.5,
    "hidden_dim": 128,
    "n_layers": 2, # 1-3 layers
    "pre_batchnorm": True,
    "batchnorm": True,
    "ret_best": True,
}


divisions = [ "inner_square", "original" ] # all_divisions
mtx_types = ["orthog",] # "linear"]
cl_epses = [1, 10, 50, 100]
lrs = [0.05] #, 0.1, 0.01, 0.005]
extra_dims_list = [0,16,64,128]
incl_extras = [False, True]
mask_dims = [1,4,8]
n_repeats = 5

ylim = [-2.75, 1.75]

exp_params = {
    "n_epochs": 3000,
    "lr": 0.005,
    "cl_eps": 1,
    "shuffle_empty": False,
    "method": "mean",
    "incl_extra": False,
    "calc_loss_in_aligned_basis": False,
    "detach_cl_vecs": True,
    "incl_actn_loss": True,
    "incl_cl_loss": True,
    "mtx_type": "orthog",
}

dfs = []
all_keys = {*set(data_params.keys()), *set(mlp_params.keys()), *set(exp_params.keys())}
save_keys = ["incl_extra", "extra_dims", "lr", "incl_actn_loss", "incl_cl_loss", "cl_eps", "mask_dims"]
excl_keys = [key for key in all_keys if key not in save_keys]
for incl_extra in incl_extras:
    exp_params["incl_extra"] = incl_extra
    for repeat in range(n_repeats):
        for extra_dims in extra_dims_list:
            data_params["extra_dims"] = extra_dims
            for lr in lrs:
                exp_params["lr"] = lr
                for ial,incl_actn_loss in enumerate([False, True]):
                    for icl,incl_cl_loss in enumerate([False, True,]):
                        if not incl_actn_loss and not incl_cl_loss: continue
                        for ice, cl_eps in enumerate(cl_epses):
                            if ice > 0 and not incl_cl_loss: continue
                            for mask_dim in mask_dims:
                                d = data_params["extra_dims"]+2
                                if mask_dim*2 > d: continue
                                data_params["mask_dims"] = mask_dim
                                df_dict = {
                                    "run_id": [],
                                    "task_num": [],
                                    "task_division": [],
                                    "n_samples": [],
                                    "min_class_count": [],
                                    "max_class_count": [],
                                    "mean_class_count": [],
                                    "class_trn_acc": [],
                                    "class_val_acc": [],
                                    "cl_loss": [],
                                    "actn_loss": [],
                                    "actn_acc": [],
                                    "emd": [],
                                    "row_emd": [],
                                    "cross_cl_loss": [],
                                    "cross_actn_loss": [],
                                    "cross_actn_acc": [],
                                    "cross_emd": [],
                                    "cross_row_emd": [],
                                    "mtx_type": [],
                                }
                                
                                exp_params["incl_cl_loss"] = incl_cl_loss
                                exp_params["incl_actn_loss"] = incl_actn_loss
                                exp_params["cl_eps"] = cl_eps
                                
                                for tdi,task_division in enumerate(divisions):
                                    if task_division in excl_divisions:
                                        print("Skipping", task_division)
                                        continue
                                    print("Starting Task Division", task_division)
                                    for _ in mtx_types:
                                        df_dict["task_division"].append(task_division)
                                        df_dict["task_division"].append(task_division)
                                        df_dict["task_num"].append(0)
                                        df_dict["task_num"].append(1)
                                    task1_bools, task2_bools = make_tasks(task_division=task_division, varbs=og_varbs)
                                
                                    ##########################################################################
                                    ### DATA PREP
                                    ##########################################################################
                                    data_dict = prep_data(
                                        og_varbs=og_varbs[task1_bools],
                                        samples=samples[task1_bools],
                                        **data_params,
                                    )
                                
                                    valids = data_dict["valid_intrvs"]
                                    if valids.sum() == 0:
                                        print("No valid intrvs")
                                        continue
                                    task1_data = {
                                        "src_vecs": data_dict["src_vecs"][valids],
                                        "trg_vecs": data_dict["trg_vecs"][valids],
                                        "intrv_varbs": data_dict["intrv_varbs"][valids],
                                        "intrv_classes": data_dict["intrv_classes"][valids],
                                        "src_varbs": data_dict["src_varbs"][valids],
                                        "src_classes": data_dict["src_classes"][valids],
                                    }
                                    counts = []
                                    intrv_classes = data_dict["intrv_classes"]
                                    for c in sorted(set(intrv_classes.cpu().tolist())):
                                        counts.append((intrv_classes==c).long().sum().item())
                                    for _ in mtx_types:
                                        df_dict["n_samples"].append(len(intrv_classes))
                                        df_dict["min_class_count"].append(np.min(counts))
                                        df_dict["max_class_count"].append(np.max(counts))
                                        df_dict["mean_class_count"].append(np.mean(counts))
                                    print("Class Distr 1:",
                                          "\n\tMin:", df_dict["min_class_count"][-1],
                                          "\n\tMax:", df_dict["max_class_count"][-1],
                                          "\n\tMean:", df_dict["mean_class_count"][-1],
                                        )
                                
                                    data_dict = prep_data(
                                        og_varbs=og_varbs[task2_bools],
                                        samples=samples[task2_bools],
                                        **data_params,
                                    )
                                
                                    valids = data_dict["valid_intrvs"]
                                    task2_data = {
                                        "src_vecs": data_dict["src_vecs"][valids],
                                        "trg_vecs": data_dict["trg_vecs"][valids],
                                        "intrv_varbs": data_dict["intrv_varbs"][valids],
                                        "intrv_classes": data_dict["intrv_classes"][valids],
                                        "src_varbs": data_dict["src_varbs"][valids],
                                        "src_classes": data_dict["src_classes"][valids],
                                    }
                                    counts = []
                                    intrv_classes = data_dict["intrv_classes"]
                                    for c in sorted(set(intrv_classes.cpu().tolist())):
                                        counts.append((intrv_classes==c).long().sum().item())
                                    for _ in mtx_types:
                                        df_dict["n_samples"].append(len(intrv_classes))
                                        df_dict["min_class_count"].append(np.min(counts))
                                        df_dict["max_class_count"].append(np.max(counts))
                                        df_dict["mean_class_count"].append(np.mean(counts))
                                    print("Class Distr 2:",
                                          "\n\tMin:", df_dict["min_class_count"][-1],
                                          "\n\tMax:", df_dict["max_class_count"][-1],
                                          "\n\tMean:", df_dict["mean_class_count"][-1],
                                        )
                                
                                    print(task_division)
                                    quick_plot(
                                        task1_data["src_vecs"], task2_data["src_vecs"],
                                        incl_legend=True,
                                        labels=["Task1", "Task2"],
                                        incl_dividers=False,
                                        ylim=ylim,
                                    )
                            
                                    ##########################################################################
                                    ### Classifier Training
                                    ##########################################################################
                                    print("Training Classifier1")
                                    classifier1, trn_acc, val_acc = train_classifier(
                                        task1_data["src_vecs"],
                                        task1_data["src_classes"],
                                        **mlp_params,
                                        verbose=False,
                                    )
                                    classifier1.cuda()
                                    classifier1.eval()
                                    for p in classifier1.parameters():
                                        p.requires_grad = False
                                    for _ in mtx_types:
                                        df_dict["class_trn_acc"].append(trn_acc)
                                        df_dict["class_val_acc"].append(val_acc)
                                    
                                    print("Training Classifier2")
                                    classifier2, trn_acc, val_acc = train_classifier(
                                        task2_data["src_vecs"],
                                        task2_data["src_classes"],
                                        **mlp_params,
                                        verbose=False,
                                    )
                                    classifier2.cuda()
                                    classifier2.eval()
                                    for p in classifier2.parameters():
                                        p.requires_grad = False
                                    for _ in mtx_types:
                                        df_dict["class_trn_acc"].append(trn_acc)
                                        df_dict["class_val_acc"].append(val_acc)
                                    
                                    mask = data_dict["mask"]
                                
                                    ##########################################################################
                                    ### Rotation Matrix Training
                                    ##########################################################################
                                    for mtx_type in mtx_types:
                                        print("--------------------")
                                        print("Performing New Training")
                                        for k in sorted(save_keys):
                                            if k in exp_params: v = exp_params[k]
                                            elif k in data_params: v = data_params[k]
                                            elif k in mlp_params: v = mlp_params[k]
                                            print(k,v)
                                        print()
                                            
                                        exp_params["mtx_type"] = mtx_type
                                        run_id = time.time()
                                        df_dict["run_id"].append(run_id)
                                        df_dict["run_id"].append(run_id)
                                        df_dict["mtx_type"].append(mtx_type)
                                        df_dict["mtx_type"].append(mtx_type)
                                        print("Training Task1 Matrix", mtx_type)
                                        task1_rot_mtx, _, cl_loss, actn_loss, acc, emd, row_emd = train_rotation(
                                            **task1_data,
                                            classifier=classifier1,
                                            mask=mask,
                                            **exp_params,
                                            print_every=200,
                                            fig_every=np.inf,
                                            early_stopping=True,
                                        )
                                        print("End Task1 Training")
                                        print()
                                        df_dict["cl_loss"].append(float(cl_loss))
                                        df_dict["actn_loss"].append(float(actn_loss))
                                        df_dict["actn_acc"].append(float(acc))
                                        df_dict["emd"].append(float(emd))
                                        df_dict["row_emd"].append(float(row_emd))
                                        
                                    
                                        print("Training Task2 Matrix")
                                        task2_rot_mtx, _, cl_loss, actn_loss, acc, emd, row_emd = train_rotation(
                                            **task2_data,
                                            classifier=classifier2,
                                            mask=mask,
                                            **exp_params,
                                            print_every=200,
                                            fig_every=np.inf,
                                            early_stopping=True,
                                        )
                                        print("End Task2 Training")
                                        print()
                                        df_dict["cl_loss"].append(float(cl_loss))
                                        df_dict["actn_loss"].append(float(actn_loss))
                                        df_dict["actn_acc"].append(float(acc))
                                        df_dict["emd"].append(float(emd))
                                        df_dict["row_emd"].append(float(row_emd))
                                    
                                    
                                        ##########################################################################
                                        ### Testing
                                        ##########################################################################
                                        save_params = {**data_params, **exp_params}
                                        for k in list(save_params.keys()):
                                            if k in excl_keys: del save_params[k]
                                        save_params["task"] = 1
                                        save_name = os.path.join("figs/",get_plot_save_name(save_params))
                                        print("Testing Task1 Matrix on Task2 Data")
                                        cl_loss, actn_loss, acc, emd, row_emd = test_rotation(
                                            task1_rot_mtx,
                                            task2_data,
                                            classifier=classifier2,
                                            mask=mask,
                                            **exp_params,
                                            ylim=ylim,
                                            fig_save_name=save_name,
                                        )
                                        df_dict["cross_cl_loss"].append(float(cl_loss))
                                        df_dict["cross_actn_loss"].append(float(actn_loss))
                                        df_dict["cross_actn_acc"].append(float(acc))
                                        df_dict["cross_emd"].append(float(emd))
                                        df_dict["cross_row_emd"].append(float(row_emd))
                                        
                                        print("Testing Task2 Matrix on Task1 Data")
                                        save_params["task"] = 2
                                        save_name = os.path.join("figs/",get_plot_save_name(save_params))
                                        cl_loss, actn_loss, acc, emd, row_emd = test_rotation(
                                            task2_rot_mtx,
                                            task1_data,
                                            classifier=classifier1,
                                            mask=mask,
                                            **exp_params,
                                            ylim=ylim,
                                            fig_save_name=save_name,
                                        )
                                        df_dict["cross_cl_loss"].append(float(cl_loss))
                                        df_dict["cross_actn_loss"].append(float(actn_loss))
                                        df_dict["cross_actn_acc"].append(float(acc))
                                        df_dict["cross_emd"].append(float(emd))
                                        df_dict["cross_row_emd"].append(float(row_emd))
                                        
                                        print("-"*100)
                                        print()
                                        print()
                                        print()
                                        print()
                                        print()
                                df = pd.DataFrame(df_dict)
                                for k in exp_params:
                                    if k!="mtx_type":
                                        df[k] = exp_params[k]
                                for k in data_params:
                                    df[k] = data_params[k]
                                dfs.append(df)
                    full_df = pd.concat(dfs)
                    full_df.to_csv("csvs/cl_ablations.csv", header=True, index=False)

In [None]:
full_df

In [None]:
temp_df = pd.read_csv("csvs/cl_ablations.csv")
temp_df

In [None]:
cols = [
    "run_id",
    "task_num", "task_division", "mtx_type", "incl_cl_loss", "incl_actn_loss", "cl_eps",
    "min_class_count", "max_class_count", "mean_class_count",
    "actn_acc", "cross_actn_acc",
    "emd", "cross_emd",
    "row_emd", "cross_row_emd",
]
# exp_keys = list(exp_params.keys())
# df = pd.merge(left=full_df, right=counts_df, on=["task_num", "task_division"]+exp_keys) 
# df.sort_values(by=["task_division", "task_num", "cross_actn_acc"], ascending=False)[cols]
full_df.sort_values(by=["task_division", "task_num", "cross_actn_acc"], ascending=False)[cols]

In [None]:
from datetime import datetime
now = datetime.now().strftime('%m-%d-%Y_%HH%MM')
print(now)
full_df.to_csv(f"csvs/inner_square_ablations_{now}.csv", index=False, header=True)

In [None]:
sort_map = {
    "Original": 0,
    "Dense": 1,
    "Sparse": 2,
}

In [None]:
full_df["train_type"] = full_df.apply(
    lambda x: f"{x.mtx_type}{x.incl_cl_loss*'_cl'}{x.incl_actn_loss*'_actn'}_{x.cl_eps}",
    axis=1)
full_df["task_spacing"] = "Original"
full_df.loc[(full_df["task_num"]==0)&(full_df["task_division"]=="inner_square"), "task_spacing"] = "Sparse"
full_df.loc[(full_df["task_num"]==1)&(full_df["task_division"]=="inner_square"), "task_spacing"] = "Dense"
full_df["spacing_order"] = full_df.apply(lambda x: sort_map[x.task_spacing], axis=1)
full_df.loc[~full_df["incl_cl_loss"]&full_df["incl_actn_loss"], "cl_eps"] = 0

In [None]:
full_df.loc[full_df["incl_actn_loss"]&~full_df["incl_cl_loss"]].head(10)

In [None]:
top_k = 6
hue = "task_spacing"
incl_filters = {
    #"mtx_type": ["orthog"],
    "incl_extra": [False],
}
excl_filters = {
    "task_division": [*excl_divisions]+["original"],
}
p = [sns.color_palette("pastel")[i] for i in [-3,0,4,6,3,2,0,6,7,-1]]

plot_df = full_df.copy()
for filt,vals in incl_filters.items():
    plot_df = plot_df.loc[plot_df[filt].isin(vals)]
for filt,vals in excl_filters.items():
    plot_df = plot_df.loc[~plot_df[filt].isin(vals)]

best_ttypes = set(plot_df.sort_values(by=["task_num","cross_actn_acc"], ascending=[True,False]).head(top_k)["train_type"])
plot_df = plot_df.loc[plot_df["train_type"].isin(best_ttypes)]

rot = 35
fig = plt.figure()
sns.barplot(x="train_type", y="actn_acc", hue=hue, data=plot_df, palette=p)
plt.xticks(rotation=rot)
plt.title("Within Task Accuracy")
plt.legend(loc="lower left", title="Training Class Spacing")
plt.show()

fig = plt.figure()
sns.barplot(x="train_type", y="cross_actn_acc", hue=hue, data=plot_df, palette=p)
plt.xticks(rotation=rot)
plt.title("Cross Task Accuracy")
plt.legend(loc="lower left", title="Training Class Spacing")
plt.show()

fig = plt.figure()
sns.barplot(x="train_type", y="emd", hue=hue, data=plot_df, palette=p)
plt.xticks(rotation=rot)
plt.title("EMD")
plt.legend(loc="lower left", title="Training Class Spacing")
plt.show()

fig = plt.figure()
sns.barplot(x="train_type", y="row_emd", hue=hue, data=plot_df, palette=p)
plt.xticks(rotation=rot)
plt.title("Row EMD")
plt.legend(loc="lower left", title="Training Class Spacing")
plt.show()

In [None]:
#full_df = pd.read_csv("csvs/cl_sweep_128_and_16_dims.csv")

In [None]:
sns.color_palette("pastel")

In [None]:
color_order = [
    #-3,0,4,4,6,3,2,0,6,7,-1
    -3,2,4,4,6,3,2,0,6,7,-1
]

In [None]:
set(full_df["extra_dims"])

In [None]:
set(full_df["mask_dims"])

In [None]:
set(full_df["lr"])

In [None]:
extra_dims = 64
incl_filters = {
    #"mtx_type": ["orthog"],
    "extra_dims": [extra_dims], # 16,],
    "incl_extra": [False],
    #"prop_rank": [1,],
    "cov_strength": [0], #,1,], # how much do the extraneous dimensions covary with the x and y values,
    "mask_dims": [1,],
    "lr": [0.05],
    #"cl_eps": [50],
    #"task_spacing": ["Sparse"],
    #"task_division": ["original"],
}
excl_filters = {
    "task_division": excl_divisions,
    #"task_spacing": ["Dense"],
}
class_acc_threshold = 0

rot = 35
ylim = [0.58,1.02]
x = "cl_eps"
y = "cross_actn_acc"
hue = "task_spacing"
mtx_type = "orthog"
leg_title = "Train Task" # " ".join([h.capitalize() for h in hue.split("_")])
p = [sns.color_palette("pastel")[i] for i in color_order]
labelsize = 25
ticksize = 28
fontsize = 30
titlesize = 25
legendsize = 20
linewidth = 4
err_alpha = 0.3
yticks = [0.6, 0.8, 1.0]



plot_df = full_df.loc[full_df["mtx_type"]==mtx_type].copy()
plot_df = plot_df.loc[plot_df["class_val_acc"]>class_acc_threshold]
for filt,vals in incl_filters.items():
    plot_df = plot_df.loc[plot_df[filt].isin(vals)]
for filt,vals in excl_filters.items():
    plot_df = plot_df.loc[~plot_df[filt].isin(vals)]
plot_df = plot_df.copy()
#combo_df = plot_df.loc[plot_df["incl_actn_loss"]&~plot_df["incl_cl_loss"]].copy()
#combo_df["cl_eps"] = 0
#combo_df["incl_cl_loss"] = True
#cl_df = combo_df.copy()
#cl_df["incl_actn_loss"] = False
#plot_df = pd.concat([ plot_df, combo_df, cl_df, ])

bools = plot_df["incl_actn_loss"]&~plot_df["incl_cl_loss"]
bloss_acc = dict(plot_df.loc[bools].groupby(hue)[y].mean())
bloss_err = dict(plot_df.loc[bools].groupby(hue)[y].sem())
bloss_emd = dict(plot_df.loc[bools].groupby(hue)["cross_row_emd"].mean())
bloss_emr = dict(plot_df.loc[bools].groupby(hue)["cross_row_emd"].sem())
print("Behavior Only Accuracy", bloss_acc, "+/-", bloss_err)
print("Behavior Only EMD", bloss_emd, "+/-", bloss_emr)

fig,axes = plt.subplots(1,3, figsize=(15,5))

ax = axes[0]
plt.sca(ax)
temp_df = plot_df.loc[plot_df["incl_cl_loss"]&~plot_df["incl_actn_loss"]].sort_values(by="spacing_order")
sns.barplot(x=x, y=y, hue=hue, data=temp_df, ax=ax, palette=p)
plt.title("CL Loss Only", fontsize=titlesize)
plt.ylim(ylim)
#xlabel = " ".join([lab.capitalize() for lab in x.split("_")])
#plt.xlabel(xlabel, fontsize=labelsize)
plt.xlabel("CL Epsilon", fontsize=labelsize)
plt.ylabel("Cross Task IIA", fontsize=labelsize)
#xticks = list(range(len(set(plot_df[x]))))
#xtick_labels = ["Behavior", *sorted(list(set(plot_df[x])))[1:]]
#plt.xticks(xticks, xtick_labels, rotation=rot, fontsize=ticksize)
plt.xticks(fontsize=ticksize)
plt.yticks(yticks,fontsize=ticksize)
plt.legend(loc="lower left", title=leg_title, fontsize=legendsize, title_fontsize=legendsize)

# Baseline
xs = list(range(len(set(plot_df[x]))))
xs = [np.min(xs)-1] + xs + [np.max(xs)]
xs = np.asarray(xs)
for i,k in enumerate(bloss_acc):
    acc = bloss_acc[k]
    err = bloss_err[k]
    color = p[sort_map[k]]
    ys = np.asarray([acc for _ in xs])
    plt.plot(xs, ys, "--", color=color, alpha=1, linewidth=linewidth)
    plt.fill_between(xs, ys-err, ys+err, alpha=err_alpha, color=color, )

ax = axes[1]
plt.sca(ax)
temp_df = plot_df.loc[plot_df["incl_cl_loss"]&plot_df["incl_actn_loss"]].sort_values(by="spacing_order")
sns.barplot(x=x, y=y, hue=hue, data=temp_df, ax=ax, palette=p)
plt.title("DAS + CL Loss", fontsize=titlesize)
plt.ylim(ylim)
#xlabel = " ".join([lab.capitalize() for lab in x.split("_")])
#plt.xlabel(xlabel, fontsize=labelsize)
plt.xlabel("CL Epsilon", fontsize=labelsize)
plt.ylabel("Cross Task IIA", fontsize=labelsize)
#plt.xticks(xticks, rotation=rot, fontsize=ticksize)
plt.xticks(fontsize=ticksize)
plt.yticks(yticks, fontsize=ticksize)
plt.legend(loc="lower left", title=leg_title, fontsize=legendsize, title_fontsize=legendsize).set_visible(False)

# Baseline
xs = list(range(len(set(plot_df[x]))))
xs = [np.min(xs)-1] + xs + [np.max(xs)]
xs = np.asarray(xs)
for i,k in enumerate(bloss_acc):
    acc = bloss_acc[k]
    err = bloss_err[k]
    color = p[sort_map[k]]
    ys = np.asarray([acc for _ in xs])
    plt.plot(xs, ys, "--", color=color, alpha=1, linewidth=linewidth)
    plt.fill_between(xs, ys-err, ys+err, alpha=err_alpha, color=color, )

ax = axes[2]
plt.sca(ax)
temp_df = plot_df.loc[plot_df["incl_cl_loss"]&plot_df["incl_actn_loss"]].sort_values(by="spacing_order")
sns.barplot(x=x, y="cross_row_emd", hue=hue, data=temp_df, ax=ax, palette=p)
plt.title("DAS + CL EMD", fontsize=titlesize)
plt.xlabel("CL Epsilon", fontsize=labelsize)
plt.ylabel("Cross Task Causal EMD", fontsize=labelsize)
#plt.ylim(ylim)
#xlabel = " ".join([lab.capitalize() for lab in x.split("_")])
#plt.xlabel(xlabel, fontsize=labelsize)
#plt.xticks(xticks, rotation=rot, fontsize=ticksize)
plt.xticks(fontsize=ticksize)
plt.yticks(fontsize=ticksize)
plt.legend(loc="lower left", title=leg_title, fontsize=legendsize, title_fontsize=legendsize).set_visible(False)

# Baseline
xs = list(range(len(set(plot_df[x]))))
xs = [np.min(xs)-1] + xs + [np.max(xs)]
xs = np.asarray(xs)
for i,k in enumerate(bloss_acc):
    emd = bloss_emd[k]
    err = bloss_emr[k]
    color = p[sort_map[k]]
    ys = np.asarray([emd for _ in xs])
    plt.plot(xs, ys, "--", color=color, alpha=1, linewidth=linewidth)
    plt.fill_between(xs, ys-err, ys+err, alpha=err_alpha, color=color, )

plt.tight_layout()
#plt.savefig(f"figs/ood_inner_square_{extra_dims}d.png", dpi=600, bbox_inches="tight")
plt.show()

### EMD

In [None]:
incl_filters = {
    #"mtx_type": ["orthog"],
    "incl_extra": [False], # refers to whether the CL loss is applied to only the masked dimensions or all dimensions
    "extra_dims": [64], # 16,],
    "mask_dims": [1,],
    "lr": [0.05],
    #"prop_rank": [1,],
    #"cov_strength": [0], #,1,], # how much do the extraneous dimensions covary with the x and y values,
    #"cl_eps": [50],
    #"task_spacing": ["Sparse"],
}
excl_filters = {
    "task_division": excl_divisions,
    #"task_spacing": ["Dense"],
}
class_acc_threshold = 0 # filters out MLPs that failed to solve the classification task

groups = ["cl_eps", "task_spacing", "spacing_order"]
metrics = ["emd", "row_emd", "cross_emd", "cross_row_emd", "actn_acc", "cross_actn_acc",]
mtx_type = "orthog"

plot_df = full_df.loc[full_df["mtx_type"]==mtx_type].copy()
plot_df = plot_df.loc[(plot_df["class_val_acc"]>class_acc_threshold)]
for filt,vals in incl_filters.items():
    plot_df = plot_df.loc[plot_df[filt].isin(vals)]
for filt,vals in excl_filters.items():
    plot_df = plot_df.loc[~plot_df[filt].isin(vals)]
plot_df = plot_df.copy()

temp = plot_df.groupby(groups)[metrics]\
    .agg(["mean", "sem"]).reset_index()
columns = []
for col in temp.columns:
    if type(col)==tuple:
        columns.append(f"{col[0]} {col[1]}".replace(" mean", "").strip())
    else:
        columns.append(col.strip())
temp.columns = columns
#temp.sort_values(by=["spacing_order", "cl_eps"], ascending=False)
temp.sort_values(by=["spacing_order", "cl_eps"])


### Ablations

In [None]:
set(full_df["cov_strength"])

In [None]:
incl_filters = {
    #"mtx_type": ["orthog"],
    "incl_extra": [False], # refers to whether the CL loss is applied to only the masked dimensions or all dimensions
    #"extra_dims": [64], # 16,],
    #"mask_dims": [1,],
    "lr": [ 0.05 ],
    #"prop_rank": [1,],
    #"cov_strength": [0], #,1,], # how much do the extraneous dimensions covary with the x and y values,
    #"cl_eps": [50],
    #"task_spacing": ["Sparse"],
}
excl_filters = {
    "task_division": excl_divisions,
    #"task_spacing": ["Dense"],
}
class_acc_threshold = 0.99 # filters out MLPs that failed to solve the classification task

rot = 35
x = "cl_eps"
y = "cross_actn_acc"
hue = "task_spacing"
col = "mask_dims"
row = "extra_dims"
mtx_type = "orthog"
leg_title = "Train Task" # " ".join([h.capitalize() for h in hue.split("_")])
p = [sns.color_palette("pastel")[i] for i in color_order]
labelsize = 25
ticksize = 28
fontsize = 30
titlesize = 20
legendsize = 20
linewidth = 4

ylabel = "Cross Task IIA" if "cross" in y else "Trained Task IIA" 
ylim = [0.58,1.02]
yticks = [0.6, 0.8, 1.0]
if "acc" not in y:
    if "row" in y:
        ylabel = ylabel.replace("Task IIA", "Row EMD")
        yticks = [0,0.025,0.05, 0.075, 0.1, 0.125, 0.15]
        ylim = [0,0.155]
    else:
        ylabel = ylabel.replace("Task IIA", "EMD")
        yticks = None
        ylim = None
print(y)
iia_threshold = 0


plot_df = full_df.loc[full_df["mtx_type"]==mtx_type].copy()
plot_df = plot_df.loc[(plot_df["class_val_acc"]>class_acc_threshold)&(plot_df["actn_acc"]>iia_threshold)]
for filt,vals in incl_filters.items():
    plot_df = plot_df.loc[plot_df[filt].isin(vals)]
for filt,vals in excl_filters.items():
    plot_df = plot_df.loc[~plot_df[filt].isin(vals)]
plot_df = plot_df.copy()
#combo_df = plot_df.loc[plot_df["incl_actn_loss"]&~plot_df["incl_cl_loss"]].copy()
#combo_df["cl_eps"] = 0
#combo_df["incl_cl_loss"] = True
#cl_df = combo_df.copy()
#cl_df["incl_actn_loss"] = False
#plot_df = pd.concat([ plot_df, combo_df, cl_df, ])

bloss_acc = dict(plot_df.loc[plot_df["incl_actn_loss"]&~plot_df["incl_cl_loss"]].groupby(hue)[y].mean())
bloss_err = dict(plot_df.loc[plot_df["incl_actn_loss"]&~plot_df["incl_cl_loss"]].groupby(hue)[y].sem())
print("Behavior Only Accuracy", bloss_acc, "+/-", bloss_err)

#fig,axes = plt.subplots(1,2, figsize=(10,5))
temp_df = plot_df.loc[plot_df["incl_actn_loss"]].sort_values(by=["spacing_order",x], ascending=[True,False])
temp_df["emd"] = temp_df["emd"]/(temp_df["extra_dims"]+1)
temp_df["cross_emd"] = temp_df["cross_emd"]/(temp_df["extra_dims"]+1)
g = sns.catplot(
    x=x, y=y,
    hue=hue,
    col=col,
    row=row,
    data=temp_df,
    ax=ax, palette=p, kind="bar",
)
for i,(_, ax) in enumerate(g.axes_dict.items()):
    plt.sca(ax)
    #plt.title("DAS + CL Loss", fontsize=titlesize)
    if ylim is not None:
        plt.ylim(ylim)
    #xlabel = " ".join([lab.capitalize() for lab in x.split("_")])
    #plt.xlabel(xlabel, fontsize=labelsize)
    plt.xlabel("CL Epsilon", fontsize=labelsize)
    plt.ylabel(ylabel, fontsize=labelsize)
    #plt.xticks(xticks, rotation=rot, fontsize=ticksize)
    plt.xticks(fontsize=ticksize)
    if yticks is None:
        plt.yticks(fontsize=ticksize)
    else:
        plt.yticks(yticks, fontsize=ticksize)
    plt.title(
        ax.get_title(),
        fontsize=titlesize
    )
    if i==9:
        plt.legend(
            title=leg_title,
            fontsize=legendsize,
            title_fontsize=legendsize
        )
    else:
        plt.legend(
            loc="lower left",
            title=leg_title,
            fontsize=legendsize,
            title_fontsize=legendsize
        ).set_visible(False)

## Baseline
#xs = list(range(len(set(plot_df[x]))))
#xs = [np.min(xs)-1] + xs + [np.max(xs)]
#xs = np.asarray(xs)
#for i,k in enumerate(bloss_acc):
#    acc = bloss_acc[k]
#    err = bloss_err[k]
#    color = p[i]
#    ys = np.asarray([acc for _ in xs])
#    plt.plot(xs, ys, "--", color=color, alpha=1, linewidth=linewidth)
#    #plt.fill_between(xs, ys-err, ys+err, alpha=0.2, color=color, )

plt.tight_layout()
#plt.savefig(f"figs/ood_ablations_c{col}_r{row}_{y}.png", dpi=600, bbox_inches="tight")
#plt.savefig(f"figs/ood_ablations_c{col}_r{row}_{y}.pdf", dpi=600, bbox_inches="tight")
plt.show()

In [None]:
fig = plt.figure()
temp_df = plot_df.loc[plot_df["incl_cl_loss"]&plot_df["incl_actn_loss"]].sort_values(by=x, ascending=False)
ax = plt.gca()
sns.barplot(x=x, y=y, hue=hue, data=temp_df, ax=ax, palette=p)
plt.title("DAS + CL Loss", fontsize=titlesize)
plt.ylim(ylim)
#xlabel = " ".join([lab.capitalize() for lab in x.split("_")])
#plt.xlabel(xlabel, fontsize=labelsize)
plt.xlabel("CL Epsilon", fontsize=labelsize)
plt.ylabel("Cross Task IIA", fontsize=labelsize)
#plt.xticks(xticks, rotation=rot, fontsize=ticksize)
plt.xticks(fontsize=ticksize)
plt.yticks(yticks, fontsize=ticksize)
plt.legend(
    loc="lower left",
    bbox_to_anchor=(1,0),
    title=leg_title,
    fontsize=legendsize,
    title_fontsize=legendsize
)
plt.savefig("figs/ood_legend.png", dpi=600,bbox_inches="tight")
plt.show()
