Skip to content

Commit

Permalink
udpate interpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
nashory committed Dec 26, 2017
1 parent f5bfeab commit 17c1767
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 21 deletions.
2 changes: 2 additions & 0 deletions clean.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
rm -rf repo
rm -rf *.pyc
86 changes: 86 additions & 0 deletions generate_interpolated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# generate interpolated images.


import os,sys
import torch
from config import config
from torch.autograd import Variable
import utils as utils


use_cuda = True
checkpoint_path = 'repo/model/gen_R8_T55.pth.tar'
n_intp = 20


# load trained model.
import network as net
test_model = net.Generator(config)
if use_cuda:
torch.set_default_tensor_type('torch.cuda.FloatTensor')
test_model = torch.nn.DataParallel(test_model).cuda(device=0)
else:
torch.set_default_tensor_type('torch.FloatTensor')

for resl in range(3, config.max_resl+1):
test_model.module.grow_network(resl)
test_model.module.flush_network()
print(test_model)


print('load checkpoint form ... {}'.format(checkpoint_path))
checkpoint = torch.load(checkpoint_path)
test_model.module.load_state_dict(checkpoint['state_dict'])


# create folder.
for i in range(1000):
name = 'repo/interpolation/try_{}'.format(i)
if not os.path.exists(name):
os.system('mkdir -p {}'.format(name))
break;

# interpolate between twe noise(z1, z2).
z_intp = torch.FloatTensor(1, config.nz)
z1 = torch.FloatTensor(1, config.nz).normal_(0.0, 1.0)
z2 = torch.FloatTensor(1, config.nz).normal_(0.0, 1.0)
if use_cuda:
z_intp = z_intp.cuda()
z1 = z1.cuda()
z2 = z2.cuda()
test_model = test_model.cuda()

z_intp = Variable(z_intp)


for i in range(1, n_intp+1):
alpha = 1.0/float(n_intp+1)
z_intp.data = z1.mul_(alpha) + z2.mul_(1.0-alpha)
fake_im = test_model.module(z_intp)
fname = os.path.join(name, '_intp{}.jpg'.format(i))
utils.save_image_single(fake_im.data, fname, imsize=pow(2,config.max_resl))
print('saved {}-th interpolated image ...'.format(i))



'''
self.z1.data.normal_(0.0, 1.0)
self.z2 = torch.FloatTensor(1, config.nz).cuda() if use_cuda else torch.FloatTensor(1,config.nz)
self.z2 = Variable(self.z2)
self.z2.data.normal_(0.0, 1.0)
print
'''
# forward



# save







12 changes: 4 additions & 8 deletions network.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def first_block(self):

def intermediate_block(self, resl):
halving = False
layer_name = 'itermediate_{}x{}_{}x{}'.format(pow(2,resl-1), pow(2,resl-1), pow(2, resl), pow(2, resl))
layer_name = 'intermediate_{}x{}_{}x{}'.format(int(pow(2,resl-1)), int(pow(2,resl-1)), int(pow(2, resl)), int(pow(2, resl)))
ndim = self.ngf
if resl==3 or resl==4 or resl==5:
halving = False
Expand Down Expand Up @@ -132,7 +132,7 @@ def grow_network(self, resl):
new_model[-1].load_state_dict(module.state_dict()) # copy pretrained weights

if resl >= 3 and resl <= 9:
print 'growing network[{}x{} to {}x{}]. It may take few seconds...'.format(pow(2,resl-1), pow(2,resl-1), pow(2,resl), pow(2,resl))
print 'growing network[{}x{} to {}x{}]. It may take few seconds...'.format(int(pow(2,resl-1)), int(pow(2,resl-1)), int(pow(2,resl)), int(pow(2,resl)))
low_resl_to_rgb = deepcopy_module(self.model, 'to_rgb_block')
prev_block = nn.Sequential()
prev_block.add_module('low_resl_upsample', nn.Upsample(scale_factor=2, mode='nearest'))
Expand Down Expand Up @@ -168,12 +168,10 @@ def flush_network(self):
new_model.add_module('to_rgb_block', high_resl_to_rgb)
self.model = new_model
self.module_names = get_module_names(self.model)
print(self.model)


except:
self.model = self.model
print(self.model)

def freeze_layers(self):
# let's freeze pretrained blocks. (Found freezing layers not helpful, so did not use this func.)
Expand Down Expand Up @@ -217,7 +215,7 @@ def last_block(self):

def intermediate_block(self, resl):
halving = False
layer_name = 'itermediate_{}x{}_{}x{}'.format(pow(2,resl), pow(2,resl), pow(2, resl-1), pow(2, resl-1))
layer_name = 'intermediate_{}x{}_{}x{}'.format(int(pow(2,resl)), int(pow(2,resl)), int(pow(2, resl-1)), int(pow(2, resl-1)))
ndim = self.ndf
if resl==3 or resl==4 or resl==5:
halving = False
Expand Down Expand Up @@ -254,7 +252,7 @@ def get_init_dis(self):
def grow_network(self, resl):

