In [1]:
# Copyright (c) 2024 Byeonghyeon Kim 
# github site: https://github.com/bhkim003/ByeonghyeonKim
# email: bhkim003@snu.ac.kr
 
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
 
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


In [2]:
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.datasets
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt

import time

from snntorch import spikegen
import matplotlib.pyplot as plt
import snntorch.spikeplot as splt
from IPython.display import HTML

from tqdm import tqdm

In [3]:
class SYNAPSE_FC_BPTT(nn.Module):
    def __init__(self, in_features, out_features, trace_const1=1, trace_const2=0.7, TIME=8):
        super(SYNAPSE_FC_BPTT, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.trace_const1 = trace_const1
        self.trace_const2 = trace_const2

        # self.weight = torch.randn(self.out_features, self.in_features, requires_grad=True)
        # self.bias = torch.randn(self.out_features, requires_grad=True)
        self.weight = nn.Parameter(torch.randn(self.out_features, self.in_features))
        self.bias = nn.Parameter(torch.randn(self.out_features))

        self.TIME = TIME

    def forward(self, spike):
        # spike: [Time, Batch, Features]   
        Time = spike.shape[0]
        assert Time == self.TIME, 'Time dimension should be same as TIME'
        Batch = spike.shape[1] 

        output_current = []

        for t in range(Time):
            output_current.append(F.linear(spike[t], weight = self.weight, bias= self.bias))

        output_current = torch.stack(output_current, dim=0)
        return output_current 


class SYNAPSE_CONV_BPTT(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, trace_const1=1, trace_const2=0.7, TIME=8):
        super(SYNAPSE_CONV_BPTT, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.trace_const1 = trace_const1
        self.trace_const2 = trace_const2
        # self.weight = torch.randn(self.out_channels, self.in_channels, self.kernel_size, self.kernel_size, requires_grad=True)
        # self.bias = torch.randn(self.out_channels, requires_grad=True)
        self.weight = nn.Parameter(torch.randn(self.out_channels, self.in_channels, self.kernel_size, self.kernel_size))
        self.bias = nn.Parameter(torch.randn(self.out_channels))

        self.TIME = TIME

    def forward(self, spike):
        # spike: [Time, Batch, Channel, Height, Width]   
        # print('spike.shape', spike.shape)
        Time = spike.shape[0]
        assert Time == self.TIME, 'Time dimension should be same as TIME'
        Batch = spike.shape[1] 
        Channel = self.out_channels
        Height = (spike.shape[3] + self.padding*2 - self.kernel_size) // self.stride + 1
        Width = (spike.shape[4] + self.padding*2 - self.kernel_size) // self.stride + 1

        # output_current = torch.zeros(Time, Batch, Channel, Height, Width, device=spike.device)
        output_current = []
        
        for t in range(Time):
            # print(f'time:{t}', torch.sum(spike_detach[t]/ torch.numel(spike_detach[t])))
            output_current.append(F.conv2d(spike[t], self.weight, bias=self.bias, stride=self.stride, padding=self.padding))
            # print(f'time:{t}', torch.sum(output_current[t]/ torch.numel(output_current[t])))

        output_current = torch.stack(output_current, dim=0)
        return output_current


class SYNAPSE_FC_METHOD(torch.autograd.Function):
    @staticmethod
    def forward(ctx, spike_one_time, spike_now, weight, bias):
        ctx.save_for_backward(spike_one_time, spike_now, weight, bias)
        return F.linear(spike_one_time, weight, bias=bias)

    @staticmethod
    def backward(ctx, grad_output_current):
        #############밑에부터 수정해라#######
        spike_one_time, spike_now, weight, bias = ctx.saved_tensors
        
        ## 이거 클론해야되는지 모르겠음!!!!
        grad_output_current_clone = grad_output_current.clone()

        grad_input_spike = grad_weight = grad_bias = None


        if ctx.needs_input_grad[0]:
            grad_input_spike = grad_output_current_clone @ weight
        if ctx.needs_input_grad[2]:
            grad_weight = grad_output_current_clone.t() @ spike_now
        if bias is not None and ctx.needs_input_grad[3]:
            grad_bias = grad_output_current_clone.sum(0)

        return grad_input_spike, None, grad_weight, grad_bias

    
class SYNAPSE_FC(nn.Module):
    def __init__(self, in_features, out_features, trace_const1=1, trace_const2=0.7, TIME=8):
        super(SYNAPSE_FC, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.trace_const1 = trace_const1
        self.trace_const2 = trace_const2

        # self.weight = torch.randn(self.out_features, self.in_features, requires_grad=True)
        # self.bias = torch.randn(self.out_features, requires_grad=True)
        self.weight = nn.Parameter(torch.randn(self.out_features, self.in_features))
        self.bias = nn.Parameter(torch.randn(self.out_features))

        self.TIME = TIME

    def forward(self, spike):
        # spike: [Time, Batch, Features]   
        Time = spike.shape[0]
        assert Time == self.TIME, 'Time dimension should be same as TIME'
        Batch = spike.shape[1] 

        # output_current = torch.zeros(Time, Batch, self.out_features, device=spike.device)
        output_current = []

        # spike_detach = spike.detach().clone()
        spike_detach = spike.detach()
        spike_past = torch.zeros_like(spike_detach[0], device=spike.device)
        spike_now = torch.zeros_like(spike_detach[0], device=spike.device)

        for t in range(Time):
            spike_now = self.trace_const1*spike_detach[t] + self.trace_const2*spike_past
            # output_current[t]= SYNAPSE_FC_METHOD.apply(spike[t], spike_now, self.weight, self.bias) 
            output_current.append( SYNAPSE_FC_METHOD.apply(spike[t], spike_now, self.weight, self.bias) )
            
            spike_past = spike_now

        output_current = torch.stack(output_current, dim=0)
        return output_current 



class SYNAPSE_CONV_METHOD(torch.autograd.Function):
    @staticmethod
    def forward(ctx, spike_one_time, spike_now, weight, bias, stride=1, padding=1):
        ctx.save_for_backward(spike_one_time, spike_now, weight, bias, torch.tensor([stride], requires_grad=False), torch.tensor([padding], requires_grad=False))
        return F.conv2d(spike_one_time, weight, bias=bias, stride=stride, padding=padding)

    @staticmethod
    def backward(ctx, grad_output_current):
        spike_one_time, spike_now, weight, bias, stride, padding = ctx.saved_tensors
        stride=stride.item()
        padding=padding.item()
        
        ## 이거 클론해야되는지 모르겠음!!!!
        grad_output_current_clone = grad_output_current.clone()

        grad_input_spike = grad_weight = grad_bias = None


        if ctx.needs_input_grad[0]:
            grad_input_spike = F.conv_transpose2d(grad_output_current_clone, weight, stride=stride, padding=padding)
        if ctx.needs_input_grad[2]:
            grad_weight = torch.nn.grad.conv2d_weight(spike_now, weight.shape, grad_output_current_clone,
                                                    stride=stride, padding=padding)
        if bias is not None and ctx.needs_input_grad[3]:
            grad_bias = grad_output_current_clone.sum((0, -1, -2))

        return grad_input_spike, None, grad_weight, grad_bias, None, None

    



class SYNAPSE_CONV(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, trace_const1=1, trace_const2=0.7, TIME=8):
        super(SYNAPSE_CONV, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.trace_const1 = trace_const1
        self.trace_const2 = trace_const2
        # self.weight = torch.randn(self.out_channels, self.in_channels, self.kernel_size, self.kernel_size, requires_grad=True)
        # self.bias = torch.randn(self.out_channels, requires_grad=True)
        self.weight = nn.Parameter(torch.randn(self.out_channels, self.in_channels, self.kernel_size, self.kernel_size))
        self.bias = nn.Parameter(torch.randn(self.out_channels))

        self.TIME = TIME

    def forward(self, spike):
        # spike: [Time, Batch, Channel, Height, Width]   
        # print('spike.shape', spike.shape)
        Time = spike.shape[0]
        assert Time == self.TIME, 'Time dimension should be same as TIME'
        Batch = spike.shape[1] 
        Channel = self.out_channels
        Height = (spike.shape[3] + self.padding*2 - self.kernel_size) // self.stride + 1
        Width = (spike.shape[4] + self.padding*2 - self.kernel_size) // self.stride + 1

        # output_current = torch.zeros(Time, Batch, Channel, Height, Width, device=spike.device)
        output_current = []
        
        # spike_detach = spike.detach().clone()
        spike_detach = spike.detach()
        spike_past = torch.zeros_like(spike_detach[0])
        spike_now = torch.zeros_like(spike_detach[0])
        for t in range(Time):
            # print(f'time:{t}', torch.sum(spike_detach[t]/ torch.numel(spike_detach[t])))
            spike_now = self.trace_const1*spike_detach[t] + self.trace_const2*spike_past

            # output_current[t]= SYNAPSE_CONV_METHOD.apply(spike[t], spike_now, self.weight, self.bias, self.stride, self.padding) 
            output_current.append( SYNAPSE_CONV_METHOD.apply(spike[t], spike_now, self.weight, self.bias, self.stride, self.padding) )
            
            spike_past = spike_now
            # print(f'time:{t}', torch.sum(output_current[t]/ torch.numel(output_current[t])))

        output_current = torch.stack(output_current, dim=0)
        return output_current



class LIF_METHOD(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_current_one_time, v_one_time, v_decay, v_threshold, v_reset, sg_width, surrogate):
        v_one_time = v_one_time * v_decay + input_current_one_time # leak + pre-synaptic current integrate
        spike = (v_one_time >= v_threshold).float() #fire
        if surrogate == 'sigmoid':
            surrogate = 1
        elif surrogate == 'rectangle':
            surrogate = 2
        elif surrogate == 'rough_rectangle':
            surrogate = 3
        else:
            pass
        ctx.save_for_backward(v_one_time, torch.tensor([v_decay], requires_grad=False), 
                            torch.tensor([v_threshold], requires_grad=False), 
                            torch.tensor([v_reset], requires_grad=False), 
                            torch.tensor([sg_width], requires_grad=False),
                            torch.tensor([surrogate], requires_grad=False)) # save before reset
        
        v_one_time = (v_one_time - spike * v_threshold).clamp_min(0) # reset
        return spike, v_one_time

    @staticmethod
    def backward(ctx, grad_output_spike, grad_output_v):
        v_one_time, v_decay, v_threshold, v_reset, sg_width, surrogate = ctx.saved_tensors
        v_decay=v_decay.item()
        v_threshold=v_threshold.item()
        v_reset=v_reset.item()
        sg_width=sg_width.item()
        surrogate=surrogate.item()

        grad_input_current = grad_output_spike.clone()
        # grad_temp_v = grad_output_v.clone() # not used

        ################ select one of the following surrogate gradient functions ################
        if (surrogate == 1):
            #===========surrogate gradient function (sigmoid)
            sig = torch.sigmoid((v_one_time - v_threshold))
            grad_input_current *= 4*sig*(1-sig)
            # grad_x = grad_output * (1. - sgax) * sgax * ctx.alpha

        elif (surrogate == 2):
            # ===========surrogate gradient function (rectangle)
            grad_input_current *= ((v_one_time - v_threshold).abs() < sg_width/2).float() / sg_width

        elif (surrogate == 3):
            #===========surrogate gradient function (rough rectangle)
            grad_input_current[(v_one_time - v_threshold).abs() > sg_width/2] = 0
        else: 
            pass
        ###########################################################################################
        return grad_input_current, None, None, None, None, None, None

class LIF_layer(nn.Module):
    def __init__ (self, v_init , v_decay , v_threshold , v_reset , sg_width, surrogate):
        super(LIF_layer, self).__init__()
        self.v_init = v_init
        self.v_decay = v_decay
        self.v_threshold = v_threshold
        self.v_reset = v_reset
        self.sg_width = sg_width
        self.surrogate = surrogate

    def forward(self, input_current):
        v = torch.full_like(input_current, fill_value = self.v_init, dtype = torch.float) # v (membrane potential) init
        post_spike = torch.full_like(input_current, fill_value = self.v_init, device=input_current.device, dtype = torch.float) # v (membrane potential) init
        # i와 v와 post_spike size는 여기서 다 같음: [Time, Batch, Channel, Height, Width] 
        Time = v.shape[0]
        for t in range(Time):
            # leaky하고 input_current 더하고 fire하고 reset까지 (backward직접처리)
            post_spike[t], v[t] = LIF_METHOD.apply(input_current[t], v[t], 
                                            self.v_decay, self.v_threshold, self.v_reset, self.sg_width, self.surrogate) 
        return post_spike 




class DimChanger_for_pooling(nn.Module):
    def __init__(self, module):
        super(DimChanger_for_pooling, self).__init__()
        self.ann_module = module

    def forward(self, x):
        timestep, batch_size, *dim = x.shape
        output = self.ann_module(x.reshape(timestep * batch_size, *dim))
        _, *dim = output.shape
        output = output.view(timestep, batch_size, *dim).contiguous()
        return output


class DimChanger_for_FC(nn.Module):
    def __init__(self):
        super(DimChanger_for_FC, self).__init__()

    def forward(self, x):
        x = x.view(x.size(0), x.size(1), -1)
        return x

class tdBatchNorm(nn.BatchNorm2d):
    def __init__(self, channel):
        super(tdBatchNorm, self).__init__(channel)
        # according to tdBN paper, the initialized weight is changed to alpha*Vth
        # self.weight.data.mul_(0.5)

    def forward(self, x):
        T, B, *spatial_dims = x.shape
        out = super().forward(x.reshape(T * B, *spatial_dims))
        TB, *spatial_dims = out.shape
        out = out.view(T, B, *spatial_dims).contiguous()
        return out
    
    
class tdBatchNorm_FC(nn.BatchNorm1d):
    def __init__(self, channel):
        super(tdBatchNorm_FC, self).__init__(channel)
        # according to tdBN paper, the initialized weight is changed to alpha*Vth
        # self.weight.data.mul_(0.5)

    def forward(self, x):
        T, B, *spatial_dims = x.shape
        out = super().forward(x.reshape(T * B, *spatial_dims))
        TB, *spatial_dims = out.shape
        out = out.view(T, B, *spatial_dims).contiguous()
        return out


class BatchNorm(nn.Module):
    def __init__(self, out_channels, TIME):
        super(BatchNorm, self).__init__()
        self.out_channels = out_channels
        self.TIME = TIME
        self.bn_layers = nn.ModuleList([nn.BatchNorm2d(self.out_channels) for _ in range(self.TIME)])

    def forward(self, x):
        # out = torch.zeros_like(x, device=x.device) #Time, Batch, Channel, Height, Width
        out = [] #Time, Batch, Channel, Height, Width
        for t in range(self.TIME):
            out.append(self.bn_layers[t](x[t]))
        out = torch.stack(out, dim=0)
        return out
    
class BatchNorm_FC(nn.Module):
    def __init__(self, out_channels, TIME):
        super(BatchNorm_FC, self).__init__()
        self.out_channels = out_channels
        self.TIME = TIME
        self.bn_layers = nn.ModuleList([nn.BatchNorm1d(self.out_channels) for _ in range(self.TIME)])

    def forward(self, x):
        # out = torch.zeros_like(x, device=x.device) #Time, Batch, Channel, Height, Width
        out = [] #Time, Batch, Channel, Height, Width
        for t in range(self.TIME):
            out.append(self.bn_layers[t](x[t]))
        out = torch.stack(out, dim=0)
        return out

def make_layers_conv(cfg, in_c, IMAGE_SIZE,
                     synapse_conv_kernel_size, synapse_conv_stride, 
                     synapse_conv_padding, synapse_conv_trace_const1, 
                     synapse_conv_trace_const2, 
                     lif_layer_v_init, lif_layer_v_decay, 
                     lif_layer_v_threshold, lif_layer_v_reset,
                     lif_layer_sg_width,
                     tdBN_on,
                     BN_on, TIME,
                     surrogate,
                     BPTT_on,
                     synapse_fc_out_features):
    
    layers = []
    in_channels = in_c
    img_size_var = IMAGE_SIZE
    for which in cfg:
        if which == 'P':
            layers += [DimChanger_for_pooling(nn.AvgPool2d(kernel_size=2, stride=2))]
            # layers += [DimChanger_for_pooling(nn.MaxPool2d(kernel_size=2, stride=2))]
            img_size_var = img_size_var // 2
        else:
            out_channels = which
            if (BPTT_on == False):
                layers += [SYNAPSE_CONV(in_channels=in_channels,
                                        out_channels=out_channels, 
                                        kernel_size=synapse_conv_kernel_size, 
                                        stride=synapse_conv_stride, 
                                        padding=synapse_conv_padding, 
                                        trace_const1=synapse_conv_trace_const1, 
                                        trace_const2=synapse_conv_trace_const2,
                                        TIME=TIME)]
            else:
                layers += [SYNAPSE_CONV_BPTT(in_channels=in_channels,
                                        out_channels=out_channels, 
                                        kernel_size=synapse_conv_kernel_size, 
                                        stride=synapse_conv_stride, 
                                        padding=synapse_conv_padding, 
                                        trace_const1=synapse_conv_trace_const1, 
                                        trace_const2=synapse_conv_trace_const2,
                                        TIME=TIME)]
            
            img_size_var = (img_size_var - synapse_conv_kernel_size + 2*synapse_conv_padding)//synapse_conv_stride + 1
           
            in_channels = which
            
            if (tdBN_on == True):
                layers += [tdBatchNorm(in_channels)] # 여기서 in_channel이 out_channel임

            if (BN_on == True):
                layers += [BatchNorm(in_channels, TIME)]

            layers += [LIF_layer(v_init=lif_layer_v_init, 
                                    v_decay=lif_layer_v_decay, 
                                    v_threshold=lif_layer_v_threshold, 
                                    v_reset=lif_layer_v_reset, 
                                    sg_width=lif_layer_sg_width,
                                    surrogate=surrogate)]
            
    layers += [DimChanger_for_FC()]
    if (BPTT_on == False):
        layers += [SYNAPSE_FC(in_features=in_channels*img_size_var*img_size_var,  # 마지막CONV의 OUT_CHANNEL * H * W
                                        out_features=synapse_fc_out_features, 
                                        trace_const1=synapse_conv_trace_const1, 
                                        trace_const2=synapse_conv_trace_const2,
                                        TIME=TIME)]
    else:
        layers += [SYNAPSE_FC_BPTT(in_features=in_channels*img_size_var*img_size_var,  # 마지막CONV의 OUT_CHANNEL * H * W
                                        out_features=synapse_fc_out_features, 
                                        trace_const1=synapse_conv_trace_const1, 
                                        trace_const2=synapse_conv_trace_const2,
                                        TIME=TIME)]

    return nn.Sequential(*layers)



class MY_SNN_CONV(nn.Module):
    def __init__(self, cfg, in_c, IMAGE_SIZE,
                     synapse_conv_kernel_size, synapse_conv_stride, 
                     synapse_conv_padding, synapse_conv_trace_const1, 
                     synapse_conv_trace_const2, 
                     lif_layer_v_init, lif_layer_v_decay, 
                     lif_layer_v_threshold, lif_layer_v_reset,
                     lif_layer_sg_width,
                     synapse_fc_out_features, synapse_fc_trace_const1, synapse_fc_trace_const2,
                     tdBN_on,
                     BN_on, TIME,
                     surrogate,
                     BPTT_on):
        super(MY_SNN_CONV, self).__init__()
        self.layers = make_layers_conv(cfg, in_c, IMAGE_SIZE,
                                    synapse_conv_kernel_size, synapse_conv_stride, 
                                    synapse_conv_padding, synapse_conv_trace_const1, 
                                    synapse_conv_trace_const2, 
                                    lif_layer_v_init, lif_layer_v_decay, 
                                    lif_layer_v_threshold, lif_layer_v_reset,
                                    lif_layer_sg_width,
                                    tdBN_on,
                                    BN_on, TIME,
                                    surrogate,
                                    BPTT_on,
                                    synapse_fc_out_features)


    def forward(self, spike_input):
        # inputs: [Batch, Time, Channel, Height, Width]   
        spike_input = spike_input.permute(1, 0, 2, 3, 4)
        # inputs: [Time, Batch, Channel, Height, Width]   
        spike_input = self.layers(spike_input)

        spike_input = spike_input.sum(axis=0)
        return spike_input



def make_layers_fc(cfg, in_c, IMAGE_SIZE, out_c,
                     synapse_fc_trace_const1, synapse_fc_trace_const2, 
                     lif_layer_v_init, lif_layer_v_decay, 
                     lif_layer_v_threshold, lif_layer_v_reset,
                     lif_layer_sg_width,
                     tdBN_on,
                     BN_on, TIME,
                     surrogate,
                     BPTT_on):

    layers = []
    img_size = IMAGE_SIZE
    in_channels = in_c * img_size * img_size
    class_num = out_c
    for which in cfg:
        out_channels = which

        if(BPTT_on == False):
            layers += [SYNAPSE_FC(in_features=in_channels,  # 마지막CONV의 OUT_CHANNEL * H * W
                                        out_features=out_channels, 
                                        trace_const1=synapse_fc_trace_const1, 
                                        trace_const2=synapse_fc_trace_const2,
                                        TIME=TIME)]
        else:
            layers += [SYNAPSE_FC_BPTT(in_features=in_channels,  # 마지막CONV의 OUT_CHANNEL * H * W
                                        out_features=out_channels, 
                                        trace_const1=synapse_fc_trace_const1, 
                                        trace_const2=synapse_fc_trace_const2,
                                        TIME=TIME)]
            



        in_channels = which
        
        if (tdBN_on == True):
            layers += [tdBatchNorm_FC(in_channels)] # 여기서 in_channel이 out_channel임

        if (BN_on == True):
            layers += [BatchNorm_FC(in_channels, TIME)]

        layers += [LIF_layer(v_init=lif_layer_v_init, 
                                v_decay=lif_layer_v_decay, 
                                v_threshold=lif_layer_v_threshold, 
                                v_reset=lif_layer_v_reset, 
                                sg_width=lif_layer_sg_width,
                                surrogate=surrogate)]

    
    out_channels = class_num
    if(BPTT_on == False):
        layers += [SYNAPSE_FC(in_features=in_channels,  # 마지막CONV의 OUT_CHANNEL * H * W
                                    out_features=out_channels, 
                                    trace_const1=synapse_fc_trace_const1, 
                                    trace_const2=synapse_fc_trace_const2,
                                    TIME=TIME)]
    else:
        layers += [SYNAPSE_FC_BPTT(in_features=in_channels,  # 마지막CONV의 OUT_CHANNEL * H * W
                                    out_features=out_channels, 
                                    trace_const1=synapse_fc_trace_const1, 
                                    trace_const2=synapse_fc_trace_const2,
                                    TIME=TIME)]
        
    return nn.Sequential(*layers)

class MY_SNN_FC(nn.Module):
    def __init__(self, cfg, in_c, IMAGE_SIZE, out_c,
                     synapse_fc_trace_const1, synapse_fc_trace_const2, 
                     lif_layer_v_init, lif_layer_v_decay, 
                     lif_layer_v_threshold, lif_layer_v_reset,
                     lif_layer_sg_width,
                     tdBN_on,
                     BN_on, TIME,
                     surrogate,
                     BPTT_on):
        super(MY_SNN_FC, self).__init__()

        self.layers = make_layers_fc(cfg, in_c, IMAGE_SIZE, out_c,
                     synapse_fc_trace_const1, synapse_fc_trace_const2, 
                     lif_layer_v_init, lif_layer_v_decay, 
                     lif_layer_v_threshold, lif_layer_v_reset,
                     lif_layer_sg_width,
                     tdBN_on,
                     BN_on, TIME,
                     surrogate,
                     BPTT_on)

    def forward(self, spike_input):
        # inputs: [Batch, Time, Channel, Height, Width]   
        spike_input = spike_input.permute(1, 0, 2, 3, 4)
        # inputs: [Time, Batch, Channel, Height, Width]   
        spike_input = spike_input.view(spike_input.size(0), spike_input.size(1), -1)
        
        spike_input = self.layers(spike_input)

        spike_input = spike_input.sum(axis=0)

        return spike_input
    












def my_snn_system(devices = "0,1,2,3",
                    my_seed = 42,
                    TIME = 8,
                    BATCH = 256,
                    IMAGE_SIZE = 32,
                    which_data = 'CIFAR10',
                    rate_coding = True,
    
                    lif_layer_v_init = 0.0,
                    lif_layer_v_decay = 0.6,
                    lif_layer_v_threshold = 1.2,
                    lif_layer_v_reset = 0.0,
                    lif_layer_sg_width = 1,

                    # synapse_conv_in_channels = IMAGE_PIXEL_CHANNEL,
                    synapse_conv_kernel_size = 3,
                    synapse_conv_stride = 1,
                    synapse_conv_padding = 1,
                    synapse_conv_trace_const1 = 1,
                    synapse_conv_trace_const2 = 0.6,

                    # synapse_fc_out_features = CLASS_NUM,
                    synapse_fc_trace_const1 = 1,
                    synapse_fc_trace_const2 = 0.6,

                    pre_trained = False,
                    convTrue_fcFalse = True,
                    cfg = [64, 64],
                    pre_trained_path = "net_save/save_now_net.pth",
                    learning_rate = 0.0001,
                    epoch_num = 100,
                    verbose_interval = 100,
                    tdBN_on = False,
                    BN_on = False,

                    surrogate = 'sigmoid',

                    gradient_verbose = False,

                    BPTT_on = False
                  ):


    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
    os.environ["CUDA_VISIBLE_DEVICES"]= devices

    
    torch.manual_seed(my_seed)


    






    if (which_data == 'MNIST'):
        data_path = '/data2'

        if rate_coding :
            transform = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0,), (1,))])
        else : 
            transform = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5,),(0.5))])

        trainset = torchvision.datasets.MNIST(root=data_path,
                                            train=True,
                                            download=True,
                                            transform=transform)


        testset = torchvision.datasets.MNIST(root=data_path,
                                            train=False,
                                            download=True,
                                            transform=transform)

        train_loader = DataLoader(trainset,
                                batch_size =BATCH,
                                shuffle = True,
                                num_workers =2)
        test_loader = DataLoader(testset,
                                batch_size =BATCH,
                                shuffle = False,
                                num_workers =2)


    if (which_data == 'CIFAR10'):
        data_path = '/data2/cifar10'

        if rate_coding :
            transform_train = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                                                transforms.RandomHorizontalFlip(),
                                                transforms.ToTensor()])

            transform_test = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                                                transforms.ToTensor()])
        
        else :
            transform_train = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                                                transforms.RandomHorizontalFlip(),
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])
                                            # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

            transform_test = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),])
                                            # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))


        trainset = torchvision.datasets.CIFAR10(root=data_path,
                                            train=True,
                                            download=True,
                                            transform=transform_train)


        testset = torchvision.datasets.CIFAR10(root=data_path,
                                            train=False,
                                            download=True,
                                            transform=transform_test)

        train_loader = DataLoader(trainset,
                                batch_size =BATCH,
                                shuffle = True,
                                num_workers =2)
        test_loader = DataLoader(testset,
                                batch_size =BATCH,
                                shuffle = False,
                                num_workers =2)

        '''
        classes = ('plane', 'car', 'bird', 'cat', 'deer',
                'dog', 'frog', 'horse', 'ship', 'truck') 
        '''

    if (which_data == 'FASHION_MNIST'):
        data_path = '/data2'

        if rate_coding :
            transform = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                                        transforms.ToTensor()])
        else : 
            transform = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5,),(0.5))])

        trainset = torchvision.datasets.FashionMNIST(root=data_path,
                                            train=True,
                                            download=True,
                                            transform=transform)


        testset = torchvision.datasets.FashionMNIST(root=data_path,
                                            train=False,
                                            download=True,
                                            transform=transform)

        train_loader = DataLoader(trainset,
                                batch_size =BATCH,
                                shuffle = True,
                                num_workers =2)
        test_loader = DataLoader(testset,
                                batch_size =BATCH,
                                shuffle = False,
                                num_workers =2)

        

    data_iter = IMAGE_PIXEL_CHANNEL = iter(train_loader)
    images, labels = data_iter.next()

    # 채널 수와 클래스 개수를 확인합니다.
    synapse_conv_in_channels = images.shape[1]
    synapse_fc_out_features = CLASS_NUM = len(torch.unique(labels))















    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    if pre_trained == False:
        if (convTrue_fcFalse == False):
            net = MY_SNN_FC(cfg, synapse_conv_in_channels, IMAGE_SIZE, synapse_fc_out_features,
                     synapse_fc_trace_const1, synapse_fc_trace_const2, 
                     lif_layer_v_init, lif_layer_v_decay, 
                     lif_layer_v_threshold, lif_layer_v_reset,
                     lif_layer_sg_width,
                     tdBN_on,
                     BN_on, TIME,
                     surrogate,
                     BPTT_on).to(device)
        else:
            net = MY_SNN_CONV(cfg, synapse_conv_in_channels, IMAGE_SIZE,
                     synapse_conv_kernel_size, synapse_conv_stride, 
                     synapse_conv_padding, synapse_conv_trace_const1, 
                     synapse_conv_trace_const2, 
                     lif_layer_v_init, lif_layer_v_decay, 
                     lif_layer_v_threshold, lif_layer_v_reset,
                     lif_layer_sg_width,
                     synapse_fc_out_features, synapse_fc_trace_const1, synapse_fc_trace_const2,
                     tdBN_on,
                     BN_on, TIME,
                     surrogate,
                     BPTT_on).to(device)
        net = torch.nn.DataParallel(net)
    else:
        net = torch.load(pre_trained_path)

    val_acc = 0

    net = net.to(device)

    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)
    # optimizer = torch.optim.Adam(net.parameters(), lr=0.00001)


    for epoch in range(epoch_num):
        print('EPOCH', epoch)
        epoch_start_time = time.time()
        running_loss = 0.0
        # for i, data in enumerate(train_loader, 0):
        for i, data in tqdm(enumerate(train_loader, 0), total=len(train_loader), desc='train', dynamic_ncols=True, position=0, leave=True):
            # print('iter', i)
            net.train()

            # print('\niter', i)
            iter_one_train_time_start = time.time()

            inputs, labels = data

            if rate_coding == True :
                inputs = spikegen.rate(inputs, num_steps=TIME)
            else :
                inputs = inputs.repeat(TIME, 1, 1, 1, 1)
            
            # inputs: [Time, Batch, Channel, Height, Width]   

            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            # inputs: [Time, Batch, Channel, Height, Width]   
            inputs = inputs.permute(1, 0, 2, 3, 4) # net에 넣어줄때는 batch가 젤 앞 차원으로 와야함. # dataparallel때매
            # inputs: [Batch, Time, Channel, Height, Width]   
        
            outputs = net(inputs)

            batch = BATCH 
            if labels.size(0) != BATCH: 
                batch = labels.size(0)


            ####### training accruacy ######
            correct = 0
            total = 0
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted[0:batch] == labels).sum().item()
            if i % verbose_interval == 9:
                print(f'iter: {i} / {len(train_loader)}')
                print(f'training acc: {100 * correct / total:.2f}%')
            ################################
            training_acc_string = f'training acc: {100 * correct / total:.2f}%'

            loss = criterion(outputs[0:batch,:], labels)
            loss.backward()

            # # optimizer.zero_grad()와 loss.backward() 호출 후에 실행해야 합니다.
            if (gradient_verbose == True):
                if (i % 100 == 9):
                    print('\n\nepoch', epoch, 'iter', i)
                    for name, param in net.named_parameters():
                        if param.requires_grad:
                            print('\n\n\n\n' , name, param.grad)
            
            optimizer.step()

            running_loss += loss.item()
            # print("Epoch: {}, Iter: {}, Loss: {}".format(epoch + 1, i + 1, running_loss / 100))

            iter_one_train_time_end = time.time()
            elapsed_time = iter_one_train_time_end - iter_one_train_time_start  # 실행 시간 계산
            # print(f"iter_one_train_time: {elapsed_time} seconds")


            if i % verbose_interval == 9:
                iter_one_val_time_start = time.time()
                
                correct = 0
                total = 0
                with torch.no_grad():
                    net.eval()
                    how_many_val_image=0
                    for data in test_loader:
                        how_many_val_image += 1
                        inputs, labels = data
            
                        if rate_coding == True :
                            inputs = spikegen.rate(inputs, num_steps=TIME)
                        else :
                            inputs = inputs.repeat(TIME, 1, 1, 1, 1)

                        
                        inputs = inputs.to(device)
                        labels = labels.to(device)
                        outputs = net(inputs.permute(1, 0, 2, 3, 4))
                        _, predicted = torch.max(outputs.data, 1)
                        total += labels.size(0)
                        batch = BATCH 
                        if labels.size(0) != BATCH: 
                            batch = labels.size(0)
                        correct += (predicted[0:batch] == labels).sum().item()
                    print(f'{epoch}-{i} validation acc: {100 * correct / total:.2f}%\n')


                iter_one_val_time_end = time.time()
                elapsed_time = iter_one_val_time_end - iter_one_val_time_start  # 실행 시간 계산
                print(f"iter_one_val_time: {elapsed_time} seconds")
                if val_acc < correct / total:
                    val_acc = correct / total
                    torch.save(net.state_dict(), "net_save/save_now_net_weights.pth")
                    torch.save(net, "net_save/save_now_net.pth")
                    torch.save(net.module.state_dict(), "net_save/save_now_net_weights2.pth")
                    torch.save(net.module, "net_save/save_now_net2.pth")
        epoch_time_end = time.time()
        epoch_time = epoch_time_end - epoch_start_time  # 실행 시간 계산
        
        print(f"epoch_time: {epoch_time} seconds")
        print('\n')



In [4]:
decay = 0.95

my_snn_system(  devices = "2,3,4,5",
                my_seed = 42,
                TIME = 6,
                BATCH = 128,
                IMAGE_SIZE = 32,
                which_data = 'CIFAR10',# 'CIFAR10' 'MNIST' 'FASHION_MNIST'
                rate_coding = False, # True # False

                lif_layer_v_init = 0.0,
                lif_layer_v_decay = decay,
                lif_layer_v_threshold = 1.2,
                lif_layer_v_reset = 0.0, #현재 안씀. 걍 빼기 해버림
                lif_layer_sg_width = 1.0, # surrogate sigmoid 쓸 때는 의미없음

                # synapse_conv_in_channels = IMAGE_PIXEL_CHANNEL,
                synapse_conv_kernel_size = 3,
                synapse_conv_stride = 1,
                synapse_conv_padding = 1,
                synapse_conv_trace_const1 = 1,
                synapse_conv_trace_const2 = decay, # lif_layer_v_decay

                # synapse_fc_out_features = CLASS_NUM,
                synapse_fc_trace_const1 = 1,
                synapse_fc_trace_const2 = decay, # lif_layer_v_decay

                pre_trained = False, # True # False
                convTrue_fcFalse = True, # True # False
                cfg = [64, 128, 'P', 256, 256, 'P', 512, 512, 'P', 512, 512], # 끝에 linear classifier 하나 자동으로 붙습니다
                pre_trained_path = "net_save/save_now_net.pth",
                learning_rate = 0.00001,
                epoch_num = 200,
                verbose_interval = 100, #숫자 크게 하면 꺼짐
                tdBN_on = False,  # True # False
                BN_on = True,  # True # False
                
                surrogate = 'sigmoid', # 'rectangle' 'sigmoid' 'rough_rectangle'
                
                gradient_verbose = False,  # True # False

                BPTT_on = False  # True # False
                )

'''
cfg 종류 = {
[64, 64]
[64, 64, 64, 64]
[64, 128, 256]
[64, 128, 'P', 256, 256, 'P', 512, 512, 'P', 512, 512]
[64, 64]
}
'''


Files already downloaded and verified
Files already downloaded and verified
EPOCH 0


train:   2%|▏         | 9/391 [00:16<05:55,  1.07it/s]  

iter: 9 / 391
training acc: 10.16%
0-9 validation acc: 9.89%

iter_one_val_time: 7.685722351074219 seconds


train:  28%|██▊       | 109/391 [01:40<03:37,  1.30it/s]

iter: 109 / 391
training acc: 18.75%
0-109 validation acc: 19.17%

iter_one_val_time: 7.694387674331665 seconds


train:  53%|█████▎    | 209/391 [03:07<02:23,  1.27it/s]

iter: 209 / 391
training acc: 25.78%
0-209 validation acc: 22.44%

iter_one_val_time: 7.920332908630371 seconds


train:  79%|███████▉  | 309/391 [04:36<01:08,  1.20it/s]

iter: 309 / 391
training acc: 20.31%
0-309 validation acc: 24.80%

iter_one_val_time: 7.904771327972412 seconds


train: 100%|██████████| 391/391 [05:50<00:00,  1.11it/s]

