## CNN Statistical Model
We recommend using linux for running the statistical model notebooks. The RAM requirements are ~24GB.

In [1]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sklearn.metrics import root_mean_squared_error
from scipy.stats import rv_discrete, wasserstein_distance
from KDEpy import FFTKDE
import cv2
from cv2 import EMD, DIST_L2
import seaborn as sns
import torch
import torch.nn as nn
import torch.utils.data as utils_data
import torch.distributions as dist
import pickle
import kde_emd_proc as kep
from importlib import reload
import cv2
from tqdm import tqdm
import os
import sys

sys.path.append(os.path.abspath(os.path.join("..")))

import visual
import utils

pygame 2.5.2 (SDL 2.28.3, Python 3.12.2)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [None]:
def get_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')

device = get_device()
print("Using Device:", device)

In [None]:
rgb_green = (17/255, 119/255, 51/255)
rgb_skyblue = (136/255, 204/255, 238/255)
rgb_magenta = (170/255, 68/255, 153/255)

In [None]:
cutoff = 32
resize_shape = (100,120)

In [None]:
def get_ball_pos(world, hole):
    hole_pos = world["hole_positions"][hole]
    return {"x": hole_pos, "y": world["height"]-world["ball_radius"]}

In [None]:
def get_img(world, hole, cutoff=None, resize_shape=None):
    ball_pos = get_ball_pos(world, hole)
    img = visual.snapshot(
        world,
        ret_np=True,
        ball_pos=ball_pos,
        unity_coordinates=True
    )
    if cutoff:
        img = img[:-cutoff,cutoff:-cutoff]
    if resize_shape:
        img = cv2.resize(
            img, dsize=resize_shape, interpolation=cv2.INTER_CUBIC
        )
    norm = 255./2
    return (img.mean(-1)[None] - norm)/norm # move channels to first dim

In [None]:
def process_sample(samp, cutoff=None, resize_shape=None):
    """
    Args:
        samp: dict
            "world": dict
                a world state
            "final_positions": list
                the drop locations for the ball from each hole
        resize_shape: None or tuple
            if tuple is argued, will resize_shape image to this
            height and width
    Returns:
        imgs: list of ndarrays (1,H,W)
        labels: list of floats
            the drop locations
    """
    world, final_positions = samp["world"], samp["final_positions"]
    imgs = []
    labels = []
    n_holes = len(final_positions)
    for hole in range(n_holes):
        img = get_img(world, hole, cutoff=cutoff, resize_shape=resize_shape)
        imgs.append(img) 
        labels.append(final_positions[hole])
    return imgs, labels


In [None]:
def process_visual_data(data, beta_scale=False, cutoff=32, resize_shape=(100,120), verbose=True):
    """
    Processes the data made by the file `world_generation_script.py` into
    inputs X and labels y.
    
    Args
        data: list of dicts
            each dict should have the keys "world" and "final_positions".
            See the input to `process_sample()`
        beta_scale: bool
            if true, will scale the labels to be on the interval of 0
            to 1.
        downsize_factor: int
            the amount to downsize the images by. greater numbers means
            smaller images.
        resize_shape: None or tuple
            if tuple is argued, will resize images to this
            height and width
        cutoff: int
            a value to cutoff the edges of the image.
    Returns:
        X: ndarray (B,1,H,W)
            the image inputs
        y: ndarray (B,)
            the corresponding labels
    """
    X = []
    y = []
    iterable = data
    if verbose:
        iterable = tqdm(data)
    for samp in iterable:
        imgs, labels = process_sample(samp, cutoff=cutoff, resize_shape=resize_shape)
        X.append(imgs)
        y.append(labels)
    img_shape = imgs[0].shape
    X = np.asarray(X, dtype=float).reshape(-1, *img_shape)
    y = np.asarray(y, dtype=float).reshape(-1)
    if beta_scale:
        og_scale = 600 # original width of image
        y = y/og_scale # normalizing by width of the image
    return X,y

# Load Human Data

In [None]:
human_data = pd.read_csv("../../../data/human_data/prediction/prediction_long.csv").drop("Unnamed: 0", axis=1)
human_data.head(10)

In [None]:
human_kde_dict = kep.compute_kdes(human_data, "response")

In [None]:
with open("cv_splits.pkl", "rb") as f:
    splits = pickle.load(f)
