Skip to content

Commit

Permalink
Merge pull request #560 from lscheinkman/RES-2372
Browse files Browse the repository at this point in the history
RES-2372: Add profiler and GPU Optimizations to dendrite experiments
  • Loading branch information
lscheinkman committed Aug 20, 2021
2 parents edff022 + 1b05f3d commit 1299011
Show file tree
Hide file tree
Showing 8 changed files with 216 additions and 23 deletions.
41 changes: 24 additions & 17 deletions nupic/research/frameworks/dendrites/functional/apply_dendrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"""

from collections import namedtuple
from typing import Optional

import torch

Expand All @@ -50,6 +51,7 @@
"""


@torch.jit.script
def dendritic_bias_1d(y, dendrite_activations):
"""
Returns the sum of the feedforward output and the max of the dendrite
Expand All @@ -65,7 +67,24 @@ def dendritic_bias_1d(y, dendrite_activations):
return dendrite_output(y + winning_activations, indices)


def dendritic_gate_1d(y, dendrite_activations, indices=None):
@torch.jit.script
def gather_activations(dendrite_activations, indices):
"""
Gathers dendritic activations from the given indices.
:param indices: tensor of indices of winning segments;
shape of batch_size x num_units
:param indices: tensor of dendritic activations;
shape of batch_size x num_units x num_segments
"""
unsqueezed = indices.unsqueeze(dim=2)
dendrite_activations = torch.gather(dendrite_activations, dim=2, index=unsqueezed)
dendrite_activations = dendrite_activations.squeeze(dim=2)
return dendrite_activations


@torch.jit.script
def dendritic_gate_1d(y, dendrite_activations, indices: Optional[torch.Tensor] = None):
"""
Returns the product of the feedforward output and sigmoid of the the max
of the dendrite activations along each segment.
Expand All @@ -87,6 +106,7 @@ def dendritic_gate_1d(y, dendrite_activations, indices=None):
return dendrite_output(y * torch.sigmoid(winning_activations), indices)


@torch.jit.script
def dendritic_absolute_max_gate_1d(y, dendrite_activations):
"""
Returns the product of the feedforward output and the sigmoid of the
Expand All @@ -101,7 +121,8 @@ def dendritic_absolute_max_gate_1d(y, dendrite_activations):
return dendritic_gate_1d(y, dendrite_activations, indices=indices)


def dendritic_gate_2d(y, dendrite_activations, indices=None):
@torch.jit.script
def dendritic_gate_2d(y, dendrite_activations, indices: Optional[torch.Tensor] = None):
"""
Returns the output of the max gating convolutional dendritic layer by
multiplying all values in each output channel by the selected dendrite
Expand Down Expand Up @@ -136,6 +157,7 @@ def dendritic_gate_2d(y, dendrite_activations, indices=None):
return dendrite_output(y_gated, indices)


@torch.jit.script
def dendritic_absolute_max_gate_2d(y, dendrite_activations):
"""
Returns the output of the absolute max gating convolutional dendritic layer by
Expand All @@ -155,18 +177,3 @@ def dendritic_absolute_max_gate_2d(y, dendrite_activations):
"""
indices = dendrite_activations.abs().max(dim=2).indices
return dendritic_gate_2d(y, dendrite_activations, indices=indices)


