[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/neuromorphs/osn23-huge-tapeout/blob/main/Train_Binarized_SNN.ipynb)

In [1]:
!pip install snntorch



In [2]:
import torch
config = {
        'model' : 'NetFC',

        'exp_name' : 'mnist_tha',
        'num_trials' : 5,
        'num_epochs' : 10, #500,
        'binarize' : True,
        'binarize_input' : True,
        'data_dir' : "~/data/mnist",
        'batch_size' : 128,
        'seed' : 0,
        'num_workers' : 0,

        # final run sweeps
        'save_csv' : True,
        'save_model' : True,
        'early_stopping': True,
        'patience': 100,

        # final params
        'grad_clip' : False,
        'weight_clip' : False,
        'batch_norm' : True,
        'dropout1' : 0.02856,
        'beta' : 0.99,
        'lr' : 9.97e-3,
        'slope': 10.22,

        # threshold annealing. note: thr_final = threshold + thr_final
        'threshold1' : 11.666,
        'alpha_thr1' : 0.024,
        'thr_final1' : 4.317,

        'threshold2' : 14.105,
        'alpha_thr2' : 0.119,
        'thr_final2' : 16.29,

        'threshold3' : 0.6656,
        'alpha_thr3' : 0.0011,
        'thr_final3' : 3.496,

        # fixed params
        'num_steps' : 100,
        'correct_rate': 0.8,
        'incorrect_rate' : 0.2,
        'betas' : (0.9, 0.999),
        't_0' : 4688,
        'eta_min' : 0,
        'df_lr' : True, # return learning rate. Useful for scheduling
    }

def optim_func(net, config):
    optimizer = torch.optim.Adam(net.parameters(), lr=config["lr"], betas=config['betas'])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['t_0'], eta_min=config['eta_min'], last_epoch=-1)
    return optimizer, scheduler


In [3]:
import torch
import torch.nn as nn
from torch.autograd import Function

class BinarizeF(Function):
    @staticmethod
    def forward(ctx, input):
        output = input.new(input.size())
        output[input >= 0] = 1
        output[input < 0] = -1
        return output

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        return grad_input

# aliases
binarize = BinarizeF.apply

def binarize_activations(input):
    output = input.new(input.size())
    output[input >= 0.5] = 1
    output[input < 0.5] = 0
    return output

In [4]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class BinaryTanh(nn.Module):
    def __init__(self):
        super(BinaryTanh, self).__init__()
        self.hardtanh = nn.Hardtanh()

    def forward(self, input):
        output = self.hardtanh(input)
        output = binarize(output)
        return output


class BinaryLinear(nn.Linear):
    def forward(self, input):
        binary_weight = binarize(self.weight)
        if self.bias is None:
            return F.linear(input, binary_weight)
        else:
            return F.linear(input, binary_weight, self.bias)

    def reset_parameters(self):
        # Glorot initialization
        in_features, out_features = self.weight.size()
        stdv = math.sqrt(1.5 / (in_features + out_features))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.zero_()

        self.weight.lr_scale = 1. / stdv

class SparseBinaryLinear(nn.Linear):
    def __init__(self, in_features, out_features, sparsity=0, bias=True, device=None, dtype=None):
        super(SparseBinaryLinear, self).__init__(in_features, out_features, bias, device, dtype)
        self.mask = torch.bernoulli(torch.ones_like(self.weight) * (1-sparsity))
        self.register_buffer('weight_mask_const', self.mask)
        print(self.mask.mean())

    def forward(self, input):
        binary_weight = binarize(self.weight).mul(Variable(self.weight_mask_const))
        #print(self.weight_mask_const.mean(), binary_weight.mean())
        if self.bias is None:
            return F.linear(input, binary_weight)
        else:
            return F.linear(input, binary_weight, self.bias)

    def reset_parameters(self):
        # Glorot initialization
        in_features, out_features = self.weight.size()
        stdv = math.sqrt(1.5 / (in_features + out_features))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.zero_()

        self.weight.lr_scale = 1. / stdv