splits["train"]

In [None]:
with open("cv_splits.pkl", "rb") as f:
    splits = pickle.load(f)
splits["train"]

world_key = splits["key"]

keys = set(splits.keys())
keys.remove("key")
print("splits:")
for k in keys:
    print(k, splits[k].shape)

og_shapes = {k: splits[k].shape for k in keys}
flat_splits = {k: splits[k].reshape(-1) for k in keys}
world_num_splits = {
    k: np.asarray([world_key[i][1] for i in flat_splits[k]]) for k in keys
}
world_num_splits = {k: v.reshape(og_shapes[k]) for k,v in world_num_splits.items()}
print("\nworld num splits:")
for k in keys:
    print(k, world_num_splits[k].shape)

In [None]:
def isolate_worlds(df, worlds):
    """
    df: human data frame
        needs "world" column referring to world number (not index)
    worlds: set or list of ints 
        the world numbers (not indices)
    """
    return df.loc[df["world"].isin(worlds)]
    
def split_df(df, train_worlds, test_worlds):
    train_df = isolate_worlds(df, train_worlds)
    test_df = isolate_worlds(df, test_worlds)
    return train_df, test_df

In [None]:
# We now have a dict that has 100 splits in which the total set of worlds is split into train and test.
# We can use these splits in conjunction with the `split_df` function to create dataframes to train on.
print("First split:")
print("Train:", world_num_splits["train"][0])
print("Test:", world_num_splits["test"][0])
assert len(set(world_num_splits["train"][0]).intersection(set(world_num_splits["test"][0])))==0
train_df, test_df = split_df(
    human_data,
    train_worlds=world_num_splits["train"][0], 
    test_worlds=world_num_splits["test"][0], 
)
train_df.head()

In [None]:
human_means = human_data.groupby(["world", "hole"])["response"].agg(["mean"]).reset_index()
human_means

In [None]:
worlds = human_data["world"].unique()
worlds

# Gaussian Mixture

## Define Model

In [None]:
class BasicBlock(nn.Module):
    def __init__(self, inplanes, planes, stride=1, resizer=None, bnorm=True, lnorm=False, leading_norm=True, noise=0):
        super().__init__()
        modules = []
        if leading_norm:
            if bnorm:
                modules.append(nn.BatchNorm2d(inplanes))
            if lnorm:
                modules.append(nn.LayerNorm(inplanes))
            if noise:
                modules.append(GaussianNoise(noise))
        modules.append(
            nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        )
        if bnorm:
            modules.append(nn.BatchNorm2d(planes))
        if lnorm:
            modules.append(nn.LayerNorm(planes))
        if noise:
            modules.append(GaussianNoise(noise))
        modules.append(nn.GELU())
        modules.append(
            nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        )
        if bnorm:
            modules.append(nn.BatchNorm2d(planes))
        if lnorm:
            modules.append(nn.LayerNorm(planes))
        if noise:
            modules.append(GaussianNoise(noise))
        self.fxns = nn.Sequential(*modules)
        self.resizer = resizer
        self.stride = stride

    def forward(self, x):
        identity = x
        
        out = self.fxns(x)

        if self.resizer is not None:
            identity = self.resizer(x)

        out += identity
        out = torch.nn.functional.gelu(out)
        return out
    
class GaussianNoise(nn.Module):
    def __init__(self, std=0.05):
        super().__init__()
        self.register_buffer("std", torch.FloatTensor([std]))
    
    def forward(self, x):
        if self.training:
            gauss = torch.randn_like(x)*self.std
            return x + gauss
        return x
    
