From 8f4e4cf935fef63174d39966e9b096c8d81b456c Mon Sep 17 00:00:00 2001 From: Fabian Ruffy Varga Date: Tue, 17 Dec 2019 14:46:34 -0500 Subject: [PATCH] last attempt to fix oh distillation --- distillers/oh_distiller.py | 92 +++++++++++++++++++++++++++++--------- evaluate_kd.py | 4 +- models/cifar10/resnet.py | 91 ++++++++++++++++++++++++++++--------- 3 files changed, 142 insertions(+), 45 deletions(-) diff --git a/distillers/oh_distiller.py b/distillers/oh_distiller.py index 6a02af3..93c064a 100644 --- a/distillers/oh_distiller.py +++ b/distillers/oh_distiller.py @@ -85,45 +85,94 @@ def __init__(self, s_net, t_net): i + 1), margin.unsqueeze(1).unsqueeze(2).unsqueeze(0).detach()) self.s_net = s_net + self.t_net = t_net + + def forward(self, x): - def forward(self, x, t_feats=None): + s_feats, s_out = self.s_net.module.extract_feature(x, preReLU=True) + t_feats, t_out = self.t_net.module.extract_feature(x, preReLU=True) + s_feats_num = len(s_feats) - s_feats, s_pool, s_out = self.s_net(x, is_feat=True, use_relu=False) + loss_distill = 0 + for i in range(s_feats_num): + s_feats[i] = self.connectors[i](s_feats[i]) + loss_distill += distillation_loss(s_feats[i], t_feats[i].detach(), getattr(self, 'margin%d' % (i + 1))) \ + / 2 ** (s_feats_num - i - 1) - if t_feats: - s_feats_num = len(s_feats) - loss_distill = 0 - for i in range(s_feats_num): - s_feats[i] = self.connectors[i](s_feats[i]) - loss_distill += distillation_loss(s_feats[i], t_feats[i].detach(), getattr(self, 'margin%d' % (i + 1))) \ - / 2 ** (s_feats_num - i - 1) - return s_out, loss_distill - return s_out + return s_out, loss_distill class OHTrainer(BaseTrainer): - def __init__(self, d_net, t_net, config): + def __init__(self, d_net, config): # the student net is the base net - super(OHTrainer, self).__init__(d_net, config) - # decouple the teacher from the student - self.t_net = t_net - optim_params = [{"params": self.net.parameters()}] + super(OHTrainer, self).__init__(d_net.s_net, config) + # We train on the distillation net + self.d_net = d_net + optim_params = [{"params": self.d_net.s_net.parameters()}, + {"params": self.d_net.connectors.parameters()}] # Retrieve preconfigured optimizers and schedulers for all runs self.optimizer = self.optim_cls(optim_params, **self.optim_args) self.scheduler = self.sched_cls(self.optimizer, **self.sched_args) def calculate_loss(self, data, target): - t_feats, t_pool, t_out = self.t_net(data, is_feat=True, use_relu=False) - s_out, loss_distill = self.net(data, t_feats) - loss_CE = self.loss_fun(s_out, target) + output, loss_distill = self.d_net(data) + loss_CE = self.loss_fun(output, target) loss = loss_CE + loss_distill.sum() / self.batch_size / 1000 loss.backward() self.optimizer.step() - return t_out, loss + return output, loss + + def train_single_epoch(self, t_bar): + self.d_net.train() + self.d_net.s_net.train() + self.d_net.t_net.train() + total_correct = 0.0 + total_loss = 0.0 + len_train_set = len(self.train_loader.dataset) + for batch_idx, (x, y) in enumerate(self.train_loader): + x = x.to(self.device) + y = y.to(self.device) + self.optimizer.zero_grad() + + # this function is implemented by the subclass + y_hat, loss = self.calculate_loss(x, y) + + # Metric tracking boilerplate + pred = y_hat.data.max(1, keepdim=True)[1] + total_correct += pred.eq(y.data.view_as(pred)).sum() + total_loss += loss + curr_acc = 100.0 * (total_correct / float(len_train_set)) + curr_loss = (total_loss / float(batch_idx)) + t_bar.update(self.batch_size) + t_bar.set_postfix_str(f"Acc {curr_acc:.3f}% Loss {curr_loss:.3f}") + total_acc = float(total_correct / len_train_set) + return total_acc + + def validate(self, epoch=0): + self.d_net.s_net.eval() + acc = 0.0 + with torch.no_grad(): + correct = 0 + acc = 0 + for images, labels in self.test_loader: + images = images.to(self.device) + labels = labels.to(self.device) + output = self.d_net.s_net(images, use_relu=False) + # Standard Learning Loss ( Classification Loss) + loss = self.loss_fun(output, labels) + # get the index of the max log-probability + pred = output.data.max(1, keepdim=True)[1] + correct += pred.eq(labels.data.view_as(pred)).cpu().sum() + + acc = float(correct) / len(self.test_loader.dataset) + print(f"\nEpoch {epoch}: Validation set: Average loss: {loss:.4f}," + f" Accuracy: {correct}/{len(self.test_loader.dataset)} " + f"({acc * 100.0:.3f}%)") + return acc def run_oh_distillation(s_net, t_net, **params): @@ -136,8 +185,9 @@ def run_oh_distillation(s_net, t_net, **params): # Student training # Define loss and the optimizer print("---------- Training OKD Student -------") + params = params.copy() d_net = Distiller(s_net, t_net).to(params["device"]) - s_trainer = OHTrainer(d_net, t_net, config=params) + s_trainer = OHTrainer(d_net, config=params) best_s_acc = s_trainer.train() return best_s_acc diff --git a/evaluate_kd.py b/evaluate_kd.py index 626ac1d..960f3a7 100644 --- a/evaluate_kd.py +++ b/evaluate_kd.py @@ -180,7 +180,7 @@ def test_pkd(s_net, t_net, params): def test_oh(s_net, t_net, params): - t_net = freeze_teacher(t_net) + # do not freeze the teacher in oh distillation best_acc = run_oh_distillation(s_net, t_net, **params) return best_acc @@ -333,4 +333,4 @@ def start_evaluation(args): if __name__ == "__main__": ARGS = parse_arguments() - start_evaluation(ARGS) + start_eva \ No newline at end of file diff --git a/models/cifar10/resnet.py b/models/cifar10/resnet.py index 78060f4..fc0e137 100644 --- a/models/cifar10/resnet.py +++ b/models/cifar10/resnet.py @@ -32,10 +32,10 @@ def __init__(self, in_planes, planes, stride=1): ) def forward(self, x): + out = F.relu(x) out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) - # out = F.relu(out) return out @@ -62,11 +62,11 @@ def __init__(self, in_planes, planes, stride=1): ) def forward(self, x): + out = F.relu(x) out = F.relu(self.bn1(self.conv1(x))) out = F.relu(self.bn2(self.conv2(out))) out = self.bn3(self.conv3(out)) out += self.shortcut(x) - # out = F.relu(out) return out @@ -83,7 +83,7 @@ def __init__(self, block, num_blocks, num_classes=10): self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) self.linear = nn.Linear(512 * block.expansion, num_classes) - self.n_channels = [64, 128, 256, 512, 512 * block.expansion] + self.n_channels = [64, 128, 256, 512] def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) @@ -98,23 +98,24 @@ def forward(self, x, is_feat=False, use_relu=True): out = self.bn1(out) if use_relu: out = F.relu(out) - b1 = self.layer1(out) + feat1 = self.layer1(out) if use_relu: - b1 = F.relu(b1) - b2 = self.layer2(b1) + feat1 = F.relu(feat1) + feat2 = self.layer2(feat1) if use_relu: - b2 = F.relu(b2) - b3 = self.layer3(b2) + feat2 = F.relu(feat2) + feat3 = self.layer3(feat2) if use_relu: - b3 = F.relu(b3) - b4 = self.layer4(b3) - b4 = F.relu(b4) - pool = F.avg_pool2d(b4, 4) + feat3 = F.relu(feat3) + + feat4 = self.layer4(feat3) + feat4 = F.relu(feat4) + pool = F.avg_pool2d(feat4, 4) pool = pool.view(pool.size(0), -1) out = self.linear(pool) if is_feat: - return[b1, b2, b3, b4], pool, out + return[feat1, feat2, feat3, feat4], pool, out return out @@ -137,6 +138,29 @@ def get_bn_before_relu(self): def get_channel_num(self): return self.n_channels + def extract_feature(self, x, preReLU=False): + + x = self.conv1(x) + x = self.bn1(x) + + feat1 = self.layer1(x) + feat2 = self.layer2(feat1) + feat3 = self.layer3(feat2) + feat4 = self.layer4(feat3) + + x = F.relu(feat4) + x = F.avg_pool2d(x, 4) + x = x.view(x.size(0), -1) + out = self.linear(x) + + if not preReLU: + feat1 = F.relu(feat1) + feat2 = F.relu(feat2) + feat3 = F.relu(feat3) + feat4 = F.relu(feat4) + + return [feat1, feat2, feat3, feat4], out + class ResNetSmall(nn.Module): def __init__(self, block, num_blocks, num_classes=10): @@ -150,7 +174,7 @@ def __init__(self, block, num_blocks, num_classes=10): self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) self.linear = nn.Linear(256 * block.expansion, num_classes) - self.n_channels = [16, 32, 64, 256 * block.expansion] + self.n_channels = [16, 32, 64] def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) @@ -165,20 +189,22 @@ def forward(self, x, is_feat=False, use_relu=True): out = self.bn1(out) if use_relu: out = F.relu(out) - b1 = self.layer1(out) + feat1 = self.layer1(out) if use_relu: - b1 = F.relu(b1) - b2 = self.layer2(b1) + feat1 = F.relu(feat1) + feat2 = self.layer2(feat1) if use_relu: - b2 = F.relu(b2) - b3 = self.layer3(b2) - b3 = F.relu(b3) - pool = F.avg_pool2d(b3, 4) + feat2 = F.relu(feat2) + feat3 = self.layer3(feat2) + + # the last relu is always included + feat3 = F.relu(feat3) + pool = F.avg_pool2d(feat3, 4) pool = pool.view(pool.size(0), -1) out = self.linear(pool) if is_feat: - return[b1, b2, b3], pool, out + return[feat1, feat2, feat3], pool, out return out @@ -199,6 +225,27 @@ def get_bn_before_relu(self): def get_channel_num(self): return self.n_channels + def extract_feature(self, x, preReLU=False): + + x = self.conv1(x) + x = self.bn1(x) + + feat1 = self.layer1(x) + feat2 = self.layer2(feat1) + feat3 = self.layer3(feat2) + + x = F.relu(feat3) + x = F.avg_pool2d(x, 4) + x = x.view(x.size(0), -1) + out = self.linear(x) + + if not preReLU: + feat1 = F.relu(feat1) + feat2 = F.relu(feat2) + feat3 = F.relu(feat3) + + return [feat1, feat2, feat3], out + def resnet8(**kwargs): return ResNetSmall(BasicBlock, [1, 1, 1], **kwargs)