Skip to content

Commit

Permalink
Another batch of fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
fruffy committed Dec 14, 2019
1 parent 1974c81 commit bdfaed0
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 18 deletions.
16 changes: 6 additions & 10 deletions distillers/ab_distiller.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,16 +239,12 @@ def __init__(self, s_net, d_net, config):
# the student net is the base net
self.s_net = self.net
self.d_net = d_net
d_net.train()
d_net.s_net.train()
d_net.t_net.train()
# unfreeze the layers of the teacher
for param in d_net.t_net.parameters():
param.requires_grad = True
self.optimizer = optim.SGD([{'params': s_net.parameters()},
{'params': d_net.connectors.parameters()}],
lr=0.1, nesterov=True, momentum=0.9,
weight_decay=5e-4)
optim_params = [{"params": self.s_net.parameters()},
{"params": self.d_net.connectors.parameters()}]

# Retrieve preconfigured optimizers and schedulers for all runs
self.optimizer = self.optim_cls(optim_params, **self.optim_args)
self.scheduler = self.sched_cls(self.optimizer, **self.sched_args)

def calculate_loss(self, data, target):

Expand Down
10 changes: 8 additions & 2 deletions models/cifar10/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def _make_layer(self, block, planes, num_blocks, stride):
return nn.Sequential(*layers)

def forward(self, x, is_feat=False, use_relu=True):
out = F.relu(self.bn1(self.conv1(x)))
out = self.conv1(x)
out = self.bn1(out)
if use_relu:
out = F.relu(out)
b1 = self.layer1(out)
if use_relu:
b1 = F.relu(b1)
Expand Down Expand Up @@ -158,7 +161,10 @@ def _make_layer(self, block, planes, num_blocks, stride):
return nn.Sequential(*layers)

def forward(self, x, is_feat=False, use_relu=True):
out = F.relu(self.bn1(self.conv1(x)))
out = self.conv1(x)
out = self.bn1(out)
if use_relu:
out = F.relu(out)
b1 = self.layer1(out)
if use_relu:
b1 = F.relu(b1)
Expand Down
2 changes: 1 addition & 1 deletion optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def get_optimizer(optim_str, params):
if optim_str.lower() == "sgd":
optim_args["momentum"] = params["momentum"]
optim_args["weight_decay"] = params["weight_decay"]
optim_args["nesterov"] = False
optim_args["nesterov"] = True
return optim.SGD, optim_args
elif optim_str.lower() == "novograd":
optim_args["weight_decay"] = params["weight_decay"]
Expand Down
12 changes: 7 additions & 5 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ def __init__(self, net, config):
self.net = net
self.device = config["device"]
self.name = config["test_name"]

# Retrieve preconfigured optimizers and schedulers for all runs
optim_cls, optim_args = get_optimizer(config["optim"], config)
sched_cls, sched_args = get_scheduler(config["sched"], config)
self.optimizer = optim_cls(net.parameters(), **optim_args)
self.scheduler = sched_cls(self.optimizer, **sched_args)
optim = config["optim"]
sched = config["sched"]
self.optim_cls, self.optim_args = get_optimizer(optim, config)
self.sched_cls, self.sched_args = get_scheduler(sched, config)
self.optimizer = self.optim_cls(net.parameters(), **self.optim_args)
self.scheduler = self.sched_cls(self.optimizer, **self.sched_args)

self.loss_fun = nn.CrossEntropyLoss()
self.train_loader = config["train_loader"]
self.test_loader = config["test_loader"]
Expand Down

0 comments on commit bdfaed0

Please sign in to comment.