## Ga trainer
* read graph from chromosome
* checkpoint 수정
* termination criteria 

In [5]:
import os
import time
import torch
import logging
import argparse

from utils.train import train
from utils.hparams import HParam
from utils.writer import MyWriter
from utils.graph_reader import read_graph
from dataset.dataloader import create_dataloader, MNIST_dataloader, CIFAR10_dataloader


# 원래는 yaml 파일, checkpoint path, name of model args로 넘기자
## parser만 잘 바꾸면 될듯
def ga_trainer(args,index_list,f_path,f_name):
    
#     parser = argparse.ArgumentParser()
#     parser.add_argument('-c', '--config', type=str, required=True,
#                         help="yaml file for configuration")
#     parser.add_argument('-p', '--checkpoint_path', type=str, default=None, required=False,
#                         help="path of checkpoint pt file")
#     parser.add_argument('-m', '--model', type=str, required=True,
#                         help="name of the model. used for logging/saving checkpoints")
#     args = parser.parse_args()

    individual_model_name =args.model + "_{}_{}_{}".format(index_list[0],index_list[1],
                                                          index_list[2])
    
    hp = HParam(args.config)
    with open(args.config, 'r') as f:
        hp_str = ''.join(f.readlines())
    ## pytoch 모델 저장하는 위치
    
    pt_path = os.path.join('.', hp.log.chkpt_dir)
    ## 모델 사전에 정의한 모델 이름으로 저장
    out_dir = os.path.join(pt_path, individual_model_name)
    os.makedirs(out_dir, exist_ok=True)

    log_dir = os.path.join('.', hp.log.log_dir)
    log_dir = os.path.join(log_dir, individual_model_name)
    os.makedirs(log_dir, exist_ok=True)

    if args.checkpoint_path is not None:
        chkpt_path = args.checkpoint_path
    else:
        chkpt_path = None

    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(os.path.join(log_dir,
                '%s-%d.log' % (args.model, time.time()))),
            logging.StreamHandler()
        ]
    )
    logger = logging.getLogger()
    
    if hp.data.train == '' or hp.data.val == '':
        logger.error("hp.data.train, hp.data.val cannot be empty")
        raise Exception("Please specify directories of train data.")

    if hp.model.graph0 == '' or hp.model.graph1 == '' or hp.model.graph2 == '':
        logger.error("hp.model.graph0, graph1, graph2 cannot be empty")
        raise Exception("Please specify random DAG architecture.")

#     graphs = [
#         read_graph(hp.model.graph0),
#         read_graph(hp.model.graph1),
#         read_graph(hp.model.graph2),
#     ]

    ## 새로 생성한 파일 위치에서 그래프 읽기
    graphs = [read_graph(os.path.join(f_path,args.model + '_' + str(idx)+ '.txt')) for idx in index_list]

    writer = MyWriter(log_dir)
    
    dataset = hp.data.type
    switcher = {
            'MNIST': MNIST_dataloader,
            'CIFAR10':CIFAR10_dataloader,
            'ImageNet':create_dataloader,
            }
    assert dataset in switcher.keys(), 'Dataset type currently not supported'
    dl_func = switcher[dataset]
    trainset = dl_func(hp, args, True)
    valset = dl_func(hp, args, False)

    val_acc = ga_train(out_dir, chkpt_path, trainset, valset, writer, logger, hp, hp_str, graphs)
## Test



In [6]:
#def ga_trainer(args,index_list,f_path,f_name):

## 테스트성공
if __name__ == '__main__':
    args = easydict.EasyDict({
  parser.add_argument('-m', '--model', type=str, required=True,
#                         help="name of the model. used for logging/saving checkpoints")
#     args = parser.parse_args()

    individual_model_name =args.model + "_{}_{}_{}".format(index_list[0],index_list[1],
                                                          index_list[2])
    
    hp = HParam(args.config)
    with open(args.config, 'r') as f:
        hp_str = ''.join(f.readlines())
    ## pytoch 모델 저장하는 위치
    
    pt_path = os.path.join('.', hp.log.chkpt_dir)
    ## 모델 사전에 정의한 모델 이름으로 저장
    out_dir = os.path.join(pt_path, individual_model_name)
    os.makedirs(out_dir, exist_ok=True)

    log_dir = os.path.join('.', hp.log.log_dir)
    log_dir = os.path.join(log_dir, individual_model_name)
    os.makedirs(log_dir, exist_ok=True)

    if args.checkpoint_path is not None:
        chkpt_path = args.checkpoint_path
    else:
        chkpt_path = None

    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(os.path.join(log_dir,
                '%s-%d.log' % (args.model, time.time()))),
            logging.StreamHandler()
        ]
    )
    logger = logging.getLogger()
    
    if hp.data.train == '' or hp.data.val == '':
        logger.error("hp.data.train, hp.data.val cannot be empty")
        raise Exception("Please specify directories of train data.")

    if hp.model.graph0 == '' or hp.model.graph1 == '' or hp.model.graph2 == '':
        logger.error("hp.model.graph0, graph1, graph2 cannot be empty")
        raise Exception("Please specify random DAG architecture.")

#     graphs = [
#         read_graph(hp.model.graph0),
#         read_graph(hp.model.graph1),
#         read_graph(hp.model.graph2),
#     ]

    ## 새로 생성한 파일 위치에서 그래프 읽기
    graphs = [read_graph(os.path.join(f_path,args.model + '_' + str(idx)+ '.txt')) for idx in index_list]

    writer = MyWriter(log_dir)
    
    dataset = hp.data.type
    switcher = {
            'MNIST': MNIST_dataloader,
            'CIFAR10':CIFAR10_dataloader,
            'ImageNet':create_dataloader,
            }
    assert dataset in switcher.keys(), 'Dataset type currently not supported'
    dl_func = switcher[dataset]
    trainset = dl_func(hp, args, True)
    valset = dl_func(hp, args, False)

    val_acc = ga_train(out_dir, chkpt_path, trainset, valset, writer, logger, hp, hp_str, graphs)

2019-07-12 14:59:09,903 - INFO - Starting new training run
2019-07-12 14:59:09,906 - INFO - Writing graph to tensorboardX...
2019-07-12 14:59:13,215 - INFO - Finished.
Loss 0.13 at step 58: 100%|████████████████████████████████████████████████████████████| 59/59 [00:35<00:00,  1.77it/s]
78it [00:06, 12.37it/s]
2019-07-12 14:59:55,357 - INFO - Saved checkpoint to: .\chkpt\ws-4-0.75_1_2_3\chkpt_000.pt
2019-07-12 14:59:55,359 - INFO - Validation Accuracy imporved: 0.488600
Loss 0.08 at step 117: 100%|███████████████████████████████████████████████████████████| 59/59 [00:35<00:00,  1.78it/s]
78it [00:06, 12.88it/s]
2019-07-12 15:00:37,195 - INFO - Saved checkpoint to: .\chkpt\ws-4-0.75_1_2_3\chkpt_001.pt
2019-07-12 15:00:37,196 - INFO - Validation Accuracy imporved: 0.974000
Loss 0.04 at step 176: 100%|███████████████████████████████████████████████████████████| 59/59 [00:35<00:00,  1.71it/s]
78it [00:06, 12.31it/s]
2019-07-12 15:01:19,156 - INFO - Saved checkpoint to: .\chkpt\ws-4-0.75_1_

In [7]:
val_acc

0.9775

## GA Train

In [2]:
import easydict

In [3]:
import os
import math
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import adabound
import itertools
import traceback
from torchsummary import summary

from utils.hparams import load_hparam_str
from utils.evaluation import validate
from model.model import RandWire


## GA train\
def ga_train(out_dir, chkpt_path, trainset, valset, writer, logger, hp, hp_str, graphs):
    model = RandWire(hp, graphs).cuda()
    if hp.train.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=hp.train.adam)
    elif hp.train.optimizer == 'adabound':
        optimizer = adabound.AdaBound(model.parameters(),
                             lr=hp.train.adabound.initial,
                             final_lr=hp.train.adabound.final)
    elif hp.train.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=hp.train.sgd.lr,
                                    momentum=hp.train.sgd.momentum,
                                    weight_decay=hp.train.sgd.weight_decay)
    else:
        raise Exception("Optimizer not supported: %s" % hp.train.optimizer)

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, hp.train.epoch)

    init_epoch = -1
    step = 0

    if chkpt_path is not None:
        logger.info("Resuming from checkpoint: %s" % chkpt_path)
        checkpoint = torch.load(chkpt_path)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        step = checkpoint['step']
        init_epoch = checkpoint['epoch']

        if hp_str != checkpoint['hp_str']:
            logger.warning("New hparams are different from checkpoint.")
            logger.warning("Will use new hparams.")
        # hp = load_hparam_str(hp_str)
    else:
        logger.info("Starting new training run")
        logger.info("Writing graph to tensorboardX...")
        writer.write_graph(model, torch.randn(7, hp.model.input_maps, 224, 224).cuda())
        logger.info("Finished.")

    try:
        model.train()
        patients = 0
        prev_acc = 0
        for epoch in itertools.count(init_epoch+1):
            loader = tqdm.tqdm(trainset, desc='Train data loader')
            for data, target in loader:
                data, target = data.cuda(), target.cuda()
                optimizer.zero_grad()
                output = model(data)
                loss = F.nll_loss(output, target)
                loss.backward()
                optimizer.step()
                
                loss = loss.item()
                if loss > 1e8 or math.isnan(loss):
                    logger.error("Loss exploded to %.02f at step %d!" % (loss, step))
                    raise Exception("Loss exploded")

                writer.log_training(loss, step)
                loader.set_description('Loss %.02f at step %d' % (loss, step))
                step += 1
            
            #validation
            val_loss, val_acc = validate(model, valset, writer, epoch)
            
            if prev_acc < val_acc:
                save_path = os.path.join(out_dir, 'chkpt_%03d.pt' % epoch)
                torch.save({
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'step': step,
                    'epoch': epoch,
                    'hp_str': hp_str,
                }, save_path)
                logger.info("Saved checkpoint to: %s" % save_path)
                logger.info("Validation Accuracy imporved: %3f" % val_acc)

                patients = 0
                prev_acc = val_acc
            else :
                patients += 1
                
            if patients > 5 :
                break
                
            ## 임시 조건
            if epoch > 3:
                break

            lr_scheduler.step()

    except Exception as e:
        logger.info("Exiting due to exception: %s" % e)
        traceback.print_exc()
        
    return prev_acc


In [None]:
import os
import math
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import adabound
import itertools
import traceback
from torchsummary import summary

from utils.hparams import load_hparam_str
from utils.evaluation import validate
from model.model import RandWire


def train(out_dir, chkpt_path, trainset, valset, writer, logger, hp, hp_str, graphs):
    model = RandWire(hp, graphs).cuda()
    if hp.train.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=hp.train.adam)
    elif hp.train.optimizer == 'adabound':
        optimizer = adabound.AdaBound(model.parameters(),
                             lr=hp.train.adabound.initial,
                             final_lr=hp.train.adabound.final)
    elif hp.train.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=hp.train.sgd.lr,
                                    momentum=hp.train.sgd.momentum,
                                    weight_decay=hp.train.sgd.weight_decay)
    else:
        raise Exception("Optimizer not supported: %s" % hp.train.optimizer)

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, hp.train.epoch)

    init_epoch = -1
    step = 0

    if chkpt_path is not None:
        logger.info("Resuming from checkpoint: %s" % chkpt_path)
        checkpoint = torch.load(chkpt_path)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        step = checkpoint['step']
        init_epoch = checkpoint['epoch']

        if hp_str != checkpoint['hp_str']:
            logger.warning("New hparams are different from checkpoint.")
            logger.warning("Will use new hparams.")
        # hp = load_hparam_str(hp_str)
    else:
        logger.info("Starting new training run")
        logger.info("Writing graph to tensorboardX...")
        #print(model)
        # parameters = 0
        # for p in list(model.parameters()):
        #     nn =1
        #     for s in list(p.size()):
        #         nn = nn * s
        #     parameters += nn
        #print("Parameters",parameters)

        #print("model",hp.model)
 #       summary(model,(1,224,224))
        writer.write_graph(model, torch.randn(7, hp.model.input_maps, 224, 224).cuda())
        #Batch??, input_channels, width,depths
        logger.info("Finished.")

    try:
        model.train()
        for epoch in itertools.count(init_epoch+1):
            loader = tqdm.tqdm(trainset, desc='Train data loader')
            for data, target in loader:
                data, target = data.cuda(), target.cuda()
                optimizer.zero_grad()
                output = model(data)
                loss = F.nll_loss(output, target)
                loss.backward()
                optimizer.step()
                
                loss = loss.item()
                if loss > 1e8 or math.isnan(loss):
                    logger.error("Loss exploded to %.02f at step %d!" % (loss, step))
                    raise Exception("Loss exploded")

                writer.log_training(loss, step)
                loader.set_description('Loss %.02f at step %d' % (loss, step))
                step += 1                

            save_path = os.path.join(out_dir, 'chkpt_%03d.pt' % epoch)
            torch.save({
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'step': step,
                'epoch': epoch,
                'hp_str': hp_str,
            }, save_path)
            logger.info("Saved checkpoint to: %s" % save_path)

            validate(model, valset, writer, epoch)
            lr_scheduler.step()

    except Exception as e:
        logger.info("Exiting due to exception: %s" % e)
        traceback.print_exc()


## Graph making code

In [None]:
import os
import math
import argparse
import numpy as np

if __name__ == '__main__':
    parser = argparse.ArgumentParser('Erdos-Renyi graph generator')
    parser.add_argument('-n', '--n_nodes', type=int, default=32,
                        help="number of nodes for random graph")
    parser.add_argument('-k', '--k_neighbors', type=int, required=True,
                        help="connecting neighboring nodes for WS")
    parser.add_argument('-p', '--prob', type=float, required=True,
                        help="probablity of rewiring for WS")
    parser.add_argument('-o', '--out_txt', type=str, required=True,
                        help="name of output txt file")

    parser.add_argument('-f', '--file_num', type=int, required=True, default=1,
                        help="number of files to generate")

    args = parser.parse_args()
    n, k, p, file_num = args.n_nodes, args.k_neighbors, args.prob, args.file_num

    assert k % 2 == 0, "k must be even."
    assert 0 < k < n, "k must be larger than 0 and smaller than n."

    os.makedirs('ga_initials', exist_ok=True)
    for num in range(file_num):
        adj = [[False] * n for _ in range(n)]  # adjacency matrix
        for i in range(n):
            adj[i][i] = True

        # initial connection
        for i in range(n):
            for j in range(i - k // 2, i + k // 2 + 1):
                real_j = j % n
                if real_j == i:
                    continue
                adj[real_j][i] = adj[i][real_j] = True

        rand = np.random.uniform(0.0, 1.0, size=(n, k // 2))
        for i in range(n):
            for j in range(1, k // 2 + 1):  # 'j' here is 'i' of paper's notation
                current = (i + j) % n
                if rand[i][j - 1] < p:  # rewire
                    unoccupied = [x for x in range(n) if not adj[i][x]]
                    rewired = np.random.choice(unoccupied)
                    adj[i][current] = adj[current][i] = False
                    adj[i][rewired] = adj[rewired][i] = True

        edges = list()
        for i in range(n):
            for j in range(i + 1, n):
                if adj[i][j]:
                    edges.append((i, j))

        edges.sort()

        file_name = args.out_txt[:-4] + "_" + str(num) + ".txt"
        with open(os.path.join('ga_initials', file_name), 'w') as f:
            f.write(str(n) + '\n')
            f.write(str(len(edges)) + '\n')
            for edge in edges:
                f.write('%d %d\n' % (edge[0], edge[1]))
