In [1]:
# Copyright (c) 2024 Byeonghyeon Kim 
# github site: https://github.com/bhkim003/ByeonghyeonKim
# email: bhkim003@snu.ac.kr
 
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
 
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


In [2]:
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.datasets
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt

import time

from snntorch import spikegen
import matplotlib.pyplot as plt
import snntorch.spikeplot as splt
from IPython.display import HTML

from tqdm import tqdm

from apex.parallel import DistributedDataParallel as DDP

import random

In [3]:
# my module import
from modules import *

# modules 폴더에 새모듈.py 만들면
# modules/__init__py 파일에 form .새모듈 import * 하셈
# 그리고 새모듈.py에서 from modules.새모듈 import * 하셈


In [4]:
def my_snn_system(devices = "0,1,2,3",
                    my_seed = 42,
                    TIME = 8,
                    BATCH = 256,
                    IMAGE_SIZE = 32,
                    which_data = 'CIFAR10',
                    CLASS_NUM = 10,
                    data_path = '/data2',
                    rate_coding = True,
    
                    lif_layer_v_init = 0.0,
                    lif_layer_v_decay = 0.6,
                    lif_layer_v_threshold = 1.2,
                    lif_layer_v_reset = 0.0,
                    lif_layer_sg_width = 1,

                    # synapse_conv_in_channels = IMAGE_PIXEL_CHANNEL,
                    synapse_conv_kernel_size = 3,
                    synapse_conv_stride = 1,
                    synapse_conv_padding = 1,
                    synapse_conv_trace_const1 = 1,
                    synapse_conv_trace_const2 = 0.6,

                    # synapse_fc_out_features = CLASS_NUM,
                    synapse_fc_trace_const1 = 1,
                    synapse_fc_trace_const2 = 0.6,

                    pre_trained = False,
                    convTrue_fcFalse = True,
                    cfg = [64, 64],
                    pre_trained_path = "net_save/save_now_net.pth",
                    learning_rate = 0.0001,
                    epoch_num = 200,
                    verbose_interval = 100, #숫자 크게 하면 꺼짐
                    validation_interval = 10, #숫자 크게 하면 꺼짐
                    tdBN_on = False,
                    BN_on = False,

                    surrogate = 'sigmoid',

                    gradient_verbose = False,

                    BPTT_on = False,

                    scheduler_name = 'no',
                    
                    ddp_on = True,
                  ):


    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
    os.environ["CUDA_VISIBLE_DEVICES"]= devices

    
    torch.manual_seed(my_seed)


    
    # data loader, pixel channel, class num
    train_loader, test_loader, synapse_conv_in_channels = data_loader(
            which_data,
            data_path, 
            rate_coding, 
            BATCH, 
            IMAGE_SIZE,
            ddp_on)
    synapse_fc_out_features = CLASS_NUM


    ## parameter number calculator ##########################################
    params_num = 0
    img_size = IMAGE_SIZE 
    bias_param = 1 # 1 or 0
    if (convTrue_fcFalse == True):
        past_kernel = synapse_conv_in_channels
        for kernel in cfg:
            if (type(kernel) == list):
                for residual_kernel in kernel:
                    params_num += residual_kernel * ((synapse_conv_kernel_size**2) * past_kernel + bias_param)
                    past_kernel = residual_kernel
            elif (kernel == 'P'):
                img_size = img_size // 2
            else:
                params_num += kernel * (synapse_conv_kernel_size**2 * past_kernel + bias_param)
                past_kernel = kernel    
        params_num += (past_kernel * (img_size**2) + bias_param) * synapse_fc_out_features
    else:
        past_in_channel = synapse_conv_in_channels*img_size*img_size
        for in_channel in cfg:
            if (type(in_channel) == list):
                for residual_in_channel in in_channel:
                    params_num += (past_in_channel + bias_param) * residual_in_channel
                    past_in_channel = residual_in_channel
            # elif (in_channel == 'M'): #it's a holy FC layer!
            #     img_size = img_size // 2
            else:
                params_num += (past_in_channel + bias_param) * in_channel
                past_in_channel = in_channel
        params_num += (past_in_channel + bias_param) * synapse_fc_out_features
    ## parameter number calculator ##########################################


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if pre_trained == False:
        if (convTrue_fcFalse == False):
            net = MY_SNN_FC(cfg, synapse_conv_in_channels, IMAGE_SIZE, synapse_fc_out_features,
                     synapse_fc_trace_const1, synapse_fc_trace_const2, 
                     lif_layer_v_init, lif_layer_v_decay, 
                     lif_layer_v_threshold, lif_layer_v_reset,
                     lif_layer_sg_width,
                     tdBN_on,
                     BN_on, TIME,
                     surrogate,
                     BPTT_on).to(device)
        else:
            net = MY_SNN_CONV(cfg, synapse_conv_in_channels, IMAGE_SIZE,
                     synapse_conv_kernel_size, synapse_conv_stride, 
                     synapse_conv_padding, synapse_conv_trace_const1, 
                     synapse_conv_trace_const2, 
                     lif_layer_v_init, lif_layer_v_decay, 
                     lif_layer_v_threshold, lif_layer_v_reset,
                     lif_layer_sg_width,
                     synapse_fc_out_features, synapse_fc_trace_const1, synapse_fc_trace_const2,
                     tdBN_on,
                     BN_on, TIME,
                     surrogate,
                     BPTT_on).to(device)
        net = torch.nn.DataParallel(net)
    else:
        net = torch.load(pre_trained_path)


    net = net.to(device)
    # print(net)
    
    ## param num and memory estimation ##########################################
    real_param_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
    # Batch norm 있으면 아래 두 개 서로 다를 수 있음.
    # assert real_param_num == params_num, f'parameter number is not same. real_param_num: {real_param_num}, params_num: {params_num}'    
    print('='*50)
    print(f"Num of PARAMS: {params_num:,}")
    memory = params_num / 8 / 1024 / 1024 # MB
    precision = 32
    memory = memory * precision 
    print(f"Memory: {memory:.2f}MiB at {precision}-bit")
    print('='*50)
    ##########################################################################

    criterion = nn.CrossEntropyLoss().to(device)
    # optimizer = torch.optim.Adam(net.parameters(), lr=0.00001)
    optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)


    if (scheduler_name == 'StepLR'):
        scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    elif (scheduler_name == 'ExponentialLR'):
        scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
    elif (scheduler_name == 'ReduceLROnPlateau'):
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)
    elif (scheduler_name == 'CosineAnnealingLR'):
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
    elif (scheduler_name == 'OneCycleLR'):
        scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=0.1, steps_per_epoch=len(train_loader), epochs=100)
    else:
        pass # 'no' scheduler

    val_acc = 0
    val_acc_now = 0
    elapsed_time_val = 0
    for epoch in range(epoch_num):
        print('EPOCH', epoch)
        epoch_start_time = time.time()
        running_loss = 0.0
        
        iterator = enumerate(train_loader, 0)
        if (ddp_on == True):
            if torch.distributed.get_rank() == 0:   
                iterator = tqdm(iterator, total=len(train_loader), desc='train', dynamic_ncols=True, position=0, leave=True)
        else:
            iterator = tqdm(iterator, total=len(train_loader), desc='train', dynamic_ncols=True, position=0, leave=True)

        for i, data in iterator:
            iter_one_train_time_start = time.time()
            net.train()

            ## data loading #################################
            inputs, labels = data

            
            if (which_data == 'DVS-CIFAR10'):
                inputs = inputs.permute(1, 0, 2, 3, 4)
            elif rate_coding == True :
                inputs = spikegen.rate(inputs, num_steps=TIME)
            else :
                inputs = inputs.repeat(TIME, 1, 1, 1, 1)
            # inputs: [Time, Batch, Channel, Height, Width]  
            ################################################# 

            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            # inputs: [Time, Batch, Channel, Height, Width]   
            inputs = inputs.permute(1, 0, 2, 3, 4) # net에 넣어줄때는 batch가 젤 앞 차원으로 와야함. # dataparallel때매
            # inputs: [Batch, Time, Channel, Height, Width]   
        
            outputs = net(inputs)

            batch = BATCH 
            if labels.size(0) != BATCH: 
                batch = labels.size(0)




            ####### training accruacy print ###############################
            correct = 0
            total = 0
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted[0:batch] == labels).sum().item()
            if i % verbose_interval == verbose_interval-1:
                print(f'{epoch}-{i} training acc: {100 * correct / total:.2f}%, lr={[f"{lr}" for lr in (param_group["lr"] for param_group in optimizer.param_groups)]}, val_acc: {100 * val_acc_now:.2f}%')
            training_acc_string = f'{epoch}-{i}/{len(train_loader)} tr_acc: {100 * correct / total:.2f}%, lr={[f"{lr}" for lr in (param_group["lr"] for param_group in optimizer.param_groups)]}'
            
            ################################################################
            

            loss = criterion(outputs[0:batch,:], labels)
            loss.backward()


            ### gradinet verbose ##########################################
            if (gradient_verbose == True):
                if (i % verbose_interval == verbose_interval-1):
                    print('\n\nepoch', epoch, 'iter', i)
                    for name, param in net.named_parameters():
                        if param.requires_grad:
                            print('\n\n\n\n' , name, param.grad)
            ################################################################
            
            optimizer.step()

            running_loss += loss.item()
            # print("Epoch: {}, Iter: {}, Loss: {}".format(epoch + 1, i + 1, running_loss / 100))

            iter_one_train_time_end = time.time()
            elapsed_time = iter_one_train_time_end - iter_one_train_time_start  # 실행 시간 계산

            if (i % verbose_interval == verbose_interval-1):
                print(f"iter_one_train_time: {elapsed_time} seconds, last one_val_time: {elapsed_time_val} seconds")

            ##### validation ##############################################
            if i % validation_interval == validation_interval-1:
                iter_one_val_time_start = time.time()
                
                correct = 0
                total = 0
                with torch.no_grad():
                    net.eval()
                    for data in test_loader:
                        ## data loading #################################
                        inputs, labels = data

                        
                        if (which_data == 'DVS-CIFAR10'):
                            inputs = inputs.permute(1, 0, 2, 3, 4)
                        elif rate_coding == True :
                            inputs = spikegen.rate(inputs, num_steps=TIME)
                        else :
                            inputs = inputs.repeat(TIME, 1, 1, 1, 1)
                        # inputs: [Time, Batch, Channel, Height, Width]  
                        ################################################# 

                        inputs = inputs.to(device)
                        labels = labels.to(device)
                        outputs = net(inputs.permute(1, 0, 2, 3, 4))
                        _, predicted = torch.max(outputs.data, 1)
                        total += labels.size(0)
                        batch = BATCH 
                        if labels.size(0) != BATCH: 
                            batch = labels.size(0)
                        correct += (predicted[0:batch] == labels).sum().item()
                        val_loss = criterion(outputs[0:batch,:], labels)

                    val_acc_now = correct / total
                    # print(f'{epoch}-{i} validation acc: {100 * val_acc_now:.2f}%, lr={[f"{lr:.10f}" for lr in (param_group["lr"] for param_group in optimizer.param_groups)]}')

                iter_one_val_time_end = time.time()
                elapsed_time_val = iter_one_val_time_end - iter_one_val_time_start  # 실행 시간 계산
                # print(f"iter_one_val_time: {elapsed_time_val} seconds")

                # network save
                if val_acc < val_acc_now:
                    val_acc = val_acc_now
                    torch.save(net.state_dict(), "net_save/save_now_net_weights.pth")
                    torch.save(net, "net_save/save_now_net.pth")
                    torch.save(net.module.state_dict(), "net_save/save_now_net_weights2.pth")
                    torch.save(net.module, "net_save/save_now_net2.pth")
            ################################################################
            iterator.set_description(f"train: {training_acc_string}, val_acc: {100 * val_acc_now:.2f}%")     
        if (scheduler_name != 'no'):
            if (scheduler_name == 'ReduceLROnPlateau'):
                scheduler.step(val_loss)
            else:
                scheduler.step()

        
        
        epoch_time_end = time.time()
        epoch_time = epoch_time_end - epoch_start_time  # 실행 시간 계산
        
        print(f"epoch_time: {epoch_time} seconds")
        print('\n')