if resl >= 3 and resl <= 9:
print 'growing network[{}x{} to {}x{}]. It may take few seconds...'.format(pow(2,resl-1), pow(2,resl-1), pow(2,resl), pow(2,resl))
print 'growing network[{}x{} to {}x{}]. It may take few seconds...'.format(int(pow(2,resl-1)), int(pow(2,resl-1)), int(pow(2,resl)), int(pow(2,resl)))
low_resl_from_rgb = deepcopy_module(self.model, 'from_rgb_block')
prev_block = nn.Sequential()
prev_block.add_module('low_resl_downsample', nn.AvgPool2d(kernel_size=2))
Expand Down Expand Up @@ -299,10 +297,8 @@ def flush_network(self):

self.model = new_model
self.module_names = get_module_names(self.model)
print new_model
except:
self.model = self.model
print self.model

def freeze_layers(self):
# let's freeze pretrained blocks. (Found freezing layers not helpful, so did not use this func.)
Expand Down
47 changes: 34 additions & 13 deletions trainer.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def resl_scheduler(self):
self.fadein['dis'].update_alpha(d_alpha)
self.complete['dis'] = self.fadein['dis'].alpha*100
self.phase = 'dtrns'
elif self.resl%1.0 >= (self.stab_tick + self.trns_tick*2)*delta:
elif self.resl%1.0 >= (self.stab_tick + self.trns_tick*2)*delta and self.phase!='final':
self.phase = 'dstab'

prev_kimgs = self.kimgs
Expand All @@ -127,6 +127,7 @@ def resl_scheduler(self):
self.complete['gen'] = self.fadein['gen'].alpha*100
self.flag_flush_gen = False
self.G.module.flush_network() # flush G
print(self.G.module.model)
#self.Gs.module.flush_network() # flush Gs
self.fadein['gen'] = None
self.complete['gen'] = 0.0
Expand All @@ -137,12 +138,14 @@ def resl_scheduler(self):
self.complete['dis'] = self.fadein['dis'].alpha*100
self.flag_flush_dis = False
self.D.module.flush_network() # flush and,
print(self.D.module.model)
self.fadein['dis'] = None
self.complete['dis'] = 0.0
self.phase = 'gtrns'

if floor(self.resl) < self.max_resl and self.phase != 'final':
self.phase = 'gtrns'

# grow network.
if floor(self.resl) != prev_resl:
if floor(self.resl) != prev_resl and floor(self.resl)<self.max_resl+1:
self.lr = self.lr * float(self.config.lr_decay)
self.G.module.grow_network(floor(self.resl))
#self.Gs.module.grow_network(floor(self.resl))
Expand All @@ -153,15 +156,16 @@ def resl_scheduler(self):
self.flag_flush_gen = True
self.flag_flush_dis = True

if floor(self.resl) >= self.max_resl:
self.resl = self.max_resl
if floor(self.resl) >= self.max_resl and self.resl%1.0 >= (self.stab_tick + self.trns_tick*2)*delta:
self.phase = 'final'
self.resl = self.max_resl + (self.stab_tick + self.trns_tick*2)*delta



def renew_everything(self):
# renew dataloader.
self.loader = DL.dataloader(config)
self.loader.renew(floor(self.resl))
self.loader.renew(min(floor(self.resl), self.max_resl))

# define tensors
self.z = torch.FloatTensor(self.loader.batchsize, self.nz)
Expand Down Expand Up @@ -199,7 +203,7 @@ def renew_everything(self):


def feed_interpolated_input(self, x):
if self.phase == 'gtrns' and floor(self.resl)>2:
if self.phase == 'gtrns' and floor(self.resl)>2 and floor(self.resl)<=self.max_resl:
alpha = self.complete['gen']/100.0
transform = transforms.Compose( [ transforms.ToPILImage(),
transforms.Scale(size=int(pow(2,floor(self.resl)-1)), interpolation=0), # 0: nearest
Expand Down Expand Up @@ -305,19 +309,36 @@ def train(self):
self.tb.add_image_grid('grid/x_intp', 4, utils.adjust_dyn_range(self.x.data.float(), [-1,1], [0,1]), self.globalIter)


def get_state(self, target):
if target == 'gen':
state = {
'resl' : self.resl,
'state_dict' : self.G.module.state_dict(),
'optimizer' : self.opt_g.state_dict(),
}
return state
elif target == 'dis':
state = {
'resl' : self.resl,
'state_dict' : self.D.module.state_dict(),
'optimizer' : self.opt_d.state_dict(),
}
return state


def snapshot(self, path):
if not os.path.exists(path):
os.system('mkdir -p {}'.format(path))
# save every 100 tick if the network is in stab phase.
ndis = 'dis_R{}_T{}.pth'.format(int(floor(self.resl)), self.globalTick)
ngen = 'gen_R{}_T{}.pth'.format(int(floor(self.resl)), self.globalTick)
ndis = 'dis_R{}_T{}.pth.tar'.format(int(floor(self.resl)), self.globalTick)
ngen = 'gen_R{}_T{}.pth.tar'.format(int(floor(self.resl)), self.globalTick)
if self.globalTick%50==0:
if self.phase == 'gstab' or self.phase =='dstab':
if self.phase == 'gstab' or self.phase =='dstab' or self.phase == 'final':
save_path = os.path.join(path, ndis)
if not os.path.exists(save_path):
utils.save_model(self.D, save_path)
torch.save(self.get_state('dis'), save_path)
save_path = os.path.join(path, ngen)
utils.save_model(self.G, save_path)
torch.save(self.get_state('gen'), save_path)
print('[snapshot] model saved @ {}'.format(path))


Expand Down

0 comments on commit 17c1767

Please sign in to comment.