# 06 => Cifar-10. 30 에폭 학습시켜보기. val_accuracy의 변화 추이 알아보기


=> 이후에는 CIFAR-10, CIFAR-100, ImageNet dataset에 대해 validation accuracy 줄어들지 않을 때까지 몇 epoch 학습시켜야하며, 시간 얼마나 소요되는지 실험 진행해야함.

[Reference] https://pytorch.org/tutorials/intermediate/ddp_tutorial.html

# Main

In [1]:
import sys
import os
import logging
from easydict import EasyDict
import numpy as np
import random

import time
import datetime

from deap import tools

In [2]:
sys.path.insert(0,'../')
from utils_kyy.utils_graph import make_random_graph
from utils_kyy.create_toolbox import create_toolbox_for_NSGA_RWNN

## 1. generation pool 구성하기 (Small RWNN 대상)

In [3]:
# 실험을 위한 환경 셋팅
run_code = 'test_kyy_CIFAR10_time_check'
stage_pool_path = '../graph_pool' + '/' + run_code + '/'
log_path = '../logs/' + run_code + '/'

# check & make the directory
if not os.path.exists(stage_pool_path): os.makedirs(stage_pool_path)
if not os.path.isdir(log_path): os.makedirs(log_path)

# write the log
log_file_name = log_path + 'logging.log'
logging.basicConfig(filename=log_file_name, level=logging.INFO)
logging.info('Start to write log.')

In [4]:
# make random graph pool
num_graph = 100
make_random_graph(num_graph, stage_pool_path)

######################################################
# => 최종적으로, num_graph와 stage_pool_path 를 인수로 받아서, 해당 path에 num_graph 수 만큼의 그래프 떨궈주는 함수 만들기
#    일단은 정해진 graph_model은 'WS', K, P 는 인수로 받지 말고 구현
#      =>  이후에 확장하기.
######################################################

Start to make random graph pool...
Finished


## 2. Main NSGA_RWNN

In [5]:
# define 'args_train' for evaluation
args_train = EasyDict({
    'lr_mode': 'cosine',
    'warmup_mode': 'linear',    # default
    'base_lr': 0.1,
    'momentum': 0.9, 
    'weight_decay': 0.00005,
    'print_freq': 100,

    'epochs': 1,
    'batch_size': 32,   # 128 => 256

    'workers': 4,  # 2 => 

    'warmup_epochs': 0,
    'warmup_lr': 0.0,
    'targetlr': 0.0,

})

In [6]:
# create custom_toolbox
# num_graph, args_train, stage_pool_path => to define the 'evaluate' function 

data_path = 'D:/data/cifar10/'
toolbox = create_toolbox_for_NSGA_RWNN(num_graph, args_train, stage_pool_path, data_path ,log_file_name)


In [7]:
"""
4. Algorithms
 For the purpose of completeness we will develop the complete generational algorithm.
"""

POP_SIZE = 4    # population size
NGEN = 1    # number of Generation
CXPB = 0.5    # crossover probability 
MUTPB = 0.5    # mutation probability


# log에 기록할 stats
stats = tools.Statistics(lambda ind: ind.fitness.values)
stats.register("min", np.min, axis=0)
stats.register("max", np.max, axis=0)

logbook = tools.Logbook()
logbook.header = "gen", "evals", "min", "max", "evals_time", "gen_time"

# population 생성.  (toolbox.population은 creator.Individual n개를 담은 list를 반환. (=> population)
now = datetime.datetime.now()
now_str = now.strftime('%Y-%m-%d %H:%M:%S')
print("Initialion starts ...")
logging.info("Initialion starts at " + now_str)
init_start_time = time.time()

pop = toolbox.population(n=POP_SIZE)

# Evaluate the individuals with an invalid fitness
invalid_ind = [ind for ind in pop if not ind.fitness.valid]
fitnesses = toolbox.map(toolbox.evaluate, invalid_ind)    # .evaluate는 tuple을 반환. 따라서 fitnesses는 튜플을 원소로 가지는 list
for ind, fit in zip(invalid_ind, fitnesses):
    ind.fitness.values = fit   # ind.fitness.values = (val_accuracy, flops) 튜플

# This is just to assign the crowding distance to the individuals
# no actual selection is done
pop = toolbox.select(pop, len(pop))

record = stats.compile(pop)
logbook.record(gen=0, evals=len(invalid_ind), **record)
print(logbook.stream)

now = datetime.datetime.now()
now_str = now.strftime('%Y-%m-%d %H:%M:%S')
print("Initialization is finished at", now_str)
logging.info("Initialion is finished at " + now_str)

init_time = time.time() - init_start_time
logging.info("Initialion time = " + str(init_time) + "s")


print()