In [5]:
### my_snn control board ########################
decay = 0.875

my_snn_system(  devices = "2,3,4,5",
                my_seed = 42,
                TIME = 10, # dvscifar 10
                BATCH = 32, # batch norm 할거면 2이상으로 해야함
                IMAGE_SIZE = 48, # dvscifar 48 # MNIST 28 # CIFAR10 32
                which_data = 'DVS-CIFAR10',# 'CIFAR10' 'MNIST' 'FASHION_MNIST' 'DVS-CIFAR10'
                CLASS_NUM = 10,
                data_path = '/data2', # YOU NEED TO CHANGE THIS
                rate_coding = False, # True # False

                lif_layer_v_init = 0.0,
                lif_layer_v_decay = decay,
                lif_layer_v_threshold = 1.2,
                lif_layer_v_reset = 0.0, #현재 안씀. 걍 빼기 해버림
                lif_layer_sg_width = 1.0, # surrogate sigmoid 쓸 때는 의미없음

                # synapse_conv_in_channels = IMAGE_PIXEL_CHANNEL,
                synapse_conv_kernel_size = 3,
                synapse_conv_stride = 1,
                synapse_conv_padding = 1,
                synapse_conv_trace_const1 = 1,
                synapse_conv_trace_const2 = decay, # lif_layer_v_decay

                # synapse_fc_out_features = CLASS_NUM,
                synapse_fc_trace_const1 = 1,
                synapse_fc_trace_const2 = decay, # lif_layer_v_decay

                pre_trained = False, # True # False
                convTrue_fcFalse = True, # True # False
                # cfg = [64],
                # cfg = [64,[64,64],64], # 끝에 linear classifier 하나 자동으로 붙습니다
                cfg = [64, 128, 'P', 256, 256, 'P', 512, 512, 'P', 512, 512],
                pre_trained_path = "net_save/save_now_net.pth",
                learning_rate = 0.00001,
                epoch_num = 200,
                verbose_interval = 10000, #숫자 크게 하면 꺼짐
                validation_interval = 50, #숫자 크게 하면 꺼짐
                tdBN_on = False,  # True # False
                BN_on = True,  # True # False
                
                surrogate = 'sigmoid', # 'rectangle' 'sigmoid' 'rough_rectangle'
                
                gradient_verbose = False,  # True # False  # weight gradient 각 layer마다 띄워줌

                BPTT_on = False,  # True # False # True이면 BPTT, False이면 OTTT

                scheduler_name = 'no', # 'no' 'StepLR' 'ExponentialLR' 'ReduceLROnPlateau' 'CosineAnnealingLR' 'OneCycleLR'
                
                ddp_on = False,
                )
