In [None]:
# Copyright (c) 2024 Byeonghyeon Kim 
# github site: https://github.com/bhkim003/ByeonghyeonKim
# email: bhkim003@snu.ac.kr
 
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
 
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


In [None]:
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.datasets
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt

import time

from snntorch import spikegen
import matplotlib.pyplot as plt
import snntorch.spikeplot as splt
from IPython.display import HTML

from tqdm import tqdm

from apex.parallel import DistributedDataParallel as DDP

import random
import datetime

import json

from sklearn.utils import shuffle

''' 레퍼런스
https://spikingjelly.readthedocs.io/zh-cn/0.0.0.0.4/spikingjelly.datasets.html#module-spikingjelly.datasets
https://github.com/GorkaAbad/Sneaky-Spikes/blob/main/datasets.py
https://github.com/GorkaAbad/Sneaky-Spikes/blob/main/how_to.md
https://github.com/nmi-lab/torchneuromorphic
https://snntorch.readthedocs.io/en/latest/snntorch.spikevision.spikedata.html#shd
'''

import snntorch
from snntorch.spikevision import spikedata

from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
from spikingjelly.datasets.cifar10_dvs import CIFAR10DVS
from spikingjelly.datasets.n_mnist import NMNIST
# from spikingjelly.datasets.es_imagenet import ESImageNet
from spikingjelly.datasets import split_to_train_test_set
from spikingjelly.datasets.n_caltech101 import NCaltech101
from spikingjelly.datasets import pad_sequence_collate, padded_sequence_mask

import torchneuromorphic

import wandb

from torchviz import make_dot
import graphviz

In [None]:
# my module import
from modules import *

# modules 폴더에 새모듈.py 만들면
# modules/__init__py 파일에 form .새모듈 import * 하셈
# 그리고 새모듈.py에서 from modules.새모듈 import * 하셈


In [None]:
 # dvs 데이터 시각화 코드
 ##############################################################################################
            # mapping = {
            #     0: 'Hand Clapping',
            #     1: 'Right Hand Wave',
            #     2: 'Left Hand Wave',
            #     3: 'Right Arm CW',
            #     4: 'Right Arm CCW',
            #     5: 'Left Arm CW',
            #     6: 'Left Arm CCW',
            #     7: 'Arm Roll',
            #     8: 'Air Drums',
            #     9: 'Air Guitar',
            #     10: 'Other'
            # }
def dvs_visualization(inputs, labels, TIME, BATCH):
            
    what_input = random.randint(0, BATCH - 1)
    inputs_for_view = inputs.permute(1, 0, 2, 3, 4)
    for i in range(TIME):
        # 예시 데이터 생성
        data1 = inputs_for_view[what_input][i][0].numpy()  # torch tensor를 numpy 배열로 변환
        data2 = inputs_for_view[what_input][i][1].numpy()  # torch tensor를 numpy 배열로 변환

        # 데이터 플로팅
        fig, axs = plt.subplots(1, 2, figsize=(12, 6))  # 1행 2열의 subplot 생성

        # 첫 번째 subplot에 데이터1 플로팅
        im1 = axs[0].imshow(data1, cmap='viridis', interpolation='nearest')
        axs[0].set_title(f'Channel 0\nLabel: {labels[what_input]}  Time: {i}')  # 라벨값 맵핑하여 제목에 추가
        axs[0].set_xlabel('X axis')
        axs[0].set_ylabel('Y axis')
        axs[0].grid(False)
        fig.colorbar(im1, ax=axs[0])  # Color bar 추가

        # 두 번째 subplot에 데이터2 플로팅
        im2 = axs[1].imshow(data2, cmap='viridis', interpolation='nearest')
        axs[1].set_title(f'Channel 1\nLabel: {labels[what_input]}  Time: {i}')  # 라벨값 맵핑하여 제목에 추가
        axs[1].set_xlabel('X axis')
        axs[1].set_ylabel('Y axis')
        axs[1].grid(False)
        fig.colorbar(im2, ax=axs[1])  # Color bar 추가

        plt.tight_layout()  # subplot 간 간격 조정
        plt.show()
    sys.exit("종료")

######################################################################################################

