Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to train with custom datasets ? #7

Open
ManuBN786 opened this issue Mar 27, 2023 · 7 comments
Open

How to train with custom datasets ? #7

ManuBN786 opened this issue Mar 27, 2023 · 7 comments

Comments

@ManuBN786
Copy link

I trained a resnet34 teacher on my custom dataset with 9 classes. I arranged the dataset in the imagenet format.
I modified the dataset/builder.py like this:

pre-configuration for the dataset

if args.dataset == 'imagenet':
    args.data_path = 'data/imagenet' if args.data_path == '' else args.data_path
    args.num_classes = 9
    args.input_shape = (3, 384, 384)

I used the command "python tools/train.py --dataset imagenet --data-path data/imagenet/ --model resnet34 -c configs/strategies/resnet/resnet.yaml --teacher-pretrained --image-mean 0.604 0.327 0.249 --image-std 0.109 0.076 0.070 -b 32 --experiment teacher_model_train --epochs 100"

Even after 100 epochs it show the best.pt accuracy as 0.3 !!

After that I tried to train a student resnet18 with the command:

"python tools/train.py --dataset imagenet --data-path data/imagenet/ --model resnet18 -c configs/strategies/distill/resnet_dist.yaml --image-mean 0.604 0.327 0.249 --image-std 0.109 0.076 0.070 --teacher-pretrained --teacher-ckpt experiments/teacher_model_train/best.pth.tar -b 16 --experiment student_model_train --epochs 100"

it shows this error:

12:29:01 INFO Model resnet18 created, params: 11.181 M, FLOPs: 5.330 G
12:29:02 INFO Loading pretrained checkpoint from experiments/teacher_model_train/best.pth.tar
Traceback (most recent call last):
File "tools/train.py", line 363, in
main()
File "tools/train.py", line 91, in main
teacher_model = build_model(args, args.teacher_model, args.teacher_pretrained, args.teacher_ckpt)
File "/home/manu/PycharmProjects/DIST_KD/classification/tools/models/builder.py", line 71, in build_model
model.load_state_dict(ckpt, strict=False)
File "/home/manu/.virtualenvs/dl4cv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1497, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ResNet:
size mismatch for fc.weight: copying a param with shape torch.Size([9, 512]) from checkpoint, the shape in current model is torch.Size([1000, 512]).
size mismatch for fc.bias: copying a param with shape torch.Size([9]) from checkpoint, the shape in current model is torch.Size([1000]).

Please tell me how to train with custom datasets.

@ManuBN786
Copy link
Author

I could fix the error by by changing the teacher model name in 'configs/strategies/distill/resnet_dist.yaml' from 'tv_resnet34' to 'resnet34'.
Now the student model trains well.

But I don't know how to improve the teacher model accuracy

@hunto
Copy link
Owner

hunto commented Mar 31, 2023

Dear @ManuBN786 ,

Sorry for the late reply. Have you tried your dataset and training settings on your training framework or other frameworks?

@ManuBN786
Copy link
Author

Yes on a resnet50 from pytorch, it give a validation accuracy of 0.93.

I dont know how using DSIT_KD the validation accuracy is so poor

@hunto
Copy link
Owner

hunto commented Mar 31, 2023

One bug I can find is that your training uses input images with 384x384 resolution, but the resolution in our framework is set to 224 with hard code. (see build_train_transforms and build_val_transforms in https://github.com/hunto/image_classification_sota/blob/main/lib/dataset/transform.py)

You should manually change all the 224 to 384 at L21, L32, and L61; and change 256 to 440 at L60.

@ManuBN786
Copy link
Author

Ok. Thanks for letting me know.

@ManuBN786
Copy link
Author

low_acc

I did all of the above mentioned for image size 384, but I still get a very low accuracy for the teacher.

@hunto
Copy link
Owner

hunto commented Apr 14, 2023

It's difficult for me to identify the differences between this repo and the example code by pytorch. If you want to use DIST KD in your project, I think the easiest way is to add KD code in our existing and valid code (You just need to initialize a pretrained teacher, compute its outputs wrt the batch input, and compute and backward the KD loss).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants