In [1]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
from torch import nn
import torch
import seaborn as sns
from torch import optim
import torch.nn.functional as F
import math

# Data

In [5]:
train_x = torch.randn(1000, 9)
train_y = torch.randint(3, size=(1000,))

# Models

In [43]:
multiplier = 0.1

class AbsLinear(nn.Linear):
    ''' A linear module that always applies abs() on the weight '''
    def forward(self, input):
        return F.linear(input, self.weight.abs(), self.bias)
    
class Elliptical(nn.Linear):
    def __init__(self, *args, **kwargs):
        super(Elliptical, self).__init__(*args, **kwargs)
        kwargs['bias'] = False
        self._quadratic = AbsLinear(*args, **kwargs)
    def forward(self, input):
        linear_term = super(Elliptical, self).forward(input)
        quadratic_term = self._quadratic.forward(input*input)
        return -multiplier * quadratic_term + linear_term

def plot_model(model):
    val_x = np.linspace(0, 1, 100)
    val_y = np.linspace(0, 1, 100)
    val_x, val_y = np.meshgrid(val_x, val_y)
    input = np.stack([val_x.flatten(), val_y.flatten()], axis=-1)
    val_z = model(torch.from_numpy(input).float()).detach().numpy().reshape(val_x.shape)
    fig, ax = plt.subplots(figsize=(2, 2))
    ax.contourf(val_x, val_y, val_z, 10, cmap=plt.cm.bone, origin='lower')

def build_model(use_elliptical=False):
    layers = []
    for _ in range(20):
        layers.extend([
            Elliptical(9, 9) if use_elliptical else nn.Linear(9, 9),
            nn.ReLU(),
            nn.BatchNorm1d(9)
        ])
    return nn.Sequential(*(layers + [nn.Linear(9, 3)]))
    
def demo_model(use_elliptical=False, n_epochs=100, lr=0.01, report_interval=10):
    model = build_model(use_elliptical)
    # train
    lr = 0.5
    loss_func = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    plt.ion()
    for epoch in range(n_epochs):
        optimizer.zero_grad()
        preds = model(train_x)
        loss = loss_func(preds, train_y)
        loss.backward()
        optimizer.step()
        if report_interval > 0 and epoch % report_interval == 0:
            _, pred_y = torch.max(preds, dim=1)
            print('Acc:', (pred_y == train_y).float().mean().item())
#             plot_model(model)
#     plot_model(model)
    plt.ioff()
    return model

In [14]:
def get_weight(m):
    if isinstance(m, Elliptical):
        return torch.cat([m.weight, m._quadratic.weight])
    else:
        return m.weight
    
get_weight(Elliptical(2,1))

tensor([[-0.0221, -0.0720],
        [ 0.0463,  0.6099]], grad_fn=<CatBackward>)

In [197]:
m = demo_model(use_elliptical=False)

Acc: 0.33899998664855957
Acc: 0.32600000500679016
Acc: 0.35600000619888306
Acc: 0.3700000047683716
Acc: 0.3869999945163727
Acc: 0.4050000011920929
Acc: 0.3840000033378601
Acc: 0.40400001406669617
Acc: 0.41999998688697815
Acc: 0.42100000381469727


In [198]:
m = demo_model(use_elliptical=True)

Acc: 0.3540000021457672
Acc: 0.3160000145435333
Acc: 0.3109999895095825
Acc: 0.34700000286102295
Acc: 0.34700000286102295
Acc: 0.34200000762939453
Acc: 0.33500000834465027
Acc: 0.33899998664855957
Acc: 0.34700000286102295
Acc: 0.3490000069141388


In [213]:
def report_stats(use_elliptical=False):
    m = build_model(use_elliptical)
    preds = m(train_x)
    print("Mean activation size last layer:", preds.abs().mean())
    print("--> std:", preds.abs().std())
    loss_func = nn.CrossEntropyLoss()
    loss = loss_func(preds, train_y)
    loss.backward()
    linear_terms_grad = torch.cat([layer.weight.grad.flatten() 
                                   for layer in m if hasattr(layer, 'weight')])
    print("Mean gradient size of linear terms' weights:", linear_terms_grad.abs().mean())
    print("--> std:", linear_terms_grad.abs().std())    
    quadratic_terms_grad = [layer._quadratic.weight.grad.flatten() 
                            for layer in m if hasattr(layer, '_quadratic')]
    if len(quadratic_terms_grad) > 0:
        quadratic_terms_grad = torch.cat(quadratic_terms_grad)
        print("Mean gradient size of quadratic terms' weights:", quadratic_terms_grad.abs().mean())
        print("--> std:", quadratic_terms_grad.abs().std())        

In [226]:
report_stats(use_elliptical=False)

Mean activation size last layer: tensor(0.3503, grad_fn=<MeanBackward0>)
--> std: tensor(0.2829, grad_fn=<StdBackward0>)
Mean gradient size of linear terms' weights: tensor(0.1711)
--> std: tensor(0.4079)


In [239]:
report_stats(use_elliptical=True)

Mean activation size last layer: tensor(0.5348, grad_fn=<MeanBackward0>)
--> std: tensor(0.4389, grad_fn=<StdBackward0>)
Mean gradient size of linear terms' weights: tensor(9.6714)
--> std: tensor(46.4385)
Mean gradient size of quadratic terms' weights: tensor(1.1180)
--> std: tensor(4.4834)
