You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
i want to train knowledge distillation to object detection but failed
can you guide me to train?
i think object detection is different for image classification in cross entropy loss
print('Num classes in the model:', student_model.num_classes)
from super_gradients.training import training_hyperparams
from super_gradients.training.losses import KDLogitsLoss, CrossEntropyLoss
from torchvision import transforms
from multiprocessing import freeze_support
kd_params = {
"max_epochs": 20, # We will stop after 3 epochs because it is slow to train on google collab
'lr_cooldown_epochs': 0, # We don't want to use lr cooldown since we only train for 3 epochs
'lr_warmup_epochs': 0, # We don't want to use lr warmup since we only train for 3 epochs
"loss": KDLogitsLoss(distillation_loss_coeff=0.8, task_loss_fn=CrossEntropyLoss()),
"loss_logging_items_names": ["Loss", "Task Loss", "Distillation Loss"]}
Sorry to say that,
but KD for object detection is not supported at the moment.
KD trainer can be used for image classification & image segmentation tasks, but not for OD.
💡 Your
Question
i want to train knowledge distillation to object detection but failed
can you guide me to train?
i think object detection is different for image classification in cross entropy loss
my code is this
import torch
print(f'{torch.cuda.device_count()} GPU found: {torch.cuda.get_device_name("cuda")}')
from super_gradients import Trainer
from super_gradients.training import MultiGPUMode
from super_gradients.training import KDTrainer
CHECKPOINT_DIR = './' # Local path
trainer = Trainer(experiment_name='transfer_learning_object_detection_yolonas', ckpt_root_dir=CHECKPOINT_DIR)
experiment_name = "kd_coco_yolonas"
kd_trainer = KDTrainer(experiment_name=experiment_name, ckpt_root_dir=CHECKPOINT_DIR)
num_classes = 80
from super_gradients.training.dataloaders import coco2017_train, coco2017_val
train_dataloader = coco2017_train(dataloader_params={"batch_size":16, "shuffle":True},
dataset_params={"data_dir": "c:/Users/user/PycharmProjects/pythonProject/datasets/coco"})
val_dataloader = coco2017_val(dataloader_params={"batch_size":16, "shuffle":True},
dataset_params={"data_dir": "c:/Users/user/PycharmProjects/pythonProject/datasets/coco"})
from super_gradients.training import models
from super_gradients.common.object_names import Models
student_model = models.get(Models.YOLO_NAS_S, pretrained_weights="coco", num_classes=num_classes)
teacher_model = models.get(Models.YOLO_NAS_L, pretrained_weights="coco", num_classes=num_classes)
print('Num classes in the model:', student_model.num_classes)
from super_gradients.training import training_hyperparams
from super_gradients.training.losses import KDLogitsLoss, CrossEntropyLoss
from torchvision import transforms
from multiprocessing import freeze_support
train_params = training_hyperparams.get('coco2017_yolo_nas_s')
train_params['max_epochs'] = 5
train_params['lr_warmup_epochs'] = 0
train_params['lr_cooldown_epochs'] = 0
train_params['criterion_params']['num_classes'] = num_classes
train_params['average_best_models'] = False
train_params['initial_lr'] = 0.0005
train_params['cosine_final_lr_ratio'] = 0.9
train_params['mixed_precision'] = False
kd_params = {
"max_epochs": 20, # We will stop after 3 epochs because it is slow to train on google collab
'lr_cooldown_epochs': 0, # We don't want to use lr cooldown since we only train for 3 epochs
'lr_warmup_epochs': 0, # We don't want to use lr warmup since we only train for 3 epochs
"loss": KDLogitsLoss(distillation_loss_coeff=0.8, task_loss_fn=CrossEntropyLoss()),
"loss_logging_items_names": ["Loss", "Task Loss", "Distillation Loss"]}
training_params = training_hyperparams.get('coco2017_yolo_nas_s', overriding_params=kd_params)
arch_params={"teacher_input_adapter": transforms.Resize(640)}
if name == 'main':
freeze_support()
thank you if you write solution for my question
Versions
No response
The text was updated successfully, but these errors were encountered: