# NN Weights Wrapper

In [1]:
import torch
from torch.autograd import Variable
import json
import hashlib
import math
from numbers import Number

In [2]:
class WrapedWeights():
    def __init__(
        self,
        params
    ):
        """Wrap NN's weight and its methods.

        Parameters
        ----------
        params : WrapedWeights or dict or generator
            Weight information of neural networks (models) .
        """

        if isinstance(params, WrapedWeights):
            params = params.to_dict()
        elif isinstance(params, dict):
            pass
        else:
            try:
                params = dict(params)
            except:
                raise ValueError("params must be `WrapedWeights (dict)` or `generator` which is retured by named_parameters() but {}.".format(type(params)))

        self.params = params  # dict

    def to_dict(self):
        return self.params
    
    """container
    
    TBA
    """

    def keys(self):
        return self.params.keys()

    def values(self):
        return self.params.values()

    def items(self):
        return self.params.items()

    def __getitem__(self, key):
        if type(key) != str:
            raise TypeError("key must be a `str` but {}.".format(type(key)))
        if key not in self.params.keys():
            raise KeyError("key '{}' is not in params.".format(key))
        return self.params[key]

    def __setitem__(self, key, value):
        self.params[key] = value

    def __delitem__(self, key):
        del self.params[key]

    def __iter__(self):
        return self.params.__iter__()

    def __contains__(self, key):
        return self.params.__contains__(key)

    """arithmetic
    
    Following under-bar (_) means in-place version.
    """

    # -x
    def neg(self):
        res = dict()
        for key, value in self.items():
            res[key] = -1 * value.data
        return WrapedWeights(res)

    def neg_(self):
        self.params = self.neg()

    def __neg__(self):
        return self.neg()

    # x + (y: dict or WrapedWeights)
    # or
    # x + (y: Number)
    def add(self, other):
        res = dict()

        if isinstance(other, dict) or isinstance(other, WrapedWeights):
            for key, value in self.items():
                if key not in other.keys():
                    raise KeyError("'{}' is not in a argument.".format(key))
                res[key] = value.add(other[key].data)
        elif isinstance(other, Number):
            s = Variable(torch.Tensor([other]).double())
            for key, value in self.items():
                res[key] = value.add(s.expand(value.size()))
        else:
            raise TypeError("The argument must be `WrapedWeights (dict)` or `Number` but {}.".format(type(other)))

        return WrapedWeights(res)

    def add_(self, other):
        self.params = self.add(other)

    def __add__(self, other):
        return self.add(other)

    # x - y
    def sub(self, other):
        # return self.add(-other)
        res = dict()

        if isinstance(other, dict) or isinstance(other, WrapedWeights):
            for key, value in self.items():
                if key not in other.keys():
                    raise KeyError("'{}' is not in a argument.".format(key))
                res[key] = value.sub(other[key].data)
        elif isinstance(other, Number):
            s = Variable(torch.Tensor([other]).double())
            for key, value in self.items():
                res[key] = value.sub(s.expand(value.size()))
        else:
            raise TypeError("The argument must be `WrapedWeights (dict)` or `Number` but {}.".format(type(other)))

        return WrapedWeights(res)

    def sub_(self, other):
        self.params = self.sub(other)

    def __sub__(self, other):
        return self.sub(other)

    # x * (y: dict or WrapedWeights): Hadamard product
    # or
    # x * (y: Number): scalar multiplication
    def mul(self, other):
        res = dict()

        if isinstance(other, dict) or isinstance(other, WrapedWeights):
            for key, value in self.items():
                if key not in other.keys():
                    raise KeyError("'{}' is not in a argument.".format(key))
                res[key] = value.mul(other[key].data)
        elif isinstance(other, Number):
            s = Variable(torch.Tensor([other]).double())
            for key, value in self.items():
                res[key] = value.mul(s.expand(value.size()))
        else:
            raise TypeError("The argument must be `WrapedWeights (dict)` or `Number` but {}.".format(type(other)))

        return WrapedWeights(res)

    def mul_(self, other):
        self.params = self.mul(other)

    def __mul__(self, other):
        return self.mul(other)

    # x / (y: dict or WrapedWeights): inverse of Hadamard product
    # or
    # x / (y: Number): inverse of scalar multiplication
    def div(self, other):
        res = dict()

        if isinstance(other, dict) or isinstance(other, WrapedWeights):
            for key, value in self.items():
                if key not in other.keys():
                    raise KeyError("'{}' is not in a argument.".format(key))
                res[key] = value.div(other[key].data)
        elif isinstance(other, Number):
            # s = Variable(torch.Tensor([other]).cuda().double())  # TODO
            s = Variable(torch.Tensor([other]).double())
            for key, value in self.items():
                res[key] = value.div(s.expand(value.size()))
        else:
            raise TypeError("The argument must be `WrapedWeights (dict)` or `Number` but {}.".format(type(other)))

        return WrapedWeights(res)

    def div_(self, other):
        self.params = self.div(other)

    def __truediv__(self, other):
        return self.div(other)

    # x // (y: dict or WrapedWeights): element-wise floor_divide
    # or
    # x // (y: Number): floor_divide with scalar
    def floor_divide(self, other):
        res = dict()

        if isinstance(other, dict) or isinstance(other, WrapedWeights):
            for key, value in self.items():
                if key not in other.keys():
                    raise KeyError("'{}' is not in a argument.".format(key))
                res[key] = value.floor_divide(other[key].data)
        elif isinstance(other, Number):
            s = Variable(torch.Tensor([other]).double())
            for key, value in self.items():
                res[key] = value.floor_divide(s.expand(value.size()))
        else:
            raise TypeError("The argument must be `WrapedWeights (dict)` or `Number` but {}.".format(type(other)))

        return WrapedWeights(res)

    def floor_divide_(self, other):
        self.params = self.floor_divide(other)

    def __floordiv__(self, other):
        return self.floor_divide(other)

    # x % (y: dict or WrapedWeights): element-wise mod operator
    # or
    # x % (y: Number): mod operator with scalar
    def remainder(self, other):
        res = dict()

        if isinstance(other, dict) or isinstance(other, WrapedWeights):
            for key, value in self.items():
                if key not in other.keys():
                    raise KeyError("'{}' is not in a argument.".format(key))
                res[key] = value.remainder(other[key].data)
        elif isinstance(other, Number):
            s = Variable(torch.Tensor([other]).double())
            for key, value in self.items():
                res[key] = value.remainder(s.expand(value.size()))
        else:
            raise TypeError("The argument must be `WrapedWeights (dict)` or `Number` but {}.".format(type(other)))

        return WrapedWeights(res)

    def remainder_(self, other):
        self.params = self.remainder(other)

    def __mod__(self, other):
        return self.remainder(other)

    # divmod()
    def __divmod__(self, other):
        return (self.div(other), self.remainder(other))

    # x ** (y: dict or WrapedWeights): element-wise
    # or
    # x ** (y: Number): power of scalar
    def pow(self, other):
        res = dict()

        if isinstance(other, dict) or isinstance(other, WrapedWeights):
            for key, value in self.items():
                if key not in other.keys():
                    raise KeyError("'{}' is not in a argument.".format(key))
                res[key] = value.pow(other[key].data)
        elif isinstance(other, Number):
            s = Variable(torch.Tensor([other]).double())
            for key, value in self.items():
                res[key] = value.pow(s.expand(value.size()))
        else:
            raise TypeError("The argument must be `WrapedWeights (dict)` or `Number` but {}.".format(type(other)))

        return WrapedWeights(res)

    def pow_(self, other):
        self.params = self.pow(other)

    def __pow__(self, other):
        return self.pow(other)

    # round()
    def round(self):
        res = dict()
        for key, value in self.items():
            res[key] = value.round()
        return WrapedWeights(res)

    def round_(self):
        self.params = self.round()

    def __round__(self):
        return self.round()

    """cmp

    Using Frobenius L2 norm.
    """

    # x < y
    def __lt__(self, other):
        return Frobenius(self) < Frobenius(other)

    # x <= y
    def __le__(self, other):
        return Frobenius(self) <= Frobenius(other)

    # x > y
    def __gt__(self, other):
        return Frobenius(self) > Frobenius(other)

    # x >= y
    def __ge__(self, other):
        return Frobenius(self) >= Frobenius(other)

    # x == y
    def __eq__(self, other):
        return Frobenius(self) == Frobenius(other)

    # x != y
    def __ne__(self, other):
        return Frobenius(self) != Frobenius(other)

    """type

    TBA
    """

    def __str__(self):
        res = dict()
        for key, value in self.items():
            res[key] = value.tolist()

        return json.dumps(res)

    def hash(self):
        return hashlib.sha256(str(self).encode()).hexdigest()

    """copy

    TBA
    """

    # copy
    def _copy(self, other):
        if not isinstance(other, WrapedWeights):
            raise TypeError("The argument must be `Weight` but {}.".format(type(other)))
        return other.params  # deepcopy

    def copy_(self, other):
        self.params = self._copy(other)

    """tensors

    TBA
    """

    # zeros
    def _zeros(self):
        res = dict()
        for key, value in self.items():
            res[key] = torch.zeros_like(value)
        return res

    def zeros(self):
        return WrapedWeights(self._zeros())

    def zeros_(self):
        self.params = self._zeros()

    # ones
    def _ones(self):
        res = dict()
        for key, value in self.items():
            res[key] = torch.ones_like(value)
        return res

    def ones(self):
        return WrapedWeights(self._ones())

    def ones_(self):
        self.params = self._ones()

    # fill and full
    def _pack(self, value):
        res = dict()
        for key, elem in self.items():
            res[key] = torch.empty_like(elem).fill_(value)
        return res

    def fill_(self, value):
        self.params = self._pack(value)

    def full(self, value):
        return WrapedWeights(self._pack(value))

    # empty
    def _empty(self):
        res = dict()
        for key, value in self.items():
            res[key] = torch.empty_like(value)
        return res

    def empty_(self):
        self.params = self._empty()

    def empty(self):
        return WrapedWeights(self._empty())

    """random

    TBA
    """

    def _rand(self):
        res = dict()
        for key, value in self.items():
            res[key] = torch.rand_like(value)
        return res

    def rand_(self):
        self.params = self._rand()

    def rand(self):
        return WrapedWeights(self._rand())

    def _randn(self):
        res = dict()
        for key, value in self.items():
            res[key] = torch.randn_like(value)
        return res

    def randn_(self):
        self.params = self._randn()

    def randn(self):
        return WrapedWeights(self._randn())

    def randint_(self, high):
        self.params = self._randint(high)

    def _randint(self, high):
        res = dict()
        for key, value in self.items():
            res[key] = torch.randint_like(value, high)
        return res

    def randint(self, high):
        return WrapedWeights(self._randint(high))

    """TODO
    
    - type
    - cat
    - split
    """

    def apply(
        self,
        net
    ):
        state_dict = net.state_dict()
        for name, param in state_dict.items():
            if name in self.keys():
                state_dict[name].copy_(self[name])

