## Full Training 
구현해야 될 것
* 모델 선정 리스트 만들기
* 모델 생성
* 훈련, hyper parameter는 기존과 동일, Termination 조건은 5 epoch 동인 Val acc 안 오를 경우
* 중간에 끊겨도 중간에서 부터 돌릴 수 있게 하기
* 로그 떨구기 -> 최종 로그 (시간 ,epoch, val acc) 

In [11]:
from deap import base, creator
from deap import tools


import random
from itertools import repeat
from collections import Sequence, OrderedDict

# For evaluate function --------------------------
import glob
from easydict import EasyDict

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn    # for hardware tunning (cudnn.benchmark = True)

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from thop import profile
from thop import clever_format

import logging

# Gray code package
from utils_kyy.utils_graycode import *

# custom package in utils_kyy
from utils_kyy.utils_graph import load_graph
from utils_kyy.models import RWNN
from utils_kyy.train_validate import train, validate, train_AMP
from utils_kyy.lr_scheduler import LRScheduler
from torchsummary import summary
# -------------------------------------------------

#from apex import amp


## For MNIST
class ReshapeTransform:
    def __init__(self, new_size):
        self.new_size = new_size

    def __call__(self, img):
        return torch.reshape(img, self.new_size)
    
def evaluate_v2_full_train(individual, args_train, stage_pool_path_list, data_path=None ,channels=109, log_file_name=None):  # individual

    
    # list 형식의 individual 객체를 input으로 받음   e.g. [0, 4, 17]
    # 1) load graph
    total_graph_path_list = []
    for i in range(3):
        total_graph_path_list.append( glob.glob(stage_pool_path_list[i] + '*.yaml') )

    graph_name = []

    # args_train 셋팅에서 graycode 변환이 true 인지 확인
    if args_train.graycode:
        ## Decode 해줘야 !
        gray_len = len(individual)//3
        for i in range(3):
            # list to string
            tmp = ''
            for j in individual[gray_len*i:gray_len*(i+1)]:
                tmp += str(j)

            # sting to binary to num
            graph_name.append(graydecode(int(tmp)))

    else :
        graph_name = individual

    stage_1_graph = load_graph( total_graph_path_list[0][graph_name[0]] )
    stage_2_graph = load_graph( total_graph_path_list[1][graph_name[1]] )
    stage_3_graph = load_graph( total_graph_path_list[2][graph_name[2]] )
    
    graphs = EasyDict({'stage_1': stage_1_graph,
                       'stage_2': stage_2_graph,
                       'stage_3': stage_3_graph
                      })

    # 2) build RWNN
    channels = channels
    NN_model = RWNN(net_type='small', graphs=graphs, channels=channels, num_classes=args_train.num_classes, input_channel=args_train.input_dim)
    NN_model.cuda()

    ###########################
    # Flops 계산 - [Debug] nn.DataParallele (for multi-gpu) 적용 전에 확인.
    ###########################
    input_flops = torch.randn(1, args_train.input_dim, 224, 224).cuda()
    flops, params = profile(NN_model, inputs=(input_flops, ), verbose=False)

    ## Model summary
    #summary(NN_model, input_size=(1, 224, 224))

    # 3) Prepare for train### 일단 꺼보자!
    #NN_model = nn.DataParallel(NN_model)  # for multi-GPU
    #NN_model = nn.DataParallel(NN_model, device_ids=[0,1,2,3])
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(NN_model.parameters(), args_train.base_lr,
                                momentum=args_train.momentum,
                                weight_decay=args_train.weight_decay)
    
    start_epoch  = 0
    best_prec1 = 0    
    
    cudnn.benchmark = True    # This flag allows you to enable the inbuilt cudnn auto-tuner to find the best algorithm to use for your hardware.  
    
    ###########################
    # Dataset & Dataloader
    ###########################

    # 이미 다운 받아놨으니 download=False
    # 데이터가 없을 경우, 처음에는 download=True 로 설정해놓고 실행해주어야함
    
    if data_path is None :
        data_path = './data'
    
 
    if args_train.data == "CIFAR10" :

        train_transform = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),  # 추가함
                transforms.Resize(224),  # 추가함.  imagenet dataset과 size 맞추기
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # rescale 0 ~ 1 => -1 ~ 1
            ])

        val_transform = transforms.Compose(
            [
                transforms.Resize(224),  # 추가함.  imagenet dataset과 size 맞추기
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # rescale 0 ~ 1 => -1 ~ 1
            ])

        train_dataset = torchvision.datasets.CIFAR10(root=data_path, train=True,
                                                download=True, transform=train_transform)

        val_dataset = torchvision.datasets.CIFAR10(root=data_path, train=False,
                                               download=True, transform=val_transform)
    elif args_train.data == "MNIST":

        train_transform =  transforms.Compose([
                                               transforms.Resize(224),
                                                transforms.ToTensor(),  # 추가함.  imagenet dataset과 size 맞추기
                                               transforms.Normalize((0.5,), (1.0,)),

            ])
        val_transform = transforms.Compose([transforms.Resize(224),
                                            transforms.ToTensor(),# 추가함.  imagenet dataset과 size 맞추기
                                            transforms.Normalize((0.5,), (1.0,))
            ])
        train_dataset = torchvision.datasets.MNIST(root=data_path, train=True, transform=train_transform, download=True)
        val_dataset = torchvision.datasets.MNIST(root=data_path, train=False, transform=val_transform, download=True)
        
    else :
        raise Exception("Data Error, Only CIFAR10, MNIST allowed for the moment")


    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args_train.batch_size,
                                              shuffle=True, num_workers=args_train.workers)  

    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args_train.batch_size,
                                             shuffle=False, num_workers=args_train.workers)    
    
    ###########################
    # Train
    ###########################
    niters = len(train_loader)
    niters = 1

    lr_scheduler = LRScheduler(optimizer, niters, args_train)  # (default) args.step = [30, 60, 90], args.decay_factor = 0.1, args.power = 2.0    
    cnt = 0 
    epoch = 0
    
    while True :
        
        # train for one epoch
        train(train_loader, NN_model, criterion, optimizer, lr_scheduler, epoch, args_train.print_freq, log_file_name)

        # evaluate on validation set
        prec1 = validate(val_loader, NN_model, criterion, epoch, log_file_name)
        
        epoch += 1

        # remember best prec@1 and save checkpoint
