Skip to content

Commit

Permalink
Fix to allow creation of custom crypten.nn networks inheriting direct…
Browse files Browse the repository at this point in the history
…ly from crypten.nn.Module (#165)

Summary:
Pull Request resolved: fairinternal/CrypTen#165

This fix allows the creation of custom crypten.nn networks directly, analogous to the creation of the custom PyTorch networks.

It ensures that any crypten.nn submodule created within a subclass of a crypten.nn.Module is registered within the calling module. This is achieved by redefining `__getattr__` and `__setattr__` similar to PyTorch:
`__setattr__`: https://github.com/pytorch/pytorch/blob/2171f910531be28f7d5dd8e6ab8bff3a5486e6fd/torch/nn/modules/module.py#L578
`__getattr__`: https://github.com/pytorch/pytorch/blob/2171f910531be28f7d5dd8e6ab8bff3a5486e6fd/torch/nn/modules/module.py#L562

This fixes the bug identified by T58626019.

Reviewed By: shubho

Differential Revision: D18839679

fbshipit-source-id: c01ce8aa1d60cf99cef7da72f68558c3f3b77157
  • Loading branch information
Shobha Venkataraman authored and facebook-github-bot committed Dec 6, 2019
1 parent f26b80c commit 9f7a2e0
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 5 deletions.
61 changes: 57 additions & 4 deletions crypten/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from collections import OrderedDict

import crypten
import torch.nn
from crypten.autograd_cryptensor import AutogradCrypTensor
Expand All @@ -16,9 +18,9 @@ class Module:
"""

def __init__(self):
self._parameters = {}
self._buffers = {}
self._modules = {}
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._modules = OrderedDict()
self.encrypted = False
self.train()

Expand Down Expand Up @@ -158,9 +160,9 @@ def named_buffers(self, recurse=True):

def _apply(self, fn):
"""Applies a function recursively on all modules."""
fn(self)
for module in self.modules():
module._apply(fn)
fn(self)
return self

def encrypt(self, mode=True, src=0):
Expand Down Expand Up @@ -204,6 +206,57 @@ def decrypt(self):
"""Decrypts model."""
return self.encrypt(mode=False)

def __getattr__(self, name):
"""Redefine __getattr__ so that any parameters, modules or buffers
inside the Module object can be accessed as attributes
"""
if "_parameters" in self.__dict__:
parameters = self.__dict__["_parameters"]
if name in parameters:
return parameters[name]
if "_modules" in self.__dict__:
modules = self.__dict__["_modules"]
if name in modules:
return modules[name]
if "_buffers" in self.__dict__:
buffers = self.__dict__["_buffers"]
if name in buffers:
return buffers[name]
raise AttributeError(
"'{}' object has no attribute '{}'".format(type(self).__name__, name)
)

def __setattr__(self, name, value):
"""Redefine __setattr__ so that any submodules created
inside the Module object are registered with _modules
OrderedDict.
"""

def remove_from(*dicts):
for d in dicts:
if name in d:
del d[name]

modules = self.__dict__.get("_modules")
if isinstance(value, Module):
if modules is None:
raise AttributeError(
"cannot assign module before Module.__init__() call"
)
remove_from(self.__dict__, self._parameters, self._buffers)
modules[name] = value
elif modules is not None and name in modules:
if value is not None:
raise TypeError(
"cannot assign '{}' as child module '{}' "
"(torch.nn.Module or None expected)".format(
torch.typename(value), name
)
)
modules[name] = value
else:
object.__setattr__(self, name, value)


class Container(Module):
"""
Expand Down
120 changes: 119 additions & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ def test_non_pytorch_modules(self):
}
# loop over all modules:
for module_name in module_args.keys():

# create encrypted CrypTen module:
encr_module = getattr(crypten.nn, module_name)(*module_args[module_name])
encr_module.encrypt()
Expand Down Expand Up @@ -578,6 +577,35 @@ def test_losses(self):
)
self._check(encrypted_loss, loss, "cross-entropy loss failed")