## Distance

In [3]:
def FilterNorm(weights):
    # Filter-wise Normalization
    # See 'GradVis: Visualization and Second Order Analysis of Optimization Surfaces during the Training of Deep Neural Networks'.

    theta = Frobenius(weights)

    res = dict()
    for key, value in weights.items():
        d = Frobenius(WrapedWeights({key: value}))
        d += 1e-10  # Ref. https://github.com/tomgoldstein/loss-landscape/blob/master/net_plotter.py#L111
        res[key] = value.div(d).mul(theta)

    if isinstance(weights, WrapedWeights):
        return WrapedWeights(res)
    else:
        return res

In [4]:
def Frobenius(weights, base_weights=None):
    # Frobenius Norm.
    base_weights = base_weights or weights.zeros()
    square = ((weights - base_weights) ** 2)

    total = 0.
    for key, value in square.items():
        total += torch.sum(value).item()

    return math.sqrt(total)

# Main

In [5]:
if __name__ == "__main__":
    import import_ipynb
    from nets import resnet18

    net1 = resnet18(num_classes=10)
    w1 = WrapedWeights(net1.named_parameters())
    w2 = WrapedWeights(net1.named_parameters()) + 2

    # cmp
    print(w2 <= w2)
    
    # distance
    print(Frobenius(w1))
    print(Frobenius(FilterNorm(w1)))
    print(Frobenius(w1, w2))
    print(Frobenius(w1, w1))

    # rand
    w1.randn_()
    print(w1['fc.weight'][0][0:5])
    print(dict(net1.named_parameters())['fc.weight'][0][0:5])
    w1.apply(net1)
    print(dict(net1.named_parameters())['fc.weight'][0][0:5])

importing Jupyter notebook from nets.ipynb
True
111.75378386949036
724.2472993004975
6687.7924608947005
0.0
tensor([ 0.0150, -0.9943,  0.0685, -1.3592, -0.0428])
tensor([ 0.0001, -0.0081,  0.0240,  0.0008,  0.0052], grad_fn=<SliceBackward>)
tensor([ 0.0150, -0.9943,  0.0685, -1.3592, -0.0428], grad_fn=<SliceBackward>)