#         is_best = prec1 > best_prec1
        if prec1 > best_prec1 :
            best_prec1 = prec1
            cnt = 0
        else :
            cnt += 1
        
        if epoch == 2 :
            break

    return (-best_prec1, flops), epoch  # Min (-val_accuracy, flops) 이므로 val_accuracy(top1)에 - 붙여서 return


In [15]:
import os
import sys
import os
import logging
from easydict import EasyDict
import numpy as np
import random
import time
import datetime
from deap import tools
from collections import OrderedDict
from pprint import pprint
import json
import torch

sys.path.insert(0, '../')
from utils_kyy.utils_graph import make_random_graph_v2
from utils_kyy.create_toolbox import create_toolbox_for_NSGA_RWNN, evaluate_v2


import argparse
import random


class full_train:
    def __init__(self, json_file):
        self.root = os.path.abspath(os.path.join(os.getcwd(), '..'))
        self.param_dir = os.path.join(self.root + '/parameters/', json_file)
        f = open(self.param_dir)
        params = json.load(f)
        pprint(params)
        self.name = params['NAME']

        ## toolbox params
        self.args_train = EasyDict(params['ARGS_TRAIN'])
        self.data_path = params['DATA_PATH']
        self.run_code = params['RUN_CODE']
        self.stage_pool_path = '../graph_pool' + '/' + self.run_code + '_' + 'experiment_1' + '/'
        self.stage_pool_path_list = []
        for i in range(1, 4):
            stage_pool_path_i = self.stage_pool_path + str(i) + '/'  # eg. [graph_pool/run_code_name/1/, ... ]
            self.stage_pool_path_list.append(stage_pool_path_i)
        
        self.log_path = '../logs/' + self.run_code + '_' + self.name + '/'
        # self.log_file_name : Initialize 부터 GA 진행상황 등 코드 전체에 대한 logging
        self.log_file_name = self.log_path + 'logging.log'
        # self.train_log_file_name : fitness (= flops, val_accuracy). 즉 GA history 를 저장 후, 나중에 사용하기 위한 logging.
        self.train_log_file_name = self.log_path + 'train_logging.log'
        
        if not os.path.exists(self.stage_pool_path):
            os.makedirs(self.stage_pool_path)
            for i in range(3):
                os.makedirs(self.stage_pool_path_list[i])
                
        if not os.path.isdir(self.log_path):
            os.makedirs(self.log_path)
            
        logging.basicConfig(filename=self.log_file_name, level=logging.INFO)
        logging.info('[Start] Rwns_train class is initialized.')        
        logging.info('Start to write log.')
            
        self.num_graph = params['NUM_GRAPH']
        
        
        ## Temperary
        self.models = [[1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1],
       [1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1],
       [1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1],
       [1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1],
       [1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1],
       [1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1],
       [1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0]]
        
        self.random_models = [[0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1],
       [1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0],
       [1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0],
       [0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0],
       [0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0],
       [1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0]]
       
        #self.random_models = 
        

        ## logs
        self.log = OrderedDict()
        self.log['hp'] = self.args_train
        self.train_log = OrderedDict()
        
        ## 기존 train_log 불러와서 이어서 train하기
        self.TRAIN_FROM_LOGS = params['TRAIN_FROM_LOGS']
        self.REAL_TRAIN_LOG_PATH = params['REAL_TRAIN_LOG_PATH']
        
        
        
    def train(self,mode=0):
        
        # mode = 0 : GA, 1: random
        inds = self.models
        if mode == 1 : inds = self.random_models
        
        
    ###################################
    # 1. Initialize the population.  (toolbox.population은 creator.Individual n개를 담은 list를 반환. (=> population)
    ###################################
        now = datetime.datetime.now()
        now_str = now.strftime('%Y-%m-%d %H:%M:%S')

        ## ind는 gray code로 되어 있다
        ### 모델 가져오기 --> evaluate 함수
        ##  모델 트레이닝


        ### inds 만들기
        # 떨궈진 파일 파싱하기? 
        inds = self.models

        if self.TRAIN_FROM_LOGS == False:

            for idx, ind in enumerate(inds):

                logging.info("Start Model Training " + now_str + " model " + str(idx))
                init_start_time = time.time()
                model_dict =  {}
                fitness, epoch = evaluate_v2_full_train(ind, args_train=self.args_train,
                                                     stage_pool_path_list=self.stage_pool_path_list,
                                                     data_path=self.data_path,
                                                     log_file_name=self.log_file_name)

                now_str = now.strftime('%Y-%m-%d %H:%M:%S')
                logging.info("Initialion is finished at " + now_str)


                end_time = time.time()
                model_dict['model_id'] = ind 
                model_dict['fitness'] = fitness
                model_dict['time'] = end_time - init_start_time
                model_dict['epoch'] = epoch


                ## log 기록 - initialize (= 0th generation)
                self.train_log[str(idx)] = model_dict
                self.save_log()



        # train_log 읽어와서 중간부터 이어서 train 하는 경우
        # [Reference] Seeding a population => https://deap.readthedocs.io/en/master/tutorials/basic/part1.html
        elif self.TRAIN_FROM_LOGS == True:
            print("################# [KYY-check] Read train_log from the middle #################")
            logging.info("################# [KYY-check] Read train_log from the middle #################")

            # train_log 읽어오기
            with open(self.REAL_TRAIN_LOG_PATH) as train_log_json_file:
                data = json.load(train_log_json_file)  # hp(=hyperparameter), train_log 있음

            train_log_past = data['train_log']
            niter = len(train_log_past)  # 기록 상 총 init 횟수

            start_gen = niter  # niter = 11 이면, log 상에 0 ~ 10번까지 기록되어있는 것.

            # self.train_log 에 읽어온 로그 넣어놓기 (OrderedDict())
            for i in range(niter):
                self.train_log[str(i)] = train_log_past[str(i)]


            for idx in range(niter, len(inds)):

                ind = inds[idx]

                logging.info("Start Model Training " + now_str + " model " + idx)
                init_start_time = time.time()
                model_dict =  {}
                fitness, epoch = evaluate_v2_full_train(ind, args_train=self.args_train,
                                                     stage_pool_path_list=self.stage_pool_path_list,
                                                     data_path=self.data_path,
                                                     log_file_name=self.log_file_name)

                now_str = now.strftime('%Y-%m-%d %H:%M:%S')
                logging.info("Initialion is finished at " + now_str)


                end_time = time.time()
                model_dict['model_id'] = ind 
                model_dict['fitness'] = fitness
                model_dict['time'] = end_time - init_start_time
                model_dict['epoch'] = epoch


                ## log 기록 - initialize (= 0th generation)
                self.train_log[str(idx)] = model_dict
                self.save_log()




    ## Save Log
    def save_log(self):
        ## 필요한 log 추후 정리하여 추가 
        self.log['train_log'] = self.train_log

        with open(self.train_log_file_name, 'w', encoding='utf-8') as make_file:
            json.dump(self.log, make_file, ensure_ascii=False, indent='\t')

            
        
