# Hoyer Training Tutorial

This tutorial begins by exploring the cuba_hoyer neuron and its associated Hoywe spike function, providing a foundational understanding of their functionality. 

Following this introduction, we will delve into enabling Hoyer training, which encompasses the utilization of Hoyer Neurons, relevant layers, and Hoyer Regularization techniques. 

For illustrative purposes, we will employ the VGG16 network architecture and the CIFAR10 dataset as our primary examples throughout this guide.

## 1. The structure of Hoyer Neuron

### 1.1 init function

```python
import torch
import torch.nn as nn
from lava.lib.dl.slayer.neuron import neuron_params, Neuron
class HoyerNeuron(Neuron):
    def __init__(
        self, threshold, current_decay, voltage_decay,
        tau_grad=1, scale_grad=1, scale=1 << 6, norm=None, dropout=None,
        shared_param=True, persistent_state=False, requires_grad=False,
        graded_spike=False, num_features=1, T=1, hoyer_type='sum', momentum=0.9, delay=False
    ):
        super(HoyerNeuron, self).__init__(...)
        
        self.learnable_thr = nn.Parameter(torch.FloatTensor([self.threshold]), requires_grad=True)
        self.T = T
        self.hoyer_type = hoyer_type
        self.num_features = num_features
        self.momentum = 0.9
        if self.num_features > 1:
            self.bn = nn.BatchNorm2d(num_features=self.num_features)
        self.delay = delay

        if self.num_features > 1: 
            if self.hoyer_type == 'sum':
                self.register_buffer('running_hoyer_ext', torch.zeros([1, 1, 1, 1, T]))
            else:
                self.register_buffer('running_hoyer_ext', torch.zeros([1, self.num_features, 1, 1, T]))
        else:
            self.register_buffer('running_hoyer_ext', torch.zeros([1, 1, T]))

        self.clamp()
```

The Hoyer neuron extends the functionality of the Cuba neuron. In addition to inheriting the initialization properties of the Cuba neuron, the Hoyer neuron incorporates a trainable threshold to enhance performance. When connected to a Convolutional (Conv) layer, it also integrates a Batch Normalization (BN) layer. Furthermore, the Hoyer neuron requires the time step and the number of input features to determine the dimensions of the `running_hoyer_ext`, which also depends on the specified Hoyer type.

### 1.2 Spike function

```python
def spike(self, voltage, hoyer_ext=1.0):
        spike = HoyerSpike.apply(
            voltage,
            hoyer_ext,
            self.tau_rho * TAU_RHO_MULT,
            self.scale_rho * SCALE_RHO_MULT,
            self.graded_spike,
            self.voltage_state,
            # self.s_scale,
            1,
        )

        if self.persistent_state is True:
            with torch.no_grad():
                self.voltage_state = leaky_integrator.persistent_state(
                    self.voltage_state, spike[..., -1]
                ).detach().clone()

        if self.drop is not None:
            spike = self.drop(spike)

        return spike
```

The spike function is almost the same as the original one except we apply the `HoyerSpike` function.

### 1.3 Calculate Hoyer loss $$ \frac{(\sum |x|)^2}{\sum x^2 }$$

```python
def cal_hoyer_loss(self, x, thr=None):
        if thr:
            x[x>thr] = thr
        x[x<0.0] = 0.0
        # avoid division by zero
        return (torch.sum(torch.abs(x))**2) / (torch.sum(x**2) + 1e-9)
```

### 1.4 Forward function

