<a href="https://colab.research.google.com/github/nitinsb/Advance_C_attendance-management-code/blob/master/baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# baseline : permute,-> context 1
# Standard libraries
import numpy as np
import copy
import tqdm
# Pytorch
import torch
from torch.nn import functional as F
from torchvision import datasets, transforms
# For visualization
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

In [None]:
MNIST_trainset = datasets.MNIST(root='data/', train=True, download=True,
                                transform=transforms.ToTensor())
MNIST_testset = datasets.MNIST(root='data/', train=False, download=True,
                               transform=transforms.ToTensor())
config = {'size': 28, 'channels': 1, 'classes': 10}


In [None]:
#@title Visualization functions
def multi_context_barplot(axis, accs, title=None):
    '''Generate barplot using the values in [accs].'''
    contexts = len(accs)
    axis.bar(range(contexts), accs, color='k')
    axis.set_ylabel('Testing Accuracy (%)')
    axis.set_xticks(range(contexts), [f'Context {i+1}' for i in range(contexts)])
    if title is not None:
        axis.set_title(title)

def plot_examples(axis, dataset, context_id=None):
    '''Plot 25 examples from [dataset].'''
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=25, shuffle=True)
    image_tensor, _ = next(iter(data_loader))
    image_grid = make_grid(image_tensor, nrow=5, pad_value=1) # pad_value=0 would give black borders
    axis.imshow(np.transpose(image_grid.numpy(), (1,2,0)))
    if context_id is not None:
        axis.set_title("Context {}".format(context_id+1))
    axis.axis('off')


In [None]:
# Function to apply a given permutation the pixels of an image.
def permutate_image_pixels(image, permutation):
    '''Permutate the pixels of [image] according to [permutation].'''

    if permutation is None:
        return image
    else:
        c, h, w = image.size()
        image = image.view(c, -1)
        image = image[:, permutation]  #--> same permutation for each channel
        image = image.view(c, h, w)
        return image

In [None]:

# Class to create a dataset with images that have all been transformed in the same way.
class TransformedDataset(torch.utils.data.Dataset):
    '''To modify an existing dataset with a transform.
    Useful for creating different permutations of MNIST without loading the data multiple times.'''

    def __init__(self, original_dataset, transform=None, target_transform=None):
        super().__init__()
        self.dataset = original_dataset
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        (input, target) = self.dataset[index]
        if self.transform:
            input = self.transform(input)
        if self.target_transform:
            target = self.target_transform(target)
        return (input, target)


In [None]:
contexts = 2

In [None]:
# Specify for each context the permutations to use (with no permutation for the first context)
permutations = [None] + [np.random.permutation(config['size']**2) for _ in range(contexts-1)]

In [None]:
# Specify for each context the transformed train- and testset
train_datasets = []
test_datasets = []
for context_id, perm in enumerate(permutations):
    train_datasets.append(TransformedDataset(
        MNIST_trainset, transform=transforms.Lambda(lambda x, p=perm: permutate_image_pixels(x, p)),
    ))
    test_datasets.append(TransformedDataset(
        MNIST_testset, transform=transforms.Lambda(lambda x, p=perm: permutate_image_pixels(x, p)),
    ))


In [None]:
# Visualize the contexts
figure, axis = plt.subplots(1, contexts, figsize=(3*contexts, 4))

for context_id in range(len(train_datasets)):
    plot_examples(axis[context_id], train_datasets[context_id], context_id=context_id)

In [None]:
#@title Helper functions

class Identity(torch.nn.Module):
    '''A nn-module to simply pass on the input data.'''
    def forward(self, x):
        return x

    def __repr__(self):
        tmpstr = self.__class__.__name__ + '()'
        return tmpstr


class Flatten(torch.nn.Module):
    '''A nn-module to flatten a multi-dimensional tensor to 2-dim tensor.'''
    def forward(self, x):
        batch_size = x.size(0)   # first dimenstion should be batch-dimension.
        return x.view(batch_size, -1)

    def __repr__(self):
        tmpstr = self.__class__.__name__ + '()'
        return tmpstr


