# CIFAR-100. val-accuracy 줄어들지 않을 때까지 학습시켜보기

[Reference] https://pytorch.org/tutorials/intermediate/ddp_tutorial.html

# Main

In [1]:
import os
import logging
from easydict import EasyDict
import numpy as np
import random

import time
import datetime

from deap import tools

In [2]:
from utils_kyy.utils_graph import make_random_graph
from utils_kyy.create_toolbox_cifar100 import create_toolbox_for_NSGA_RWNN

## 1. generation pool 구성하기 (Small RWNN 대상)

In [3]:
# 실험을 위한 환경 셋팅
run_code = 'test_kyy_CIFAR100_time_check'
stage_pool_path = './graph_pool' + '/' + run_code + '/'
log_path = './logs/' + run_code + '/'

# check & make the directory
if not os.path.exists(stage_pool_path): os.makedirs(stage_pool_path)
if not os.path.isdir(log_path): os.makedirs(log_path)

# write the log
log_file_name = log_path + 'logging.log'
logging.basicConfig(filename=log_file_name, level=logging.INFO)
logging.info('Start to write log.')

In [4]:
# make random graph pool
num_graph = 100
make_random_graph(num_graph, stage_pool_path)

######################################################
# => 최종적으로, num_graph와 stage_pool_path 를 인수로 받아서, 해당 path에 num_graph 수 만큼의 그래프 떨궈주는 함수 만들기
#    일단은 정해진 graph_model은 'WS', K, P 는 인수로 받지 말고 구현
#      =>  이후에 확장하기.
######################################################

Start to make random graph pool...
Finished


## 2. Main NSGA_RWNN

In [5]:
# define 'args_train' for evaluation
args_train = EasyDict({
    'lr_mode': 'cosine',
    'warmup_mode': 'linear',    # default
    'base_lr': 0.1,
    'momentum': 0.9, 
    'weight_decay': 0.00005,
    'print_freq': 100,

    'epochs': -1,    # -1: val_accuracy가 증가하는 한 계속. 즉, 3 epoch 기다려보고 증가하지 않으면 멈추기.
    'batch_size': 256,   # 128 => 256

    'workers': 32,  # 2 => 

    'warmup_epochs': 0,
    'warmup_lr': 0.0,
    'targetlr': 0.0,

})

In [6]:
# create custom_toolbox
# num_graph, args_train, stage_pool_path => to define the 'evaluate' function 
toolbox = create_toolbox_for_NSGA_RWNN(num_graph, args_train, stage_pool_path, log_file_name)

In [7]:
"""
4. Algorithms
 For the purpose of completeness we will develop the complete generational algorithm.
"""

POP_SIZE = 1    # population size
NGEN = 1    # number of Generation
CXPB = 0   # crossover probability 
MUTPB = 0    # mutation probability


# log에 기록할 stats
stats = tools.Statistics(lambda ind: ind.fitness.values)
stats.register("min", np.min, axis=0)
stats.register("max", np.max, axis=0)

logbook = tools.Logbook()
logbook.header = "gen", "evals", "min", "max", "evals_time", "gen_time"

# population 생성.  (toolbox.population은 creator.Individual n개를 담은 list를 반환. (=> population)
now = datetime.datetime.now()
now_str = now.strftime('%Y-%m-%d %H:%M:%S')
print("Initialion starts ...")
logging.info("Initialion starts at " + now_str)
init_start_time = time.time()

pop = toolbox.population(n=POP_SIZE)

# Evaluate the individuals with an invalid fitness
invalid_ind = [ind for ind in pop if not ind.fitness.valid]
fitnesses = toolbox.map(toolbox.evaluate, invalid_ind)    # .evaluate는 tuple을 반환. 따라서 fitnesses는 튜플을 원소로 가지는 list
for ind, fit in zip(invalid_ind, fitnesses):
    ind.fitness.values = fit   # ind.fitness.values = (val_accuracy, flops) 튜플

# This is just to assign the crowding distance to the individuals
# no actual selection is done
pop = toolbox.select(pop, len(pop))

