Skip to content

Commit

Permalink
Merge pull request #515 from karangrewal/vernon-dendrites-2
Browse files Browse the repository at this point in the history
Centroid method for inferring context signal
  • Loading branch information
Karan Grewal committed May 18, 2021
2 parents bfae744 + dd20307 commit a042d3e
Show file tree
Hide file tree
Showing 11 changed files with 441 additions and 284 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class DendriteContinualLearningExperiment(ContinualLearningExperiment):

def setup_experiment(self, config):
super().setup_experiment(config)
self.context_vector = None

def run_task(self):
"""
Expand All @@ -46,7 +47,8 @@ def run_task(self):
# Run epochs, inner loop
# TODO: return the results from run_epoch
self.current_epoch = 0
for _ in range(self.epochs):
for e in range(self.epochs):
self.logger.info("Training task %d, epoch %d...", self.current_task, e)
self.run_epoch()

# TODO: put back evaluation_metrics from cl_experiment
Expand Down
25 changes: 21 additions & 4 deletions nupic/research/frameworks/dendrites/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def train_dendrite_model(
share_labels=False,
num_labels=None,
active_classes=None,
context_vector=None,
post_batch_callback=None,
complexity_loss_fn=None,
batches_in_epoch=None,
Expand Down Expand Up @@ -65,6 +66,8 @@ def train_dendrite_model(
:param active_classes: List of indices of the heads that are active for a given
task; only relevant if this function is being used in a
continual learning scenario
:param context_vector: If not None, use this context vector in place of any that
we get from the loader
:param post_batch_callback: Callback function to be called after every batch
with the following parameters: model, batch_idx
:param complexity_loss_fn: Unused
Expand All @@ -75,12 +78,19 @@ def train_dendrite_model(
"""
model.train()
context = None
if context_vector is not None:
# Tile context vector
context = context_vector.repeat(loader.batch_size, 1)

for batch_idx, (data, target) in enumerate(loader):

# `data` may be a 2-item list comprising the example data and context signal in
# case context is explicitly provided
if isinstance(data, list):
data, context = data
if context_vector is None:
data, context = data
else:
data, _ = data
data = data.flatten(start_dim=1)