class fc_layer(torch.nn.Module):
    '''Fully connected layer, with possibility of returning "pre-activations".

    Input:  [batch_size] x ... x [in_size] tensor
    Output: [batch_size] x ... x [out_size] tensor'''

    def __init__(self, in_size, out_size, nl=torch.nn.ReLU(), bias=True):
        super().__init__()
        self.bias = bias
        self.linear = torch.nn.Linear(in_size, out_size, bias=bias)
        if isinstance(nl, torch.nn.Module):
            self.nl = nl
        elif nl=="relu":
            self.nl = torch.nn.ReLU()
        elif nl=="leakyrelu":
          self.nl = torch.nn.LeakyReLU()

    def forward(self, x):
        pre_activ = self.linear(x)
        output = self.nl(pre_activ) if hasattr(self, 'nl') else pre_activ
        return output


class MLP(torch.nn.Module):
    '''Module for a multi-layer perceptron (MLP).

    Input:  [batch_size] x ... x [size_per_layer[0]] tensor
    Output: (tuple of) [batch_size] x ... x [size_per_layer[-1]] tensor'''

    def __init__(self, input_size=1000, output_size=10, layers=2,
                 hid_size=1000, hid_smooth=None, size_per_layer=None,
                 nl="relu", bias=True, output='normal'):
        '''sizes: 0th=[input], 1st=[hid_size], ..., 1st-to-last=[hid_smooth], last=[output].
        [input_size]       # of inputs
        [output_size]      # of units in final layer
        [layers]           # of layers
        [hid_size]         # of units in each hidden layer
        [hid_smooth]       if None, all hidden layers have [hid_size] units, else # of units linearly in-/decreases s.t.
                             final hidden layer has [hid_smooth] units (if only 1 hidden layer, it has [hid_size] units)
        [size_per_layer]   None or  with for each layer number of units (1st element = number of inputs)
                                --> overwrites [input_size], [output_size], [layers], [hid_size] and [hid_smooth]
        [nl]               ; type of non-linearity to be used (options: "relu", "leakyrelu", "none")
        [output]           ; if - "normal", final layer is same as all others
                                     - "none", final layer has no non-linearity
                                     - "sigmoid", final layer has sigmoid non-linearity'''

        super().__init__()
        self.output = output

        # get sizes of all layers
        if size_per_layer is None:
            hidden_sizes = []
            if layers > 1:
                if (hid_smooth is not None):
                    hidden_sizes = [int(x) for x in np.linspace(hid_size, hid_smooth, num=layers-1)]
                else:
                    hidden_sizes = [int(x) for x in np.repeat(hid_size, layers - 1)]
            size_per_layer = [input_size] + hidden_sizes + [output_size] if layers>0 else [input_size]
        self.layers = len(size_per_layer)-1

        # set label for this module
        # -determine "non-default options"-label
        nd_label = "{bias}{nl}".format(
            bias="" if bias else "n",
            nl="l" if nl=="leakyrelu" else ("n" if nl=="none" else ""),
        )
        nd_label = "{}{}".format("" if nd_label=="" else "-{}".format(nd_label),
                                 "" if output=="normal" else "-{}".format(output))
        # -set label
        size_statement = ""
        for i in size_per_layer:
            size_statement += "{}{}".format("-" if size_statement=="" else "x", i)
        self.label = "F{}{}".format(size_statement, nd_label) if self.layers>0 else ""

        # set layers
        for lay_id in range(1, self.layers+1):
            # number of units of this layer's input and output
            in_size = size_per_layer[lay_id-1]
            out_size = size_per_layer[lay_id]
            # define and set the fully connected layer
            layer = fc_layer(
                in_size, out_size, bias=bias,
                nl=("none" if output=="none" else nn.Sigmoid()) if (
                    lay_id==self.layers and not output=="normal"
                ) else nl,
            )
            setattr(self, 'fcLayer{}'.format(lay_id), layer)

        # if no layers, add "identity"-module to indicate in this module's representation nothing happens
        if self.layers<1:
            self.noLayers = Identity()

    def forward(self, x):
        for lay_id in range(1, self.layers + 1):
            x = getattr(self, "fcLayer{}".format(lay_id))(x)
        return x

    @property
    def name(self):
        return self.label




class Classifier(torch.nn.Module):
    '''Model for classifying images.'''

    def __init__(self, image_size, image_channels, output_units,
                 fc_layers=3, fc_units=1000, fc_nl="relu", bias=True):

        super().__init__()

       # Flatten image to 2D-tensor
        self.flatten = Flatten()

        # Specify the fully connected hidden layers
        input_size = image_channels * image_size * image_size
        self.fcE = MLP(input_size=input_size, output_size=fc_units, layers=fc_layers-1,
                       hid_size=fc_units, nl=fc_nl, bias=bias)
        mlp_output_size = fc_units if fc_layers>1 else self.input_size

        # Specify the final linear classifier layer
        self.classifier = fc_layer(mlp_output_size, output_units, nl='none')

    def forward(self, x):
        flatten_x = self.flatten(x)
        final_features = self.fcE(flatten_x)
        out = self.classifier(final_features)
        return out


