diff --git a/timm/kd/__init__.py b/timm/kd/__init__.py new file mode 100644 index 0000000000..8b3d7f2cea --- /dev/null +++ b/timm/kd/__init__.py @@ -0,0 +1,4 @@ +"""Knowledge Distillation module for timm""" +from .distillation import DistillationTeacher, apply_kd_loss + +__all__ = ['DistillationTeacher', 'apply_kd_loss'] diff --git a/timm/kd/distillation.py b/timm/kd/distillation.py new file mode 100644 index 0000000000..109d8703c2 --- /dev/null +++ b/timm/kd/distillation.py @@ -0,0 +1,142 @@ +"""Knowledge Distillation helpers for training with a teacher model.""" +import logging +from typing import Tuple + +import torch +import torch.nn as nn +import torchvision.transforms as T + +from timm.models import create_model + + +_logger = logging.getLogger(__name__) + + +class DistillationTeacher(nn.Module): + """Wrapper for a teacher model used in knowledge distillation. + + Creates and manages a pre-trained teacher model for knowledge distillation, + handling model compilation and normalization differences between teacher and student. + + Args: + model_name: Name of the teacher model to create + num_classes: Number of output classes + in_chans: Number of input channels + pretrained: Whether to load pretrained weights + device: Device to place the model on (default: 'cuda') + dtype: Model dtype (default: None, uses float32) + """ + + def __init__( + self, + model_name: str, + num_classes: int, + in_chans: int = 3, + device: torch.device = torch.device('cuda'), + dtype: torch.dtype = None, + ): + super().__init__() + + _logger.info(f"Creating KD teacher model: '{model_name}'") + + model_kd = create_model( + model_name=model_name, + num_classes=num_classes, + pretrained=True, + in_chans=in_chans, + ) + + model_kd = model_kd.to(device=device, dtype=dtype) + model_kd.eval() + + try: + model_kd = torch.compile(model_kd) + _logger.info("torch.compile applied successfully to KD teacher model") + except Exception as e: + _logger.warning(f"torch.compile failed with error {e}, continuing without compilation") + + self.model = model_kd + self.mean_model_kd = model_kd.pretrained_cfg['mean'] + self.std_model_kd = model_kd.pretrained_cfg['std'] + + def normalize_input( + self, + input: torch.Tensor, + student_model: nn.Module, + ) -> torch.Tensor: + """Normalize input to match teacher's expected normalization. + + Handles different normalization between teacher and student models by + converting the student's normalized input to the teacher's expected format. + + Args: + input: Input tensor (already normalized for student) + student_model: Student model to extract normalization params from + + Returns: + Input tensor normalized for the teacher model + """ + if hasattr(student_model, 'module'): + model_s = student_model.module + else: + model_s = student_model + + mean_student = model_s.pretrained_cfg['mean'] + std_student = model_s.pretrained_cfg['std'] + + input_kd = input + if mean_student != self.mean_model_kd or std_student != self.std_model_kd: + # Compute normalized std and mean transformations + std = tuple(t_std / s_std for t_std, s_std in zip(self.std_model_kd, std_student)) + transform_std = T.Normalize(mean=(0, 0, 0), std=std) + + mean = tuple(t_mean - s_mean for t_mean, s_mean in zip(self.mean_model_kd, mean_student)) + transform_mean = T.Normalize(mean=mean, std=(1, 1, 1)) + + input_kd = transform_mean(transform_std(input)) + + return input_kd + + +def apply_kd_loss( + loss: torch.Tensor, + student_output: torch.Tensor, + input: torch.Tensor, + student_model: nn.Module, + teacher_model: DistillationTeacher, + alpha_kd: float, + use_kd_only: bool = False, +) -> torch.Tensor: + """Apply knowledge distillation loss. + + Computes KL divergence between student and teacher outputs and combines + with the base loss (or replaces it if use_kd_only is True). + + Args: + loss: Base loss (e.g., cross-entropy with labels) + student_output: Logits from student model + input: Input tensor (already normalized for student) + student_model: Student model being trained + teacher_model: Teacher model for distillation + alpha_kd: Weight for the KD loss component + use_kd_only: If True, only use KD loss (ignore base loss) + + Returns: + Combined loss with KD component + """ + # Student probability calculation + prob_s = torch.nn.functional.log_softmax(student_output, dim=-1) + + # Teacher probability calculation + with torch.no_grad(): + input_kd = teacher_model.normalize_input(input, student_model) + out_t = teacher_model.model(input_kd.detach()) + prob_t = torch.nn.functional.softmax(out_t, dim=-1) + + # Compute KL divergence loss + kd_loss = alpha_kd * torch.nn.functional.kl_div(prob_s, prob_t, reduction='batchmean') + + if use_kd_only: + return kd_loss + else: + return loss + kd_loss diff --git a/train.py b/train.py index 131260dca4..19366b3ebb 100755 --- a/train.py +++ b/train.py @@ -41,6 +41,7 @@ from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler_v2, scheduler_kwargs from timm.utils import ApexScaler, NativeScaler +from timm.kd import DistillationTeacher, apply_kd_loss try: from apex import amp @@ -415,6 +416,14 @@ group.add_argument('--naflex-loss-scale', default='linear', type=str, help='Scale loss (gradient) by batch_size ("none", "sqrt", or "linear")') +# Knowledge Distillation parameters +parser.add_argument('--kd-model-name', default=None, type=str, + help='Name of teacher model for knowledge distillation') +parser.add_argument('--alpha-kd', default=5, type=float, + help='Weight for KD loss (default: 5)') +parser.add_argument('--use-kd-only-loss', action='store_true', default=False, + help='Use only KD loss, without cross-entropy loss') + def _parse_args(): # Do we have a config file to parse? @@ -529,6 +538,17 @@ def main(): if args.grad_checkpointing: model.set_grad_checkpointing(enable=True) + # Create the KD teacher model if specified + model_kd = None + if args.kd_model_name is not None: + model_kd = DistillationTeacher( + model_name=args.kd_model_name, + num_classes=args.num_classes, + in_chans=in_chans, + device=device, + dtype=model_dtype, + ) + if utils.is_primary(args): _logger.info( f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') @@ -1006,6 +1026,7 @@ def main(): mixup_fn=mixup_fn, num_updates_total=num_epochs * updates_per_epoch, naflex_mode=naflex_mode, + model_kd=model_kd, ) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): @@ -1109,6 +1130,7 @@ def train_one_epoch( mixup_fn=None, num_updates_total=None, naflex_mode=False, + model_kd=None, ): if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.prefetcher and loader.mixup_enabled: @@ -1155,6 +1177,19 @@ def _forward(): with amp_autocast(): output = model(input) _loss = loss_fn(output, target) + + # KD logic + if model_kd is not None: + _loss = apply_kd_loss( + loss=_loss, + student_output=output, + input=input, + student_model=model, + teacher_model=model_kd, + alpha_kd=args.alpha_kd, + use_kd_only=args.use_kd_only_loss, + ) + if accum_steps > 1: _loss /= accum_steps return _loss