# Since labels are shared, target values should be in
Expand Down Expand Up @@ -114,10 +124,11 @@ def evaluate_dendrite_model(
model,
loader,
device,
criterion=F.nll_loss,
criterion=F.cross_entropy,
share_labels=False,
num_labels=None,
active_classes=None,
infer_context_fn=None,
batches_in_epoch=None,
complexity_loss_fn=None,
progress=None,
Expand All @@ -143,6 +154,8 @@ def evaluate_dendrite_model(
:param active_classes: List of indices of the heads that are active for a given
task; only relevant if this function is being used in a
continual learning scenario
:infer_context_fn: A function that computes the context vector to use given a batch
of data samples
:param batches_in_epoch: Unused
:param complexity_loss_fn: Unused
:param progress: Unused
Expand All @@ -155,6 +168,7 @@ def evaluate_dendrite_model(
loss = torch.tensor(0., device=device)
correct = torch.tensor(0, device=device)

infer_context = (infer_context_fn is not None)
context = None

with torch.no_grad():
Expand All @@ -174,6 +188,9 @@ def evaluate_dendrite_model(

data = data.to(device)
target = target.to(device)
if infer_context:
# Use `infer_context_fn` to retrieve the context vector
context = infer_context_fn(data)
if context is not None:
context = context.to(device)

Expand All @@ -188,9 +205,9 @@ def evaluate_dendrite_model(
total += len(data)

results = {
"total_correct": correct,
"total_correct": correct.item(),
"total_tested": total,
"mean_loss": loss / total if total > 0 else 0,
"mean_loss": loss.item() / total if total > 0 else 0,
"mean_accuracy": torch.true_divide(correct, total).item() if total > 0 else 0,
}
return results
7 changes: 5 additions & 2 deletions nupic/research/frameworks/dendrites/modules/dendritic_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class DendriticMLP(nn.Module):
:param num_segments: the number of dendritic segments that each hidden unit has
:param dim_context: the size of the context input to the network
:param kw: whether to apply k-Winners to the outputs of each hidden layer
:param kw_percent_on: percent of hidden units activated by K-winners.
:param kw_percent_on: percent of hidden units activated by K-winners. If 0, use ReLU
:param context_percent_on: percent of non-zero units in the context input.
:param dendrite_weight_sparsity: the sparsity level of dendritic weights.
:param weight_sparsity: the sparsity level of feed-forward weights.
Expand Down Expand Up @@ -84,9 +84,12 @@ def __init__(
# "modified"
assert weight_init in ("kaiming", "modified")
assert dendrite_init in ("kaiming", "modified")
assert kw_percent_on >= 0.0
assert kw_percent_on >= 0.0 and kw_percent_on < 1.0
assert context_percent_on >= 0.0

if kw_percent_on == 0.0:
kw = False

super().__init__()

if num_segments == 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def setup_experiment(self, config):

# Set train and validate methods.
self.train_model = config.get("train_model_func", train_model)
self.train_model_args = config.get("train_model_args", None)
self.train_model_args = config.get("train_model_args", {})
self.evaluate_model = config.get("evaluate_model_func", evaluate_model)
self.tasks_to_validate = config.get("tasks_to_validate",
range(self.num_tasks - 3,
Expand Down
2 changes: 2 additions & 0 deletions nupic/research/frameworks/vernon/mixins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# http://numenta.org/licenses/
# ----------------------------------------------------------------------

from .centroid_context import *
from .composite_loss import CompositeLoss
from .configure_optimizer_param_groups import ConfigureOptimizerParamGroups
from .constrain_parameters import ConstrainParameters
Expand All @@ -42,6 +43,7 @@
from .maxup import MaxupPerSample, MaxupStandard
from .multi_cycle_lr import MultiCycleLR
from .oml import OnlineMetaLearning
from .permuted_mnist_task_indices import *
from .profile import Profile
from .profile_autograd import ProfileAutograd
from .prune_low_magnitude import PruneLowMagnitude
Expand Down
134 changes: 134 additions & 0 deletions nupic/research/frameworks/vernon/mixins/centroid_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# ----------------------------------------------------------------------
# Numenta Platform for Intelligent Computing (NuPIC)
# Copyright (C) 2021, Numenta, Inc. Unless you have an agreement
# with Numenta, Inc., for a separate license for this software code, the
# following terms and conditions apply:
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero Public License version 3 as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU Affero Public License for more details.
#
# You should have received a copy of the GNU Affero Public License
# along with this program. If not, see http://www.gnu.org/licenses.
#
# http://numenta.org/licenses/
# ----------------------------------------------------------------------

import abc

import torch

from nupic.research.frameworks.dendrites import (
evaluate_dendrite_model,
train_dendrite_model,
)

__all__ = [
"CentroidContext",
"compute_centroid",
"infer_centroid",
]


class CentroidContext(metaclass=abc.ABCMeta):
"""
When training a dendritic network, use the centroid method for computing context
vectors (that dendrites receive as input) for both training and inference.
"""

def setup_experiment(self, config):
# Since the centroid vector is an element-wise mean of individual data samples,
# it's necessarily the same dimension as the input
model_args = config.get("model_args")
dim_context = model_args.get("dim_context")
input_size = model_args.get("input_size")

assert dim_context == input_size, ("For centroid experiments `dim_context` "
"must match `input_size`")

super().setup_experiment(config)

# Store batch size
self.batch_size = config.get("batch_size", 1)

# Tensor for accumulating each task's centroid vector
self.contexts = torch.zeros((0, self.model.input_size))
self.contexts = self.contexts.to(self.device)

# The following will point to the the 'active' context vector used to train on
# the current task
self.context_vector = None

def run_task(self):
self.train_loader.sampler.set_active_tasks(self.current_task)

# Construct a context vector by computing the centroid of all training examples
self.context_vector = compute_centroid(self.train_loader).to(self.device)
self.contexts = torch.cat((self.contexts, self.context_vector.unsqueeze(0)))

return super().run_task()

def train_epoch(self):
# TODO: take out constants in the call below. How do we determine num_labels?
train_dendrite_model(
model=self.model,
loader=self.train_loader,
optimizer=self.optimizer,
device=self.device,
criterion=self.error_loss,
share_labels=True,
num_labels=10,
context_vector=self.context_vector,
post_batch_callback=self.post_batch_wrapper,
)

def validate(self, loader=None):
if loader is None:
loader = self.val_loader

# TODO: take out constants in the call below
return evaluate_dendrite_model(model=self.model,
loader=loader,
device=self.device,
criterion=self.error_loss,
share_labels=True, num_labels=10,
infer_context_fn=infer_centroid(self.contexts))


def compute_centroid(loader):
"""
Returns the centroid vector of all samples iterated over in `loader`.
"""
centroid_vector = torch.zeros([])
n_centroid = 0
for x, _ in loader:
if isinstance(x, list):
x = x[0]
x = x.flatten(start_dim=1)
n_x = x.size(0)

centroid_vector = centroid_vector + x.sum(dim=0)
n_centroid += n_x

centroid_vector /= n_centroid
return centroid_vector


def infer_centroid(contexts):
"""
Returns a function that takes a batch of test examples and returns a 2D array where
row i gives the the centroid vector closest to the ith test example.
"""

def _infer_centroid(data):
context = torch.cdist(contexts, data)
context = context.argmin(dim=0)
context = contexts[context]
return context

return _infer_centroid
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# ----------------------------------------------------------------------
# Numenta Platform for Intelligent Computing (NuPIC)
# Copyright (C) 2021, Numenta, Inc. Unless you have an agreement
# with Numenta, Inc., for a separate license for this software code, the
# following terms and conditions apply:
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero Public License version 3 as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU Affero Public License for more details.
#
# You should have received a copy of the GNU Affero Public License
# along with this program. If not, see http://www.gnu.org/licenses.
#
# http://numenta.org/licenses/
# ----------------------------------------------------------------------

import math
from collections import defaultdict

__all__ = [
"PermutedMNISTTaskIndices",
]


class PermutedMNISTTaskIndices:
"""
A mixin that overwrites `compute_task_indices` when using permutedMNIST to allow
for much faster dataset initialization. Note that this mixin may not work with
other datasets.
"""

@classmethod
def compute_task_indices(cls, config, dataset):
# Assume dataloaders are already created
class_indices = defaultdict(list)
for idx in range(len(dataset)):
target = _get_target(dataset, idx)
class_indices[target].append(idx)

# Defines how many classes should exist per task
num_tasks = config.get("num_tasks", 1)
num_classes = config.get("num_classes", None)
assert num_classes is not None, "num_classes should be defined"
num_classes_per_task = math.floor(num_classes / num_tasks)

task_indices = defaultdict(list)
for i in range(num_tasks):
for j in range(num_classes_per_task):
task_indices[i].extend(class_indices[j + (i * num_classes_per_task)])
return task_indices


def _get_target(dataset, idx):
target = int(dataset.targets[idx % len(dataset.data)])
task_id = dataset.get_task_id(idx)
target += 10 * task_id
return target
Loading

0 comments on commit a042d3e

Please sign in to comment.