Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
kwotsin committed Aug 11, 2020
1 parent fd3c512 commit 88b1b71
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 22 deletions.
4 changes: 3 additions & 1 deletion tests/modules/test_layers.py
Expand Up @@ -56,7 +56,9 @@ def test_SNEmbedding(self):
num_classes = 10
X = torch.ones(self.N, dtype=torch.int64)
for default in [True, False]:
layer = layers.SNEmbedding(num_classes, self.n_out, default=default)
layer = layers.SNEmbedding(num_classes,
self.n_out,
default=default)

assert layer(X).shape == (self.N, self.n_out)

Expand Down
13 changes: 9 additions & 4 deletions tests/training/test_scheduler.py
Expand Up @@ -36,12 +36,13 @@ def test_linear_decay(self):
assert abs(2e-4 - self.get_lr(optG)) < 1e-5

else:
curr_lr = ((1 - (max(0, step - lr_scheduler.start_step) / (self.num_steps-lr_scheduler.start_step))) * self.lr_D)
curr_lr = ((1 - (max(0, step - lr_scheduler.start_step) /
(self.num_steps - lr_scheduler.start_step))) *
self.lr_D)

assert abs(curr_lr - self.get_lr(optD)) < 1e-5
assert abs(curr_lr - self.get_lr(optG)) < 1e-5


def test_no_decay(self):
optD = optim.Adam(self.netD.parameters(), self.lr_D, betas=(0.0, 0.9))
optG = optim.Adam(self.netG.parameters(), self.lr_G, betas=(0.0, 0.9))
Expand All @@ -60,8 +61,12 @@ def test_no_decay(self):

def test_arguments(self):
with pytest.raises(NotImplementedError):
optD = optim.Adam(self.netD.parameters(), self.lr_D, betas=(0.0, 0.9))
optG = optim.Adam(self.netG.parameters(), self.lr_G, betas=(0.0, 0.9))
optD = optim.Adam(self.netD.parameters(),
self.lr_D,
betas=(0.0, 0.9))
optG = optim.Adam(self.netG.parameters(),
self.lr_G,
betas=(0.0, 0.9))
scheduler.LRScheduler(lr_decay='does_not_exist',
optD=optD,
optG=optG,
Expand Down
26 changes: 13 additions & 13 deletions tests/training/test_trainer.py
Expand Up @@ -143,19 +143,19 @@ def test_attributes(self):

with pytest.raises(ValueError):
bad_trainer = Trainer(netD=self.netD,
netG=self.netG,
optD=self.optD,
optG=self.optG,
netG_ckpt_file=netG_ckpt_file,
netD_ckpt_file=netD_ckpt_file,
log_dir=os.path.join(self.log_dir, 'extra'),
dataloader=self.dataloader,
num_steps=-1000,
device=device,
save_steps=float('inf'),
log_steps=float('inf'),
vis_steps=float('inf'),
lr_decay='linear')
netG=self.netG,
optD=self.optD,
optG=self.optG,
netG_ckpt_file=netG_ckpt_file,
netD_ckpt_file=netD_ckpt_file,
log_dir=os.path.join(self.log_dir, 'extra'),
dataloader=self.dataloader,
num_steps=-1000,
device=device,
save_steps=float('inf'),
log_steps=float('inf'),
vis_steps=float('inf'),
lr_decay='linear')

def test_get_latest_checkpoint(self):
ckpt_files = [
Expand Down
2 changes: 1 addition & 1 deletion torch_mimicry/nets/wgan_gp/wgan_gp_resblocks.py
Expand Up @@ -88,7 +88,7 @@ def __init__(self,
self.norm2 = None

# TODO: Verify again. Interestingly, LN has no effect on FID. Not using LN
# has almost no difference in FID score.
# has almost no difference in FID score.
# def residual(self, x):
# r"""
# Helper function for feedforwarding through main layers.
Expand Down
14 changes: 11 additions & 3 deletions torch_mimicry/training/scheduler.py
Expand Up @@ -19,7 +19,13 @@ class LRScheduler:
lr_D (float): The initial learning rate of optD.
lr_G (float): The initial learning rate of optG.
"""
def __init__(self, lr_decay, optD, optG, num_steps, start_step=0, **kwargs):
def __init__(self,
lr_decay,
optD,
optG,
num_steps,
start_step=0,
**kwargs):
if lr_decay not in [None, 'None', 'linear']:
raise NotImplementedError(
"lr_decay {} is not currently supported.")
Expand Down Expand Up @@ -90,12 +96,14 @@ def step(self, log_data, global_step):
lr_D = self.linear_decay(optimizer=self.optD,
global_step=global_step,
lr_value_range=(self.lr_D, 0.0),
lr_step_range=(self.start_step, self.num_steps))
lr_step_range=(self.start_step,
self.num_steps))

lr_G = self.linear_decay(optimizer=self.optG,
global_step=global_step,
lr_value_range=(self.lr_G, 0.0),
lr_step_range=(self.start_step, self.num_steps))
lr_step_range=(self.start_step,
self.num_steps))

elif self.lr_decay in [None, "None"]:
lr_D = self.lr_D
Expand Down

0 comments on commit 88b1b71

Please sign in to comment.