In [None]:
def my_snn_system(devices = "0,1,2,3",
                    single_step = False, # True # False
                    unique_name = 'main',
                    my_seed = 42,
                    TIME = 10,
                    BATCH = 256,
                    IMAGE_SIZE = 32,
                    which_data = 'CIFAR10',
                    # CLASS_NUM = 10,
                    data_path = '/data2',
                    rate_coding = True,
    
                    lif_layer_v_init = 0.0,
                    lif_layer_v_decay = 0.6,
                    lif_layer_v_threshold = 1.2,
                    lif_layer_v_reset = 0.0,
                    lif_layer_sg_width = 1,

                    # synapse_conv_in_channels = IMAGE_PIXEL_CHANNEL,
                    synapse_conv_kernel_size = 3,
                    synapse_conv_stride = 1,
                    synapse_conv_padding = 1,
                    synapse_conv_trace_const1 = 1,
                    synapse_conv_trace_const2 = 0.6,

                    # synapse_fc_out_features = CLASS_NUM,
                    synapse_fc_trace_const1 = 1,
                    synapse_fc_trace_const2 = 0.6,

                    pre_trained = False,
                    convTrue_fcFalse = True,
                    cfg = [64, 64],
                    net_print = False, # True # False
                    weight_count_print = False, # True # False
                    pre_trained_path = "net_save/save_now_net.pth",
                    learning_rate = 0.0001,
                    epoch_num = 200,
                    verbose_interval = 100, #숫자 크게 하면 꺼짐
                    validation_interval = 10, #숫자 크게 하면 꺼짐
                    tdBN_on = False,
                    BN_on = False,

                    surrogate = 'sigmoid',

                    gradient_verbose = False,

                    BPTT_on = False,

                    optimizer_what = 'SGD', # 'SGD' 'Adam', 'RMSprop'
                    scheduler_name = 'no',
                    
                    ddp_on = True,

                    nda_net = False,
                    
                    domain_il_epoch = 0, # over 0, then domain il mode on

                    dvs_clipping = True, 
                    dvs_duration = 1000000,

                    OTTT_sWS_on = True, # True # False
                  ):
    if OTTT_sWS_on == True:
        assert BPTT_on == False and tdBN_on == False and convTrue_fcFalse == True
    if single_step == True:
        assert BPTT_on == False and tdBN_on == False

    ## 함수 내 모든 로컬 변수 저장 ########################################################
    hyperparameters = locals()
    hyperparameters['current epoch'] = 0
    ######################################################################################
    
    
    ## wandb 세팅 ###################################################################
    current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    wandb.config.update(hyperparameters)
    wandb.run.name = f'lr_{learning_rate}_{unique_name}_{which_data}_tstep{TIME}'
    wandb.define_metric("summary_val_acc", summary="max")
    ###################################################################################



    ## gpu setting ##################################################################################################################
    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
    os.environ["CUDA_VISIBLE_DEVICES"]= devices
    ###################################################################################################################################


    ## seed setting ##################################################################################################################
    torch.manual_seed(my_seed)
    ###################################################################################################################################


    ## data_loader 가져오기 ##################################################################################################################
    # data loader, pixel channel, class num
    train_loader, test_loader, synapse_conv_in_channels, CLASS_NUM = data_loader(
            which_data,
            data_path, 
            rate_coding, 
            BATCH, 
            IMAGE_SIZE,
            ddp_on,
            TIME,
            dvs_clipping,
            dvs_duration)
    synapse_fc_out_features = CLASS_NUM
    ###########################################################################################################################################

    
    ## parameter number calculator (안 중요함) ##################################################################################################################
    params_num = 0
    img_size = IMAGE_SIZE 
    bias_param = 1 # 1 or 0
    classifier_making = False
    if (convTrue_fcFalse == True):
        past_kernel = synapse_conv_in_channels
        for kernel in cfg:
            if (classifier_making == False):
                if (type(kernel) == list):
                    for residual_kernel in kernel:
                        if (residual_kernel >= 10000 and residual_kernel < 20000): # separable
                            residual_kernel -= 10000
                            params_num += (synapse_conv_kernel_size**2 + bias_param) * past_kernel
                            params_num += (1**2 * past_kernel + bias_param) * residual_kernel
                            past_kernel = residual_kernel  
                        elif (residual_kernel >= 20000 and residual_kernel < 30000): # depthwise
                            residual_kernel -= 20000
                            # 'past_kernel' should be same with 'kernel'
                            params_num += (synapse_conv_kernel_size**2 + bias_param) * past_kernel
                            past_kernel = residual_kernel  
                        else:
                            params_num += residual_kernel * ((synapse_conv_kernel_size**2) * past_kernel + bias_param)
                            past_kernel = residual_kernel
                elif (kernel == 'P' or kernel == 'M'):
                    img_size = img_size // 2
                elif (kernel == 'D'):
                    img_size = 1
                elif (kernel == 'L'):
                    classifier_making = True
                    past_kernel = past_kernel * (img_size**2)
                else:
                    if (kernel >= 10000 and kernel < 20000): # separable
                        kernel -= 10000
                        params_num += (synapse_conv_kernel_size**2 + bias_param) * past_kernel
                        params_num += (1**2 * past_kernel + bias_param) * kernel
                        past_kernel = kernel  
                    elif (kernel >= 20000 and kernel < 30000): # depthwise
                        kernel -= 20000
                        # 'past_kernel' should be same with 'kernel'
                        params_num += (synapse_conv_kernel_size**2 + bias_param) * past_kernel
                        past_kernel = kernel  
                    else:
                        params_num += kernel * (synapse_conv_kernel_size**2 * past_kernel + bias_param)
                        past_kernel = kernel    
            else: # classifier making
                params_num += (past_kernel + bias_param) * kernel
                past_kernel = kernel
        
        
        if classifier_making == False:
            past_kernel = past_kernel*img_size*img_size

        params_num += (past_kernel + bias_param) * synapse_fc_out_features
    else:
        past_in_channel = synapse_conv_in_channels*img_size*img_size
        for in_channel in cfg:
            if (type(in_channel) == list):
                for residual_in_channel in in_channel:
                    params_num += (past_in_channel + bias_param) * residual_in_channel
                    past_in_channel = residual_in_channel
            # elif (in_channel == 'M'): #it's a holy FC layer!
            #     img_size = img_size // 2
            else:
                print('past_in_channel', past_in_channel)
                print('bias_param', bias_param)
                print('in_channel', in_channel)
                params_num += (past_in_channel + bias_param) * in_channel
                past_in_channel = in_channel
        params_num += (past_in_channel + bias_param) * synapse_fc_out_features
    ###########################################################################################################################################


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ### network setting #######################################################################################################################
    if pre_trained == False:
        if (convTrue_fcFalse == False):
            if (single_step == False):
                net = MY_SNN_FC(cfg, synapse_conv_in_channels, IMAGE_SIZE, synapse_fc_out_features,
                            synapse_fc_trace_const1, synapse_fc_trace_const2, 
                            lif_layer_v_init, lif_layer_v_decay, 
                            lif_layer_v_threshold, lif_layer_v_reset,
                            lif_layer_sg_width,
                            tdBN_on,
                            BN_on, TIME,
                            surrogate,
                            BPTT_on).to(device)
        else:
            if (single_step == False):
                net = MY_SNN_CONV(cfg, synapse_conv_in_channels, IMAGE_SIZE,
                            synapse_conv_kernel_size, synapse_conv_stride, 
                            synapse_conv_padding, synapse_conv_trace_const1, 
                            synapse_conv_trace_const2, 
                            lif_layer_v_init, lif_layer_v_decay, 
                            lif_layer_v_threshold, lif_layer_v_reset,
                            lif_layer_sg_width,
                            synapse_fc_out_features, synapse_fc_trace_const1, synapse_fc_trace_const2,
                            tdBN_on,
                            BN_on, TIME,
                            surrogate,
                            BPTT_on,
                            OTTT_sWS_on).to(device)
            else:
                net = MY_SNN_CONV_ottt_sstep(cfg, synapse_conv_in_channels, IMAGE_SIZE,
                            synapse_conv_kernel_size, synapse_conv_stride, 
                            synapse_conv_padding, synapse_conv_trace_const1, 
                            synapse_conv_trace_const2, 
                            lif_layer_v_init, lif_layer_v_decay, 
                            lif_layer_v_threshold, lif_layer_v_reset,
                            lif_layer_sg_width,
                            synapse_fc_out_features, synapse_fc_trace_const1, synapse_fc_trace_const2,
                            tdBN_on,
                            BN_on, TIME,
                            surrogate,
                            BPTT_on,
                            OTTT_sWS_on).to(device)
        if (nda_net == True):
            net = VGG(cfg = cfg, num_classes=10, batch_norm = tdBN_on, in_c = synapse_conv_in_channels, 
                      lif_layer_v_threshold=lif_layer_v_threshold, lif_layer_v_decay=lif_layer_v_decay, lif_layer_sg_width=lif_layer_sg_width)
            net.T = TIME
        net = torch.nn.DataParallel(net) #나중에풀어줘
    else:
        net = torch.load(pre_trained_path)

    net = net.to(device)
    if (net_print == True):
        print(net)        
    ####################################################################################################################################
    

    ## wandb logging ###########################################
    wandb.watch(net, log="all", log_freq = 10) #gradient, parameter logging해줌
    ############################################################

    ## param num and memory estimation except BN with MY own calculation some lines above ##########################################
    real_param_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
    if (weight_count_print == True):
        for name, param in net.named_parameters():
            if param.requires_grad:
                print(f'Layer: {name} | Number of parameters: {param.numel()}')
    # Batch norm 있으면 아래 두 개 서로 다를 수 있음.
    # assert real_param_num == params_num, f'parameter number is not same. real_param_num: {real_param_num}, params_num: {params_num}'    
    print('='*50)
    print(f"My Num of PARAMS: {params_num:,}, system's param_num : {real_param_num:,}")
    memory = params_num / 8 / 1024 / 1024 # MB
    precision = 32
    memory = memory * precision 
    print(f"Memory: {memory:.2f}MiB at {precision}-bit")
    print('='*50)
    ##############################################################################################################################



    ## criterion ########################################## # loss 구해주는 친구
    criterion = nn.CrossEntropyLoss().to(device)
    if (OTTT_sWS_on == True):
        # criterion = nn.CrossEntropyLoss().to(device)
        criterion = lambda y_t, target_t: ((1 - 0.05) * F.cross_entropy(y_t, target_t) + 0.05 * F.mse_loss(y_t, F.one_hot(target_t, CLASS_NUM).float())) / TIME 
    ####################################################
    



    ## optimizer, scheduler ########################################################################
    if(optimizer_what == 'SGD'):
        # optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)
        optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0)
    elif(optimizer_what == 'Adam'):
        # optimizer = torch.optim.Adam(net.parameters(), lr=0.00001)
        optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate/256 * BATCH, weight_decay=1e-4)
        # optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=0, betas=(0.9, 0.999))
    elif(optimizer_what == 'RMSprop'):
        pass


    if (scheduler_name == 'StepLR'):
        scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    elif (scheduler_name == 'ExponentialLR'):
        scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
    elif (scheduler_name == 'ReduceLROnPlateau'):
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)
    elif (scheduler_name == 'CosineAnnealingLR'):
        # scheduler = lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=50)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=epoch_num)
    elif (scheduler_name == 'OneCycleLR'):
        scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=0.1, steps_per_epoch=len(train_loader), epochs=100)
    else:
        pass # 'no' scheduler
    ## optimizer, scheduler ########################################################################


    tr_acc = 0
    tr_correct = 0
    tr_total = 0
    val_acc = 0
    val_acc_now = 0
    elapsed_time_val = 0
    iter_acc_array = np.array([])
    tr_acc_array = np.array([])
    val_acc_now_array = np.array([])
    #======== EPOCH START ==========================================================================================
    for epoch in range(epoch_num):
        print('EPOCH', epoch)
        epoch_start_time = time.time()
        running_loss = 0.0

        # if (domain_il_epoch>0 and which_data == 'PMNIST'):
        #     k = epoch // domain_il_epoch
        #     xtrain=data[k]['train']['x']
        #     ytrain=data[k]['train']['y']
        #     xtest =data[k]['test']['x']
        #     ytest =data[k]['test']['y']

        
        ####### iterator : input_loading & tqdm을 통한 progress_bar 생성###################
        iterator = enumerate(train_loader, 0)
        if (ddp_on == True):
            if torch.distributed.get_rank() == 0:   
                iterator = tqdm(iterator, total=len(train_loader), desc='train', dynamic_ncols=True, position=0, leave=True)
        else:
            iterator = tqdm(iterator, total=len(train_loader), desc='train', dynamic_ncols=True, position=0, leave=True)
        ##################################################################################   
        
        #### validation_interval이 batch size보다 작을 시 validation_interval을 batch size로 맞춰줌#############
        validation_interval2 = validation_interval
        if (validation_interval > len(iterator)):
            validation_interval2 = len(iterator)
        ##################################################################################################



        ###### ITERATION START ##########################################################################################################
        for i, data in iterator:
            iter_one_train_time_start = time.time()
            net.train() # train 모드로 바꿔줘야함

            ### data loading & semi-pre-processing ################################################################################
            if len(data) == 2:
                inputs, labels = data
                # 처리 로직 작성
            elif len(data) == 3:
                inputs, labels, x_len = data
                # print('x_len',x_len)
                # mask = padded_sequence_mask(x_len)
                # max_time_step = x_len.max()
                # min_time_step = x_len.min()
            # print('inputs',inputs.size(),'\nlabels',labels.size())
                    
            if (which_data == 'n_tidigits'):
                inputs = inputs.permute(0, 1, 3, 2, 4)
                labels = labels[:, 0, :]
                labels = torch.argmax(labels, dim=1)
            elif (which_data == 'heidelberg'):
                inputs = inputs.view(5, 1000, 1, 700, 1)
                print("\n\n\n경고!!!! heidelberg 이거 타임스텝이랑 채널 잘 바꿔줘라!!!\n\n\n\n")
            # print('inputs',inputs.size(),'\nlabels',labels.size())
            # print(labels)
                
            if (which_data == 'DVS_CIFAR10' or which_data == 'DVS_GESTURE' or which_data == 'DVS_CIFAR10_2' or which_data == 'NMNIST' or which_data == 'N_CALTECH101' or which_data == 'n_tidigits' or which_data == 'heidelberg'):
                inputs = inputs.permute(1, 0, 2, 3, 4)
            elif rate_coding == True :
                inputs = spikegen.rate(inputs, num_steps=TIME)
            else :
                inputs = inputs.repeat(TIME, 1, 1, 1, 1)
            # inputs: [Time, Batch, Channel, Height, Width]  
            ####################################################################################################################### 
                

                
            # # dvs 데이터 시각화 코드 (확인 필요할 시 써라)
            # ##############################################################################################
            # dvs_visualization(inputs, labels, TIME, BATCH)
            # ######################################################################################################


            ## device로 보내주기 ######################################
            inputs = inputs.to(device)
            labels = labels.to(device)
            real_batch = labels.size(0)
            ###########################################################


            ## gradient 초기화 #######################################
            optimizer.zero_grad()
            ###########################################################


            if single_step == False:
                # net에 넣어줄때는 batch가 젤 앞 차원으로 와야함. # dataparallel때매##############################
                # inputs: [Time, Batch, Channel, Height, Width]   
                inputs = inputs.permute(1, 0, 2, 3, 4) # net에 넣어줄때는 batch가 젤 앞 차원으로 와야함. # dataparallel때매
                # inputs: [Batch, Time, Channel, Height, Width] 
                #################################################################################################
            else:
                labels = labels.repeat(TIME, 1)

            

            if single_step == False:
                ### input --> net --> output #####################################################
                outputs = net(inputs)
                ##################################################################################
                ## loss, backward ##########################################
                loss = criterion(outputs, labels)
                loss.backward()
                ############################################################
                ## weight 업데이트!! ##################################
                optimizer.step()
                ################################################################
            else:
                outputs_all = []
                loss = 0.0
                for t in range(TIME):
                    outputs = net(inputs[t])
                    one_time_loss = criterion(outputs, labels[t].contiguous())
                    one_time_loss.backward() # one_time backward
                    loss += one_time_loss.data
                    outputs_all.append(outputs.detach())
                optimizer.step() # full step time update
                outputs_all = torch.stack(outputs_all, dim=1)
                outputs = outputs_all.mean(1) # ottt꺼 쓸때
                

            ## net 그림 출력해보기 #################################################################
            # print('시각화')
            # make_dot(outputs, params=dict(list(net.named_parameters()))).render("net_torchviz", format="png")
            # return 0
            ##################################################################################

            #### batch 어긋남 방지 ###############################################
            assert real_batch == BATCH, f'batch size is not same. real_batch: {real_batch}, BATCH: {BATCH}'
            #######################################################################
            

            ####### training accruacy save for print ###############################
            _, predicted = torch.max(outputs.data, 1)
            total = labels.size(0)
            correct = (predicted == labels).sum().item()
            tr_total += total
            tr_correct += correct
            iter_acc = correct / total
            if i % verbose_interval == verbose_interval-1:
                print(f'{epoch}-{i} training acc: {100 * iter_acc:.2f}%, lr={[f"{lr}" for lr in (param_group["lr"] for param_group in optimizer.param_groups)]}, val_acc: {100 * val_acc_now:.2f}%')
            iter_acc_string = f'{epoch}-{i}/{len(train_loader)} iter_acc: {100 * iter_acc:.2f}%, lr={[f"{lr}" for lr in (param_group["lr"] for param_group in optimizer.param_groups)]}'
            ################################################################
            



            running_loss += loss.item()
            # print("Epoch: {}, Iter: {}, Loss: {}".format(epoch + 1, i + 1, running_loss / 100))

            iter_one_train_time_end = time.time()
            elapsed_time = iter_one_train_time_end - iter_one_train_time_start  # 실행 시간 계산

            if (i % verbose_interval == verbose_interval-1):
                print(f"iter_one_train_time: {elapsed_time} seconds, last one_val_time: {elapsed_time_val} seconds\n")

            ##### validation ##################################################################################################################################
            if i % validation_interval2 == validation_interval2-1:
                iter_one_val_time_start = time.time()
                tr_acc = tr_correct/tr_total
                tr_correct = 0
                tr_total = 0
                correct = 0
                total = 0
                with torch.no_grad():
                    net.eval() # eval 모드로 바꿔줘야함 
                    for data in test_loader:
                        ## data loading & semi-pre-processing ##########################################################
                        if len(data) == 2:
                            inputs, labels = data
                            # 처리 로직 작성
                        elif len(data) == 3:
                            inputs, labels, x_len = data
                            # print('x_len',x_len)
                            # mask = padded_sequence_mask(x_len)
                            # max_time_step = x_len.max()
                            # min_time_step = x_len.min()
                            # B, T, *spatial_dims = inputs.shape

                        if (which_data == 'DVS_CIFAR10' or which_data == 'DVS_GESTURE' or which_data == 'DVS_CIFAR10_2' or which_data == 'NMNIST' or which_data == 'N_CALTECH101' or which_data == 'n_tidigits' or which_data == 'heidelberg'):
                            inputs = inputs.permute(1, 0, 2, 3, 4)
                        elif rate_coding == True :
                            inputs = spikegen.rate(inputs, num_steps=TIME)
                        else :
                            inputs = inputs.repeat(TIME, 1, 1, 1, 1)
                        # inputs: [Time, Batch, Channel, Height, Width]  
                        ###################################################################################################

                        inputs = inputs.to(device)
                        labels = labels.to(device)
                        real_batch = labels.size(0)

                        if single_step == False:
                            outputs = net(inputs.permute(1, 0, 2, 3, 4)) #inputs: [Batch, Time, Channel, Height, Width]  
                            val_loss = criterion(outputs, labels)
                        else:
                            val_loss=0
                            outputs_all = []
                            for t in range(TIME):
                                outputs = net(inputs[t])
                                loss = criterion(outputs, labels)
                                outputs_all.append(outputs.detach())
                                val_loss += loss.data
                            outputs_all = torch.stack(outputs_all, dim=1)
                            outputs = outputs_all.mean(1)


                        _, predicted = torch.max(outputs.data, 1)
                        total += real_batch
                        assert real_batch == BATCH, f'batch size is not same. real_batch: {real_batch}, batch: {BATCH}'
                        correct += (predicted == labels).sum().item()

                    val_acc_now = correct / total
                    # print(f'{epoch}-{i} validation acc: {100 * val_acc_now:.2f}%, lr={[f"{lr:.10f}" for lr in (param_group["lr"] for param_group in optimizer.param_groups)]}')

                iter_one_val_time_end = time.time()
                elapsed_time_val = iter_one_val_time_end - iter_one_val_time_start  # 실행 시간 계산
                # print(f"iter_one_val_time: {elapsed_time_val} seconds")

                # network save
                if val_acc < val_acc_now:
                    val_acc = val_acc_now
                    torch.save(net.state_dict(), f"net_save/save_now_net_weights_{unique_name}.pth")
                    torch.save(net, f"net_save/save_now_net_{unique_name}.pth")
                    # torch.save(net.module.state_dict(), f"net_save/save_now_net_weights2_{unique_name}.pth")
                    # torch.save(net.module, f"net_save/save_now_net2_{unique_name}.pth")
            ####################################################################################################################################################
            iterator.set_description(f"iter_acc: {iter_acc_string}, iter_loss: {loss}, val_acc: {100 * val_acc_now:.2f}%")  
            wandb.log({"iter_acc": iter_acc}, step=i+epoch*len(train_loader))
            wandb.log({"tr_acc": tr_acc}, step=i+epoch*len(train_loader))
            wandb.log({"val_acc_now": val_acc_now}, step=i+epoch*len(train_loader))
            wandb.log({"summary_val_acc": val_acc_now})
            iter_acc_array = np.append(iter_acc_array, iter_acc)
            tr_acc_array = np.append(tr_acc_array, tr_acc)
            val_acc_now_array = np.append(val_acc_now_array, val_acc_now)
            base_name = f'{current_time}'
            iter_acc_file_name_time = f'result_save/{base_name}_iter_acc_array_{unique_name}.npy'
            tr_acc_file_name_time = f'result_save/{base_name}_tr_acc_array_{unique_name}.npy'
            val_acc_file_name_time = f'result_save/{base_name}_val_acc_now_array_{unique_name}.npy'
            hyperparameters_file_name_time = f'result_save/{base_name}_hyperparameters_{unique_name}.json'

            hyperparameters['current epoch'] = epoch

            ### 모듈 세이브: 덮어쓰기 하기 싫으면 주석 풀어서 사용 (시간마다 새로 쓰기) 비추천 ########################
            # np.save(iter_acc_file_name_time, iter_acc_array)
            # np.save(tr_acc_file_name_time, iter_acc_array)
            # np.save(val_acc_file_name_time, val_acc_now_array)
            # with open(hyperparameters_file_name_time, 'w') as f:
            #     json.dump(hyperparameters, f, indent=4)
            #########################################################################################################

            ## 모듈 세이브 ###########################################################################################
            # np.save(f'result_save/iter_acc_array_{unique_name}.npy', iter_acc_array)
            # np.save(f'result_save/tr_acc_array_{unique_name}.npy', tr_acc_array)
            # np.save(f'result_save/val_acc_now_array_{unique_name}.npy', val_acc_now_array)
            # with open(f'result_save/hyperparameters_{unique_name}.json', 'w') as f:
            #     json.dump(hyperparameters, f, indent=4)
            ##########################################################################################################
        ###### ITERATION END ##########################################################################################################
                

        ## scheduler update #############################################################################
        if (scheduler_name != 'no'):
            if (scheduler_name == 'ReduceLROnPlateau'):
                scheduler.step(val_loss)
            else:
                scheduler.step()
        #################################################################################################
        
        # 실행 시간 계산
        epoch_time_end = time.time()
        print(f"epoch_time: {epoch_time_end - epoch_start_time} seconds\n") 
    #======== EPOCH END ==========================================================================================