# sigmoid와 BN이 있어야 잘되는건가?

'''
cfg 종류 = {
[64, 64]
[64, 64, 64, 64]
[64, [64, 64], 64]
[64, 128, 256]
[64, 128, 'P', 256, 256, 'P', 512, 512, 'P', 512, 512]
[64, 64]
[64, 'P', 128, 'P', 256, 256, 'P', 512, 512, 512, 512]
} 
'''


Num of PARAMS: 9,301,834
Memory: 35.48MiB at 32-bit
EPOCH 0


train: 0-5/282 tr_acc: 18.75%, lr=['1e-05'], val_acc: 0.00%:   2%|▏         | 6/282 [00:15<11:57,  2.60s/it]


KeyboardInterrupt: 

In [None]:
# my_snn_system(  devices = "2,3,4,5",
#                 my_seed = 42,
#                 TIME = 8,
#                 BATCH = 128, # batch norm 할거면 2이상으로 해야함
#                 IMAGE_SIZE = 32,
#                 which_data = 'CIFAR10',# 'CIFAR10' 'MNIST' 'FASHION_MNIST' 'DVS-CIFAR10'
#                 CLASS_NUM = 10,
#                 data_path = '/data2', # YOU NEED TO CHANGE THIS
#                 rate_coding = False, # True # False

