In [None]:
import numpy as np
from torch.backends import cudnn
from utils.utils import make_folder
from utils.utils import Configuration
from models import SOMGAN_toy2D_model
import time
import math

In [None]:
d_num = 9
topo_style = 'circular' # linear, grid, circular
if topo_style == 'grid':
    a = int((math.sqrt(d_num)))
    if a * a != d_num:
        print('d_num should be a squared number.')

total_step = 25000
batch_size = 512

num_mixtures = 8
std = 0.02
radius = 2.0

# circular data:
thetas = np.linspace(0, 2 * np.pi, num_mixtures + 1)[:num_mixtures]
xs, ys = radius * np.sin(thetas), radius * np.cos(thetas)

# # grid data:
# thetas = np.linspace(0, math.sqrt(num_mixtures), int(math.sqrt(num_mixtures)) + 1)[:int(math.sqrt(num_mixtures))]
# xs, ys = radius * np.repeat(thetas, int(math.sqrt(num_mixtures))), radius * np.tile(thetas, int(math.sqrt(num_mixtures)))

# # spiral:
# thetas = np.linspace(0, 3 * np.pi, num_mixtures + 1)[:num_mixtures]
# xs, ys = (1 + 0.5 * thetas) * np.cos(thetas), (1 + 0.5 * thetas) * np.sin(thetas)

cudnn.benchmark = True

In [None]:
version = 'SOMGAN_toy2D_' + str(d_num) + 'G_' + topo_style + '_' + str(total_step) + 'iters_' + str(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
model_save_path = './output/'+version+'/models'
loss_save_path = './output/' + version + '/loss'
sample_path = './output/'+version+'/samples'

# Create directories if not exist
make_folder(path = model_save_path)
make_folder(path = loss_save_path)
make_folder(path = sample_path)

In [None]:
# Training configures
configs = {    
    'z_dim': 256,
    'h_dim': 128, 
    'd_num': d_num,
    'mix_coeffs': tuple([1 / num_mixtures] * num_mixtures),
    'mean': tuple(zip(xs, ys)),
    'cov': tuple([(std, std)] * num_mixtures),
    'num_samples': 512,
    'batch_size': batch_size,
    'topo_style': topo_style,
    'total_step': total_step,
    'log_step': 1000,
    'sample_step': 5000,
    'model_save_step': 5000,
    'g_lr': 0.001,
    'd_lr': 0.001,
    'lr_decay': 0.95,
    'beta1': 0.9,
    'beta2': 0.999,
    'version': version,
    'parallel': False,
    'printnet': False,
    'pretrained_model': None,
    'model_save_path': model_save_path,
    'sample_path': sample_path,
    'loss_save_path': loss_save_path
}

trainer = SOMGAN_toy2D_model.Trainer(Configuration(configs))

In [None]:
trainer.train()