# Dynamic Sparse Distributed Memory



This notebook implements the DSDM model presented in [Online Task-free Continual Learning with Dynamic Sparse Distributed Memory](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136850721.pdf).

In [85]:
from hashlib import sha256
import math
import matplotlib
import matplotlib.pyplot as plt
import numpy
import numpy as np
import random

import pandas as pd
import pathlib
from preprocess import preprocess_text

from sklearn.neighbors import LocalOutlierFactor

import torch
import torchhd as thd
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F 
from tqdm import tqdm

In [86]:
# TODO: Move to experiment notebook.
# Set device.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Vector dimension. TODO: Why was it chosen this high? Cite papers where confusion is not possible after a certain value.
dim = 2000 
n = 100000
# TODO: Might make more sense to be a field in DSDM.
cleanup = {}

In [87]:
def fix_seed():
    seed = 42
    print("[ Using Seed : ", seed, " ]")

    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    numpy.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
def load_data(path, bs=0, shuffle=False):
    """Load data from file path. """
    text = pathlib.Path(path).read_text(encoding='utf-8')
    
    lines = text.splitlines()

    return lines


class SubDataset(Dataset):
    '''To sub-sample a dataset, taking only those samples with label in [sub_labels].
    After this selection of samples has been made, it is possible to transform the target-labels,
    which can be useful when doing continual learning with fixed number of output units.'''

    def __init__(self, original_dataset, sub_labels, target_transform=None, transform=None):
        super().__init__()
        self.dataset = original_dataset
        self.sub_indeces = []
        for index in range(len(self.dataset)):
            if hasattr(original_dataset, "targets"):
                if self.dataset.target_transform is None:
                    label = self.dataset.targets[index]
                else:
                    label = self.dataset.target_transform(self.dataset.targets[index])
            else:
                label = self.dataset[index][1]
            if label in sub_labels:
                self.sub_indeces.append(index)
        self.target_transform = target_transform
        self.transform=transform

    def __len__(self):
        return len(self.sub_indeces)

    def __getitem__(self, index):
        sample = self.dataset[self.sub_indeces[index]]
        if self.transform:
            sample=self.transform(sample)
        if self.target_transform:
            target = self.target_transform(sample[1])
            sample = (sample[0], target)
        return sample
    

def compute_distances_gpu(X, Y):
    return torch.sqrt(-2 * torch.mm(X, Y.T) +
                    torch.sum(torch.pow(Y, 2), dim=1) +
                    torch.sum(torch.pow(X, 2), dim=1).view(-1, 1))