In [None]:

# Specify the architectural layout of the network to use
fc_lay = 4        #--> number of fully-connected layers
fc_units = 40     #--> number of units in each hidden layer
fc_nl = "relu"    #--> what non-linearity to use?

In [None]:
# Define the model
model = Classifier(image_size=config['size'], image_channels=config['channels'],
                   output_units=config['classes'],
                   fc_layers=fc_lay, fc_units=fc_units, fc_nl=fc_nl)

In [None]:
model

In [None]:
total_params = 0
for param in model.parameters():
    n_params = index_dims = 0
    for dim in param.size():
        n_params = dim if index_dims==0 else n_params*dim
        index_dims += 1
    total_params += n_params
print( "--> this network has {} parameters (~{}K)"
      .format(total_params, round(total_params / 1000)))

In [None]:
def train(model, dataset, iters, lr, batch_size):
    # Define the optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))

    # Set model in training-mode
    model.train()

    # Initialize # iters left on current data-loader(s)
    iters_left = 1

    # Define tqdm progress bar(s)
    progress_bar = tqdm.tqdm(range(1, iters+1))

    # Loop over all iterations
    for batch_index in range(1, iters+1):

        # Update # iters left on current data-loader(s) and, if needed, create new one(s)
        iters_left -= 1
        if iters_left==0:
            data_loader = iter(torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                          shuffle=True, drop_last=True))
            iters_left = len(data_loader)

        # Sample training data of current context
        x, y = next(data_loader)

        # Reset optimizer
        optimizer.zero_grad()

        # Run model
        y_hat = model(x)

        # Calculate prediction loss
        loss = torch.nn.functional.cross_entropy(input=y_hat, target=y, reduction='mean')

        # Calculate training-accuracy (in %)
        accuracy = (y == y_hat.max(1)[1]).sum().item()*100 / x.size(0)

        # Backpropagate errors
        loss.backward()

        # Take the optimizer step
        optimizer.step()

        # Update progress bar
        progress_bar.set_description(
        ' | training loss: {loss:.3} | training accuracy: {prec:.3}% |'
            .format(loss=loss.item(), prec=accuracy)
        )
        progress_bar.update(1)

    # Close the progress bar
    progress_bar.close()

In [None]:
iters = 200       # for how many iterations to train?
lr = 0.01         # learning rate
batch_size = 128  # size of mini-batches

In [None]:
train(model, dataset=train_datasets[0], iters=iters, lr=lr, batch_size=batch_size)

In [None]:

def test_acc(model, dataset, test_size=None, batch_size=128):
    '''Evaluate accuracy (% samples classified correctly) of a classifier ([model]) on [dataset].'''

    # Set model to eval()-mode
    mode = model.training
    model.eval()

    # Loop over batches in [dataset]
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                              shuffle=True, drop_last=False)
    total_tested = total_correct = 0
    for x, y in data_loader:
        # -break on [test_size] (if "None", full dataset is used)
        if test_size:
            if total_tested >= test_size:
                break
        # -evaluate model
        with torch.no_grad():
            scores = model(x)
        _, predicted = torch.max(scores, 1)
        # -update statistics
        total_correct += (predicted == y).sum().item()
        total_tested += len(x)
    accuracy = total_correct*100 / total_tested

    # Set model back to its initial mode, print result on screen (if requested) and return it
    model.train(mode=mode)

    return accuracy

In [None]:

# Evaluate accuracy per context and print to screen
print("\n Accuracy (in %) of the model on test-set of:")
context1_accs = []
for i in range(contexts):
    acc = test_acc(model, test_datasets[i], test_size=None)
    print(" - Context {}: {:.1f}".format(i+1, acc))
    context1_accs.append(acc)

In [None]:
model_after_context1 = copy.deepcopy(model)

In [None]:
# Continue to train the model on the second context
train(model, dataset=train_datasets[1], iters=iters, lr=lr, batch_size=batch_size)