```python
def forward(self, input):
        if self.num_features > 1 and hasattr(self, 'bn'):
            B,C,H,W,T = input.shape
            input = self.bn(input.permute(4,0,1,2,3).reshape(T*B,C, H, W).contiguous())\
                .reshape(T,B,C,H,W).permute(1,2,3,4,0).contiguous()
        _, voltage = self.dynamics(input)
        self.hoyer_loss = self.cal_hoyer_loss(torch.clamp(voltage.clone(), min=0.0, max=1.0), 1.0)
        voltage = voltage / self.learnable_thr
        if self.training:
            clamped_input = torch.clamp(voltage.clone().detach(), min=0.0, max=1.0)
            dim = tuple(range(clamped_input.ndim-1))
            if self.hoyer_type == 'sum':
                hoyer_ext = torch.sum(clamped_input**2, dim=dim) / 
                (torch.sum(torch.abs(clamped_input), dim=dim))
            else:
                hoyer_ext = torch.sum((clamped_input)**2, dim=(0,2,3), keepdim=True) / 
                torch.sum(torch.abs(clamped_input), dim=(0,2,3), keepdim=True)

            hoyer_ext = torch.nan_to_num(hoyer_ext, nan=1.0)
            with torch.no_grad():
                
                if self.delay:
                    # delay hoyer ext
                    self.running_hoyer_ext[..., 0] = 0
                    self.running_hoyer_ext = torch.roll(self.running_hoyer_ext, shifts=-1, dims=-1)
                    self.running_hoyer_ext = self.momentum * hoyer_ext + (1 - self.momentum) * self.running_hoyer_ext
                else:
                    # do not delay hoyer ext
                    self.running_hoyer_ext = self.momentum * hoyer_ext + (1 - self.momentum) * self.running_hoyer_ext
                
        else:
            hoyer_ext = self.running_hoyer_ext
        output = self.spike(voltage, hoyer_ext)
        return output
```

In the forward pass, if the preceding synapse is a Conv layer, a Batch Normalization (BN) layer is applied first. Then, the dynamics function calculates the voltage, which is also used to calculate the Hoyer loss. The voltage is normalized by a trainable threshold, diverging from the fixed value of 1. Subsequently, running_hoyer_ext is updated akin to BN behavior. Finally, spikes are generated from the voltage using the spike function.

$$ Ext_{hoyer} = \frac{\sum x^2 }{\sum |x|} $$

## 2. The structure of Hoyer Spike Function

```python
def _hoyer_spike_backward(
    voltage, threshold, tau_rho, scale_rho,
    graded_spike=False
):
    grad_inp = torch.zeros_like(voltage).cuda()

    grad_inp[voltage > 0] = 1.0
    grad_inp[voltage > 2.0] = 0.0

    return 0.5 * grad_inp


class HoyerSpike(Spike):
    derivative = None
    @staticmethod
    def backward(ctx, grad_spikes):
        voltage, threshold, tau_rho, scale_rho, graded_spike \
            = ctx.saved_tensors
        graded_spike = True if graded_spike > 0.5 else False
        return (
            _hoyer_spike_backward(
                voltage, threshold, tau_rho, scale_rho, graded_spike
            ) * grad_spikes,
            None, None, None, None, None, None
        )
```

In the forward part, we compare the normalized voltage with the Hoyer Ext to decide whether to emit spikes. For the backward part, the surrogate function is shown in the figure. (The scale is 0.5 in the code.)

<img src="figs/hoyer_spike_bp.png" width="300" height="200">


## 3. Training example

### 3.1 Import libraries

In [None]:
import os
import sys
import h5py
import copy
import torch
import datetime
import numpy as np
from torch import nn
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import torch.nn.functional as F

# import slayer from lava-dl
import lava.lib.dl.slayer as slayer


### 3.2 Create Dataset

Since CIFAR10 is a common static dataset, we use the same approach to process the data as we did with PyTorch.

In [None]:
labels      = 10
batch_size  = 128
normalize   = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
transform_train = transforms.Compose([
                            transforms.RandomCrop(32, padding=4), # this line can improve 2%
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            normalize])
transform_test = transforms.Compose([transforms.ToTensor(), normalize])
train_dataset   = datasets.CIFAR10(root='./cifar_data', train=True, download=True,transform =transform_train)
test_dataset    = datasets.CIFAR10(root='./cifar_data', train=False, download=True, transform=transform_test)
train_loader    = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=8, shuffle=True)
test_loader     = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=8, shuffle=False)

### 3.3 Define Network Structure

All the Hoyer version of cuba layers are stored in lava.lib.dl.slayer.block.cuba_hoyer. Currently, we implement ```cuba_hoyer.Conv```, ```cuba_hoyer.Dense```, ```cuba_hoyer.Affine```, ```cuba_hoyer.Pool```.

