In [13]:
# Copyright (c) 2024 Byeonghyeon Kim 
# github site: https://github.com/bhkim003/ByeonghyeonKim
# email: bhkim003@snu.ac.kr
 
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
 
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


In [14]:
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.datasets
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt

import time

from snntorch import spikegen
import matplotlib.pyplot as plt
import snntorch.spikeplot as splt
from IPython.display import HTML

from tqdm import tqdm

In [15]:
class SYNAPSE_FC_BPTT(nn.Module):
    def __init__(self, in_features, out_features, trace_const1=1, trace_const2=0.7, TIME=8):
        super(SYNAPSE_FC_BPTT, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.trace_const1 = trace_const1
        self.trace_const2 = trace_const2

        # self.weight = torch.randn(self.out_features, self.in_features, requires_grad=True)
        # self.bias = torch.randn(self.out_features, requires_grad=True)
        self.weight = nn.Parameter(torch.randn(self.out_features, self.in_features))
        self.bias = nn.Parameter(torch.randn(self.out_features))

        self.TIME = TIME

    def forward(self, spike):
        # spike: [Time, Batch, Features]   
        Time = spike.shape[0]
        assert Time == self.TIME, 'Time dimension should be same as TIME'
        Batch = spike.shape[1] 

        output_current = []

        for t in range(Time):
            output_current.append(F.linear(spike[t], weight = self.weight, bias= self.bias))

        output_current = torch.stack(output_current, dim=0)
        return output_current 


class SYNAPSE_CONV_BPTT(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, trace_const1=1, trace_const2=0.7, TIME=8):
        super(SYNAPSE_CONV_BPTT, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.trace_const1 = trace_const1
        self.trace_const2 = trace_const2
        # self.weight = torch.randn(self.out_channels, self.in_channels, self.kernel_size, self.kernel_size, requires_grad=True)
        # self.bias = torch.randn(self.out_channels, requires_grad=True)
        self.weight = nn.Parameter(torch.randn(self.out_channels, self.in_channels, self.kernel_size, self.kernel_size))
        self.bias = nn.Parameter(torch.randn(self.out_channels))

        self.TIME = TIME

    def forward(self, spike):
        # spike: [Time, Batch, Channel, Height, Width]   
        # print('spike.shape', spike.shape)
        Time = spike.shape[0]
        assert Time == self.TIME, 'Time dimension should be same as TIME'
        Batch = spike.shape[1] 
        Channel = self.out_channels
        Height = (spike.shape[3] + self.padding*2 - self.kernel_size) // self.stride + 1
        Width = (spike.shape[4] + self.padding*2 - self.kernel_size) // self.stride + 1

        # output_current = torch.zeros(Time, Batch, Channel, Height, Width, device=spike.device)
        output_current = []
        
        for t in range(Time):
            # print(f'time:{t}', torch.sum(spike_detach[t]/ torch.numel(spike_detach[t])))
            output_current.append(F.conv2d(spike[t], self.weight, bias=self.bias, stride=self.stride, padding=self.padding))
            # print(f'time:{t}', torch.sum(output_current[t]/ torch.numel(output_current[t])))

        output_current = torch.stack(output_current, dim=0)
        return output_current


class SYNAPSE_FC_METHOD(torch.autograd.Function):
    @staticmethod
    def forward(ctx, spike_one_time, spike_now, weight, bias):
        ctx.save_for_backward(spike_one_time, spike_now, weight, bias)
        return F.linear(spike_one_time, weight, bias=bias)

    @staticmethod
    def backward(ctx, grad_output_current):
        #############밑에부터 수정해라#######
        spike_one_time, spike_now, weight, bias = ctx.saved_tensors
        
        ## 이거 클론해야되는지 모르겠음!!!!
        grad_output_current_clone = grad_output_current.clone()

        grad_input_spike = grad_weight = grad_bias = None


        if ctx.needs_input_grad[0]:
            grad_input_spike = grad_output_current_clone @ weight
        if ctx.needs_input_grad[2]:
            grad_weight = grad_output_current_clone.t() @ spike_now
        if bias is not None and ctx.needs_input_grad[3]:
            grad_bias = grad_output_current_clone.sum(0)

        return grad_input_spike, None, grad_weight, grad_bias

    
class SYNAPSE_FC(nn.Module):
    def __init__(self, in_features, out_features, trace_const1=1, trace_const2=0.7, TIME=8):
        super(SYNAPSE_FC, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.trace_const1 = trace_const1
        self.trace_const2 = trace_const2

        # self.weight = torch.randn(self.out_features, self.in_features, requires_grad=True)
        # self.bias = torch.randn(self.out_features, requires_grad=True)
        self.weight = nn.Parameter(torch.randn(self.out_features, self.in_features))
        self.bias = nn.Parameter(torch.randn(self.out_features))

        self.TIME = TIME

    def forward(self, spike):
        # spike: [Time, Batch, Features]   
        Time = spike.shape[0]
        assert Time == self.TIME, 'Time dimension should be same as TIME'
        Batch = spike.shape[1] 

        # output_current = torch.zeros(Time, Batch, self.out_features, device=spike.device)
        output_current = []

        # spike_detach = spike.detach().clone()
        spike_detach = spike.detach()
        spike_past = torch.zeros_like(spike_detach[0], device=spike.device)
        spike_now = torch.zeros_like(spike_detach[0], device=spike.device)

        for t in range(Time):
            spike_now = self.trace_const1*spike_detach[t] + self.trace_const2*spike_past
            # output_current[t]= SYNAPSE_FC_METHOD.apply(spike[t], spike_now, self.weight, self.bias) 
            output_current.append( SYNAPSE_FC_METHOD.apply(spike[t], spike_now, self.weight, self.bias) )
            
            spike_past = spike_now

        output_current = torch.stack(output_current, dim=0)
        return output_current 



class SYNAPSE_CONV_METHOD(torch.autograd.Function):
    @staticmethod
    def forward(ctx, spike_one_time, spike_now, weight, bias, stride=1, padding=1):
        ctx.save_for_backward(spike_one_time, spike_now, weight, bias, torch.tensor([stride], requires_grad=False), torch.tensor([padding], requires_grad=False))
        return F.conv2d(spike_one_time, weight, bias=bias, stride=stride, padding=padding)

    @staticmethod
    def backward(ctx, grad_output_current):
        spike_one_time, spike_now, weight, bias, stride, padding = ctx.saved_tensors
        stride=stride.item()
        padding=padding.item()
        
        ## 이거 클론해야되는지 모르겠음!!!!
        grad_output_current_clone = grad_output_current.clone()

        grad_input_spike = grad_weight = grad_bias = None


        if ctx.needs_input_grad[0]:
            grad_input_spike = F.conv_transpose2d(grad_output_current_clone, weight, stride=stride, padding=padding)
        if ctx.needs_input_grad[2]:
            grad_weight = torch.nn.grad.conv2d_weight(spike_now, weight.shape, grad_output_current_clone,
                                                    stride=stride, padding=padding)
        if bias is not None and ctx.needs_input_grad[3]:
            grad_bias = grad_output_current_clone.sum((0, -1, -2))

        return grad_input_spike, None, grad_weight, grad_bias, None, None

    



class SYNAPSE_CONV(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, trace_const1=1, trace_const2=0.7, TIME=8):
        super(SYNAPSE_CONV, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.trace_const1 = trace_const1
        self.trace_const2 = trace_const2
        # self.weight = torch.randn(self.out_channels, self.in_channels, self.kernel_size, self.kernel_size, requires_grad=True)
        # self.bias = torch.randn(self.out_channels, requires_grad=True)
        self.weight = nn.Parameter(torch.randn(self.out_channels, self.in_channels, self.kernel_size, self.kernel_size))
        self.bias = nn.Parameter(torch.randn(self.out_channels))

        self.TIME = TIME

    def forward(self, spike):
        # spike: [Time, Batch, Channel, Height, Width]   
        # print('spike.shape', spike.shape)
        Time = spike.shape[0]
        assert Time == self.TIME, 'Time dimension should be same as TIME'
        Batch = spike.shape[1] 
        Channel = self.out_channels
        Height = (spike.shape[3] + self.padding*2 - self.kernel_size) // self.stride + 1
        Width = (spike.shape[4] + self.padding*2 - self.kernel_size) // self.stride + 1

        # output_current = torch.zeros(Time, Batch, Channel, Height, Width, device=spike.device)
        output_current = []
        
        # spike_detach = spike.detach().clone()
        spike_detach = spike.detach()
        spike_past = torch.zeros_like(spike_detach[0])
        spike_now = torch.zeros_like(spike_detach[0])
        for t in range(Time):
            # print(f'time:{t}', torch.sum(spike_detach[t]/ torch.numel(spike_detach[t])))
            spike_now = self.trace_const1*spike_detach[t] + self.trace_const2*spike_past

            # output_current[t]= SYNAPSE_CONV_METHOD.apply(spike[t], spike_now, self.weight, self.bias, self.stride, self.padding) 
            output_current.append( SYNAPSE_CONV_METHOD.apply(spike[t], spike_now, self.weight, self.bias, self.stride, self.padding) )
            
            spike_past = spike_now
            # print(f'time:{t}', torch.sum(output_current[t]/ torch.numel(output_current[t])))

        output_current = torch.stack(output_current, dim=0)
        return output_current



class LIF_METHOD(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_current_one_time, v_one_time, v_decay, v_threshold, v_reset, sg_width, surrogate):
        v_one_time = v_one_time * v_decay + input_current_one_time # leak + pre-synaptic current integrate
        spike = (v_one_time >= v_threshold).float() #fire
        if surrogate == 'sigmoid':
            surrogate = 1
        elif surrogate == 'rectangle':
            surrogate = 2
        elif surrogate == 'rough_rectangle':
            surrogate = 3
        else:
            pass
        ctx.save_for_backward(v_one_time, torch.tensor([v_decay], requires_grad=False), 
                            torch.tensor([v_threshold], requires_grad=False), 
                            torch.tensor([v_reset], requires_grad=False), 
                            torch.tensor([sg_width], requires_grad=False),
                            torch.tensor([surrogate], requires_grad=False)) # save before reset
        
        v_one_time = (v_one_time - spike * v_threshold).clamp_min(0) # reset
        return spike, v_one_time

    @staticmethod
    def backward(ctx, grad_output_spike, grad_output_v):
        v_one_time, v_decay, v_threshold, v_reset, sg_width, surrogate = ctx.saved_tensors
        v_decay=v_decay.item()
        v_threshold=v_threshold.item()
        v_reset=v_reset.item()
        sg_width=sg_width.item()
        surrogate=surrogate.item()

        grad_input_current = grad_output_spike.clone()
        # grad_temp_v = grad_output_v.clone() # not used

        ################ select one of the following surrogate gradient functions ################
        if (surrogate == 1):
            #===========surrogate gradient function (sigmoid)
            sig = torch.sigmoid((v_one_time - v_threshold))
            grad_input_current *= 4*sig*(1-sig)
            # grad_x = grad_output * (1. - sgax) * sgax * ctx.alpha

        elif (surrogate == 2):
            # ===========surrogate gradient function (rectangle)
            grad_input_current *= ((v_one_time - v_threshold).abs() < sg_width/2).float() / sg_width

        elif (surrogate == 3):
            #===========surrogate gradient function (rough rectangle)
            grad_input_current[(v_one_time - v_threshold).abs() > sg_width/2] = 0
        else: 
            pass
        ###########################################################################################
        return grad_input_current, None, None, None, None, None, None

class LIF_layer(nn.Module):
    def __init__ (self, v_init , v_decay , v_threshold , v_reset , sg_width, surrogate):
        super(LIF_layer, self).__init__()
        self.v_init = v_init
        self.v_decay = v_decay
        self.v_threshold = v_threshold
        self.v_reset = v_reset
        self.sg_width = sg_width
        self.surrogate = surrogate

    def forward(self, input_current):
        v = torch.full_like(input_current, fill_value = self.v_init, dtype = torch.float) # v (membrane potential) init
        post_spike = torch.full_like(input_current, fill_value = self.v_init, device=input_current.device, dtype = torch.float) # v (membrane potential) init
        # i와 v와 post_spike size는 여기서 다 같음: [Time, Batch, Channel, Height, Width] 
        Time = v.shape[0]
        for t in range(Time):
            # leaky하고 input_current 더하고 fire하고 reset까지 (backward직접처리)
            post_spike[t], v[t] = LIF_METHOD.apply(input_current[t], v[t], 
                                            self.v_decay, self.v_threshold, self.v_reset, self.sg_width, self.surrogate) 
        return post_spike 




class DimChanger_for_pooling(nn.Module):
    def __init__(self, module):
        super(DimChanger_for_pooling, self).__init__()
        self.ann_module = module

    def forward(self, x):
        timestep, batch_size, *dim = x.shape
        output = self.ann_module(x.reshape(timestep * batch_size, *dim))
        _, *dim = output.shape
        output = output.view(timestep, batch_size, *dim).contiguous()
        return output


class DimChanger_for_FC(nn.Module):
    def __init__(self):
        super(DimChanger_for_FC, self).__init__()

    def forward(self, x):
        x = x.view(x.size(0), x.size(1), -1)
        return x

class tdBatchNorm(nn.BatchNorm2d):
    def __init__(self, channel):
        super(tdBatchNorm, self).__init__(channel)
        # according to tdBN paper, the initialized weight is changed to alpha*Vth
        # self.weight.data.mul_(0.5)

    def forward(self, x):
        T, B, *spatial_dims = x.shape
        out = super().forward(x.reshape(T * B, *spatial_dims))
        TB, *spatial_dims = out.shape
        out = out.view(T, B, *spatial_dims).contiguous()
        return out
    
    
class tdBatchNorm_FC(nn.BatchNorm1d):
    def __init__(self, channel):
        super(tdBatchNorm_FC, self).__init__(channel)
        # according to tdBN paper, the initialized weight is changed to alpha*Vth
        # self.weight.data.mul_(0.5)

    def forward(self, x):
        T, B, *spatial_dims = x.shape
        out = super().forward(x.reshape(T * B, *spatial_dims))
        TB, *spatial_dims = out.shape
        out = out.view(T, B, *spatial_dims).contiguous()
        return out


class BatchNorm(nn.Module):
    def __init__(self, out_channels, TIME):
        super(BatchNorm, self).__init__()
        self.out_channels = out_channels
        self.TIME = TIME
        self.bn_layers = nn.ModuleList([nn.BatchNorm2d(self.out_channels) for _ in range(self.TIME)])

    def forward(self, x):
        # out = torch.zeros_like(x, device=x.device) #Time, Batch, Channel, Height, Width
        out = [] #Time, Batch, Channel, Height, Width
        for t in range(self.TIME):
            out.append(self.bn_layers[t](x[t]))
        out = torch.stack(out, dim=0)
        return out
    
class BatchNorm_FC(nn.Module):
    def __init__(self, out_channels, TIME):
        super(BatchNorm_FC, self).__init__()
        self.out_channels = out_channels
        self.TIME = TIME
        self.bn_layers = nn.ModuleList([nn.BatchNorm1d(self.out_channels) for _ in range(self.TIME)])

    def forward(self, x):
        # out = torch.zeros_like(x, device=x.device) #Time, Batch, Channel, Height, Width
        out = [] #Time, Batch, Channel, Height, Width
        for t in range(self.TIME):
            out.append(self.bn_layers[t](x[t]))
        out = torch.stack(out, dim=0)
        return out

def make_layers_conv(cfg, in_c, IMAGE_SIZE,
                     synapse_conv_kernel_size, synapse_conv_stride, 
                     synapse_conv_padding, synapse_conv_trace_const1, 
                     synapse_conv_trace_const2, 
                     lif_layer_v_init, lif_layer_v_decay, 
                     lif_layer_v_threshold, lif_layer_v_reset,
                     lif_layer_sg_width,
                     tdBN_on,
                     BN_on, TIME,
                     surrogate,
                     BPTT_on,
                     synapse_fc_out_features):
    
    layers = []
    in_channels = in_c
    img_size_var = IMAGE_SIZE
    for which in cfg:
        if which == 'P':
            layers += [DimChanger_for_pooling(nn.AvgPool2d(kernel_size=2, stride=2))]
            # layers += [DimChanger_for_pooling(nn.MaxPool2d(kernel_size=2, stride=2))]
            img_size_var = img_size_var // 2
        else:
            out_channels = which
            if (BPTT_on == False):
                layers += [SYNAPSE_CONV(in_channels=in_channels,
                                        out_channels=out_channels, 
                                        kernel_size=synapse_conv_kernel_size, 
                                        stride=synapse_conv_stride, 
                                        padding=synapse_conv_padding, 
                                        trace_const1=synapse_conv_trace_const1, 
                                        trace_const2=synapse_conv_trace_const2,
                                        TIME=TIME)]
            else:
                layers += [SYNAPSE_CONV_BPTT(in_channels=in_channels,
                                        out_channels=out_channels, 
                                        kernel_size=synapse_conv_kernel_size, 
                                        stride=synapse_conv_stride, 
                                        padding=synapse_conv_padding, 
                                        trace_const1=synapse_conv_trace_const1, 
                                        trace_const2=synapse_conv_trace_const2,
                                        TIME=TIME)]
            
            img_size_var = (img_size_var - synapse_conv_kernel_size + 2*synapse_conv_padding)//synapse_conv_stride + 1
           
            in_channels = which
            
            if (tdBN_on == True):
                layers += [tdBatchNorm(in_channels)] # 여기서 in_channel이 out_channel임

            if (BN_on == True):
                layers += [BatchNorm(in_channels, TIME)]

            layers += [LIF_layer(v_init=lif_layer_v_init, 
                                    v_decay=lif_layer_v_decay, 
                                    v_threshold=lif_layer_v_threshold, 
                                    v_reset=lif_layer_v_reset, 
                                    sg_width=lif_layer_sg_width,
                                    surrogate=surrogate)]
            
    layers += [DimChanger_for_FC()]
    if (BPTT_on == False):
        layers += [SYNAPSE_FC(in_features=in_channels*img_size_var*img_size_var,  # 마지막CONV의 OUT_CHANNEL * H * W
                                        out_features=synapse_fc_out_features, 
                                        trace_const1=synapse_conv_trace_const1, 
                                        trace_const2=synapse_conv_trace_const2,
                                        TIME=TIME)]
    else:
        layers += [SYNAPSE_FC_BPTT(in_features=in_channels*img_size_var*img_size_var,  # 마지막CONV의 OUT_CHANNEL * H * W
                                        out_features=synapse_fc_out_features, 
                                        trace_const1=synapse_conv_trace_const1, 
                                        trace_const2=synapse_conv_trace_const2,
                                        TIME=TIME)]

    return nn.Sequential(*layers)



class MY_SNN_CONV(nn.Module):
    def __init__(self, cfg, in_c, IMAGE_SIZE,
                     synapse_conv_kernel_size, synapse_conv_stride, 
                     synapse_conv_padding, synapse_conv_trace_const1, 
                     synapse_conv_trace_const2, 
                     lif_layer_v_init, lif_layer_v_decay, 
                     lif_layer_v_threshold, lif_layer_v_reset,
                     lif_layer_sg_width,
                     synapse_fc_out_features, synapse_fc_trace_const1, synapse_fc_trace_const2,
                     tdBN_on,
                     BN_on, TIME,
                     surrogate,
                     BPTT_on):
        super(MY_SNN_CONV, self).__init__()
        self.layers = make_layers_conv(cfg, in_c, IMAGE_SIZE,
                                    synapse_conv_kernel_size, synapse_conv_stride, 
                                    synapse_conv_padding, synapse_conv_trace_const1, 
                                    synapse_conv_trace_const2, 
                                    lif_layer_v_init, lif_layer_v_decay, 
                                    lif_layer_v_threshold, lif_layer_v_reset,
                                    lif_layer_sg_width,
                                    tdBN_on,
                                    BN_on, TIME,
                                    surrogate,
                                    BPTT_on,
                                    synapse_fc_out_features)


    def forward(self, spike_input):
        # inputs: [Batch, Time, Channel, Height, Width]   
        spike_input = spike_input.permute(1, 0, 2, 3, 4)
        # inputs: [Time, Batch, Channel, Height, Width]   
        spike_input = self.layers(spike_input)

        spike_input = spike_input.sum(axis=0)
        return spike_input



def make_layers_fc(cfg, in_c, IMAGE_SIZE, out_c,
                     synapse_fc_trace_const1, synapse_fc_trace_const2, 
                     lif_layer_v_init, lif_layer_v_decay, 
                     lif_layer_v_threshold, lif_layer_v_reset,
                     lif_layer_sg_width,
                     tdBN_on,
                     BN_on, TIME,
                     surrogate,
                     BPTT_on):

    layers = []
    img_size = IMAGE_SIZE
    in_channels = in_c * img_size * img_size
    class_num = out_c
    for which in cfg:
        out_channels = which

        if(BPTT_on == False):
            layers += [SYNAPSE_FC(in_features=in_channels,  # 마지막CONV의 OUT_CHANNEL * H * W
                                        out_features=out_channels, 
                                        trace_const1=synapse_fc_trace_const1, 
                                        trace_const2=synapse_fc_trace_const2,
                                        TIME=TIME)]
        else:
            layers += [SYNAPSE_FC_BPTT(in_features=in_channels,  # 마지막CONV의 OUT_CHANNEL * H * W
                                        out_features=out_channels, 
                                        trace_const1=synapse_fc_trace_const1, 
                                        trace_const2=synapse_fc_trace_const2,
                                        TIME=TIME)]
            



        in_channels = which
        
        if (tdBN_on == True):
            layers += [tdBatchNorm_FC(in_channels)] # 여기서 in_channel이 out_channel임

        if (BN_on == True):
            layers += [BatchNorm_FC(in_channels, TIME)]

        layers += [LIF_layer(v_init=lif_layer_v_init, 
                                v_decay=lif_layer_v_decay, 
                                v_threshold=lif_layer_v_threshold, 
                                v_reset=lif_layer_v_reset, 
                                sg_width=lif_layer_sg_width,
                                surrogate=surrogate)]

    
    out_channels = class_num
    if(BPTT_on == False):
        layers += [SYNAPSE_FC(in_features=in_channels,  # 마지막CONV의 OUT_CHANNEL * H * W
                                    out_features=out_channels, 
                                    trace_const1=synapse_fc_trace_const1, 
                                    trace_const2=synapse_fc_trace_const2,
                                    TIME=TIME)]
    else:
        layers += [SYNAPSE_FC_BPTT(in_features=in_channels,  # 마지막CONV의 OUT_CHANNEL * H * W
                                    out_features=out_channels, 
                                    trace_const1=synapse_fc_trace_const1, 
                                    trace_const2=synapse_fc_trace_const2,
                                    TIME=TIME)]
        
    return nn.Sequential(*layers)

class MY_SNN_FC(nn.Module):
    def __init__(self, cfg, in_c, IMAGE_SIZE, out_c,
                     synapse_fc_trace_const1, synapse_fc_trace_const2, 
                     lif_layer_v_init, lif_layer_v_decay, 
                     lif_layer_v_threshold, lif_layer_v_reset,
                     lif_layer_sg_width,
                     tdBN_on,
                     BN_on, TIME,
                     surrogate,
                     BPTT_on):
        super(MY_SNN_FC, self).__init__()

        self.layers = make_layers_fc(cfg, in_c, IMAGE_SIZE, out_c,
                     synapse_fc_trace_const1, synapse_fc_trace_const2, 
                     lif_layer_v_init, lif_layer_v_decay, 
                     lif_layer_v_threshold, lif_layer_v_reset,
                     lif_layer_sg_width,
                     tdBN_on,
                     BN_on, TIME,
                     surrogate,
                     BPTT_on)

    def forward(self, spike_input):
        # inputs: [Batch, Time, Channel, Height, Width]   
        spike_input = spike_input.permute(1, 0, 2, 3, 4)
        # inputs: [Time, Batch, Channel, Height, Width]   
        spike_input = spike_input.view(spike_input.size(0), spike_input.size(1), -1)
        
        spike_input = self.layers(spike_input)

        spike_input = spike_input.sum(axis=0)

        return spike_input
    












def my_snn_system(devices = "0,1,2,3",
                    my_seed = 42,
                    TIME = 8,
                    BATCH = 256,
                    IMAGE_SIZE = 32,
                    which_data = 'CIFAR10',
                    rate_coding = True,
    
                    lif_layer_v_init = 0.0,
                    lif_layer_v_decay = 0.6,
                    lif_layer_v_threshold = 1.2,
                    lif_layer_v_reset = 0.0,
                    lif_layer_sg_width = 1,

                    # synapse_conv_in_channels = IMAGE_PIXEL_CHANNEL,
                    synapse_conv_kernel_size = 3,
                    synapse_conv_stride = 1,
                    synapse_conv_padding = 1,
                    synapse_conv_trace_const1 = 1,
                    synapse_conv_trace_const2 = 0.6,

                    # synapse_fc_out_features = CLASS_NUM,
                    synapse_fc_trace_const1 = 1,
                    synapse_fc_trace_const2 = 0.6,

                    pre_trained = False,
                    convTrue_fcFalse = True,
                    cfg = [64, 64],
                    pre_trained_path = "net_save/save_now_net.pth",
                    learning_rate = 0.0001,
                    epoch_num = 100,
                    verbose_interval = 100,
                    tdBN_on = False,
                    BN_on = False,

                    surrogate = 'sigmoid',

                    gradient_verbose = False,

                    BPTT_on = False
                  ):


    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
    os.environ["CUDA_VISIBLE_DEVICES"]= devices

    
    torch.manual_seed(my_seed)


    






    if (which_data == 'MNIST'):
        data_path = '/data2'

        if rate_coding :
            transform = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0,), (1,))])
        else : 
            transform = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5,),(0.5))])

        trainset = torchvision.datasets.MNIST(root=data_path,
                                            train=True,
                                            download=True,
                                            transform=transform)


        testset = torchvision.datasets.MNIST(root=data_path,
                                            train=False,
                                            download=True,
                                            transform=transform)

        train_loader = DataLoader(trainset,
                                batch_size =BATCH,
                                shuffle = True,
                                num_workers =2)
        test_loader = DataLoader(testset,
                                batch_size =BATCH,
                                shuffle = False,
                                num_workers =2)


    if (which_data == 'CIFAR10'):
        data_path = '/data2/cifar10'

        if rate_coding :
            transform_train = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                                                transforms.RandomHorizontalFlip(),
                                                transforms.ToTensor()])

            transform_test = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                                                transforms.ToTensor()])
        
        else :
            transform_train = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                                                transforms.RandomHorizontalFlip(),
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])
                                            # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

            transform_test = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),])
                                            # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))


        trainset = torchvision.datasets.CIFAR10(root=data_path,
                                            train=True,
                                            download=True,
                                            transform=transform_train)


        testset = torchvision.datasets.CIFAR10(root=data_path,
                                            train=False,
                                            download=True,
                                            transform=transform_test)

        train_loader = DataLoader(trainset,
                                batch_size =BATCH,
                                shuffle = True,
                                num_workers =2)
        test_loader = DataLoader(testset,
                                batch_size =BATCH,
                                shuffle = False,
                                num_workers =2)

        '''
        classes = ('plane', 'car', 'bird', 'cat', 'deer',
                'dog', 'frog', 'horse', 'ship', 'truck') 
        '''

    if (which_data == 'FASHION_MNIST'):
        data_path = '/data2'

        if rate_coding :
            transform = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                                        transforms.ToTensor()])
        else : 
            transform = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5,),(0.5))])

        trainset = torchvision.datasets.FashionMNIST(root=data_path,
                                            train=True,
                                            download=True,
                                            transform=transform)


        testset = torchvision.datasets.FashionMNIST(root=data_path,
                                            train=False,
                                            download=True,
                                            transform=transform)

        train_loader = DataLoader(trainset,
                                batch_size =BATCH,
                                shuffle = True,
                                num_workers =2)
        test_loader = DataLoader(testset,
                                batch_size =BATCH,
                                shuffle = False,
                                num_workers =2)

        

    data_iter = IMAGE_PIXEL_CHANNEL = iter(train_loader)
    images, labels = data_iter.next()

    # 채널 수와 클래스 개수를 확인합니다.
    synapse_conv_in_channels = images.shape[1]
    synapse_fc_out_features = CLASS_NUM = len(torch.unique(labels))















    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    if pre_trained == False:
        if (convTrue_fcFalse == False):
            net = MY_SNN_FC(cfg, synapse_conv_in_channels, IMAGE_SIZE, synapse_fc_out_features,
                     synapse_fc_trace_const1, synapse_fc_trace_const2, 
                     lif_layer_v_init, lif_layer_v_decay, 
                     lif_layer_v_threshold, lif_layer_v_reset,
                     lif_layer_sg_width,
                     tdBN_on,
                     BN_on, TIME,
                     surrogate,
                     BPTT_on).to(device)
        else:
            net = MY_SNN_CONV(cfg, synapse_conv_in_channels, IMAGE_SIZE,
                     synapse_conv_kernel_size, synapse_conv_stride, 
                     synapse_conv_padding, synapse_conv_trace_const1, 
                     synapse_conv_trace_const2, 
                     lif_layer_v_init, lif_layer_v_decay, 
                     lif_layer_v_threshold, lif_layer_v_reset,
                     lif_layer_sg_width,
                     synapse_fc_out_features, synapse_fc_trace_const1, synapse_fc_trace_const2,
                     tdBN_on,
                     BN_on, TIME,
                     surrogate,
                     BPTT_on).to(device)
        net = torch.nn.DataParallel(net)
    else:
        net = torch.load(pre_trained_path)

    val_acc = 0

    net = net.to(device)

    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)
    # optimizer = torch.optim.Adam(net.parameters(), lr=0.00001)


    for epoch in range(epoch_num):
        print('EPOCH', epoch)
        epoch_start_time = time.time()
        running_loss = 0.0
        # for i, data in enumerate(train_loader, 0):
        for i, data in tqdm(enumerate(train_loader, 0), total=len(train_loader), desc='train', dynamic_ncols=True, position=0, leave=True):
            # print('iter', i)
            net.train()

            # print('\niter', i)
            iter_one_train_time_start = time.time()

            inputs, labels = data

            if rate_coding == True :
                inputs = spikegen.rate(inputs, num_steps=TIME)
            else :
                inputs = inputs.repeat(TIME, 1, 1, 1, 1)
            
            # inputs: [Time, Batch, Channel, Height, Width]   

            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            # inputs: [Time, Batch, Channel, Height, Width]   
            inputs = inputs.permute(1, 0, 2, 3, 4) # net에 넣어줄때는 batch가 젤 앞 차원으로 와야함. # dataparallel때매
            # inputs: [Batch, Time, Channel, Height, Width]   
        
            outputs = net(inputs)

            batch = BATCH 
            if labels.size(0) != BATCH: 
                batch = labels.size(0)


            ####### training accruacy ######
            correct = 0
            total = 0
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted[0:batch] == labels).sum().item()
            if i % verbose_interval == 9:
                print(f'iter: {i} / {len(train_loader)}')
                print(f'training acc: {100 * correct / total:.2f}%')
            ################################
            training_acc_string = f'training acc: {100 * correct / total:.2f}%'

            loss = criterion(outputs[0:batch,:], labels)
            loss.backward()

            # # optimizer.zero_grad()와 loss.backward() 호출 후에 실행해야 합니다.
            if (gradient_verbose == True):
                if (i % 100 == 9):
                    print('\n\nepoch', epoch, 'iter', i)
                    for name, param in net.named_parameters():
                        if param.requires_grad:
                            print('\n\n\n\n' , name, param.grad)
            
            optimizer.step()

            running_loss += loss.item()
            # print("Epoch: {}, Iter: {}, Loss: {}".format(epoch + 1, i + 1, running_loss / 100))

            iter_one_train_time_end = time.time()
            elapsed_time = iter_one_train_time_end - iter_one_train_time_start  # 실행 시간 계산
            # print(f"iter_one_train_time: {elapsed_time} seconds")


            if i % verbose_interval == 9:
                iter_one_val_time_start = time.time()
                
                correct = 0
                total = 0
                with torch.no_grad():
                    net.eval()
                    how_many_val_image=0
                    for data in test_loader:
                        how_many_val_image += 1
                        inputs, labels = data
            
                        if rate_coding == True :
                            inputs = spikegen.rate(inputs, num_steps=TIME)
                        else :
                            inputs = inputs.repeat(TIME, 1, 1, 1, 1)

                        
                        inputs = inputs.to(device)
                        labels = labels.to(device)
                        outputs = net(inputs.permute(1, 0, 2, 3, 4))
                        _, predicted = torch.max(outputs.data, 1)
                        total += labels.size(0)
                        batch = BATCH 
                        if labels.size(0) != BATCH: 
                            batch = labels.size(0)
                        correct += (predicted[0:batch] == labels).sum().item()
                    print(f'{epoch}-{i} validation acc: {100 * correct / total:.2f}%\n')


                iter_one_val_time_end = time.time()
                elapsed_time = iter_one_val_time_end - iter_one_val_time_start  # 실행 시간 계산
                print(f"iter_one_val_time: {elapsed_time} seconds")
                if val_acc < correct / total:
                    val_acc = correct / total
                    torch.save(net.state_dict(), "net_save/save_now_net_weights.pth")
                    torch.save(net, "net_save/save_now_net.pth")
                    torch.save(net.module.state_dict(), "net_save/save_now_net_weights2.pth")
                    torch.save(net.module, "net_save/save_now_net2.pth")
        epoch_time_end = time.time()
        epoch_time = epoch_time_end - epoch_start_time  # 실행 시간 계산
        
        print(f"epoch_time: {epoch_time} seconds")
        print('\n')