In [None]:
# Evaluate accuracy per context and print to screen
print("\n Accuracy (in %) of the model on test-set of:")
context2_accs = []
for i in range(contexts):
    acc = test_acc(model, test_datasets[i], test_size=None)
    print(" - Context {}: {:.1f}".format(i+1, acc))
    context2_accs.append(acc)

In [None]:
# Define a new model with same architecture
model_joint = Classifier(image_size=config['size'], image_channels=config['channels'],
                         output_units=config['classes'],
                         fc_layers=fc_lay, fc_units=fc_units, fc_nl=fc_nl)

In [None]:
# Create a joint dataset with data from both contexts
joint_trainset = torch.utils.data.ConcatDataset(train_datasets)

In [None]:
batch_size_joint = 256


In [None]:
# Train the joint model
train(model_joint, dataset=joint_trainset, iters=iters, lr=lr, batch_size=batch_size_joint)

In [None]:

# Evaluate the model
print("\n Accuracy (in %) of the model on test-set of:")
joint_accs = []
for i in range(contexts):
    acc = test_acc(model_joint, test_datasets[i], test_size=None)
    print(" - Context {}: {:.1f}".format(i+1, acc))
    joint_accs.append(acc)


In [None]:

# Visualize
figure, axis = plt.subplots(1, 4, figsize=(15, 5))

title='After training on context 1, \nbut not yet training on context 2'
multi_context_barplot(axis[0], context1_accs, title)

title='After first training on context 1, \nand then training on context 2'
multi_context_barplot(axis[1], context2_accs, title)

axis[2].axis('off')

title='After jointly training on both contexts'
multi_context_barplot(axis[3], joint_accs, title)

In [None]:
model_ewc = copy.deepcopy(model_after_context1)
model_replay = copy.deepcopy(model_after_context1)

In [None]:
def estimate_fisher(model, dataset, n_samples, ewc_gamma=1.):
    '''Estimate diagonal of Fisher Information matrix for [model] on [dataset] using [n_samples].'''

    # Prepare <dict> to store estimated Fisher Information matrix
    est_fisher_info = {}
    for n, p in model.named_parameters():
        n = n.replace('.', '__')
        est_fisher_info[n] = p.detach().clone().zero_()

    # Set model to evaluation mode
    mode = model.training
    model.eval()

    # Create data-loader to give batches of size 1
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)

    # Estimate the FI-matrix for [n_samples] batches of size 1
    for index,(x,y) in enumerate(data_loader):
        # break from for-loop if max number of samples has been reached
        if n_samples is not None:
            if index > n_samples:
                break
        # run forward pass of model
        output = model(x)
        # calculate the FI-matrix
        with torch.no_grad():
            label_weights = F.softmax(output, dim=1)  #--> get weights, with no gradient tracked
        # - loop over all classes
        for label_index in range(output.shape[1]):
            label = torch.LongTensor([label_index])
            negloglikelihood = F.cross_entropy(output, label)
            # Calculate gradient of negative loglikelihood for this class
            model.zero_grad()
            negloglikelihood.backward(retain_graph=True if (label_index+1)<output.shape[1] else False)
            # Square gradients and keep running sum (using the weights)
            for n, p in model.named_parameters():
                n = n.replace('.', '__')
                if p.grad is not None:
                    est_fisher_info[n] += label_weights[0][label_index] * (p.grad.detach() ** 2)

    # Normalize by sample size used for estimation
    est_fisher_info = {n: p/index for n, p in est_fisher_info.items()}

    # Store new values in the network
    for n, p in model.named_parameters():
        n = n.replace('.', '__')
        # -mode (=MAP parameter estimate)
        model.register_buffer('{}_EWC_param_values'.format(n,), p.detach().clone())
        # -precision (approximated by diagonal Fisher Information matrix)
        if hasattr(model, '{}_EWC_estimated_fisher'.format(n)):
            existing_values = getattr(model, '{}_EWC_estimated_fisher'.format(n))
            est_fisher_info[n] += ewc_gamma * existing_values
        model.register_buffer('{}_EWC_estimated_fisher'.format(n), est_fisher_info[n])

    # Set model back to its initial mode
    model.train(mode=mode)

