<a href="https://colab.research.google.com/github/jonyghosh444/transformer-res-ger/blob/master/DSDM_JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install --upgrade equinox

Collecting equinox
  Downloading equinox-0.12.2-py3-none-any.whl.metadata (18 kB)
Collecting jaxtyping>=0.2.20 (from equinox)
  Downloading jaxtyping-0.3.2-py3-none-any.whl.metadata (7.0 kB)
Collecting wadler-lindig>=0.1.0 (from equinox)
  Downloading wadler_lindig-0.1.6-py3-none-any.whl.metadata (17 kB)
Downloading equinox-0.12.2-py3-none-any.whl (177 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m177.2/177.2 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jaxtyping-0.3.2-py3-none-any.whl (55 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.4/55.4 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading wadler_lindig-0.1.6-py3-none-any.whl (20 kB)
Installing collected packages: wadler-lindig, jaxtyping, equinox
Successfully installed equinox-0.12.2 jaxtyping-0.3.2 wadler-lindig-0.1.6


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Importing Libraries

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap, pmap
import equinox as eqx
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import collections
import time
from torchvision.datasets import MNIST
from torchvision import transforms

# MNIST DATA

In [None]:
def load_train_test_data():

    train = MNIST(root="./train_data", train=True, transform=transforms.ToTensor(), target_transform=None, download=True)
    test = MNIST(root="./test_data", train=False, transform=transforms.ToTensor(), target_transform=None, download=True)

    train_data = []
    print("Start Preparing Train Data")
    for image, label in tqdm(train):
        image = jnp.array(image.view(-1).numpy())
        train_data.append([image, label])

    test_data = []
    print("Start Preparing Test Data")
    for image, label in tqdm(test):
        image = jnp.array(image.view(-1).numpy())
        test_data.append([image, label])

    train_labels = [label for _, label in train_data]
    test_labels = [label for _, label in test_data]
    print("Train label distribution:", collections.Counter(train_labels))
    print("Test label distribution:", collections.Counter(test_labels))

    return train_data, test_data

# Loading MNIST Train and Test Data

In [None]:
train_data, test_data = load_train_test_data()

100%|██████████| 9.91M/9.91M [00:00<00:00, 11.3MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 409kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.19MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.90MB/s]
100%|██████████| 9.91M/9.91M [00:00<00:00, 12.8MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 344kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.21MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.44MB/s]

Start Preparing Train Data





  0%|          | 0/60000 [00:00<?, ?it/s]

Start Preparing Test Data


  0%|          | 0/10000 [00:00<?, ?it/s]

Train label distribution: Counter({1: 6742, 7: 6265, 3: 6131, 2: 5958, 9: 5949, 0: 5923, 6: 5918, 8: 5851, 4: 5842, 5: 5421})
Test label distribution: Counter({1: 1135, 2: 1032, 7: 1028, 3: 1010, 9: 1009, 4: 982, 0: 980, 8: 974, 6: 958, 5: 892})


# Splitting Train and Test Data based on Classes.

In [None]:
def split_dataset_based_on_class(dataset, splitted_labels, num_data = "all"):

    splitted_dataset_feat = []
    splitted_dataset_labels = []

    print('Start Making Splitted Dataset')
    for sub_labels in tqdm(splitted_labels):
        print(f"Creating Dataset for label {sub_labels}")
        sub_dataset_feat = []
        sub_dataset_labels = []
        for input, label in dataset:
            if num_data != "all" and len(sub_dataset_feat) > num_data:
                break
            if label in sub_labels:

                sub_dataset_feat.append(input)
                sub_dataset_labels.append(label)

        splitted_dataset_feat.append(jnp.array(sub_dataset_feat))
        splitted_dataset_labels.append(jnp.array(sub_dataset_labels))

    return splitted_dataset_feat, splitted_dataset_labels

In [None]:
# You can give num_data = "all" or any integer numerical value upto 60000 for train
splitted_train_feat, splitted_train_targets = split_dataset_based_on_class(train_data, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]], num_data = "all")

Start Making Splitted Dataset


  0%|          | 0/5 [00:00<?, ?it/s]

