Skip to content

Commit

Permalink
fixed docstrings and kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
stes committed Nov 26, 2018
1 parent abecab8 commit cea9ffd
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
7 changes: 4 additions & 3 deletions salad/solver/da/coral.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,10 @@ def __init__(self, model, dataset, *args, **kwargs):

class CentroidDistanceLossSolver(CorrelationDistanceSolver):
"""
Minimal Entropy Correlation Alignment for Unsupervised Domain Adaptation
Paper: https://openreview.net/pdf?id=rJWechg0Z
and: https://arxiv.org/pdf/1705.08180.pdf
Notes
-----
Needs work.
"""

def __init__(self, model, dataset, *args, **kwargs):
Expand Down
11 changes: 8 additions & 3 deletions salad/solver/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ def derive_losses(self, batch):

class ConditionalGANSolver(Solver):

""" Train a class conditional GAN model
"""

names = ['D_GAN', 'D_CL', 'D_CON']
n_classes = [1, 11, 3]

Expand Down Expand Up @@ -125,12 +130,12 @@ def _init_models(self, **kwargs):
D.weight_init(mean=0.0, std=0.02)
self.register_model(D, name)

def _init_optims(self, **kwargs):
opt = torch.optim.Adam(self.model.parameters(),lr = learningrate, betas = (.5,.999))
def _init_optims(self, lr = 2e-4, beta1 = .5, beta2 = .999, **kwargs):
opt = torch.optim.Adam(self.model.parameters(),lr = lr, betas = (beta1, beta2))
self.register_optimizer(opt, CGANLoss(), name="Generator")

for D in self.discriminators:
optim = torch.optim.Adam(D.parameters(), lr = learningrate, betas = (.5,.999))
optim = torch.optim.Adam(D.parameters(), lr = lr, betas = (beta1, beta2))
self.register_optim(optim, CGANLoss())

def format_train_report(self, losses):
Expand Down

0 comments on commit cea9ffd

Please sign in to comment.