This notebook explores the trainability of a deep ReLog network (using an old formulation: $relog = log_n(x+1/n)+1$).

In [1]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
from torch import nn
import torch
import seaborn as sns
from torch import optim
import torch.nn.functional as F
import math

In [2]:
n = 200
train_x = torch.randn(1000, n)
train_y = torch.randint(3, size=(1000,))

In [29]:
log_strength = 0.5

class ReLog(nn.Module):
    def __init__(self, n=10):
        super(ReLog, self).__init__()
        self.n = n

    def forward(self, input):
        effective_log_strength = max(0, log_strength)
        linear_term = F.relu(input)
        relog_func = lambda x: torch.log(F.relu(x) + 1/self.n) / math.log(self.n) + 1
        log_term = relog_func(input + effective_log_strength)
        return log_term * effective_log_strength + linear_term * (1-effective_log_strength)

In [11]:
def build_model(use_relog=False):
    layers = []
    for _ in range(20):
        layers.extend([
            nn.Linear(n, n),
            ReLog() if use_relog else nn.ReLU(),
            nn.BatchNorm1d(n)
        ])
    return nn.Sequential(*(layers + [nn.Linear(n, 3)]))
    
def demo_model(use_relog=False, n_epochs=5, lr=0.01, report_interval=1):
    model = build_model(use_relog)
    # train
    lr = 0.5
    loss_func = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    for epoch in range(n_epochs):
        optimizer.zero_grad()
        preds = model(train_x)
        loss = loss_func(preds, train_y)
        loss.backward()
        optimizer.step()
        if report_interval > 0 and epoch % report_interval == 0:
            _, pred_y = torch.max(preds, dim=1)
            print('Acc:', (pred_y == train_y).float().mean().item())
    return model

In [19]:
relu_model = demo_model(use_relog=False)

Acc: 0.30300000309944153
Acc: 0.33899998664855957
Acc: 0.3109999895095825
Acc: 0.3140000104904175
Acc: 0.3070000112056732


In [20]:
relog_model = demo_model(use_relog=True)

Acc: 0.32199999690055847
Acc: 0.35499998927116394
Acc: 0.2939999997615814
Acc: 0.3610000014305115
Acc: 0.30000001192092896


In [30]:
def report_stats(use_relog=False):
    m = build_model(use_relog)
    preds = m(train_x)
    print("Mean activation size last layer:", preds.abs().mean())
    print("--> std:", preds.abs().std())
    loss_func = nn.CrossEntropyLoss()
    loss = loss_func(preds, train_y)
    loss.backward()
    weight_grad = torch.cat([layer.weight.grad.flatten() 
                             for layer in m if hasattr(layer, 'weight')])
    print("Mean gradient size of weights:", weight_grad.abs().mean())
    print("--> std:", weight_grad.abs().std())    

In [31]:
report_stats(relu_model)

Mean activation size last layer: tensor(0.4362, grad_fn=<MeanBackward0>)
--> std: tensor(0.3328, grad_fn=<StdBackward0>)
Mean gradient size of weights: tensor(0.1243)
--> std: tensor(0.2441)


In [32]:
report_stats(relog_model)

Mean activation size last layer: tensor(0.4737, grad_fn=<MeanBackward0>)
--> std: tensor(0.3555, grad_fn=<StdBackward0>)
Mean gradient size of weights: tensor(0.1290)
--> std: tensor(0.2552)