In [18]:
decay = 0.95

my_snn_system(  devices = "2,3,4,5",
                my_seed = 42,
                TIME = 6,
                BATCH = 128,
                IMAGE_SIZE = 32,
                which_data = 'CIFAR10',# 'CIFAR10' 'MNIST' 'FASHION_MNIST'
                rate_coding = False, # True # False

                lif_layer_v_init = 0.0,
                lif_layer_v_decay = decay,
                lif_layer_v_threshold = 1.2,
                lif_layer_v_reset = 0.0, #현재 안씀. 걍 빼기 해버림
                lif_layer_sg_width = 1.0, # surrogate sigmoid 쓸 때는 의미없음

                # synapse_conv_in_channels = IMAGE_PIXEL_CHANNEL,
                synapse_conv_kernel_size = 3,
                synapse_conv_stride = 1,
                synapse_conv_padding = 1,
                synapse_conv_trace_const1 = 1,
                synapse_conv_trace_const2 = decay, # lif_layer_v_decay

                # synapse_fc_out_features = CLASS_NUM,
                synapse_fc_trace_const1 = 1,
                synapse_fc_trace_const2 = decay, # lif_layer_v_decay

                pre_trained = True, # True # False
                convTrue_fcFalse = True, # True # False
                cfg = [64, 128, 'P', 256, 256, 'P', 512, 512, 'P', 512, 512], # 끝에 linear classifier 하나 자동으로 붙습니다
                pre_trained_path = "net_save/save_now_net.pth",
                learning_rate = 0.000001,
                epoch_num = 200,
                verbose_interval = 100, #숫자 크게 하면 꺼짐
                tdBN_on = False,  # True # False
                BN_on = True,  # True # False
                
                surrogate = 'sigmoid', # 'rectangle' 'sigmoid' 'rough_rectangle'
                
                gradient_verbose = False,  # True # False

                BPTT_on = False  # True # False
                )

