In [1]:
import torch
import torch.nn as nn
import torch.functional as F

from torch.optim import SGD, Adam
from torch.autograd import Variable

import os
import numpy as np

## Model Definition

In [2]:
# fc network
layers = []
layers.append(nn.Linear(784, 512))
layers.append(nn.ReLU())
layers.append(nn.Linear(512, 256))
layers.append(nn.ReLU())
layers.append(nn.Linear(256, 128))
layers.append(nn.ReLU())
layers.append(nn.Linear(128, 10))
layers.append(nn.LogSoftmax(dim=1))
fc_net = nn.Sequential(*layers)

In [43]:
# conv network
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

conv_net = Net()

## Initializations

In [4]:
def xavier_init(model):
    for m in model:
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.xavier_normal(m.weight)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant(m.weight, 1)
            nn.init.constant(m.bias, 0)
            
def kaiming_init(model):
    for m in model:
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.kaiming_normal(m.weight)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant(m.weight, 1)
            nn.init.constant(m.bias, 0)

def orthogonal_init(model):
    for m in model:
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.orthogonal(m.weight)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant(m.weight, 1)
            nn.init.constant(m.bias, 0)

def selu_init(model):
    for m in model:
        if isinstance(m, nn.Conv2d):
            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            nn.init.normal(m.weight, 0, sqrt(1. / n))
        elif isinstance(m, nn.Linear):
            n = m.out_features
            nn.init.normal(m.weight, 0, sqrt(1. / n))
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant(m.weight, 1)
            nn.init.constant(m.bias, 0)

## Regularizations

In [5]:
def l2_reg(model):
    l2_loss = Variable(torch.FloatTensor(1), requires_grad=True)
    for W in model.parameters():
        l2_loss = l2_loss + (0.5 * W.norm(2) ** 2)
    return l2_loss

def l1_reg(model):
    l1_loss = Variable(torch.FloatTensor(1), requires_grad=True)
    for W in model.parameters():
        l1_loss = l1_loss + W.norm(1)
    return l1_loss

def orthogonal_reg(model):
    orth_loss = Variable(torch.FloatTensor(1), requires_grad=True)
    for W in model.parameters():
        W_reshaped = W.view(W.shape[0], -1)
        sym = torch.mm(W_reshaped, torch.t(W_reshaped))
        sym -= Variable(torch.eye(W_reshaped.shape[0]))
        orth_loss = orth_loss + sym.sum()
    return orth_loss

def max_norm(model, max_val=3, eps=1e-8):
    """
    Rescale weight vector by (c/L) if L2 norm 
    greater than `max_val`.
    """
    for name, param in model.named_parameters():
        if 'bias' not in name:
            norm = param.norm(2, dim=0, keepdim=True) ** 2
            desired = torch.clamp(norm, 0, max_val)
            param = param * (desired / (eps + norm))

## Test

In [79]:
model = Net()

In [None]:
orth_loss = Variable(torch.FloatTensor(1), requires_grad=True)
for name, param in model.named_parameters():
    if 'bias' not in name:
        W_reshaped = W.view(W.shape[0], -1)
        sym = torch.mm(W_reshaped, torch.t(W_reshaped))
        sym -= Variable(torch.eye(W_reshaped.shape[0]))
        orth_loss = orth_loss + sym.sum()