In [None]:
import math
import torch
import torch.nn as nn
import key_value_bottleneck.core as kv_core
import tqdm
from einops import rearrange
import numpy as np
from copy import deepcopy
import pathlib
import matplotlib.pyplot as plt
from kv_bottleneck_experiments.utils.model import CodebookVotingLogitsDecoder
from addict import Dict

#### First, we define a few helper functions

In [None]:
# Copied from here: https://jwalton.info/Embed-Publication-Matplotlib-Latex/
def set_size(width, fraction=1):
    """Set figure dimensions to avoid scaling in LaTeX.

    Parameters
    ----------
    width: float
            Document textwidth or columnwidth in pts
    fraction: float, optional
            Fraction of the width which you wish the figure to occupy

    Returns
    -------
    fig_dim: tuple
            Dimensions of figure in inches
    """
    # Width of figure (in pts)
    fig_width_pt = width * fraction

    # Convert from pt to inches
    inches_per_pt = 1 / 72.27

    # Golden ratio to set aesthetic figure height
    # https://disq.us/p/2940ij3
    ratio = 1.3

    # Figure width in inches
    fig_width_in = fig_width_pt * inches_per_pt
    # Figure height in inches
    fig_height_in = fig_width_in * ratio

    fig_dim = (fig_width_in, fig_height_in)

    return fig_dim

def generate_plot(input_xy, prediction_xy, title, dir_path, file_name, samples_of_tasks=None):
    if samples_of_tasks is None:
        samples_of_tasks = []
    width = 398
    fig, ax = plt.subplots(1, 1, figsize=set_size(width, fraction=0.16))
    plt.style.use('seaborn')
    tex_fonts = {
       "text.usetex": True,
       "font.family": "serif",
       "axes.labelsize": 8,
       "axes.titlesize": 6,
       "font.size": 8,
       "legend.fontsize": 8,
       "xtick.labelsize": 8,
       "ytick.labelsize": 8
    }

    plt.rcParams.update(tex_fonts)

    ax.scatter(input_xy[:,0, 0], input_xy[:,0, 1], s=0.5, marker="s", cmap="Set2", c=prediction_xy.detach().numpy(), vmin=0, vmax=8)
    for task_id in samples_of_tasks:
        train_inputs, targets = get_train_stream_data(task_id)
        ax.scatter(train_inputs[:, 0, 0], train_inputs[:, 0, 1], cmap="Set2", edgecolors="black", s=2, c=targets.detach().numpy(), vmin=0, vmax=8)
    ax.set_xlabel("$x_1$")
    ax.set_ylabel("$x_2$")
    ax.set_xlim(0.05, 0.95)
    ax.set_ylim(0.05, 0.95)
    ax.set_aspect(1.0)
    ax.set_axis_off()
    ax.set_title(title)
    pathlib.Path(dir_path).mkdir(parents=True, exist_ok=True)
    fig.savefig(f"{dir_path}/{file_name}.pdf", format='pdf')
    plt.close()

#### 2D toy domain data loading functions

In [None]:
def sample_key_init_data(n_samples):
    return torch.rand(n_samples, 2)

def get_train_stream_data(task_id, n_samples=100):
    """
    Returns samples from a class_incremental data stream that comprises 8 classes with each task covering 2 classes.
    """
    tasks = [{"classes": [0, 4],
              "means": [[0.25, 0.25], [0.75, 0.5]],
              "std": [[0.035, 0.035], [0.035, 0.035]],},
             {"classes": [1, 5],
              "means": [[0.75, 0.25], [0.5, 0.75]],
              "std": [[0.035, 0.035], [0.035, 0.035]],},
             {"classes": [2, 6],
              "means": [[0.75, 0.75], [0.25, 0.5]],
              "std": [[0.035, 0.035], [0.035, 0.035]],},
             {"classes": [3, 7],
              "means": [[0.25, 0.75], [0.5, 0.25]],
              "std": [[0.035, 0.035], [0.035, 0.035]],}]

    # generate data
    n_dim = 2
    x = np.zeros((n_samples, n_dim))
    y = np.zeros(n_samples)
    for i in range(n_samples):
        class_id = np.random.choice([0,1])
        x[i, 0] = torch.normal(mean=torch.tensor(tasks[task_id]["means"][class_id][0]), std=torch.tensor(tasks[task_id]["std"][class_id][0]))
        x[i, 1] = torch.normal(mean=torch.tensor(tasks[task_id]["means"][class_id][1]), std=torch.tensor(tasks[task_id]["std"][class_id][1]))
        y[i] = tasks[task_id]["classes"][class_id]
    return torch.tensor(x, dtype=torch.float32)[:, None, :], torch.tensor(y, dtype=torch.int64)