'''
cfg 종류 = {
[64, 64]
[64, 64, 64, 64]
[64, 128, 256]
[64, 128, 'P', 256, 256, 'P', 512, 512, 'P', 512, 512]
[64, 64]
}
'''


Files already downloaded and verified
Files already downloaded and verified
EPOCH 0


train:   2%|▏         | 9/391 [00:07<05:01,  1.27it/s]

iter: 9 / 391
training acc: 78.91%
0-9 validation acc: 71.44%

iter_one_val_time: 17.939154863357544 seconds


train:  28%|██▊       | 109/391 [01:45<03:42,  1.27it/s]

iter: 109 / 391
training acc: 76.56%
0-109 validation acc: 71.67%

iter_one_val_time: 7.7305519580841064 seconds


train:  53%|█████▎    | 209/391 [03:14<02:25,  1.25it/s]

iter: 209 / 391
training acc: 71.09%
0-209 validation acc: 71.75%

iter_one_val_time: 7.886085748672485 seconds


train:  79%|███████▉  | 309/391 [04:42<01:05,  1.25it/s]

iter: 309 / 391
training acc: 78.12%


train:  79%|███████▉  | 310/391 [04:51<04:12,  3.12s/it]

0-309 validation acc: 71.75%

iter_one_val_time: 7.728886842727661 seconds