In [6]:
### my_snn control board ########################
decay = 0.7 # 0.875 0.25 0.125 0.75 0.5
# nda 0.25 # ottt 0.5

unique_name = 'main' ## 이거 설정하면 새로운 경로에 모두 save
wandb.init(project= f'my_snn {unique_name}')
my_snn_system(  devices = "5",
                single_step = True, # True # False
                unique_name = unique_name,
                my_seed = 42,
                TIME = 6 , # dvscifar 10 # ottt 6 or 10 # nda 10  # 제작하는 dvs에서 TIME넘거나 적으면 자르거나 PADDING함
                BATCH = 96, # batch norm 할거면 2이상으로 해야함   # nda 256   #  ottt 128
                IMAGE_SIZE = 32, # dvscifar 48 # MNIST 28 # CIFAR10 32 # PMNIST 28
                # dvsgesture 128, dvs_cifar2 128, nmnist 34, n_caltech101 180,240, n_tidigits 64, heidelberg 700, 
                #pmnist는 28로 해야 됨. 나머지는 바꿔도 돌아는 감.

                # DVS_CIFAR10 할거면 time 10으로 해라
                which_data = 'CIFAR10',
# 'CIFAR100' 'CIFAR10' 'MNIST' 'FASHION_MNIST' 'DVS_CIFAR10' 'PMNIST'아직
# 'DVS_GESTURE','DVS_CIFAR10_2','NMNIST','N_CALTECH101','n_tidigits','heidelberg'
                # CLASS_NUM = 10,
                data_path = '/data2', # YOU NEED TO CHANGE THIS
                rate_coding = False, # True # False

                lif_layer_v_init = 0.0,
                lif_layer_v_decay = decay,
                lif_layer_v_threshold = 1.0,  # 10000이상으로 하면 NDA LIF 씀. #nda 0.5  #ottt 1.0
                lif_layer_v_reset = 0, # 10000이상은 hardreset (내 LIF쓰기는 함 ㅇㅇ)
                lif_layer_sg_width = 1.0, # # surrogate sigmoid 쓸 때는 의미없음

                # synapse_conv_in_channels = IMAGE_PIXEL_CHANNEL,
                synapse_conv_kernel_size = 3,
                synapse_conv_stride = 1,
                synapse_conv_padding = 1,
                synapse_conv_trace_const1 = 1,
                synapse_conv_trace_const2 = decay, # lif_layer_v_decay

                # synapse_fc_out_features = CLASS_NUM,
                synapse_fc_trace_const1 = 1,
                synapse_fc_trace_const2 = decay, # lif_layer_v_decay

                pre_trained = False, # True # False
                convTrue_fcFalse = True, # True # False

                # 'P' for average pooling, 'D' for (1,1) aver pooling, 'M' for maxpooling, 'L' for linear classifier, [  ] for residual block
                # conv에서 10000 이상은 depth-wise separable (BPTT만 지원), 20000이상은 depth-wise (BPTT만 지원)
                # cfg = [64],
                # cfg = [64,[64,64],64], # 끝에 linear classifier 하나 자동으로 붙습니다
                cfg = [64, 128, 'P', 256, 256, 'P', 512, 512, 'P', 512, 512, 'D'], #ottt
                # cfg = [64, 128, 'P', 256, 256, 'P', 512, 512, 'P', 512, 512], #ottt
                # cfg = [64, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], # ottt 
                # cfg = [64, 'P', 128, 'P', 256, 256, 'P', 512, 512, 512, 512, 'D'], # nda
                # cfg = [64, 'P', 128, 'P', 256, 256, 'P', 512, 512, 512, 512], # nda 128pixel
                # cfg = [64, 'P', 128, 'P', 256, 256, 'P', 512, 512, 512, 512, 'L', 4096, 4096],
                # cfg = [20001,10001], # depthwise, separable
                # cfg = [64,20064,10001], # vanilla conv, depthwise, separable
                # cfg = [8, 'P', 8, 'P', 8, 'P', 8,'P', 8, 'P'],
                # cfg = [], 
                
                net_print = True, # True # False
                weight_count_print = False, # True # False
                
                pre_trained_path = f"net_save/save_now_net_{unique_name}.pth",
                learning_rate = 0.5, # default 0.001  # ottt 0.1 0.00001 # nda 0.001 
                epoch_num = 300,
                verbose_interval = 999999999, #숫자 크게 하면 꺼짐 #걍 중간중간 iter에서 끊어서 출력
                validation_interval = 999999999, #숫자 크게 하면 에포크 마지막 iter 때 val 함

                tdBN_on = False,  # True # False
                BN_on = False,  # True # False
                
                surrogate = 'sigmoid', # 'rectangle' 'sigmoid' 'rough_rectangle'
                
                gradient_verbose = False,  # True # False  # weight gradient 각 layer마다 띄워줌

                BPTT_on = False,  # True # False # True이면 BPTT, False이면 OTTT  # depthwise, separable은 BPTT만 가능
                optimizer_what = 'SGD', # 'SGD' 'Adam', 'RMSprop'
                scheduler_name = 'CosineAnnealingLR', # 'no' 'StepLR' 'ExponentialLR' 'ReduceLROnPlateau' 'CosineAnnealingLR' 'OneCycleLR'
                
                ddp_on = False,   # True # False

                nda_net = False,   # True # False

                domain_il_epoch = 0, # over 0, then domain il mode on # pmnist 쓸거면 HLOP 코드보고 더 디벨롭하셈. 지금 개발 hold함.
                
                dvs_clipping = True, # dvs zero&one  # gesture, cifar-dvs2, nmnist, ncaltech101
                dvs_duration = 1000000, # 0 아니면 time sampling # dvs number sampling OR time sampling # gesture, cifar-dvs2, nmnist, ncaltech101
                #있는 데이터들 #gesture 1000000 #nmnist 10000

                OTTT_sWS_on = True, # True # False # BPTT끄고, CONV에만 적용됨.
                
                ) 