def get_grid_samples_from_unit_square(num_samples_per_dim=100):
    xs = torch.linspace(0, 1, steps=num_samples_per_dim)
    ys = torch.linspace(0, 1, steps=num_samples_per_dim)
    x, y = torch.meshgrid(xs, ys, indexing='xy')
    xy_tensor = torch.stack([torch.flatten(x)[:, None], torch.flatten(y)[:, None]], dim=-1)
    return xy_tensor

#### Model helper functions

In [None]:
class KVModel(torch.nn.Module):
    def __init__(self, bottlenecked_encoder, decoder, args):
        super().__init__()
        self.bottlenecked_encoder = bottlenecked_encoder
        self.decoder = decoder
        self.args = args

    def forward(self, x):
        bottleneck_tuple = self.bottlenecked_encoder(x)
        x = bottleneck_tuple[0]
        x = self.decoder(x=x)
        return x

### MLP baseline: Let's first investigating the behaviour of a naive MLP

In [None]:
model_mlp = nn.Sequential(nn.Linear(2, 32),
                          nn.ReLU(),
                          nn.Linear(32, 8))

We first plot the decision boundaries of this randomly initialized mlp model

In [None]:
%%capture
xy_tensor = get_grid_samples_from_unit_square()
model_mlp_predictions = torch.argmax(model_mlp(xy_tensor[:, 0, :]), dim=-1)
generate_plot(input_xy=xy_tensor,
              prediction_xy=model_mlp_predictions,
              title="MLP - after init",
              dir_path="../artefacts/toy_experiment",
              file_name="mlp_predictions_after_init")

Next, we'll see how the decision boundaries change when trained on the data stream

In [None]:
%%capture
model_mlp.train()
model_state_dicts = []
for task_id in range(4):
    optimizer_model_mlp = torch.optim.Adam(model_mlp.parameters(), lr=0.001)
    train_inputs, targets = get_train_stream_data(task_id)
    for epoch in range(1000):
        optimizer_model_mlp.zero_grad()
        outputs = model_mlp(train_inputs[:, 0, :])
        loss = nn.CrossEntropyLoss()(outputs, targets)
        loss.backward()
        optimizer_model_mlp.step()
    model_state_dicts.append(deepcopy(model_mlp.state_dict()))

model_mlp_predictions_trained = []
for task_id in range(4):
    model_mlp.load_state_dict(model_state_dicts[task_id])
    model_mlp_predictions_trained.append(torch.argmax(model_mlp(xy_tensor[:, 0, :]), dim=-1))
    generate_plot(input_xy=xy_tensor,
                  prediction_xy=model_mlp_predictions_trained[task_id],
                  title=f"MLP - after $D_{task_id+1}$",
                  dir_path="../artefacts/toy_experiment",
                  file_name=f"mlp_predictions_after_k_{task_id+1}",
                  samples_of_tasks=[task_id])

#Plot the final decision boundaries including all visited task sample
generate_plot(input_xy=xy_tensor,
              prediction_xy=model_mlp_predictions_trained[-1],
              title=f"MLP - final",
              dir_path="../artefacts/toy_experiment",
              file_name="mlp_final_decision_boundaries",
              samples_of_tasks=[0, 1, 2, 3])

Next, we'll repeat the same with linear probe only

In [None]:
%%capture
model_lp = nn.Sequential(nn.Linear(2, 8))
xy_tensor = get_grid_samples_from_unit_square()
model_mlp_predictions = torch.argmax(model_lp(xy_tensor[:, 0, :]), dim=-1)
generate_plot(input_xy=xy_tensor,
              prediction_xy=model_mlp_predictions,
              title="LP - after init",
              dir_path="../artefacts/toy_experiment",
              file_name="lp_predictions_after_init")
model_mlp.train()
model_state_dicts = []
for task_id in range(4):
    optimizer_model_lp = torch.optim.Adam(model_lp.parameters(), lr=0.01)
    train_inputs, targets = get_train_stream_data(task_id)
    for epoch in range(1000):
        optimizer_model_lp.zero_grad()
        outputs = model_lp(train_inputs[:, 0, :])
        loss = nn.CrossEntropyLoss()(outputs, targets)
        loss.backward()
        optimizer_model_lp.step()
    model_state_dicts.append(deepcopy(model_lp.state_dict()))