# if __name__ == "__main__":
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--params', type=str, help='Parameter Json file')
    
#     args = parser.parse_args()
    
#     trainer = full_train(json_file=args.params)

#     trainer.train()
#     trainer.save_log()


In [16]:
trainer = full_train(json_file='full_train.json')

{'ARGS_TRAIN': {'base_lr': 0.2,
                'batch_size': 32,
                'data': 'CIFAR10',
                'epochs': 5,
                'graycode': True,
                'input_dim': 3,
                'lr_mode': 'cosine',
                'momentum': 0.9,
                'num_classes': 10,
                'print_freq': 100,
                'targetlr': 0.0,
                'warmup_epochs': 0,
                'warmup_lr': 0.0,
                'warmup_mode': 'linear',
                'weight_decay': 5e-05,
                'workers': 4},
 'DATA_PATH': 'D:/data/cifar10/',
 'NAME': 'full_train',
 'NUM_GRAPH': 128,
 'REAL_TRAIN_LOG_PATH': '/root/data/basic_model/logs/__New_main_experiment_1_to13gen/train_logging.log',
 'RUN_CODE': 'New_main',
 'TRAIN_FROM_LOGS': False}


In [17]:
trainer.train()

Files already downloaded and verified
Files already downloaded and verified
	 - Epoch: [0][0/1563]	Time 6.871 (6.871)	Loss 2.3372 (2.3372)	Prec@1 3.125 (3.125)	Prec@5 37.500 (37.500)
	 - Epoch: [0][100/1563]	Time 0.871 (0.847)	Loss 2.2453 (2.4310)	Prec@1 9.375 (13.181)	Prec@5 53.125 (57.426)
	 - Epoch: [0][200/1563]	Time 0.730 (0.813)	Loss 2.1650 (2.2854)	Prec@1 15.625 (15.967)	Prec@5 78.125 (65.205)
	 - Epoch: [0][300/1563]	Time 0.722 (0.801)	Loss 2.0291 (2.2044)	Prec@1 28.125 (18.272)	Prec@5 71.875 (68.978)
	 - Epoch: [0][400/1563]	Time 0.729 (0.793)	Loss 1.7214 (2.1513)	Prec@1 46.875 (20.433)	Prec@5 84.375 (71.626)
	 - Epoch: [0][500/1563]	Time 0.743 (0.792)	Loss 1.8915 (2.1148)	Prec@1 21.875 (21.850)	Prec@5 87.500 (73.428)
	 - Epoch: [0][600/1563]	Time 0.798 (0.792)	Loss 1.7782 (2.0791)	Prec@1 28.125 (23.102)	Prec@5 81.250 (75.192)
	 - Epoch: [0][700/1563]	Time 0.893 (0.796)	Loss 1.8582 (2.0521)	Prec@1 31.250 (24.055)	Prec@5 90.625 (76.306)
	 - Epoch: [0][800/1563]	Time 0.877 (0.79

	 - Epoch: [0][1000/1563]	Time 0.788 (0.819)	Loss 1.7804 (2.0923)	Prec@1 25.000 (21.831)	Prec@5 87.500 (72.883)
	 - Epoch: [0][1100/1563]	Time 0.792 (0.819)	Loss 1.6928 (2.0718)	Prec@1 40.625 (22.624)	Prec@5 87.500 (73.856)
	 - Epoch: [0][1200/1563]	Time 0.785 (0.818)	Loss 2.0390 (2.0522)	Prec@1 34.375 (23.397)	Prec@5 78.125 (74.768)
	 - Epoch: [0][1300/1563]	Time 0.788 (0.818)	Loss 1.8812 (2.0344)	Prec@1 25.000 (24.152)	Prec@5 75.000 (75.579)
	 - Epoch: [0][1400/1563]	Time 0.803 (0.817)	Loss 1.7490 (2.0161)	Prec@1 40.625 (24.946)	Prec@5 78.125 (76.280)
	 - Epoch: [0][1500/1563]	Time 0.783 (0.817)	Loss 1.6381 (1.9992)	Prec@1 40.625 (25.656)	Prec@5 87.500 (76.932)
##### Validation_time 74.642 Prec@1 39.140 Prec@5 88.220 #####
	 - Epoch: [1][0/1563]	Time 6.703 (6.703)	Loss 1.8182 (1.8182)	Prec@1 31.250 (31.250)	Prec@5 84.375 (84.375)
	 - Epoch: [1][100/1563]	Time 0.801 (0.875)	Loss 1.7814 (1.7186)	Prec@1 21.875 (35.984)	Prec@5 90.625 (87.129)
	 - Epoch: [1][200/1563]	Time 0.889 (0.849)	L

KeyboardInterrupt: 

In [None]:
trainer.stage_pool_path