def gather_activations(dendrite_activations, indices):
"""
Gathers dendritic activations from the given indices.
:param indices: tensor of indices of winning segments;
shape of batch_size x num_units
:param indices: tensor of dendritic activations;
shape of batch_size x num_units x num_segments
"""
unsqueezed = indices.unsqueeze(dim=2)
dendrite_activations = torch.gather(dendrite_activations, dim=2, index=unsqueezed)
dendrite_activations = dendrite_activations.squeeze(dim=2)
return dendrite_activations
16 changes: 12 additions & 4 deletions nupic/research/frameworks/dendrites/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def train_dendrite_model(
:param progress_bar: Unused
"""
model.train()
# Use asynchronous GPU copies when the memory is pinned
# See https://pytorch.org/docs/master/notes/cuda.html
async_gpu = loader.pin_memory
context = None
if context_vector is not None:
# Tile context vector
Expand All @@ -96,7 +99,9 @@ def train_dendrite_model(
data, context = data
else:
data, _ = data
data = data.to(device, non_blocking=async_gpu)
data = data.flatten(start_dim=1)
target = target.to(device, non_blocking=async_gpu)

if train_context_fn is not None:
context = train_context_fn(data)
Expand All @@ -106,12 +111,15 @@ def train_dendrite_model(
if share_labels:
target = target % num_labels

data = data.to(device)
target = target.to(device)
if context is not None:
context = context.to(device)
context = context.to(device, non_blocking=async_gpu)

# FIXME: Pytorch 1.7: Replace with optimizer.zero_grad(set_to_none=True)
# optimizer.zero_grad(set_to_none=True)
for group in optimizer.param_groups:
for p in group["params"]:
p.grad = None

optimizer.zero_grad()
forward_args = [data] if context is None else [data, context]
output = model(*forward_args)
if active_classes is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def reset_parameters(self):
init_linear_(weight, bias)

def rezero_weights(self):
self.weights.data[self.zero_mask.bool()] = 0
self.weights.data.masked_fill_(self.zero_mask.bool(), 0)

def forward(self, context):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def set_active_tasks(self, tasks):
self.indices = np.concatenate([self.task_indices[t] for t in self.active_tasks])

def __iter__(self):
return (self.indices[i] for i in torch.randperm(len(self.indices)))
return (self.indices[i] for i in np.random.permutation(len(self.indices)))

def __len__(self):
return len(self.indices)
1 change: 1 addition & 0 deletions nupic/research/frameworks/vernon/mixins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from .save_final_checkpoint import SaveFinalCheckpoint
from .si import *
from .step_based_logging import *
from .torch_profiler import *
from .track_representation_sparsity import *
from .update_boost_strength import UpdateBoostStrength
from .update_dendrite_boost_strength import UpdateDendriteBoostStrength
Expand Down
99 changes: 99 additions & 0 deletions nupic/research/frameworks/vernon/mixins/torch_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# ----------------------------------------------------------------------
# Numenta Platform for Intelligent Computing (NuPIC)
# Copyright (C) 2021, Numenta, Inc. Unless you have an agreement
# with Numenta, Inc., for a separate license for this software code, the
# following terms and conditions apply:
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero Public License version 3 as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU Affero Public License for more details.
#
# You should have received a copy of the GNU Affero Public License
# along with this program. If not, see http://www.gnu.org/licenses.
#
# http://numenta.org/licenses/
# ----------------------------------------------------------------------

import os

import pkg_resources
import torch

from nupic.research.frameworks.vernon.experiments import SupervisedExperiment

__all__ = ["TorchProfilerMixin", "inject_torch_profiler_mixin"]


class TorchProfilerMixin:
"""
Mixin class enabling profiling via pytorch's native profiler.
See https://pytorch.org/docs/stable/profiler.html
.. note::
Requires pytorch 1.8.1 or higher
"""

def __init__(self, *args, **kwargs):
pkg_resources.require("torch>=1.8.1")
super().__init__(*args, **kwargs)
self._profiler = None

def setup_experiment(self, config):
super().setup_experiment(config)

# Whether or not to export chrome trace
self._export_chrome_trace = config.get("export_chrome_trace", False)

# Default profiler args
self._profiler_args = config.get("profiler", {
"with_stack": True,
"record_shapes": True,
"schedule": torch.profiler.schedule(wait=1, warmup=1, active=5)
})

def train_epoch(self):
profiler_path = os.path.join(self.logdir, "profiler")
# Default profiler output to tensorboard.
# Requires `torch-tb-profiler` tensorboard plugin
profiler_args = {
**self._profiler_args,
"on_trace_ready": torch.profiler.tensorboard_trace_handler(profiler_path)
}
with torch.profiler.profile(**profiler_args) as prof:
self._profiler = prof
super().train_epoch()

if self._export_chrome_trace and self._profiler is not None:
self._profiler.export_chrome_trace(profiler_path)

self._profiler = None

def post_batch(self, *args, **kwargs):
super().post_batch(*args, **kwargs)
if self._profiler is not None:
self._profiler.step()

@classmethod
def get_execution_order(cls):
eo = super().get_execution_order()
eo["setup_experiment"].append("TorchProfilerMixin initialization")
eo["train_epoch"].insert(0, "TorchProfilerMixin begin")
eo["post_batch"].append("TorchProfilerMixin step")
eo["train_epoch"].append("TorchProfilerMixin end")
return eo


def inject_torch_profiler_mixin(experiment_class):
"""
Injects torch profiler mixin to the given experiment class
"""
assert issubclass(experiment_class, SupervisedExperiment)
return type(
f"Profile{experiment_class.__name__}", (TorchProfilerMixin, experiment_class),
{}
)
2 changes: 2 additions & 0 deletions projects/dendrites/permutedMNIST/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .centroid import CONFIGS as CENTROID
from .hyperparameter_search import CONFIGS as HYPERPARAMETERSEARCH
from .no_dendrites import CONFIGS as NO_DENDRITES
from .profiler import CONFIGS as PROFILER
from .si_centroid import CONFIGS as SI_CENTROID
from .sp_context import CONFIGS as SP_CONTEXT
from .sp_context_search import CONFIGS as SP_PROTO
Expand All @@ -42,6 +43,7 @@
CONFIGS.update(CENTROID)
CONFIGS.update(HYPERPARAMETERSEARCH)
CONFIGS.update(NO_DENDRITES)
CONFIGS.update(PROFILER)
CONFIGS.update(SI_CENTROID)
CONFIGS.update(SP_CONTEXT)
CONFIGS.update(SP_PROTO)
76 changes: 76 additions & 0 deletions projects/dendrites/permutedMNIST/experiments/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Numenta Platform for Intelligent Computing (NuPIC)
# Copyright (C) 2021, Numenta, Inc. Unless you have an agreement
# with Numenta, Inc., for a separate license for this software code, the
# following terms and conditions apply:
#
# This program is free software you can redistribute it and/or modify
# it under the terms of the GNU Affero Public License version 3 as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU Affero Public License for more details.
#
# You should have received a copy of the GNU Affero Public License
# along with this program. If not, see htt"://www.gnu.org/licenses.
#
# http://numenta.org/licenses/
#

"""
Experiments profiling dendrites experiments
"""
from copy import deepcopy

import torch

from nupic.research.frameworks.vernon import mixins

from .centroid import CENTROID_10, CENTROID_50

__all__ = ["CONFIGS"]

PROFILER_ARGS = {
"with_stack": True,
"record_shapes": False,
"schedule": torch.profiler.schedule(wait=1, warmup=1, active=5),
}
WANDB_ARGS = {
"project": "dendrite_baselines",
"group": "profiler",
"notes": "Profiler for dendrite network",
}
CENTROID_10_PROFILER = deepcopy(CENTROID_10)
experiment_class = CENTROID_10_PROFILER["experiment_class"]
CENTROID_10_PROFILER.update(
experiment_class=mixins.inject_torch_profiler_mixin(experiment_class),
epochs=1,
num_samples=1,
profiler=PROFILER_ARGS,
wandb_args=WANDB_ARGS,
)

CENTROID_10_ONE_SEGMENT_PROFILER = deepcopy(CENTROID_10_PROFILER)
CENTROID_10_ONE_SEGMENT_PROFILER["model_args"].update(num_segments=1)

CENTROID_50_PROFILER = deepcopy(CENTROID_50)
experiment_class = CENTROID_50_PROFILER["experiment_class"]
CENTROID_50_PROFILER.update(
experiment_class=mixins.inject_torch_profiler_mixin(experiment_class),
epochs=1,
num_samples=1,
profiler=PROFILER_ARGS,
wandb_args=WANDB_ARGS,
)

CENTROID_10_TWO_SEGMENT_PROFILER = deepcopy(CENTROID_10_PROFILER)
CENTROID_10_TWO_SEGMENT_PROFILER["model_args"].update(num_segments=2)

# Export configurations in this file
CONFIGS = dict(
centroid_10_profiler=CENTROID_10_PROFILER,
centroid_10_one_segment_profiler=CENTROID_10_ONE_SEGMENT_PROFILER,
centroid_10_two_segment_profiler=CENTROID_10_TWO_SEGMENT_PROFILER,
centroid_50_profiler=CENTROID_50_PROFILER,
)

0 comments on commit 1299011

Please sign in to comment.