# Continual Learning with Deep Artificial Neurons
https://arxiv.org/pdf/2011.07035.pdf

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [150]:
class Neuron(nn.Module):
    """
    Core of a single neural component. This can be thought of as a simple 2-layer ANN
    
    ::param ni: input layer dimension
    ::param nf: intermediate layer dimension
    ::nout_ni: next Neuron input layer dimension
    """
    def __init__(self, ni, nf, nout_ni, bias=True, act_fn=None):
        super(Neuron, self).__init__()
        self._ni = ni
        act_fn = nn.Tanh() if act_fn is None else act_fn
        assert isinstance(act_fn, nn.Module), 'activation must be of nn.Module instance'
        
        lin1 = nn.Linear(ni, nf, bias=bias)
        lin2 = nn.Linear(nf, nf, bias=bias)
        lin3 = nn.Linear(nf, nout_ni, bias=bias)
        
        bn1 = nn.BatchNorm1d(ni)
        bn2 = nn.BatchNorm1d(nf)
        
        layers = [bn1, lin1, act_fn, bn2, lin2, act_fn, lin3]
        
        self.core = nn.Sequential(*layers)
        
    def forward(self, x):
        assert x.shape[1] == self._ni, f'input to neuron ({x.shape[1]}) does not match input dimension size: ({self._ni})'
        return self.core(x)

In [151]:
class DANLayer(nn.Module):
    """
    A single DANLayer is composed of n number of Neurons.
    
    ::param neurons: number of neurons in given layer
    ::param n_ni: dendritic dimension, intermediate layers take concatenated output from previous layer
    ::param n_nf: intermediate neuron dimension
    ::param n_out_ni: axonal dimension
    ::param p_neurons: number of neurons in previous layer 
    """
    def __init__(self, neurons, n_ni, n_nf, n_nout_ni, p_neurons=1, bias=True, act_fn=None):
        super(DANLayer, self).__init__()
        self._n = neurons
        self._ni = n_ni*p_neurons
        self.neurons = nn.ModuleList([Neuron(n_ni*p_neurons, n_nf, n_nout_ni, bias=bias, act_fn=act_fn) for _ in range(neurons)])
        
    def forward(self, x):
        """
        ::param x: dimensions -> (batch, features)
        for intermediate layers, features will be the concatenated output
        """
        b, f = x.size()
        assert f==self._ni, f'input dimension ({f}) do not match neuron input dimension ({self._ni})'
        
        outs = [self.neurons[i](x) for i in range(self._n)]
        return torch.cat(outs, dim=1)

In [152]:
class DAN(nn.Module):
    """
    Hardcodes DAN architecture. This follows the paper https://arxiv.org/pdf/2011.07035.pdf
    A simple 3 layer DAN
    ::param ni: number of dimension for first DANLayer
    ::param num_classes: number of classes to predict
    """
    def __init__(self, ni, num_classes, act=None):
        super(DAN, self).__init__()
        """
        NOTE: n_out_ni must match next layers n_ni
        TODO: make this dynamic
        """
        self.dan1 = DANLayer(neurons=3, n_ni=ni, n_nf=50, n_nout_ni=100, p_neurons=1, act_fn=act)
        self.dan2 = DANLayer(neurons=2, n_ni=100, n_nf=50, n_nout_ni=25, p_neurons=3, act_fn=act)
        self.dan3 = DANLayer(neurons=1, n_ni=25, n_nf=10, n_nout_ni=num_classes, p_neurons=2, act_fn=act)
        
    def forward(self, x):
        """
        ::param x: shape (batch, features)
        """
        x = self.dan1(x)
        x = self.dan2(x)
        return self.dan3(x)