# sigmoid와 BN이 있어야 잘된다.
# average pooling
# 이 낫다. 
 
# nda에서는 decay = 0.25, threshold = 0.5, width =1, surrogate = rectangle, batch = 256, tdBN = True
## OTTT 에서는 decay = 0.5, threshold = 1.0, surrogate = sigmoid, batch = 128, BN = True


iter_acc: 80-520/521 iter_acc: 97.50%, lr=['0.41728265158971456'], iter_loss: 0.03622371330857277, val_acc: 90.69%: 100%|██████████| 521/521 [03:21<00:00,  2.58it/s]

epoch_time: 201.94062638282776 seconds

EPOCH 81



iter_acc: 81-520/521 iter_acc: 95.00%, lr=['0.41532796633091296'], iter_loss: 0.04835034906864166, val_acc: 90.95%: 100%|██████████| 521/521 [03:21<00:00,  2.59it/s]  

epoch_time: 201.57207775115967 seconds

EPOCH 82



iter_acc: 82-520/521 iter_acc: 98.75%, lr=['0.4133551509975264'], iter_loss: 0.03099818527698517, val_acc: 90.97%: 100%|██████████| 521/521 [03:22<00:00,  2.58it/s]  

epoch_time: 202.32380032539368 seconds

EPOCH 83



iter_acc: 83-520/521 iter_acc: 97.50%, lr=['0.41136442193098766'], iter_loss: 0.03676402568817139, val_acc: 90.66%: 100%|██████████| 521/521 [03:21<00:00,  2.59it/s]  

epoch_time: 201.33381700515747 seconds

EPOCH 84



iter_acc: 84-520/521 iter_acc: 97.50%, lr=['0.40935599743717244'], iter_loss: 0.03478919714689255, val_acc: 91.19%: 100%|██████████| 521/521 [03:22<00:00,  2.57it/s]  

epoch_time: 202.82430982589722 seconds

EPOCH 85



iter_acc: 85-520/521 iter_acc: 97.50%, lr=['0.4073300977624594'], iter_loss: 0.03763521462678909, val_acc: 90.81%: 100%|██████████| 521/521 [03:20<00:00,  2.60it/s]  

epoch_time: 200.89998507499695 seconds

EPOCH 86



iter_acc: 86-520/521 iter_acc: 93.75%, lr=['0.40528694506957763'], iter_loss: 0.05076569318771362, val_acc: 90.84%: 100%|██████████| 521/521 [03:21<00:00,  2.59it/s]  

epoch_time: 201.41532588005066 seconds

EPOCH 87



iter_acc: 87-520/521 iter_acc: 93.75%, lr=['0.40322676341324415'], iter_loss: 0.03963596746325493, val_acc: 91.10%: 100%|██████████| 521/521 [03:25<00:00,  2.54it/s]  

epoch_time: 205.46025967597961 seconds

EPOCH 88



iter_acc: 88-520/521 iter_acc: 97.50%, lr=['0.40114977871559376'], iter_loss: 0.03747342526912689, val_acc: 91.14%: 100%|██████████| 521/521 [03:20<00:00,  2.59it/s]  

epoch_time: 201.01103520393372 seconds

EPOCH 89



iter_acc: 89-520/521 iter_acc: 97.50%, lr=['0.39905621874140396'], iter_loss: 0.030353834852576256, val_acc: 90.87%: 100%|██████████| 521/521 [03:21<00:00,  2.59it/s] 

epoch_time: 201.25519943237305 seconds

EPOCH 90



iter_acc: 90-520/521 iter_acc: 96.25%, lr=['0.3969463130731183'], iter_loss: 0.03198334574699402, val_acc: 91.54%: 100%|██████████| 521/521 [03:39<00:00,  2.37it/s]  

epoch_time: 219.7375888824463 seconds

EPOCH 91



iter_acc: 91-520/521 iter_acc: 97.50%, lr=['0.3948202930856697'], iter_loss: 0.03172188997268677, val_acc: 91.03%: 100%|██████████| 521/521 [03:27<00:00,  2.51it/s]  

epoch_time: 207.58921480178833 seconds

EPOCH 92



iter_acc: 92-520/521 iter_acc: 97.50%, lr=['0.392678391921108'], iter_loss: 0.032344572246074677, val_acc: 91.19%: 100%|██████████| 521/521 [03:23<00:00,  2.56it/s] 

epoch_time: 203.76957845687866 seconds

EPOCH 93



iter_acc: 93-520/521 iter_acc: 92.50%, lr=['0.3905208444630327'], iter_loss: 0.0492185577750206, val_acc: 90.64%: 100%|██████████| 521/521 [03:22<00:00,  2.57it/s]   

epoch_time: 202.77678322792053 seconds

EPOCH 94



iter_acc: 94-520/521 iter_acc: 92.50%, lr=['0.388347887310836'], iter_loss: 0.05689641088247299, val_acc: 91.07%: 100%|██████████| 521/521 [03:21<00:00,  2.58it/s]  

epoch_time: 201.7227065563202 seconds

EPOCH 95



iter_acc: 95-520/521 iter_acc: 93.75%, lr=['0.38615975875375674'], iter_loss: 0.05089832469820976, val_acc: 91.52%: 100%|██████████| 521/521 [03:21<00:00,  2.59it/s]  

epoch_time: 201.5309829711914 seconds

EPOCH 96



iter_acc: 96-520/521 iter_acc: 95.00%, lr=['0.38395669874474914'], iter_loss: 0.04326064512133598, val_acc: 91.30%: 100%|██████████| 521/521 [03:20<00:00,  2.59it/s]  

epoch_time: 201.02212381362915 seconds

EPOCH 97



iter_acc: 97-520/521 iter_acc: 97.50%, lr=['0.3817389488741694'], iter_loss: 0.03413093835115433, val_acc: 91.36%: 100%|██████████| 521/521 [03:22<00:00,  2.58it/s]  

epoch_time: 202.32265663146973 seconds

EPOCH 98



iter_acc: 98-520/521 iter_acc: 95.00%, lr=['0.37950675234328257'], iter_loss: 0.04775248095393181, val_acc: 91.08%: 100%|██████████| 521/521 [03:21<00:00,  2.58it/s]  

epoch_time: 201.9774785041809 seconds

EPOCH 99



iter_acc: 99-520/521 iter_acc: 97.50%, lr=['0.37726035393759283'], iter_loss: 0.03004661202430725, val_acc: 91.47%: 100%|██████████| 521/521 [03:20<00:00,  2.59it/s]  

epoch_time: 201.0492663383484 seconds

EPOCH 100



iter_acc: 100-520/521 iter_acc: 97.50%, lr=['0.375'], iter_loss: 0.035349149256944656, val_acc: 91.42%: 100%|██████████| 521/521 [03:21<00:00,  2.59it/s] 

epoch_time: 201.24844527244568 seconds

EPOCH 101



iter_acc: 101-520/521 iter_acc: 96.25%, lr=['0.3727259384037852'], iter_loss: 0.03543423116207123, val_acc: 91.22%: 100%|██████████| 521/521 [03:21<00:00,  2.59it/s]  

epoch_time: 201.18239784240723 seconds

EPOCH 102



iter_acc: 102-520/521 iter_acc: 98.75%, lr=['0.3704384185254288'], iter_loss: 0.028175026178359985, val_acc: 91.24%: 100%|██████████| 521/521 [03:21<00:00,  2.58it/s] 

epoch_time: 201.8127737045288 seconds

EPOCH 103



iter_acc: 103-520/521 iter_acc: 98.75%, lr=['0.36813769121726353'], iter_loss: 0.031213607639074326, val_acc: 91.20%: 100%|██████████| 521/521 [03:21<00:00,  2.59it/s] 

epoch_time: 201.47113490104675 seconds

EPOCH 104



iter_acc: 104-520/521 iter_acc: 95.00%, lr=['0.36582400877996546'], iter_loss: 0.041281770914793015, val_acc: 91.04%: 100%|██████████| 521/521 [03:21<00:00,  2.58it/s] 

epoch_time: 202.12622022628784 seconds

EPOCH 105



iter_acc: 105-520/521 iter_acc: 97.50%, lr=['0.3634976249348867'], iter_loss: 0.038526467978954315, val_acc: 91.21%: 100%|██████████| 521/521 [03:21<00:00,  2.58it/s] 

epoch_time: 201.7678644657135 seconds

EPOCH 106



iter_acc: 106-520/521 iter_acc: 97.50%, lr=['0.36115879479623186'], iter_loss: 0.038240429013967514, val_acc: 91.64%: 100%|██████████| 521/521 [03:21<00:00,  2.58it/s] 

epoch_time: 201.95917177200317 seconds

EPOCH 107



iter_acc: 107-520/521 iter_acc: 95.00%, lr=['0.3588077748430819'], iter_loss: 0.04653308540582657, val_acc: 91.38%: 100%|██████████| 521/521 [03:21<00:00,  2.58it/s]  

epoch_time: 201.85935258865356 seconds

EPOCH 108



iter_acc: 108-520/521 iter_acc: 98.75%, lr=['0.35644482289126816'], iter_loss: 0.029707174748182297, val_acc: 91.15%: 100%|██████████| 521/521 [03:21<00:00,  2.58it/s] 

epoch_time: 201.78996539115906 seconds

EPOCH 109



iter_acc: 109-520/521 iter_acc: 100.00%, lr=['0.3540701980651003'], iter_loss: 0.028416162356734276, val_acc: 91.52%: 100%|██████████| 521/521 [03:24<00:00,  2.54it/s]

epoch_time: 205.0502495765686 seconds

EPOCH 110