In addition to the original parameters in Cuba layers, ```cuba_hoyer.Conv``` has three extra parameters: 
- T: the whole time step of the input data;
- hoyer_type: 'sum' means all channels share one threshold, 'cw' means every channel has its own threshold;
- num_features: the number of features/channels (used for channel-wise threshold).

The hoyer_type of ```cuba_hoyer.Dense``` and ```cuba_hoyer.Affine``` must be 'sum' and the num_features must be set to 1. ```cuba_hoyer.Pool``` does not need the 3 extra parameters.

In [None]:
import lava.lib.dl.slayer.block.cuba_hoyer as cuba_hoyer

cfg = {
    'VGG5': [64, 'M', 128, 'M', 256, 'M', 512, 'M', 512],
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512],
}

class VGG_Lava(nn.Module):
    def __init__(self, vgg_name='VGG16', labels=10, dataset = 'CIFAR10', time_steps=1):
        super(VGG_Lava, self).__init__()
        self.dataset = dataset
        self.T = time_steps
        sdnn_params = { 
                'threshold'     : 1,    # delta unit threshold
                'current_decay' : 1,    # u[t] = (1 - alpha_u) * u[t-1] + x[t]
                'voltage_decay' : 1,    # v[t] = (1 - alpha_v) * v[t-1] + u[t] + bias
                'tau_grad'      : 1.0,  # delta unit surrogate gradient relaxation parameter
                'scale_grad'    : 1.0,  # delta unit surrogate gradient scale parameter
                'requires_grad' : True, # trainable threshold
            }
        sdnn_cnn_params = { 
                **sdnn_params,                            # copy all sdnn_params
                'dropout' : slayer.neuron.Dropout(p=0.1), # neuron dropout
                'T': self.T,                              # time steps   
                'hoyer_type': 'cw',
            }
        sdnn_dense_params = { 
                **sdnn_params,                            # copy all sdnn_cnn_params
                'dropout' : slayer.neuron.Dropout(p=0.1), # neuron dropout
                'T': self.T,                              # time steps   
                'num_features': 1, # set 1 for linear layer to avoid bn and set different shape of hoyer ext
            }
        self.weight_norm = False
        self.delay = True
        self.delay_shift = True
        self.cfg = cfg[vgg_name]
        self.sdnn_cnn_params = sdnn_cnn_params
        self.features = self._make_layers(cfg[vgg_name], sdnn_cnn_params, sdnn_params)

        self.classifier = nn.Sequential(
            slayer.block.cuba.Flatten(),
            cuba_hoyer.Dense(sdnn_dense_params, 2048, 4096, weight_norm=self.weight_norm, delay=self.delay, delay_shift=self.delay_shift),
            cuba_hoyer.Dense(sdnn_dense_params, 4096, 4096, weight_norm=self.weight_norm, delay=self.delay, delay_shift=self.delay_shift),
            cuba_hoyer.Affine(sdnn_params, 4096, labels, weight_norm=self.weight_norm),
        )
        self.blocks = torch.nn.Sequential(*(list(self.features)+list(self.classifier)))
        del self.features
        del self.classifier

    def forward(self, x):
        count = []
        event_cost = torch.zeros(1).to(x.device)
        if self.dataset == 'CIFAR10':
            x = slayer.utils.time.replicate(x, self.T)
        elif self.dataset == 'CIFAR10DVS':
            x = x.permute(1, 2, 3, 4, 0).contiguous()

        for i,block in enumerate(self.blocks): 
            x = block(x)
            if hasattr(block, 'neuron'):
                count.append(torch.sum((torch.abs(x[..., 1:]) > 0).to(x.dtype)).item())
        out = x[:,:,-1]
        return out, event_cost, torch.FloatTensor(count).reshape((1, -1)).to(x.device)

    def _make_layers(self, cfg, sdnn_cnn_params, sdnn_params):
        layers = []
        in_channels = 3 if self.dataset == 'CIFAR10' else 2
        if self.dataset == 'IMAGENET':
            cfg.append('M')
        for i,x in enumerate(cfg):
            if x == 'M':
                continue
            sdnn_cnn_params['num_features'] = x
            conv = cuba_hoyer.Conv(sdnn_cnn_params,  in_channels, x, 3, padding=1, stride=1, weight_scale=1, weight_norm=self.weight_norm, delay=self.delay, delay_shift=self.delay_shift)
            if i+1 < len(cfg) and cfg[i+1] == 'M':
                layers += [conv,
                        cuba_hoyer.Pool(sdnn_params, 2, stride=2, weight_scale=1, weight_norm=False, delay=self.delay, delay_shift=self.delay_shift)]
            else:
                layers += [conv]
            in_channels = x
        return nn.Sequential(*layers)

    def grad_flow(self, path):
        # helps monitor the gradient flow
        grad = [b.synapse.grad_norm for b in self.blocks if hasattr(b, 'synapse')]

        plt.figure()
        plt.semilogy(grad)
        plt.savefig(path + 'gradFlow.png')
        plt.close()

        return grad
    
    def merge_bn(self):
        # merge batch normalization layer to convolutional layer
        new_layers = list(self.blocks)
        for i, b in enumerate(new_layers):
            if isinstance(b, slayer.block.cuba.Flatten):
                self.blocks = torch.nn.Sequential(*new_layers)  
                return
            if hasattr(b, 'neuron'):
                if hasattr(b.neuron, 'bn'):
                    gamma = b.neuron.bn.weight
                    beta = b.neuron.bn.bias
                    mean = b.neuron.bn.running_mean
                    var = b.neuron.bn.running_var
                    eps = b.neuron.bn.eps
                    print('merge bn layer: ', i)
                    W = b.synapse.weight
                    bias = b.synapse.bias if b.synapse.bias is not None else torch.zeros_like(mean)
                    W_prime = W * (gamma / torch.sqrt(var + eps)).reshape(-1, 1, 1, 1, 1)
                    bias_prime = (bias - mean) * (gamma / torch.sqrt(var + eps)) + beta
                    b.synapse.weight = nn.Parameter(W_prime)
                    b.synapse.bias = nn.Parameter(bias_prime)
                    del b.neuron.bn

    def merge_pool_conv(self, b, s=2):
        # merge pooling layer to convolutional layer
        ori_weight = b.synapse.weight
        ori_bias = b.synapse.bias
        assert ori_bias is None
        padding = b.synapse.padding[0] * s
        stride = b.synapse.stride[0] * s
        o,i,k1,k2,t = ori_weight.shape
        merged_weight = torch.zeros(o, i, k1*s, k2*s, t)
        for i in range(k1):
            for j in range(k2):
                merged_weight[:, :, i, j, :] = ori_weight[:,:,i//s,j//s,:]
        self.sdnn_cnn_params['num_features'] = o
        neuron_shape = b.neuron.shape
        b = cuba_hoyer.Conv(self.sdnn_cnn_params, i, o, k1*s, padding=padding, stride=stride, 
                weight_scale=1, weight_norm=self.weight_norm, delay=self.delay, delay_shift=self.delay_shift)
        b.synapse.weight = nn.Parameter(merged_weight)
        b.neuron.shape = neuron_shape
        return b
    
    def export_hdf5(self, filename):
        # network export to hdf5 format, first merge pooling layer
        # then merge batch normalization layer
        merge_flag = False
        for i, b in enumerate(self.blocks):
            if merge_flag:
                self.blocks[i] = self.merge_pool_conv(b)
                merge_flag = False
            if i < len(self.cfg) and self.cfg[i] == 'M':
                merge_flag = True
        self.merge_bn()
        h = h5py.File(filename, 'w')
        layer = h.create_group('layer')
        layer_index = 0
        for i, b in enumerate(self.blocks):
            if i < len(self.cfg) and self.cfg[i] == 'M':
                continue
            b.export_hdf5(layer.create_group(f'{layer_index}'))
            layer_index += 1

### 3.4 Define Meter

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

### 3.5 Instantiate Network, Optimizer, Learning rate scheduler

In [None]:
lam = 1e-9
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = VGG_Lava('VGG16', labels, dataset='CIFAR10', time_steps=1).to(device)
print(net)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=1e-4)

