In [None]:
%load_ext tensorboard

In [None]:
from pathlib import Path

import jax
import jax.numpy as jnp
import jax_verify
import numpy as np
import optax
import pandas as pd
import seaborn as sns
from einops import reduce
from flax import nnx
from orbax import checkpoint as ocp
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from vml_final.data import CSVDataset, CSVDatasetEpochLoader
from vml_final.model import TemporalConvolutionalNetwork
from vml_final.training import do_eval_epoch, do_train_epoch

In [None]:
project_root = Path("../../")

set_paths = [
    project_root / "csv_dataset" / f"AB{str(i).zfill(2)}" for i in range(6, 25)
]

stack_size = 200
train_frac = 0.9

sensors = ["imu", "gon", "emg"]
# sensors = ["emg"]

dset_name = f"6-25_{'-'.join(sensors)}_{stack_size}"
ckpt_path = (project_root / "processed_sets" / dset_name).resolve()


build_dset = not ckpt_path.exists() # True


rngs = nnx.Rngs(0)

with ocp.StandardCheckpointer() as ckptr:

    if build_dset:
        dset = CSVDataset.build(
            rngs(),
            set_paths,
            stack_size=stack_size,
            sensors_to_use=sensors,
            train_frac=train_frac,
        )
        ckptr.save(ckpt_path, dset)
    else:
        restored_list = ckptr.restore(ckpt_path)
        dummy = CSVDataset(
            *restored_list
        )
        dset = ckptr.restore(ckpt_path, dummy)
        stack_size = dset.stack_size

In [None]:
dset.x.shape

In [None]:
train_loader = CSVDatasetEpochLoader(dset, 32_768)
eval_loader = CSVDatasetEpochLoader(dset, 32_768, train=False)

In [None]:
%tensorboard --logdir ../../logs --port 6006

In [None]:
model = TemporalConvolutionalNetwork(
    input_channels=dset.x.shape[-1],
    conv_hidden_dims=[4, 4, 4, 8],
    # mlp_hidden_dims=[128, 128],
    kernel_size=7,
    stride=6,
    dropout=0.2,
    rngs=rngs,
)

optim = nnx.Optimizer(model, optax.adam(2.5e-3))

In [None]:
writer = SummaryWriter(project_root / "logs")

In [None]:
jax.config.update("jax_debug_nans", True)
jax.config.jax_debug_nans

In [None]:
num_epochs = 128

# pbar.close()
pbar = tqdm(total=num_epochs)


def logging_callback(train_loss, validation_loss, step):
    pbar.update()
    pbar.set_postfix(
        {"Train Loss": train_loss.item(), "Val Loss": validation_loss.item()}
    )
    writer.add_scalars(
        "loss",
        {"train": train_loss.item(), "validation": validation_loss.item()},
        global_step=step,
    )


optim_graphdef, optim_state = nnx.split(optim)


def scanf(optim_state, key):
    optim = nnx.merge(optim_graphdef, optim_state)
    rngs = nnx.Rngs(key)
    train_loss = do_train_epoch(optim, train_loader, rngs=rngs)
    validation_loss = do_eval_epoch(optim.model, eval_loader, rngs=rngs)

    current_step_index = optim.step.value
    jax.debug.callback(
        logging_callback, train_loss, validation_loss, current_step_index
    )

    return nnx.state(optim), None


optim_state, _ = jax.lax.scan(scanf, optim_state, jax.random.split(rngs(), num_epochs))
nnx.update(optim, optim_state)

pbar.close()

In [None]:
batch_x, batch_y = dset.get_batch(rngs(), 128, train=False)

np.stack([model(batch_x), batch_y], -1)

In [None]:
do_eval_epoch(optim.model, eval_loader, rngs=rngs)

In [None]:
nnx.display(optim.model)

In [None]:
channel_max = reduce(dset.x, "e c -> c", "max")
channel_min = reduce(dset.x, "e c -> c", "min")

upper = np.array([channel_max] * stack_size)
lower = np.array([channel_min] * stack_size)

In [None]:
model_graphdef, model_state = nnx.split(model)

model.eval()

@jax.jit
def pure_call(x):
    model = nnx.merge(model_graphdef, model_state)
    return model(x)

In [None]:
output_bound = jax_verify.backward_crown_bound_propagation(
    model,
    jax_verify.IntervalBound(lower, upper),
)

In [None]:
output_bound.lower, output_bound.upper

# Now let's verify it for particular common speeds

In [None]:
def make_bound(center, to_add):
    return jax_verify.IntervalBound(center - to_add, center + to_add)


def crown_verify(centers, stdevs, factor=0.01):
    def crown(center):
        output_bound = jax_verify.backward_crown_bound_propagation(
            optim.model, make_bound(center, speed_to_add)
        )
        return output_bound.lower, output_bound.upper

    return jax.vmap(crown)(centers)


def noise_verify(centers, stdevs, factor=0.1, samples_per_center=16):

    def verify(center, key):
        normal_samples = jax.random.normal(key, (samples_per_center, *center.shape))
        perturbed_inputs = normal_samples * stdevs * factor + center
        outputs = optim.model(perturbed_inputs)
        return jnp.min(outputs, axis=0), jnp.max(outputs, axis=0)

    center_lower, center_upper = jax.vmap(verify)(centers, jax.random.split(rngs(), len(centers)))

    return jnp.min(center_lower, axis=0), jnp.max(center_upper, axis=0)