Creating Dataset for label [0, 1]
Creating Dataset for label [2, 3]
Creating Dataset for label [4, 5]
Creating Dataset for label [6, 7]
Creating Dataset for label [8, 9]


In [None]:
# You can give num_data = "all" or any integer numerical value upto 10000 for test
splitted_test_feat, splitted_test_targets = split_dataset_based_on_class(test_data, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]], num_data = "all")

Start Making Splitted Dataset


  0%|          | 0/5 [00:00<?, ?it/s]

Creating Dataset for label [0, 1]
Creating Dataset for label [2, 3]
Creating Dataset for label [4, 5]
Creating Dataset for label [6, 7]
Creating Dataset for label [8, 9]


# Local Outlier Factor coded in JAX

In [None]:
class LocalOutlierFactorJAX(eqx.Module):
    k: int

    def __init__(self, n_neighbors):

        self.k = n_neighbors

    def _pairwise_distances(self, X):
        diffs = jnp.expand_dims(X, 1) - jnp.expand_dims(X, 0)
        return jnp.sqrt(jnp.sum(diffs ** 2, axis=-1))

    def _k_neighbors(self, dists):
        n = dists.shape[0]
        dists_no_self = dists + jnp.eye(n) * 1e10
        return jnp.argsort(dists_no_self, axis=1)[:, :self.k]

    def _reachability_distance(self, dists, neighbors, k_distances):
        def single_point_reach(i):
            nbrs = neighbors[i]
            reach_d = jnp.maximum(k_distances[nbrs], dists[i, nbrs])
            return jnp.mean(reach_d)
        return jax.vmap(single_point_reach)(jnp.arange(dists.shape[0]))

    def _local_reachability_density(self, reach_dists):
        return 1.0 / reach_dists

    def _lof_score(self, lrd, neighbors):
        def single_lof(i):
            nbrs = neighbors[i]
            return jnp.sum(lrd[nbrs]) / (self.k * lrd[i])
        return jax.vmap(single_lof)(jnp.arange(len(lrd)))

    @eqx.filter_jit
    def __call__(self, X):
        dists = self._pairwise_distances(X)
        neighbors = self._k_neighbors(dists)
        k_dists = jnp.take_along_axis(dists, neighbors, axis=1)[:, -1]
        reach_d = self._reachability_distance(dists, neighbors, k_dists)
        lrd = self._local_reachability_density(reach_d)
        lof = self._lof_score(lrd, neighbors)
        return lof

# Naive Pruning Class

In [None]:
import equinox as eqx
import jax
import jax.numpy as jnp
import functools
class NaivePruning(eqx.Module):
    Q: int = eqx.static_field()
    n_neighbors : int = eqx.static_field()
    K : int = eqx.static_field()
    N_prune : int = eqx.static_field()

    def __init__(self, Q, K, n_neighbors = 20):
        self.Q = Q
        self.n_neighbors = n_neighbors
        self.K = K
        self.N_prune = K - Q

    @eqx.filter_jit
    def __call__(self, A, C):
        #N_prune = K - self.Q
        address = A[:self.K]
        # Assume LocalOutlierFactorJAX works like a function
        clf = LocalOutlierFactorJAX(n_neighbors=self.n_neighbors)
        lof_scores = clf(address)

        indices_to_include = jnp.argsort(-lof_scores)[:self.N_prune]
        new_A = A[indices_to_include]
        new_C = C[indices_to_include]

        # Padding if needed
        pad_A = A.shape[0] - new_A.shape[0]
        pad_C = C.shape[0] - new_C.shape[0]
        padding_A = jnp.zeros((pad_A, A.shape[1]))
        padding_C = jnp.zeros((pad_C, C.shape[1]))
        new_A = jnp.vstack([new_A, padding_A])
        new_C = jnp.vstack([new_C, padding_C])

        new_K = self.Q
        return new_A, new_C, new_K

# Balance Pruning Class