iter_acc: 110-520/521 iter_acc: 93.75%, lr=['0.35168416076895004'], iter_loss: 0.06093659624457359, val_acc: 91.49%: 100%|██████████| 521/521 [03:22<00:00,  2.58it/s]  

epoch_time: 202.40620708465576 seconds

EPOCH 111



iter_acc: 111-520/521 iter_acc: 96.25%, lr=['0.3492869726586951'], iter_loss: 0.037808410823345184, val_acc: 91.40%: 100%|██████████| 521/521 [03:27<00:00,  2.51it/s] 

epoch_time: 207.79018187522888 seconds

EPOCH 112



iter_acc: 112-520/521 iter_acc: 98.75%, lr=['0.34687889661302573'], iter_loss: 0.031209345906972885, val_acc: 91.53%: 100%|██████████| 521/521 [04:01<00:00,  2.15it/s] 

epoch_time: 242.0584433078766 seconds

EPOCH 113



iter_acc: 113-520/521 iter_acc: 97.50%, lr=['0.3444601967046168'], iter_loss: 0.03914804384112358, val_acc: 91.48%: 100%|██████████| 521/521 [03:48<00:00,  2.28it/s]  

epoch_time: 228.82825446128845 seconds

EPOCH 114



iter_acc: 114-520/521 iter_acc: 97.50%, lr=['0.34203113817116954'], iter_loss: 0.03797927871346474, val_acc: 91.31%: 100%|██████████| 521/521 [03:53<00:00,  2.23it/s]  

epoch_time: 234.0280795097351 seconds

EPOCH 115



iter_acc: 115-520/521 iter_acc: 98.75%, lr=['0.339591987386325'], iter_loss: 0.030957099050283432, val_acc: 91.53%: 100%|██████████| 521/521 [04:15<00:00,  2.04it/s] 

epoch_time: 256.25695991516113 seconds

EPOCH 116



iter_acc: 116-520/521 iter_acc: 97.50%, lr=['0.3371430118304538'], iter_loss: 0.037034835666418076, val_acc: 91.31%: 100%|██████████| 521/521 [03:58<00:00,  2.18it/s] 

epoch_time: 238.7909712791443 seconds

EPOCH 117



iter_acc: 117-520/521 iter_acc: 95.00%, lr=['0.3346844800613229'], iter_loss: 0.04270452260971069, val_acc: 91.27%: 100%|██████████| 521/521 [03:54<00:00,  2.22it/s]  

epoch_time: 234.59290432929993 seconds

EPOCH 118



iter_acc: 118-520/521 iter_acc: 98.75%, lr=['0.3322166616846458'], iter_loss: 0.03210258483886719, val_acc: 91.45%: 100%|██████████| 521/521 [04:03<00:00,  2.14it/s]  

epoch_time: 243.84248733520508 seconds

EPOCH 119



iter_acc: 119-520/521 iter_acc: 98.75%, lr=['0.3297398273245175'], iter_loss: 0.033257029950618744, val_acc: 91.31%: 100%|██████████| 521/521 [03:31<00:00,  2.47it/s] 

epoch_time: 211.3317744731903 seconds

EPOCH 120



iter_acc: 120-520/521 iter_acc: 93.75%, lr=['0.32725424859373686'], iter_loss: 0.04928436130285263, val_acc: 91.83%: 100%|██████████| 521/521 [03:29<00:00,  2.49it/s]  

epoch_time: 209.67794013023376 seconds

EPOCH 121



iter_acc: 121-520/521 iter_acc: 95.00%, lr=['0.3247601980640217'], iter_loss: 0.03837728127837181, val_acc: 91.56%: 100%|██████████| 521/521 [03:31<00:00,  2.46it/s]  

epoch_time: 211.7609646320343 seconds

EPOCH 122



iter_acc: 122-520/521 iter_acc: 98.75%, lr=['0.3222579492361179'], iter_loss: 0.03299914672970772, val_acc: 91.53%: 100%|██████████| 521/521 [03:25<00:00,  2.53it/s]  

epoch_time: 205.9233274459839 seconds

EPOCH 123



iter_acc: 123-520/521 iter_acc: 95.00%, lr=['0.31974777650980735'], iter_loss: 0.04444747790694237, val_acc: 91.47%: 100%|██████████| 521/521 [03:51<00:00,  2.25it/s]  

epoch_time: 232.10794162750244 seconds

EPOCH 124



iter_acc: 124-520/521 iter_acc: 97.50%, lr=['0.3172299551538164'], iter_loss: 0.033681005239486694, val_acc: 91.71%: 100%|██████████| 521/521 [03:30<00:00,  2.48it/s] 

epoch_time: 210.58421635627747 seconds

EPOCH 125



iter_acc: 125-520/521 iter_acc: 96.25%, lr=['0.31470476127563024'], iter_loss: 0.031651630997657776, val_acc: 91.85%: 100%|██████████| 521/521 [03:22<00:00,  2.58it/s] 

epoch_time: 202.47425866127014 seconds

EPOCH 126



iter_acc: 126-520/521 iter_acc: 96.25%, lr=['0.31217247179121366'], iter_loss: 0.0361785814166069, val_acc: 91.58%: 100%|██████████| 521/521 [03:22<00:00,  2.57it/s]   

epoch_time: 202.77396893501282 seconds

EPOCH 127



iter_acc: 127-520/521 iter_acc: 97.50%, lr=['0.3096333643946452'], iter_loss: 0.03106074593961239, val_acc: 91.48%: 100%|██████████| 521/521 [03:25<00:00,  2.53it/s]  

epoch_time: 205.86650252342224 seconds

EPOCH 128



iter_acc: 128-520/521 iter_acc: 96.25%, lr=['0.30708771752766395'], iter_loss: 0.041001759469509125, val_acc: 92.00%: 100%|██████████| 521/521 [03:22<00:00,  2.57it/s] 

epoch_time: 202.94825172424316 seconds

EPOCH 129



iter_acc: 129-520/521 iter_acc: 98.75%, lr=['0.3045358103491357'], iter_loss: 0.02796543762087822, val_acc: 91.55%: 100%|██████████| 521/521 [03:22<00:00,  2.57it/s]  

epoch_time: 202.65483212471008 seconds

EPOCH 130



iter_acc: 130-520/521 iter_acc: 98.75%, lr=['0.30197792270443985'], iter_loss: 0.03356296569108963, val_acc: 91.68%: 100%|██████████| 521/521 [03:28<00:00,  2.50it/s]  

epoch_time: 208.585697889328 seconds

EPOCH 131



iter_acc: 131-520/521 iter_acc: 92.50%, lr=['0.2994143350947815'], iter_loss: 0.048098281025886536, val_acc: 91.67%: 100%|██████████| 521/521 [03:31<00:00,  2.47it/s] 

epoch_time: 211.52884531021118 seconds

EPOCH 132



iter_acc: 132-520/521 iter_acc: 100.00%, lr=['0.2968453286464312'], iter_loss: 0.027554422616958618, val_acc: 92.08%: 100%|██████████| 521/521 [03:36<00:00,  2.41it/s]

epoch_time: 216.21486639976501 seconds

EPOCH 133



iter_acc: 133-520/521 iter_acc: 96.25%, lr=['0.29427118507989586'], iter_loss: 0.039599400013685226, val_acc: 91.70%: 100%|██████████| 521/521 [03:36<00:00,  2.41it/s] 

epoch_time: 216.6088764667511 seconds

EPOCH 134



iter_acc: 134-520/521 iter_acc: 98.75%, lr=['0.2916921866790256'], iter_loss: 0.025828000158071518, val_acc: 91.46%: 100%|██████████| 521/521 [03:36<00:00,  2.41it/s] 

epoch_time: 216.47261381149292 seconds

EPOCH 135



iter_acc: 135-520/521 iter_acc: 95.00%, lr=['0.28910861626005774'], iter_loss: 0.04196532815694809, val_acc: 91.61%: 100%|██████████| 521/521 [03:30<00:00,  2.48it/s]  

epoch_time: 210.26702332496643 seconds

EPOCH 136



iter_acc: 136-520/521 iter_acc: 97.50%, lr=['0.28652075714060293'], iter_loss: 0.02994176559150219, val_acc: 91.50%: 100%|██████████| 521/521 [03:34<00:00,  2.43it/s]  

epoch_time: 214.86708736419678 seconds

EPOCH 137



iter_acc: 137-520/521 iter_acc: 96.25%, lr=['0.2839288931085761'], iter_loss: 0.039136968553066254, val_acc: 91.71%: 100%|██████████| 521/521 [03:33<00:00,  2.43it/s] 

epoch_time: 214.15104126930237 seconds

EPOCH 138



iter_acc: 138-520/521 iter_acc: 97.50%, lr=['0.28133330839107606'], iter_loss: 0.03759156912565231, val_acc: 91.56%: 100%|██████████| 521/521 [03:39<00:00,  2.37it/s]  

epoch_time: 219.9703812599182 seconds

EPOCH 139



iter_acc: 139-520/521 iter_acc: 97.50%, lr=['0.27873428762321667'], iter_loss: 0.033373355865478516, val_acc: 91.40%: 100%|██████████| 521/521 [03:32<00:00,  2.45it/s] 

epoch_time: 212.78360080718994 seconds

EPOCH 140



iter_acc: 140-520/521 iter_acc: 98.75%, lr=['0.27613211581691344'], iter_loss: 0.03212208300828934, val_acc: 91.83%: 100%|██████████| 521/521 [03:31<00:00,  2.46it/s]  

epoch_time: 211.88402199745178 seconds

EPOCH 141



iter_acc: 141-520/521 iter_acc: 98.75%, lr=['0.27352707832962864'], iter_loss: 0.029839392751455307, val_acc: 91.65%: 100%|██████████| 521/521 [03:39<00:00,  2.38it/s] 

epoch_time: 219.2690155506134 seconds

EPOCH 142



iter_acc: 142-520/521 iter_acc: 98.75%, lr=['0.2709194608330789'], iter_loss: 0.02659936249256134, val_acc: 91.79%: 100%|██████████| 521/521 [03:32<00:00,  2.45it/s]  

epoch_time: 213.06636452674866 seconds

EPOCH 143



iter_acc: 143-520/521 iter_acc: 97.50%, lr=['0.2683095492819079'], iter_loss: 0.028870541602373123, val_acc: 91.98%: 100%|██████████| 521/521 [03:41<00:00,  2.36it/s] 

epoch_time: 221.3850884437561 seconds

EPOCH 144



iter_acc: 144-520/521 iter_acc: 91.25%, lr=['0.26569762988232837'], iter_loss: 0.05792314559221268, val_acc: 91.78%: 100%|██████████| 521/521 [03:36<00:00,  2.40it/s]  

epoch_time: 216.9208528995514 seconds

EPOCH 145



iter_acc: 145-520/521 iter_acc: 93.75%, lr=['0.263083989060736'], iter_loss: 0.060582034289836884, val_acc: 91.71%: 100%|██████████| 521/521 [03:33<00:00,  2.44it/s] 

epoch_time: 213.4339497089386 seconds

