Skip to content

Commit

Permalink
last attempt to fix oh distillation
Browse files Browse the repository at this point in the history
  • Loading branch information
Fabian Ruffy Varga committed Dec 17, 2019
1 parent 73543b5 commit 8f4e4cf
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 45 deletions.
92 changes: 71 additions & 21 deletions distillers/oh_distiller.py
Expand Up @@ -85,45 +85,94 @@ def __init__(self, s_net, t_net):
i + 1), margin.unsqueeze(1).unsqueeze(2).unsqueeze(0).detach())

self.s_net = s_net
self.t_net = t_net

def forward(self, x):

def forward(self, x, t_feats=None):
s_feats, s_out = self.s_net.module.extract_feature(x, preReLU=True)
t_feats, t_out = self.t_net.module.extract_feature(x, preReLU=True)
s_feats_num = len(s_feats)

s_feats, s_pool, s_out = self.s_net(x, is_feat=True, use_relu=False)
loss_distill = 0
for i in range(s_feats_num):
s_feats[i] = self.connectors[i](s_feats[i])
loss_distill += distillation_loss(s_feats[i], t_feats[i].detach(), getattr(self, 'margin%d' % (i + 1))) \
/ 2 ** (s_feats_num - i - 1)

if t_feats:
s_feats_num = len(s_feats)
loss_distill = 0
for i in range(s_feats_num):
s_feats[i] = self.connectors[i](s_feats[i])
loss_distill += distillation_loss(s_feats[i], t_feats[i].detach(), getattr(self, 'margin%d' % (i + 1))) \
/ 2 ** (s_feats_num - i - 1)
return s_out, loss_distill
return s_out
return s_out, loss_distill


class OHTrainer(BaseTrainer):
def __init__(self, d_net, t_net, config):
def __init__(self, d_net, config):
# the student net is the base net
super(OHTrainer, self).__init__(d_net, config)
# decouple the teacher from the student
self.t_net = t_net
optim_params = [{"params": self.net.parameters()}]
super(OHTrainer, self).__init__(d_net.s_net, config)
# We train on the distillation net
self.d_net = d_net
optim_params = [{"params": self.d_net.s_net.parameters()},
{"params": self.d_net.connectors.parameters()}]

# Retrieve preconfigured optimizers and schedulers for all runs
self.optimizer = self.optim_cls(optim_params, **self.optim_args)
self.scheduler = self.sched_cls(self.optimizer, **self.sched_args)

def calculate_loss(self, data, target):
t_feats, t_pool, t_out = self.t_net(data, is_feat=True, use_relu=False)

s_out, loss_distill = self.net(data, t_feats)
loss_CE = self.loss_fun(s_out, target)
output, loss_distill = self.d_net(data)
loss_CE = self.loss_fun(output, target)

loss = loss_CE + loss_distill.sum() / self.batch_size / 1000

loss.backward()
self.optimizer.step()
return t_out, loss
return output, loss

def train_single_epoch(self, t_bar):
self.d_net.train()
self.d_net.s_net.train()
self.d_net.t_net.train()
total_correct = 0.0
total_loss = 0.0
len_train_set = len(self.train_loader.dataset)
for batch_idx, (x, y) in enumerate(self.train_loader):
x = x.to(self.device)
y = y.to(self.device)
self.optimizer.zero_grad()

# this function is implemented by the subclass
y_hat, loss = self.calculate_loss(x, y)

# Metric tracking boilerplate
pred = y_hat.data.max(1, keepdim=True)[1]
total_correct += pred.eq(y.data.view_as(pred)).sum()
total_loss += loss
curr_acc = 100.0 * (total_correct / float(len_train_set))
curr_loss = (total_loss / float(batch_idx))
t_bar.update(self.batch_size)
t_bar.set_postfix_str(f"Acc {curr_acc:.3f}% Loss {curr_loss:.3f}")
total_acc = float(total_correct / len_train_set)
return total_acc

def validate(self, epoch=0):
self.d_net.s_net.eval()
acc = 0.0
with torch.no_grad():
correct = 0
acc = 0
for images, labels in self.test_loader:
images = images.to(self.device)
labels = labels.to(self.device)
output = self.d_net.s_net(images, use_relu=False)
# Standard Learning Loss ( Classification Loss)
loss = self.loss_fun(output, labels)
# get the index of the max log-probability
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(labels.data.view_as(pred)).cpu().sum()

acc = float(correct) / len(self.test_loader.dataset)
print(f"\nEpoch {epoch}: Validation set: Average loss: {loss:.4f},"
f" Accuracy: {correct}/{len(self.test_loader.dataset)} "
f"({acc * 100.0:.3f}%)")
return acc


def run_oh_distillation(s_net, t_net, **params):
Expand All @@ -136,8 +185,9 @@ def run_oh_distillation(s_net, t_net, **params):
# Student training
# Define loss and the optimizer
print("---------- Training OKD Student -------")
params = params.copy()
d_net = Distiller(s_net, t_net).to(params["device"])
s_trainer = OHTrainer(d_net, t_net, config=params)
s_trainer = OHTrainer(d_net, config=params)
best_s_acc = s_trainer.train()

return best_s_acc
4 changes: 2 additions & 2 deletions evaluate_kd.py
Expand Up @@ -180,7 +180,7 @@ def test_pkd(s_net, t_net, params):