#                 lif_layer_v_init = 0.0,
#                 lif_layer_v_decay = decay,
#                 lif_layer_v_threshold = 1.2,
#                 lif_layer_v_reset = 0.0, #현재 안씀. 걍 빼기 해버림
#                 lif_layer_sg_width = 1.0, # surrogate sigmoid 쓸 때는 의미없음

#                 # synapse_conv_in_channels = IMAGE_PIXEL_CHANNEL,
#                 synapse_conv_kernel_size = 3,
#                 synapse_conv_stride = 1,
#                 synapse_conv_padding = 1,
#                 synapse_conv_trace_const1 = 1,
#                 synapse_conv_trace_const2 = decay, # lif_layer_v_decay

#                 # synapse_fc_out_features = CLASS_NUM,
#                 synapse_fc_trace_const1 = 1,
#                 synapse_fc_trace_const2 = decay, # lif_layer_v_decay

#                 pre_trained = False, # True # False
#                 convTrue_fcFalse = True, # True # False
#                 # cfg = [64],
#                 # cfg = [64,[64,64],64], # 끝에 linear classifier 하나 자동으로 붙습니다
#                 cfg = [64, 128, 'P', 256, 256, 'P', 512, 512, 'P', 512, 512],
#                 pre_trained_path = "net_save/save_now_net.pth",
#                 learning_rate = 0.00001,
#                 epoch_num = 200,
#                 verbose_interval = 10000, #숫자 크게 하면 꺼짐
#                 validation_interval = 50, #숫자 크게 하면 꺼짐
#                 tdBN_on = False,  # True # False
#                 BN_on = True,  # True # False
                
#                 surrogate = 'sigmoid', # 'rectangle' 'sigmoid' 'rough_rectangle'
                
#                 gradient_verbose = False,  # True # False  # weight gradient 각 layer마다 띄워줌

#                 BPTT_on = False,  # True # False

#                 scheduler_name = 'no', # 'no' 'StepLR' 'ExponentialLR' 'ReduceLROnPlateau' 'CosineAnnealingLR' 'OneCycleLR'
                
#                 ddp_on = False,
#                 )


# Files already downloaded and verified
# Files already downloaded and verified
# ==================================================
# Num of PARAMS: 9,302,410
# Memory: 35.49MiB at 32-bit
# ==================================================
# EPOCH 0
# train: 0-390/391 tr_acc: 27.50%, lr=['1e-05'], val_acc: 25.65%: 100%|██████████| 391/391 [08:39<00:00,  1.33s/it]
# epoch_time: 519.6511018276215 seconds


# EPOCH 1

# train: 1-390/391 tr_acc: 32.50%, lr=['1e-05'], val_acc: 32.00%: 100%|██████████| 391/391 [08:40<00:00,  1.33s/it]
# epoch_time: 520.4912984371185 seconds


# EPOCH 2

# train: 2-390/391 tr_acc: 33.75%, lr=['1e-05'], val_acc: 35.14%: 100%|██████████| 391/391 [08:33<00:00,  1.31s/it]
# epoch_time: 513.511162519455 seconds


# EPOCH 3

# train: 3-390/391 tr_acc: 28.75%, lr=['1e-05'], val_acc: 37.17%: 100%|██████████| 391/391 [08:30<00:00,  1.30s/it]
# epoch_time: 510.2808663845062 seconds


# EPOCH 4

# train: 4-390/391 tr_acc: 58.75%, lr=['1e-05'], val_acc: 40.37%: 100%|██████████| 391/391 [08:21<00:00,  1.28s/it]
# epoch_time: 501.48126435279846 seconds


# EPOCH 5

# train: 5-390/391 tr_acc: 37.50%, lr=['1e-05'], val_acc: 42.59%: 100%|██████████| 391/391 [08:18<00:00,  1.28s/it]
# epoch_time: 498.8543312549591 seconds


# EPOCH 6

# train: 6-390/391 tr_acc: 35.00%, lr=['1e-05'], val_acc: 43.45%: 100%|██████████| 391/391 [08:15<00:00,  1.27s/it]
# epoch_time: 495.6830530166626 seconds


# EPOCH 7

# train: 7-390/391 tr_acc: 48.75%, lr=['1e-05'], val_acc: 44.97%: 100%|██████████| 391/391 [08:16<00:00,  1.27s/it]
# epoch_time: 496.7004475593567 seconds


# EPOCH 8

# train: 8-390/391 tr_acc: 47.50%, lr=['1e-05'], val_acc: 46.19%: 100%|██████████| 391/391 [08:16<00:00,  1.27s/it]
# epoch_time: 496.552987575531 seconds


# EPOCH 9

# train: 9-390/391 tr_acc: 46.25%, lr=['1e-05'], val_acc: 46.52%: 100%|██████████| 391/391 [08:37<00:00,  1.32s/it]
# epoch_time: 517.5774216651917 seconds


# EPOCH 10

# train: 10-390/391 tr_acc: 48.75%, lr=['1e-05'], val_acc: 48.67%: 100%|██████████| 391/391 [08:57<00:00,  1.37s/it]
# epoch_time: 537.5675616264343 seconds


# EPOCH 11

# train: 11-390/391 tr_acc: 48.75%, lr=['1e-05'], val_acc: 48.40%: 100%|██████████| 391/391 [08:53<00:00,  1.37s/it]
# epoch_time: 533.9970047473907 seconds


# EPOCH 12

# train: 12-390/391 tr_acc: 43.75%, lr=['1e-05'], val_acc: 50.58%: 100%|██████████| 391/391 [08:57<00:00,  1.38s/it]
# epoch_time: 537.9225826263428 seconds


# EPOCH 13

# train: 13-390/391 tr_acc: 42.50%, lr=['1e-05'], val_acc: 50.98%: 100%|██████████| 391/391 [08:58<00:00,  1.38s/it]
# epoch_time: 538.5700080394745 seconds


# EPOCH 14

# train: 14-390/391 tr_acc: 48.75%, lr=['1e-05'], val_acc: 52.68%: 100%|██████████| 391/391 [08:59<00:00,  1.38s/it]
# epoch_time: 539.7910151481628 seconds


# EPOCH 15

# train: 15-390/391 tr_acc: 55.00%, lr=['1e-05'], val_acc: 54.05%: 100%|██████████| 391/391 [08:52<00:00,  1.36s/it]
# epoch_time: 532.8848164081573 seconds


# EPOCH 16

# train: 16-390/391 tr_acc: 52.50%, lr=['1e-05'], val_acc: 53.80%: 100%|██████████| 391/391 [08:58<00:00,  1.38s/it]
# epoch_time: 538.2881193161011 seconds


# EPOCH 17

# train: 17-390/391 tr_acc: 57.50%, lr=['1e-05'], val_acc: 54.73%: 100%|██████████| 391/391 [09:00<00:00,  1.38s/it]
# epoch_time: 540.2721989154816 seconds


# EPOCH 18

# train: 18-390/391 tr_acc: 60.00%, lr=['1e-05'], val_acc: 55.54%: 100%|██████████| 391/391 [09:00<00:00,  1.38s/it]
# epoch_time: 540.7067131996155 seconds


# EPOCH 19

# train: 19-390/391 tr_acc: 61.25%, lr=['1e-05'], val_acc: 56.20%: 100%|██████████| 391/391 [08:53<00:00,  1.36s/it]
# epoch_time: 533.882045507431 seconds


# EPOCH 20

# train: 20-390/391 tr_acc: 61.25%, lr=['1e-05'], val_acc: 57.64%: 100%|██████████| 391/391 [08:55<00:00,  1.37s/it]
# epoch_time: 535.4623730182648 seconds