epoch_time: 350.9365441799164 seconds


EPOCH 1



train:   2%|▏         | 9/391 [00:07<05:06,  1.25it/s]

iter: 9 / 391
training acc: 31.25%
1-9 validation acc: 25.82%

iter_one_val_time: 7.8667378425598145 seconds


train:  28%|██▊       | 109/391 [01:36<03:55,  1.20it/s]

iter: 109 / 391
training acc: 24.22%
1-109 validation acc: 27.87%

iter_one_val_time: 7.872048854827881 seconds


train:  53%|█████▎    | 209/391 [03:06<02:28,  1.23it/s]

iter: 209 / 391
training acc: 29.69%
1-209 validation acc: 30.60%

iter_one_val_time: 7.785880088806152 seconds


train:  79%|███████▉  | 309/391 [04:36<01:07,  1.22it/s]

iter: 309 / 391
training acc: 32.03%


train:  79%|███████▉  | 310/391 [04:45<04:18,  3.19s/it]

1-309 validation acc: 29.33%

iter_one_val_time: 7.906373977661133 seconds


train: 100%|██████████| 391/391 [05:50<00:00,  1.12it/s]

epoch_time: 350.8450846672058 seconds


EPOCH 2



train:   2%|▏         | 9/391 [00:07<05:11,  1.23it/s]