train: 100%|██████████| 391/391 [05:56<00:00,  1.10it/s]

epoch_time: 356.9304814338684 seconds


EPOCH 1



train:   2%|▏         | 9/391 [00:07<05:06,  1.25it/s]

iter: 9 / 391
training acc: 71.88%
1-9 validation acc: 72.04%

iter_one_val_time: 7.879246473312378 seconds


train:  28%|██▊       | 109/391 [01:35<03:48,  1.23it/s]

iter: 109 / 391
training acc: 78.12%
1-109 validation acc: 72.13%

iter_one_val_time: 7.725550174713135 seconds


train:  53%|█████▎    | 209/391 [03:04<02:32,  1.20it/s]

iter: 209 / 391
training acc: 75.78%
1-209 validation acc: 72.42%

iter_one_val_time: 7.773592710494995 seconds


train:  79%|███████▉  | 309/391 [04:33<01:06,  1.24it/s]

iter: 309 / 391
training acc: 76.56%


train:  79%|███████▉  | 310/391 [04:41<04:17,  3.18s/it]

1-309 validation acc: 71.55%

iter_one_val_time: 7.919850587844849 seconds


train: 100%|██████████| 391/391 [05:46<00:00,  1.13it/s]

epoch_time: 347.24430871009827 seconds