def test_oh(s_net, t_net, params):
t_net = freeze_teacher(t_net)
# do not freeze the teacher in oh distillation
best_acc = run_oh_distillation(s_net, t_net, **params)
return best_acc

Expand Down Expand Up @@ -333,4 +333,4 @@ def start_evaluation(args):

if __name__ == "__main__":
ARGS = parse_arguments()
start_evaluation(ARGS)
start_eva
91 changes: 69 additions & 22 deletions models/cifar10/resnet.py
Expand Up @@ -32,10 +32,10 @@ def __init__(self, in_planes, planes, stride=1):
)

def forward(self, x):
out = F.relu(x)
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
# out = F.relu(out)
return out


Expand All @@ -62,11 +62,11 @@ def __init__(self, in_planes, planes, stride=1):
)

def forward(self, x):
out = F.relu(x)
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
# out = F.relu(out)
return out


Expand All @@ -83,7 +83,7 @@ def __init__(self, block, num_blocks, num_classes=10):
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear = nn.Linear(512 * block.expansion, num_classes)
self.n_channels = [64, 128, 256, 512, 512 * block.expansion]
self.n_channels = [64, 128, 256, 512]

def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
Expand All @@ -98,23 +98,24 @@ def forward(self, x, is_feat=False, use_relu=True):
out = self.bn1(out)
if use_relu:
out = F.relu(out)
b1 = self.layer1(out)
feat1 = self.layer1(out)
if use_relu:
b1 = F.relu(b1)
b2 = self.layer2(b1)
feat1 = F.relu(feat1)
feat2 = self.layer2(feat1)
if use_relu:
b2 = F.relu(b2)
b3 = self.layer3(b2)
feat2 = F.relu(feat2)
feat3 = self.layer3(feat2)
if use_relu:
b3 = F.relu(b3)
b4 = self.layer4(b3)
b4 = F.relu(b4)
pool = F.avg_pool2d(b4, 4)
feat3 = F.relu(feat3)

feat4 = self.layer4(feat3)
feat4 = F.relu(feat4)
pool = F.avg_pool2d(feat4, 4)
pool = pool.view(pool.size(0), -1)
out = self.linear(pool)

if is_feat:
return[b1, b2, b3, b4], pool, out
return[feat1, feat2, feat3, feat4], pool, out

return out

Expand All @@ -137,6 +138,29 @@ def get_bn_before_relu(self):
def get_channel_num(self):
return self.n_channels

def extract_feature(self, x, preReLU=False):

x = self.conv1(x)
x = self.bn1(x)

feat1 = self.layer1(x)
feat2 = self.layer2(feat1)
feat3 = self.layer3(feat2)
feat4 = self.layer4(feat3)

x = F.relu(feat4)
x = F.avg_pool2d(x, 4)
x = x.view(x.size(0), -1)
out = self.linear(x)

if not preReLU:
feat1 = F.relu(feat1)
feat2 = F.relu(feat2)
feat3 = F.relu(feat3)
feat4 = F.relu(feat4)

return [feat1, feat2, feat3, feat4], out


class ResNetSmall(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
Expand All @@ -150,7 +174,7 @@ def __init__(self, block, num_blocks, num_classes=10):
self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
self.linear = nn.Linear(256 * block.expansion, num_classes)
self.n_channels = [16, 32, 64, 256 * block.expansion]
self.n_channels = [16, 32, 64]

def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
Expand All @@ -165,20 +189,22 @@ def forward(self, x, is_feat=False, use_relu=True):
out = self.bn1(out)
if use_relu:
out = F.relu(out)
b1 = self.layer1(out)
feat1 = self.layer1(out)
if use_relu:
b1 = F.relu(b1)
b2 = self.layer2(b1)
feat1 = F.relu(feat1)
feat2 = self.layer2(feat1)
if use_relu:
b2 = F.relu(b2)
b3 = self.layer3(b2)
b3 = F.relu(b3)
pool = F.avg_pool2d(b3, 4)
feat2 = F.relu(feat2)
feat3 = self.layer3(feat2)

# the last relu is always included
feat3 = F.relu(feat3)
pool = F.avg_pool2d(feat3, 4)
pool = pool.view(pool.size(0), -1)
out = self.linear(pool)

if is_feat:
return[b1, b2, b3], pool, out
return[feat1, feat2, feat3], pool, out

return out

Expand All @@ -199,6 +225,27 @@ def get_bn_before_relu(self):
def get_channel_num(self):
return self.n_channels

def extract_feature(self, x, preReLU=False):

x = self.conv1(x)
x = self.bn1(x)

feat1 = self.layer1(x)
feat2 = self.layer2(feat1)
feat3 = self.layer3(feat2)

x = F.relu(feat3)
x = F.avg_pool2d(x, 4)
x = x.view(x.size(0), -1)
out = self.linear(x)

if not preReLU:
feat1 = F.relu(feat1)
feat2 = F.relu(feat2)
feat3 = F.relu(feat3)

return [feat1, feat2, feat3], out


def resnet8(**kwargs):
return ResNetSmall(BasicBlock, [1, 1, 1], **kwargs)
Expand Down

0 comments on commit 8f4e4cf

Please sign in to comment.