From b48e88cd81dd8254f78ca494860706456c646808 Mon Sep 17 00:00:00 2001 From: Tal <21198860+mrT23@users.noreply.github.com> Date: Wed, 15 Oct 2025 14:35:06 +0300 Subject: [PATCH 1/3] Add knowledge distillation model and loss function support --- timm/utils/model_kd.py | 77 ++++++++++++++++++++++++++++++++++++++++++ train.py | 21 ++++++++++++ 2 files changed, 98 insertions(+) create mode 100644 timm/utils/model_kd.py diff --git a/timm/utils/model_kd.py b/timm/utils/model_kd.py new file mode 100644 index 0000000000..45d50833ba --- /dev/null +++ b/timm/utils/model_kd.py @@ -0,0 +1,77 @@ +import logging +import torch +import torch.nn as nn +import torchvision.transforms as T +from timm.models import create_model + +_logger = logging.getLogger(__name__) + +class build_kd_model(nn.Module): + def __init__(self, args): + super(build_kd_model, self).__init__() + + _logger.info(f"Creating KD model: from '{args.kd_model_name}'") + in_chans = 3 + if args.in_chans is not None: + in_chans = args.in_chans + model_kd = create_model( + model_name=args.kd_model_name, + num_classes=args.num_classes, + pretrained=True, + in_chans=in_chans) + + # compile model + model_kd.cpu().eval() + try: + model_kd = torch.compile(model_kd) + _logger.info(f"torch.compile applied successfully to KD model") + except Exception as e: + _logger.warning(f"torch.compile failed with error {e}, continuing KD model without torch compilation") + + self.model = model_kd.cuda() + self.mean_model_kd = model_kd.default_cfg['mean'] + self.std_model_kd = model_kd.default_cfg['std'] + + # handling different normalization of teacher and student + def normalize_input(self, input, student_model): + if hasattr(student_model, 'module'): + model_s = student_model.module + else: + model_s = student_model + + mean_student = model_s.default_cfg['mean'] + std_student = model_s.default_cfg['std'] + + input_kd = input + if mean_student != self.mean_model_kd or std_student != self.std_model_kd: + std = (self.std_model_kd[0] / std_student[0], self.std_model_kd[1] / std_student[1], + self.std_model_kd[2] / std_student[2]) + transform_std = T.Normalize(mean=(0, 0, 0), std=std) + + mean = (self.mean_model_kd[0] - mean_student[0], self.mean_model_kd[1] - mean_student[1], + self.mean_model_kd[2] - mean_student[2]) + transform_mean = T.Normalize(mean=mean, std=(1, 1, 1)) + + input_kd = transform_mean(transform_std(input)) + + return input_kd + + +def add_kd_loss(_loss, output, input, model, model_kd, args): + # student probability calculation + prob_s = torch.nn.functional.log_softmax(output, dim=-1) + + # teacher probability calculation + with torch.no_grad(): + input_kd = model_kd.normalize_input(input, model) + out_t = model_kd.model(input_kd.detach()) + prob_t = torch.nn.functional.softmax(out_t, dim=-1) + + # adding KL loss + if not args.use_kd_only_loss: + _loss += args.alpha_kd * torch.nn.functional.kl_div(prob_s, prob_t, reduction='batchmean') + else: # only kd + _loss = args.alpha_kd * torch.nn.functional.kl_div(prob_s, prob_t, reduction='batchmean') + + return _loss + diff --git a/train.py b/train.py index 131260dca4..32a3b683fd 100755 --- a/train.py +++ b/train.py @@ -41,6 +41,7 @@ from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler_v2, scheduler_kwargs from timm.utils import ApexScaler, NativeScaler +from timm.utils.model_kd import build_kd_model, add_kd_loss try: from apex import amp @@ -415,6 +416,14 @@ group.add_argument('--naflex-loss-scale', default='linear', type=str, help='Scale loss (gradient) by batch_size ("none", "sqrt", or "linear")') +# Knowledge Distillation parameters +parser.add_argument('--kd-model-name', default=None, type=str, + help='Name of teacher model for knowledge distillation') +parser.add_argument('--alpha-kd', default=5, type=float, + help='Weight for KD loss (default: 5)') +parser.add_argument('--use-kd-only-loss', action='store_true', default=False, + help='Use only KD loss, without cross-entropy loss') + def _parse_args(): # Do we have a config file to parse? @@ -480,6 +489,11 @@ def main(): utils.random_seed(args.seed, args.rank) + # Create the KD teacher model if specified + model_kd = None + if args.kd_model_name is not None: + model_kd = build_kd_model(args) + if args.fuser: utils.set_jit_fuser(args.fuser) if args.fast_norm: @@ -1006,6 +1020,7 @@ def main(): mixup_fn=mixup_fn, num_updates_total=num_epochs * updates_per_epoch, naflex_mode=naflex_mode, + model_kd=model_kd, ) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): @@ -1109,6 +1124,7 @@ def train_one_epoch( mixup_fn=None, num_updates_total=None, naflex_mode=False, + model_kd=None, ): if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.prefetcher and loader.mixup_enabled: @@ -1155,6 +1171,11 @@ def _forward(): with amp_autocast(): output = model(input) _loss = loss_fn(output, target) + + # KD logic + if model_kd is not None: + _loss= add_kd_loss(_loss, output, input, model, model_kd, args) + if accum_steps > 1: _loss /= accum_steps return _loss From d705d6782b07bb465645f9271a9ce2e182ad2cc6 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 17 Oct 2025 12:45:37 -0700 Subject: [PATCH 2/3] Cleanup distillation code --- timm/kd/__init__.py | 4 ++ timm/kd/distillation.py | 142 ++++++++++++++++++++++++++++++++++++++++ timm/utils/model_kd.py | 77 ---------------------- train.py | 28 ++++++-- 4 files changed, 167 insertions(+), 84 deletions(-) create mode 100644 timm/kd/__init__.py create mode 100644 timm/kd/distillation.py delete mode 100644 timm/utils/model_kd.py diff --git a/timm/kd/__init__.py b/timm/kd/__init__.py new file mode 100644 index 0000000000..8b3d7f2cea --- /dev/null +++ b/timm/kd/__init__.py @@ -0,0 +1,4 @@ +"""Knowledge Distillation module for timm""" +from .distillation import DistillationTeacher, apply_kd_loss + +__all__ = ['DistillationTeacher', 'apply_kd_loss'] diff --git a/timm/kd/distillation.py b/timm/kd/distillation.py new file mode 100644 index 0000000000..b9d051993d --- /dev/null +++ b/timm/kd/distillation.py @@ -0,0 +1,142 @@ +"""Knowledge Distillation helpers for training with a teacher model.""" +import logging +from typing import Tuple + +import torch +import torch.nn as nn +import torchvision.transforms as T + +from timm.models import create_model + + +_logger = logging.getLogger(__name__) + + +class DistillationTeacher(nn.Module): + """Wrapper for a teacher model used in knowledge distillation. + + Creates and manages a pre-trained teacher model for knowledge distillation, + handling model compilation and normalization differences between teacher and student. + + Args: + model_name: Name of the teacher model to create + num_classes: Number of output classes + in_chans: Number of input channels + pretrained: Whether to load pretrained weights + device: Device to place the model on (default: 'cuda') + dtype: Model dtype (default: None, uses float32) + """ + + def __init__( + self, + model_name: str, + num_classes: int, + in_chans: int = 3, + device: torch.device = torch.device('cuda'), + dtype: torch.dtype = None, + ): + super().__init__() + + _logger.info(f"Creating KD teacher model: '{model_name}'") + + model_kd = create_model( + model_name=model_name, + num_classes=num_classes, + pretrained=True, + in_chans=in_chans, + ) + + model_kd = model_kd.to(device=device, dtype=dtype) + model_kd.eval() + + try: + model_kd = torch.compile(model_kd) + _logger.info("torch.compile applied successfully to KD teacher model") + except Exception as e: + _logger.warning(f"torch.compile failed with error {e}, continuing without compilation") + + self.model = model_kd + self.mean_model_kd = model_kd.pretrained_cfg['mean'] + self.std_model_kd = model_kd.pretrained_cfg['std'] + + def normalize_input( + self, + input: torch.Tensor, + student_model: nn.Module, + ) -> torch.Tensor: + """Normalize input to match teacher's expected normalization. + + Handles different normalization between teacher and student models by + converting the student's normalized input to the teacher's expected format. + + Args: + input: Input tensor (already normalized for student) + student_model: Student model to extract normalization params from + + Returns: + Input tensor normalized for the teacher model + """ + if hasattr(student_model, 'module'): + model_s = student_model.module + else: + model_s = student_model + + mean_student = model_s.pretrained_cfg['mean'] + std_student = model_s.pretrained_cfg['std'] + + input_kd = input + if mean_student != self.mean_model_kd or std_student != self.std_model_kd: + # Compute normalized std and mean transformations + std = tuple(t_std / s_std for t_std, s_std in zip(self.std_model_kd, std_student)) + transform_std = T.Normalize(mean=(0, 0, 0), std=std) + + mean = tuple(t_mean - s_mean for t_mean, s_mean in zip(self.mean_model_kd, mean_student)) + transform_mean = T.Normalize(mean=mean, std=(1, 1, 1)) + + input_kd = transform_mean(transform_std(input)) + + return input_kd + + +def apply_kd_loss( + loss: torch.Tensor, + student_output: torch.Tensor, + input: torch.Tensor, + student_model: nn.Module, + teacher_model: DistillationTeacher, + alpha_kd: float, + use_kd_only: bool = False, +) -> torch.Tensor: + """Apply knowledge distillation loss. + + Computes KL divergence between student and teacher outputs and combines + with the base loss (or replaces it if use_kd_only is True). + + Args: + loss: Base loss (e.g., cross-entropy with labels) + student_output: Logits from student model + input: Input tensor (already normalized for student) + student_model: Student model being trained + teacher_model: Teacher model for distillation + alpha_kd: Weight for the KD loss component + use_kd_only: If True, only use KD loss (ignore base loss) + + Returns: + Combined loss with KD component + """ + # Student probability calculation + prob_s = torch.nn.functional.log_softmax(student_output, dim=-1) + + # Teacher probability calculation + with torch.inference_mode(): + input_kd = teacher_model.normalize_input(input, student_model) + out_t = teacher_model.model(input_kd.detach()) + prob_t = torch.nn.functional.softmax(out_t, dim=-1) + + # Compute KL divergence loss + kd_loss = alpha_kd * torch.nn.functional.kl_div(prob_s, prob_t, reduction='batchmean') + + if use_kd_only: + return kd_loss + else: + return loss + kd_loss diff --git a/timm/utils/model_kd.py b/timm/utils/model_kd.py deleted file mode 100644 index 45d50833ba..0000000000 --- a/timm/utils/model_kd.py +++ /dev/null @@ -1,77 +0,0 @@ -import logging -import torch -import torch.nn as nn -import torchvision.transforms as T -from timm.models import create_model - -_logger = logging.getLogger(__name__) - -class build_kd_model(nn.Module): - def __init__(self, args): - super(build_kd_model, self).__init__() - - _logger.info(f"Creating KD model: from '{args.kd_model_name}'") - in_chans = 3 - if args.in_chans is not None: - in_chans = args.in_chans - model_kd = create_model( - model_name=args.kd_model_name, - num_classes=args.num_classes, - pretrained=True, - in_chans=in_chans) - - # compile model - model_kd.cpu().eval() - try: - model_kd = torch.compile(model_kd) - _logger.info(f"torch.compile applied successfully to KD model") - except Exception as e: - _logger.warning(f"torch.compile failed with error {e}, continuing KD model without torch compilation") - - self.model = model_kd.cuda() - self.mean_model_kd = model_kd.default_cfg['mean'] - self.std_model_kd = model_kd.default_cfg['std'] - - # handling different normalization of teacher and student - def normalize_input(self, input, student_model): - if hasattr(student_model, 'module'): - model_s = student_model.module - else: - model_s = student_model - - mean_student = model_s.default_cfg['mean'] - std_student = model_s.default_cfg['std'] - - input_kd = input - if mean_student != self.mean_model_kd or std_student != self.std_model_kd: - std = (self.std_model_kd[0] / std_student[0], self.std_model_kd[1] / std_student[1], - self.std_model_kd[2] / std_student[2]) - transform_std = T.Normalize(mean=(0, 0, 0), std=std) - - mean = (self.mean_model_kd[0] - mean_student[0], self.mean_model_kd[1] - mean_student[1], - self.mean_model_kd[2] - mean_student[2]) - transform_mean = T.Normalize(mean=mean, std=(1, 1, 1)) - - input_kd = transform_mean(transform_std(input)) - - return input_kd - - -def add_kd_loss(_loss, output, input, model, model_kd, args): - # student probability calculation - prob_s = torch.nn.functional.log_softmax(output, dim=-1) - - # teacher probability calculation - with torch.no_grad(): - input_kd = model_kd.normalize_input(input, model) - out_t = model_kd.model(input_kd.detach()) - prob_t = torch.nn.functional.softmax(out_t, dim=-1) - - # adding KL loss - if not args.use_kd_only_loss: - _loss += args.alpha_kd * torch.nn.functional.kl_div(prob_s, prob_t, reduction='batchmean') - else: # only kd - _loss = args.alpha_kd * torch.nn.functional.kl_div(prob_s, prob_t, reduction='batchmean') - - return _loss - diff --git a/train.py b/train.py index 32a3b683fd..19366b3ebb 100755 --- a/train.py +++ b/train.py @@ -41,7 +41,7 @@ from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler_v2, scheduler_kwargs from timm.utils import ApexScaler, NativeScaler -from timm.utils.model_kd import build_kd_model, add_kd_loss +from timm.kd import DistillationTeacher, apply_kd_loss try: from apex import amp @@ -489,11 +489,6 @@ def main(): utils.random_seed(args.seed, args.rank) - # Create the KD teacher model if specified - model_kd = None - if args.kd_model_name is not None: - model_kd = build_kd_model(args) - if args.fuser: utils.set_jit_fuser(args.fuser) if args.fast_norm: @@ -543,6 +538,17 @@ def main(): if args.grad_checkpointing: model.set_grad_checkpointing(enable=True) + # Create the KD teacher model if specified + model_kd = None + if args.kd_model_name is not None: + model_kd = DistillationTeacher( + model_name=args.kd_model_name, + num_classes=args.num_classes, + in_chans=in_chans, + device=device, + dtype=model_dtype, + ) + if utils.is_primary(args): _logger.info( f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') @@ -1174,7 +1180,15 @@ def _forward(): # KD logic if model_kd is not None: - _loss= add_kd_loss(_loss, output, input, model, model_kd, args) + _loss = apply_kd_loss( + loss=_loss, + student_output=output, + input=input, + student_model=model, + teacher_model=model_kd, + alpha_kd=args.alpha_kd, + use_kd_only=args.use_kd_only_loss, + ) if accum_steps > 1: _loss /= accum_steps From 6dcbc22c4678c4279f5bf19410d764aab4fb476a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 17 Oct 2025 12:54:08 -0700 Subject: [PATCH 3/3] Keep as no_grad --- timm/kd/distillation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/kd/distillation.py b/timm/kd/distillation.py index b9d051993d..109d8703c2 100644 --- a/timm/kd/distillation.py +++ b/timm/kd/distillation.py @@ -128,7 +128,7 @@ def apply_kd_loss( prob_s = torch.nn.functional.log_softmax(student_output, dim=-1) # Teacher probability calculation - with torch.inference_mode(): + with torch.no_grad(): input_kd = teacher_model.normalize_input(input, student_model) out_t = teacher_model.model(input_kd.detach()) prob_t = torch.nn.functional.softmax(out_t, dim=-1)