EPOCH 2



train:   2%|▏         | 9/391 [00:07<05:05,  1.25it/s]

iter: 9 / 391
training acc: 77.34%


train:   3%|▎         | 10/391 [00:15<20:18,  3.20s/it]

2-9 validation acc: 72.06%

iter_one_val_time: 7.76666522026062 seconds


train:  28%|██▊       | 109/391 [01:35<03:45,  1.25it/s]

iter: 109 / 391
training acc: 76.56%


train:  28%|██▊       | 110/391 [01:44<15:02,  3.21s/it]

2-109 validation acc: 71.89%

iter_one_val_time: 8.037434577941895 seconds


train:  53%|█████▎    | 209/391 [03:03<02:26,  1.24it/s]

iter: 209 / 391
training acc: 82.81%
2-209 validation acc: 72.44%

iter_one_val_time: 7.9448816776275635 seconds


train:  79%|███████▉  | 309/391 [04:32<01:07,  1.22it/s]

iter: 309 / 391
training acc: 71.09%


train:  79%|███████▉  | 310/391 [04:40<04:13,  3.13s/it]

2-309 validation acc: 72.40%

iter_one_val_time: 7.7203803062438965 seconds


train: 100%|██████████| 391/391 [05:45<00:00,  1.13it/s]

epoch_time: 346.17419838905334 seconds


EPOCH 3



train:   2%|▏         | 9/391 [00:07<05:07,  1.24it/s]

iter: 9 / 391
training acc: 80.47%


train:   3%|▎         | 10/391 [00:16<20:50,  3.28s/it]

3-9 validation acc: 72.24%

iter_one_val_time: 8.024026870727539 seconds


train:  28%|██▊       | 109/391 [01:36<03:46,  1.24it/s]

iter: 109 / 391
training acc: 72.66%


train:  28%|██▊       | 110/391 [01:45<14:54,  3.18s/it]

3-109 validation acc: 71.88%

iter_one_val_time: 7.933567762374878 seconds


train:  53%|█████▎    | 209/391 [03:05<02:25,  1.25it/s]

iter: 209 / 391
training acc: 73.44%


train:  54%|█████▎    | 210/391 [03:14<09:40,  3.20s/it]

3-209 validation acc: 71.88%

iter_one_val_time: 8.015734195709229 seconds


train:  79%|███████▉  | 309/391 [04:33<01:05,  1.25it/s]

iter: 309 / 391
training acc: 79.69%


train:  79%|███████▉  | 310/391 [04:42<04:16,  3.17s/it]

3-309 validation acc: 71.96%

iter_one_val_time: 7.910808801651001 seconds


train: 100%|██████████| 391/391 [05:47<00:00,  1.13it/s]


epoch_time: 347.39737272262573 seconds


EPOCH 4


train:   2%|▏         | 9/391 [00:07<05:11,  1.23it/s]

iter: 9 / 391
training acc: 72.66%


train:   3%|▎         | 10/391 [00:16<20:51,  3.29s/it]

4-9 validation acc: 72.35%

iter_one_val_time: 8.010568141937256 seconds


train:  28%|██▊       | 109/391 [01:35<03:47,  1.24it/s]

iter: 109 / 391
training acc: 77.34%


train:  28%|██▊       | 110/391 [01:44<14:36,  3.12s/it]

4-109 validation acc: 72.05%

iter_one_val_time: 7.712385177612305 seconds


train:  53%|█████▎    | 209/391 [03:03<02:25,  1.25it/s]

iter: 209 / 391
training acc: 72.66%
4-209 validation acc: 72.69%

iter_one_val_time: 7.943535804748535 seconds


train:  79%|███████▉  | 309/391 [04:32<01:05,  1.24it/s]

iter: 309 / 391
training acc: 75.78%


train:  79%|███████▉  | 310/391 [04:40<04:12,  3.12s/it]

4-309 validation acc: 71.91%

iter_one_val_time: 7.737257719039917 seconds


train: 100%|██████████| 391/391 [05:46<00:00,  1.13it/s]

epoch_time: 346.4662718772888 seconds


EPOCH 5



train:   2%|▏         | 9/391 [00:07<05:11,  1.23it/s]

iter: 9 / 391
training acc: 77.34%


train:   3%|▎         | 10/391 [00:16<20:50,  3.28s/it]

5-9 validation acc: 72.59%

iter_one_val_time: 8.00572156906128 seconds


train:  28%|██▊       | 109/391 [01:35<03:45,  1.25it/s]

iter: 109 / 391
training acc: 70.31%
5-109 validation acc: 72.80%

iter_one_val_time: 7.731691360473633 seconds


train:  53%|█████▎    | 209/391 [03:04<02:27,  1.23it/s]

iter: 209 / 391
training acc: 77.34%


train:  54%|█████▎    | 210/391 [03:13<09:35,  3.18s/it]

5-209 validation acc: 72.29%

iter_one_val_time: 7.907668828964233 seconds


train:  79%|███████▉  | 309/391 [04:33<01:06,  1.23it/s]

iter: 309 / 391
training acc: 75.00%


train:  79%|███████▉  | 310/391 [04:42<04:24,  3.27s/it]

5-309 validation acc: 72.09%

iter_one_val_time: 8.173746824264526 seconds


train: 100%|██████████| 391/391 [05:48<00:00,  1.12it/s]

epoch_time: 348.5004925727844 seconds


EPOCH 6



train:   2%|▏         | 9/391 [00:07<05:08,  1.24it/s]

iter: 9 / 391
training acc: 71.88%


train:   3%|▎         | 10/391 [00:16<20:53,  3.29s/it]

6-9 validation acc: 72.51%

iter_one_val_time: 8.047602653503418 seconds


train:  28%|██▊       | 109/391 [01:36<03:49,  1.23it/s]

