[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/neuromorphs/osn23-huge-tapeout/blob/main/Train_Binarized_SNN.ipynb)

In [None]:
!pip install snntorch



In [None]:
import torch
config = {
        'model' : 'Net',

        'exp_name' : 'mnist_tha',
        'num_trials' : 5,
        'num_epochs' : 500,
        'binarize' : True,
        'data_dir' : "~/data/mnist",
        'batch_size' : 128,
        'seed' : 0,
        'num_workers' : 0,

        # final run sweeps
        'save_csv' : True,
        'save_model' : True,
        'early_stopping': True,
        'patience': 100,

        # final params
        'grad_clip' : False,
        'weight_clip' : False,
        'batch_norm' : True,
        'dropout1' : 0.02856,
        'beta' : 0.99,
        'lr' : 9.97e-3,
        'slope': 10.22,

        # threshold annealing. note: thr_final = threshold + thr_final
        'threshold1' : 11.666,
        'alpha_thr1' : 0.024,
        'thr_final1' : 4.317,

        'threshold2' : 14.105,
        'alpha_thr2' : 0.119,
        'thr_final2' : 16.29,

        'threshold3' : 0.6656,
        'alpha_thr3' : 0.0011,
        'thr_final3' : 3.496,

        # fixed params
        'num_steps' : 100,
        'correct_rate': 0.8,
        'incorrect_rate' : 0.2,
        'betas' : (0.9, 0.999),
        't_0' : 4688,
        'eta_min' : 0,
        'df_lr' : True, # return learning rate. Useful for scheduling
    }

def optim_func(net, config):
    optimizer = torch.optim.Adam(net.parameters(), lr=config["lr"], betas=config['betas'])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['t_0'], eta_min=config['eta_min'], last_epoch=-1)
    return optimizer, scheduler


In [None]:
import torch
import torch.nn as nn
from torch.autograd import Function

class BinarizeF(Function):
    @staticmethod
    def forward(ctx, input):
        output = input.new(input.size())
        output[input >= 0] = 1
        output[input < 0] = -1
        return output

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        return grad_input

# aliases
binarize = BinarizeF.apply

In [None]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class BinaryTanh(nn.Module):
    def __init__(self):
        super(BinaryTanh, self).__init__()
        self.hardtanh = nn.Hardtanh()

    def forward(self, input):
        output = self.hardtanh(input)
        output = binarize(output)
        return output


class BinaryLinear(nn.Linear):

    def forward(self, input):
        binary_weight = binarize(self.weight)
        if self.bias is None:
            return F.linear(input, binary_weight)
        else:
            return F.linear(input, binary_weight, self.bias)

    def reset_parameters(self):
        # Glorot initialization
        in_features, out_features = self.weight.size()
        stdv = math.sqrt(1.5 / (in_features + out_features))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.zero_()

        self.weight.lr_scale = 1. / stdv

class SparseBinaryLinear(nn.Linear):
    def __init__(self, in_features, out_features, sparsity=0, bias=True, device=None, dtype=None):
        super(SparseBinaryLinear, self).__init__(in_features, out_features, bias, device, dtype)
        self.mask = torch.bernoulli(torch.ones_like(self.weight) * (1-sparsity))
        self.register_buffer('weight_mask_const', self.mask)

    def forward(self, input):
        binary_weight = binarize(self.weight).mul(Variable(self.weight_mask_const))
        if self.bias is None:
            return F.linear(input, binary_weight)
        else:
            return F.linear(input, binary_weight, self.bias)

    def reset_parameters(self):
        # Glorot initialization
        in_features, out_features = self.weight.size()
        stdv = math.sqrt(1.5 / (in_features + out_features))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.zero_()

        self.weight.lr_scale = 1. / stdv

class BinaryConv2d(nn.Conv2d):

    def forward(self, input):
        bw = binarize(self.weight)
        return F.conv2d(input, bw, self.bias, self.stride,
                               self.padding, self.dilation, self.groups)

    def reset_parameters(self):
        # Glorot initialization
        in_features = self.in_channels
        out_features = self.out_channels
        for k in self.kernel_size:
            in_features *= k
            out_features *= k
        stdv = math.sqrt(1.5 / (in_features + out_features))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.zero_()

        self.weight.lr_scale = 1. / stdv

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

def load_data(config):
        data_dir = config['data_dir']

        transform = transforms.Compose([
                transforms.Resize((28, 28)),
                transforms.Grayscale(),
                transforms.ToTensor(),
                transforms.Normalize((0,), (1,))])

        trainset = datasets.MNIST(data_dir, train=True, download=True, transform=transform)
        testset = datasets.MNIST(data_dir, train=False, download=True, transform=transform)

        return trainset, testset

In [None]:
import torch


class EarlyStopping_acc:
    """Early stops the training if test acc doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.test_loss_min = 0
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, test_loss, model):

        score = test_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(test_loss, model)
        elif score <= self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
                self.counter = 0
        else:
            self.best_score = score
            self.save_checkpoint(test_loss, model)
            self.counter = 0

    def save_checkpoint(self, test_loss, model):
        '''Saves model when test acc increases.'''
        if self.verbose:
            self.trace_func(f'Test acc increased ({self.test_loss_min:.6f} --> {test_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.test_loss_min = test_loss

In [None]:
import torch
import snntorch as snn
from snntorch import functional as SF

def test_accuracy(config, net, testloader, device="cpu"):
    correct = 0
    total = 0
    with torch.no_grad():
        net.eval()
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            outputs, _ = net(images)
            accuracy = SF.accuracy_rate(outputs, labels)

            total += labels.size(0)
            correct += accuracy * labels.size(0)

    return 100 * correct / total

In [None]:
# exp relaxation implementation of THA based on Eq (4)

def thr_annealing(config, network):
    alpha_thr1 = config['alpha_thr1']
    alpha_thr2 = config['alpha_thr2']
    alpha_thr3 = config['alpha_thr3']

    thr_final1 = config['thr_final1']
    thr_final2 = config['thr_final2']
    thr_final3 = config['thr_final3']

    network.lif1.threshold += (thr_final1 - network.lif1.threshold) * alpha_thr1
    network.lif2.threshold += (thr_final2 - network.lif2.threshold) * alpha_thr2
    network.lif3.threshold += (thr_final3 - network.lif3.threshold) * alpha_thr3

    return

In [None]:
# snntorch
import snntorch as snn
from snntorch import spikegen
from snntorch import surrogate
from snntorch import functional as SF

# torch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR

# misc
import os
import numpy as np
import math
import itertools
import matplotlib.pyplot as plt
import pandas as pd
import shutil
import time
from tqdm import tqdm

def train(config, net, epoch, trainloader, testloader, criterion, optimizer, scheduler, device):
    net.train()
    loss_accum = []
    lr_accum = []

    # TRAIN
    progress_bar = tqdm(trainloader)
    loss_current = None

    #for data, labels in trainloader:
    for data, labels in progress_bar:
        data, labels = data.to(device), labels.to(device)

        spk_rec2, _ = net(data)
        loss = criterion(spk_rec2, labels)
        optimizer.zero_grad()
        loss.backward()
        if loss_current is None:
            loss_current = loss.item()
        else:
            loss_current = 0.9 * loss_current + 0.1 * loss.item()
        progress_bar.set_description(f"loss: {loss_current:.4f}")

        if config['grad_clip']:
            nn.utils.clip_grad_norm_(net.parameters(), 1.0)
        if config['weight_clip']:
            with torch.no_grad():
                for param in net.parameters():
                    param.clamp_(-1, 1)

        optimizer.step()
        scheduler.step()
        thr_annealing(config, net)


        loss_accum.append(loss.item()/config['num_steps'])
        lr_accum.append(optimizer.param_groups[0]["lr"])


    return loss_accum, lr_accum

In [None]:
# snntorch
import snntorch as snn
from snntorch import spikegen
from snntorch import surrogate

# torch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# misc
import numpy as np
import pandas as pd
import time
import logging

def run(config):
    print(config)
    file_name = config['exp_name']

    ### to address conditional parameters, s.t. thr_final > threshold
    config['thr_final1'] = config['thr_final1'] + config['threshold1']
    config['thr_final2'] = config['thr_final2'] + config['threshold2']
    config['thr_final3'] = config['thr_final3'] + config['threshold3']

    threshold1 = config['threshold1']
    threshold2 = config['threshold2']
    threshold3 = config['threshold3']

    for trial in range(config['num_trials']):
        # file names
        SAVE_CSV = config['save_csv']
        SAVE_MODEL = config['save_model']
        csv_name = file_name + '_t' + str(trial) + '.csv'
        log_name = file_name + '_t' + str(trial) + '.log'
        model_name = file_name + '_t' + str(trial) + '.pt'
        num_epochs = config['num_epochs']
        torch.manual_seed(config['seed'])

        config['threshold1'] = threshold1
        config['threshold2'] = threshold2
        config['threshold3'] = threshold3

        # dataframes
        df_train_loss = pd.DataFrame()
        df_test_acc = pd.DataFrame(columns=['epoch', 'test_acc', 'train_time'])
        df_lr = pd.DataFrame()

        # initialize network
        net = None
        net_desc = config['model']
        if net_desc in globals():
            klass = globals()[net_desc]
            net = klass(config)
        else:
            net = eval(net_desc)
        if trial == 0:
            print(net)
        device = "cpu"
        #device = "mps"
        if torch.cuda.is_available():
            device = "cuda:0"
            if torch.cuda.device_count() > 1:
                net = nn.DataParallel(net)
        net.to(device)

        # net params
        criterion = SF.mse_count_loss(correct_rate=config['correct_rate'], incorrect_rate=config['incorrect_rate'])
        optimizer, scheduler = optim_func(net, config)

        # early stopping condition
        if config['early_stopping']:
            early_stopping = EarlyStopping_acc(patience=config['patience'], verbose=True, path=model_name)
            early_stopping.early_stop = False
            early_stopping.best_score = None

        # load data
        trainset, testset = load_data(config)
        config['dataset_length'] = len(trainset)
        trainloader = DataLoader(trainset, batch_size=int(config["batch_size"]), shuffle=True)
        testloader = DataLoader(testset, batch_size=int(config["batch_size"]), shuffle=False)

        print(f"=======Trial: {trial}=======")

        for epoch in range(num_epochs):

            # train
            start_time = time.time()
            loss_list, lr_list = train(config, net, epoch, trainloader, testloader, criterion, optimizer, scheduler, device)
            epoch_time = time.time() - start_time

            # test
            test_acc = test_accuracy(config, net, testloader, device)
            print(f'Epoch: {epoch} \tTest Accuracy: {test_acc}')

            if config['df_lr']:
                df_lr = pd.concat([df_lr, pd.DataFrame(lr_list)])
            df_train_loss = pd.concat([df_train_loss, pd.DataFrame(loss_list)])
            test_data = pd.DataFrame([[epoch, test_acc, epoch_time]], columns = ['epoch', 'test_acc', 'train_time'])
            df_test_acc = pd.concat([df_test_acc, test_data])

            if SAVE_CSV:
                df_train_loss.to_csv('loss_' + csv_name, index=False)
                df_test_acc.to_csv('acc_' + csv_name, index=False)
                if config['df_lr']:
                    df_lr.to_csv('lr_' + csv_name, index=False)

            if config['early_stopping']:
                early_stopping(test_acc, net)

                if early_stopping.early_stop:
                    print("Early stopping")
                    early_stopping.early_stop = False
                    early_stopping.best_score = None
                    break

            if SAVE_MODEL and not config['early_stopping']:
                torch.save(net.state_dict(), model_name)

            #for param in net.parameters():
            #    print(param)
        # net.load_state_dict(torch.load(model_name))


In [None]:
# snntorch
import snntorch as snn
from snntorch import surrogate

# torch
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.thr1 = config['threshold1']
        self.thr2 = config['threshold2']
        self.thr3 = config['threshold3']
        slope = config['slope']
        beta = config['beta']
        self.num_steps = config['num_steps']
        self.batch_norm = config['batch_norm']
        p1 = config['dropout1']
        self.binarize = config['binarize']

        spike_grad = surrogate.fast_sigmoid(slope)
        # Initialize layers with spike operator
        self.bconv1 = BinaryConv2d(1, 16, 5, bias=False)
        self.conv1 = nn.Conv2d(1, 16, 5, bias=False)
        self.conv1_bn = nn.BatchNorm2d(16)
        self.lif1 = snn.Leaky(beta, threshold=self.thr1, spike_grad=spike_grad)
        self.bconv2 = BinaryConv2d(16, 64, 5, bias=False)
        self.conv2 = nn.Conv2d(16, 64, 5, bias=False)
        self.conv2_bn = nn.BatchNorm2d(64)
        self.lif2 = snn.Leaky(beta, threshold=self.thr2, spike_grad=spike_grad)
        self.bfc1 = BinaryLinear(64 * 4 * 4, 10)
        self.fc1 = nn.Linear(64 * 4 * 4, 10)
        self.lif3 = snn.Leaky(beta, threshold=self.thr3, spike_grad=spike_grad)
        self.dropout = nn.Dropout(p1)

    def forward(self, x):

        # Initialize hidden states and outputs at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        # Record the final layer
        spk3_rec = []
        mem3_rec = []

        # Binarized
        if self.binarize:

            for step in range(self.num_steps):

                cur1 = F.avg_pool2d(self.bconv1(x), 2)
                if self.batch_norm:
                    cur1 = self.conv1_bn(cur1)
                spk1, mem1 = self.lif1(cur1, mem1)
                cur2 = F.avg_pool2d(self.bconv2(spk1), 2)
                if self.batch_norm:
                    cur2 = self.conv2_bn(cur2)
                spk2, mem2 = self.lif2(cur2, mem2)
                cur3 = self.dropout(self.bfc1(spk2.flatten(1)))
                spk3, mem3 = self.lif3(cur3, mem3)

                spk3_rec.append(spk3)
                mem3_rec.append(mem3)

            return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)

        # Full Precision
        else:

            for step in range(self.num_steps):

                cur1 = F.avg_pool2d(self.conv1(x), 2)
                if self.batch_norm:
                    cur1 = self.conv1_bn(cur1)
                spk1, mem1 = self.lif1(cur1, mem1)
                cur2 = F.avg_pool2d(self.conv2(spk1), 2)
                if self.batch_norm:
                    cur2 = self.conv2_bn(cur2)
                spk2, mem2 = self.lif2(cur2, mem2)
                cur3 = self.dropout(self.fc1(spk2.flatten(1)))
                spk3, mem3 = self.lif3(cur3, mem3)

                spk3_rec.append(spk3)
                mem3_rec.append(mem3)

            return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)

class NetFC(nn.Module):
    def __init__(self, config, neurons=[256, 128], sparsity=0.0):
        super().__init__()

        self.thr1 = config['threshold1']
        self.thr2 = config['threshold2']
        self.thr3 = config['threshold3']
        slope = config['slope']
        beta = config['beta']
        self.num_steps = config['num_steps']
        self.batch_norm = config['batch_norm']
        p1 = config['dropout1']
        self.binarize = config['binarize']

        spike_grad = surrogate.fast_sigmoid(slope)
        # Initialize layers with spike operator
        self.bfc1 = SparseBinaryLinear(28 * 28, neurons[0], sparsity, bias=False)
        self.fc1 = nn.Linear(28 * 28, neurons[0], bias=False)
        self.bn1 = nn.BatchNorm1d(neurons[0])
        self.lif1 = snn.Leaky(beta, threshold=self.thr1, spike_grad=spike_grad)

        self.bfc2 = SparseBinaryLinear(neurons[0], neurons[1], sparsity, bias=False)
        self.fc2 = nn.Linear(neurons[0], neurons[1], bias=False)
        self.bn2 = nn.BatchNorm1d(neurons[1])
        self.lif2 = snn.Leaky(beta, threshold=self.thr2, spike_grad=spike_grad)

        self.bfc3 = BinaryLinear(neurons[1], 10, bias=False)
        self.fc3 = nn.Linear(neurons[1], 10, bias=False)
        self.lif3 = snn.Leaky(beta, threshold=self.thr3, spike_grad=spike_grad)
        self.dropout = nn.Dropout(p1)

    def forward(self, x):

        # Initialize hidden states and outputs at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        # Record the final layer
        spk3_rec = []
        mem3_rec = []

        # Binarized
        if self.binarize:
            for step in range(self.num_steps):

                cur1 = self.bfc1(x.flatten(1))
                if self.batch_norm:
                    cur1 = self.bn1(cur1)
                spk1, mem1 = self.lif1(cur1, mem1)
                cur2 = self.bfc2(spk1)
                if self.batch_norm:
                    cur2 = self.bn2(cur2)
                spk2, mem2 = self.lif2(cur2, mem2)
                cur3 = self.dropout(self.bfc3(spk2))
                spk3, mem3 = self.lif3(cur3, mem3)
                # =======Trial: 0=======
                # loss: 0.8453: 100%|██████████| 469/469 [02:46<00:00,  2.82it/s]
                # Epoch: 0 	Test Accuracy: 91.18
                # Test acc increased (0.000000 --> 91.180000).  Saving model ...
                # loss: 0.6933: 100%|██████████| 469/469 [02:45<00:00,  2.83it/s]
                # Epoch: 1 	Test Accuracy: 92.05
                # Test acc increased (91.180000 --> 92.050000).  Saving model ...
                # loss: 0.6445: 100%|██████████| 469/469 [02:45<00:00,  2.83it/s]
                # Epoch: 2 	Test Accuracy: 93.22
                # Test acc increased (92.050000 --> 93.220000).  Saving model ...
                # loss: 0.5774: 100%|██████████| 469/469 [02:46<00:00,  2.82it/s]
                # Epoch: 3 	Test Accuracy: 94.03
                # Test acc increased (93.220000 --> 94.030000).  Saving model ...
                # loss: 0.5002: 100%|██████████| 469/469 [02:45<00:00,  2.83it/s]
                # Epoch: 4 	Test Accuracy: 94.21
                # Test acc increased (94.030000 --> 94.210000).  Saving model ...
                # loss: 0.5120: 100%|██████████| 469/469 [02:45<00:00,  2.83it/s]
                # Epoch: 5 	Test Accuracy: 94.88
                # Test acc increased (94.210000 --> 94.880000).  Saving model ...
                # loss: 0.4805: 100%|██████████| 469/469 [02:44<00:00,  2.84it/s]
                # Epoch: 6 	Test Accuracy: 95.07
                # Test acc increased (94.880000 --> 95.070000).  Saving model ...
                # loss: 0.4456: 100%|██████████| 469/469 [02:45<00:00,  2.84it/s]
                # Epoch: 7 	Test Accuracy: 95.49
                # Test acc increased (95.070000 --> 95.490000).  Saving model ...
                # loss: 0.4310: 100%|██████████| 469/469 [02:44<00:00,  2.85it/s]
                # Epoch: 8 	Test Accuracy: 95.43
                # EarlyStopping counter: 1 out of 100
                # loss: 0.4317: 100%|██████████| 469/469 [02:42<00:00,  2.89it/s]
                # Epoch: 9 	Test Accuracy: 95.63

                # spk1, mem1 = self.lif1(self.bfc1(x.flatten(1)), mem1)
                # if self.batch_norm:
                #     spk1 = self.bn1(spk1)
                # spk2, mem2 = self.lif2(self.bfc2(spk1), mem2)
                # if self.batch_norm:
                #     spk2 = self.bn2(spk2)
                # cur3 = self.dropout(self.bfc3(spk2))
                # spk3, mem3 = self.lif3(cur3, mem3)
                #                 =======Trial: 0=======
                # loss: 6.6307: 100%|██████████| 469/469 [02:45<00:00,  2.84it/s]
                # Epoch: 0 	Test Accuracy: 62.82
                # Test acc increased (0.000000 --> 62.820000).  Saving model ...
                # loss: 6.2358: 100%|██████████| 469/469 [02:45<00:00,  2.84it/s]
                # Epoch: 1 	Test Accuracy: 65.86
                # Test acc increased (62.820000 --> 65.860000).  Saving model ...
                # loss: 2.3662: 100%|██████████| 469/469 [02:45<00:00,  2.84it/s]
                # Epoch: 2 	Test Accuracy: 84.27
                # Test acc increased (65.860000 --> 84.270000).  Saving model ...
                # loss: 1.1735: 100%|██████████| 469/469 [02:46<00:00,  2.82it/s]
                # Epoch: 3 	Test Accuracy: 89.85
                # Test acc increased (84.270000 --> 89.850000).  Saving model ...
                # loss: 0.9328: 100%|██████████| 469/469 [02:45<00:00,  2.83it/s]
                # Epoch: 4 	Test Accuracy: 90.92
                # Test acc increased (89.850000 --> 90.920000).  Saving model ...
                # loss: 0.8480: 100%|██████████| 469/469 [02:46<00:00,  2.82it/s]

                spk3_rec.append(spk3)
                mem3_rec.append(mem3)

            return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)

        # Full Precision
        else:

            for step in range(self.num_steps):

                cur1 = self.fc1(x.flatten(1))
                if self.batch_norm:
                    cur1 = self.bn1(cur1)
                spk1, mem1 = self.lif1(cur1, mem1)
                cur2 = self.fc2(spk1)
                if self.batch_norm:
                    cur2 = self.bn2(cur2)
                spk2, mem2 = self.lif2(cur2, mem2)
                cur3 = self.dropout(self.fc3(spk2))
                spk3, mem3 = self.lif3(cur3, mem3)

                spk3_rec.append(spk3)
                mem3_rec.append(mem3)

            return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)

class NetFC_FirstConv(nn.Module):
    def __init__(self, config, neurons=[16, 256]):
        super().__init__()

        self.thr1 = config['threshold1']
        self.thr2 = config['threshold2']
        self.thr3 = config['threshold3']
        slope = config['slope']
        beta = config['beta']
        self.num_steps = config['num_steps']
        self.batch_norm = config['batch_norm']
        p1 = config['dropout1']
        self.binarize = config['binarize']

        spike_grad = surrogate.fast_sigmoid(slope)
        # Initialize layers with spike operator
        self.bconv1 = BinaryConv2d(1, neurons[0], 5, bias=False)
        self.conv1 = nn.Conv2d(1, neurons[0], 5, bias=False)
        self.conv1_bn = nn.BatchNorm2d(neurons[0])
        self.lif1 = snn.Leaky(beta, threshold=self.thr1, spike_grad=spike_grad)

        n = ((28-(5//2))/2)**2
        print(n)
        self.bfc2 = BinaryLinear(n, neurons[1])
        self.fc2 = nn.Linear(n, neurons[1])
        self.bn2 = nn.BatchNorm1d(neurons[1])
        self.lif2 = snn.Leaky(beta, threshold=self.thr2, spike_grad=spike_grad)

        self.bfc3 = BinaryLinear(neurons[1], 10)
        self.fc3 = nn.Linear(neurons[1], 10)
        self.lif3 = snn.Leaky(beta, threshold=self.thr3, spike_grad=spike_grad)
        self.dropout = nn.Dropout(p1)

    def forward(self, x):

        # Initialize hidden states and outputs at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        # Record the final layer
        spk3_rec = []
        mem3_rec = []

        # Binarized
        if self.binarize:

            for step in range(self.num_steps):

                cur1 = F.avg_pool2d(self.bconv1(x), 2)
                if self.batch_norm:
                    cur1 = self.conv1_bn(cur1)
                spk1, mem1 = self.lif1(cur1, mem1)
                spk2, mem2 = self.lif2(self.bfc2(spk1.flatten(1)), mem2)
                if self.batch_norm:
                    spk2 = self.bn2(spk2)
                cur3 = self.dropout(self.bfc3(spk2))
                spk3, mem3 = self.lif3(cur3, mem3)

                spk3_rec.append(spk3)
                mem3_rec.append(mem3)

            return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)

        # Full Precision
        else:

            for step in range(self.num_steps):

                cur1 = F.avg_pool2d(self.conv1(x), 2)
                if self.batch_norm:
                    cur1 = self.conv1_bn(cur1)
                spk1, mem1 = self.lif1(cur1, mem1)
                spk2, mem2 = self.lif2(self.fc2(spk1.flatten(1)), mem2)
                if self.batch_norm:
                    spk2 = self.bn2(spk2)
                cur3 = self.dropout(self.fc3(spk2))
                spk3, mem3 = self.lif3(cur3, mem3)

                spk3_rec.append(spk3)
                mem3_rec.append(mem3)

            return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)

In [None]:
cfg = config
cfg['model'] = 'NetFC'
run(cfg)

{'model': 'NetFC', 'exp_name': 'mnist_tha', 'num_trials': 5, 'num_epochs': 500, 'binarize': True, 'data_dir': '~/data/mnist', 'batch_size': 128, 'seed': 0, 'num_workers': 0, 'save_csv': True, 'save_model': True, 'early_stopping': True, 'patience': 100, 'grad_clip': False, 'weight_clip': False, 'batch_norm': True, 'dropout1': 0.02856, 'beta': 0.99, 'lr': 0.00997, 'slope': 10.22, 'threshold1': 11.666, 'alpha_thr1': 0.024, 'thr_final1': 4.317, 'threshold2': 14.105, 'alpha_thr2': 0.119, 'thr_final2': 16.29, 'threshold3': 0.6656, 'alpha_thr3': 0.0011, 'thr_final3': 3.496, 'num_steps': 100, 'correct_rate': 0.8, 'incorrect_rate': 0.2, 'betas': (0.9, 0.999), 't_0': 4688, 'eta_min': 0, 'df_lr': True}
tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]])
tensor([[1., 1., 1.,  ..., 1., 1., 1.],
     

loss: 0.7371: 100%|██████████| 469/469 [02:37<00:00,  2.97it/s]


Epoch: 0 	Test Accuracy: 95.05
Test acc increased (0.000000 --> 95.050000).  Saving model ...


loss: 0.6158: 100%|██████████| 469/469 [02:35<00:00,  3.02it/s]


Epoch: 1 	Test Accuracy: 95.83
Test acc increased (95.050000 --> 95.830000).  Saving model ...


loss: 0.5232: 100%|██████████| 469/469 [02:33<00:00,  3.05it/s]


Epoch: 2 	Test Accuracy: 96.39
Test acc increased (95.830000 --> 96.390000).  Saving model ...


loss: 0.4982: 100%|██████████| 469/469 [02:37<00:00,  2.98it/s]


Epoch: 3 	Test Accuracy: 96.7
Test acc increased (96.390000 --> 96.700000).  Saving model ...


loss: 0.4986: 100%|██████████| 469/469 [02:34<00:00,  3.03it/s]


Epoch: 4 	Test Accuracy: 96.19
EarlyStopping counter: 1 out of 100


loss: 0.4009: 100%|██████████| 469/469 [02:34<00:00,  3.03it/s]


Epoch: 5 	Test Accuracy: 96.82
Test acc increased (96.700000 --> 96.820000).  Saving model ...


loss: 0.4119: 100%|██████████| 469/469 [02:35<00:00,  3.02it/s]


Epoch: 6 	Test Accuracy: 97.03
Test acc increased (96.820000 --> 97.030000).  Saving model ...


loss: 0.3949: 100%|██████████| 469/469 [02:34<00:00,  3.04it/s]


Epoch: 7 	Test Accuracy: 96.9
EarlyStopping counter: 1 out of 100


loss: 0.3644: 100%|██████████| 469/469 [02:35<00:00,  3.01it/s]


Epoch: 8 	Test Accuracy: 97.42
Test acc increased (97.030000 --> 97.420000).  Saving model ...


loss: 0.2978: 100%|██████████| 469/469 [02:34<00:00,  3.04it/s]


Epoch: 9 	Test Accuracy: 97.45
Test acc increased (97.420000 --> 97.450000).  Saving model ...


loss: 0.3801: 100%|██████████| 469/469 [02:36<00:00,  2.99it/s]


Epoch: 10 	Test Accuracy: 97.19
EarlyStopping counter: 1 out of 100


loss: 0.4160: 100%|██████████| 469/469 [02:35<00:00,  3.01it/s]


Epoch: 11 	Test Accuracy: 97.1
EarlyStopping counter: 2 out of 100


loss: 0.3943: 100%|██████████| 469/469 [02:35<00:00,  3.01it/s]


Epoch: 12 	Test Accuracy: 96.76
EarlyStopping counter: 3 out of 100


loss: 0.3945: 100%|██████████| 469/469 [02:35<00:00,  3.01it/s]


Epoch: 13 	Test Accuracy: 97.16
EarlyStopping counter: 4 out of 100


loss: 0.4393:  91%|█████████ | 427/469 [02:21<00:14,  2.93it/s]

In [None]:
cfg = config
cfg['model'] = 'NetFC(config, sparsity=0.1)'
run(cfg)

In [None]:
cfg = config
cfg['model'] = 'NetFC(config, [128, 128], sparsity=0.1)'
run(cfg)

{'model': 'NetFC(config, [128, 128], sparsity=0.1)', 'exp_name': 'mnist_tha', 'num_trials': 5, 'num_epochs': 500, 'binarize': True, 'data_dir': '~/data/mnist', 'batch_size': 128, 'seed': 0, 'num_workers': 0, 'save_csv': True, 'save_model': True, 'early_stopping': True, 'patience': 100, 'grad_clip': False, 'weight_clip': False, 'batch_norm': True, 'dropout1': 0.02856, 'beta': 0.99, 'lr': 0.00997, 'slope': 10.22, 'threshold1': 11.666, 'alpha_thr1': 0.024, 'thr_final1': 39.315, 'threshold2': 14.105, 'alpha_thr2': 0.119, 'thr_final2': 58.605000000000004, 'threshold3': 0.6656, 'alpha_thr3': 0.0011, 'thr_final3': 5.492799999999999, 'num_steps': 100, 'correct_rate': 0.8, 'incorrect_rate': 0.2, 'betas': (0.9, 0.999), 't_0': 4688, 'eta_min': 0, 'df_lr': True, 'dataset_length': 60000}
tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 1., 0., 0.],
 

loss: 2.2480: 100%|██████████| 469/469 [02:39<00:00,  2.94it/s]


Epoch: 0 	Test Accuracy: 68.11
Test acc increased (0.000000 --> 68.110000).  Saving model ...


loss: 1.6682: 100%|██████████| 469/469 [02:39<00:00,  2.93it/s]


Epoch: 1 	Test Accuracy: 82.4
Test acc increased (68.110000 --> 82.400000).  Saving model ...


loss: 1.3879: 100%|██████████| 469/469 [02:41<00:00,  2.90it/s]


Epoch: 2 	Test Accuracy: 86.15
Test acc increased (82.400000 --> 86.150000).  Saving model ...


loss: 1.3864:  32%|███▏      | 148/469 [00:51<01:51,  2.87it/s]


KeyboardInterrupt: ignored

In [None]:
cfg = config
cfg['model'] = 'NetFC(config, [128, 128])'
run(cfg)

{'model': 'NetFC(config, [128, 128])', 'exp_name': 'mnist_tha', 'num_trials': 5, 'num_epochs': 500, 'binarize': True, 'data_dir': '~/data/mnist', 'batch_size': 128, 'seed': 0, 'num_workers': 0, 'save_csv': True, 'save_model': True, 'early_stopping': True, 'patience': 100, 'grad_clip': False, 'weight_clip': False, 'batch_norm': True, 'dropout1': 0.02856, 'beta': 0.99, 'lr': 0.00997, 'slope': 10.22, 'threshold1': 11.666, 'alpha_thr1': 0.024, 'thr_final1': 4.317, 'threshold2': 14.105, 'alpha_thr2': 0.119, 'thr_final2': 16.29, 'threshold3': 0.6656, 'alpha_thr3': 0.0011, 'thr_final3': 3.496, 'num_steps': 100, 'correct_rate': 0.8, 'incorrect_rate': 0.2, 'betas': (0.9, 0.999), 't_0': 4688, 'eta_min': 0, 'df_lr': True}
NetFC(
  (bfc1): SparseBinaryLinear(in_features=784, out_features=128, bias=True)
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (lif1): Leaky()
  (bfc2): SparseBinaryLinea

100%|██████████| 9912422/9912422 [00:00<00:00, 81587872.50it/s]


Extracting /root/data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /root/data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /root/data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 102831658.59it/s]

Extracting /root/data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /root/data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /root/data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz



100%|██████████| 1648877/1648877 [00:00<00:00, 20970082.89it/s]


Extracting /root/data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /root/data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /root/data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 6998724.75it/s]


Extracting /root/data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /root/data/mnist/MNIST/raw



loss: 0.8217: 100%|██████████| 469/469 [02:45<00:00,  2.84it/s]


Epoch: 0 	Test Accuracy: 93.57
Test acc increased (0.000000 --> 93.570000).  Saving model ...


loss: 0.7384: 100%|██████████| 469/469 [02:41<00:00,  2.91it/s]


Epoch: 1 	Test Accuracy: 94.28
Test acc increased (93.570000 --> 94.280000).  Saving model ...


loss: 0.7035: 100%|██████████| 469/469 [02:40<00:00,  2.92it/s]


Epoch: 2 	Test Accuracy: 94.8
Test acc increased (94.280000 --> 94.800000).  Saving model ...


loss: 0.6322: 100%|██████████| 469/469 [02:40<00:00,  2.92it/s]


Epoch: 3 	Test Accuracy: 94.93
Test acc increased (94.800000 --> 94.930000).  Saving model ...


loss: 0.5964: 100%|██████████| 469/469 [02:40<00:00,  2.92it/s]


Epoch: 4 	Test Accuracy: 94.57
EarlyStopping counter: 1 out of 100


loss: 0.6028: 100%|██████████| 469/469 [02:40<00:00,  2.92it/s]


Epoch: 5 	Test Accuracy: 95.05
Test acc increased (94.930000 --> 95.050000).  Saving model ...


loss: 0.5782: 100%|██████████| 469/469 [02:42<00:00,  2.88it/s]


Epoch: 6 	Test Accuracy: 95.6
Test acc increased (95.050000 --> 95.600000).  Saving model ...


loss: 0.5397: 100%|██████████| 469/469 [02:42<00:00,  2.89it/s]


Epoch: 7 	Test Accuracy: 95.74
Test acc increased (95.600000 --> 95.740000).  Saving model ...


loss: 0.5119: 100%|██████████| 469/469 [02:40<00:00,  2.92it/s]


Epoch: 8 	Test Accuracy: 95.81
Test acc increased (95.740000 --> 95.810000).  Saving model ...


loss: 0.4337: 100%|██████████| 469/469 [02:35<00:00,  3.02it/s]


Epoch: 9 	Test Accuracy: 96.15
Test acc increased (95.810000 --> 96.150000).  Saving model ...


loss: 0.5021: 100%|██████████| 469/469 [02:34<00:00,  3.04it/s]


Epoch: 10 	Test Accuracy: 96.11
EarlyStopping counter: 1 out of 100


loss: 0.5650:   7%|▋         | 34/469 [00:11<02:25,  2.99it/s]


KeyboardInterrupt: ignored

In [None]:
binarize = BinarizeAnsSparsifyF.apply; print(binarize)
cfg = config
cfg['model'] = 'NetFC(config, [128, 128])'
run(cfg)
binarize = BinarizeF.apply

<bound method Function.apply of <class '__main__.BinarizeAnsSparsifyF'>>
{'model': 'NetFC(config, [128, 128])', 'exp_name': 'mnist_tha', 'num_trials': 5, 'num_epochs': 500, 'binarize': True, 'data_dir': '~/data/mnist', 'batch_size': 128, 'seed': 0, 'num_workers': 0, 'save_csv': True, 'save_model': True, 'early_stopping': True, 'patience': 100, 'grad_clip': False, 'weight_clip': False, 'batch_norm': True, 'dropout1': 0.02856, 'beta': 0.99, 'lr': 0.00997, 'slope': 10.22, 'threshold1': 11.666, 'alpha_thr1': 0.024, 'thr_final1': 132.64299999999997, 'threshold2': 14.105, 'alpha_thr2': 0.119, 'thr_final2': 171.445, 'threshold3': 0.6656, 'alpha_thr3': 0.0011, 'thr_final3': 10.817599999999995, 'num_steps': 100, 'correct_rate': 0.8, 'incorrect_rate': 0.2, 'betas': (0.9, 0.999), 't_0': 4688, 'eta_min': 0, 'df_lr': True, 'dataset_length': 60000}
NetFC(
  (bfc1): BinaryLinear(in_features=784, out_features=128, bias=True)
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (bn1): BatchN

loss: 3.1639: 100%|██████████| 469/469 [02:50<00:00,  2.76it/s]


Epoch: 0 	Test Accuracy: 19.18
Test acc increased (0.000000 --> 19.180000).  Saving model ...


loss: 2.9104: 100%|██████████| 469/469 [02:48<00:00,  2.78it/s]


Epoch: 1 	Test Accuracy: 50.52
Test acc increased (19.180000 --> 50.520000).  Saving model ...


loss: 2.6256: 100%|██████████| 469/469 [02:46<00:00,  2.82it/s]


Epoch: 2 	Test Accuracy: 63.73
Test acc increased (50.520000 --> 63.730000).  Saving model ...


loss: 2.4258: 100%|██████████| 469/469 [02:44<00:00,  2.85it/s]


Epoch: 3 	Test Accuracy: 69.65
Test acc increased (63.730000 --> 69.650000).  Saving model ...


loss: 2.3231: 100%|██████████| 469/469 [02:46<00:00,  2.82it/s]


Epoch: 4 	Test Accuracy: 67.64
EarlyStopping counter: 1 out of 100


loss: 2.2189: 100%|██████████| 469/469 [02:44<00:00,  2.85it/s]


Epoch: 5 	Test Accuracy: 72.19
Test acc increased (69.650000 --> 72.190000).  Saving model ...


loss: 2.1478: 100%|██████████| 469/469 [02:44<00:00,  2.85it/s]


Epoch: 6 	Test Accuracy: 72.03
EarlyStopping counter: 1 out of 100


loss: 2.1336: 100%|██████████| 469/469 [02:44<00:00,  2.85it/s]


Epoch: 7 	Test Accuracy: 71.62
EarlyStopping counter: 2 out of 100


loss: 2.1269: 100%|██████████| 469/469 [02:46<00:00,  2.83it/s]


Epoch: 8 	Test Accuracy: 72.59
Test acc increased (72.190000 --> 72.590000).  Saving model ...


loss: 2.1047: 100%|██████████| 469/469 [02:42<00:00,  2.89it/s]


Epoch: 9 	Test Accuracy: 73.29
Test acc increased (72.590000 --> 73.290000).  Saving model ...


loss: 2.1254: 100%|██████████| 469/469 [02:42<00:00,  2.89it/s]


Epoch: 10 	Test Accuracy: 73.38
Test acc increased (73.290000 --> 73.380000).  Saving model ...


loss: 2.1354: 100%|██████████| 469/469 [02:41<00:00,  2.90it/s]


Epoch: 11 	Test Accuracy: 71.16
EarlyStopping counter: 1 out of 100


loss: 2.0641: 100%|██████████| 469/469 [02:42<00:00,  2.89it/s]


Epoch: 12 	Test Accuracy: 72.75
EarlyStopping counter: 2 out of 100


loss: 1.9617: 100%|██████████| 469/469 [02:43<00:00,  2.87it/s]


Epoch: 13 	Test Accuracy: 72.59
EarlyStopping counter: 3 out of 100


loss: 1.7518: 100%|██████████| 469/469 [02:41<00:00,  2.90it/s]


Epoch: 14 	Test Accuracy: 73.98
Test acc increased (73.380000 --> 73.980000).  Saving model ...


loss: 1.5294: 100%|██████████| 469/469 [02:41<00:00,  2.91it/s]


Epoch: 15 	Test Accuracy: 74.77
Test acc increased (73.980000 --> 74.770000).  Saving model ...


loss: 1.2905: 100%|██████████| 469/469 [02:43<00:00,  2.87it/s]


Epoch: 16 	Test Accuracy: 87.52
Test acc increased (74.770000 --> 87.520000).  Saving model ...


loss: 1.0854: 100%|██████████| 469/469 [02:41<00:00,  2.90it/s]


Epoch: 17 	Test Accuracy: 90.2
Test acc increased (87.520000 --> 90.200000).  Saving model ...


loss: 0.9083: 100%|██████████| 469/469 [02:43<00:00,  2.87it/s]


Epoch: 18 	Test Accuracy: 90.07
EarlyStopping counter: 1 out of 100


loss: 0.7851: 100%|██████████| 469/469 [02:44<00:00,  2.84it/s]


Epoch: 19 	Test Accuracy: 91.35
Test acc increased (90.200000 --> 91.350000).  Saving model ...


loss: 0.7478: 100%|██████████| 469/469 [02:41<00:00,  2.90it/s]


Epoch: 20 	Test Accuracy: 90.61
EarlyStopping counter: 1 out of 100


loss: 0.7122: 100%|██████████| 469/469 [02:41<00:00,  2.91it/s]


Epoch: 21 	Test Accuracy: 91.9
Test acc increased (91.350000 --> 91.900000).  Saving model ...


loss: 0.7341: 100%|██████████| 469/469 [02:44<00:00,  2.86it/s]


Epoch: 22 	Test Accuracy: 91.52
EarlyStopping counter: 1 out of 100


loss: 0.7067: 100%|██████████| 469/469 [02:44<00:00,  2.86it/s]


Epoch: 23 	Test Accuracy: 92.7
Test acc increased (91.900000 --> 92.700000).  Saving model ...


loss: 0.6618: 100%|██████████| 469/469 [02:45<00:00,  2.84it/s]


Epoch: 24 	Test Accuracy: 92.98
Test acc increased (92.700000 --> 92.980000).  Saving model ...


loss: 0.6329: 100%|██████████| 469/469 [02:42<00:00,  2.89it/s]


Epoch: 25 	Test Accuracy: 93.25
Test acc increased (92.980000 --> 93.250000).  Saving model ...


loss: 0.6487: 100%|██████████| 469/469 [02:43<00:00,  2.88it/s]


Epoch: 26 	Test Accuracy: 93.6
Test acc increased (93.250000 --> 93.600000).  Saving model ...


loss: 0.5709: 100%|██████████| 469/469 [02:43<00:00,  2.88it/s]


Epoch: 27 	Test Accuracy: 94.17
Test acc increased (93.600000 --> 94.170000).  Saving model ...


loss: 0.5760: 100%|██████████| 469/469 [02:42<00:00,  2.88it/s]


Epoch: 28 	Test Accuracy: 93.74
EarlyStopping counter: 1 out of 100


loss: 0.6159: 100%|██████████| 469/469 [02:43<00:00,  2.86it/s]


Epoch: 29 	Test Accuracy: 93.53
EarlyStopping counter: 2 out of 100


loss: 0.6306: 100%|██████████| 469/469 [02:43<00:00,  2.87it/s]


Epoch: 30 	Test Accuracy: 93.45
EarlyStopping counter: 3 out of 100


loss: 0.5531: 100%|██████████| 469/469 [02:42<00:00,  2.89it/s]


Epoch: 31 	Test Accuracy: 93.79
EarlyStopping counter: 4 out of 100


loss: 0.5907: 100%|██████████| 469/469 [02:44<00:00,  2.85it/s]


Epoch: 32 	Test Accuracy: 93.71
EarlyStopping counter: 5 out of 100


loss: 0.5746: 100%|██████████| 469/469 [02:42<00:00,  2.88it/s]


Epoch: 33 	Test Accuracy: 93.97
EarlyStopping counter: 6 out of 100


loss: 0.6001: 100%|██████████| 469/469 [02:42<00:00,  2.89it/s]


Epoch: 34 	Test Accuracy: 93.88
EarlyStopping counter: 7 out of 100


loss: 0.6649: 100%|██████████| 469/469 [02:40<00:00,  2.92it/s]


Epoch: 35 	Test Accuracy: 93.3
EarlyStopping counter: 8 out of 100


loss: 0.5827: 100%|██████████| 469/469 [02:39<00:00,  2.94it/s]


Epoch: 36 	Test Accuracy: 93.54
EarlyStopping counter: 9 out of 100


loss: 0.5955: 100%|██████████| 469/469 [02:38<00:00,  2.96it/s]


Epoch: 37 	Test Accuracy: 94.04
EarlyStopping counter: 10 out of 100


loss: 0.5881: 100%|██████████| 469/469 [02:38<00:00,  2.95it/s]


Epoch: 38 	Test Accuracy: 93.94
EarlyStopping counter: 11 out of 100


loss: 0.5821: 100%|██████████| 469/469 [02:39<00:00,  2.95it/s]


Epoch: 39 	Test Accuracy: 94.14
EarlyStopping counter: 12 out of 100


loss: 0.5927: 100%|██████████| 469/469 [02:40<00:00,  2.93it/s]


Epoch: 40 	Test Accuracy: 94.03
EarlyStopping counter: 13 out of 100


loss: 0.5290: 100%|██████████| 469/469 [02:39<00:00,  2.94it/s]


Epoch: 41 	Test Accuracy: 94.26
Test acc increased (94.170000 --> 94.260000).  Saving model ...


loss: 0.5380: 100%|██████████| 469/469 [02:40<00:00,  2.93it/s]


Epoch: 42 	Test Accuracy: 94.42
Test acc increased (94.260000 --> 94.420000).  Saving model ...


loss: 0.5211: 100%|██████████| 469/469 [02:38<00:00,  2.96it/s]


Epoch: 43 	Test Accuracy: 93.75
EarlyStopping counter: 1 out of 100


loss: 0.5172: 100%|██████████| 469/469 [02:39<00:00,  2.95it/s]


Epoch: 44 	Test Accuracy: 94.87
Test acc increased (94.420000 --> 94.870000).  Saving model ...


loss: 0.4963: 100%|██████████| 469/469 [02:39<00:00,  2.93it/s]


Epoch: 45 	Test Accuracy: 94.73
EarlyStopping counter: 1 out of 100


loss: 0.4974: 100%|██████████| 469/469 [02:39<00:00,  2.94it/s]


Epoch: 46 	Test Accuracy: 95.26
Test acc increased (94.870000 --> 95.260000).  Saving model ...


loss: 0.5190: 100%|██████████| 469/469 [02:41<00:00,  2.91it/s]


Epoch: 47 	Test Accuracy: 95.15
EarlyStopping counter: 1 out of 100


loss: 0.4819: 100%|██████████| 469/469 [02:42<00:00,  2.88it/s]


Epoch: 48 	Test Accuracy: 95.36
Test acc increased (95.260000 --> 95.360000).  Saving model ...


loss: 0.4989: 100%|██████████| 469/469 [02:43<00:00,  2.86it/s]


Epoch: 49 	Test Accuracy: 95.2
EarlyStopping counter: 1 out of 100


loss: 0.4952: 100%|██████████| 469/469 [02:45<00:00,  2.84it/s]


Epoch: 50 	Test Accuracy: 95.23
EarlyStopping counter: 2 out of 100


loss: 0.4882: 100%|██████████| 469/469 [02:45<00:00,  2.83it/s]


Epoch: 51 	Test Accuracy: 94.95
EarlyStopping counter: 3 out of 100


loss: 0.5096: 100%|██████████| 469/469 [02:45<00:00,  2.84it/s]


Epoch: 52 	Test Accuracy: 95.29
EarlyStopping counter: 4 out of 100


loss: 0.5233: 100%|██████████| 469/469 [02:44<00:00,  2.86it/s]


Epoch: 53 	Test Accuracy: 95.19
EarlyStopping counter: 5 out of 100


loss: 0.5324: 100%|██████████| 469/469 [02:44<00:00,  2.86it/s]


Epoch: 54 	Test Accuracy: 94.38
EarlyStopping counter: 6 out of 100


loss: 0.5261: 100%|██████████| 469/469 [02:48<00:00,  2.79it/s]


Epoch: 55 	Test Accuracy: 95.4
Test acc increased (95.360000 --> 95.400000).  Saving model ...


loss: 0.5344: 100%|██████████| 469/469 [02:44<00:00,  2.85it/s]


Epoch: 56 	Test Accuracy: 94.94
EarlyStopping counter: 1 out of 100


loss: 0.4907: 100%|██████████| 469/469 [02:44<00:00,  2.85it/s]


Epoch: 57 	Test Accuracy: 94.42
EarlyStopping counter: 2 out of 100


loss: 0.5112: 100%|██████████| 469/469 [02:42<00:00,  2.89it/s]


Epoch: 58 	Test Accuracy: 94.92
EarlyStopping counter: 3 out of 100


loss: 0.5206: 100%|██████████| 469/469 [02:42<00:00,  2.89it/s]


Epoch: 59 	Test Accuracy: 95.2
EarlyStopping counter: 4 out of 100


loss: 0.4733: 100%|██████████| 469/469 [02:42<00:00,  2.88it/s]


Epoch: 60 	Test Accuracy: 95.36
EarlyStopping counter: 5 out of 100


loss: 0.5005: 100%|██████████| 469/469 [02:43<00:00,  2.88it/s]


Epoch: 61 	Test Accuracy: 95.33
EarlyStopping counter: 6 out of 100


loss: 0.5165: 100%|██████████| 469/469 [02:44<00:00,  2.84it/s]


Epoch: 62 	Test Accuracy: 94.9
EarlyStopping counter: 7 out of 100


loss: 0.4586: 100%|██████████| 469/469 [02:41<00:00,  2.91it/s]


Epoch: 63 	Test Accuracy: 94.5
EarlyStopping counter: 8 out of 100


loss: 0.4810: 100%|██████████| 469/469 [02:40<00:00,  2.92it/s]


Epoch: 64 	Test Accuracy: 95.36
EarlyStopping counter: 9 out of 100


loss: 0.4687: 100%|██████████| 469/469 [02:40<00:00,  2.92it/s]


Epoch: 65 	Test Accuracy: 95.27
EarlyStopping counter: 10 out of 100


loss: 0.4455: 100%|██████████| 469/469 [02:39<00:00,  2.94it/s]


Epoch: 66 	Test Accuracy: 95.47
Test acc increased (95.400000 --> 95.470000).  Saving model ...


loss: 0.4888: 100%|██████████| 469/469 [02:39<00:00,  2.95it/s]


Epoch: 67 	Test Accuracy: 95.6
Test acc increased (95.470000 --> 95.600000).  Saving model ...


loss: 0.4374: 100%|██████████| 469/469 [02:40<00:00,  2.93it/s]


Epoch: 68 	Test Accuracy: 95.68
Test acc increased (95.600000 --> 95.680000).  Saving model ...


loss: 0.4090: 100%|██████████| 469/469 [02:39<00:00,  2.94it/s]


Epoch: 69 	Test Accuracy: 95.95
Test acc increased (95.680000 --> 95.950000).  Saving model ...


loss: 0.4351: 100%|██████████| 469/469 [02:39<00:00,  2.94it/s]


Epoch: 70 	Test Accuracy: 95.62
EarlyStopping counter: 1 out of 100


loss: 0.4192: 100%|██████████| 469/469 [02:39<00:00,  2.95it/s]


Epoch: 71 	Test Accuracy: 95.76
EarlyStopping counter: 2 out of 100


loss: 0.4403: 100%|██████████| 469/469 [02:40<00:00,  2.93it/s]


Epoch: 72 	Test Accuracy: 95.77
EarlyStopping counter: 3 out of 100


loss: 0.4479: 100%|██████████| 469/469 [02:39<00:00,  2.95it/s]


Epoch: 73 	Test Accuracy: 95.42
EarlyStopping counter: 4 out of 100


loss: 0.4754: 100%|██████████| 469/469 [02:38<00:00,  2.96it/s]


Epoch: 74 	Test Accuracy: 95.33
EarlyStopping counter: 5 out of 100


loss: 0.4649: 100%|██████████| 469/469 [02:39<00:00,  2.95it/s]


Epoch: 75 	Test Accuracy: 95.41
EarlyStopping counter: 6 out of 100


loss: 0.4671: 100%|██████████| 469/469 [02:39<00:00,  2.95it/s]


Epoch: 76 	Test Accuracy: 95.48
EarlyStopping counter: 7 out of 100


loss: 0.4490: 100%|██████████| 469/469 [02:39<00:00,  2.95it/s]


Epoch: 77 	Test Accuracy: 95.68
EarlyStopping counter: 8 out of 100


loss: 0.4931: 100%|██████████| 469/469 [02:38<00:00,  2.95it/s]


Epoch: 78 	Test Accuracy: 94.14
EarlyStopping counter: 9 out of 100


loss: 0.4651: 100%|██████████| 469/469 [02:40<00:00,  2.92it/s]


Epoch: 79 	Test Accuracy: 95.53
EarlyStopping counter: 10 out of 100


loss: 0.5143: 100%|██████████| 469/469 [02:40<00:00,  2.91it/s]


Epoch: 80 	Test Accuracy: 95.54
EarlyStopping counter: 11 out of 100


loss: 0.4348: 100%|██████████| 469/469 [02:40<00:00,  2.92it/s]


Epoch: 81 	Test Accuracy: 95.41
EarlyStopping counter: 12 out of 100


loss: 0.4203: 100%|██████████| 469/469 [02:40<00:00,  2.91it/s]


Epoch: 82 	Test Accuracy: 95.39
EarlyStopping counter: 13 out of 100


loss: 0.4676: 100%|██████████| 469/469 [02:41<00:00,  2.90it/s]


Epoch: 83 	Test Accuracy: 95.52
EarlyStopping counter: 14 out of 100


loss: 0.4151: 100%|██████████| 469/469 [02:43<00:00,  2.86it/s]


Epoch: 84 	Test Accuracy: 95.89
EarlyStopping counter: 15 out of 100


loss: 0.4091: 100%|██████████| 469/469 [02:43<00:00,  2.88it/s]


Epoch: 85 	Test Accuracy: 95.76
EarlyStopping counter: 16 out of 100


loss: 0.4507: 100%|██████████| 469/469 [02:41<00:00,  2.90it/s]


Epoch: 86 	Test Accuracy: 95.96
Test acc increased (95.950000 --> 95.960000).  Saving model ...


loss: 0.4301: 100%|██████████| 469/469 [02:42<00:00,  2.88it/s]


Epoch: 87 	Test Accuracy: 95.13
EarlyStopping counter: 1 out of 100


loss: 0.3903: 100%|██████████| 469/469 [02:42<00:00,  2.89it/s]


Epoch: 88 	Test Accuracy: 95.67
EarlyStopping counter: 2 out of 100


loss: 0.3983: 100%|██████████| 469/469 [02:43<00:00,  2.88it/s]


Epoch: 89 	Test Accuracy: 95.97
Test acc increased (95.960000 --> 95.970000).  Saving model ...


loss: 0.4230: 100%|██████████| 469/469 [02:43<00:00,  2.87it/s]


Epoch: 90 	Test Accuracy: 95.92
EarlyStopping counter: 1 out of 100


loss: 0.3965: 100%|██████████| 469/469 [02:44<00:00,  2.85it/s]


Epoch: 91 	Test Accuracy: 95.86
EarlyStopping counter: 2 out of 100


loss: 0.4205: 100%|██████████| 469/469 [02:43<00:00,  2.87it/s]


Epoch: 92 	Test Accuracy: 95.8
EarlyStopping counter: 3 out of 100


loss: 0.4070: 100%|██████████| 469/469 [02:43<00:00,  2.86it/s]


Epoch: 93 	Test Accuracy: 95.87
EarlyStopping counter: 4 out of 100


loss: 0.3994: 100%|██████████| 469/469 [02:43<00:00,  2.87it/s]


Epoch: 94 	Test Accuracy: 95.77
EarlyStopping counter: 5 out of 100


loss: 0.4376: 100%|██████████| 469/469 [02:44<00:00,  2.86it/s]


Epoch: 95 	Test Accuracy: 95.41
EarlyStopping counter: 6 out of 100


loss: 0.4241: 100%|██████████| 469/469 [02:42<00:00,  2.89it/s]


Epoch: 96 	Test Accuracy: 95.39
EarlyStopping counter: 7 out of 100


loss: 0.4662: 100%|██████████| 469/469 [02:41<00:00,  2.90it/s]


Epoch: 97 	Test Accuracy: 95.81
EarlyStopping counter: 8 out of 100


loss: 0.4445: 100%|██████████| 469/469 [02:40<00:00,  2.93it/s]


Epoch: 98 	Test Accuracy: 95.56
EarlyStopping counter: 9 out of 100


loss: 0.4313: 100%|██████████| 469/469 [02:43<00:00,  2.88it/s]


Epoch: 99 	Test Accuracy: 95.81
EarlyStopping counter: 10 out of 100


loss: 0.4438: 100%|██████████| 469/469 [02:40<00:00,  2.92it/s]


Epoch: 100 	Test Accuracy: 95.83
EarlyStopping counter: 11 out of 100


loss: 0.4472: 100%|██████████| 469/469 [02:45<00:00,  2.83it/s]


Epoch: 101 	Test Accuracy: 95.63
EarlyStopping counter: 12 out of 100


loss: 0.4313: 100%|██████████| 469/469 [02:44<00:00,  2.86it/s]


Epoch: 102 	Test Accuracy: 95.65
EarlyStopping counter: 13 out of 100


loss: 0.4407: 100%|██████████| 469/469 [02:42<00:00,  2.89it/s]


Epoch: 103 	Test Accuracy: 95.55
EarlyStopping counter: 14 out of 100


loss: 0.4281: 100%|██████████| 469/469 [02:41<00:00,  2.91it/s]


Epoch: 104 	Test Accuracy: 95.6
EarlyStopping counter: 15 out of 100


loss: 0.3991: 100%|██████████| 469/469 [02:41<00:00,  2.90it/s]


Epoch: 105 	Test Accuracy: 95.81
EarlyStopping counter: 16 out of 100


loss: 0.3624: 100%|██████████| 469/469 [02:41<00:00,  2.91it/s]


Epoch: 106 	Test Accuracy: 95.67
EarlyStopping counter: 17 out of 100


loss: 0.3822: 100%|██████████| 469/469 [02:44<00:00,  2.85it/s]


Epoch: 107 	Test Accuracy: 95.86
EarlyStopping counter: 18 out of 100


loss: 0.3560: 100%|██████████| 469/469 [02:43<00:00,  2.87it/s]


Epoch: 108 	Test Accuracy: 95.95
EarlyStopping counter: 19 out of 100


loss: 0.3865: 100%|██████████| 469/469 [02:43<00:00,  2.86it/s]


Epoch: 109 	Test Accuracy: 95.92
EarlyStopping counter: 20 out of 100


loss: 0.3676: 100%|██████████| 469/469 [02:42<00:00,  2.89it/s]


Epoch: 110 	Test Accuracy: 95.66
EarlyStopping counter: 21 out of 100


loss: 0.3671: 100%|██████████| 469/469 [02:43<00:00,  2.88it/s]


Epoch: 111 	Test Accuracy: 96.02
Test acc increased (95.970000 --> 96.020000).  Saving model ...


loss: 0.3827: 100%|██████████| 469/469 [02:42<00:00,  2.88it/s]


Epoch: 112 	Test Accuracy: 96.11
Test acc increased (96.020000 --> 96.110000).  Saving model ...


loss: 0.3810: 100%|██████████| 469/469 [02:44<00:00,  2.84it/s]


Epoch: 113 	Test Accuracy: 96.28
Test acc increased (96.110000 --> 96.280000).  Saving model ...


loss: 0.4117: 100%|██████████| 469/469 [02:43<00:00,  2.87it/s]


Epoch: 114 	Test Accuracy: 95.55
EarlyStopping counter: 1 out of 100


loss: 0.4160: 100%|██████████| 469/469 [02:43<00:00,  2.86it/s]


Epoch: 115 	Test Accuracy: 96.11
EarlyStopping counter: 2 out of 100


loss: 0.3939: 100%|██████████| 469/469 [02:44<00:00,  2.85it/s]


Epoch: 116 	Test Accuracy: 95.98
EarlyStopping counter: 3 out of 100


loss: 0.4134: 100%|██████████| 469/469 [02:45<00:00,  2.83it/s]


Epoch: 117 	Test Accuracy: 95.17
EarlyStopping counter: 4 out of 100


loss: 0.4291: 100%|██████████| 469/469 [02:45<00:00,  2.83it/s]


Epoch: 118 	Test Accuracy: 95.77
EarlyStopping counter: 5 out of 100


loss: 0.4031: 100%|██████████| 469/469 [02:46<00:00,  2.81it/s]


Epoch: 119 	Test Accuracy: 95.59
EarlyStopping counter: 6 out of 100


loss: 0.3906: 100%|██████████| 469/469 [02:45<00:00,  2.84it/s]


Epoch: 120 	Test Accuracy: 95.74
EarlyStopping counter: 7 out of 100


loss: 0.4024: 100%|██████████| 469/469 [02:46<00:00,  2.82it/s]


Epoch: 121 	Test Accuracy: 95.88
EarlyStopping counter: 8 out of 100


loss: 0.3836: 100%|██████████| 469/469 [02:48<00:00,  2.79it/s]


Epoch: 122 	Test Accuracy: 95.74
EarlyStopping counter: 9 out of 100


loss: 0.3900: 100%|██████████| 469/469 [02:52<00:00,  2.73it/s]


Epoch: 123 	Test Accuracy: 95.91
EarlyStopping counter: 10 out of 100


loss: 0.3803: 100%|██████████| 469/469 [02:50<00:00,  2.75it/s]


Epoch: 124 	Test Accuracy: 96.0
EarlyStopping counter: 11 out of 100


loss: 0.3778: 100%|██████████| 469/469 [02:51<00:00,  2.74it/s]


Epoch: 125 	Test Accuracy: 95.86
EarlyStopping counter: 12 out of 100


loss: 0.3772: 100%|██████████| 469/469 [02:48<00:00,  2.79it/s]


Epoch: 126 	Test Accuracy: 95.58
EarlyStopping counter: 13 out of 100


loss: 0.3788: 100%|██████████| 469/469 [02:46<00:00,  2.81it/s]


Epoch: 127 	Test Accuracy: 95.86
EarlyStopping counter: 14 out of 100


loss: 0.3812: 100%|██████████| 469/469 [02:46<00:00,  2.81it/s]


Epoch: 128 	Test Accuracy: 95.97
EarlyStopping counter: 15 out of 100


loss: 0.3537: 100%|██████████| 469/469 [02:45<00:00,  2.83it/s]


Epoch: 129 	Test Accuracy: 95.84
EarlyStopping counter: 16 out of 100


loss: 0.3314: 100%|██████████| 469/469 [02:46<00:00,  2.82it/s]


Epoch: 130 	Test Accuracy: 96.02
EarlyStopping counter: 17 out of 100


loss: 0.3444: 100%|██████████| 469/469 [02:44<00:00,  2.86it/s]


Epoch: 131 	Test Accuracy: 96.11
EarlyStopping counter: 18 out of 100


loss: 0.3468: 100%|██████████| 469/469 [02:42<00:00,  2.89it/s]


Epoch: 132 	Test Accuracy: 95.98
EarlyStopping counter: 19 out of 100


loss: 0.3666: 100%|██████████| 469/469 [02:43<00:00,  2.86it/s]


Epoch: 133 	Test Accuracy: 96.16
EarlyStopping counter: 20 out of 100


loss: 0.3827: 100%|██████████| 469/469 [02:45<00:00,  2.83it/s]


Epoch: 134 	Test Accuracy: 96.06
EarlyStopping counter: 21 out of 100


loss: 0.3848: 100%|██████████| 469/469 [02:45<00:00,  2.84it/s]


Epoch: 135 	Test Accuracy: 96.1
EarlyStopping counter: 22 out of 100


loss: 0.3737: 100%|██████████| 469/469 [02:45<00:00,  2.83it/s]


Epoch: 136 	Test Accuracy: 96.02
EarlyStopping counter: 23 out of 100


loss: 0.4342: 100%|██████████| 469/469 [02:45<00:00,  2.83it/s]


Epoch: 137 	Test Accuracy: 96.17
EarlyStopping counter: 24 out of 100


loss: 0.4081: 100%|██████████| 469/469 [02:44<00:00,  2.85it/s]


Epoch: 138 	Test Accuracy: 95.06
EarlyStopping counter: 25 out of 100


loss: 0.3728: 100%|██████████| 469/469 [02:45<00:00,  2.84it/s]


Epoch: 139 	Test Accuracy: 95.84
EarlyStopping counter: 26 out of 100


loss: 0.4270: 100%|██████████| 469/469 [02:44<00:00,  2.85it/s]


Epoch: 140 	Test Accuracy: 95.51
EarlyStopping counter: 27 out of 100


loss: 0.3998: 100%|██████████| 469/469 [02:45<00:00,  2.83it/s]


Epoch: 141 	Test Accuracy: 95.53
EarlyStopping counter: 28 out of 100


loss: 0.3869: 100%|██████████| 469/469 [02:46<00:00,  2.82it/s]


Epoch: 142 	Test Accuracy: 95.61
EarlyStopping counter: 29 out of 100


loss: 0.4065: 100%|██████████| 469/469 [02:45<00:00,  2.83it/s]


Epoch: 143 	Test Accuracy: 95.64
EarlyStopping counter: 30 out of 100


loss: 0.4086: 100%|██████████| 469/469 [02:44<00:00,  2.85it/s]


Epoch: 144 	Test Accuracy: 95.55
EarlyStopping counter: 31 out of 100


loss: 0.3428: 100%|██████████| 469/469 [02:43<00:00,  2.86it/s]


Epoch: 145 	Test Accuracy: 95.96
EarlyStopping counter: 32 out of 100


loss: 0.3628: 100%|██████████| 469/469 [02:42<00:00,  2.89it/s]


Epoch: 146 	Test Accuracy: 96.16
EarlyStopping counter: 33 out of 100


loss: 0.3721: 100%|██████████| 469/469 [02:42<00:00,  2.88it/s]


Epoch: 147 	Test Accuracy: 95.76
EarlyStopping counter: 34 out of 100


loss: 0.3439: 100%|██████████| 469/469 [02:43<00:00,  2.88it/s]


Epoch: 148 	Test Accuracy: 95.82
EarlyStopping counter: 35 out of 100


loss: 0.3442: 100%|██████████| 469/469 [02:42<00:00,  2.89it/s]


Epoch: 149 	Test Accuracy: 96.36
Test acc increased (96.280000 --> 96.360000).  Saving model ...


loss: 0.3419: 100%|██████████| 469/469 [02:41<00:00,  2.91it/s]


Epoch: 150 	Test Accuracy: 96.28
EarlyStopping counter: 1 out of 100


loss: 0.3614: 100%|██████████| 469/469 [02:42<00:00,  2.89it/s]


Epoch: 151 	Test Accuracy: 96.26
EarlyStopping counter: 2 out of 100


loss: 0.3355: 100%|██████████| 469/469 [02:41<00:00,  2.90it/s]


Epoch: 152 	Test Accuracy: 96.26
EarlyStopping counter: 3 out of 100


loss: 0.3544: 100%|██████████| 469/469 [02:43<00:00,  2.87it/s]


Epoch: 153 	Test Accuracy: 96.18
EarlyStopping counter: 4 out of 100


loss: 0.3952: 100%|██████████| 469/469 [02:41<00:00,  2.90it/s]


Epoch: 154 	Test Accuracy: 96.07
EarlyStopping counter: 5 out of 100


loss: 0.3572: 100%|██████████| 469/469 [02:41<00:00,  2.90it/s]


Epoch: 155 	Test Accuracy: 96.02
EarlyStopping counter: 6 out of 100


loss: 0.4028: 100%|██████████| 469/469 [02:42<00:00,  2.89it/s]


Epoch: 156 	Test Accuracy: 95.86
EarlyStopping counter: 7 out of 100


loss: 0.3771: 100%|██████████| 469/469 [02:41<00:00,  2.91it/s]


Epoch: 157 	Test Accuracy: 96.29
EarlyStopping counter: 8 out of 100


loss: 0.3822: 100%|██████████| 469/469 [02:41<00:00,  2.90it/s]


Epoch: 158 	Test Accuracy: 95.89
EarlyStopping counter: 9 out of 100


loss: 0.3631: 100%|██████████| 469/469 [02:40<00:00,  2.92it/s]


Epoch: 159 	Test Accuracy: 96.27
EarlyStopping counter: 10 out of 100


loss: 0.3859: 100%|██████████| 469/469 [02:41<00:00,  2.91it/s]


Epoch: 160 	Test Accuracy: 95.77
EarlyStopping counter: 11 out of 100


loss: 0.3793: 100%|██████████| 469/469 [02:41<00:00,  2.91it/s]


Epoch: 161 	Test Accuracy: 95.72
EarlyStopping counter: 12 out of 100


loss: 0.4077: 100%|██████████| 469/469 [02:43<00:00,  2.87it/s]


Epoch: 162 	Test Accuracy: 95.8
EarlyStopping counter: 13 out of 100


loss: 0.3862: 100%|██████████| 469/469 [02:42<00:00,  2.89it/s]


Epoch: 163 	Test Accuracy: 95.65
EarlyStopping counter: 14 out of 100


loss: 0.4017: 100%|██████████| 469/469 [02:44<00:00,  2.84it/s]


Epoch: 164 	Test Accuracy: 95.55
EarlyStopping counter: 15 out of 100


loss: 0.3431: 100%|██████████| 469/469 [02:41<00:00,  2.90it/s]


Epoch: 165 	Test Accuracy: 95.7
EarlyStopping counter: 16 out of 100


loss: 0.3536:  54%|█████▎    | 252/469 [01:26<01:11,  3.03it/s]

In [None]:
binarize = BinarizeF.apply; print(binarize)
cfg = config
cfg['model'] = 'NetFC(config, [128, 128])'
run(cfg)

<bound method Function.apply of <class '__main__.BinarizeF'>>


In [None]:
cfg = config
cfg['model_class_name'] = 'Net'
run(cfg)

In [None]:
cfg = config
cfg['model_class_name'] = 'NetFC_FirstConv'
run(cfg)

In [None]:
cfg = config
cfg['model_class_name'] = 'NetFC_FirstConv'
run(cfg)


In [None]:
cfg = config
cfg['model'] = 'NetFC_FirstConv(config)'
run(config=cfg)