In [1]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
from torch import nn
import torch
import seaborn as sns
from torch import optim
import torch.nn.functional as F
import math

# Data

In [299]:
n = 200
train_x = torch.randn(1000, n)
train_y = torch.randint(3, size=(1000,))

# Models

In [314]:
multiplier = 0.1

class AbsLinear(nn.Linear):
    ''' A linear module that always applies abs() on the weight '''
    def forward(self, input):
        return F.linear(input, self.weight.abs(), self.bias)
    
class Elliptical(nn.Linear):
    def __init__(self, *args, **kwargs):
        super(Elliptical, self).__init__(*args, **kwargs)
        kwargs['bias'] = False
        self._quadratic = AbsLinear(*args, **kwargs)
    def forward(self, input):
        linear_term = super(Elliptical, self).forward(input)
        quadratic_term = self._quadratic.forward(input*input)
#         print(quadratic_term.mean())
        return -multiplier * quadratic_term + linear_term + multiplier * math.sqrt(self.weight.shape[1])

def plot_model(model):
    val_x = np.linspace(0, 1, 100)
    val_y = np.linspace(0, 1, 100)
    val_x, val_y = np.meshgrid(val_x, val_y)
    input = np.stack([val_x.flatten(), val_y.flatten()], axis=-1)
    val_z = model(torch.from_numpy(input).float()).detach().numpy().reshape(val_x.shape)
    fig, ax = plt.subplots(figsize=(2, 2))
    ax.contourf(val_x, val_y, val_z, 10, cmap=plt.cm.bone, origin='lower')

def build_model(use_elliptical=False):
    layers = []
    for _ in range(20):
        layers.extend([
            Elliptical(n, n) if use_elliptical else nn.Linear(n, n),
            nn.ReLU(),
            nn.BatchNorm1d(n)
        ])
    return nn.Sequential(*(layers + [nn.Linear(n, 3)]))
    
def demo_model(use_elliptical=False, n_epochs=100, lr=0.01, report_interval=10):
    model = build_model(use_elliptical)
    # train
    lr = 0.5
    loss_func = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    plt.ion()
    for epoch in range(n_epochs):
        optimizer.zero_grad()
        preds = model(train_x)
        loss = loss_func(preds, train_y)
        loss.backward()
        optimizer.step()
        if report_interval > 0 and epoch % report_interval == 0:
            _, pred_y = torch.max(preds, dim=1)
            print('Acc:', (pred_y == train_y).float().mean().item())
#             plot_model(model)
#     plot_model(model)
    plt.ioff()
    return model

In [308]:
def get_weight(m):
    if isinstance(m, Elliptical):
        return torch.cat([m.weight, m._quadratic.weight])
    else:
        return m.weight
    
get_weight(Elliptical(2,1))

tensor([[ 0.1387, -0.2251],
        [ 0.6259, -0.3840]], grad_fn=<CatBackward>)

In [302]:
m = demo_model(use_elliptical=False)

Acc: 0.33000001311302185
Acc: 0.3709999918937683


KeyboardInterrupt: 

In [312]:
m = demo_model(use_elliptical=True)

tensor(7.0222, grad_fn=<MeanBackward0>)
tensor(7.0930, grad_fn=<MeanBackward0>)
tensor(7.0522, grad_fn=<MeanBackward0>)
tensor(7.0624, grad_fn=<MeanBackward0>)
tensor(7.0991, grad_fn=<MeanBackward0>)
tensor(7.0961, grad_fn=<MeanBackward0>)
tensor(7.0896, grad_fn=<MeanBackward0>)
tensor(7.0788, grad_fn=<MeanBackward0>)
tensor(7.0558, grad_fn=<MeanBackward0>)
tensor(7.0556, grad_fn=<MeanBackward0>)
tensor(7.0434, grad_fn=<MeanBackward0>)
tensor(7.0913, grad_fn=<MeanBackward0>)
tensor(7.0916, grad_fn=<MeanBackward0>)
tensor(7.0748, grad_fn=<MeanBackward0>)
tensor(7.0682, grad_fn=<MeanBackward0>)
tensor(7.0725, grad_fn=<MeanBackward0>)
tensor(7.0486, grad_fn=<MeanBackward0>)
tensor(7.0985, grad_fn=<MeanBackward0>)
tensor(7.0773, grad_fn=<MeanBackward0>)
tensor(7.0781, grad_fn=<MeanBackward0>)
Acc: 0.33799999952316284
tensor(7.0283, grad_fn=<MeanBackward0>)
tensor(7.0923, grad_fn=<MeanBackward0>)
tensor(7.0448, grad_fn=<MeanBackward0>)
tensor(7.0664, grad_fn=<MeanBackward0>)
tensor(7.0796, 

KeyboardInterrupt: 

In [273]:
def report_stats(use_elliptical=False):
    m = build_model(use_elliptical)
    preds = m(train_x)
    print("Mean activation size last layer:", preds.abs().mean())
    print("--> std:", preds.abs().std())
    loss_func = nn.CrossEntropyLoss()
    loss = loss_func(preds, train_y)
    loss.backward()
    linear_terms_grad = torch.cat([layer.weight.grad.flatten() 
                                   for layer in m if hasattr(layer, 'weight')])
    print("Mean gradient size of linear terms' weights:", linear_terms_grad.abs().mean())
    print("--> std:", linear_terms_grad.abs().std())    
    quadratic_terms_grad = [layer._quadratic.weight.grad.flatten() 
                            for layer in m if hasattr(layer, '_quadratic')]
    if len(quadratic_terms_grad) > 0:
        quadratic_terms_grad = torch.cat(quadratic_terms_grad)
        print("Mean gradient size of quadratic terms' weights:", quadratic_terms_grad.abs().mean())
        print("--> std:", quadratic_terms_grad.abs().std())        

In [292]:
report_stats(use_elliptical=False)

Mean activation size last layer: tensor(0.4366, grad_fn=<MeanBackward0>)
--> std: tensor(0.4079, grad_fn=<StdBackward0>)
Mean gradient size of linear terms' weights: tensor(0.0919)
--> std: tensor(0.1508)


In [313]:
report_stats(use_elliptical=True)

tensor(7.0027, grad_fn=<MeanBackward0>)
tensor(7.0817, grad_fn=<MeanBackward0>)
tensor(7.0575, grad_fn=<MeanBackward0>)
tensor(7.0511, grad_fn=<MeanBackward0>)
tensor(7.0601, grad_fn=<MeanBackward0>)
tensor(7.0851, grad_fn=<MeanBackward0>)
tensor(7.0749, grad_fn=<MeanBackward0>)
tensor(7.0586, grad_fn=<MeanBackward0>)
tensor(7.0686, grad_fn=<MeanBackward0>)
tensor(7.0444, grad_fn=<MeanBackward0>)
tensor(7.0583, grad_fn=<MeanBackward0>)
tensor(7.0703, grad_fn=<MeanBackward0>)
tensor(7.0539, grad_fn=<MeanBackward0>)
tensor(7.0856, grad_fn=<MeanBackward0>)
tensor(7.0520, grad_fn=<MeanBackward0>)
tensor(7.0715, grad_fn=<MeanBackward0>)
tensor(7.0537, grad_fn=<MeanBackward0>)
tensor(7.0935, grad_fn=<MeanBackward0>)
tensor(7.0740, grad_fn=<MeanBackward0>)
tensor(7.0701, grad_fn=<MeanBackward0>)
Mean activation size last layer: tensor(0.4343, grad_fn=<MeanBackward0>)
--> std: tensor(0.3282, grad_fn=<StdBackward0>)
Mean gradient size of linear terms' weights: tensor(0.0071)
--> std: tensor(0.0

In [239]:
report_stats(use_elliptical=True)

Mean activation size last layer: tensor(0.5348, grad_fn=<MeanBackward0>)
--> std: tensor(0.4389, grad_fn=<StdBackward0>)
Mean gradient size of linear terms' weights: tensor(9.6714)
--> std: tensor(46.4385)
Mean gradient size of quadratic terms' weights: tensor(1.1180)
--> std: tensor(4.4834)