EPOCH 146



iter_acc: 146-520/521 iter_acc: 95.00%, lr=['0.26046891343229994'], iter_loss: 0.04317118972539902, val_acc: 91.78%: 100%|██████████| 521/521 [03:28<00:00,  2.49it/s]  

epoch_time: 209.12950897216797 seconds

EPOCH 147



iter_acc: 147-520/521 iter_acc: 98.75%, lr=['0.257852689769532'], iter_loss: 0.03158324956893921, val_acc: 91.78%: 100%|██████████| 521/521 [03:31<00:00,  2.47it/s]  

epoch_time: 211.32663774490356 seconds

EPOCH 148



iter_acc: 148-520/521 iter_acc: 100.00%, lr=['0.25523560497083925'], iter_loss: 0.02745889686048031, val_acc: 91.89%: 100%|██████████| 521/521 [03:35<00:00,  2.42it/s] 

epoch_time: 215.78389859199524 seconds

EPOCH 149



iter_acc: 149-520/521 iter_acc: 97.50%, lr=['0.25261794602906146'], iter_loss: 0.03339303284883499, val_acc: 91.93%: 100%|██████████| 521/521 [03:33<00:00,  2.44it/s]  

epoch_time: 213.5838053226471 seconds

EPOCH 150



iter_acc: 150-520/521 iter_acc: 98.75%, lr=['0.25'], iter_loss: 0.02646477520465851, val_acc: 91.56%: 100%|██████████| 521/521 [03:34<00:00,  2.43it/s]  

epoch_time: 214.6154580116272 seconds

EPOCH 151



iter_acc: 151-520/521 iter_acc: 100.00%, lr=['0.24738205397093857'], iter_loss: 0.02681245468556881, val_acc: 91.80%: 100%|██████████| 521/521 [03:34<00:00,  2.43it/s] 

epoch_time: 214.90006279945374 seconds

EPOCH 152



iter_acc: 152-520/521 iter_acc: 96.25%, lr=['0.24476439502916084'], iter_loss: 0.033599257469177246, val_acc: 91.99%: 100%|██████████| 521/521 [03:32<00:00,  2.45it/s] 

epoch_time: 212.45397305488586 seconds

EPOCH 153



iter_acc: 153-520/521 iter_acc: 97.50%, lr=['0.24214731023046793'], iter_loss: 0.0328083261847496, val_acc: 91.86%: 100%|██████████| 521/521 [03:33<00:00,  2.45it/s]   

epoch_time: 213.2173728942871 seconds

EPOCH 154



iter_acc: 154-520/521 iter_acc: 100.00%, lr=['0.23953108656770009'], iter_loss: 0.026431512087583542, val_acc: 91.80%: 100%|██████████| 521/521 [03:27<00:00,  2.51it/s]

epoch_time: 208.00167798995972 seconds

EPOCH 155



iter_acc: 155-520/521 iter_acc: 98.75%, lr=['0.23691601093926404'], iter_loss: 0.02921244502067566, val_acc: 91.66%: 100%|██████████| 521/521 [03:50<00:00,  2.26it/s]  

epoch_time: 230.2600576877594 seconds

EPOCH 156



iter_acc: 156-520/521 iter_acc: 97.50%, lr=['0.23430237011767172'], iter_loss: 0.03498245030641556, val_acc: 91.90%: 100%|██████████| 521/521 [04:16<00:00,  2.03it/s]  

epoch_time: 256.67881655693054 seconds

EPOCH 157



iter_acc: 157-520/521 iter_acc: 97.50%, lr=['0.23169045071809213'], iter_loss: 0.03082546591758728, val_acc: 92.01%: 100%|██████████| 521/521 [04:22<00:00,  1.98it/s]  

epoch_time: 263.0055105686188 seconds

EPOCH 158



iter_acc: 158-520/521 iter_acc: 96.25%, lr=['0.22908053916692112'], iter_loss: 0.03805198892951012, val_acc: 91.74%: 100%|██████████| 521/521 [04:17<00:00,  2.02it/s]  

epoch_time: 257.6414601802826 seconds

EPOCH 159



iter_acc: 159-520/521 iter_acc: 97.50%, lr=['0.22647292167037142'], iter_loss: 0.03169791400432587, val_acc: 91.92%: 100%|██████████| 521/521 [04:22<00:00,  1.99it/s]  

epoch_time: 262.4047887325287 seconds

EPOCH 160



iter_acc: 160-520/521 iter_acc: 98.75%, lr=['0.22386788418308667'], iter_loss: 0.027776405215263367, val_acc: 91.90%: 100%|██████████| 521/521 [03:59<00:00,  2.18it/s] 

epoch_time: 239.55426216125488 seconds

EPOCH 161



iter_acc: 161-520/521 iter_acc: 93.75%, lr=['0.22126571237678339'], iter_loss: 0.03572535142302513, val_acc: 92.22%: 100%|██████████| 521/521 [04:33<00:00,  1.91it/s]  

epoch_time: 273.7361583709717 seconds

EPOCH 162



iter_acc: 162-520/521 iter_acc: 97.50%, lr=['0.21866669160892396'], iter_loss: 0.0432014986872673, val_acc: 92.17%: 100%|██████████| 521/521 [04:28<00:00,  1.94it/s]   

epoch_time: 269.0829887390137 seconds

EPOCH 163



iter_acc: 163-520/521 iter_acc: 96.25%, lr=['0.21607110689142392'], iter_loss: 0.033990368247032166, val_acc: 91.99%: 100%|██████████| 521/521 [04:40<00:00,  1.86it/s] 

epoch_time: 280.4782748222351 seconds

EPOCH 164



iter_acc: 164-520/521 iter_acc: 96.25%, lr=['0.21347924285939718'], iter_loss: 0.030808981508016586, val_acc: 92.12%: 100%|██████████| 521/521 [04:18<00:00,  2.02it/s] 

epoch_time: 258.50845980644226 seconds

EPOCH 165



iter_acc: 165-520/521 iter_acc: 98.75%, lr=['0.21089138373994235'], iter_loss: 0.02809979021549225, val_acc: 91.77%: 100%|██████████| 521/521 [04:12<00:00,  2.06it/s]  

epoch_time: 252.82312893867493 seconds

EPOCH 166



iter_acc: 166-520/521 iter_acc: 98.75%, lr=['0.2083078133209744'], iter_loss: 0.02848820947110653, val_acc: 92.08%: 100%|██████████| 521/521 [04:12<00:00,  2.06it/s]  

epoch_time: 252.79206609725952 seconds

EPOCH 167



iter_acc: 167-520/521 iter_acc: 98.75%, lr=['0.2057288149201042'], iter_loss: 0.034561194479465485, val_acc: 92.24%: 100%|██████████| 521/521 [04:08<00:00,  2.10it/s] 

epoch_time: 248.59015941619873 seconds

EPOCH 168



iter_acc: 168-520/521 iter_acc: 100.00%, lr=['0.20315467135356885'], iter_loss: 0.027626723051071167, val_acc: 92.08%: 100%|██████████| 521/521 [04:16<00:00,  2.03it/s]

epoch_time: 256.7699272632599 seconds

EPOCH 169



iter_acc: 169-520/521 iter_acc: 97.50%, lr=['0.20058566490521845'], iter_loss: 0.031078428030014038, val_acc: 91.98%: 100%|██████████| 521/521 [04:31<00:00,  1.92it/s] 

epoch_time: 272.08397579193115 seconds

EPOCH 170



iter_acc: 170-520/521 iter_acc: 97.50%, lr=['0.19802207729556015'], iter_loss: 0.03579970821738243, val_acc: 92.08%: 100%|██████████| 521/521 [04:31<00:00,  1.92it/s]  

epoch_time: 271.4677720069885 seconds

EPOCH 171



iter_acc: 171-520/521 iter_acc: 97.50%, lr=['0.19546418965086437'], iter_loss: 0.03933637961745262, val_acc: 92.36%: 100%|██████████| 521/521 [04:16<00:00,  2.03it/s]  

epoch_time: 256.95850467681885 seconds

EPOCH 172



iter_acc: 172-520/521 iter_acc: 100.00%, lr=['0.1929122824723361'], iter_loss: 0.02517789974808693, val_acc: 92.33%: 100%|██████████| 521/521 [04:14<00:00,  2.05it/s] 

epoch_time: 254.922696352005 seconds

EPOCH 173



iter_acc: 173-520/521 iter_acc: 96.25%, lr=['0.19036663560535483'], iter_loss: 0.03870956599712372, val_acc: 92.19%: 100%|██████████| 521/521 [04:17<00:00,  2.03it/s]  

epoch_time: 257.2760548591614 seconds

EPOCH 174



iter_acc: 174-520/521 iter_acc: 98.75%, lr=['0.18782752820878634'], iter_loss: 0.028743591159582138, val_acc: 91.96%: 100%|██████████| 521/521 [03:40<00:00,  2.36it/s] 

epoch_time: 221.13499402999878 seconds

EPOCH 175



iter_acc: 175-520/521 iter_acc: 98.75%, lr=['0.18529523872436984'], iter_loss: 0.027453085407614708, val_acc: 92.22%: 100%|██████████| 521/521 [03:34<00:00,  2.43it/s] 

epoch_time: 214.7875862121582 seconds

EPOCH 176



iter_acc: 176-520/521 iter_acc: 97.50%, lr=['0.18277004484618364'], iter_loss: 0.03654478117823601, val_acc: 91.98%: 100%|██████████| 521/521 [03:44<00:00,  2.32it/s]  

epoch_time: 224.53073287010193 seconds

EPOCH 177



iter_acc: 177-520/521 iter_acc: 96.25%, lr=['0.18025222349019265'], iter_loss: 0.033949725329875946, val_acc: 92.16%: 100%|██████████| 521/521 [03:42<00:00,  2.35it/s] 

epoch_time: 222.23891305923462 seconds

EPOCH 178



iter_acc: 178-520/521 iter_acc: 97.50%, lr=['0.17774205076388205'], iter_loss: 0.03386326879262924, val_acc: 92.13%: 100%|██████████| 521/521 [03:38<00:00,  2.39it/s]  

epoch_time: 218.20318937301636 seconds

EPOCH 179



iter_acc: 179-520/521 iter_acc: 98.75%, lr=['0.17523980193597835'], iter_loss: 0.031751833856105804, val_acc: 92.10%: 100%|██████████| 521/521 [03:56<00:00,  2.20it/s] 

epoch_time: 236.98455238342285 seconds

EPOCH 180



iter_acc: 180-520/521 iter_acc: 100.00%, lr=['0.17274575140626316'], iter_loss: 0.023203104734420776, val_acc: 92.34%: 100%|██████████| 521/521 [03:42<00:00,  2.34it/s]

epoch_time: 222.7893407344818 seconds

EPOCH 181



iter_acc: 181-520/521 iter_acc: 98.75%, lr=['0.17026017267548252'], iter_loss: 0.02806231379508972, val_acc: 91.97%: 100%|██████████| 521/521 [03:41<00:00,  2.35it/s]  