iter: 109 / 391
training acc: 76.56%


train:  28%|██▊       | 110/391 [01:45<14:54,  3.18s/it]

6-109 validation acc: 72.20%

iter_one_val_time: 7.9039387702941895 seconds


train:  53%|█████▎    | 209/391 [03:05<02:28,  1.23it/s]

iter: 209 / 391
training acc: 78.91%


train:  54%|█████▎    | 210/391 [03:13<09:43,  3.22s/it]

6-209 validation acc: 72.11%

iter_one_val_time: 8.029848337173462 seconds


train:  79%|███████▉  | 309/391 [04:34<01:06,  1.24it/s]

iter: 309 / 391
training acc: 73.44%


train:  79%|███████▉  | 310/391 [04:43<04:25,  3.28s/it]

6-309 validation acc: 72.52%

iter_one_val_time: 8.246172904968262 seconds


train: 100%|██████████| 391/391 [05:48<00:00,  1.12it/s]

epoch_time: 349.0660755634308 seconds


EPOCH 7



train:   2%|▏         | 9/391 [00:07<05:08,  1.24it/s]

iter: 9 / 391
training acc: 80.47%
7-9 validation acc: 73.06%

iter_one_val_time: 7.865886688232422 seconds


train:  28%|██▊       | 109/391 [01:36<03:48,  1.23it/s]

iter: 109 / 391
training acc: 64.84%


train:  28%|██▊       | 110/391 [01:45<14:56,  3.19s/it]

7-109 validation acc: 72.43%

iter_one_val_time: 7.9383323192596436 seconds


train:  53%|█████▎    | 209/391 [03:05<02:29,  1.22it/s]

iter: 209 / 391
training acc: 75.00%


train:  54%|█████▎    | 210/391 [03:14<09:26,  3.13s/it]

7-209 validation acc: 72.07%

iter_one_val_time: 7.7211997509002686 seconds


train:  79%|███████▉  | 309/391 [04:35<01:06,  1.24it/s]

iter: 309 / 391
training acc: 75.78%


train:  79%|███████▉  | 310/391 [04:52<07:53,  5.85s/it]

7-309 validation acc: 72.02%

iter_one_val_time: 16.793192148208618 seconds


train: 100%|██████████| 391/391 [05:58<00:00,  1.09it/s]

epoch_time: 358.46886587142944 seconds


EPOCH 8



train:   2%|▏         | 9/391 [00:07<05:12,  1.22it/s]

iter: 9 / 391
training acc: 74.22%


train:   3%|▎         | 10/391 [00:16<21:48,  3.43s/it]

8-9 validation acc: 72.63%

iter_one_val_time: 8.477874517440796 seconds


train:  28%|██▊       | 109/391 [01:37<03:47,  1.24it/s]

iter: 109 / 391
training acc: 71.88%


train:  28%|██▊       | 110/391 [01:46<14:55,  3.19s/it]

8-109 validation acc: 72.69%

iter_one_val_time: 7.9320714473724365 seconds


train:  53%|█████▎    | 209/391 [03:06<02:27,  1.23it/s]

iter: 209 / 391
training acc: 73.44%


train:  54%|█████▎    | 210/391 [03:15<09:42,  3.22s/it]

8-209 validation acc: 72.58%

iter_one_val_time: 8.032639980316162 seconds


train:  79%|███████▉  | 309/391 [04:36<01:05,  1.25it/s]

iter: 309 / 391
training acc: 81.25%


train:  79%|███████▉  | 310/391 [04:45<04:18,  3.19s/it]

8-309 validation acc: 72.59%

iter_one_val_time: 7.948352813720703 seconds


train: 100%|██████████| 391/391 [05:50<00:00,  1.11it/s]

epoch_time: 351.14962577819824 seconds


EPOCH 9



train:   2%|▏         | 9/391 [00:07<05:09,  1.23it/s]

iter: 9 / 391
training acc: 75.78%


train:   3%|▎         | 10/391 [00:16<20:36,  3.25s/it]

9-9 validation acc: 71.79%

iter_one_val_time: 7.8998377323150635 seconds


train:  28%|██▊       | 109/391 [01:36<03:49,  1.23it/s]

iter: 109 / 391
training acc: 73.44%
9-109 validation acc: 73.10%

iter_one_val_time: 7.697368621826172 seconds


train:  53%|█████▎    | 209/391 [03:05<02:26,  1.24it/s]

iter: 209 / 391
training acc: 70.31%


train:  54%|█████▎    | 210/391 [03:14<09:34,  3.17s/it]

9-209 validation acc: 72.34%

iter_one_val_time: 7.902353525161743 seconds


train:  79%|███████▉  | 309/391 [04:35<01:06,  1.24it/s]

iter: 309 / 391
training acc: 67.97%


train:  79%|███████▉  | 310/391 [04:44<04:18,  3.19s/it]

9-309 validation acc: 73.10%

iter_one_val_time: 7.934512615203857 seconds


train: 100%|██████████| 391/391 [05:49<00:00,  1.12it/s]

epoch_time: 349.6759581565857 seconds


EPOCH 10



train:   2%|▏         | 9/391 [00:07<05:09,  1.23it/s]

iter: 9 / 391
training acc: 72.66%


train:   3%|▎         | 10/391 [00:15<20:25,  3.22s/it]

10-9 validation acc: 72.64%

iter_one_val_time: 7.794861316680908 seconds


train:  28%|██▊       | 109/391 [01:35<03:45,  1.25it/s]

iter: 109 / 391
training acc: 75.00%


train:  28%|██▊       | 110/391 [01:44<14:38,  3.13s/it]

10-109 validation acc: 72.89%

iter_one_val_time: 7.758264541625977 seconds


train:  53%|█████▎    | 209/391 [03:03<02:26,  1.25it/s]

iter: 209 / 391
training acc: 72.66%


train:  54%|█████▎    | 210/391 [03:12<09:29,  3.15s/it]

10-209 validation acc: 72.96%

iter_one_val_time: 7.82605767250061 seconds


train:  79%|███████▉  | 309/391 [04:31<01:05,  1.25it/s]

iter: 309 / 391
training acc: 72.66%


train:  79%|███████▉  | 310/391 [04:40<04:17,  3.18s/it]

10-309 validation acc: 72.68%

iter_one_val_time: 7.913005352020264 seconds


train: 100%|██████████| 391/391 [05:45<00:00,  1.13it/s]

epoch_time: 345.94245743751526 seconds


EPOCH 11



train:   2%|▏         | 9/391 [00:07<05:09,  1.24it/s]

iter: 9 / 391
training acc: 75.78%


train:   3%|▎         | 10/391 [00:16<20:31,  3.23s/it]

11-9 validation acc: 72.52%

iter_one_val_time: 7.84930157661438 seconds


train:  28%|██▊       | 109/391 [01:35<03:46,  1.25it/s]

iter: 109 / 391
training acc: 74.22%


train:  28%|██▊       | 110/391 [01:44<14:47,  3.16s/it]

11-109 validation acc: 73.01%

iter_one_val_time: 7.8651933670043945 seconds


train:  53%|█████▎    | 209/391 [03:04<02:26,  1.24it/s]

iter: 209 / 391
training acc: 81.25%


train:  54%|█████▎    | 210/391 [03:13<09:34,  3.17s/it]

11-209 validation acc: 72.29%

iter_one_val_time: 7.8989503383636475 seconds


train:  79%|███████▉  | 309/391 [04:32<01:05,  1.24it/s]

iter: 309 / 391
training acc: 79.69%


train:  79%|███████▉  | 310/391 [04:41<04:18,  3.20s/it]

11-309 validation acc: 73.01%

iter_one_val_time: 7.970198631286621 seconds


train: 100%|██████████| 391/391 [05:46<00:00,  1.13it/s]

epoch_time: 347.01211404800415 seconds


EPOCH 12



train:   2%|▏         | 9/391 [00:07<05:07,  1.24it/s]

iter: 9 / 391
training acc: 78.12%


train:   3%|▎         | 10/391 [00:15<20:24,  3.21s/it]

12-9 validation acc: 72.92%

iter_one_val_time: 7.812759160995483 seconds


train:  28%|██▊       | 109/391 [01:35<03:48,  1.24it/s]

iter: 109 / 391
training acc: 77.34%
12-109 validation acc: 73.12%

iter_one_val_time: 7.7220423221588135 seconds


train:  53%|█████▎    | 209/391 [03:05<02:26,  1.25it/s]

iter: 209 / 391
training acc: 75.00%


train:  54%|█████▎    | 210/391 [03:14<09:35,  3.18s/it]

12-209 validation acc: 72.98%

iter_one_val_time: 7.927709102630615 seconds


train:  79%|███████▉  | 309/391 [04:34<01:06,  1.24it/s]

iter: 309 / 391
training acc: 65.62%
12-309 validation acc: 73.34%

iter_one_val_time: 7.781642436981201 seconds