iter: 9 / 391
training acc: 31.25%
2-9 validation acc: 31.87%

iter_one_val_time: 7.671346426010132 seconds


train:  28%|██▊       | 109/391 [01:37<03:47,  1.24it/s]

iter: 109 / 391
training acc: 32.81%
2-109 validation acc: 32.65%

iter_one_val_time: 7.903830528259277 seconds


train:  53%|█████▎    | 209/391 [03:07<02:27,  1.23it/s]

iter: 209 / 391
training acc: 35.16%
2-209 validation acc: 33.12%

iter_one_val_time: 8.043986797332764 seconds


train:  79%|███████▉  | 309/391 [04:37<01:06,  1.24it/s]

iter: 309 / 391
training acc: 31.25%
2-309 validation acc: 33.73%

iter_one_val_time: 7.909653663635254 seconds


train: 100%|██████████| 391/391 [05:51<00:00,  1.11it/s]

epoch_time: 352.00875639915466 seconds


EPOCH 3



train:   2%|▏         | 9/391 [00:07<05:10,  1.23it/s]

iter: 9 / 391
training acc: 28.12%
3-9 validation acc: 34.16%

iter_one_val_time: 7.790113210678101 seconds


train:  28%|██▊       | 109/391 [01:37<03:51,  1.22it/s]

iter: 109 / 391
training acc: 35.94%
3-109 validation acc: 34.45%

iter_one_val_time: 7.782779216766357 seconds


train:  53%|█████▎    | 209/391 [03:06<02:27,  1.23it/s]

iter: 209 / 391
training acc: 31.25%
3-209 validation acc: 35.36%

iter_one_val_time: 7.874624013900757 seconds


train:  79%|███████▉  | 309/391 [04:36<01:05,  1.24it/s]

iter: 309 / 391
training acc: 35.94%
3-309 validation acc: 36.62%

iter_one_val_time: 7.801546335220337 seconds


train: 100%|██████████| 391/391 [05:51<00:00,  1.11it/s]

epoch_time: 351.8660707473755 seconds


EPOCH 4



train:   2%|▏         | 9/391 [00:07<05:11,  1.23it/s]

iter: 9 / 391
training acc: 42.97%


train:   3%|▎         | 10/391 [00:16<20:27,  3.22s/it]

4-9 validation acc: 36.56%

iter_one_val_time: 7.803544282913208 seconds


train:  28%|██▊       | 109/391 [01:36<03:48,  1.23it/s]

iter: 109 / 391
training acc: 32.81%
4-109 validation acc: 37.61%

iter_one_val_time: 7.7427074909210205 seconds


train:  53%|█████▎    | 209/391 [03:05<02:26,  1.25it/s]

iter: 209 / 391
training acc: 34.38%


train:  54%|█████▎    | 210/391 [03:14<09:27,  3.13s/it]

4-209 validation acc: 37.60%

iter_one_val_time: 7.762232065200806 seconds


train:  79%|███████▉  | 309/391 [04:34<01:05,  1.24it/s]

iter: 309 / 391
training acc: 42.19%
4-309 validation acc: 38.13%

iter_one_val_time: 7.783702850341797 seconds


train: 100%|██████████| 391/391 [05:48<00:00,  1.12it/s]

epoch_time: 348.93655037879944 seconds


EPOCH 5



train:   2%|▏         | 9/391 [00:07<05:16,  1.21it/s]

iter: 9 / 391
training acc: 46.09%
5-9 validation acc: 38.33%

iter_one_val_time: 7.715498924255371 seconds


train:  28%|██▊       | 109/391 [01:36<03:52,  1.21it/s]

iter: 109 / 391
training acc: 39.06%
5-109 validation acc: 40.16%

iter_one_val_time: 7.87175989151001 seconds


train:  53%|█████▎    | 209/391 [03:06<02:27,  1.24it/s]

iter: 209 / 391
training acc: 42.19%


train:  54%|█████▎    | 210/391 [03:14<09:21,  3.10s/it]

5-209 validation acc: 39.05%

iter_one_val_time: 7.654790878295898 seconds


train:  79%|███████▉  | 309/391 [04:35<01:06,  1.23it/s]

iter: 309 / 391
training acc: 35.94%


train:  79%|███████▉  | 310/391 [04:43<04:10,  3.10s/it]

5-309 validation acc: 38.85%

iter_one_val_time: 7.619792461395264 seconds


train: 100%|██████████| 391/391 [05:49<00:00,  1.12it/s]

epoch_time: 349.2941462993622 seconds


EPOCH 6



train:   2%|▏         | 9/391 [00:07<05:08,  1.24it/s]

iter: 9 / 391
training acc: 42.19%
6-9 validation acc: 41.50%

iter_one_val_time: 7.817790269851685 seconds


train:  28%|██▊       | 109/391 [01:37<03:47,  1.24it/s]

iter: 109 / 391
training acc: 37.50%


train:  28%|██▊       | 110/391 [01:45<14:34,  3.11s/it]

6-109 validation acc: 41.23%

iter_one_val_time: 7.669596910476685 seconds


train:  53%|█████▎    | 209/391 [03:05<02:26,  1.24it/s]

iter: 209 / 391
training acc: 39.84%


train:  54%|█████▎    | 210/391 [03:14<09:29,  3.15s/it]

6-209 validation acc: 40.96%

iter_one_val_time: 7.80966329574585 seconds


train:  79%|███████▉  | 309/391 [04:35<01:07,  1.21it/s]

iter: 309 / 391
training acc: 48.44%


train:  79%|███████▉  | 310/391 [04:43<04:19,  3.21s/it]

6-309 validation acc: 40.73%

iter_one_val_time: 7.9535088539123535 seconds


train: 100%|██████████| 391/391 [05:49<00:00,  1.12it/s]

epoch_time: 349.7991576194763 seconds


EPOCH 7



train:   2%|▏         | 9/391 [00:07<05:13,  1.22it/s]

iter: 9 / 391
training acc: 38.28%


train:   3%|▎         | 10/391 [00:16<20:49,  3.28s/it]

7-9 validation acc: 40.57%

iter_one_val_time: 7.971982002258301 seconds


train:  28%|██▊       | 109/391 [01:36<03:49,  1.23it/s]

iter: 109 / 391
training acc: 36.72%
7-109 validation acc: 43.05%

iter_one_val_time: 7.811573505401611 seconds


train:  53%|█████▎    | 209/391 [03:05<02:26,  1.24it/s]

iter: 209 / 391
training acc: 39.84%


train:  54%|█████▎    | 210/391 [03:14<09:25,  3.12s/it]

7-209 validation acc: 42.67%

iter_one_val_time: 7.732365369796753 seconds


train:  79%|███████▉  | 309/391 [04:34<01:06,  1.24it/s]

iter: 309 / 391
training acc: 39.84%


train:  79%|███████▉  | 310/391 [04:42<04:15,  3.15s/it]

7-309 validation acc: 42.11%

iter_one_val_time: 7.809930086135864 seconds


train: 100%|██████████| 391/391 [05:48<00:00,  1.12it/s]

epoch_time: 348.6072943210602 seconds


EPOCH 8



train:   2%|▏         | 9/391 [00:07<05:09,  1.24it/s]

iter: 9 / 391
training acc: 47.66%
8-9 validation acc: 43.66%

iter_one_val_time: 7.823739051818848 seconds


train:  28%|██▊       | 109/391 [01:36<03:49,  1.23it/s]

iter: 109 / 391
training acc: 46.09%


train:  28%|██▊       | 110/391 [01:45<14:47,  3.16s/it]

8-109 validation acc: 42.81%

iter_one_val_time: 7.819440126419067 seconds


train:  53%|█████▎    | 209/391 [03:05<02:28,  1.23it/s]