class res_gmm(nn.Module):
    def __init__(self,
                 input_shape,
                 layer_counts=[2,2,2],
                 chans=[12,24,48],
                 num_comp=10,
                 bnorm=True,
                 lnorm=False,
                 leading_norm=True,
                 noise=0,
                 *args, **kwargs):
        """
        layer_counts: list of ints
            denotes the number of res blocks for each channel change
        """
        super().__init__()

        self.input_shape = input_shape[-3:]
        self.in_conv = [nn.Conv2d(self.input_shape[0], chans[0], kernel_size=7, stride=2, padding=3,bias=False)]
        if bnorm:
            self.in_conv.append(nn.BatchNorm2d(chans[0]))
        if lnorm:
            self.in_conv.append(nn.LayerNorm(chans[0]))
        if noise:
            self.in_conv.append(GaussianNoise(noise))
        self.in_conv.append(nn.GELU())
        self.in_conv = nn.Sequential(*self.in_conv)
        
        self.blocks = nn.ModuleList([])
        chans.append(chans[-1])
        for i, (chan,n_layers) in enumerate(zip(chans,layer_counts)):
            self.blocks.append(
                self._make_layer(
                    BasicBlock, chan, chans[i+1], n_layers, stride=2 if i!=0 else 1, noise=noise
                )
            )
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        self.a_out = nn.Linear(    chans[-1], num_comp)
        self.mu_out = nn.Linear(   chans[-1], num_comp)
        self.sigma_out = nn.Linear(chans[-1], num_comp)
            
    def _make_layer(self, block, inplanes, planes, n_blocks, stride=1, bnorm=True, lnorm=False, leading_norm=True, noise=0):
        resizer = None
        if stride != 1 or inplanes != planes:
            modules = [nn.Conv2d(inplanes, planes, 1, stride, bias=False)]
            if bnorm: modules.append(nn.BatchNorm2d(planes))
            if lnorm: modules.append(nn.LayerNorm(planes))
            if noise: modules.append(GaussianNoise(noise))
            resizer = nn.Sequential(*modules)
        layers = []
        layers.append(block(
            inplanes=inplanes,
            planes=planes,
            stride=stride,
            resizer=resizer,
            bnorm=bnorm,
            lnorm=lnorm,
            leading_norm=leading_norm,
            noise=noise,
        ))
        inplanes = planes
        for _ in range(1, n_blocks):
            layers.append(block(inplanes, planes, stride=1, bnorm=bnorm, lnorm=lnorm, leading_norm=leading_norm, noise=noise))
        return nn.Sequential(*layers)
    
    def get_vars(self, fx):
        a =  self.a_out(fx).softmax(dim=-1)
        mu = self.mu_out(fx)
        sd = torch.nn.functional.softplus(self.sigma_out(fx)) + 1e-5
        return a, mu, sd

    def forward(self, x):
        x = self.in_conv(x)
        for block in self.blocks:
            x = block(x)
        x = self.avgpool(x).reshape(len(x),-1)

        a,mu,sd = self.get_vars(x)
        return a, mu, sd

    def setup_distr(self,x):
        a, mu, sd = self(x)
        return self.get_distr(a,mu,sd)
    
    def get_distr(self, w, center, spread):
        mix = dist.Categorical(w)
        comp = dist.Normal(center, spread)
        return dist.MixtureSameFamily(mix, comp)

class res_bmm(res_gmm):
    def get_vars(self, fx):
        c = self.a_out(fx).softmax(dim=1)
        alpha = self.mu_out(fx).exp()
        beta = self.sigma_out(fx).exp()
        return c, alpha, beta
    
    def setup_distr(self, x):
        comp_weights, alpha, beta = self(x)
        return self.get_distr(comp_weights, alpha, beta)
    
    def get_distr(self, w, center, spread):
        mix = dist.Categorical(w)
        comp = dist.Beta(center, spread)
        return dist.MixtureSameFamily(mix, comp)
    
kwargs = {
    "num_comp": 10,
    "drop_p": 0.5,
    "layer_counts": [2,2,2],
    "chans": [12,24,48],
    "bnorm": True,
    "lnorm": False,
    "leading_norm": True,
    "noise": 0.01,
}

## Setup Data

In [None]:

def create_feature_rep(exp_trial, world_num=None, world_rep=None, cutoff=None, resize_shape=None):

    if exp_trial:
        world_rep = utils.load_trial(world_num, experiment="prediction", hole=1)
    else:
        assert not (world_rep is None)
        
    world_rep_unity = visual.unity_transform_trial(world_rep)
    n_holes = len(world_rep_unity["hole_positions"])
    imgs = []
    for hole in range(n_holes):
        img = get_img(world_rep_unity, hole, cutoff=cutoff, resize_shape=resize_shape)
        imgs.append(img)
    return imgs