epoch_time: 221.80643129348755 seconds

EPOCH 182



iter_acc: 182-520/521 iter_acc: 97.50%, lr=['0.1677833383153542'], iter_loss: 0.03220386430621147, val_acc: 92.15%: 100%|██████████| 521/521 [03:38<00:00,  2.38it/s]  

epoch_time: 218.91363549232483 seconds

EPOCH 183



iter_acc: 183-520/521 iter_acc: 98.75%, lr=['0.16531551993867716'], iter_loss: 0.0288365688174963, val_acc: 92.29%: 100%|██████████| 521/521 [03:47<00:00,  2.29it/s]   

epoch_time: 227.84930682182312 seconds

EPOCH 184



iter_acc: 184-520/521 iter_acc: 98.75%, lr=['0.16285698816954625'], iter_loss: 0.03236357867717743, val_acc: 92.04%: 100%|██████████| 521/521 [03:30<00:00,  2.47it/s]  

epoch_time: 211.1877465248108 seconds

EPOCH 185



iter_acc: 185-520/521 iter_acc: 98.75%, lr=['0.160408012613675'], iter_loss: 0.026749078184366226, val_acc: 92.25%: 100%|██████████| 521/521 [03:42<00:00,  2.34it/s] 

epoch_time: 222.59865593910217 seconds

EPOCH 186



iter_acc: 186-520/521 iter_acc: 97.50%, lr=['0.15796886182883058'], iter_loss: 0.028887879103422165, val_acc: 92.21%: 100%|██████████| 521/521 [03:36<00:00,  2.41it/s] 

epoch_time: 216.22685503959656 seconds

EPOCH 187



iter_acc: 187-520/521 iter_acc: 96.25%, lr=['0.1555398032953832'], iter_loss: 0.03620710223913193, val_acc: 91.97%: 100%|██████████| 521/521 [03:52<00:00,  2.24it/s]  

epoch_time: 232.3807761669159 seconds

EPOCH 188



iter_acc: 188-520/521 iter_acc: 97.50%, lr=['0.15312110338697427'], iter_loss: 0.0341271311044693, val_acc: 92.36%: 100%|██████████| 521/521 [03:35<00:00,  2.42it/s]   

epoch_time: 215.22720503807068 seconds

EPOCH 189



iter_acc: 189-520/521 iter_acc: 98.75%, lr=['0.15071302734130482'], iter_loss: 0.02766011469066143, val_acc: 92.39%: 100%|██████████| 521/521 [03:31<00:00,  2.46it/s]  

epoch_time: 211.8346128463745 seconds

EPOCH 190



iter_acc: 190-520/521 iter_acc: 96.25%, lr=['0.14831583923104993'], iter_loss: 0.03543270751833916, val_acc: 92.07%: 100%|██████████| 521/521 [03:35<00:00,  2.41it/s]  

epoch_time: 216.14324307441711 seconds

EPOCH 191



iter_acc: 191-520/521 iter_acc: 96.25%, lr=['0.14592980193489974'], iter_loss: 0.03530602902173996, val_acc: 92.36%: 100%|██████████| 521/521 [04:02<00:00,  2.15it/s]  

epoch_time: 242.6422107219696 seconds

EPOCH 192



iter_acc: 192-520/521 iter_acc: 98.75%, lr=['0.14355517710873184'], iter_loss: 0.02945483848452568, val_acc: 92.34%: 100%|██████████| 521/521 [03:51<00:00,  2.25it/s]  

epoch_time: 231.80335116386414 seconds

EPOCH 193



iter_acc: 193-520/521 iter_acc: 95.00%, lr=['0.14119222515691815'], iter_loss: 0.04224197193980217, val_acc: 92.15%: 100%|██████████| 521/521 [03:34<00:00,  2.43it/s]  

epoch_time: 214.30146884918213 seconds

EPOCH 194



iter_acc: 194-520/521 iter_acc: 100.00%, lr=['0.1388412052037682'], iter_loss: 0.02800450287759304, val_acc: 92.07%: 100%|██████████| 521/521 [03:44<00:00,  2.33it/s] 

epoch_time: 224.26727414131165 seconds

EPOCH 195



iter_acc: 195-520/521 iter_acc: 100.00%, lr=['0.1365023750651133'], iter_loss: 0.023092135787010193, val_acc: 92.43%: 100%|██████████| 521/521 [03:30<00:00,  2.47it/s]

epoch_time: 211.19450783729553 seconds

EPOCH 196



iter_acc: 196-520/521 iter_acc: 97.50%, lr=['0.13417599122003462'], iter_loss: 0.03345286101102829, val_acc: 92.64%: 100%|██████████| 521/521 [04:25<00:00,  1.96it/s]  

epoch_time: 265.9976444244385 seconds

EPOCH 197



iter_acc: 197-520/521 iter_acc: 98.75%, lr=['0.13186230878273653'], iter_loss: 0.02944573014974594, val_acc: 92.11%: 100%|██████████| 521/521 [03:48<00:00,  2.28it/s]  

epoch_time: 228.82966995239258 seconds

EPOCH 198



iter_acc: 198-520/521 iter_acc: 100.00%, lr=['0.12956158147457114'], iter_loss: 0.022917233407497406, val_acc: 92.39%: 100%|██████████| 521/521 [04:05<00:00,  2.13it/s]

epoch_time: 245.38776397705078 seconds

EPOCH 199



iter_acc: 199-520/521 iter_acc: 96.25%, lr=['0.1272740615962148'], iter_loss: 0.03458113223314285, val_acc: 92.41%: 100%|██████████| 521/521 [04:05<00:00,  2.12it/s]  

epoch_time: 245.93979954719543 seconds

EPOCH 200



iter_acc: 200-520/521 iter_acc: 98.75%, lr=['0.12499999999999994'], iter_loss: 0.025422824546694756, val_acc: 92.34%: 100%|██████████| 521/521 [03:48<00:00,  2.28it/s] 

epoch_time: 229.17895317077637 seconds

EPOCH 201



iter_acc: 201-520/521 iter_acc: 100.00%, lr=['0.12273964606240717'], iter_loss: 0.027271132916212082, val_acc: 92.29%: 100%|██████████| 521/521 [03:28<00:00,  2.49it/s]

epoch_time: 209.24375081062317 seconds

EPOCH 202



iter_acc: 202-520/521 iter_acc: 97.50%, lr=['0.12049324765671748'], iter_loss: 0.038545407354831696, val_acc: 92.25%: 100%|██████████| 521/521 [03:41<00:00,  2.36it/s] 

epoch_time: 221.35342669487 seconds

EPOCH 203



iter_acc: 203-520/521 iter_acc: 100.00%, lr=['0.1182610511258306'], iter_loss: 0.023625817149877548, val_acc: 92.52%: 100%|██████████| 521/521 [03:37<00:00,  2.39it/s]

epoch_time: 218.19170713424683 seconds

EPOCH 204



iter_acc: 204-520/521 iter_acc: 98.75%, lr=['0.11604330125525089'], iter_loss: 0.028399653732776642, val_acc: 92.31%: 100%|██████████| 521/521 [03:33<00:00,  2.44it/s] 

epoch_time: 214.17340970039368 seconds

EPOCH 205



iter_acc: 205-520/521 iter_acc: 97.50%, lr=['0.11384024124624331'], iter_loss: 0.03139436990022659, val_acc: 92.12%: 100%|██████████| 521/521 [03:34<00:00,  2.42it/s]  

epoch_time: 215.10604596138 seconds

EPOCH 206



iter_acc: 206-520/521 iter_acc: 98.75%, lr=['0.111652112689164'], iter_loss: 0.02580947056412697, val_acc: 92.12%: 100%|██████████| 521/521 [03:32<00:00,  2.45it/s]  

epoch_time: 213.06210327148438 seconds

EPOCH 207



iter_acc: 207-520/521 iter_acc: 100.00%, lr=['0.10947915553696741'], iter_loss: 0.023978084325790405, val_acc: 92.29%: 100%|██████████| 521/521 [03:37<00:00,  2.40it/s]

epoch_time: 217.4837682247162 seconds

EPOCH 208



iter_acc: 208-520/521 iter_acc: 100.00%, lr=['0.1073216080788921'], iter_loss: 0.027983641251921654, val_acc: 92.28%: 100%|██████████| 521/521 [03:33<00:00,  2.44it/s]

epoch_time: 213.92218279838562 seconds

EPOCH 209



iter_acc: 209-520/521 iter_acc: 100.00%, lr=['0.10517970691433026'], iter_loss: 0.026605531573295593, val_acc: 92.15%: 100%|██████████| 521/521 [03:35<00:00,  2.42it/s]

epoch_time: 215.76855278015137 seconds

EPOCH 210



iter_acc: 210-520/521 iter_acc: 98.75%, lr=['0.10305368692688174'], iter_loss: 0.027001459151506424, val_acc: 92.05%: 100%|██████████| 521/521 [03:36<00:00,  2.40it/s] 

epoch_time: 216.978835105896 seconds

EPOCH 211



iter_acc: 211-520/521 iter_acc: 98.75%, lr=['0.10094378125859602'], iter_loss: 0.03350551053881645, val_acc: 92.40%: 100%|██████████| 521/521 [03:33<00:00,  2.44it/s]  

epoch_time: 213.9743185043335 seconds

EPOCH 212



iter_acc: 212-520/521 iter_acc: 98.75%, lr=['0.09885022128440629'], iter_loss: 0.026252765208482742, val_acc: 92.45%: 100%|██████████| 521/521 [03:45<00:00,  2.31it/s] 

epoch_time: 225.74599194526672 seconds

EPOCH 213



iter_acc: 213-520/521 iter_acc: 98.75%, lr=['0.09677323658675593'], iter_loss: 0.025241974741220474, val_acc: 92.28%: 100%|██████████| 521/521 [04:05<00:00,  2.12it/s] 

epoch_time: 245.5973265171051 seconds

EPOCH 214



iter_acc: 214-519/521 iter_acc: 100.00%, lr=['0.09471305493042242'], iter_loss: 0.024585148319602013, val_acc: 92.28%: 100%|█████████▉| 520/521 [04:07<00:00,  2.76it/s]

In [None]:
# # sweep 하는 코드, 위 셀 주석처리 해야 됨.

# unique_name_hyper = 'main'
# sweep_configuration = {
#     'method': 'bayes',
#     'name': 'my_snn_sweep',
#     'metric': {'goal': 'maximize', 'name': 'val_acc_now'},
#     'parameters': 
#     {
#         "learning_rate": {"values": [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0,1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,2.0]},
#         "batch_size": {"values": [64, 96, 128]},
#         "decay": {"values": [0.3,0.4,0.5,0.6,0.7,0.8,0.875,0.9]},
#      }
# }

# def hyper_iter():
#     ### my_snn control board ########################
#     unique_name = unique_name_hyper ## 이거 설정하면 새로운 경로에 모두 save
    