# EPOCH 21

# train: 21-390/391 tr_acc: 57.50%, lr=['1e-05'], val_acc: 57.87%: 100%|██████████| 391/391 [08:57<00:00,  1.37s/it]
# epoch_time: 537.7847683429718 seconds


# EPOCH 22

# train: 22-390/391 tr_acc: 57.50%, lr=['1e-05'], val_acc: 58.46%: 100%|██████████| 391/391 [08:55<00:00,  1.37s/it]
# epoch_time: 535.7902438640594 seconds


# EPOCH 23

# train: 23-390/391 tr_acc: 57.50%, lr=['1e-05'], val_acc: 57.95%: 100%|██████████| 391/391 [08:57<00:00,  1.37s/it]
# epoch_time: 537.5092451572418 seconds


# EPOCH 24

# train: 24-390/391 tr_acc: 55.00%, lr=['1e-05'], val_acc: 57.98%: 100%|██████████| 391/391 [08:44<00:00,  1.34s/it]
# epoch_time: 524.3707118034363 seconds


# EPOCH 25

# train: 25-390/391 tr_acc: 50.00%, lr=['1e-05'], val_acc: 60.33%: 100%|██████████| 391/391 [08:41<00:00,  1.33s/it]
# epoch_time: 521.3296625614166 seconds


# EPOCH 26

# train: 26-390/391 tr_acc: 52.50%, lr=['1e-05'], val_acc: 60.02%: 100%|██████████| 391/391 [08:41<00:00,  1.33s/it]
# epoch_time: 521.8723931312561 seconds


# EPOCH 27

# train: 27-390/391 tr_acc: 60.00%, lr=['1e-05'], val_acc: 60.80%: 100%|██████████| 391/391 [08:42<00:00,  1.34s/it]
# epoch_time: 522.6468982696533 seconds


# EPOCH 28

# train: 28-390/391 tr_acc: 56.25%, lr=['1e-05'], val_acc: 60.99%: 100%|██████████| 391/391 [08:43<00:00,  1.34s/it]
# epoch_time: 524.2045466899872 seconds


# EPOCH 29

# train: 29-390/391 tr_acc: 66.25%, lr=['1e-05'], val_acc: 62.56%: 100%|██████████| 391/391 [08:42<00:00,  1.34s/it]
# epoch_time: 523.102609872818 seconds


# EPOCH 30

# train: 30-390/391 tr_acc: 65.00%, lr=['1e-05'], val_acc: 62.23%: 100%|██████████| 391/391 [08:43<00:00,  1.34s/it]
# epoch_time: 523.3475232124329 seconds


# EPOCH 31

# train: 31-390/391 tr_acc: 70.00%, lr=['1e-05'], val_acc: 61.05%: 100%|██████████| 391/391 [08:42<00:00,  1.34s/it]
# epoch_time: 522.611293554306 seconds


# EPOCH 32

# train: 32-390/391 tr_acc: 57.50%, lr=['1e-05'], val_acc: 63.73%: 100%|██████████| 391/391 [08:41<00:00,  1.33s/it]
# epoch_time: 521.550628900528 seconds


# EPOCH 33

# train: 33-390/391 tr_acc: 67.50%, lr=['1e-05'], val_acc: 63.01%: 100%|██████████| 391/391 [08:40<00:00,  1.33s/it]
# epoch_time: 520.7382657527924 seconds


# EPOCH 34

# train: 34-390/391 tr_acc: 66.25%, lr=['1e-05'], val_acc: 63.66%: 100%|██████████| 391/391 [08:42<00:00,  1.34s/it]
# epoch_time: 523.2350842952728 seconds


# EPOCH 35

# train: 35-390/391 tr_acc: 65.00%, lr=['1e-05'], val_acc: 64.20%: 100%|██████████| 391/391 [08:39<00:00,  1.33s/it]
# epoch_time: 519.9772260189056 seconds


# EPOCH 36

# train: 36-390/391 tr_acc: 61.25%, lr=['1e-05'], val_acc: 64.67%: 100%|██████████| 391/391 [08:41<00:00,  1.33s/it]
# epoch_time: 522.0173165798187 seconds


# EPOCH 37

# train: 37-390/391 tr_acc: 62.50%, lr=['1e-05'], val_acc: 64.44%: 100%|██████████| 391/391 [08:41<00:00,  1.33s/it]
# epoch_time: 522.151460647583 seconds


# EPOCH 38

# train: 38-390/391 tr_acc: 62.50%, lr=['1e-05'], val_acc: 64.77%: 100%|██████████| 391/391 [08:41<00:00,  1.33s/it]
# epoch_time: 521.8787703514099 seconds


# EPOCH 39