record = stats.compile(pop)
logbook.record(gen=0, evals=len(invalid_ind), **record)
print(logbook.stream)

now = datetime.datetime.now()
now_str = now.strftime('%Y-%m-%d %H:%M:%S')
print("Initialization is finished at", now_str)
logging.info("Initialion is finished at " + now_str)

init_time = time.time() - init_start_time
logging.info("Initialion time = " + str(init_time) + "s")


print()

# Begin the generational process
for gen in range(1, NGEN):
    now = datetime.datetime.now()
    now_str = now.strftime('%Y-%m-%d %H:%M:%S')
    print("#####", gen, "th generation starts at", now_str)
    logging.info(str(gen) + "th generation starts at" + now_str)
    
    start_gen = time.time()
    # Vary the population
    offspring = tools.selTournamentDCD(pop, len(pop))
    offspring = [toolbox.clone(ind) for ind in offspring]

    for ind1, ind2 in zip(offspring[::2], offspring[1::2]):
        if random.random() <= CXPB:
            toolbox.mate(ind1, ind2)

        toolbox.mutate(ind1, indpb=MUTPB)
        toolbox.mutate(ind2, indpb=MUTPB)
        del ind1.fitness.values, ind2.fitness.values

    # Evaluate the individuals with an invalid fitness
    print("##### Evaluation starts")
    start_time = time.time()
    
    invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
    fitnesses = toolbox.map(toolbox.evaluate, invalid_ind)
    for ind, fit in zip(invalid_ind, fitnesses):
        ind.fitness.values = fit
        
    eval_time_for_one_generation = time.time() - start_time        
    print("##### Evaluation ends (Time : %.3f)" % eval_time_for_one_generation)
    
    # Select the next generation population
    pop = toolbox.select(pop + offspring, POP_SIZE)
    
    gen_time = time.time() - start_gen
    print('##### [gen_time: %.3fs]' % gen_time, gen, 'th generation is finished.')
    
    record = stats.compile(pop)
    logbook.record(gen=gen, evals=len(invalid_ind), **record,
                   evals_time=eval_time_for_one_generation, gen_time=gen_time)
    
    logging.info('Gen [%03d/%03d] -- evals: %03d, evals_time: %.4fs, gen_time: %.4fs' % (gen, NGEN, len(invalid_ind), eval_time_for_one_generation, gen_time))
    print(logbook.stream)

Initialion starts ...
Files already downloaded and verified
Files already downloaded and verified
	 - Epoch: [0][0/196]	Time 14.505 (14.505)	Loss 6.9288 (6.9288)	Prec@1 0.000 (0.000)	Prec@5 0.781 (0.781)
	 - Epoch: [0][100/196]	Time 0.690 (0.784)	Loss 3.8885 (4.4187)	Prec@1 9.766 (5.036)	Prec@5 28.906 (19.086)
##### Validation_time 15.225 Prec@1 15.570 Prec@5 41.120 #####
	 - Epoch: [1][0/196]	Time 7.577 (7.577)	Loss 3.5685 (3.5685)	Prec@1 14.062 (14.062)	Prec@5 39.453 (39.453)
	 - Epoch: [1][100/196]	Time 0.627 (0.716)	Loss 3.6931 (3.5705)	Prec@1 12.109 (15.114)	Prec@5 35.938 (39.817)
##### Validation_time 14.597 Prec@1 13.250 Prec@5 36.820 #####
	 - Epoch: [2][0/196]	Time 6.056 (6.056)	Loss 3.4154 (3.4154)	Prec@1 15.234 (15.234)	Prec@5 40.625 (40.625)
	 - Epoch: [2][100/196]	Time 0.663 (0.719)	Loss 3.0281 (3.3016)	Prec@1 27.734 (19.446)	Prec@5 55.078 (47.567)
##### Validation_time 14.813 Prec@1 29.940 Prec@5 60.320 #####
gen	evals	min                              	max                

In [8]:
# Check logbook
logbook

[{'evals': 1,
  'gen': 0,
  'max': array([-2.99400000e+01,  1.73641306e+09]),
  'min': array([-2.99400000e+01,  1.73641306e+09])}]