class BinaryConv2d(nn.Conv2d):

    def forward(self, input):
        bw = binarize(self.weight)
        return F.conv2d(input, bw, self.bias, self.stride,
                               self.padding, self.dilation, self.groups)

    def reset_parameters(self):
        # Glorot initialization
        in_features = self.in_channels
        out_features = self.out_channels
        for k in self.kernel_size:
            in_features *= k
            out_features *= k
        stdv = math.sqrt(1.5 / (in_features + out_features))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.zero_()

        self.weight.lr_scale = 1. / stdv

In [5]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

MNIST_INPUT_RESOLUTION = 16

def load_data(config):
        data_dir = config['data_dir']

        transform = transforms.Compose([
                transforms.Resize((MNIST_INPUT_RESOLUTION, MNIST_INPUT_RESOLUTION)),
                transforms.Grayscale(),
                transforms.ToTensor(),
                transforms.Normalize((0,), (1,))])

        trainset = datasets.MNIST(data_dir, train=True, download=True, transform=transform)
        testset = datasets.MNIST(data_dir, train=False, download=True, transform=transform)

        return trainset, testset

In [6]:
import torch


class EarlyStopping_acc:
    """Early stops the training if test acc doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.test_loss_min = 0
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, test_loss, model):

        score = test_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(test_loss, model)
        elif score <= self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
                self.counter = 0
        else:
            self.best_score = score
            self.save_checkpoint(test_loss, model)
            self.counter = 0

    def save_checkpoint(self, test_loss, model):
        '''Saves model when test acc increases.'''
        if self.verbose:
            self.trace_func(f'Test acc increased ({self.test_loss_min:.6f} --> {test_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.test_loss_min = test_loss

In [7]:
import torch
import snntorch as snn
from snntorch import functional as SF

def test_accuracy(config, net, testloader, device="cpu"):
    correct = 0
    total = 0
    with torch.no_grad():
        net.eval()
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            outputs, _ = net(images)
            accuracy = SF.accuracy_rate(outputs, labels)

            total += labels.size(0)
            correct += accuracy * labels.size(0)

    return 100 * correct / total

In [8]:
# exp relaxation implementation of THA based on Eq (4)

def thr_annealing(config, network):
    alpha_thr1 = config['alpha_thr1']
    alpha_thr2 = config['alpha_thr2']
    alpha_thr3 = config['alpha_thr3']

    ### to address conditional parameters, s.t. thr_final > threshold
    thr_final1 = config['thr_final1'] + config['threshold1']
    thr_final2 = config['thr_final2'] + config['threshold2']
    thr_final3 = config['thr_final3'] + config['threshold3']

    network.lif1.threshold += (thr_final1 - network.lif1.threshold) * alpha_thr1
    network.lif2.threshold += (thr_final2 - network.lif2.threshold) * alpha_thr2
    network.lif3.threshold += (thr_final3 - network.lif3.threshold) * alpha_thr3

    return

In [9]:
# snntorch
import snntorch as snn
from snntorch import spikegen
from snntorch import surrogate
from snntorch import functional as SF

# torch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR

# misc
import os
import numpy as np
import math
import itertools
import matplotlib.pyplot as plt
import pandas as pd
import shutil
import time
from tqdm import tqdm

def train(config, net, epoch, trainloader, testloader, criterion, optimizer, scheduler, device):
    net.train()
    loss_accum = []
    lr_accum = []

    # TRAIN
    progress_bar = tqdm(trainloader)
    loss_current = None

    #for data, labels in trainloader:
    for data, labels in progress_bar:
        data, labels = data.to(device), labels.to(device)

        spk_rec2, _ = net(data)
        loss = criterion(spk_rec2, labels)
        optimizer.zero_grad()
        loss.backward()
        if loss_current is None:
            loss_current = loss.item()
        else:
            loss_current = 0.9 * loss_current + 0.1 * loss.item()
        progress_bar.set_description(f"loss: {loss_current:.4f}")

        if config['grad_clip']:
            nn.utils.clip_grad_norm_(net.parameters(), 1.0)
        if config['weight_clip']:
            with torch.no_grad():
                for param in net.parameters():
                    param.clamp_(-1, 1)

        optimizer.step()
        scheduler.step()
        thr_annealing(config, net)


        loss_accum.append(loss.item()/config['num_steps'])
        lr_accum.append(optimizer.param_groups[0]["lr"])


    return loss_accum, lr_accum

In [10]:
# snntorch
import snntorch as snn
from snntorch import spikegen
from snntorch import surrogate

# torch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# misc
import numpy as np
import pandas as pd
import time
import logging

def run(config):
    print(config)
    file_name = config['exp_name']

    for trial in range(config['num_trials']):
        # file names
        SAVE_CSV = config['save_csv']
        SAVE_MODEL = config['save_model']
        csv_name = file_name + '_t' + str(trial) + '.csv'
        log_name = file_name + '_t' + str(trial) + '.log'
        model_name = file_name + '_t' + str(trial) + '.pt'
        num_epochs = config['num_epochs']
        torch.manual_seed(config['seed'])

        # dataframes
        df_train_loss = pd.DataFrame()
        df_test_acc = pd.DataFrame(columns=['epoch', 'test_acc', 'train_time'])
        df_lr = pd.DataFrame()

        # initialize network
        net = None
        net_desc = config['model']
        if net_desc in globals():
            klass = globals()[net_desc]
            net = klass(config)
        else:
            net = eval(net_desc)
        if trial == 0:
            print(net)
        device = "cpu"
        #device = "mps"
        if torch.cuda.is_available():
            device = "cuda:0"
            if torch.cuda.device_count() > 1:
                net = nn.DataParallel(net)
        net.to(device)

        # net params
        criterion = SF.mse_count_loss(correct_rate=config['correct_rate'], incorrect_rate=config['incorrect_rate'])
        optimizer, scheduler = optim_func(net, config)

        # early stopping condition
        if config['early_stopping']:
            early_stopping = EarlyStopping_acc(patience=config['patience'], verbose=True, path=model_name)
            early_stopping.early_stop = False
            early_stopping.best_score = None

        # load data
        trainset, testset = load_data(config)
        config['dataset_length'] = len(trainset)
        trainloader = DataLoader(trainset, batch_size=int(config["batch_size"]), shuffle=True)
        testloader = DataLoader(testset, batch_size=int(config["batch_size"]), shuffle=False)

        print(f"=======Trial: {trial}=======")

        for epoch in range(num_epochs):

            # train
            start_time = time.time()
            loss_list, lr_list = train(config, net, epoch, trainloader, testloader, criterion, optimizer, scheduler, device)
            epoch_time = time.time() - start_time

            # test
            test_acc = test_accuracy(config, net, testloader, device)
            print(f'Epoch: {epoch} \tTest Accuracy: {test_acc}')

            if config['df_lr']:
                df_lr = pd.concat([df_lr, pd.DataFrame(lr_list)])
            df_train_loss = pd.concat([df_train_loss, pd.DataFrame(loss_list)])
            test_data = pd.DataFrame([[epoch, test_acc, epoch_time]], columns = ['epoch', 'test_acc', 'train_time'])
            df_test_acc = pd.concat([df_test_acc, test_data])

            if SAVE_CSV:
                df_train_loss.to_csv('loss_' + csv_name, index=False)
                df_test_acc.to_csv('acc_' + csv_name, index=False)
                if config['df_lr']:
                    df_lr.to_csv('lr_' + csv_name, index=False)

            if config['early_stopping']:
                early_stopping(test_acc, net)

                if early_stopping.early_stop:
                    print("Early stopping")
                    early_stopping.early_stop = False
                    early_stopping.best_score = None
                    break

            if SAVE_MODEL and not config['early_stopping']:
                torch.save(net.state_dict(), model_name)

            #for param in net.parameters():
            #    print(param)
        # net.load_state_dict(torch.load(model_name))


In [11]:
# snntorch
import snntorch as snn
from snntorch import surrogate

# torch
import torch
import torch.nn as nn
import torch.nn.functional as F

class NetConv(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.thr1 = config['threshold1']
        self.thr2 = config['threshold2']
        self.thr3 = config['threshold3']
        slope = config['slope']
        beta = config['beta']
        self.num_steps = config['num_steps']
        self.batch_norm = config['batch_norm']
        p1 = config['dropout1']
        self.binarize = config['binarize']
        self.binarize_input = config['binarize_input']

        spike_grad = surrogate.fast_sigmoid(slope)
        # Initialize layers with spike operator
        self.bconv1 = BinaryConv2d(1, 16, 5, bias=False)
        self.conv1 = nn.Conv2d(1, 16, 5, bias=False)
        self.conv1_bn = nn.BatchNorm2d(16)
        self.lif1 = snn.Leaky(beta, threshold=self.thr1, spike_grad=spike_grad)
        self.bconv2 = BinaryConv2d(16, 64, 5, bias=False)
        self.conv2 = nn.Conv2d(16, 64, 5, bias=False)
        self.conv2_bn = nn.BatchNorm2d(64)
        self.lif2 = snn.Leaky(beta, threshold=self.thr2, spike_grad=spike_grad)
        self.bfc1 = BinaryLinear(64 * 4 * 4, 10)
        self.fc1 = nn.Linear(64 * 4 * 4, 10)
        self.lif3 = snn.Leaky(beta, threshold=self.thr3, spike_grad=spike_grad)
        self.dropout = nn.Dropout(p1)

    def forward(self, x):

        # Initialize hidden states and outputs at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        # Record the final layer
        spk3_rec = []
        mem3_rec = []

        # Binarized
        if self.binarize:

            for step in range(self.num_steps):

                if self.binarize_input:
                    x = binarize_activations(x)
                cur1 = F.avg_pool2d(self.bconv1(x), 2)
                if self.batch_norm:
                    cur1 = self.conv1_bn(cur1)
                spk1, mem1 = self.lif1(cur1, mem1)
                cur2 = F.avg_pool2d(self.bconv2(spk1), 2)
                if self.batch_norm:
                    cur2 = self.conv2_bn(cur2)
                spk2, mem2 = self.lif2(cur2, mem2)
                cur3 = self.dropout(self.bfc1(spk2.flatten(1)))
                spk3, mem3 = self.lif3(cur3, mem3)

                spk3_rec.append(spk3)
                mem3_rec.append(mem3)

            return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)

        # Full Precision
        else:

            for step in range(self.num_steps):

                cur1 = F.avg_pool2d(self.conv1(x), 2)
                if self.batch_norm:
                    cur1 = self.conv1_bn(cur1)
                spk1, mem1 = self.lif1(cur1, mem1)
                cur2 = F.avg_pool2d(self.conv2(spk1), 2)
                if self.batch_norm:
                    cur2 = self.conv2_bn(cur2)
                spk2, mem2 = self.lif2(cur2, mem2)
                cur3 = self.dropout(self.fc1(spk2.flatten(1)))
                spk3, mem3 = self.lif3(cur3, mem3)

                spk3_rec.append(spk3)
                mem3_rec.append(mem3)

            return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)

class NetFC(nn.Module):
    def __init__(self, config, neurons=[256, 128], sparsity=[0.0, 0.0]):
        super().__init__()

        self.thr1 = config['threshold1']
        self.thr2 = config['threshold2']
        self.thr3 = config['threshold3']
        slope = config['slope']
        beta = config['beta']
        self.num_steps = config['num_steps']
        self.batch_norm = config['batch_norm']
        p1 = config['dropout1']
        self.binarize = config['binarize']
        self.binarize_input = config['binarize_input']

        spike_grad = surrogate.fast_sigmoid(slope)
        # Initialize layers with spike operator
        self.bfc1 = SparseBinaryLinear(MNIST_INPUT_RESOLUTION * MNIST_INPUT_RESOLUTION, neurons[0], sparsity[0], bias=False)
        self.fc1 = nn.Linear(MNIST_INPUT_RESOLUTION * MNIST_INPUT_RESOLUTION, neurons[0], bias=False)
        self.bn1 = nn.BatchNorm1d(neurons[0])
        self.lif1 = snn.Leaky(beta, threshold=self.thr1, spike_grad=spike_grad)

        self.bfc2 = SparseBinaryLinear(neurons[0], neurons[1], sparsity[1], bias=False)
        self.fc2 = nn.Linear(neurons[0], neurons[1], bias=False)
        self.bn2 = nn.BatchNorm1d(neurons[1])
        self.lif2 = snn.Leaky(beta, threshold=self.thr2, spike_grad=spike_grad)

        self.bfc3 = BinaryLinear(neurons[1], 10, bias=False)
        self.fc3 = nn.Linear(neurons[1], 10, bias=False)
        self.lif3 = snn.Leaky(beta, threshold=self.thr3, spike_grad=spike_grad)
        self.dropout = nn.Dropout(p1)

    def forward(self, x):

        # Initialize hidden states and outputs at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        # Record the final layer
        spk3_rec = []
        mem3_rec = []

        # Binarized
        if self.binarize:
            for step in range(self.num_steps):

                x = x.flatten(1)
                if self.binarize_input:
                    x = binarize_activations(x)
                cur1 = self.bfc1(x)
                if self.batch_norm:
                    cur1 = self.bn1(cur1)
                spk1, mem1 = self.lif1(cur1, mem1)
                cur2 = self.bfc2(spk1)
                if self.batch_norm:
                    cur2 = self.bn2(cur2)
                spk2, mem2 = self.lif2(cur2, mem2)
                cur3 = self.dropout(self.bfc3(spk2))
                spk3, mem3 = self.lif3(cur3, mem3)

                spk3_rec.append(spk3)
                mem3_rec.append(mem3)

            return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)

        # Full Precision
        else:

            for step in range(self.num_steps):

                cur1 = self.fc1(x.flatten(1))
                if self.batch_norm:
                    cur1 = self.bn1(cur1)
                spk1, mem1 = self.lif1(cur1, mem1)
                cur2 = self.fc2(spk1)
                if self.batch_norm:
                    cur2 = self.bn2(cur2)
                spk2, mem2 = self.lif2(cur2, mem2)
                cur3 = self.dropout(self.fc3(spk2))
                spk3, mem3 = self.lif3(cur3, mem3)

                spk3_rec.append(spk3)
                mem3_rec.append(mem3)

            return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)

class NetFC_FirstConv(nn.Module):
    def __init__(self, config, neurons=[16, 256]):
        super().__init__()

        self.thr1 = config['threshold1']
        self.thr2 = config['threshold2']
        self.thr3 = config['threshold3']
        slope = config['slope']
        beta = config['beta']
        self.num_steps = config['num_steps']
        self.batch_norm = config['batch_norm']
        p1 = config['dropout1']
        self.binarize = config['binarize']
        self.binarize_input = config['binarize_input']

        spike_grad = surrogate.fast_sigmoid(slope)
        # Initialize layers with spike operator
        self.bconv1 = BinaryConv2d(1, neurons[0], 5, bias=False)
        self.conv1 = nn.Conv2d(1, neurons[0], 5, bias=False)
        self.conv1_bn = nn.BatchNorm2d(neurons[0])
        self.lif1 = snn.Leaky(beta, threshold=self.thr1, spike_grad=spike_grad)

        n = ((28-(5//2))/2)**2
        print(n)
        self.bfc2 = BinaryLinear(n, neurons[1])
        self.fc2 = nn.Linear(n, neurons[1])
        self.bn2 = nn.BatchNorm1d(neurons[1])
        self.lif2 = snn.Leaky(beta, threshold=self.thr2, spike_grad=spike_grad)

        self.bfc3 = BinaryLinear(neurons[1], 10)
        self.fc3 = nn.Linear(neurons[1], 10)
        self.lif3 = snn.Leaky(beta, threshold=self.thr3, spike_grad=spike_grad)
        self.dropout = nn.Dropout(p1)

    def forward(self, x):

        # Initialize hidden states and outputs at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        # Record the final layer
        spk3_rec = []
        mem3_rec = []

        # Binarized
        if self.binarize:

            for step in range(self.num_steps):

                if self.binarize_input:
                    x = binarize_activations(x)
                cur1 = F.avg_pool2d(self.bconv1(x), 2)
                if self.batch_norm:
                    cur1 = self.conv1_bn(cur1)
                spk1, mem1 = self.lif1(cur1, mem1)
                spk2, mem2 = self.lif2(self.bfc2(spk1.flatten(1)), mem2)
                if self.batch_norm:
                    spk2 = self.bn2(spk2)
                cur3 = self.dropout(self.bfc3(spk2))
                spk3, mem3 = self.lif3(cur3, mem3)

                spk3_rec.append(spk3)
                mem3_rec.append(mem3)

            return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)

        # Full Precision
        else:

            for step in range(self.num_steps):

                cur1 = F.avg_pool2d(self.conv1(x), 2)
                if self.batch_norm:
                    cur1 = self.conv1_bn(cur1)
                spk1, mem1 = self.lif1(cur1, mem1)
                spk2, mem2 = self.lif2(self.fc2(spk1.flatten(1)), mem2)
                if self.batch_norm:
                    spk2 = self.bn2(spk2)
                cur3 = self.dropout(self.fc3(spk2))
                spk3, mem3 = self.lif3(cur3, mem3)

                spk3_rec.append(spk3)
                mem3_rec.append(mem3)

            return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)

In [24]:
cfg = config
cfg['model'] = 'NetFC'
run(cfg)

{'model': 'NetFC', 'exp_name': 'mnist_tha', 'num_trials': 5, 'num_epochs': 500, 'binarize': True, 'data_dir': '~/data/mnist', 'batch_size': 128, 'seed': 0, 'num_workers': 0, 'save_csv': True, 'save_model': True, 'early_stopping': True, 'patience': 100, 'grad_clip': False, 'weight_clip': False, 'batch_norm': True, 'dropout1': 0.02856, 'beta': 0.99, 'lr': 0.00997, 'slope': 10.22, 'threshold1': 11.666, 'alpha_thr1': 0.024, 'thr_final1': 4.317, 'threshold2': 14.105, 'alpha_thr2': 0.119, 'thr_final2': 16.29, 'threshold3': 0.6656, 'alpha_thr3': 0.0011, 'thr_final3': 3.496, 'num_steps': 100, 'correct_rate': 0.8, 'incorrect_rate': 0.2, 'betas': (0.9, 0.999), 't_0': 4688, 'eta_min': 0, 'df_lr': True}
NetFC(
  (bfc1): SparseBinaryLinear(in_features=256, out_features=256, bias=False)
  (fc1): Linear(in_features=256, out_features=256, bias=False)
  (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (lif1): Leaky()
  (bfc2): SparseBinaryLinear(in_features=256,

loss: 0.7818: 100%|██████████| 469/469 [03:02<00:00,  2.56it/s]


Epoch: 0 	Test Accuracy: 94.05
Test acc increased (0.000000 --> 94.050000).  Saving model ...


loss: 0.7076:  51%|█████     | 238/469 [01:29<01:26,  2.67it/s]


KeyboardInterrupt: ignored

In [13]:
#binarized inputs
cfg = config
cfg['model'] = 'NetFC'
run(cfg)

{'model': 'NetFC', 'exp_name': 'mnist_tha', 'num_trials': 5, 'num_epochs': 500, 'binarize': True, 'data_dir': '~/data/mnist', 'batch_size': 128, 'seed': 0, 'num_workers': 0, 'save_csv': True, 'save_model': True, 'early_stopping': True, 'patience': 100, 'grad_clip': False, 'weight_clip': False, 'batch_norm': True, 'dropout1': 0.02856, 'beta': 0.99, 'lr': 0.00997, 'slope': 10.22, 'threshold1': 11.666, 'alpha_thr1': 0.024, 'thr_final1': 4.317, 'threshold2': 14.105, 'alpha_thr2': 0.119, 'thr_final2': 16.29, 'threshold3': 0.6656, 'alpha_thr3': 0.0011, 'thr_final3': 3.496, 'num_steps': 100, 'correct_rate': 0.8, 'incorrect_rate': 0.2, 'betas': (0.9, 0.999), 't_0': 4688, 'eta_min': 0, 'df_lr': True}
NetFC(
  (bfc1): SparseBinaryLinear(in_features=256, out_features=256, bias=False)
  (fc1): Linear(in_features=256, out_features=256, bias=False)
  (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (lif1): Leaky()
  (bfc2): SparseBinaryLinear(in_features=256,

loss: 0.8944: 100%|██████████| 469/469 [02:55<00:00,  2.67it/s]


Epoch: 0 	Test Accuracy: 91.56
Test acc increased (0.000000 --> 91.560000).  Saving model ...


loss: 0.8133: 100%|██████████| 469/469 [02:55<00:00,  2.67it/s]


Epoch: 1 	Test Accuracy: 93.5
Test acc increased (91.560000 --> 93.500000).  Saving model ...


loss: 0.7441:  27%|██▋       | 125/469 [00:46<02:08,  2.69it/s]


KeyboardInterrupt: ignored

In [14]:
cfg = config
cfg['model'] = 'NetFC(config, sparsity=[0.5, 0.5])'
run(cfg)

{'model': 'NetFC(config, sparsity=0.5)', 'exp_name': 'mnist_tha', 'num_trials': 5, 'num_epochs': 500, 'binarize': True, 'data_dir': '~/data/mnist', 'batch_size': 128, 'seed': 0, 'num_workers': 0, 'save_csv': True, 'save_model': True, 'early_stopping': True, 'patience': 100, 'grad_clip': False, 'weight_clip': False, 'batch_norm': True, 'dropout1': 0.02856, 'beta': 0.99, 'lr': 0.00997, 'slope': 10.22, 'threshold1': 11.666, 'alpha_thr1': 0.024, 'thr_final1': 15.983, 'threshold2': 14.105, 'alpha_thr2': 0.119, 'thr_final2': 30.395, 'threshold3': 0.6656, 'alpha_thr3': 0.0011, 'thr_final3': 4.1616, 'num_steps': 100, 'correct_rate': 0.8, 'incorrect_rate': 0.2, 'betas': (0.9, 0.999), 't_0': 4688, 'eta_min': 0, 'df_lr': True, 'dataset_length': 60000}
NetFC(
  (bfc1): SparseBinaryLinear(in_features=256, out_features=256, bias=False)
  (fc1): Linear(in_features=256, out_features=256, bias=False)
  (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (lif1): Lea

loss: 1.0576: 100%|██████████| 469/469 [02:57<00:00,  2.64it/s]


Epoch: 0 	Test Accuracy: 89.58
Test acc increased (0.000000 --> 89.580000).  Saving model ...


loss: 0.9438: 100%|██████████| 469/469 [02:57<00:00,  2.64it/s]


Epoch: 1 	Test Accuracy: 90.51
Test acc increased (89.580000 --> 90.510000).  Saving model ...


loss: 0.9071: 100%|██████████| 469/469 [02:57<00:00,  2.64it/s]


Epoch: 2 	Test Accuracy: 91.0
Test acc increased (90.510000 --> 91.000000).  Saving model ...


loss: 0.8651: 100%|██████████| 469/469 [02:57<00:00,  2.64it/s]


Epoch: 3 	Test Accuracy: 91.85
Test acc increased (91.000000 --> 91.850000).  Saving model ...


loss: 0.8304: 100%|██████████| 469/469 [02:53<00:00,  2.71it/s]


Epoch: 4 	Test Accuracy: 92.0
Test acc increased (91.850000 --> 92.000000).  Saving model ...


loss: 0.8051: 100%|██████████| 469/469 [02:52<00:00,  2.72it/s]


Epoch: 5 	Test Accuracy: 92.3
Test acc increased (92.000000 --> 92.300000).  Saving model ...


loss: 0.7615: 100%|██████████| 469/469 [02:50<00:00,  2.74it/s]


Epoch: 6 	Test Accuracy: 92.17
EarlyStopping counter: 1 out of 100


loss: 0.7586: 100%|██████████| 469/469 [02:49<00:00,  2.77it/s]


Epoch: 7 	Test Accuracy: 92.73
Test acc increased (92.300000 --> 92.730000).  Saving model ...


loss: 0.7067: 100%|██████████| 469/469 [02:47<00:00,  2.80it/s]


Epoch: 8 	Test Accuracy: 92.95
Test acc increased (92.730000 --> 92.950000).  Saving model ...


loss: 0.6712: 100%|██████████| 469/469 [02:44<00:00,  2.85it/s]


Epoch: 9 	Test Accuracy: 92.91
EarlyStopping counter: 1 out of 100


loss: 0.7154:  30%|███       | 141/469 [00:50<01:58,  2.78it/s]


KeyboardInterrupt: ignored

In [None]:
cfg = config
cfg['model'] = 'NetFC(config, sparsity=[0.9, 0.9])'
run(cfg)

{'model': 'NetFC(config, sparsity=0.9)', 'exp_name': 'mnist_tha', 'num_trials': 5, 'num_epochs': 10, 'binarize': True, 'binarize_input': True, 'data_dir': '~/data/mnist', 'batch_size': 128, 'seed': 0, 'num_workers': 0, 'save_csv': True, 'save_model': True, 'early_stopping': True, 'patience': 100, 'grad_clip': False, 'weight_clip': False, 'batch_norm': True, 'dropout1': 0.02856, 'beta': 0.99, 'lr': 0.00997, 'slope': 10.22, 'threshold1': 11.666, 'alpha_thr1': 0.024, 'thr_final1': 4.317, 'threshold2': 14.105, 'alpha_thr2': 0.119, 'thr_final2': 16.29, 'threshold3': 0.6656, 'alpha_thr3': 0.0011, 'thr_final3': 3.496, 'num_steps': 100, 'correct_rate': 0.8, 'incorrect_rate': 0.2, 'betas': (0.9, 0.999), 't_0': 4688, 'eta_min': 0, 'df_lr': True}
tensor(0.1001)
tensor(0.1024)
NetFC(
  (bfc1): SparseBinaryLinear(in_features=256, out_features=256, bias=False)
  (fc1): Linear(in_features=256, out_features=256, bias=False)
  (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_

loss: 1.6109: 100%|██████████| 469/469 [02:47<00:00,  2.80it/s]


Epoch: 0 	Test Accuracy: 85.25
Test acc increased (0.000000 --> 85.250000).  Saving model ...


loss: 1.4953: 100%|██████████| 469/469 [02:46<00:00,  2.82it/s]


Epoch: 1 	Test Accuracy: 84.37
EarlyStopping counter: 1 out of 100


loss: 1.4063: 100%|██████████| 469/469 [02:48<00:00,  2.79it/s]


Epoch: 2 	Test Accuracy: 85.62
Test acc increased (85.250000 --> 85.620000).  Saving model ...


loss: 1.3945: 100%|██████████| 469/469 [02:48<00:00,  2.78it/s]


Epoch: 3 	Test Accuracy: 86.7
Test acc increased (85.620000 --> 86.700000).  Saving model ...


loss: 1.4162: 100%|██████████| 469/469 [02:46<00:00,  2.82it/s]


Epoch: 4 	Test Accuracy: 87.12
Test acc increased (86.700000 --> 87.120000).  Saving model ...


loss: 1.3976: 100%|██████████| 469/469 [02:46<00:00,  2.82it/s]


Epoch: 5 	Test Accuracy: 87.33
Test acc increased (87.120000 --> 87.330000).  Saving model ...


loss: 1.2937: 100%|██████████| 469/469 [02:48<00:00,  2.79it/s]


Epoch: 6 	Test Accuracy: 87.65
Test acc increased (87.330000 --> 87.650000).  Saving model ...


loss: 1.3561:  95%|█████████▍| 445/469 [02:38<00:08,  2.76it/s]

In [None]:
cfg = config
cfg['model'] = 'NetFC(config, [128, 128])'
run(cfg)