Skip to content

Commit

Permalink
- Adding implementation of Per-Location Normalization
Browse files Browse the repository at this point in the history
- Updating README.md
- Advancing version number to 0.1.2
- Adding whitespace to improve readability
- Fixing minor typos in docstrings
- Increasing version requirement on PyTorch to 1.8.0. Tests now performed at 1.13.1

PiperOrigin-RevId: 503432432
  • Loading branch information
james-martens authored and DKSdev committed Jan 20, 2023
1 parent dd7bfb0 commit 2cf357a
Show file tree
Hide file tree
Showing 16 changed files with 666 additions and 33 deletions.
40 changes: 22 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@

# Official Python package for Deep Kernel Shaping (DKS) and Tailored Activation Transformations (TAT)

This Python package implements the activation function transformations and
weight initializations used in Deep Kernel Shaping (DKS) and Tailored Activation
Transformations (TAT). DKS and TAT, which were introduced in the [DKS paper] and
[TAT paper], are methods for constructing/transforming neural networks to make
them much easier to train. For example, these methods can be used in conjunction
with K-FAC to train deep vanilla deep convnets (without skip connections or
normalization layers) as fast as standard ResNets of the same depth.
This Python package implements the activation function transformations, weight
initializations, and dataset preprocessing used in Deep Kernel Shaping (DKS) and
Tailored Activation Transformations (TAT). DKS and TAT, which were introduced in
the [DKS paper] and [TAT paper], are methods for constructing/transforming
neural networks to make them much easier to train. For example, these methods
can be used in conjunction with K-FAC to train deep vanilla deep convnets
(without skip connections or normalization layers) as fast as standard ResNets
of the same depth.

The package supports the JAX, PyTorch, and TensorFlow tensor programming
frameworks.
Expand All @@ -23,16 +24,18 @@ from Github will be rejected. Instead, please email us if you find a bug.
## Usage