# Begin the generational process
for gen in range(1, NGEN):
    now = datetime.datetime.now()
    now_str = now.strftime('%Y-%m-%d %H:%M:%S')
    print("#####", gen, "th generation starts at", now_str)
    logging.info(str(gen) + "th generation starts at" + now_str)
    
    start_gen = time.time()
    # Vary the population
    offspring = tools.selTournamentDCD(pop, len(pop))
    offspring = [toolbox.clone(ind) for ind in offspring]

    for ind1, ind2 in zip(offspring[::2], offspring[1::2]):
        if random.random() <= CXPB:
            toolbox.mate(ind1, ind2)

        toolbox.mutate(ind1, indpb=MUTPB)
        toolbox.mutate(ind2, indpb=MUTPB)
        del ind1.fitness.values, ind2.fitness.values

    # Evaluate the individuals with an invalid fitness
    print("##### Evaluation starts")
    start_time = time.time()
    
    invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
    fitnesses = toolbox.map(toolbox.evaluate, invalid_ind)
    for ind, fit in zip(invalid_ind, fitnesses):
        ind.fitness.values = fit
        
    eval_time_for_one_generation = time.time() - start_time        
    print("##### Evaluation ends (Time : %.3f)" % eval_time_for_one_generation)
    
    # Select the next generation population
    pop = toolbox.select(pop + offspring, POP_SIZE)
    
    gen_time = time.time() - start_gen
    print('##### [gen_time: %.3fs]' % gen_time, gen, 'th generation is finished.')
    
    record = stats.compile(pop)
    logbook.record(gen=gen, evals=len(invalid_ind), **record,
                   evals_time=eval_time_for_one_generation, gen_time=gen_time)
    
    logging.info('Gen [%03d/%03d] -- evals: %03d, evals_time: %.4fs, gen_time: %.4fs' % (gen, NGEN, len(invalid_ind), eval_time_for_one_generation, gen_time))
    print(logbook.stream)

Initialion starts ...
Files already downloaded and verified
Files already downloaded and verified
	 - Epoch: [0][0/1563]	Time 7.287 (7.287)	Loss 6.9324 (6.9324)	Prec@1 0.000 (0.000)	Prec@5 0.000 (0.000)
	 - Epoch: [0][100/1563]	Time 0.625 (0.649)	Loss 2.2371 (2.5518)	Prec@1 15.625 (15.625)	Prec@5 78.125 (61.015)
	 - Epoch: [0][200/1563]	Time 0.595 (0.617)	Loss 1.8860 (2.3488)	Prec@1 25.000 (17.864)	Prec@5 93.750 (68.097)
	 - Epoch: [0][300/1563]	Time 0.564 (0.605)	Loss 2.1227 (2.2472)	Prec@1 18.750 (19.923)	Prec@5 75.000 (71.865)
	 - Epoch: [0][400/1563]	Time 0.562 (0.601)	Loss 2.0643 (2.1770)	Prec@1 21.875 (21.766)	Prec@5 81.250 (74.213)
	 - Epoch: [0][500/1563]	Time 0.579 (0.598)	Loss 1.8191 (2.1159)	Prec@1 40.625 (23.646)	Prec@5 84.375 (75.979)
	 - Epoch: [0][600/1563]	Time 0.565 (0.595)	Loss 1.6630 (2.0713)	Prec@1 31.250 (24.808)	Prec@5 90.625 (77.262)
	 - Epoch: [0][700/1563]	Time 0.594 (0.594)	Loss 1.8579 (2.0295)	Prec@1 28.125 (26.092)	Prec@5 81.250 (78.504)
	 - Epoch: [0][800/1

In [8]:
# Check logbook
logbook

[{'gen': 0,
  'evals': 4,
  'min': array([-5.05900000e+01,  7.92123008e+08]),
  'max': array([-4.77900000e+01,  9.94054464e+08])}]

### logbook - plot

In [9]:
type(logbook)

deap.tools.support.Logbook

In [10]:
len(logbook)

1

In [11]:
logbook[0]

{'gen': 0,
 'evals': 4,
 'min': array([-5.05900000e+01,  7.92123008e+08]),
 'max': array([-4.77900000e+01,  9.94054464e+08])}

In [12]:
logbook[0]['min']

array([-5.05900000e+01,  7.92123008e+08])

In [13]:
-logbook[0]['min'][0], logbook[0]['min'][1]

(50.59, 792123008.0)

In [14]:
min_val_acc = []
min_flops = []

max_val_acc = []
max_flops = []

for i in range(len(logbook)):
    min_val_acc_i, min_flops_i = -logbook[i]['min'][0], logbook[i]['min'][1]
    max_val_acc_i, max_flops_i = -logbook[i]['max'][0], logbook[i]['max'][1]
    
    min_val_acc.append(min_val_acc_i)
    min_flops.append(min_flops_i)
    max_val_acc.append(max_val_acc_i)
    max_flops.append(max_flops_i)    

In [15]:
from matplotlib import pyplot as plt
%matplotlib inline

ModuleNotFoundError: No module named 'matplotlib'

In [17]:
# NSGA-2가 제대로 동작함을 알 수 있음.
plt.plot(min_val_acc, min_flops)

plt.xlabel('min_val_acc')
plt.ylabel('min_flops')
plt.title('Experiment Result')

plt.show()

NameError: name 'plt' is not defined

In [None]:
# gen time 확인
gen_time_list = []

for i in range(1, len(logbook)):
    # 첫번째 initialize 에는 gen_time이 없음
    gen_time_i = logbook[i]['gen_time']
    gen_time_list.append(gen_time_i)

In [None]:
plt.plot(gen_time_list)

plt.xlabel('generation')
plt.ylabel('gen_time_list')
plt.title('Experiment Result')

plt.show()