In [None]:
# (only the steps that differ from the original `train`-function are commented)
def train_ewc(model, dataset, iters, lr, batch_size, current_context, ewc_lambda=100.):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))
    model.train()
    iters_left = 1
    progress_bar = tqdm.tqdm(range(1, iters+1))

    for batch_index in range(1, iters+1):
        iters_left -= 1
        if iters_left==0:
            data_loader = iter(torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                           shuffle=True, drop_last=True))
            iters_left = len(data_loader)
        x, y = next(data_loader)
        optimizer.zero_grad()
        y_hat = model(x)
        loss = torch.nn.functional.cross_entropy(input=y_hat, target=y, reduction='mean')

        # Compute the EWC-regularization term, and add it to the loss (except if first context)
        if current_context>1:
            ewc_losses = []
            for n, p in model.named_parameters():
                # Retrieve stored mode (MAP estimate) and precision (Fisher Information matrix)
                n = n.replace('.', '__')
                mean = getattr(model, '{}_EWC_param_values'.format(n))
                fisher = getattr(model, '{}_EWC_estimated_fisher'.format(n))
                # Calculate weight regularization loss
                ewc_losses.append((fisher * (p-mean)**2).sum())
            ewc_loss = (1./2)*sum(ewc_losses)
            total_loss = loss + ewc_lambda*ewc_loss
        else:
            total_loss = loss

        accuracy = (y == y_hat.max(1)[1]).sum().item()*100 / x.size(0)
        total_loss.backward()
        optimizer.step()
        progress_bar.set_description(
        '<CLASSIFIER> | training loss: {loss:.3} | training accuracy: {prec:.3}% |'
            .format(loss=total_loss.item(), prec=accuracy)
        )
        progress_bar.update(1)
    progress_bar.close()

In [None]:
estimate_fisher(model_ewc, train_datasets[0], n_samples=200)

In [None]:
# Train on the second context using EWC parameter regularization
ewc_lambda = 100   #--> this is a "continual learning hyperparameter", setting these is a delicate
                   #    business. Here we ignore that and just use one that gives good performance.
train_ewc(model_ewc, train_datasets[1], iters=iters, lr=lr, batch_size=batch_size,
          current_context=2, ewc_lambda=ewc_lambda)

In [None]:
# Evaluate the model
print("\n Accuracy (in %) of the model on test-set of:")
ewc_accs = []
for i in range(contexts):
    acc = test_acc(model_ewc, test_datasets[i], test_size=None)
    print(" - Context {}: {:.1f}".format(i+1, acc))
    ewc_accs.append(acc)

In [None]:
#@title Helper dataset classes for constructing memory buffer
class SubDataset(torch.utils.data.Dataset):
    '''To sub-sample a dataset, taking only those samples with label in [sub_labels].

    After this selection of samples has been made, it is possible to transform the target-labels,
    which can be useful when doing continual learning with fixed number of output units.'''

    def __init__(self, original_dataset, sub_labels, target_transform=None):
        super().__init__()
        self.dataset = original_dataset
        self.sub_indeces = []
        for index in range(len(self.dataset)):
            if hasattr(original_dataset, "targets"):
                if self.dataset.target_transform is None:
                    label = self.dataset.targets[index]
                else:
                    label = self.dataset.target_transform(self.dataset.targets[index])
            else:
                label = self.dataset[index][1]
            if label in sub_labels:
                self.sub_indeces.append(index)
        self.target_transform = target_transform

    def __len__(self):
        return len(self.sub_indeces)

    def __getitem__(self, index):
        sample = self.dataset[self.sub_indeces[index]]
        if self.target_transform:
            target = self.target_transform(sample[1])
            sample = (sample[0], target)
        return sample


class MemorySetDataset(torch.utils.data.Dataset):
    '''Create dataset from list of <np.arrays> with shape (N, C, H, W) (i.e., with N images each).

    The images at the i-th entry of [memory_sets] belong to class [i],
    unless a [target_transform] is specified
    '''

    def __init__(self, memory_sets, target_transform=None):
        super().__init__()
        self.memory_sets = memory_sets
        self.target_transform = target_transform

    def __len__(self):
        total = 0
        for class_id in range(len(self.memory_sets)):
            total += len(self.memory_sets[class_id])
        return total

    def __getitem__(self, index):
        total = 0
        for class_id in range(len(self.memory_sets)):
            examples_in_this_class = len(self.memory_sets[class_id])
            if index < (total + examples_in_this_class):
                class_id_to_return = class_id if self.target_transform is None else self.target_transform(class_id)
                example_id = index - total
                break
            else:
                total += examples_in_this_class
        image = torch.from_numpy(self.memory_sets[class_id][example_id])
        return (image, class_id_to_return)