In [153]:
def train(epochs, model, opt, loss_fn, device='cuda:0'):
    device = torch.device(device)
    model = model.to(device)
    dummy_data = [(torch.randn(64, 100), torch.randn(64, 1)) for i in range(100)]
    for i in range(epochs):
        for dd in dummy_data:
            x, y = dd
            x = x.to(device)
            y = y.to(device)
            out = model(x)
            loss = loss_fn(out, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            
        print(f'Epoch: {i+1}/{epochs}, loss: {loss.item()}')

In [154]:
def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

In [155]:
seed_everything(42)

model = DAN(100, 1, act=nn.ReLU())
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

train(100, model, optimizer, loss_fn)

Epoch: 1/100, loss: 0.840283215045929
Epoch: 2/100, loss: 0.758152961730957
Epoch: 3/100, loss: 0.7313737869262695
Epoch: 4/100, loss: 0.6388325691223145
Epoch: 5/100, loss: 0.5919957160949707
Epoch: 6/100, loss: 0.3751508891582489
Epoch: 7/100, loss: 0.34337562322616577
Epoch: 8/100, loss: 0.21609780192375183
Epoch: 9/100, loss: 0.20937995612621307
Epoch: 10/100, loss: 0.12115730345249176
Epoch: 11/100, loss: 0.11588992178440094
Epoch: 12/100, loss: 0.1057678684592247
Epoch: 13/100, loss: 0.11734393239021301
Epoch: 14/100, loss: 0.07674159109592438
Epoch: 15/100, loss: 0.0907539576292038
Epoch: 16/100, loss: 0.06652944535017014
Epoch: 17/100, loss: 0.05083984136581421
Epoch: 18/100, loss: 0.051033034920692444
Epoch: 19/100, loss: 0.05533038079738617
Epoch: 20/100, loss: 0.07391571253538132
Epoch: 21/100, loss: 0.06154697760939598
Epoch: 22/100, loss: 0.09060288220643997
Epoch: 23/100, loss: 0.0421791598200798
Epoch: 24/100, loss: 0.0698050707578659
Epoch: 25/100, loss: 0.0479958504438

In [164]:
seed_everything(42)

basic_model = nn.Sequential(
    nn.BatchNorm1d(100),
    nn.Linear(100, 350),
    nn.ReLU(),
    nn.BatchNorm1d(350),
    nn.Linear(350, 100),
    nn.ReLU(),
    nn.BatchNorm1d(100),
    nn.Linear(100, 100),
    nn.ReLU(),
    nn.Linear(100, 1)
)

optimizer = optim.Adam(basic_model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

train(100, basic_model, optimizer, loss_fn)

Epoch: 1/100, loss: 0.9728431701660156
Epoch: 2/100, loss: 0.828559398651123
Epoch: 3/100, loss: 0.6443156003952026
Epoch: 4/100, loss: 0.2395050972700119
Epoch: 5/100, loss: 0.16950803995132446
Epoch: 6/100, loss: 0.2235107123851776
Epoch: 7/100, loss: 0.1715463399887085
Epoch: 8/100, loss: 0.14578860998153687
Epoch: 9/100, loss: 0.11118602752685547
Epoch: 10/100, loss: 0.08071592450141907
Epoch: 11/100, loss: 0.10418432950973511
Epoch: 12/100, loss: 0.09245992451906204
Epoch: 13/100, loss: 0.06690025329589844
Epoch: 14/100, loss: 0.05668084695935249
Epoch: 15/100, loss: 0.07833646237850189
Epoch: 16/100, loss: 0.09935720264911652
Epoch: 17/100, loss: 0.054995857179164886
Epoch: 18/100, loss: 0.09818050265312195
Epoch: 19/100, loss: 0.05502917245030403
Epoch: 20/100, loss: 0.05882522091269493
Epoch: 21/100, loss: 0.05111636966466904
Epoch: 22/100, loss: 0.05144578218460083
Epoch: 23/100, loss: 0.10784360021352768
Epoch: 24/100, loss: 0.09862090647220612
Epoch: 25/100, loss: 0.04022400

In [166]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

78901

In [167]:
sum(p.numel() for p in basic_model.parameters() if p.requires_grad)

81751