In [None]:
uniques, counts = np.unique_counts(dset.y)
common_speeds = uniques[counts > 2048]

range_size = []

for common_speed in common_speeds:
    speed_idxs = dset.validation_idxs[dset.y[dset.validation_idxs] == common_speed]
    speed_x, speed_y = dset[speed_idxs]

    speed_channel_x_stds = jnp.std(speed_x, axis=0)
    std_factor = 0.1
    speed_to_add = speed_channel_x_stds * std_factor

    # speed_bound = make_full_bound(speed_x, speed_to_add)

    def crown(x):
        output_bound = jax_verify.backward_crown_bound_propagation(
            optim.model, make_bound(x, speed_to_add)
        )
        return output_bound.lower, output_bound.upper

    if len(speed_x) > 8:
        speed_x = jax.random.choice(rngs(), speed_x, (8,), replace=False)
    speed_output_lowers, speed_output_uppers = jax.vmap(crown)(speed_x)

    speed_output_lower = jnp.min(speed_output_lowers)
    speed_output_upper = jnp.max(speed_output_uppers)
    
    range_size.append(speed_output_upper - speed_output_lower)
    
    print(
        f"For speed: {common_speed} bounds are: {speed_output_lower} to {speed_output_upper}"
    )
    
print(f"Mean range size: {np.mean(range_size)}")

In [None]:
uniques, counts = np.unique_counts(dset.y)
common_speeds = uniques[counts > 200]

range_size = []

for common_speed in common_speeds:
    speed_idxs = dset.validation_idxs[dset.y[dset.validation_idxs] == common_speed]
    speed_x, speed_y = dset[speed_idxs]

    speed_channel_x_stds = jnp.std(speed_x, axis=0)
    std_factor = 0.01
    speed_to_add = speed_channel_x_stds * std_factor

    if len(speed_x) > 256:
        speed_x = jax.random.choice(rngs(), speed_x, (256,), replace=False)

    speed_output_lowers, speed_output_uppers = noise_verify(
        speed_x, stdevs=speed_channel_x_stds, factor=std_factor
    )

    speed_output_lower = jnp.min(speed_output_lowers)
    speed_output_upper = jnp.max(speed_output_uppers)

    range_size.append(speed_output_upper - speed_output_lower)

    print(
        f"For speed: {common_speed} bounds are: {speed_output_lower} to {speed_output_upper}"
    )

print(f"Mean range size: {np.mean(range_size)}")

In [None]:
# Now if we are using all sensors lets try perturbing each one randomly separately

sensor_dims = {
    "imu": 25,
    "gon": 6,
    "emg": 12,
    # "imu": 25, "gon": 6, "emg": 12
}

sensor_ranges = {}
running = 0
for sensor in sensors:
    sensor_dim = sensor_dims[sensor]
    sensor_ranges[sensor] = slice(running, running + sensor_dim)
    running += sensor_dim

uniques, counts = np.unique_counts(dset.y)
common_speeds = uniques[counts > 200]

std_factor = 0.025

result_dict = {
    "Speed": [],
    **{f"{sensor} Perturbation": [] for sensor in sensors}
}

for common_speed in common_speeds:
    speed_idxs = dset.validation_idxs[dset.y[dset.validation_idxs] == common_speed]
    speed_x, speed_y = dset[speed_idxs]

    speed_channel_x_stds = jnp.std(speed_x, axis=0)
    speed_to_add = speed_channel_x_stds * std_factor

    if len(speed_x) > 256:
        speed_x = jax.random.choice(rngs(), speed_x, (32,), replace=False)

    for sensor in sensors:
        sensor_channel_x_std = (
            jnp.zeros_like(speed_channel_x_stds)
            .at[..., sensor_ranges[sensor]]
            .set(speed_channel_x_stds[..., sensor_ranges[sensor]])
        )

        def verify(center, key):
            normal_samples = jax.random.normal(key, (64, *center.shape))
            perturbed_inputs = (
                normal_samples * sensor_channel_x_std * std_factor + center
            )
            outputs = optim.model(perturbed_inputs)
            return jnp.min(outputs, axis=0), jnp.max(outputs, axis=0)

        speed_output_lowers, speed_output_uppers = jax.vmap(verify)(
            speed_x, jax.random.split(rngs(), len(speed_x))
        )

        # speed_output_lowers, speed_output_uppers = noise_verify(
        #     speed_x, stdevs=sensor_channel_x_std, factor=std_factor
        # )

        ranges = speed_output_uppers - speed_output_lowers

        for sample_range in np.array(ranges):
            result_dict[f"{sensor} Perturbation"].append(sample_range.item())

    # Extend speeds to max len
    max_len = max(*[len(val) for val in result_dict.values()])
    missing_len = max_len - len(result_dict["Speed"])
    result_dict["Speed"].extend([common_speed.item()] * missing_len)

    print(f"Did speed {common_speed}")

In [None]:
df = pd.DataFrame.from_dict(
    result_dict
)

df

In [None]:
sns.pairplot(df, hue="Speed"), #kind="scatter", #vars=["imu Perturbation", "gon Perturbation", "emg Perturbation"])
# output_perturbation[0].shape