In [None]:
class BalancePruning(eqx.Module):
    Q: int = eqx.static_field()
    buffer: int = eqx.static_field()
    n_class: int = eqx.static_field()
    n_neighbors: int = eqx.static_field()

    def __init__(self, Q, buffer, n_class, n_neighbors=20):
        self.Q = Q
        self.buffer = buffer
        self.n_class = n_class
        self.n_neighbors = n_neighbors

    def __call__(self, A, C, K):
        return self._prune(A, C, K)

    @eqx.filter_jit
    def _prune(self, A, C, K):
        N_prune = K - self.Q
        mean_prune = N_prune // self.n_class

        # Get LOF scores and sort indices
        clf = LocalOutlierFactorJAX(n_neighbors=self.n_neighbors)
        lof_scores = clf(A)
        indices_to_include = jnp.argsort(lof_scores)

        # Get class labels
        classes = jnp.argmax(C, axis=1)

        # Count how many to keep per class
        class_counts = jnp.bincount(classes, length=self.n_class)

        class_keep_counts = jnp.where(class_counts >= mean_prune, class_counts - mean_prune, class_counts)  # how many to keep


        # Initialize outputs
        new_A = jnp.zeros_like(A)
        new_C = jnp.zeros_like(C)
        counts = jnp.zeros((self.n_class,), dtype=jnp.int32)

        def insert_data(carry, idx):
            new_A, new_C, counts = carry
            current_class = classes[idx].astype(jnp.int32)
            keep_limit = class_keep_counts[current_class]

            def place():
                current_idx = jnp.sum(counts).astype(jnp.int32)
                new_A_ = new_A.at[current_idx].set(A[idx])
                new_C_ = new_C.at[current_idx].set(C[idx])
                counts_ = counts.at[current_class].add(1)
                return new_A_, new_C_, counts_

            def skip():
                return new_A, new_C, counts

            cond = counts[current_class] < keep_limit
            return jax.lax.cond(cond, place, skip), None

        (new_A, new_C, counts), _ = jax.lax.scan(insert_data, (new_A, new_C, counts), indices_to_include)

        # Total new size
        new_K = jnp.sum(class_keep_counts).astype(jnp.int32)

        return new_A, new_C, new_K, mean_prune

# Slicer

In [None]:
class Slice(eqx.Module):

    num_rows : int = eqx.static_field()

    def __init__(self, num_rows):
        self.num_rows = num_rows

    def __call__(self, X):
        return X[:self.num_rows]

# DSDM IN JAX