model_lp_predictions_trained = []
for task_id in range(4):
    model_lp.load_state_dict(model_state_dicts[task_id])
    model_lp_predictions_trained.append(torch.argmax(model_lp(xy_tensor[:, 0, :]), dim=-1))
    generate_plot(input_xy=xy_tensor,
                  prediction_xy=model_lp_predictions_trained[task_id],
                  title=f"LP - after $D_{task_id+1}$",
                  dir_path="../artefacts/toy_experiment",
                  file_name=f"lp_predictions_after_k_{task_id+1}",
                  samples_of_tasks=[task_id])

#Plot the final decision boundaries including all visited task sample
generate_plot(input_xy=xy_tensor,
              prediction_xy=model_lp_predictions_trained[-1],
              title=f"LP - final",
              dir_path="../artefacts/toy_experiment",
              file_name="lp_final_decision_boundaries",
              samples_of_tasks=[0, 1, 2, 3])

Next, we'll see what would happen if we insert a Discrete Key-Value Bottleneck and combine it with our decoder architecture

In [None]:
args = Dict()
args.num_pairs = 400
args.cl_epochs = 1000
args.init_epochs = 1000
args.num_codebooks = 1
args.input_dims = 2
args.dim_value = 8
args.dim_key = 2
args.topk = 1
args.num_classes = 8
args.ff_dropout = 0.0

decoder = CodebookVotingLogitsDecoder(dim_values=args.num_classes,
                                       class_nums=args.num_classes,
                                       args=args
                                       )

bottlenecked_encoder = kv_core.BottleneckedEncoder(encoder=nn.Identity(),
                                                   num_codebooks=args.num_codebooks,
                                                   num_channels=args.input_dims,
                                                   key_value_pairs_per_codebook=args.num_pairs,
                                                   dim_keys=args.dim_key,
                                                   dim_values=decoder.dim_values,
                                                   splitting_mode="chunk",
                                                   return_values_only=False,
                                                   encoder_is_channel_last=False,
                                                   concat_values_from_all_codebooks=False)

values = 0.001*torch.randn_like(bottlenecked_encoder.bottleneck.values)
bottlenecked_encoder.bottleneck.values = nn.Parameter(values)
bottlenecked_encoder.bottleneck.values.requires_grad = True

model_kv = KVModel(bottlenecked_encoder=bottlenecked_encoder,
                   decoder=decoder,
                   args=args)

for _ in range(args.init_epochs):
    model_kv(sample_key_init_data(100))
bottlenecked_encoder.freeze_keys()
bottlenecked_encoder.disable_update_keys()

We first plot the decision boundaries of our model with keys initialized on random input data

In [None]:
%%capture
xy_tensor = get_grid_samples_from_unit_square()
model_kv_predictions = torch.argmax(model_kv(xy_tensor[:, 0, :]), dim=-1)
generate_plot(input_xy=xy_tensor,
              prediction_xy=model_kv_predictions,
              title="KV - after init",
              dir_path="../artefacts/toy_experiment",
              file_name="kv_chunk_dec1_predictions_after_init")

Next, we'll see how the decision boundaries change when trained on the data stream

In [None]:
%%capture
bottlenecked_encoder.reset_cluster_size_counter()
bottlenecked_encoder.activate_counts()
model_kv.train()
optimizer_model_kv = torch.optim.Adam(model_kv.parameters(), lr=0.001)

model_state_dicts = []
for task_id in tqdm.tqdm(range(4)):
    train_inputs, targets = get_train_stream_data(task_id)
    for epoch in range(args.cl_epochs):
        optimizer_model_kv.zero_grad()
        outputs = model_kv(train_inputs[:, 0, :])
        loss = nn.CrossEntropyLoss()(outputs, targets)
        loss.backward()
        optimizer_model_kv.step()
    model_state_dicts.append(deepcopy(model_kv.state_dict()))

model_kv.eval()
bottlenecked_encoder.deactivate_counts()
model_kv_predictions_trained = []
for task_id in range(4):
    model_kv.load_state_dict(model_state_dicts[task_id])
    model_kv_predictions_trained.append(torch.argmax(model_kv(xy_tensor[:, 0, :]), dim=-1))
    generate_plot(input_xy=xy_tensor,
                  prediction_xy=model_kv_predictions_trained[task_id],
                  title=f"KV - after $D_{task_id+1}$",
                  dir_path="../artefacts/toy_experiment",
                  file_name=f"kv_chunk_dec1_predictions_after_k_{task_id+1}",
                  samples_of_tasks=[task_id])