For each of the supported tensor programming frameworks, there is a
corresponding subpackage which handles the activation function transformations
and weight initializations. (These are `dks.jax`, `dks.pytorch`, and
`dks.tensorflow`.) It's up to the user to import these and use them
appropriately within their model code. Activation functions are transformed by
the function `get_transformed_activations()` in the module
corresponding subpackage which handles the activation function transformations,
weight initializations, and (optional) data preprocessing. (These are `dks.jax`,
`dks.pytorch`, and `dks.tensorflow`.) It's up to the user to import these and
use them appropriately within their model code. Activation functions are
transformed by the function `get_transformed_activations()` in the module
`activation_transform` of the appropriate subpackage. Sampling initial
parameters is done using functions inside of the module
`parameter_sampling_functions` of said subpackage. Note that in order to avoid
having to import all of the tensor programming frameworks, the user is required
to individually import whatever framework subpackage they want. e.g. `import
`parameter_sampling_functions` of said subpackage. And data preprocessing is
done using the function `per_location_normalization` inside of the module
`data_preprocessing` of said subpackage. Note that in order to avoid having to
import all of the tensor programming frameworks, the user is required to
individually import whatever framework subpackage they want. e.g. `import
dks.jax`. Meanwhile, `import dks` won't actually do anything.

`get_transformed_activations()` requires the user to pass either the "maximal
Expand All @@ -52,9 +55,10 @@ weighted sums into "normalized sums" (which are weighted sums whose
non-trainable weights have a sum of squares equal to 1). See the section titled
"Summary of our method" of the [DKS paper] for more details.

Note that this package doesn't currently include an implementation of
Per-Location Normalization (PLN) data pre-processing. While not required for
CIFAR or ImageNet, PLN could potentially be important for other datasets. Also
Note that the data preprocessing method implemented, called Per-Location
Normalization (PLN), may not always be needed in practice, but we have observed
certain situations where not using can lead to problems. (For example, training
on datasets that contain all-zero pixels, such as CIFAR-10.) Also
note that ReLUs are only partially supported by DKS, and unsupported by TAT, and
so their use is *highly* discouraged. Instead, one should use Leaky ReLUs, which
are fully supported by DKS, and work especially well with TAT.
Expand Down
2 changes: 1 addition & 1 deletion dks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
# Do not directly import this package; it won't do anything. Instead, import one
# of the framework-specific subpackages.

__version__ = "0.1.1"
__version__ = "0.1.2"
36 changes: 36 additions & 0 deletions dks/examples/haiku/modified_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,15 @@ def __init__(
w_init: Optional[Any],
name: Optional[str] = None,
):

super().__init__(name=name)

self.use_projection = use_projection
self.use_batch_norm = use_batch_norm
self.shortcut_weight = shortcut_weight

if self.use_projection and self.shortcut_weight != 0.0:

self.proj_conv = hk.Conv2D(
output_channels=channels,
kernel_shape=1,
Expand All @@ -61,11 +64,13 @@ def __init__(
with_bias=not use_batch_norm,
padding="SAME",
name="shortcut_conv")

if use_batch_norm:
self.proj_batchnorm = hk.BatchNorm(
name="shortcut_batchnorm", **BN_CONFIG)

channel_div = 4 if bottleneck else 1

conv_0 = hk.Conv2D(
output_channels=channels // channel_div,
kernel_shape=1 if bottleneck else 3,
Expand All @@ -87,8 +92,10 @@ def __init__(
layers = (conv_0, conv_1)

if use_batch_norm:

bn_0 = hk.BatchNorm(name="batchnorm_0", **BN_CONFIG)
bn_1 = hk.BatchNorm(name="batchnorm_1", **BN_CONFIG)

bn_layers = (bn_0, bn_1)

if bottleneck:
Expand All @@ -112,23 +119,31 @@ def __init__(
self.activation = activation

def __call__(self, inputs, is_training, test_local_stats):

out = shortcut = inputs

if self.use_projection and self.shortcut_weight != 0.0:

shortcut = self.proj_conv(shortcut)

if self.use_batch_norm:
shortcut = self.proj_batchnorm(shortcut, is_training, test_local_stats)

for i, conv_i in enumerate(self.layers):

out = conv_i(out)

if self.use_batch_norm:
out = self.bn_layers[i](out, is_training, test_local_stats)

if i < len(self.layers) - 1: # Don't apply activation on last layer
out = self.activation(out)

if self.shortcut_weight is None:
return self.activation(out + shortcut)

elif self.shortcut_weight != 0.0:

return self.activation(
math.sqrt(1 - self.shortcut_weight**2) * out +
self.shortcut_weight * shortcut)
Expand All @@ -151,12 +166,15 @@ def __init__(
w_init: Optional[Any],
name: Optional[str] = None,
):

super().__init__(name=name)

self.use_projection = use_projection
self.use_batch_norm = use_batch_norm
self.shortcut_weight = shortcut_weight

if self.use_projection and self.shortcut_weight != 0.0:

self.proj_conv = hk.Conv2D(
output_channels=channels,
kernel_shape=1,
Expand All @@ -167,6 +185,7 @@ def __init__(
name="shortcut_conv")

channel_div = 4 if bottleneck else 1

conv_0 = hk.Conv2D(
output_channels=channels // channel_div,
kernel_shape=1 if bottleneck else 3,
Expand All @@ -188,11 +207,14 @@ def __init__(
layers = (conv_0, conv_1)

if use_batch_norm:

bn_0 = hk.BatchNorm(name="batchnorm_0", **BN_CONFIG)
bn_1 = hk.BatchNorm(name="batchnorm_1", **BN_CONFIG)

bn_layers = (bn_0, bn_1)

if bottleneck:

conv_2 = hk.Conv2D(
output_channels=channels,
kernel_shape=1,
Expand All @@ -205,8 +227,10 @@ def __init__(
layers = layers + (conv_2,)

if use_batch_norm:

bn_2 = hk.BatchNorm(name="batchnorm_2", **BN_CONFIG)
bn_layers += (bn_2,)

self.bn_layers = bn_layers

self.layers = layers
Expand All @@ -229,9 +253,11 @@ def __call__(self, inputs, is_training, test_local_stats):

if self.shortcut_weight is None:
return x + shortcut

elif self.shortcut_weight != 0.0:
return math.sqrt(
1 - self.shortcut_weight**2) * x + self.shortcut_weight * shortcut

else:
return x

Expand Down Expand Up @@ -272,13 +298,17 @@ def __init__(
name="block_%d" % (i)))

def __call__(self, inputs, is_training, test_local_stats):

out = inputs

for block in self.blocks:
out = block(out, is_training, test_local_stats)

return out


def check_length(length, value, name):

if len(value) != length:
raise ValueError(f"`{name}` must be of length 4 not {len(value)}")

Expand Down Expand Up @@ -481,12 +511,15 @@ def __init__(
self.logits = hk.Linear(num_classes, **logits_config)

def __call__(self, inputs, is_training, test_local_stats=False):

out = inputs
out = self.initial_conv(out)

if not self.resnet_v2:

if self.use_batch_norm:
out = self.initial_batchnorm(out, is_training, test_local_stats)

out = self.activation(out)

out = hk.max_pool(
Expand Down Expand Up @@ -525,15 +558,18 @@ def subnet_max_func(x, r_fn, depth, shortcut_weight, resnet_v2=True):

if bottleneck and resnet_v2:
res_fn = lambda z: r_fn(r_fn(r_fn(z)))

elif (not bottleneck and resnet_v2) or (bottleneck and not resnet_v2):
res_fn = lambda z: r_fn(r_fn(z))

else:
res_fn = r_fn

res_branch_subnetwork = res_fn(x)

for i in range(4):
for j in range(blocks_per_group[i]):

res_x = res_fn(x)

if j == 0 and use_projection[i] and resnet_v2:
Expand Down
1 change: 1 addition & 0 deletions dks/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
"""Subpackage for JAX."""

from dks.jax import activation_transform
from dks.jax import data_preprocessing
from dks.jax import haiku_initializers
from dks.jax import parameter_sampling_functions
Loading

0 comments on commit 2cf357a

Please sign in to comment.