In [None]:
class FASTER_DSDM(eqx.Module):
    Address: jnp.ndarray
    Content: jnp.ndarray
    n_feat: int = eqx.static_field()
    n_class: int = eqx.static_field()
    RT: float
    K: int
    Q: int = eqx.static_field()
    buffer: int = eqx.static_field()
    beta: float
    Lambda : float
    Lambda_RT: float
    prune_method: int = eqx.static_field()
    n_neighbors: int = eqx.static_field()
    contamination: float = eqx.static_field()

    def __init__(self, RT=0, Q=100, buffer = 100, beta=0.022, Lambda=0.01, Lambda_RT=0.01,
                 n_feat=784, n_class=10, prune_method=0, n_neighbors=1000, contamination=0.1):
        self.n_feat = n_feat
        self.n_class = n_class
        self.Q = Q
        self.buffer = buffer
        self.Address = jnp.zeros((self.Q + self.buffer, self.n_feat))
        self.Content = jnp.zeros((self.Q + self.buffer, self.n_class))
        self.beta = beta
        self.RT = RT
        self.K = 0
        self.Lambda = Lambda
        self.Lambda_RT = Lambda_RT
        self.prune_method = prune_method
        self.n_neighbors = n_neighbors
        self.contamination = contamination


    def prune(self, address, content, K, RT):
        naive_pruning = NaivePruning(
            Q=self.Q,
            K = K,
            #buffer=self.buffer,
            #n_class=self.n_class,
            n_neighbors=self.n_neighbors,
        )
        #new_address, new_content, new_K, mean_prune = naive_pruning(address, content, K)
        new_address, new_content, new_K = naive_pruning(address, content)
        #print(mean_prune)
        new_RT = RT + 0.0  # dummy update

        return eqx.tree_at(
            lambda model: (model.Address, model.Content, model.K, model.RT),
            self,
            (new_address, new_content, new_K, new_RT)
        )



    def return_same_model(self, address, content, K, RT):
        return eqx.tree_at(lambda model : (model.Address, model.Content, model.K, model.RT),
                           self,
                           (address, content, K, RT))

    #Calculating raw distance
    def calculate_distance(self, x, y):
        return x - y
    #Calculating Eucledian Norm
    def calculate_norm(self, distance):
        return jnp.linalg.norm(distance, ord=2)

    #Intiating Address and Content Matrix for the first Datapoint
    def initiate_address_and_content(self, input_x_and_target_y):
        input_x, target_y = input_x_and_target_y
        new_address = self.Address.at[self.K].set(input_x)
        one_hot_encoded_target = jax.nn.one_hot(target_y, num_classes=self.n_class)
        new_content = self.Content.at[self.K].set(one_hot_encoded_target)
        new_K = jnp.array(self.K + 1, dtype=jnp.int32)
        new_RT = self.RT + 0.0
        return new_address, new_content, new_K, new_RT

    #Adding new node to the Address and Content Matrix if minimum distance of the datapoint is greater than RT(Recursive Temperature) Parameter
    def add_new_node(self, input_x_and_target_y, min_BMU_distance):
        input_x, target_y = input_x_and_target_y
        new_address = self.Address.at[self.K].set(input_x)
        one_hot_encoded_target = jax.nn.one_hot(target_y, num_classes=self.n_class)
        new_content = self.Content.at[self.K].set(one_hot_encoded_target)
        new_K = jnp.array(self.K + 1, dtype=jnp.int32)
        #new_RT = self.Lambda_RT * self.RT + (1 - self.Lambda_RT) * min_BMU_distance
        new_RT = (1 - self.Lambda_RT) * self.RT + self.Lambda_RT * min_BMU_distance
        return new_address, new_content, new_K, new_RT

    # Modifying Existing nodes in Address and Content Matrix using soft norm
    def modify_existing_nodes(self, input_x_and_target_y, min_BMU_distance, all_BMU_distances):
        input_x, target_y = input_x_and_target_y
        soft_norm = jax.nn.softmax(-all_BMU_distances / self.beta, axis=-1)
        soft_norm_reshaped = jnp.expand_dims(soft_norm, axis=-1)
        address_diff = jax.vmap(self.calculate_distance, in_axes = (None, 0))(input_x, self.Address)

        new_address = self.Address + self.Lambda * soft_norm_reshaped * (address_diff)
        one_hot_encoded_target = jax.nn.one_hot(target_y, num_classes=self.n_class)
        content_diff = jax.vmap(self.calculate_distance, in_axes = (None, 0))(one_hot_encoded_target, self.Content)
        new_content = self.Content + self.Lambda * soft_norm_reshaped * content_diff
        new_K = jnp.array(self.K, dtype=jnp.int32)
        new_RT = (1 - self.Lambda_RT) * self.RT + self.Lambda_RT * min_BMU_distance
        return new_address, new_content, new_K, new_RT

    # Calculaing BMU distances of the datapoint from all the nodes in Address Matrix and Deciding whether to add a new node or modify the existing nodes
    def add_or_modify_node(self, input_x_and_target_y):
        input_x, target_y = input_x_and_target_y
        distances = jax.vmap(self.calculate_distance, in_axes=(None, 0))(input_x, self.Address)
        all_BMU_distances = jax.vmap(self.calculate_norm, in_axes=0)(distances)
        min_BMU_distance = jnp.min(all_BMU_distances)

        new_address, new_content, new_K, new_RT = jax.lax.cond(
            abs(min_BMU_distance) > self.RT,
            lambda: self.add_new_node(input_x_and_target_y, min_BMU_distance),
            lambda: self.modify_existing_nodes(input_x_and_target_y, min_BMU_distance, all_BMU_distances)
        )

        return new_address, new_content, new_K, new_RT

    # Forward pass is initiating the Content and Address Matrix if K(number of current Datapoints) == 0 else it is calling add_or_modify_function
    def forward(self, input_x_and_target_y):
        new_address, new_content, new_K, new_RT = jax.lax.cond(
            self.K == 0,
            lambda: self.initiate_address_and_content(input_x_and_target_y),
            lambda: self.add_or_modify_node(input_x_and_target_y)
        )
        return eqx.tree_at(
            lambda model: (model.Address, model.Content, model.K, model.RT),
            self,
            (new_address, new_content, new_K, new_RT)
        ), None

    # This is the Training Loop
    @eqx.filter_jit
    def train(self, inputs, targets):
        #Doing A single Batch Calculation
        def scan_fn(model, batch_x_and_y):
            # Do Training for a single Datapoint in a Batch
            def inner_scan_fn(m, x_and_y):
                new_model, _ = m.forward(x_and_y)

                return new_model, None

            new_model, _ = jax.lax.scan(inner_scan_fn, model, batch_x_and_y)

            return new_model, None

        new_model, _ = jax.lax.scan(scan_fn, self, (inputs, targets))
        return new_model

    #This is the Inference Loop predicting the class for each test datapoint in batch
    def inference(self, inputs, K):
        # Convert to static int for eqx.static_field
        K_int = int(K) if isinstance(K, jax.Array) else K

        address_slicer = Slice(num_rows=K_int)
        content_slicer = Slice(num_rows=K_int)

        @eqx.filter_jit
        def _jit_infer(self, inputs, address_slicer, content_slicer):
            sliced_address = address_slicer(self.Address)
            sliced_content = content_slicer(self.Content)

            def batch_infer(carry, batch_x):
                def single_infer(input_x):
                    distances = jax.vmap(self.calculate_distance, in_axes=(None, 0))(input_x, sliced_address)
                    norm = jax.vmap(self.calculate_norm, in_axes=0)(distances)
                    soft_norm = jax.nn.softmax(-norm / self.beta, axis=-1)
                    return jnp.argmax(jnp.matmul(soft_norm, sliced_content))

                predictions = jax.vmap(single_infer)(batch_x)
                return predictions, predictions

            init_carry = jnp.zeros((inputs.shape[1],), dtype=jnp.int32)
            predictions, _ = jax.lax.scan(batch_infer, init_carry, inputs)
            return predictions

        return _jit_infer(self, inputs, address_slicer, content_slicer)


