Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

yolo nas pose F.conv2d(input, weight, bias, self.stride, RuntimeError: expected scalar type Byte but found Float #1969

Closed
wesboyt opened this issue Apr 24, 2024 · 3 comments

Comments

@wesboyt
Copy link

wesboyt commented Apr 24, 2024

馃悰 Describe the bug

Hello I'm attempting to compare yolo nas pose with my custom yolov8 pose.
Please give me advice on training my model using my custom dataset, which is 1 class with 12 keypoints. Right now I am recieving the error:
venv\lib\site-packages\torch\nn\modules\conv.py", line 456, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: expected scalar type Byte but found Float

I reused your animal example which im very confident is broken and pulled out the relevant config parameters. I'm also using a cvat server generated coco keypoints json file and your coco_pose_estimation_dataset: COCOPoseEstimationDataset.
It seems like its starting a training loop and failing the first forward call.

`import json
from typing import Union

import super_gradients
import yaml
from super_gradients.common.object_names import Models
from super_gradients.training import Trainer
from super_gradients.training.datasets.pose_estimation_datasets import coco_pose_estimation_dataset
from super_gradients.training.metrics import PoseEstimationMetrics
from super_gradients.training.models.pose_estimation_models import YoloNASPosePostPredictionCallback
from super_gradients.training.utils import EarlyStop
from super_gradients.training.utils.callbacks import ExtremeBatchPoseEstimationVisualizationCallback, Phase
from torch.utils.data import DataLoader
from super_gradients.training.datasets.pose_estimation_datasets import YoloNASPoseCollateFN
from super_gradients.training.metrics import PoseEstimationMetrics

CHECKPOINT_DIR = './pose'
trainer = Trainer(experiment_name='my_first_yn_pose_run', ckpt_root_dir=CHECKPOINT_DIR)

post_prediction_callback = YoloNASPosePostPredictionCallback(
pose_confidence_threshold = 0.01,
nms_iou_threshold = 0.7,
pre_nms_max_predictions = 300,
post_nms_max_predictions = 30,
)

metrics = PoseEstimationMetrics(
num_joints = 12,
oks_sigmas = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
max_objects_per_image = 30,
post_prediction_callback = post_prediction_callback,
)

dataset = coco_pose_estimation_dataset.COCOPoseEstimationDataset(
data_dir='./pose/',
images_dir='images/',
json_file='annotations/person_keypoints_default.json',
include_empty_samples=True,
transforms=[],
edge_links=[],
edge_colors=[],
keypoint_colors=[]
)

visualization_callback = ExtremeBatchPoseEstimationVisualizationCallback(
keypoint_colors = [],
edge_colors = [],
edge_links = [],
loss_to_monitor = "YoloNASPoseLoss/loss",
max = True,
freq = 1,
max_images = 16,
enable_on_train_loader = True,
enable_on_valid_loader = True,
post_prediction_callback = post_prediction_callback,
)

early_stop = EarlyStop(
phase = Phase.VALIDATION_EPOCH_END,
monitor = "AP",
mode = "max",
min_delta = 0.0001,
patience = 100,
verbose = True,
)

#yolo_nas = super_gradients.training.models.get("yolo_nas_pose_l", pretrained_weights="coco_pose")
yolo_nas_pose = super_gradients.training.models.get(Models.YOLO_NAS_POSE_N, num_classes=1)

train_dataloader_params = {
'shuffle': True,
'batch_size': 3,
'drop_last': True,
'pin_memory': False,
'collate_fn': YoloNASPoseCollateFN()
}

train_params = {
"warmup_mode": "LinearBatchLRWarmup",
"warmup_initial_lr": 1e-8,
"lr_warmup_epochs": 2,
"initial_lr": 5e-5,
"lr_mode": "cosine",
"cosine_final_lr_ratio": 5e-3,
"max_epochs": 10,
"zero_weight_decay_on_bias_and_bn": True,
"batch_accumulate": 1,
"average_best_models": True,
"save_ckpt_epoch_list": [5, 10, 15, 20],
"loss": "yolo_nas_pose_loss",
"criterion_params": {
"oks_sigmas": [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
"classification_loss_weight": 1,
"classification_loss_type": "focal",
"regression_iou_loss_type": "ciou",
"iou_loss_weight": 2.5,
"dfl_loss_weight": 0.01,
"pose_cls_loss_weight": 1.0,
"pose_reg_loss_weight": 34.0,
"pose_classification_loss_type": "focal",
"rescale_pose_loss_with_assigned_score": True,
"assigner_multiply_by_pose_oks": True,
},
"optimizer": "AdamW",
"optimizer_params": {
"weight_decay": 0.000001
},
"ema": True,
"ema_params": {
"decay": 0.997,
"decay_type": "threshold"
},
"mixed_precision": True,
"sync_bn": False,
"valid_metrics_list": [metrics],
"phase_callbacks": [visualization_callback, early_stop],
"pre_prediction_callback": None,
"metric_to_watch": "AP",
"greater_metric_to_watch_is_better": True,
"convert": "all"
}

train_dataloader = DataLoader(dataset, **train_dataloader_params)
valid_dataloader = DataLoader(dataset, **train_dataloader_params)
trainer.train(model=yolo_nas_pose,
training_params=train_params,
train_loader=train_dataloader,
valid_loader=valid_dataloader,
)

`

Versions

I

@BloodAxe
Copy link
Collaborator

Greetings.
Usually this error message indicates that you are sending a uint8 type input to the model and not floats. Which may indicate that you don't have any image normalization. And indeed based on your code you are not using any transformations in training/validation datasets. You will get poor results.

Please follow your tutorial notebook on fine-tuning the YoloNAS-Pose: https://github.com/Deci-AI/super-gradients/blob/master/notebooks/YoloNAS_Pose_Fine_Tuning_Animals_Pose_Dataset.ipynb

@wesboyt
Copy link
Author

wesboyt commented Apr 25, 2024

thank you, will update with results.

Normalize in this context means turning pixel coordinates into floats proportional to the image?

@wesboyt
Copy link
Author

wesboyt commented Apr 25, 2024

Thank you, I was missing the normalization transform from the example. appreciate the prompt response sir!

@wesboyt wesboyt closed this as completed Apr 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants