Skip to content
This repository has been archived by the owner on Mar 15, 2024. It is now read-only.

Add Knowledge-Distillation #42

Merged
merged 10 commits into from
Jan 13, 2021
Merged
6 changes: 3 additions & 3 deletions engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@
from timm.data import Mixup
from timm.utils import accuracy, ModelEma

from losses import DistillationLoss
import utils


def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None):
# TODO fix this for finetuning
model.train()
criterion.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
Expand All @@ -40,7 +40,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,

with torch.cuda.amp.autocast():
outputs = model(samples)
loss = criterion(outputs, targets)
loss = criterion(samples, outputs, targets)

loss_value = loss.item()

Expand Down
59 changes: 59 additions & 0 deletions losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
from torch.nn import functional as F


class DistillationLoss(torch.nn.Module):
"""
This module wraps a standard criterion and adds an extra knowledge distillation loss by
taking a teacher model prediction and using it as additional supervision.
"""
def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
distillation_type: str, alpha: float, tau: float):
super().__init__()
self.base_criterion = base_criterion
self.teacher_model = teacher_model
assert distillation_type in ['none', 'soft', 'hard']
self.distillation_type = distillation_type
self.alpha = alpha
self.tau = tau

def forward(self, inputs, outputs, labels):
"""
Args:
inputs: The original inputs that are feed to the teacher model
outputs: the outputs of the model to be trained. It is expected to be
either a Tensor, or a Tuple[Tensor, Tensor], with the original output
in the first position and the distillation predictions as the second output
labels: the labels for the base criterion
"""
outputs_kd = None
if not isinstance(outputs, torch.Tensor):
# assume that the model outputs a tuple of [outputs, outputs_kd]
outputs, outputs_kd = outputs
base_loss = self.base_criterion(outputs, labels)
if self.distillation_type == 'none':
return base_loss

if outputs_kd is None:
raise ValueError("When knowledge distillation is enabled, the model is "
"expected to return a Tuple[Tensor, Tensor] with the output of the "
"class_token and the dist_token")
# don't backprop throught the teacher
with torch.no_grad():
teacher_outputs = self.teacher_model(inputs)

if self.distillation_type == 'soft':
T = self.tau
# taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
# with slight modifications
distillation_loss = F.kl_div(
F.log_softmax(outputs_kd / T, dim=1),
F.log_softmax(teacher_outputs / T, dim=1),
reduction='sum',
log_target=True
) * (T * T) / outputs_kd.numel()
elif self.distillation_type == 'hard':
distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))

loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
return loss
34 changes: 34 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from datasets import build_dataset
from engine import train_one_epoch, evaluate
from losses import DistillationLoss
from samplers import RASampler
import models
import utils
Expand Down Expand Up @@ -127,6 +128,14 @@ def get_args_parser():
parser.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')

# Distillation parameters
parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
help='Name of teacher model to train (default: "regnety_160"')
parser.add_argument('--teacher-path', type=str, default='')
parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
parser.add_argument('--distillation-tau', default=1.0, type=float, help="")

# Dataset parameters
parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str,
help='dataset path')
Expand Down Expand Up @@ -260,6 +269,31 @@ def main(args):
else:
criterion = torch.nn.CrossEntropyLoss()

teacher_model = None
if args.distillation_type != 'none':
assert args.teacher_path, 'need to specify teacher-path when using distillation'
print(f"Creating teacher model: {args.teacher_model}")
teacher_model = create_model(
args.teacher_model,
pretrained=False,
num_classes=args.nb_classes,
global_pool='avg',
)
if args.teacher_path.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.teacher_path, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.teacher_path, map_location='cpu')
teacher_model.load_state_dict(checkpoint['model'])
teacher_model.to(device)
teacher_model.eval()

# wrap the criterion in our custom DistillationLoss, which
# just dispatches to the original criterion if args.distillation_type is 'none'
criterion = DistillationLoss(
criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau
)

output_dir = Path(args.output_dir)
if args.resume:
if args.resume.startswith('https'):
Expand Down
88 changes: 88 additions & 0 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,49 @@

from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_


class DistilledVisionTransformer(VisionTransformer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
num_patches = self.patch_embed.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()

trunc_normal_(self.dist_token, std=.02)
trunc_normal_(self.pos_embed, std=.02)
self.head_dist.apply(self._init_weights)

def forward_features(self, x):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to add the dist_token
B = x.shape[0]
x = self.patch_embed(x)

cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
dist_token = self.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)

x = x + self.pos_embed
x = self.pos_drop(x)

for blk in self.blocks:
x = blk(x)

x = self.norm(x)
return x[:, 0], x[:, 1]

def forward(self, x):
x, x_dist = self.forward_features(x)
x = self.head(x)
x_dist = self.head_dist(x_dist)
if self.training:
return x, x_dist
else:
# during inference, return the average of both classifier predictions
return (x + x_dist) / 2


@register_model
Expand Down Expand Up @@ -55,3 +98,48 @@ def deit_base_patch16_224(pretrained=False, **kwargs):
)
model.load_state_dict(checkpoint["model"])
return model


@register_model
def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
model = DistilledVisionTransformer(
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model


@register_model
def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
model = DistilledVisionTransformer(
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model


@register_model
def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
model = DistilledVisionTransformer(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model