# Training MNIST Raw pixels with pruning

In [None]:
dsdm = FASTER_DSDM(RT=0, Q=5000, buffer = 12000, beta=0.5, Lambda=0.0022, Lambda_RT=0.0025,
                 n_feat=784, n_class=10, prune_method=0, n_neighbors=20, contamination=0)


class_seen = []
train_data_seen = []
test_data_seen = []
start = time.time()
print("DATASET : MNIST")
for input_x, target_y, test_x, test_y in tqdm(list(zip(splitted_train_feat, splitted_train_targets, splitted_test_feat, splitted_test_targets))):
    unique_classes = list(set(target_y.tolist()))
    class_seen += unique_classes
    input_x = jnp.expand_dims(input_x, axis = 0)
    target_y = jnp.expand_dims(target_y, axis = 0)
    test_x = jnp.expand_dims(test_x, axis = 0)
    test_y = jnp.expand_dims(test_y, axis = 0)
    train_data_seen.append((input_x, target_y))
    test_data_seen.append((test_x, test_y))
    dsdm = dsdm.train(input_x, target_y)

    if dsdm.K > dsdm.Q:
        print("Nodes before pruning, ", dsdm.K)
        dsdm = dsdm.prune(dsdm.Address, dsdm.Content, dsdm.K, dsdm.RT)
        print("Nodes after pruning ", dsdm.K)

    train_accuracy = 0
    train_total_data = 0
    for train_x, train_y in train_data_seen:
        train_pred = dsdm.inference(train_x, dsdm.K)
        train_accuracy += (train_pred == train_y).sum().item()
        train_total_data += train_y.shape[1]

    print(f"Train Accuracy:{(train_accuracy / train_total_data) * 100} % after seeing class {class_seen}")

    test_accuracy = 0
    test_total_data = 0
    for test_x, test_y in test_data_seen:
        test_pred = dsdm.inference(test_x, dsdm.K)
        test_accuracy += (test_pred == test_y).sum().item()
        test_total_data += test_y.shape[1]

    print(f"Test Accuracy:{(test_accuracy / test_total_data * 100)} % after seeing class {class_seen}")
    print("===========================================================================================")

end = time.time()
print(f"Total Training Time required : {end - start}")


DATASET : MNIST


  0%|          | 0/5 [00:00<?, ?it/s]

Nodes before pruning,  6232


  naive_pruning = NaivePruning(


Nodes after pruning  5000
Train Accuracy:99.9526253454402 % after seeing class [0, 1]
Test Accuracy:99.90543735224587 % after seeing class [0, 1]
Nodes before pruning,  10977
Nodes after pruning  5000
Train Accuracy:99.42635533651128 % after seeing class [0, 1, 2, 3]
Test Accuracy:99.3504931440943 % after seeing class [0, 1, 2, 3]
Nodes before pruning,  10155
Nodes after pruning  5000
Train Accuracy:88.55818085903879 % after seeing class [0, 1, 2, 3, 4, 5]
Test Accuracy:87.16630741170619 % after seeing class [0, 1, 2, 3, 4, 5]
Nodes before pruning,  10357
Nodes after pruning  5000
Train Accuracy:85.93775933609959 % after seeing class [0, 1, 2, 3, 4, 5, 6, 7]
Test Accuracy:84.60770861918422 % after seeing class [0, 1, 2, 3, 4, 5, 6, 7]
Nodes before pruning,  10458
Nodes after pruning  5000
Train Accuracy:71.58166666666666 % after seeing class [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Test Accuracy:70.04 % after seeing class [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Total Training Time required : 114.77384710

# Training on MNIST raw pixels without pruning

In [None]:
dsdm = FASTER_DSDM(RT=0, Q=30000, buffer = 5000, beta=0.5, Lambda=0.0022, Lambda_RT=0.0025,
                 n_feat=784, n_class=10, prune_method=0, n_neighbors=20, contamination=0)


class_seen = []
train_data_seen = []
test_data_seen = []
start = time.time()
print("DATASET : MNIST")
for input_x, target_y, test_x, test_y in tqdm(list(zip(splitted_train_feat, splitted_train_targets, splitted_test_feat, splitted_test_targets))):
    unique_classes = list(set(target_y.tolist()))
    class_seen += unique_classes
    input_x = jnp.expand_dims(input_x, axis = 0)
    target_y = jnp.expand_dims(target_y, axis = 0)
    test_x = jnp.expand_dims(test_x, axis = 0)
    test_y = jnp.expand_dims(test_y, axis = 0)
    train_data_seen.append((input_x, target_y))
    test_data_seen.append((test_x, test_y))
    dsdm = dsdm.train(input_x, target_y)


    print("Current Nodes : ", dsdm.K)

    train_accuracy = 0
    train_total_data = 0
    for train_x, train_y in train_data_seen:
        train_pred = dsdm.inference(train_x, dsdm.K)
        train_accuracy += (train_pred == train_y).sum().item()
        train_total_data += train_y.shape[1]

    print(f"Train Accuracy:{(train_accuracy / train_total_data) * 100} % after seeing class {class_seen}")

    test_accuracy = 0
    test_total_data = 0
    for test_x, test_y in test_data_seen:
        test_pred = dsdm.inference(test_x, dsdm.K)
        test_accuracy += (test_pred == test_y).sum().item()
        test_total_data += test_y.shape[1]

    print(f"Test Accuracy:{(test_accuracy / test_total_data * 100)} % after seeing class {class_seen}")
    print("===========================================================================================")

end = time.time()
print(f"Total Training Time required : {end - start}")

DATASET : MNIST


  0%|          | 0/5 [00:00<?, ?it/s]

Current Nodes :  6230
Train Accuracy:100.0 % after seeing class [0, 1]
Test Accuracy:99.90543735224587 % after seeing class [0, 1]
Current Nodes :  12208
Train Accuracy:99.15973176052356 % after seeing class [0, 1, 2, 3]
Test Accuracy:97.83497714698099 % after seeing class [0, 1, 2, 3]
Current Nodes :  17382
Train Accuracy:99.01713079934476 % after seeing class [0, 1, 2, 3, 4, 5]
Test Accuracy:97.2972972972973 % after seeing class [0, 1, 2, 3, 4, 5]
Current Nodes :  22730
Train Accuracy:98.91078838174275 % after seeing class [0, 1, 2, 3, 4, 5, 6, 7]
Test Accuracy:96.50742172882624 % after seeing class [0, 1, 2, 3, 4, 5, 6, 7]
Current Nodes :  28133
Train Accuracy:98.00999999999999 % after seeing class [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Test Accuracy:94.67999999999999 % after seeing class [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Total Training Time required : 214.70278644561768


train by using Encoded Data

CORE 50

# Loading Core50_resnet18_224.npz

In [None]:
import numpy as np

data = np.load("/content/drive/MyDrive/DSDM/dataset/Core50_resnet18_224.npz")
train_x = data["traindata"]
train_y = data["trainlabel"]
test_x  = data["testdata"]
test_y  = data["label_test"]

In [None]:
import jax.numpy as jnp

train_x = jnp.array(train_x)
train_y = jnp.array(train_y)
test_x  = jnp.array(test_x)
test_y  = jnp.array(test_y)

In [None]:
def split_dataset_by_class(X, Y, class_splits):
    split_x = []
    split_y = []
    for cls_group in class_splits:
        idxs = jnp.isin(Y, jnp.array(cls_group))
        split_x.append(X[idxs])
        split_y.append(Y[idxs])
    return split_x, split_y

In [None]:
class_splits = [list(range(i, i+10)) for i in range(0, 50, 10)]
splitted_train_x, splitted_train_y = split_dataset_by_class(train_x, train_y, class_splits)
splitted_test_x, splitted_test_y = split_dataset_by_class(test_x, test_y, class_splits)

Train DSDM

In [None]:
dsdm = FASTER_DSDM(
    RT=0,
    Q=5000,
    buffer = 15000,
    beta=0.5,
    Lambda=0.0022,
    Lambda_RT=0.0025,
    n_feat=512,
    n_class=50,
    prune_method=0,
    n_neighbors=20,
    contamination=0.0,
)

In [None]:
train_data_seen = []
test_data_seen = []
class_seen = []

# Training on Core 50 Resnet18 Encoded with pruning

In [None]:
for train_x_batch, train_y_batch, test_x_batch, test_y_batch in tqdm(
    list(zip(splitted_train_x, splitted_train_y, splitted_test_x, splitted_test_y)),
    desc="Class-incremental Training"
):

    unique_classes = list(set(train_y_batch.tolist()))
    class_seen += unique_classes


    train_x_batch = jnp.expand_dims(train_x_batch, axis=0)
    train_y_batch = jnp.expand_dims(train_y_batch, axis=0)
    test_x_batch = jnp.expand_dims(test_x_batch, axis=0)
    test_y_batch = jnp.expand_dims(test_y_batch, axis=0)

    train_data_seen.append((train_x_batch, train_y_batch))
    test_data_seen.append((test_x_batch, test_y_batch))


    dsdm = dsdm.train(train_x_batch, train_y_batch)


    if dsdm.K > dsdm.Q:
        print("Nodes before pruning : ", dsdm.K)
        dsdm = dsdm.prune(dsdm.Address, dsdm.Content, dsdm.K, dsdm.RT)
        print("Nodes after pruning : ", dsdm.K)

    correct = 0
    total = 0
    for tx, ty in train_data_seen:
        preds = dsdm.inference(tx, dsdm.K)
        correct += (preds == ty).sum().item()
        total += ty.shape[1]
    train_acc = (correct / total) * 100
    print(f"Train Accuracy: {train_acc:.2f}% after classes {class_seen}")


    correct = 0
    total = 0
    for tx, ty in test_data_seen:
        preds = dsdm.inference(tx, dsdm.K)
        correct += (preds == ty).sum().item()
        total += ty.shape[1]
    test_acc = (correct / total) * 100
    print(f"Test Accuracy: {test_acc:.2f}% after classes {class_seen}")
    print("=" * 80)

Class-incremental Training:   0%|          | 0/5 [00:00<?, ?it/s]

Nodes before pruning :  11433


  naive_pruning = NaivePruning(


Nodes after pruning :  5000
Train Accuracy: 89.20% after classes [1, 2, 3, 4, 5, 6, 7, 8, 9]
Test Accuracy: 88.98% after classes [1, 2, 3, 4, 5, 6, 7, 8, 9]
Nodes before pruning :  17514
Nodes after pruning :  5000
Train Accuracy: 78.04% after classes [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
Test Accuracy: 78.09% after classes [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
Nodes before pruning :  17646
Nodes after pruning :  5000
Train Accuracy: 69.16% after classes [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
Test Accuracy: 69.27% after classes [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
Nodes before pruning :  17719
Nodes after pruning :  5000
Train Accuracy: 64.74% after classes [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 32, 33, 34, 35, 36, 37, 38

# Training on Core 50 Resnet18 Encoded without pruning,  final nodes amount for without pruing case, it's so much

In [None]:
dsdm = FASTER_DSDM(
    RT=0,
    Q=80000,
    buffer = 15000,
    beta=0.5,
    Lambda=0.0022,
    Lambda_RT=0.0025,
    n_feat=512,
    n_class=50,
    prune_method=0,
    n_neighbors=20,
    contamination=0.0,
)

train_data_seen = []
test_data_seen = []
class_seen = []

for train_x_batch, train_y_batch, test_x_batch, test_y_batch in tqdm(
    list(zip(splitted_train_x, splitted_train_y, splitted_test_x, splitted_test_y)),
    desc="Class-incremental Training"
):

    unique_classes = list(set(train_y_batch.tolist()))
    class_seen += unique_classes


    train_x_batch = jnp.expand_dims(train_x_batch, axis=0)
    train_y_batch = jnp.expand_dims(train_y_batch, axis=0)
    test_x_batch = jnp.expand_dims(test_x_batch, axis=0)
    test_y_batch = jnp.expand_dims(test_y_batch, axis=0)

    train_data_seen.append((train_x_batch, train_y_batch))
    test_data_seen.append((test_x_batch, test_y_batch))


    dsdm = dsdm.train(train_x_batch, train_y_batch)

    """
    if dsdm.K > dsdm.Q:
        print("Nodes before pruning : ", dsdm.K)
        dsdm = dsdm.prune(dsdm.Address, dsdm.Content, dsdm.K, dsdm.RT)
        print("Nodes after pruning : ", dsdm.K)
    """
    print("Current Nodes : ", dsdm.K)
    correct = 0
    total = 0
    for tx, ty in train_data_seen:
        preds = dsdm.inference(tx, dsdm.K)
        correct += (preds == ty).sum().item()
        total += ty.shape[1]
    train_acc = (correct / total) * 100
    print(f"Train Accuracy: {train_acc:.2f}% after classes {class_seen}")


    correct = 0
    total = 0
    for tx, ty in test_data_seen:
        preds = dsdm.inference(tx, dsdm.K)
        correct += (preds == ty).sum().item()
        total += ty.shape[1]
    test_acc = (correct / total) * 100
    print(f"Test Accuracy: {test_acc:.2f}% after classes {class_seen}")
    print("=" * 80)

Class-incremental Training:   0%|          | 0/5 [00:00<?, ?it/s]

Current Nodes :  11433
Train Accuracy: 100.00% after classes [1, 2, 3, 4, 5, 6, 7, 8, 9]
Test Accuracy: 99.93% after classes [1, 2, 3, 4, 5, 6, 7, 8, 9]
Current Nodes :  23965
Train Accuracy: 100.00% after classes [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
Test Accuracy: 99.94% after classes [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
Current Nodes :  36618
Train Accuracy: 99.99% after classes [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
Test Accuracy: 99.93% after classes [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
Current Nodes :  49339
Train Accuracy: 99.99% after classes [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 32, 33, 34, 35, 36, 37, 38, 39, 30, 31]
Test Accuracy: 99.92% after classes [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17