# train: 39-390/391 tr_acc: 66.25%, lr=['1e-05'], val_acc: 65.22%: 100%|██████████| 391/391 [08:42<00:00,  1.34s/it]
# epoch_time: 522.234944820404 seconds


# EPOCH 40

# train: 40-390/391 tr_acc: 68.75%, lr=['1e-05'], val_acc: 64.49%: 100%|██████████| 391/391 [08:41<00:00,  1.33s/it]
# epoch_time: 521.9287745952606 seconds


# EPOCH 41

# train: 41-390/391 tr_acc: 66.25%, lr=['1e-05'], val_acc: 66.06%: 100%|██████████| 391/391 [08:41<00:00,  1.33s/it]
# epoch_time: 521.8189256191254 seconds


# EPOCH 42

# train: 42-390/391 tr_acc: 76.25%, lr=['1e-05'], val_acc: 65.47%: 100%|██████████| 391/391 [08:43<00:00,  1.34s/it]
# epoch_time: 523.2214741706848 seconds


# EPOCH 43

# train: 43-390/391 tr_acc: 78.75%, lr=['1e-05'], val_acc: 66.97%: 100%|██████████| 391/391 [08:41<00:00,  1.33s/it]
# epoch_time: 521.9151771068573 seconds


# EPOCH 44

# train: 44-390/391 tr_acc: 72.50%, lr=['1e-05'], val_acc: 66.76%: 100%|██████████| 391/391 [08:40<00:00,  1.33s/it]
# epoch_time: 520.2902402877808 seconds


# EPOCH 45

# train: 45-390/391 tr_acc: 67.50%, lr=['1e-05'], val_acc: 66.30%: 100%|██████████| 391/391 [08:40<00:00,  1.33s/it]
# epoch_time: 520.3684139251709 seconds


# EPOCH 46

# train: 46-390/391 tr_acc: 76.25%, lr=['1e-05'], val_acc: 66.27%: 100%|██████████| 391/391 [08:42<00:00,  1.34s/it]
# epoch_time: 522.4130618572235 seconds


# EPOCH 47

# train: 47-390/391 tr_acc: 67.50%, lr=['1e-05'], val_acc: 66.20%: 100%|██████████| 391/391 [08:41<00:00,  1.33s/it]
# epoch_time: 521.4402587413788 seconds


# EPOCH 48

# train: 48-390/391 tr_acc: 67.50%, lr=['1e-05'], val_acc: 67.45%: 100%|██████████| 391/391 [08:40<00:00,  1.33s/it]
# epoch_time: 521.1478617191315 seconds


# EPOCH 49

# train: 49-390/391 tr_acc: 77.50%, lr=['1e-05'], val_acc: 68.14%: 100%|██████████| 391/391 [08:40<00:00,  1.33s/it]
# epoch_time: 520.3577303886414 seconds


# EPOCH 50

# train: 50-390/391 tr_acc: 63.75%, lr=['1e-05'], val_acc: 67.95%: 100%|██████████| 391/391 [08:42<00:00,  1.34s/it]
# epoch_time: 522.6194558143616 seconds


# EPOCH 51

# train: 51-390/391 tr_acc: 75.00%, lr=['1e-05'], val_acc: 68.41%: 100%|██████████| 391/391 [08:40<00:00,  1.33s/it]
# epoch_time: 520.8790509700775 seconds


# EPOCH 52

# train: 52-390/391 tr_acc: 66.25%, lr=['1e-05'], val_acc: 67.84%: 100%|██████████| 391/391 [08:40<00:00,  1.33s/it]
# epoch_time: 521.1223595142365 seconds


# EPOCH 53

# train: 53-390/391 tr_acc: 68.75%, lr=['1e-05'], val_acc: 68.00%: 100%|██████████| 391/391 [08:42<00:00,  1.34s/it]
# epoch_time: 522.3871352672577 seconds


# EPOCH 54

# train: 54-390/391 tr_acc: 73.75%, lr=['1e-05'], val_acc: 68.75%: 100%|██████████| 391/391 [08:41<00:00,  1.33s/it]
# epoch_time: 521.407853603363 seconds


# EPOCH 55

# train: 55-390/391 tr_acc: 71.25%, lr=['1e-05'], val_acc: 69.21%: 100%|██████████| 391/391 [08:42<00:00,  1.34s/it]
# epoch_time: 522.9369058609009 seconds


# EPOCH 56

# train: 56-390/391 tr_acc: 67.50%, lr=['1e-05'], val_acc: 69.28%: 100%|██████████| 391/391 [08:41<00:00,  1.33s/it]
# epoch_time: 521.430141210556 seconds


# EPOCH 57

# train: 57-390/391 tr_acc: 65.00%, lr=['1e-05'], val_acc: 69.16%: 100%|██████████| 391/391 [08:43<00:00,  1.34s/it]
# epoch_time: 523.3304183483124 seconds


# EPOCH 58

