In [1]:
import torch
import math
from collections import OrderedDict
import torch.nn.functional as F

In [2]:
class Parameter(torch.Tensor):
    def __new__(cls, data=None, requires_grad=True):
        if data is None:
            data = torch.Tensor()
        return torch.Tensor._make_subclass(cls, data, requires_grad)
    
    def __deepcopy__(self, memo):
        if id(self) in memo:
            return memo[id(self)]
        else:
            result = type(self)(self.data.clone(memory_format=torch.preserve_format),
                                self.requires_grad)
            memo[id(self)] = result
            return result
        
    def __repr__(self):
        return 'Parameter containing:\n' + super(Parameter, self).__repr__()
    
    def __reduce_ex__(self, proto):
        # See Note [Don't serialize hooks]
        return (
            torch.utils._rebuild_parameter,
            (self.data, self.requires_grad, OrderedDict())
        )

In [3]:
class simpleModule:
    
    def __init__(self):
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._modules = OrderedDict()
        
    def __getattr__(self, name):
        if '_parameters' in self.__dict__:
            _parameters = self.__dict__['_parameters']
            if name in _parameters:
                return _parameters[name]
        if '_buffers' in self.__dict__:
            _buffers = self.__dict__['_buffers']
            if name in _buffers:
                return _buffers[name]
        if '_modules' in self.__dict__:
            modules = self.__dict__['_modules']
            if name in modules:
                return modules[name]
        raise AttributeError("'{}' object has no attribute '{}'".format(
            type(self).__name__, name))
        
    def __setattr__(self, name, value):
        if isinstance(value, Parameter):
            self.register_parameter(name, value)
        elif isinstance(value, simpleModule):
            modules = self.__dict__.get('_modules')
            modules[name] = value
        else:
            buffers = self.__dict__.get('_buffers')
            if (buffers is not None) and (name in buffers):
                if (value is not None) and (not isinstance(value, torch.Tensor)):
                    raise TypeError('Nope.')
                    buffers[name] = value
            else:
                object.__setattr__(self, name, value)
        
    def register_parameter(self, name, value):
        self._parameters[name] = value