In [None]:
# Fill the memory buffer using class-balanced random sampling
def fill_memory_buffer(memory_sets, dataset, buffer_size_per_class, class_indeces):
    '''This function is rather slow and can be optimized.'''
    for class_id in class_indeces:
        # Create dataset with only instances of one class
        class_dataset = SubDataset(original_dataset=dataset, sub_labels=[class_id])

        # Randomly select which indeces to store in the buffer
        n_total = len(class_dataset)
        indeces_selected = np.random.choice(n_total, size=min(buffer_size_per_class, n_total),
                                            replace=False)

        # Select those indeces
        memory_set = []
        for k in indeces_selected:
            memory_set.append(class_dataset[k][0].numpy())

        # Add this [memory_set] as a [n]x[ich]x[isz]x[isz] to the list of [memory_sets]
        memory_sets.append(np.array(memory_set))

    return memory_sets

In [None]:
buffer_size_per_class = 20
memory_sets = []
# The next command is unneccesary slow, apologies! Bonus question: optimize this implementation :)
memory_sets = fill_memory_buffer(memory_sets, train_datasets[0],
                                 buffer_size_per_class=buffer_size_per_class,
                                 class_indeces=list(range(10)))
buffer_dataset = MemorySetDataset(memory_sets)

In [None]:
# (only the steps that differ from the original `train`-function are commented)
def train_replay(model, dataset, iters, lr, batch_size, current_context, buffer_dataset=None):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))
    model.train()
    iters_left = 1
    iters_left_replay = 1
    progress_bar = tqdm.tqdm(range(1, iters+1))

    for batch_index in range(1, iters+1):
        optimizer.zero_grad()

        # Data from current context
        iters_left -= 1
        if iters_left==0:
            data_loader = iter(torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                           shuffle=True, drop_last=True))
            iters_left = len(data_loader)
        x, y = next(data_loader)
        y_hat = model(x)
        loss = torch.nn.functional.cross_entropy(input=y_hat, target=y, reduction='mean')
        accuracy = (y == y_hat.max(1)[1]).sum().item()*100 / x.size(0)

        # Replay data from memory buffer
        if buffer_dataset is not None:
          iters_left_replay -= 1
          if iters_left_replay==0:
              batch_size_to_use = min(batch_size, len(buffer_dataset))
              data_loader_replay = iter(torch.utils.data.DataLoader(buffer_dataset,
                                                                    batch_size_to_use, shuffle=True,
                                                                    drop_last=True))
              iters_left_replay = len(data_loader_replay)
          x_, y_ = next(data_loader_replay)
          y_hat_ = model(x_)
          loss_replay = torch.nn.functional.cross_entropy(input=y_hat_, target=y_, reduction='mean')

        # Combine both losses to approximate the joint loss over both contexts
        # (i.e., the loss on the replayed data has weight proportional to number of contexts so far)
        if buffer_dataset is not None:
            rnt = 1./current_context
            total_loss = rnt*loss + (1-rnt)*loss_replay
        else:
            total_loss = loss

        total_loss.backward()
        optimizer.step()
        progress_bar.set_description(
        '<CLASSIFIER> | training loss: {loss:.3} | training accuracy: {prec:.3}% |'
            .format(loss=total_loss.item(), prec=accuracy)
        )
        progress_bar.update(1)
    progress_bar.close()

In [None]:
# Train on the second context using experience replay
train_replay(model_replay, train_datasets[1], iters=iters, lr=lr, batch_size=batch_size,
             current_context=2, buffer_dataset=buffer_dataset)

In [None]:
# Evaluate the model
print("\n Accuracy (in %) of the model on test-set of:")
replay_accs = []
for i in range(contexts):
    acc = test_acc(model_replay, test_datasets[i], test_size=None)
    print(" - Context {}: {:.1f}".format(i+1, acc))
    replay_accs.append(acc)

In [None]:
xfigure, axis = plt.subplots(1, 3, figsize=(12, 4))

title='Fine-tuning'
multi_context_barplot(axis[0], context2_accs, title)

title='EWC \n(lambda: {})'.format(ewc_lambda)
multi_context_barplot(axis[1], ewc_accs, title)

title='Replay \n(buffer: {} samples per class)'.format(buffer_size_per_class)
multi_context_barplot(axis[2], replay_accs, title)