In [88]:
# Class that implements a self-organizing neural network which models a DSDM.
class SONN(nn.Module):
    def __init__(self, Time_period, n_mini_batch, n_class=10, n_feat=384):
        super(SONN, self).__init__()
        self.n_feat = n_feat
        self.n_class=n_class
        self.Time_period = Time_period 
        self.ema = 2/(Time_period + 1)
        self.n_mini_batch = n_mini_batch
        self.count = 0
        self.T = 1
        self.Address = torch.zeros(1, n_feat).to(device)
        self.M = torch.zeros(1, self.n_class)
        self.p_norm = "fro"
        self.Error = torch.zeros(len(self.Address)).to(device)
        self.global_error = 0
        self.Time_period_Temperature = self.ema
        self.ema_Temperature = (2 / (self.Time_period_Temperature + 1))
        self.memory_global_error = torch.zeros(1)
        self.memory_min_distance = torch.zeros(1)
        self.memory_count_address = torch.zeros(1)
        self.dataset_name = "MNIST"
        
        self.acc_after_each_task = []
        self.acc_aft_all_task = []
        self.stock_feat = torch.tensor([]).to(device)
        self.forgetting = []
        self.N_prune = 5000 # Pruning threshold
        self.prune_mode = "balance"
        self.n_neighbors = 20
        self.contamination = "auto"
        self.pruning = False 
        self.cum_acc_activ = False
        self.batch_test = True
        
        self.reset()
        
    def apply_param(self, T, pruning, N_prune, n_neighbors, Time_period_Temperature):
        self.T = T
        self.pruning = True
        self.N_prune = N_prune
        self.n_neighbors = n_neighbors
        self.Time_period_Temperature = Time_period_Temperature
        
    def reset(self):
        self.ema = 2 / (self.Time_period + 1)
        self.ema_Temperature = (2 / (self.Time_period_Temperature + 1))
        self.count = 0
        self.Address = torch.zeros(1, self.n_feat).to(device)
        self.M = torch.zeros(1, self.n_class).to(device)
        self.Error = torch.zeros(len(self.Address)).to(device)
        self.global_error = 0
        self.memory_global_error = torch.zeros(1)
        self.memory_min_distance = torch.zeros(1)
        self.memory_count_address = torch.zeros(1)
        
    def retrieve(self, query_address, batch_test=False):
        """TODO: Add description."""
      # No gradient will be computed.
        with torch.no_grad():
            retrieved_content = torch.tensor([]).to(device)
            # Get prediction.
            if batch_test:
                pass
                # Compute distance from inputs to address space.
                #distance = compute_distances_gpu(inputs, self.Address)
                # Calculate address weight based on distance.
                #soft_norm = F.softmin(distance/self.T, dim=-1)
                # Pool weighted (come from the distance) content to get prediction.
                #pred = torch.matmul(soft_norm, self.M)
            else:
                difference = query_address - self.Address
                norm = torch.norm(difference, p=self.p_norm, dim=-1)
                soft_norm = F.softmin(norm/self.T, dim=-1)
                soft_pred = torch.matmul(soft_norm, self.M.to(device)).view(-1)
                retrieved_content = torch.sum(soft_pred.view(1, -1), 0)
        return retrieved_content   
    
    def prune(self):
        N_pruning = self.N_prune
        n_class = self.M.size(1)
        if len(self.Address) > N_pruning:
            clf = LocalOutlierFactor(n_neighbors=min(len(self.Address), self.n_neighbors), contamination=self.contamination)
            A = self.Address
            M = self.M
            y_pred = clf.fit_predict(A.cpu())
            X_scores = clf.negative_outlier_factor_
            x_scor = torch.tensor(X_scores)
            # "Naive" pruning mode.
            if self.prune_mode == "naive":
                if len(A) > N_pruning:
                    prun_N_addr = len(A) - N_pruning # No. of addresses that must be pruned out.
                    val, ind = torch.topk(x_scor, prun_N_addr) 
                    idx_remove = [True] * len(A)
                    for i in ind:
                        idx_remove[i] = False
                    self.M = self.M[idx_remove] # Delete content from address.
                    self.Address = self.Address[idx_remove] # Delete address.
            # "Balance" pruning mode.
            if self.prune_mode == "balance":
                prun_N_addr = len(A) - N_pruning # No. of addresses that must be pruned out.
                mean_addr = N_pruning // n_class
                val, ind = torch.sort(x_scor, descending=True)

                count = prun_N_addr
                idx_remove = [True] * len(A)
                idx = 0
                arg_m = torch.argmax(M, axis=1)
                N_remaining = torch.bincount(arg_m)
                while count != 0:
                    idx +=1
                    indice = ind[idx]
                    if N_remaining[arg_m[indice]] > (N_pruning // n_class):
                        N_remaining[arg_m[indice]] -= 1
                        idx_remove[ind[idx]] = False
                        count-=1
                self.M = self.M[idx_remove]
                self.Address = self.Address[idx_remove]
        
    def test(self, testloader):
      """ Test batch-wise. """
      total = 0
      correct = 0

      for batch_idx, (inputs, targets) in enumerate(testloader):
          targets = targets.type(torch.LongTensor).to(device)
          inputs = inputs.to(device)
          # Pass inputs through NN to get prediction.
          outputs = self.forward(inputs)
          _, predicted = torch.max(outputs, 1)
          total += targets.size(0)
          correct += (predicted == targets).sum().item()

      accuracy = correct / total * 100
      print("test accuracy {:.3f} %,  {:.3f} / {:.3f} ".format(accuracy, correct, total))
      return accuracy
    
    def test_idx(self, test_dataset_10_way_split, idx_test):
        with torch.no_grad():
            total = 0
            correct = 0

            for idx in idx_test:
                curr_correct = 0
                curr_total = 0
                for batch_idx, (inputs, targets) in enumerate(test_dataset_10_way_split[idx]):
                    inputs = inputs.to(device)
                    targets = targets.type(torch.LongTensor).to(device)
                    # Pass inputs through NN to get prediction.
                    outputs = self.forward(inputs)
                    _, predicted = torch.max(outputs,1)
                    total += targets.size(0)
                    corr = (predicted == targets).sum().item()
                    curr_correct +=corr
                    correct += corr
                    curr_total += targets.size(0)
            accuracy = correct / total * 100
        return accuracy, curr_correct / curr_total * 100
    
    def save(self, target_address, target_content, coef_global_error):
        """Add an item (target_address, target_content) to memory."""
        address_difference = target_address - self.Address # Difference tensor
        # Distance tensor, where the distance is the Frobenius norm of the difference. 
        norm = torch.norm(address_difference, p=self.p_norm, dim=-1)
        # Will be later fed as input to the softmin layer.
        soft_norm = norm
        # Get the minimum distance and the corresponding address index.  
        min_distance = torch.min(norm, dim=0)[0].item()
        # Adjust parameter based on the minimum distance..
        self.global_error += self.ema_Temperature * (min_distance - self.global_error)

        # Check if the minimum distance is bigger than the adaptive threshold.
        # If the minimum distance is bigger, then:
        
        if min_distance >= self.global_error * coef_global_error:
            # Add a new entry to the address matrix/tensor equal to the target address.
            self.Address = torch.cat((self.Address, target_address.view(1, -1)))
            # Add a new entry to the content matrix/tensor equal to the target content.
            self.M = torch.cat((self.M, target_content.view(1, -1)))
        # If the minimum distance is not bigger, then:
        else:
            # Apply the softmin function to the distance tensor the get a list of weights, that can be interpretated as probabilities.
            soft_norm = F.softmin(norm/self.T, dim=-1)
            # Weight the address difference by the corresponding address softmin weight and modify all the existing addresses. 
            self.Address += self.ema * torch.mul(soft_norm.view(-1, 1), address_difference)
            # Weight the content difference by the corresponding address softmin weight and modify all the existing content entries. 
            self.M += self.ema * torch.mul(soft_norm.view(-1, 1), (target_content - self.M))
        
        return

    def generate_and_save_chunks(self, tokens, coef_global_error):
        # TODO: Move to the experiment notebook once you've figured out how to make DSDM as modular as possible.
        """TODO: Add function description."""
        # Generate 1-token chunks.
        for token in tokens:
            # Check if the chunk has been encountered before by querying the cleanup memory.
            entry = cleanup.get(token)
            # If it has not, then:
            if entry == None:
                # Generate a random HRR HV representation for the token.
                val = thd.HRRTensor.random(1, dim)[0]
                # Add val, key, and token to store.
                cleanup[token] = {'val': val, 'trans': token}
            # If it has, then:
            else:
                val = entry['val']
            # Add chunk to the DSDM in an autoassociative manner.
            self.save(val, val, coef_global_error)

        # "n" represents the no. of tokens in the sentence, which is also the max. no. of tokens 
        # that can be grouped to form a chunk.
        n = len(tokens)

        for no_tokens in range(n + 1)[2:]:
          print("no. of tokens: ", no_tokens)
          for i in range(n):
            print("start index: ", i)
            # If there are not enough tokens left to construct a chunk comprised of "no_tokens", break. 
            if i + no_tokens > len(tokens):
              print("Not enough tokens left.")
              break 
            val = thd.HRRTensor.empty(1, dim)[0]
            _ = ""
            # Construct val.
            for j in range(no_tokens):
              print(tokens[i + j])
              _ += tokens[i + j] 
              _ += " "
              val += thd.permute(cleanup[tokens[i + j]]['val'], shifts=no_tokens - j - 1) # TODO: you need the original position in the sentence.
            # Check to see if val has been encountered before or not.
            store_key = sha256(''.join([str(elem) for elem in val.tolist()]).encode('utf-8')).hexdigest()
            if cleanup.get(store_key) == None:
              # Add values to the cleanup memory.
              cleanup[store_key] = {'val': val, 'trans': _}

           # Add the chunk representation to DSDM.
          self.save(val, val, coef_global_error)

        return
    
    def train__test_n_way_split(self, train_set, test_set, coef_global_error=1, ema_global_error=None, save_feat=False):
        """TODO: Add description."""
        # Sentence processing train loop. 
        for sentence in train_set:
            # Generate chunks from the sentence and add them to DSDM.
            self.generate_and_save_chunks(sentence, coef_global_error)

            # Prune memory.
            #if self.pruning:
            #    self.prune()

            return 
        
    def grid_search_spread_factor(self, Time_period, n_mini_batch, train_set, test_set, N_try=1, ema_global_error="same", coef_global_error=1, random_ordering=False):
        """Search for the best spread factor."""
        # Instantiate array with a length equal to the number of trials.
        N_address_use = torch.zeros(N_try)
        self.forgetting = []
        
        self.n_mini_batch = n_mini_batch
        self.Time_period = Time_period 
        self.ema = 2 / (Time_period + 1)

        for idx_try in tqdm(range(N_try)):
            # Reset DSDM parameters.
            self.reset()

            # Get train and test accuracy for current trial.
            self.train__test_n_way_split(train_set,
                                         test_set,
                                         ema_global_error=ema_global_error,
                                         coef_global_error=coef_global_error)
            # Number of generated addresses.
            N_address_use[idx_try] = self.M.size(0)

            # Shuffle the data randomly for a new trail.
            if random_ordering:
                dataset_shuffle = list(zip(train_dataset_10_way_split, test_dataset_10_way_split))
                random.shuffle(dataset_shuffle)
                train_dataset_10_way_split, test_dataset_10_way_split = zip(*dataset_shuffle)
          
        return 

### Run experiment

In [89]:
# Load data.
lines_raw = load_data('../data/data.txt')

# Preprocess input. 
lines = []
for line_raw in lines_raw:
    lines.append(preprocess_text(line_raw))

nprune = [0] #TODO: [1000, 2000, 5000, 10000]
for i in nprune:
    N_try = 5 
    n_mini_batch = 55 
    alpha = 1
    Time_period = 500
    Time_period_temperature = 150

    # Instantiate DSDM instance.
    sonn = SONN(Time_period, n_mini_batch, dim, n_feat=dim)
    sonn.n_neighbors = 1000
    sonn.contamination = "auto"
    sonn.p_norm = "fro"
    sonn.T = 2.3
    sonn.pruning = True
    sonn.N_prune = i
    sonn.cum_acc_activ = True
    sonn.Time_period_Temperature = Time_period_temperature
    
    # Flush cleanup memory.
    cleanup = {}
    # Train and test DSDM.
    sonn.grid_search_spread_factor(Time_period,
                                   n_mini_batch,
                                   lines,
                                   lines,
                                   N_try,
                                   ema_global_error="diff",
                                   coef_global_error=alpha)

 40%|████      | 2/5 [00:00<00:00, 17.87it/s]

no. of tokens:  2
start index:  0
the
red
start index:  1
red
house
start index:  2
house
is
start index:  3
is
big
start index:  4
Not enough tokens left.
no. of tokens:  3
start index:  0
the
red
house
start index:  1
red
house
is
start index:  2
house
is
big
start index:  3
Not enough tokens left.
no. of tokens:  4
start index:  0
the
red
house
is
start index:  1
red
house
is
big
start index:  2
Not enough tokens left.
no. of tokens:  5
start index:  0
the
red
house
is
big
start index:  1
Not enough tokens left.
no. of tokens:  2
start index:  0
the
red
start index:  1
red
house
start index:  2
house
is
start index:  3
is
big
start index:  4
Not enough tokens left.
no. of tokens:  3
start index:  0
the
red
house
start index:  1
red
house
is
start index:  2
house
is
big
start index:  3
Not enough tokens left.
no. of tokens:  4
start index:  0
the
red
house
is
start index:  1
red
house
is
big
start index:  2
Not enough tokens left.
no. of tokens:  5
start index:  0
the
red
house
is
bi

100%|██████████| 5/5 [00:00<00:00, 18.92it/s]

start index:  2
Not enough tokens left.
no. of tokens:  5
start index:  0
the
red
house
is
big
start index:  1
Not enough tokens left.
no. of tokens:  2
start index:  0
the
red
start index:  1
red
house
start index:  2
house
is
start index:  3
is
big
start index:  4
Not enough tokens left.
no. of tokens:  3
start index:  0
the
red
house
start index:  1
red
house
is
start index:  2
house
is
big
start index:  3
Not enough tokens left.
no. of tokens:  4
start index:  0
the
red
house
is
start index:  1
red
house
is
big
start index:  2
Not enough tokens left.
no. of tokens:  5
start index:  0
the
red
house
is
big
start index:  1
Not enough tokens left.





In [91]:
def generate_query(tokens: list):
  n = len(tokens)
  val = thd.HRRTensor.empty(1, dim)

  for i in range(n):
    # The token hasn't been encountered before.
    if cleanup.get(tokens[i]) == None:
        # Generate a random value for the unencountered token.
        val += thd.HRRTensor.permute(thd.HRRTensor.random(1, dim), shifts=n - i - 1)
    # The token has been encountered before.
    else:
        val += thd.permute(cleanup[tokens[i]]['val'], shifts=n - i - 1)

    return val

In [95]:
retrieved_content = sonn.retrieve(generate_query(preprocess_text("She likes.")))

sims_df = pd.DataFrame(columns=['chunk', 'sim'])

for key, item in cleanup.items():
  sims_df = pd.concat([sims_df, pd.DataFrame([{'chunk': cleanup[key]['trans'], 'sim': thd.cosine_similarity(cleanup[key]['val'],  retrieved_content).item()}])])

display(sims_df.sort_values('sim', ascending=False))

Unnamed: 0,chunk,sim
0,red house is big,0.88123
0,house is big,0.878236
0,the red house is big,0.834457
0,is big,0.8301
0,big,0.671697
0,the,0.178868
0,house,0.15995
0,is,0.140828
0,red,0.131422
0,the red house,0.124279