train: 100%|██████████| 391/391 [05:48<00:00,  1.12it/s]

epoch_time: 348.73473834991455 seconds


EPOCH 13



train:   2%|▏         | 9/391 [00:07<05:06,  1.25it/s]

iter: 9 / 391
training acc: 79.69%


train:   3%|▎         | 10/391 [00:16<20:54,  3.29s/it]

13-9 validation acc: 72.86%

iter_one_val_time: 8.064665079116821 seconds


train:  28%|██▊       | 109/391 [01:35<03:46,  1.25it/s]

iter: 109 / 391
training acc: 75.78%


train:  28%|██▊       | 110/391 [01:44<14:51,  3.17s/it]

13-109 validation acc: 72.89%

iter_one_val_time: 7.914062023162842 seconds


train:  53%|█████▎    | 209/391 [03:04<02:26,  1.24it/s]

iter: 209 / 391
training acc: 74.22%


train:  54%|█████▎    | 210/391 [03:12<09:40,  3.21s/it]

13-209 validation acc: 72.80%

iter_one_val_time: 8.024676322937012 seconds


train:  79%|███████▉  | 309/391 [04:32<01:05,  1.25it/s]

iter: 309 / 391
training acc: 79.69%


train:  79%|███████▉  | 310/391 [04:41<04:17,  3.18s/it]

13-309 validation acc: 73.27%

iter_one_val_time: 7.936282634735107 seconds


train: 100%|██████████| 391/391 [05:46<00:00,  1.13it/s]

epoch_time: 346.9191937446594 seconds


EPOCH 14



train:   2%|▏         | 9/391 [00:07<05:07,  1.24it/s]

iter: 9 / 391
training acc: 71.09%


train:   3%|▎         | 10/391 [00:16<20:57,  3.30s/it]

14-9 validation acc: 72.33%

iter_one_val_time: 8.08316421508789 seconds


train:  28%|██▊       | 109/391 [01:36<03:48,  1.24it/s]

iter: 109 / 391
training acc: 75.00%


train:  28%|██▊       | 110/391 [01:44<14:43,  3.14s/it]

14-109 validation acc: 72.76%

iter_one_val_time: 7.781051397323608 seconds


train:  53%|█████▎    | 209/391 [03:04<02:26,  1.24it/s]

iter: 209 / 391
training acc: 74.22%


train:  54%|█████▎    | 210/391 [03:13<09:41,  3.22s/it]

14-209 validation acc: 72.87%

iter_one_val_time: 8.036190509796143 seconds


train:  79%|███████▉  | 309/391 [04:33<01:05,  1.25it/s]

iter: 309 / 391
training acc: 76.56%


train:  79%|███████▉  | 310/391 [04:42<04:18,  3.19s/it]

14-309 validation acc: 72.67%

iter_one_val_time: 7.969099283218384 seconds


train: 100%|██████████| 391/391 [05:47<00:00,  1.12it/s]

epoch_time: 347.9593415260315 seconds


EPOCH 15



train:   2%|▏         | 9/391 [00:07<05:06,  1.25it/s]

iter: 9 / 391
training acc: 76.56%


train:   3%|▎         | 10/391 [00:15<20:20,  3.20s/it]

15-9 validation acc: 72.41%

iter_one_val_time: 7.776015281677246 seconds


train:  28%|██▊       | 109/391 [01:35<03:46,  1.25it/s]

iter: 109 / 391
training acc: 77.34%


train:  28%|██▊       | 110/391 [01:44<14:41,  3.14s/it]

15-109 validation acc: 72.82%

iter_one_val_time: 7.780127048492432 seconds


train:  53%|█████▎    | 209/391 [03:03<02:25,  1.25it/s]

iter: 209 / 391
training acc: 65.62%


train:  54%|█████▎    | 210/391 [03:12<09:35,  3.18s/it]

15-209 validation acc: 73.09%

iter_one_val_time: 7.931337833404541 seconds


train:  79%|███████▉  | 309/391 [04:32<01:05,  1.25it/s]

iter: 309 / 391
training acc: 75.00%


train:  79%|███████▉  | 310/391 [04:41<04:16,  3.17s/it]

15-309 validation acc: 72.94%

iter_one_val_time: 7.883540868759155 seconds


train: 100%|██████████| 391/391 [05:46<00:00,  1.13it/s]

epoch_time: 346.9578363895416 seconds


EPOCH 16



train:   2%|▏         | 9/391 [00:07<05:05,  1.25it/s]

iter: 9 / 391
training acc: 72.66%


train:   3%|▎         | 10/391 [00:15<20:14,  3.19s/it]

16-9 validation acc: 73.05%

iter_one_val_time: 7.734458923339844 seconds


train:  28%|██▊       | 109/391 [01:35<03:47,  1.24it/s]

iter: 109 / 391
training acc: 75.78%


train:  28%|██▊       | 110/391 [01:44<14:44,  3.15s/it]

16-109 validation acc: 72.76%

iter_one_val_time: 7.797995328903198 seconds


train:  53%|█████▎    | 209/391 [03:04<02:26,  1.25it/s]

iter: 209 / 391
training acc: 75.78%


train:  54%|█████▎    | 210/391 [03:13<09:30,  3.15s/it]

16-209 validation acc: 72.82%

iter_one_val_time: 7.823532342910767 seconds


train:  79%|███████▉  | 309/391 [04:33<01:05,  1.24it/s]

iter: 309 / 391
training acc: 78.91%


train:  79%|███████▉  | 310/391 [04:41<04:13,  3.13s/it]

16-309 validation acc: 72.90%

iter_one_val_time: 7.742618083953857 seconds


train: 100%|██████████| 391/391 [05:46<00:00,  1.13it/s]

epoch_time: 347.07173800468445 seconds


EPOCH 17



train:   2%|▏         | 9/391 [00:07<05:07,  1.24it/s]

iter: 9 / 391
training acc: 76.56%


train:   3%|▎         | 10/391 [00:16<20:49,  3.28s/it]

17-9 validation acc: 72.83%

iter_one_val_time: 8.010719299316406 seconds


train:  28%|██▊       | 109/391 [01:36<03:46,  1.24it/s]

iter: 109 / 391
training acc: 75.78%


train:  28%|██▊       | 110/391 [01:45<14:58,  3.20s/it]

17-109 validation acc: 72.77%

iter_one_val_time: 7.977971315383911 seconds


train:  53%|█████▎    | 209/391 [03:04<02:26,  1.24it/s]

iter: 209 / 391
training acc: 78.12%


train:  54%|█████▎    | 210/391 [03:13<09:34,  3.18s/it]

17-209 validation acc: 72.94%

iter_one_val_time: 7.90313196182251 seconds


train:  79%|███████▉  | 309/391 [04:33<01:05,  1.24it/s]

iter: 309 / 391
training acc: 71.88%


train:  79%|███████▉  | 310/391 [04:42<04:17,  3.17s/it]

17-309 validation acc: 72.51%

iter_one_val_time: 7.91007661819458 seconds


train: 100%|██████████| 391/391 [05:47<00:00,  1.12it/s]

epoch_time: 348.1126263141632 seconds


EPOCH 18



train:   2%|▏         | 9/391 [00:07<05:08,  1.24it/s]

iter: 9 / 391
training acc: 71.88%


train:   3%|▎         | 10/391 [00:16<20:56,  3.30s/it]

18-9 validation acc: 72.73%

iter_one_val_time: 8.075047969818115 seconds


train:  28%|██▊       | 109/391 [01:36<03:54,  1.20it/s]

iter: 109 / 391
training acc: 75.00%


train:  28%|██▊       | 110/391 [01:45<15:19,  3.27s/it]

18-109 validation acc: 72.58%

iter_one_val_time: 8.133409261703491 seconds


train:  53%|█████▎    | 209/391 [03:05<02:26,  1.24it/s]

iter: 209 / 391
training acc: 77.34%


train:  54%|█████▎    | 210/391 [03:14<09:36,  3.19s/it]

18-209 validation acc: 73.05%

iter_one_val_time: 7.795913457870483 seconds


train:  79%|███████▉  | 309/391 [04:34<01:05,  1.24it/s]

iter: 309 / 391
training acc: 75.00%


train:  79%|███████▉  | 310/391 [04:43<04:18,  3.19s/it]

18-309 validation acc: 72.55%

iter_one_val_time: 7.957659006118774 seconds


train: 100%|██████████| 391/391 [05:48<00:00,  1.12it/s]

epoch_time: 348.8010678291321 seconds


EPOCH 19



train:   2%|▏         | 9/391 [00:07<05:08,  1.24it/s]

iter: 9 / 391
training acc: 75.00%


train:   3%|▎         | 10/391 [00:16<20:37,  3.25s/it]

19-9 validation acc: 72.90%

iter_one_val_time: 7.898122549057007 seconds


train:  28%|██▊       | 109/391 [01:35<03:48,  1.23it/s]

iter: 109 / 391
training acc: 75.00%