# epochs = 300
epochs = 10
lr_reduce = 5
steps  = [0.6,0.8,0.9]
lr_interval = [step * epochs for step in steps]
def lr_scale(epoch):
    for i, step_epoch in enumerate(lr_interval):
        if epoch < step_epoch:
            return i
    return i + 1
lambda0 = lambda cur_epoch : 1.0/(lr_reduce**lr_scale(cur_epoch))
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda0, last_epoch=-1)

stats = slayer.utils.LearningStats()


### 3.6 Set log file

In [None]:
log = True
log_file = './logs_ann/'

if log:
    try:
        os.mkdir(log_file)
    except OSError:
        pass 
    identifier = 'cifar10_VGG' + '_' + datetime.datetime.now().strftime('%Y%m%d%H%M')
    log_file+=identifier+'.log'
    print('log file: ', log_file)
    f = open(log_file, 'w', buffering=1)
    f.write('\n{}'.format(net))
else:
    print('use stdout')
    f = sys.stdout


### 3.7 Training Loop

In [None]:
best_accuracy = 0.0
for epoch in range(epochs):
    net.train()
    train_losses = AverageMeter('Loss')
    train_event_losses = AverageMeter('Loss')
    train_total_losses = AverageMeter('Loss')
    train_accuracy = AverageMeter('Accuracy')
    for i, (input, label) in enumerate(train_loader, 0):
        input = input.to(device)
        output,event_cost,count = net.forward(input)
        
        rate = output.reshape((input.shape[0], -1))
        loss = F.cross_entropy(rate, label.to(device))
        total_loss = loss + lam * event_cost
        prediction = rate.data.max(1, keepdim=True)[1].cpu().flatten()

        train_losses.update(loss.data.item(), input.shape[0])
        train_event_losses.update(lam * event_cost.data.item(), input.shape[0])
        train_total_losses.update(total_loss.data.item(), input.shape[0])
        train_accuracy.update(torch.sum( prediction == label ).data.item() / input.shape[0], input.shape[0])

        stats.training.num_samples += len(label)
        stats.training.loss_sum += total_loss.cpu().data.item() * input.shape[0]
        stats.training.correct_samples += torch.sum( prediction == label ).data.item()

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

    for param_group in optimizer.param_groups:
        learning_rate = param_group['lr']
    f.write(f'\n[Epoch {epoch:2d}/{epochs}], lr: {learning_rate:.4f} loss: {train_losses.avg:.4f}, event_loss: {train_event_losses.avg:.4f}, total loss: {train_total_losses.avg:.4f}, accuracy: {train_accuracy.avg:.4f} ')
    net.eval()
    test_losses = AverageMeter('Loss')
    test_event_losses = AverageMeter('Loss')
    test_total_losses = AverageMeter('Loss')
    test_accuracy = AverageMeter('Accuracy')
    for i, (input, label) in enumerate(test_loader, 0):
        # net.eval()
        with torch.no_grad():
            input = input.to(device)
            output,event_cost,count = net.forward(input)
            rate = output.reshape((input.shape[0], -1))

            loss = F.cross_entropy(rate, label.to(device))
            total_loss = loss + lam * event_cost
            prediction = rate.data.max(1, keepdim=True)[1].cpu().flatten()

            test_losses.update(loss.data.item(), input.shape[0])
            test_event_losses.update(lam * event_cost.data.item(), input.shape[0])
            test_total_losses.update(total_loss.data.item(), input.shape[0])
            test_accuracy.update(torch.sum( prediction == label ).data.item() / input.shape[0], input.shape[0])

        stats.testing.num_samples += len(label)
        stats.testing.loss_sum += loss.cpu().data.item() * input.shape[0]
        stats.testing.correct_samples += torch.sum( prediction == label ).data.item()
    f.write(f'test: loss: {test_losses.avg:.4f}, event_loss: {test_event_losses.avg:.4f}, total loss: {test_total_losses.avg:.4f}, accuracy: {test_accuracy.avg:.4f}')
    lr_scheduler.step()
    if test_accuracy.avg > best_accuracy:
        best_accuracy = test_accuracy.avg
        new_net = copy.deepcopy(net)
        new_net.export_hdf5(f'network_{identifier}.net')
print('best accuracy: ', best_accuracy)
f.write('\nbest accuracy: {}'.format(best_accuracy))