def test_getattr_setattr(self):
"""Tests the __getattr__ and __setattr__ functions"""

tensor1 = get_random_test_tensor(size=(3, 3), is_float=True)
tensor2 = get_random_test_tensor(size=(3, 3), is_float=True)

class ExampleNet(crypten.nn.Module):
def __init__(self):
super(ExampleNet, self).__init__()
self.fc1 = crypten.nn.Linear(20, 1)
sample_buffer = tensor1
self.register_buffer("sample_buffer", sample_buffer)
sample_param = tensor2
self.register_parameter("sample_param", sample_param)

def forward(self, x):
out = self.fc1(x)
return out

model = ExampleNet()
model.encrypt()

self.assertTrue("fc1" in model._modules.keys(), "modules __setattr__ failed")
self._check(model.sample_buffer, tensor1, "buffer __getattr__ failed")
self._check(model.sample_param, tensor2, "parameter __getattr__ failed")
self.assertTrue(
isinstance(model.fc1, crypten.nn.Linear), "modules __getattr__ failed"
)

def test_training(self):
"""
Tests training of simple model in crypten.nn.
Expand Down Expand Up @@ -634,6 +662,96 @@ def test_training(self):
model.update_parameters(learning_rate)
self._check_reference_parameters("", reference, model)

def test_custom_module_training(self):
"""Tests training CrypTen models created directly using the crypten.nn.Module"""

class ExampleNet(crypten.nn.Module):
def __init__(self):
super(ExampleNet, self).__init__()
self.fc1 = crypten.nn.Linear(20, 5)
self.fc2 = crypten.nn.Linear(5, 2)

def forward(self, x):
out = self.fc1(x)
out = self.fc2(out)
return out

model = ExampleNet()

batch_size = 5
x_orig = get_random_test_tensor(size=(batch_size, 20), is_float=True)
y_orig = (
get_random_test_tensor(size=(batch_size, 1), is_float=True).gt(0).long()
)
y_one_hot = onehot(y_orig, num_targets=2)

# encrypt training sample:
x_train = AutogradCrypTensor(crypten.cryptensor(x_orig))
y_train = crypten.cryptensor(y_one_hot)

for loss_name in ["BCELoss", "CrossEntropyLoss", "MSELoss"]:
# create loss function
loss = getattr(crypten.nn, loss_name)()

# create encrypted model
model.train()
model.encrypt()

num_epochs = 3
learning_rate = 0.001

for i in range(num_epochs):
output = model(x_train)
if loss_name == "MSELoss":
output_norm = output
else:
output_norm = output.softmax(1)
loss_value = loss(output_norm, y_train)

# set gradients to "zero"
model.zero_grad()
for param in model.parameters():
self.assertIsNone(param.grad, "zero_grad did not reset gradients")

# perform backward pass:
loss_value.backward()
for param in model.parameters():
if param.requires_grad:
self.assertIsNotNone(
param.grad, "required parameter gradient not created"
)

# update parameters
orig_parameters, upd_parameters = {}, {}
orig_parameters = self._compute_reference_parameters(
"", orig_parameters, model, 0
)
model.update_parameters(learning_rate)
upd_parameters = self._compute_reference_parameters(
"", upd_parameters, model, learning_rate
)

parameter_changed = False
for name, value in orig_parameters.items():
if param.requires_grad and param.grad is not None:
unchanged = torch.allclose(upd_parameters[name], value)
if unchanged is False:
parameter_changed = True
self.assertTrue(
parameter_changed, "no parameter changed in training step"
)

# record initial and current loss
if i == 0:
orig_loss = loss_value.get_plain_text()
curr_loss = loss_value.get_plain_text()

# check that the loss has decreased after training
self.assertTrue(
curr_loss.item() < orig_loss.item(),
"loss has not decreased after training",
)

def test_from_pytorch_training(self):
"""Tests the from_pytorch code path for training CrypTen models"""
import torch.nn as nn
Expand Down

0 comments on commit 9f7a2e0

Please sign in to comment.