def create_data_arrays(df_data, beta_scale=False, cutoff=cutoff, resize_shape=resize_shape):

    feature_reps_dict = {}

    for world_num in df_data["world"].unique():
        world_dict = {}
        tr_rep = create_feature_rep(True, world_num=world_num, cutoff=cutoff, resize_shape=resize_shape)
        for hole in [1,2,3]:
            world_dict[hole] = tr_rep[hole-1]
        feature_reps_dict[world_num] = world_dict

    input_list = []
    label_list = []

    for ind, row in df_data.iterrows():
        world_num = int(row["world"])
        hole = int(row["hole"])
        response = row["response"]

        if beta_scale:
            og_shape = 600
            response /= og_shape
        
        tr_rep = feature_reps_dict[world_num][hole]

        input_list.append(tr_rep)
        label_list.append(response)

    input_np = np.array(input_list)
    label_np = np.array(label_list)

    return input_np, label_np

class MLPDataSet(utils_data.Dataset):

    def __init__(self, inputs, labels):

        self.inputs = inputs
        self.labels = labels

    def __len__(self):
        return self.inputs.shape[0]
    
    def __getitem__(self, idx):
        inpt = self.inputs[idx]
        label = self.labels[idx]
        return inpt, label

In [None]:
# df_train_data = human_data[human_data["world"].isin(train_worlds)]
# df_valid_data = human_data[human_data["world"].isin(valid_worlds)]

def process_human_data(train_data, valid_data, beta_scale=True, cutoff=32, resize_shape=(100,120)):
    train_inputs_np, train_labels_np = create_data_arrays(train_data, beta_scale=beta_scale, cutoff=cutoff, resize_shape=resize_shape)
    valid_inputs_np, valid_labels_np = create_data_arrays(valid_data, beta_scale=beta_scale, cutoff=cutoff, resize_shape=resize_shape)
    return train_inputs_np, train_labels_np, valid_inputs_np, valid_labels_np


## Setup Model and Run Training Loop

In [None]:
np.random.seed(1)
torch.manual_seed(1)

In [None]:
def train_loop(model_type, model_kwargs, lrs,
               train_data, valid_data, process_data,
               resize_shape=resize_shape, cutoff=cutoff,
               batch_size=50, beta_scale=True,
               print_prog=False, n_epochs=100,
               init_model_path=None, combos=None,
               data_tup=None,
              ):
    if data_tup is None:
        print("Preprocessing the Image Data")
        data_tup = process_data(train_data, valid_data, beta_scale=beta_scale, resize_shape=resize_shape, cutoff=cutoff)
    train_input_np, train_label_np, valid_input_np, valid_label_np = data_tup
    
    train_input_torch = torch.FloatTensor(train_input_np)
    train_label_torch = torch.FloatTensor(train_label_np)
    print("Input Shape:", train_input_torch.shape)
    print("Label Shape:", train_label_torch.shape)
    
    valid_input_torch = torch.FloatTensor(valid_input_np)
    valid_label_torch = torch.FloatTensor(valid_label_np)
    
    train_ds = MLPDataSet(train_input_torch, train_label_torch)
    train_dl = utils_data.DataLoader(train_ds,
                                batch_size=batch_size,
                                shuffle=True)
    model_kwargs["input_shape"] = train_input_torch.shape
    return_results = False
    results = []
    if combos is None:
        combos = [model_kwargs]
    else:
        print("Searching Over:", list(combos[0].keys()))
        combos = [{**model_kwargs, **combo} for combo in combos]
    for lr in lrs:
        for combo in combos:
            try:
                print()
                print("New Training - LR:", lr)
                print("Model Kwargs:")
                for k,v in combo.items():
                    print("\t", k,v)
                model = model_type(**combo)
                if init_model_path:
                    checkpt = torch.load(init_model_path)
                    model.load_state_dict(checkpt["state_dict"])
                optimizer = torch.optim.Adam(model.parameters(),
                                            lr = lr,
                                            weight_decay=0.01)
                model.to(device)
                
                train_loss_record = []
                valid_loss_record = []
                bpe = len(train_dl)
                train_n = len(train_ds)
                best_loss = np.inf
                
                for epoch in range(n_epochs):
                    data_iter = iter(train_dl)
                
                    model.train()
                
                    train_epoch_loss = 0
                
                    for i,(batch_inp, batch_lab) in enumerate(data_iter):
                        optimizer.zero_grad()
                        distr = model.setup_distr(batch_inp.to(device))
                        loss = -distr.log_prob(batch_lab.to(device)).sum()
                
                        loss.backward()
                        optimizer.step()
                
                        train_epoch_loss += loss.item()
                        if print_prog:
                            print(round(100*(i/len(data_iter))), "%", end="          \r")
                
                    
                    train_loss_record.append(train_epoch_loss/train_n)
                
                    model.eval()
                
                    with torch.no_grad():
                        vbsize = 1000
                        valid_loss = 0
                        for j in range(0,len(valid_input_torch),vbsize):
                            inpts =  valid_input_torch[j:j+batch_size]
                            labels = valid_label_torch[j:j+batch_size]
                            valid_distr = model.setup_distr(inpts.to(device))
                            valid_loss += -valid_distr.log_prob(labels.to(device)).sum().item()
                        valid_loss = valid_loss/valid_input_torch.shape[0]
                        valid_loss_record.append(valid_loss)
                    if valid_loss<best_loss:
                        best_loss = valid_loss
                        best_model_sd = model.state_dict()
                        best_epoch = epoch
                
                    print("Epoch", epoch, "-- Loss", train_epoch_loss/train_n, "-- Val:", valid_loss)
                    
            except KeyboardInterrupt as e:
                return_results = True
            model.cpu()
            results.append({
                "lr": lr,
                "model": model,
                "train_loss": train_loss_record,
                "val_loss": valid_loss_record,
                "model_kwargs": model_kwargs,
                "best_model_sd": best_model_sd,
                "best_loss": best_loss,
                "best_epoch": best_epoch,
            })
            del optimizer
            del model
            torch.cuda.empty_cache()
            if return_results:
                return results, data_tup
    return results, data_tup