iter: 209 / 391
training acc: 45.31%
8-209 validation acc: 44.56%

iter_one_val_time: 8.012867450714111 seconds


train:  79%|███████▉  | 309/391 [04:35<01:06,  1.24it/s]

iter: 309 / 391
training acc: 49.22%
8-309 validation acc: 44.68%

iter_one_val_time: 7.906424283981323 seconds


train: 100%|██████████| 391/391 [05:50<00:00,  1.12it/s]

epoch_time: 350.32393312454224 seconds


EPOCH 9



train:   2%|▏         | 9/391 [00:07<05:11,  1.23it/s]

iter: 9 / 391
training acc: 46.88%
9-9 validation acc: 45.36%

iter_one_val_time: 8.020035982131958 seconds


train:  28%|██▊       | 109/391 [01:37<03:50,  1.22it/s]

iter: 109 / 391
training acc: 46.88%
9-109 validation acc: 45.54%

iter_one_val_time: 8.09017300605774 seconds


train:  53%|█████▎    | 209/391 [03:06<02:26,  1.24it/s]

iter: 209 / 391
training acc: 46.09%
9-209 validation acc: 46.99%

iter_one_val_time: 7.82838773727417 seconds


train:  79%|███████▉  | 309/391 [04:35<01:06,  1.24it/s]

iter: 309 / 391
training acc: 48.44%


train:  79%|███████▉  | 310/391 [04:44<04:13,  3.13s/it]

9-309 validation acc: 46.27%

iter_one_val_time: 7.732442140579224 seconds


train: 100%|██████████| 391/391 [05:49<00:00,  1.12it/s]

epoch_time: 349.9833028316498 seconds


EPOCH 10



train:   2%|▏         | 9/391 [00:07<05:10,  1.23it/s]

iter: 9 / 391
training acc: 39.06%


train:   3%|▎         | 10/391 [00:15<20:17,  3.20s/it]

10-9 validation acc: 45.49%

iter_one_val_time: 7.731903553009033 seconds


train:  28%|██▊       | 109/391 [01:36<03:46,  1.24it/s]

iter: 109 / 391
training acc: 45.31%


train:  28%|██▊       | 110/391 [01:45<15:00,  3.21s/it]

10-109 validation acc: 46.12%

iter_one_val_time: 7.996338129043579 seconds


train:  53%|█████▎    | 209/391 [03:05<02:26,  1.24it/s]

iter: 209 / 391
training acc: 43.75%


train:  54%|█████▎    | 210/391 [03:13<09:23,  3.11s/it]

10-209 validation acc: 46.22%

iter_one_val_time: 7.687251091003418 seconds


train:  79%|███████▉  | 309/391 [04:33<01:06,  1.24it/s]

iter: 309 / 391
training acc: 47.66%
10-309 validation acc: 47.17%

iter_one_val_time: 7.741566181182861 seconds


train: 100%|██████████| 391/391 [05:48<00:00,  1.12it/s]

epoch_time: 348.73838090896606 seconds


EPOCH 11



train:   2%|▏         | 9/391 [00:07<05:10,  1.23it/s]

iter: 9 / 391
training acc: 44.53%


train:   3%|▎         | 10/391 [00:15<20:19,  3.20s/it]

11-9 validation acc: 45.94%

iter_one_val_time: 7.743763446807861 seconds


train:  28%|██▊       | 109/391 [01:36<03:47,  1.24it/s]

iter: 109 / 391
training acc: 46.09%
11-109 validation acc: 48.52%

iter_one_val_time: 7.65193247795105 seconds


train:  53%|█████▎    | 209/391 [03:05<02:32,  1.20it/s]

iter: 209 / 391
training acc: 46.88%


train:  54%|█████▎    | 210/391 [03:13<09:25,  3.12s/it]

11-209 validation acc: 48.20%

iter_one_val_time: 7.648504972457886 seconds


train:  79%|███████▉  | 309/391 [04:33<01:05,  1.25it/s]

iter: 309 / 391
training acc: 46.09%
11-309 validation acc: 48.82%

iter_one_val_time: 7.943361520767212 seconds


train: 100%|██████████| 391/391 [05:48<00:00,  1.12it/s]

epoch_time: 349.05605697631836 seconds


EPOCH 12



train:   2%|▏         | 9/391 [00:07<05:08,  1.24it/s]

iter: 9 / 391
training acc: 42.19%
12-9 validation acc: 49.63%

iter_one_val_time: 7.99199104309082 seconds


train:  28%|██▊       | 109/391 [01:38<03:50,  1.22it/s]

iter: 109 / 391
training acc: 47.66%


train:  28%|██▊       | 110/391 [01:46<14:49,  3.17s/it]

12-109 validation acc: 49.15%

iter_one_val_time: 7.828287363052368 seconds


train:  53%|█████▎    | 209/391 [03:06<02:27,  1.24it/s]

iter: 209 / 391
training acc: 46.09%
12-209 validation acc: 50.67%

iter_one_val_time: 7.839506149291992 seconds


train:  79%|███████▉  | 309/391 [04:36<01:07,  1.22it/s]

iter: 309 / 391
training acc: 50.00%


train:  79%|███████▉  | 310/391 [04:44<04:14,  3.14s/it]

12-309 validation acc: 47.02%

iter_one_val_time: 7.731997966766357 seconds


train: 100%|██████████| 391/391 [05:50<00:00,  1.11it/s]

epoch_time: 350.9204785823822 seconds


EPOCH 13



train:   2%|▏         | 9/391 [00:07<05:07,  1.24it/s]

iter: 9 / 391
training acc: 55.47%


train:   3%|▎         | 10/391 [00:15<20:23,  3.21s/it]

13-9 validation acc: 49.54%

iter_one_val_time: 7.792961835861206 seconds


train:  28%|██▊       | 109/391 [01:36<03:50,  1.23it/s]

iter: 109 / 391
training acc: 50.78%


train:  28%|██▊       | 110/391 [01:44<14:35,  3.12s/it]

13-109 validation acc: 49.13%

iter_one_val_time: 7.664130210876465 seconds


train:  53%|█████▎    | 209/391 [03:05<02:27,  1.24it/s]

iter: 209 / 391
training acc: 53.12%


train:  54%|█████▎    | 210/391 [03:13<09:26,  3.13s/it]

13-209 validation acc: 50.24%

iter_one_val_time: 7.737093448638916 seconds


train:  79%|███████▉  | 309/391 [04:34<01:06,  1.23it/s]

iter: 309 / 391
training acc: 47.66%


train:  79%|███████▉  | 310/391 [04:42<04:13,  3.13s/it]

13-309 validation acc: 50.03%

iter_one_val_time: 7.737652540206909 seconds


train: 100%|██████████| 391/391 [05:47<00:00,  1.12it/s]

epoch_time: 348.2054831981659 seconds


EPOCH 14



train:   2%|▏         | 9/391 [00:07<05:09,  1.24it/s]

iter: 9 / 391
training acc: 50.78%


train:   3%|▎         | 10/391 [00:15<20:11,  3.18s/it]

14-9 validation acc: 50.46%

iter_one_val_time: 7.6896071434021 seconds


train:  28%|██▊       | 109/391 [01:36<03:47,  1.24it/s]

iter: 109 / 391
training acc: 45.31%


train:  28%|██▊       | 110/391 [01:44<14:19,  3.06s/it]

14-109 validation acc: 50.18%

iter_one_val_time: 7.494447708129883 seconds


train:  53%|█████▎    | 209/391 [03:09<02:27,  1.24it/s]

iter: 209 / 391
training acc: 54.69%
14-209 validation acc: 52.08%

iter_one_val_time: 8.034749746322632 seconds


train:  79%|███████▉  | 309/391 [04:39<01:06,  1.23it/s]

iter: 309 / 391
training acc: 51.56%


train:  79%|███████▉  | 310/391 [04:47<04:12,  3.12s/it]

14-309 validation acc: 50.92%

iter_one_val_time: 7.694060564041138 seconds


train: 100%|██████████| 391/391 [05:53<00:00,  1.11it/s]

epoch_time: 353.8044550418854 seconds


EPOCH 15



train:   2%|▏         | 9/391 [00:07<05:08,  1.24it/s]

iter: 9 / 391
training acc: 44.53%


train:   3%|▎         | 10/391 [00:15<20:02,  3.16s/it]

15-9 validation acc: 51.47%

iter_one_val_time: 7.602135896682739 seconds


train:  28%|██▊       | 109/391 [01:35<03:50,  1.22it/s]

iter: 109 / 391
training acc: 53.91%
15-109 validation acc: 52.36%

iter_one_val_time: 7.759235858917236 seconds


train:  53%|█████▎    | 209/391 [03:27<04:09,  1.37s/it]

iter: 209 / 391
training acc: 47.66%


train:  54%|█████▎    | 210/391 [03:38<12:34,  4.17s/it]

15-209 validation acc: 51.59%

iter_one_val_time: 9.321491956710815 seconds


train:  79%|███████▉  | 309/391 [05:54<01:54,  1.39s/it]

iter: 309 / 391
training acc: 48.44%


train:  79%|███████▉  | 310/391 [06:04<05:36,  4.16s/it]

15-309 validation acc: 51.27%

iter_one_val_time: 9.242517709732056 seconds