#Plot the final decision boundaries including all visited task sample
generate_plot(input_xy=xy_tensor,
              prediction_xy=model_kv_predictions_trained[-1],
              title=f"KV - final",
              dir_path="../artefacts/toy_experiment",
              file_name="kv_chunk_dec1_final_decision_boundaries",
              samples_of_tasks=[0, 1, 2, 3])

### Finally, we'll investigate what would happen if we would use random projections and 20 codebooks with 20 key-value pairs each

In [None]:
args = Dict()
args.num_pairs = 20
args.cl_epochs = 1000
args.init_epochs = 1000
args.num_codebooks = 20
args.input_dims = 2
args.dim_value = 8
args.dim_key = 2
args.topk = 1
args.num_classes = 8
args.ff_dropout = 0.0

decoder = CodebookVotingLogitsDecoder(dim_values=args.num_classes,
                                       class_nums=args.num_classes,
                                       args=args
                                       )

bottlenecked_encoder = kv_core.BottleneckedEncoder(encoder=nn.Identity(),
                                                   num_codebooks=args.num_codebooks,
                                                   num_channels=args.input_dims,
                                                   key_value_pairs_per_codebook=args.num_pairs,
                                                   dim_keys=args.dim_key,
                                                   dim_values=decoder.dim_values,
                                                   splitting_mode="random_projection",
                                                   return_values_only=False,
                                                   encoder_is_channel_last=False,
                                                   concat_values_from_all_codebooks=False)

values = 0.001*torch.randn_like(bottlenecked_encoder.bottleneck.values)
bottlenecked_encoder.bottleneck.values = nn.Parameter(values)
bottlenecked_encoder.bottleneck.values.requires_grad = True

model_kv = KVModel(bottlenecked_encoder=bottlenecked_encoder,
                   decoder=decoder,
                   args=args)

for _ in range(args.init_epochs):
    model_kv(sample_key_init_data(100))
bottlenecked_encoder.freeze_keys()
bottlenecked_encoder.disable_update_keys()

In [None]:
%%capture
xy_tensor = get_grid_samples_from_unit_square()
model_kv_predictions = torch.argmax(model_kv(xy_tensor[:, 0, :]), dim=-1)
generate_plot(input_xy=xy_tensor,
              prediction_xy=model_kv_predictions,
              title="KV - after init",
              dir_path="../artefacts/toy_experiment",
              file_name="kv_random_dec1_predictions_after_init")

In [None]:
%%capture
bottlenecked_encoder.reset_cluster_size_counter()
bottlenecked_encoder.activate_counts()
model_kv.train()
for p in model_kv.parameters():
    p.requires_grad = False
bottlenecked_encoder.bottleneck.values.requires_grad = True
bottlenecked_encoder.disable_update_keys()
optimizer_model_kv = torch.optim.Adam(model_kv.parameters(), lr=0.001)

model_state_dicts = []
for task_id in range(4):
    train_inputs, targets = get_train_stream_data(task_id)
    for epoch in range(1000):
        optimizer_model_kv.zero_grad()
        outputs = model_kv(train_inputs[:, 0, :])
        loss = nn.CrossEntropyLoss()(outputs, targets)
        loss.backward()
        optimizer_model_kv.step()
    model_state_dicts.append(deepcopy(model_kv.state_dict()))

model_kv.eval()
bottlenecked_encoder.deactivate_counts()
model_kv_predictions_trained = []
for task_id in range(4):
    model_kv.load_state_dict(model_state_dicts[task_id])
    model_kv_predictions_trained.append(torch.argmax(model_kv(xy_tensor[:, 0, :]), dim=-1))
    generate_plot(input_xy=xy_tensor,
                  prediction_xy=model_kv_predictions_trained[task_id],
                  title=f"KV - after $D_{task_id+1}$",
                  dir_path="../artefacts/toy_experiment",
                  file_name=f"kv_random_dec1_predictions_after_k_{task_id+1}",
                  samples_of_tasks=[task_id])

#Plot the final decision boundaries including all visited task sample
generate_plot(input_xy=xy_tensor,
              prediction_xy=model_kv_predictions_trained[-1],
              title=f"KV - final",
              dir_path="../artefacts/toy_experiment",
              file_name="kv_random_dec1_final_decision_boundaries",
              samples_of_tasks=[0, 1, 2, 3])