#     wandb.init()
#     learning_rate  =  wandb.config.learning_rate
#     batch_size  =  wandb.config.batch_size
#     decay  =  wandb.config.decay

#     my_snn_system(  devices = "3",
#                     single_step = True, # True # False
#                     unique_name = unique_name,
#                     my_seed = 42,
#                     TIME = 6 , # dvscifar 10 # ottt 6 or 10 # nda 10  # 제작하는 dvs에서 TIME넘거나 적으면 자르거나 PADDING함
#                     BATCH = batch_size, # batch norm 할거면 2이상으로 해야함   # nda 256   #  ottt 128
#                     IMAGE_SIZE = 32, # dvscifar 48 # MNIST 28 # CIFAR10 32 # PMNIST 28
#                     # dvsgesture 128, dvs_cifar2 128, nmnist 34, n_caltech101 180,240, n_tidigits 64, heidelberg 700, 
#                     #pmnist는 28로 해야 됨. 나머지는 바꿔도 돌아는 감.

#                     # DVS_CIFAR10 할거면 time 10으로 해라
#                     which_data = 'CIFAR10',
#     # 'CIFAR100' 'CIFAR10' 'MNIST' 'FASHION_MNIST' 'DVS_CIFAR10' 'PMNIST'아직
#     # 'DVS_GESTURE','DVS_CIFAR10_2','NMNIST','N_CALTECH101','n_tidigits','heidelberg'
#                     # CLASS_NUM = 10,
#                     data_path = '/data2', # YOU NEED TO CHANGE THIS
#                     rate_coding = False, # True # False

#                     lif_layer_v_init = 0.0,
#                     lif_layer_v_decay = decay,
#                     lif_layer_v_threshold = 1.0,  # 10000이상으로 하면 NDA LIF 씀. #nda 0.5  #ottt 1.0
#                     lif_layer_v_reset = 0, # 10000이상은 hardreset (내 LIF쓰기는 함 ㅇㅇ)
#                     lif_layer_sg_width = 1.0, # # surrogate sigmoid 쓸 때는 의미없음

#                     # synapse_conv_in_channels = IMAGE_PIXEL_CHANNEL,
#                     synapse_conv_kernel_size = 3,
#                     synapse_conv_stride = 1,
#                     synapse_conv_padding = 1,
#                     synapse_conv_trace_const1 = 1,
#                     synapse_conv_trace_const2 = decay, # lif_layer_v_decay

#                     # synapse_fc_out_features = CLASS_NUM,
#                     synapse_fc_trace_const1 = 1,
#                     synapse_fc_trace_const2 = decay, # lif_layer_v_decay

#                     pre_trained = False, # True # False
#                     convTrue_fcFalse = True, # True # False

#                     # 'P' for average pooling, 'D' for (1,1) aver pooling, 'M' for maxpooling, 'L' for linear classifier, [  ] for residual block
#                     # conv에서 10000 이상은 depth-wise separable (BPTT만 지원), 20000이상은 depth-wise (BPTT만 지원)
#                     # cfg = [64],
#                     # cfg = [64,[64,64],64], # 끝에 linear classifier 하나 자동으로 붙습니다
#                     cfg = [64, 128, 'P', 256, 256, 'P', 512, 512, 'P', 512, 512, 'D'], #ottt
#                     # cfg = [64, 128, 'P', 256, 256, 'P', 512, 512, 'P', 512, 512], #ottt
#                     # cfg = [64, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], # ottt 
#                     # cfg = [64, 'P', 128, 'P', 256, 256, 'P', 512, 512, 512, 512, 'D'], # nda
#                     # cfg = [64, 'P', 128, 'P', 256, 256, 'P', 512, 512, 512, 512], # nda 128pixel
#                     # cfg = [64, 'P', 128, 'P', 256, 256, 'P', 512, 512, 512, 512, 'L', 4096, 4096],
#                     # cfg = [20001,10001], # depthwise, separable
#                     # cfg = [64,20064,10001], # vanilla conv, depthwise, separable
#                     # cfg = [8, 'P', 8, 'P', 8, 'P', 8,'P', 8, 'P'],
#                     # cfg = [], 
                    
#                     net_print = True, # True # False
#                     weight_count_print = False, # True # False
                    
#                     pre_trained_path = f"net_save/save_now_net_{unique_name}.pth",
#                     learning_rate = learning_rate, # default 0.001  # ottt 0.1 0.00001 # nda 0.001 
#                     epoch_num = 4,
#                     verbose_interval = 999999999, #숫자 크게 하면 꺼짐 #걍 중간중간 iter에서 끊어서 출력
#                     validation_interval = 999999999, #숫자 크게 하면 에포크 마지막 iter 때 val 함

#                     tdBN_on = False,  # True # False
#                     BN_on = False,  # True # False
                    
#                     surrogate = 'sigmoid', # 'rectangle' 'sigmoid' 'rough_rectangle'
                    
#                     gradient_verbose = False,  # True # False  # weight gradient 각 layer마다 띄워줌

#                     BPTT_on = False,  # True # False # True이면 BPTT, False이면 OTTT  # depthwise, separable은 BPTT만 가능
#                     optimizer_what = 'SGD', # 'SGD' 'Adam', 'RMSprop'
#                     scheduler_name = 'CosineAnnealingLR', # 'no' 'StepLR' 'ExponentialLR' 'ReduceLROnPlateau' 'CosineAnnealingLR' 'OneCycleLR'
                    
#                     ddp_on = False,   # True # False

#                     nda_net = False,   # True # False

#                     domain_il_epoch = 0, # over 0, then domain il mode on # pmnist 쓸거면 HLOP 코드보고 더 디벨롭하셈. 지금 개발 hold함.
                    
#                     dvs_clipping = True, # dvs zero&one  # gesture, cifar-dvs2, nmnist, ncaltech101
#                     dvs_duration = 1000000, # 0 아니면 time sampling # dvs number sampling OR time sampling # gesture, cifar-dvs2, nmnist, ncaltech101
#                     #있는 데이터들 #gesture 1000000 #nmnist 10000

#                     OTTT_sWS_on = True, # True # False # BPTT끄고, CONV에만 적용됨.
                    
#                     ) 
#     # sigmoid와 BN이 있어야 잘된다.
#     # average pooling
#     # 이 낫다. 
    
#     # nda에서는 decay = 0.25, threshold = 0.5, width =1, surrogate = rectangle, batch = 256, tdBN = True
#     ## OTTT 에서는 decay = 0.5, threshold = 1.0, surrogate = sigmoid, batch = 128, BN = True


# sweep_id = wandb.sweep(sweep=sweep_configuration, project=f'my_snn {unique_name_hyper}')
# wandb.agent(sweep_id, function=hyper_iter, count=1000)


In [None]:
# import numpy as np
# import matplotlib.pyplot as plt
# import json




# def pad_array_to_match_length(array1, array2):
#     if len(array1) > len(array2):
#         padded_array2 = np.pad(array2, (0, len(array1) - len(array2)), 'constant')
#         return array1, padded_array2
#     elif len(array2) > len(array1):
#         padded_array1 = np.pad(array1, (0, len(array2) - len(array1)), 'constant')
#         return padded_array1, array2
#     else:
#         return array1, array2
# def load_hyperparameters(filename=f'result_save/hyperparameters_{unique_name}.json'):
#     with open(filename, 'r') as f:
#         return json.load(f)
    




# current_time = '20240628_110116'
# base_name = f'{current_time}'
# iter_acc_file_name = f'result_save/{base_name}_iter_acc_array_{unique_name}.npy'
# val_acc_file_name = f'result_save/{base_name}_val_acc_now_array_{unique_name}.npy'
# hyperparameters_file_name = f'result_save/{base_name}_hyperparameters_{unique_name}.json'

# ### if you want to just see most recent train and val acc###########################
# iter_acc_file_name = f'result_save/iter_acc_array_{unique_name}.npy'
# tr_acc_file_name = f'result_save/tr_acc_array_{unique_name}.npy'
# val_acc_file_name = f'result_save/val_acc_now_array_{unique_name}.npy'
# hyperparameters_file_name = f'result_save/hyperparameters_{unique_name}.json'

# loaded_iter_acc_array = np.load(iter_acc_file_name)*100
# loaded_tr_acc_array = np.load(tr_acc_file_name)*100
# loaded_val_acc_array = np.load(val_acc_file_name)*100
# hyperparameters = load_hyperparameters(hyperparameters_file_name)

# loaded_iter_acc_array, loaded_val_acc_array = pad_array_to_match_length(loaded_iter_acc_array, loaded_val_acc_array)
# loaded_iter_acc_array, loaded_tr_acc_array = pad_array_to_match_length(loaded_iter_acc_array, loaded_tr_acc_array)
# loaded_val_acc_array, loaded_tr_acc_array = pad_array_to_match_length(loaded_val_acc_array, loaded_tr_acc_array)

# top_iter_acc = np.max(loaded_iter_acc_array)
# top_tr_acc = np.max(loaded_tr_acc_array)
# top_val_acc = np.max(loaded_val_acc_array)

# which_data = hyperparameters['which_data']
# BPTT_on = hyperparameters['BPTT_on']
# current_epoch = hyperparameters['current epoch']
# surrogate = hyperparameters['surrogate']
# cfg = hyperparameters['cfg']
# tdBN_on = hyperparameters['tdBN_on']
# BN_on = hyperparameters['BN_on']


# iterations = np.arange(len(loaded_iter_acc_array))

# # 그래프 그리기
# plt.figure(figsize=(10, 5))
# plt.plot(iterations, loaded_iter_acc_array, label='Iter Accuracy', color='g', alpha=0.2)
# plt.plot(iterations, loaded_tr_acc_array, label='Training Accuracy', color='b')
# plt.plot(iterations, loaded_val_acc_array, label='Validation Accuracy', color='r')

# # # 텍스트 추가
# # plt.text(0.05, 0.95, f'Top Training Accuracy: {100*top_iter_acc:.2f}%', transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color='blue')
# # plt.text(0.05, 0.90, f'Top Validation Accuracy: {100*top_val_acc:.2f}%', transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color='red')
# # 텍스트 추가
# plt.text(0.5, 0.10, f'Top Training Accuracy: {top_tr_acc:.2f}%', transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='center', color='blue')
# plt.text(0.5, 0.05, f'Top Validation Accuracy: {top_val_acc:.2f}%', transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='center', color='red')

# plt.xlabel('Iterations')
# plt.ylabel('Accuracy [%]')

# # 그래프 제목에 하이퍼파라미터 정보 추가
# title = f'Training and Validation Accuracy over Iterations\n\nData: {which_data}, BPTT: {"On" if BPTT_on else "Off"}, Current Epoch: {current_epoch}, Surrogate: {surrogate},\nCFG: {cfg}, tdBN: {"On" if tdBN_on else "Off"}, BN: {"On" if BN_on else "Off"}'

# plt.title(title)

# plt.legend(loc='lower right')
# plt.xlim(0)  # x축을 0부터 시작
# plt.grid(True)
# plt.show()