train: 100%|██████████| 391/391 [07:48<00:00,  1.20s/it]

epoch_time: 468.5993592739105 seconds


EPOCH 16



train:   2%|▏         | 9/391 [00:12<08:40,  1.36s/it]

iter: 9 / 391
training acc: 57.81%
16-9 validation acc: 52.53%

iter_one_val_time: 9.134217500686646 seconds


train:  28%|██▊       | 109/391 [02:39<06:27,  1.37s/it]

iter: 109 / 391
training acc: 50.00%
16-109 validation acc: 52.63%

iter_one_val_time: 9.310616254806519 seconds


train:  53%|█████▎    | 209/391 [05:01<04:13,  1.39s/it]

iter: 209 / 391
training acc: 50.00%
16-209 validation acc: 53.01%

iter_one_val_time: 9.556948184967041 seconds


train:  79%|███████▉  | 309/391 [07:07<01:06,  1.24it/s]

iter: 309 / 391
training acc: 51.56%
16-309 validation acc: 54.23%

iter_one_val_time: 7.989675283432007 seconds


train: 100%|██████████| 391/391 [08:26<00:00,  1.29s/it]

epoch_time: 506.35412788391113 seconds


EPOCH 17



train:   2%|▏         | 9/391 [00:07<05:16,  1.21it/s]

iter: 9 / 391
training acc: 44.53%


train:   3%|▎         | 10/391 [00:16<20:57,  3.30s/it]

17-9 validation acc: 51.29%

iter_one_val_time: 8.010636329650879 seconds


train:  28%|██▊       | 109/391 [01:37<03:48,  1.23it/s]

iter: 109 / 391
training acc: 62.50%


train:  28%|██▊       | 110/391 [01:46<14:50,  3.17s/it]

17-109 validation acc: 53.67%

iter_one_val_time: 7.863712787628174 seconds


train:  53%|█████▎    | 209/391 [03:07<02:28,  1.22it/s]

iter: 209 / 391
training acc: 50.78%


train:  54%|█████▎    | 210/391 [03:16<09:46,  3.24s/it]

17-209 validation acc: 53.12%

iter_one_val_time: 8.075648069381714 seconds


train:  79%|███████▉  | 309/391 [04:37<01:07,  1.22it/s]

iter: 309 / 391
training acc: 55.47%


train:  79%|███████▉  | 310/391 [04:46<04:20,  3.21s/it]

17-309 validation acc: 53.46%

iter_one_val_time: 7.9820945262908936 seconds


train: 100%|██████████| 391/391 [05:53<00:00,  1.11it/s]

epoch_time: 353.22983169555664 seconds


EPOCH 18



train:   2%|▏         | 9/391 [00:07<05:25,  1.17it/s]

iter: 9 / 391
training acc: 53.12%


train:   3%|▎         | 10/391 [00:17<22:58,  3.62s/it]

18-9 validation acc: 54.03%

iter_one_val_time: 8.993160963058472 seconds


train:  28%|██▊       | 109/391 [01:38<03:51,  1.22it/s]

iter: 109 / 391
training acc: 45.31%


train:  28%|██▊       | 110/391 [01:47<15:40,  3.35s/it]

18-109 validation acc: 53.42%

iter_one_val_time: 8.433325052261353 seconds


train:  53%|█████▎    | 209/391 [03:11<02:29,  1.22it/s]

iter: 209 / 391
training acc: 54.69%


train:  54%|█████▎    | 210/391 [03:20<10:25,  3.46s/it]

18-209 validation acc: 54.09%

iter_one_val_time: 8.764744997024536 seconds


train:  79%|███████▉  | 309/391 [04:42<01:09,  1.19it/s]

iter: 309 / 391
training acc: 54.69%
18-309 validation acc: 54.43%

iter_one_val_time: 7.631380081176758 seconds


train: 100%|██████████| 391/391 [06:00<00:00,  1.09it/s]

epoch_time: 360.5802607536316 seconds


EPOCH 19



train:   2%|▏         | 9/391 [00:07<05:13,  1.22it/s]

iter: 9 / 391
training acc: 60.94%


train:   3%|▎         | 10/391 [00:17<22:49,  3.59s/it]

19-9 validation acc: 53.09%

iter_one_val_time: 8.98316240310669 seconds


train:  28%|██▊       | 109/391 [01:38<03:51,  1.22it/s]

iter: 109 / 391
training acc: 53.91%
19-109 validation acc: 54.75%

iter_one_val_time: 9.01245665550232 seconds


train:  53%|█████▎    | 209/391 [03:13<02:33,  1.19it/s]

iter: 209 / 391
training acc: 46.88%
19-209 validation acc: 55.44%

iter_one_val_time: 9.215416431427002 seconds


train:  79%|███████▉  | 309/391 [04:47<01:08,  1.20it/s]

iter: 309 / 391
training acc: 62.50%
19-309 validation acc: 55.74%

iter_one_val_time: 9.267043352127075 seconds


train: 100%|██████████| 391/391 [06:11<00:00,  1.05it/s]

epoch_time: 371.28404211997986 seconds


EPOCH 20



train:   2%|▏         | 9/391 [00:07<05:23,  1.18it/s]

iter: 9 / 391
training acc: 60.16%


train:   3%|▎         | 10/391 [00:17<22:34,  3.55s/it]

20-9 validation acc: 54.23%

iter_one_val_time: 8.76214051246643 seconds


train:  28%|██▊       | 109/391 [01:40<03:58,  1.18it/s]

iter: 109 / 391
training acc: 53.91%


train:  28%|██▊       | 110/391 [01:50<16:31,  3.53s/it]

20-109 validation acc: 54.50%

iter_one_val_time: 8.965749979019165 seconds


train:  53%|█████▎    | 209/391 [03:12<02:33,  1.18it/s]

iter: 209 / 391
training acc: 53.91%


train:  54%|█████▎    | 210/391 [03:23<11:41,  3.87s/it]

20-209 validation acc: 54.70%

iter_one_val_time: 10.088842391967773 seconds


train:  79%|███████▉  | 309/391 [04:47<01:07,  1.21it/s]

iter: 309 / 391
training acc: 46.88%


train:  79%|███████▉  | 310/391 [04:56<04:42,  3.49s/it]

20-309 validation acc: 53.65%

iter_one_val_time: 8.90199613571167 seconds


train: 100%|██████████| 391/391 [06:03<00:00,  1.08it/s]

epoch_time: 363.5675299167633 seconds


EPOCH 21



train:   2%|▏         | 9/391 [00:07<05:18,  1.20it/s]

iter: 9 / 391
training acc: 53.12%


train:   3%|▎         | 10/391 [00:17<22:01,  3.47s/it]

21-9 validation acc: 55.16%

iter_one_val_time: 8.525668621063232 seconds


train:  28%|██▊       | 109/391 [01:40<04:00,  1.17it/s]

iter: 109 / 391
training acc: 48.44%
21-109 validation acc: 56.21%

iter_one_val_time: 8.649115800857544 seconds


train:  53%|█████▎    | 209/391 [03:47<02:34,  1.18it/s]

iter: 209 / 391
training acc: 51.56%


train:  54%|█████▎    | 210/391 [03:59<12:37,  4.19s/it]

21-209 validation acc: 55.88%

iter_one_val_time: 10.58210039138794 seconds


train:  79%|███████▉  | 309/391 [06:15<01:51,  1.36s/it]

iter: 309 / 391
training acc: 53.12%


train:  79%|███████▉  | 310/391 [06:26<06:00,  4.45s/it]

21-309 validation acc: 56.04%

iter_one_val_time: 10.28677773475647 seconds


train: 100%|██████████| 391/391 [08:13<00:00,  1.26s/it]

epoch_time: 493.6343421936035 seconds


EPOCH 22



train:   2%|▏         | 9/391 [00:12<08:33,  1.34s/it]

iter: 9 / 391
training acc: 51.56%


train:   3%|▎         | 10/391 [00:23<27:04,  4.26s/it]

22-9 validation acc: 56.09%

iter_one_val_time: 9.448765277862549 seconds


train:  28%|██▊       | 109/391 [02:39<06:23,  1.36s/it]

iter: 109 / 391
training acc: 60.94%
22-109 validation acc: 56.26%

iter_one_val_time: 9.458154201507568 seconds


train:  53%|█████▎    | 209/391 [05:04<04:00,  1.32s/it]

iter: 209 / 391
training acc: 57.03%
22-209 validation acc: 56.31%

iter_one_val_time: 9.468120574951172 seconds


train:  79%|███████▉  | 309/391 [07:32<01:48,  1.33s/it]

iter: 309 / 391
training acc: 53.91%
22-309 validation acc: 56.90%

iter_one_val_time: 9.414029598236084 seconds


train: 100%|██████████| 391/391 [09:35<00:00,  1.47s/it]

epoch_time: 575.8832693099976 seconds


EPOCH 23



train:   2%|▏         | 9/391 [00:12<08:34,  1.35s/it]

iter: 9 / 391
training acc: 67.97%


train:   3%|▎         | 10/391 [00:22<26:40,  4.20s/it]

23-9 validation acc: 56.12%

iter_one_val_time: 9.2360680103302 seconds


train:  28%|██▊       | 109/391 [02:16<03:48,  1.23it/s]

iter: 109 / 391
training acc: 55.47%
23-109 validation acc: 56.96%

