In [None]:
import numpy as np
from torch.backends import cudnn
from utils.data_loader import Data_Loader
from utils.utils import make_folder
from utils.utils import Configuration
from models import SOMGAN_model
import time
import math

In [None]:
d_num = 9
topo_style = 'grid' # linear, grid, circular
if topo_style == 'grid':
    a = int((math.sqrt(d_num)))
    if a * a != d_num:
        print('d_num should be a squared number.')

total_step = 100
batch_size = 64
network = 'CNN_network' # CNN_network, SN_network, Attn_SN_network

dataset = 'stl10'  # 'mnist', 'cifar', 'lsun', 'celeb', 'stl10'
image_path = '~/Datasets/STL-10/'
imsize = 96
channel = 3
conv_dim = 64

In [None]:
version = 'SOMGAN_' + network +'_'  + dataset + '_' + str(d_num) + 'D_' + topo_style + '_' + str(total_step) + 'iters_' + str(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
model_save_path = './output/'+version+'/models'
loss_save_path = './output/' + version + '/loss'
sample_path = './output/'+version+'/samples'

# Create directories if not exist
make_folder(path = model_save_path)
make_folder(path = loss_save_path)
make_folder(path = sample_path)

# Data loader
data_loader = Data_Loader(dataset = dataset, image_path = image_path, image_size = imsize, batch_size = batch_size, shuf = True)

print("Data_Loader: 1 epoch = %d iterations" % len(data_loader.loader()))

In [None]:
# Training configures
configs = {
    'dataset': dataset,
    'imsize': imsize,
    'network': network,
    'batch_size': batch_size,
    'g_conv_dim': conv_dim,
    'd_conv_dim': conv_dim,
    'channel': channel,
    'd_num': d_num,
    'topo_style': topo_style,
    'z_dim': 64,
    'total_step': total_step,
    'log_step': 1000,
    'sample_step': 5000,
    'model_save_step': 5000,
    'g_lr': 0.001,
    'd_lr': 0.001,
    'lr_decay': 0.95,
    'beta1': 0.9,
    'beta2': 0.999,
    'version': version,
    'parallel': False,
    'printnet': False,
    'use_tensorboard': True,
    'pretrained_model': None,
    'model_save_path': model_save_path,
    'sample_path': sample_path,
    'loss_save_path': loss_save_path
}

trainer = SOMGAN_model.Trainer(data_loader.loader(), Configuration(configs))

In [None]:
trainer.train()