In [None]:
def get_combos(d, combo=None, keys=None, idx=0, combos=None):
    """
    This function will return a list of combination dicts.
    
    Args:
        d: dict
            the dict with the keys to be searched over and
            values that are lists of the values to be searched
            over.
    Ignorable Args:
        combo: dict
            the current combo
        keys: list
            the keys that are being searched over
        idx: int
            the current recursion level
        combos: list of dicts
            the resulting list of parameter dicts
    Returns:
        list of dicts
    """
    if combo is None: combo = dict()
    if keys is None:
        keys = list(d.keys())
    if combos is None:
        combos = []
    if idx>=len(keys):
        combos.append(combo)
        return combos
    k = keys[idx]
    for v in d[k]:
        get_combos(d, combo={k:v, **combo}, keys=keys, idx=idx+1, combos=combos)
    return combos
    

In [None]:
# Can find human data at /data/human_data/prediction/prediction_long.csv
# cv_splits: has world indices which can be converted into world numbers which is a column in prediction long
# create a train and eval set filtered on the world numbers in the cv_splits data


# Beta Mixture

## Define Model

In [None]:
class res_bmm(res_gmm):
    def get_vars(self, fx):
        c = self.a_out(fx).softmax(dim=1)
        alpha = self.mu_out(fx).exp()
        beta = self.sigma_out(fx).exp()
        return c, alpha, beta
    
    def setup_distr(self, x):
        comp_weights, alpha, beta = self(x)
        return self.get_distr(comp_weights, alpha, beta)
    
    def get_distr(self, w, center, spread):
        mix = dist.Categorical(w)
        comp = dist.Beta(center, spread)
        return dist.MixtureSameFamily(mix, comp)


# BMM Simulated Data

## Setup Dataset

In [None]:
with open("saved_sims/train_phys_worlds.pkl", "rb") as f:
    train_data = pickle.load(f)
    
with open("saved_sims/val_phys_worlds.pkl", "rb") as f:
    val_data = pickle.load(f)
    

In [None]:
print("Train Data Size:", len(train_data))
print("Valid Data Size:", len(val_data))

In [None]:
def process_phys_data(train_data, val_data, beta_scale=True, cutoff=32, resize_shape=(100,120), verbose=True):
    train_X, train_y = process_visual_data(train_data, beta_scale=beta_scale, cutoff=cutoff, resize_shape=resize_shape, verbose=verbose)
    val_X, val_y = process_visual_data(val_data, beta_scale=beta_scale, cutoff=cutoff, resize_shape=resize_shape, verbose=verbose)
    return train_X, train_y, val_X, val_y

## Setup Model, Data -- Run Training Loop