iter_one_val_time: 7.771877288818359 seconds


train:  53%|█████▎    | 209/391 [04:28<03:59,  1.32s/it]

iter: 209 / 391
training acc: 62.50%


train:  54%|█████▎    | 210/391 [04:39<12:37,  4.18s/it]

23-209 validation acc: 55.71%

iter_one_val_time: 9.513392925262451 seconds


train:  79%|███████▉  | 309/391 [06:20<01:07,  1.22it/s]

iter: 309 / 391
training acc: 58.59%


train:  79%|███████▉  | 310/391 [06:29<04:28,  3.31s/it]

23-309 validation acc: 56.96%

iter_one_val_time: 8.323137998580933 seconds


train: 100%|██████████| 391/391 [07:35<00:00,  1.17s/it]

epoch_time: 456.0833361148834 seconds


EPOCH 24



train:   2%|▏         | 9/391 [00:07<05:22,  1.18it/s]

iter: 9 / 391
training acc: 58.59%
24-9 validation acc: 57.80%

iter_one_val_time: 8.856001853942871 seconds


train:  28%|██▊       | 109/391 [01:42<03:59,  1.18it/s]

iter: 109 / 391
training acc: 56.25%


train:  28%|██▊       | 110/391 [01:53<18:19,  3.91s/it]

24-109 validation acc: 56.76%

iter_one_val_time: 10.21252989768982 seconds


train:  53%|█████▎    | 209/391 [03:18<02:34,  1.18it/s]

iter: 209 / 391
training acc: 55.47%


train:  54%|█████▎    | 210/391 [03:29<11:46,  3.91s/it]

24-209 validation acc: 56.67%

iter_one_val_time: 10.189338207244873 seconds


train:  79%|███████▉  | 309/391 [04:53<01:12,  1.13it/s]

iter: 309 / 391
training acc: 57.81%
24-309 validation acc: 58.25%

iter_one_val_time: 10.122416257858276 seconds


train: 100%|██████████| 391/391 [06:14<00:00,  1.04it/s]

epoch_time: 374.68619441986084 seconds


EPOCH 25



train:   2%|▏         | 9/391 [00:07<05:24,  1.18it/s]

iter: 9 / 391
training acc: 58.59%


train:   3%|▎         | 10/391 [00:18<25:36,  4.03s/it]

25-9 validation acc: 56.82%

iter_one_val_time: 10.29680609703064 seconds


train:  28%|██▊       | 109/391 [01:42<03:59,  1.18it/s]

iter: 109 / 391
training acc: 61.72%


train:  28%|██▊       | 110/391 [01:53<17:04,  3.65s/it]

25-109 validation acc: 57.44%

iter_one_val_time: 9.363298654556274 seconds


train:  53%|█████▎    | 209/391 [03:16<02:32,  1.19it/s]

iter: 209 / 391
training acc: 55.47%


train:  54%|█████▎    | 210/391 [03:26<11:01,  3.65s/it]

25-209 validation acc: 57.40%

iter_one_val_time: 9.396446466445923 seconds


train:  79%|███████▉  | 309/391 [04:50<01:08,  1.19it/s]

iter: 309 / 391
training acc: 63.28%
25-309 validation acc: 58.26%

iter_one_val_time: 9.333195686340332 seconds


train: 100%|██████████| 391/391 [06:08<00:00,  1.06it/s]

epoch_time: 368.58819365501404 seconds


EPOCH 26



train:   2%|▏         | 9/391 [00:07<05:21,  1.19it/s]

iter: 9 / 391
training acc: 55.47%
26-9 validation acc: 58.61%

iter_one_val_time: 9.426865339279175 seconds


train:  28%|██▊       | 109/391 [01:41<04:00,  1.17it/s]

iter: 109 / 391
training acc: 59.38%
26-109 validation acc: 58.99%

iter_one_val_time: 9.682835102081299 seconds


train:  53%|█████▎    | 209/391 [03:15<02:28,  1.23it/s]

iter: 209 / 391
training acc: 50.78%


train:  54%|█████▎    | 210/391 [03:24<09:50,  3.26s/it]

26-209 validation acc: 57.27%

iter_one_val_time: 8.062931537628174 seconds


train:  79%|███████▉  | 309/391 [04:45<01:06,  1.23it/s]

iter: 309 / 391
training acc: 61.72%


train:  79%|███████▉  | 310/391 [04:54<04:16,  3.17s/it]

26-309 validation acc: 57.60%

iter_one_val_time: 7.7615439891815186 seconds


train: 100%|██████████| 391/391 [06:00<00:00,  1.09it/s]

epoch_time: 360.5681960582733 seconds


EPOCH 27



train:   2%|▏         | 9/391 [00:07<05:16,  1.21it/s]

iter: 9 / 391
training acc: 60.94%


train:   3%|▎         | 10/391 [00:16<21:38,  3.41s/it]

27-9 validation acc: 57.72%

iter_one_val_time: 8.37693190574646 seconds


train:  28%|██▊       | 109/391 [01:37<03:49,  1.23it/s]

iter: 109 / 391
training acc: 64.06%


train:  28%|██▊       | 110/391 [01:46<14:57,  3.20s/it]

27-109 validation acc: 58.29%

iter_one_val_time: 7.9428441524505615 seconds


train:  53%|█████▎    | 209/391 [03:07<02:29,  1.21it/s]

iter: 209 / 391
training acc: 62.50%


train:  54%|█████▎    | 210/391 [03:16<09:29,  3.15s/it]

27-209 validation acc: 57.52%

iter_one_val_time: 7.76222825050354 seconds


train:  79%|███████▉  | 309/391 [04:37<01:06,  1.23it/s]

iter: 309 / 391
training acc: 57.81%


train:  79%|███████▉  | 310/391 [04:45<04:19,  3.20s/it]

27-309 validation acc: 55.80%

iter_one_val_time: 7.870113372802734 seconds


train: 100%|██████████| 391/391 [05:51<00:00,  1.11it/s]

epoch_time: 352.1470625400543 seconds


EPOCH 28



train:   2%|▏         | 9/391 [00:07<05:11,  1.23it/s]

iter: 9 / 391
training acc: 63.28%


train:   3%|▎         | 10/391 [00:15<20:01,  3.15s/it]

28-9 validation acc: 57.21%

iter_one_val_time: 7.566985130310059 seconds


train:  28%|██▊       | 109/391 [01:37<03:54,  1.20it/s]

iter: 109 / 391
training acc: 67.19%
28-109 validation acc: 59.43%

iter_one_val_time: 8.130760192871094 seconds


train:  53%|█████▎    | 209/391 [03:08<02:29,  1.22it/s]

iter: 209 / 391
training acc: 58.59%
28-209 validation acc: 60.22%

iter_one_val_time: 7.8469624519348145 seconds


train:  79%|███████▉  | 309/391 [04:37<01:06,  1.23it/s]

iter: 309 / 391
training acc: 67.19%


train:  79%|███████▉  | 310/391 [04:46<04:14,  3.15s/it]

28-309 validation acc: 59.09%

iter_one_val_time: 7.703172445297241 seconds


train: 100%|██████████| 391/391 [05:52<00:00,  1.11it/s]

epoch_time: 352.458034992218 seconds


EPOCH 29



train:   2%|▏         | 9/391 [00:07<05:11,  1.23it/s]

iter: 9 / 391
training acc: 56.25%


train:   3%|▎         | 10/391 [00:15<20:22,  3.21s/it]

29-9 validation acc: 58.00%

iter_one_val_time: 7.764055252075195 seconds


train:  28%|██▊       | 109/391 [01:36<03:47,  1.24it/s]

iter: 109 / 391
training acc: 60.16%


train:  28%|██▊       | 110/391 [01:44<14:35,  3.12s/it]

29-109 validation acc: 59.77%

iter_one_val_time: 7.696803331375122 seconds


train:  53%|█████▎    | 209/391 [03:05<02:27,  1.24it/s]

iter: 209 / 391
training acc: 59.38%


train:  54%|█████▎    | 210/391 [03:14<09:35,  3.18s/it]

29-209 validation acc: 60.20%

iter_one_val_time: 7.90010142326355 seconds


train:  79%|███████▉  | 309/391 [04:34<01:06,  1.24it/s]

iter: 309 / 391
training acc: 58.59%


train:  79%|███████▉  | 310/391 [04:42<04:12,  3.12s/it]

29-309 validation acc: 60.18%

iter_one_val_time: 7.619698524475098 seconds


train: 100%|██████████| 391/391 [05:48<00:00,  1.12it/s]

epoch_time: 348.83278727531433 seconds


EPOCH 30



train:   2%|▏         | 9/391 [00:07<05:08,  1.24it/s]

iter: 9 / 391
training acc: 58.59%


train:   3%|▎         | 10/391 [00:15<20:02,  3.16s/it]

30-9 validation acc: 59.03%

iter_one_val_time: 7.607578754425049 seconds


train:  28%|██▊       | 109/391 [01:36<03:48,  1.23it/s]

iter: 109 / 391
training acc: 58.59%


train:  28%|██▊       | 110/391 [01:44<14:40,  3.13s/it]

30-109 validation acc: 60.11%

iter_one_val_time: 7.736724138259888 seconds


train:  31%|███       | 120/391 [01:52<03:56,  1.15it/s]