train:  28%|██▊       | 110/391 [01:44<14:53,  3.18s/it]

19-109 validation acc: 73.23%

iter_one_val_time: 7.893476963043213 seconds


train:  53%|█████▎    | 209/391 [03:04<02:27,  1.23it/s]

iter: 209 / 391
training acc: 78.12%


train:  54%|█████▎    | 210/391 [03:13<09:27,  3.13s/it]

19-209 validation acc: 72.84%

iter_one_val_time: 7.747832775115967 seconds


train:  79%|███████▉  | 309/391 [04:33<01:06,  1.24it/s]

iter: 309 / 391
training acc: 77.34%


train:  79%|███████▉  | 310/391 [04:41<04:13,  3.13s/it]

19-309 validation acc: 72.69%

iter_one_val_time: 7.758384943008423 seconds


train: 100%|██████████| 391/391 [05:47<00:00,  1.13it/s]


epoch_time: 347.5045096874237 seconds


EPOCH 20


train:   2%|▏         | 9/391 [00:07<05:08,  1.24it/s]

iter: 9 / 391
training acc: 72.66%


train:   3%|▎         | 10/391 [00:16<20:15,  3.19s/it]

20-9 validation acc: 73.15%

iter_one_val_time: 7.720524072647095 seconds


train:  28%|██▊       | 109/391 [01:36<03:48,  1.23it/s]

iter: 109 / 391
training acc: 75.78%


train:  28%|██▊       | 110/391 [01:44<14:55,  3.19s/it]

20-109 validation acc: 72.86%

iter_one_val_time: 7.93156361579895 seconds


train:  53%|█████▎    | 209/391 [03:04<02:27,  1.24it/s]

iter: 209 / 391
training acc: 71.09%


train:  54%|█████▎    | 210/391 [03:13<09:31,  3.16s/it]

20-209 validation acc: 73.32%

iter_one_val_time: 7.821039199829102 seconds


train:  79%|███████▉  | 309/391 [04:33<01:05,  1.25it/s]

iter: 309 / 391
training acc: 75.00%


train:  79%|███████▉  | 310/391 [04:42<04:14,  3.14s/it]

20-309 validation acc: 72.99%

iter_one_val_time: 7.7756569385528564 seconds


train: 100%|██████████| 391/391 [05:47<00:00,  1.13it/s]

epoch_time: 347.61015152931213 seconds


EPOCH 21



train:   2%|▏         | 9/391 [00:07<05:14,  1.22it/s]

iter: 9 / 391
training acc: 82.03%


train:   3%|▎         | 10/391 [00:15<20:19,  3.20s/it]

21-9 validation acc: 72.34%

iter_one_val_time: 7.728041172027588 seconds


train:  28%|██▊       | 109/391 [01:35<03:47,  1.24it/s]

iter: 109 / 391
training acc: 78.12%


train:  28%|██▊       | 110/391 [01:44<14:38,  3.13s/it]

21-109 validation acc: 72.80%

iter_one_val_time: 7.743985891342163 seconds


train:  53%|█████▎    | 209/391 [03:04<02:27,  1.24it/s]

iter: 209 / 391
training acc: 79.69%


train:  54%|█████▎    | 210/391 [03:12<09:31,  3.16s/it]

21-209 validation acc: 73.14%

iter_one_val_time: 7.8283302783966064 seconds


train:  79%|███████▉  | 309/391 [04:33<01:05,  1.24it/s]

iter: 309 / 391
training acc: 80.47%


train:  79%|███████▉  | 310/391 [04:41<04:17,  3.18s/it]

21-309 validation acc: 72.86%

iter_one_val_time: 7.9239583015441895 seconds


train: 100%|██████████| 391/391 [05:47<00:00,  1.13it/s]

epoch_time: 347.3126301765442 seconds


EPOCH 22



train:   2%|▏         | 9/391 [00:07<05:21,  1.19it/s]

iter: 9 / 391
training acc: 78.91%


train:   3%|▎         | 10/391 [00:16<22:01,  3.47s/it]

22-9 validation acc: 72.65%

iter_one_val_time: 8.546625137329102 seconds


train:  28%|██▊       | 109/391 [01:37<03:53,  1.21it/s]

iter: 109 / 391
training acc: 71.09%


train:  28%|██▊       | 110/391 [01:46<15:09,  3.24s/it]

22-109 validation acc: 72.93%

iter_one_val_time: 8.046215057373047 seconds


train:  53%|█████▎    | 209/391 [03:06<02:25,  1.25it/s]

iter: 209 / 391
training acc: 79.69%


train:  54%|█████▎    | 210/391 [03:14<09:38,  3.19s/it]

22-209 validation acc: 72.76%

iter_one_val_time: 7.977060794830322 seconds


train:  79%|███████▉  | 309/391 [04:35<01:07,  1.21it/s]

iter: 309 / 391
training acc: 75.00%
22-309 validation acc: 73.52%

iter_one_val_time: 7.985539436340332 seconds


train: 100%|██████████| 391/391 [05:51<00:00,  1.11it/s]

epoch_time: 351.62087416648865 seconds


EPOCH 23



train:   2%|▏         | 9/391 [00:07<05:11,  1.23it/s]

iter: 9 / 391
training acc: 72.66%


train:   3%|▎         | 10/391 [00:16<20:44,  3.27s/it]

23-9 validation acc: 72.59%

iter_one_val_time: 7.949851989746094 seconds


train:  28%|██▊       | 109/391 [01:36<03:52,  1.21it/s]

iter: 109 / 391
training acc: 76.56%


train:  28%|██▊       | 110/391 [01:44<14:45,  3.15s/it]

23-109 validation acc: 72.77%

iter_one_val_time: 7.774939298629761 seconds


train:  53%|█████▎    | 209/391 [03:05<02:29,  1.21it/s]

iter: 209 / 391
training acc: 80.47%


train:  54%|█████▎    | 210/391 [03:13<09:27,  3.14s/it]

23-209 validation acc: 73.14%

iter_one_val_time: 7.725614547729492 seconds


train:  79%|███████▉  | 309/391 [04:33<01:07,  1.22it/s]

iter: 309 / 391
training acc: 77.34%


train:  79%|███████▉  | 310/391 [04:42<04:14,  3.14s/it]

23-309 validation acc: 72.83%

iter_one_val_time: 7.7615931034088135 seconds


train: 100%|██████████| 391/391 [05:47<00:00,  1.12it/s]

epoch_time: 348.07648730278015 seconds


EPOCH 24



train:   2%|▏         | 9/391 [00:07<05:10,  1.23it/s]

iter: 9 / 391
training acc: 78.91%


train:   3%|▎         | 10/391 [00:16<20:44,  3.27s/it]

24-9 validation acc: 72.88%

iter_one_val_time: 7.945330858230591 seconds


train:  28%|██▊       | 109/391 [01:36<03:50,  1.22it/s]

iter: 109 / 391
training acc: 72.66%


train:  28%|██▊       | 110/391 [01:44<14:48,  3.16s/it]

24-109 validation acc: 72.64%

iter_one_val_time: 7.828792572021484 seconds


train:  53%|█████▎    | 209/391 [03:04<02:25,  1.25it/s]

iter: 209 / 391
training acc: 76.56%


train:  54%|█████▎    | 210/391 [03:13<09:30,  3.15s/it]

24-209 validation acc: 73.22%

iter_one_val_time: 7.812132120132446 seconds


train:  79%|███████▉  | 309/391 [04:33<01:06,  1.24it/s]

iter: 309 / 391
training acc: 76.56%


train:  79%|███████▉  | 310/391 [04:42<04:18,  3.19s/it]

24-309 validation acc: 72.68%

iter_one_val_time: 7.925442934036255 seconds


train: 100%|██████████| 391/391 [05:47<00:00,  1.12it/s]

epoch_time: 347.8764147758484 seconds


EPOCH 25



train:   2%|▏         | 9/391 [00:07<05:08,  1.24it/s]

iter: 9 / 391
training acc: 74.22%


train:   3%|▎         | 10/391 [00:16<20:41,  3.26s/it]

25-9 validation acc: 72.84%

iter_one_val_time: 7.936924457550049 seconds


train:  28%|██▊       | 109/391 [01:36<03:48,  1.23it/s]

iter: 109 / 391
training acc: 75.00%


train:  28%|██▊       | 110/391 [01:44<14:58,  3.20s/it]

25-109 validation acc: 73.14%

iter_one_val_time: 7.957545042037964 seconds


train:  53%|█████▎    | 209/391 [03:04<02:26,  1.24it/s]

iter: 209 / 391
training acc: 78.91%


train:  54%|█████▎    | 210/391 [03:13<09:30,  3.15s/it]

25-209 validation acc: 72.72%

iter_one_val_time: 7.836515188217163 seconds


train:  68%|██████▊   | 264/391 [03:57<01:54,  1.11it/s]


KeyboardInterrupt: 