In [None]:
lrs = [5e-4,]
model_type = res_bmm
batch_size = 50
n_epochs = 20
beta_scale = True
kwargs = {
    "num_comp": 10,
    "noise": 0.08,
    "layer_counts": [2,2,2],
    "chans": [12,24,48],
    "bnorm": True,
    "lnorm": False,
    "leading_norm": True,
}
overrides = {
    "noise": [0.08] #, 0.11],
}
combos = get_combos(overrides)

results, data_tup = train_loop(
    model_type, kwargs, lrs,
    train_data=train_data,
    valid_data=val_data,
    process_data=process_phys_data,
    batch_size=batch_size,
    beta_scale=beta_scale,
    print_prog=True,
    n_epochs=n_epochs,
    combos=combos,
)

(train_inputs_np, train_labels_np, valid_inputs_np, valid_labels_np) = data_tup

In [None]:
best_lr = None
min_loss = np.inf
model = None
for i,res in enumerate(results): 
    lr = res["lr"]
    train_loss_record = res["train_loss"]
    valid_loss_record = res["val_loss"]
    plt.plot(range(len(train_loss_record)), train_loss_record, label="train_loss")
    plt.plot(range(len(valid_loss_record)), valid_loss_record, color="red", label="validation_loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.ylim([np.min(train_loss_record)-1,max(np.max(valid_loss_record), np.max(train_loss_record))])
    plt.title(f"LR: {lr}")
    plt.legend(loc="upper right")
    
    #plt.savefig("plots/mlp_bmm_loss_curve.jpg",
    #            dpi=200,
    #            bbox_inches="tight")
    plt.show()
    
    loss = np.min(valid_loss_record[-1])
    if loss<min_loss:
        best_idx = i
        best_lr = lr
        min_loss = loss
        model = res["model"]
        best_res = res
        
best_model_data = {
    "state_dict": model.state_dict(),
    "lr": best_res["lr"],
    "train_loss": best_res["train_loss"],
    "val_loss": best_res["val_loss"],
    "model_kwargs": best_res["model_kwargs"],
}
torch.save(best_model_data, "nn_checkpoints/best_phys_cnn_bmm.pth")

print("Best LR:", best_lr, "-- Train Loss:", results[best_idx]["train_loss"][-1], "-- Val Loss:", results[best_idx]["val_loss"][-1])

# Cross Validation

In [None]:
train_worlds = world_num_splits["train"]
valid_worlds = world_num_splits["test"]
test_worlds = world_num_splits["test"]

## Run Cross Val

In [None]:
# fix random seed
np.random.seed(1)
torch.manual_seed(1)


In [None]:
lrs = [5e-4]
model_type = res_bmm
batch_size = 50
beta_scale = True
n_epochs = 30
kwargs = {
    "num_comp": 10,
    "drop_p": 0.5,
    "layer_counts": [2,2,2],
    "chans": [12,24,48],
    "bnorm": True,
    "lnorm": False,
    "leading_norm": True,
    "noise": 0.01,
}

init_model_path = "nn_checkpoints/best_phys_cnn_bmm.pth"

opt_epochs = []
opt_model_sds = []
end_model_sds = []
models = []

split = 0
for train_split, valid_split in zip(train_worlds, valid_worlds):

    print("Split:", split)
    print()

    df_train_data = human_data[human_data["world"].isin(train_split)]
    df_valid_data = human_data[human_data["world"].isin(valid_split)]
        
    results, data_tup = train_loop(
        model_type, kwargs, lrs,
        train_data=df_train_data,
        valid_data=df_valid_data,
        process_data=process_human_data,
        batch_size=batch_size,
        beta_scale=beta_scale,
        init_model_path=init_model_path,
        n_epochs=n_epochs,
    )
    (train_input_np, train_label_np, valid_input_np, valid_label_np) = data_tup
    train_loss_record = results[0]["train_loss"]
    valid_loss_record = results[0]["val_loss"]

    models.append(results[0]["model"].cpu())
    end_model_sds.append(results[0]["model"].state_dict())

    min_epoch = min(list(enumerate(valid_loss_record)), key=lambda x: x[1])[0]

    opt_model_sds.append(results[0]["best_model_sd"])
    opt_epochs.append(min_epoch)

    split += 1
    

In [None]:
for res in results:
    for k,v in res["best_model_sd"].items():
        res["best_model_sd"][k] = v.cpu()

In [None]:
res.keys()

In [None]:
torch.cuda.empty_cache()

In [None]:
mlp_means = []
mlp_kde_dict = {}

x_values = torch.linspace(0.0001,0.9999,600).unsqueeze(1)

results = []
for mi,(test_split, model, model_sd) in enumerate(zip(test_worlds, models, opt_model_sds)):
    print(f"Testing {mi}/{len(test_worlds)}")
    torch.cuda.empty_cache()
    model.load_state_dict(model_sd)
    model.eval()
    
    # Setup input for given split
    test_input_list = []
    for world_num in test_split:
        test_input_list += create_feature_rep(True, world_num, resize_shape=resize_shape, cutoff=cutoff)
    test_input_np = np.array(test_input_list)

    test_input_torch = torch.from_numpy(test_input_np).float()

    # Get predictions and setup bmm
    tcw_list = []
    ta_list = []
    tb_list = []
    bsize = 100
    for i in range(0,len(test_input_torch), bsize):
        test_comp_weights, test_alpha, test_beta = model(test_input_torch[i:i+bsize])
        tcw_list.append(test_comp_weights.cpu())
        ta_list.append(test_alpha.cpu())
        tb_list.append(test_beta.cpu())
    model.cpu()
    test_comp_weights = torch.cat(tcw_list, dim=0)
    test_alpha = torch.cat(ta_list, dim=0)
    test_beta = torch.cat(tb_list, dim=0)
    test_bmm = model.get_distr(test_comp_weights, test_alpha, test_beta)

    # Get means
    test_bmm_means = test_bmm.mean.detach().cpu().numpy() * 600

    # Get pdf
    log_probs = test_bmm.log_prob(x_values)
    pdf = torch.exp(log_probs)/600
    pdf_np = pdf.cpu().detach().numpy().T

    mlp_means += list(zip(np.repeat(test_split, 3), test_bmm_means))
    split_kde = kep.make_kde_dict(pdf_np, test_split)

    mlp_kde_dict.update(split_kde)
    
    res_dict = {
        "test_bmm_means": test_bmm_means,
        "log_probs": log_probs.cpu().detach().numpy(),
        "pdf": pdf_np,
        "kde_dicts": split_kde,
    }
    results.append(res_dict)

d = "saved_model_pred/"
if not os.path.exists(d):
    os.mkdir(d)
f = os.path.join(d,"cnn_w_pretraining_cvsplits.p")
with open(f, "wb") as f:
    pickle.dump(results, f)
print("Saved to", f)

mlp_means.sort(key=lambda x: x[0])


# All Data Training

In [None]:
lrs = [5e-4]
model_type = res_bmm
batch_size = 50
beta_scale = True
n_epochs = 30
kwargs = {
    "num_comp": 10,
    "drop_p": 0.5,
    "layer_counts": [2,2,2],
    "chans": [12,24,48],
    "bnorm": True,
    "lnorm": False,
    "leading_norm": True,
    "noise": 0.01,
}

init_model_path = "nn_checkpoints/best_phys_cnn_bmm.pth"

opt_epochs = []
opt_model_sds = []
end_model_sds = []
models = []

split = 0
data_split = list(train_worlds[0]) + list(valid_worlds[0])#+ list(test_worlds[0])
df_train_data = human_data[human_data["world"].isin(data_split)]
df_valid_data = human_data[human_data["world"].isin(valid_worlds[0])]

    
results, data_tup = train_loop(
    model_type, kwargs, lrs,
    train_data=df_train_data,
    valid_data=df_valid_data,
    process_data=process_human_data,
    batch_size=batch_size,
    beta_scale=beta_scale,
    init_model_path=init_model_path,
    n_epochs=n_epochs,
)
(train_input_np, train_label_np, valid_input_np, valid_label_np) = data_tup
train_loss_record = results[0]["train_loss"]
valid_loss_record = results[0]["val_loss"]

models.append(results[0]["model"].cpu())
end_model_sds.append(results[0]["model"].state_dict())

min_epoch = min(list(enumerate(valid_loss_record)), key=lambda x: x[1])[0]

opt_model_sds.append(results[0]["best_model_sd"])
opt_epochs.append(min_epoch)

split += 1


In [None]:
mlp_means = []
mlp_kde_dict = {}

x_values = torch.linspace(0.0001,0.9999,600).unsqueeze(1)

results = []
mi = 0
test_split = sorted(data_split)
model = models[0]
model_sd = opt_model_sds[0]

print(f"Testing")
torch.cuda.empty_cache()
model.load_state_dict(model_sd)
model.eval()

# Setup input for given split
test_input_list = []
for world_num in test_split:
    test_input_list += create_feature_rep(True, world_num, resize_shape=resize_shape, cutoff=cutoff)
test_input_np = np.array(test_input_list)

test_input_torch = torch.from_numpy(test_input_np).float()

# Get predictions and setup bmm
tcw_list = []
ta_list = []
tb_list = []
bsize = 100
for i in range(0,len(test_input_torch), bsize):
    test_comp_weights, test_alpha, test_beta = model(test_input_torch[i:i+bsize])
    tcw_list.append(test_comp_weights.cpu())
    ta_list.append(test_alpha.cpu())
    tb_list.append(test_beta.cpu())
model.cpu()
test_comp_weights = torch.cat(tcw_list, dim=0)
test_alpha = torch.cat(ta_list, dim=0)
test_beta = torch.cat(tb_list, dim=0)
test_bmm = model.get_distr(test_comp_weights, test_alpha, test_beta)

# Get means
test_bmm_means = test_bmm.mean.detach().cpu().numpy() * 600

# Get pdf
log_probs = test_bmm.log_prob(x_values)
pdf = torch.exp(log_probs)/600
pdf_np = pdf.cpu().detach().numpy().T

mlp_means += list(zip(np.repeat(test_split, 3), test_bmm_means))
split_kde = kep.make_kde_dict(pdf_np, test_split)

mlp_kde_dict.update(split_kde)

res_dict = {
    "test_bmm_means": test_bmm_means,
    "log_probs": log_probs.cpu().detach().numpy(),
    "pdf": pdf_np,
    "kde_dicts": split_kde,
}
results.append(res_dict)

f = "saved_model_pred/cnn_w_pretraining_alldata.p"
with open(f, "wb") as f:
    pickle.dump(results, f)
print("Saved to", f)

mlp_means.sort(key=lambda x: x[0])

In [None]:
log_probs.shape

In [None]:
_, mlp_mean_vals = zip(*mlp_means)

r = np.round(np.corrcoef(human_means["mean"], mlp_mean_vals)[0,1], decimals=2)
rmse = np.round(root_mean_squared_error(human_means["mean"], mlp_mean_vals), decimals=2)

print("r:", r)
print("RMSE:", rmse)

In [None]:
means_dict = {
    "human": human_means["mean"].to_numpy(),
    "mlp": mlp_mean_vals,
    "hole": [1,2,3]*40
}

df_to_show = pd.DataFrame(means_dict)

# plt.scatter(human_means["mean"], mlp_mean_vals)
plt.plot([0,600], [0,600], color="black", zorder=1)
sns.scatterplot(data=df_to_show,
                x="mlp", 
                y="human", 
                hue="hole", 
                palette=[rgb_green, rgb_skyblue, rgb_magenta])



plt.xlabel("Badass CNN Mean Prediction", fontsize=20)
plt.ylabel("Human Mean Response", fontsize=20)

plt.xlim(0,600)
plt.ylim(0,600)

plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

# remove the legend
plt.legend([],[], frameon=False)

plt.text(10, 570, f"r = {r}", fontsize=14)
plt.text(9, 535, f"RMSE = {rmse}", fontsize=14)

In [None]:
df_mlp_emd = kep.compute_emds(mlp_kde_dict, human_kde_dict, worlds)

In [None]:
print("Mean:", round(df_mlp_emd["EMD"].mean(), 2))
print()
print("Quantiles:")
print(df_mlp_emd["EMD"].quantile([0.05, 0.95]).round(decimals=2))

In [None]:
sns.histplot(data=df_mlp_emd, x="EMD")

plt.xlabel("CNN EMD", fontsize=20)
plt.ylabel("Count", fontsize=20)

plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

# remove the legend
plt.legend([],[], frameon=False)

In [None]:
mlp_kde_dict.keys()

In [None]:
world_num = worlds[0]
print(world_num)
kep.plot_kdes(mlp_kde_dict, world_num, colors=[rgb_green, rgb_skyblue, rgb_magenta], alpha=0.8, format="wide")