# train: 58-390/391 tr_acc: 67.50%, lr=['1e-05'], val_acc: 68.30%: 100%|██████████| 391/391 [08:41<00:00,  1.33s/it]
# epoch_time: 521.6461317539215 seconds


# EPOCH 59

# train: 59-390/391 tr_acc: 70.00%, lr=['1e-05'], val_acc: 69.50%: 100%|██████████| 391/391 [08:39<00:00,  1.33s/it]
# epoch_time: 520.0395126342773 seconds


# EPOCH 60

# train: 60-390/391 tr_acc: 68.75%, lr=['1e-05'], val_acc: 69.68%: 100%|██████████| 391/391 [08:41<00:00,  1.33s/it]
# epoch_time: 521.3599922657013 seconds


# EPOCH 61

# train: 61-390/391 tr_acc: 63.75%, lr=['1e-05'], val_acc: 70.64%: 100%|██████████| 391/391 [08:41<00:00,  1.33s/it]
# epoch_time: 521.446202993393 seconds


# EPOCH 62

# train: 62-390/391 tr_acc: 75.00%, lr=['1e-05'], val_acc: 70.57%: 100%|██████████| 391/391 [08:42<00:00,  1.34s/it]
# epoch_time: 522.744499206543 seconds


# EPOCH 63

# train: 63-390/391 tr_acc: 65.00%, lr=['1e-05'], val_acc: 70.47%: 100%|██████████| 391/391 [08:39<00:00,  1.33s/it]
# epoch_time: 519.3143074512482 seconds


# EPOCH 64

# train: 64-390/391 tr_acc: 63.75%, lr=['1e-05'], val_acc: 70.40%: 100%|██████████| 391/391 [08:41<00:00,  1.33s/it]
# epoch_time: 521.6692788600922 seconds


# EPOCH 65

# train: 65-390/391 tr_acc: 80.00%, lr=['1e-05'], val_acc: 70.64%: 100%|██████████| 391/391 [08:42<00:00,  1.34s/it]
# epoch_time: 522.667236328125 seconds


# EPOCH 66

# train: 66-390/391 tr_acc: 72.50%, lr=['1e-05'], val_acc: 69.29%: 100%|██████████| 391/391 [08:41<00:00,  1.33s/it]
# epoch_time: 522.0747861862183 seconds


# EPOCH 67

# train: 67-390/391 tr_acc: 73.75%, lr=['1e-05'], val_acc: 70.66%: 100%|██████████| 391/391 [08:42<00:00,  1.34s/it]
# epoch_time: 522.6984555721283 seconds


# EPOCH 68

# train: 68-390/391 tr_acc: 71.25%, lr=['1e-05'], val_acc: 70.71%: 100%|██████████| 391/391 [08:42<00:00,  1.34s/it]
# epoch_time: 523.1850342750549 seconds


# EPOCH 69

# train: 69-390/391 tr_acc: 86.25%, lr=['1e-05'], val_acc: 70.95%: 100%|██████████| 391/391 [08:41<00:00,  1.33s/it]
# epoch_time: 522.0473079681396 seconds


# EPOCH 70

# train: 70-390/391 tr_acc: 70.00%, lr=['1e-05'], val_acc: 71.07%: 100%|██████████| 391/391 [08:44<00:00,  1.34s/it]
# epoch_time: 524.3250975608826 seconds


# EPOCH 71

# train: 71-390/391 tr_acc: 71.25%, lr=['1e-05'], val_acc: 71.03%: 100%|██████████| 391/391 [08:43<00:00,  1.34s/it]
# epoch_time: 524.1436457633972 seconds


# EPOCH 72

# train: 72-390/391 tr_acc: 72.50%, lr=['1e-05'], val_acc: 70.97%: 100%|██████████| 391/391 [08:42<00:00,  1.34s/it]
# epoch_time: 522.3500926494598 seconds


# EPOCH 73

# train: 73-390/391 tr_acc: 81.25%, lr=['1e-05'], val_acc: 69.71%: 100%|██████████| 391/391 [08:43<00:00,  1.34s/it]
# epoch_time: 523.6104423999786 seconds


# EPOCH 74

# train: 74-390/391 tr_acc: 72.50%, lr=['1e-05'], val_acc: 71.35%: 100%|██████████| 391/391 [08:42<00:00,  1.34s/it]
# epoch_time: 523.1159906387329 seconds


# EPOCH 75

# train: 75-390/391 tr_acc: 70.00%, lr=['1e-05'], val_acc: 72.26%: 100%|██████████| 391/391 [08:43<00:00,  1.34s/it]
# epoch_time: 523.9588918685913 seconds


# EPOCH 76

# train: 76-105/391 tr_acc: 75.78%, lr=['1e-05'], val_acc: 71.69%:  27%|██▋       | 106/391 [02:24<06:28,  1.36s/it]