Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

knowledge distillation to object detection(YOLONAS) #1992

Closed
kimsjpk1 opened this issue May 13, 2024 · 1 comment
Closed

knowledge distillation to object detection(YOLONAS) #1992

kimsjpk1 opened this issue May 13, 2024 · 1 comment

Comments

@kimsjpk1
Copy link

kimsjpk1 commented May 13, 2024

💡 Your Question

i want to train knowledge distillation to object detection but failed
can you guide me to train?
i think object detection is different for image classification in cross entropy loss

my code is this

import torch

print(f'{torch.cuda.device_count()} GPU found: {torch.cuda.get_device_name("cuda")}')

from super_gradients import Trainer
from super_gradients.training import MultiGPUMode
from super_gradients.training import KDTrainer

CHECKPOINT_DIR = './' # Local path
trainer = Trainer(experiment_name='transfer_learning_object_detection_yolonas', ckpt_root_dir=CHECKPOINT_DIR)
experiment_name = "kd_coco_yolonas"

kd_trainer = KDTrainer(experiment_name=experiment_name, ckpt_root_dir=CHECKPOINT_DIR)

num_classes = 80

from super_gradients.training.dataloaders import coco2017_train, coco2017_val

train_dataloader = coco2017_train(dataloader_params={"batch_size":16, "shuffle":True},
dataset_params={"data_dir": "c:/Users/user/PycharmProjects/pythonProject/datasets/coco"})

val_dataloader = coco2017_val(dataloader_params={"batch_size":16, "shuffle":True},
dataset_params={"data_dir": "c:/Users/user/PycharmProjects/pythonProject/datasets/coco"})

from super_gradients.training import models
from super_gradients.common.object_names import Models

student_model = models.get(Models.YOLO_NAS_S, pretrained_weights="coco", num_classes=num_classes)
teacher_model = models.get(Models.YOLO_NAS_L, pretrained_weights="coco", num_classes=num_classes)

print('Num classes in the model:', student_model.num_classes)

from super_gradients.training import training_hyperparams
from super_gradients.training.losses import KDLogitsLoss, CrossEntropyLoss
from torchvision import transforms
from multiprocessing import freeze_support

train_params = training_hyperparams.get('coco2017_yolo_nas_s')
train_params['max_epochs'] = 5
train_params['lr_warmup_epochs'] = 0
train_params['lr_cooldown_epochs'] = 0
train_params['criterion_params']['num_classes'] = num_classes
train_params['average_best_models'] = False
train_params['initial_lr'] = 0.0005
train_params['cosine_final_lr_ratio'] = 0.9
train_params['mixed_precision'] = False

kd_params = {
"max_epochs": 20, # We will stop after 3 epochs because it is slow to train on google collab
'lr_cooldown_epochs': 0, # We don't want to use lr cooldown since we only train for 3 epochs
'lr_warmup_epochs': 0, # We don't want to use lr warmup since we only train for 3 epochs
"loss": KDLogitsLoss(distillation_loss_coeff=0.8, task_loss_fn=CrossEntropyLoss()),
"loss_logging_items_names": ["Loss", "Task Loss", "Distillation Loss"]}

training_params = training_hyperparams.get('coco2017_yolo_nas_s', overriding_params=kd_params)
arch_params={"teacher_input_adapter": transforms.Resize(640)}

if name == 'main':
freeze_support()

kd_trainer.train(training_params=training_params,
             student=student_model,
             teacher=teacher_model,
             kd_architecture="kd_module",
             kd_arch_params=arch_params,
             train_loader=train_dataloader, valid_loader=val_dataloader)

#trainer.train(model=student_model, training_params=train_params, train_loader=train_dataloader, valid_loader=val_dataloader)

thank you if you write solution for my question

Versions

No response

@BloodAxe
Copy link
Collaborator

Sorry to say that,
but KD for object detection is not supported at the moment.
KD trainer can be used for image classification & image segmentation tasks, but not for OD.

@BloodAxe BloodAxe closed this as not planned Won't fix, can't repro, duplicate, stale May 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants