Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distill Problem #106

Closed
saijo0404 opened this issue Nov 9, 2022 · 4 comments
Closed

Distill Problem #106

saijo0404 opened this issue Nov 9, 2022 · 4 comments

Comments

@saijo0404
Copy link

I tried to train a pix2pix model on the edges2shoes-r dataset using train_full.sh.

#!/usr/bin/env bash
python distill.py --dataroot database/edges2shoes-r \
  --distiller resnet \
  --log_dir logs/pix2pix/edges2shoes-r/distill \
  --batch_size 4 \
  --restore_teacher_G_path logs/pix2pix/edges2shoes-r/train/checkpoints/latest_net_G.pth \
  --restore_pretrained_G_path logs/pix2pix/edges2shoes-r/train/checkpoints/latest_net_G.pth \
  --pretrained_netG resnet_9blocks \
  --teacher_netG resnet_9blocks \
  --student_netG resnet_9blocks \
  --restore_D_path logs/pix2pix/edges2shoes-r/train/checkpoints/latest_net_D.pth \
  --real_stat_path real_stat/edges2shoes-r_B.npz \
  --meta_path datasets/metas/edges2shoes-r/train1.meta 

After training, I used this bash, but I get an AssertionError.
In weight_transfer.py line 14, in transfer_Conv2d
assert isinstance(m1, nn.Conv2d) and isinstance(m2, (nn.Conv2d, SuperConv2d))
How can I solve this problem?

@lmxyy
Copy link
Collaborator

lmxyy commented Nov 9, 2022

Could you provide some more information? What is the type of your m1 and m2?

@saijo0404
Copy link
Author

I try to print m1 and m2 type, the result look like this.

distiller [ResnetDistiller] was created
Load network at logs/pix2pix/edges2shoes-r/train/checkpoints/latest_net_G.pth
isinstance(netA, nn.DataParallel):  False
isinstance(netB, nn.DataParallel):  False
isinstance(m1, nn.Conv2d) and isinstance(m2, (nn.Conv2d, SuperConv2d)):  True
m1 type:  <class 'torch.nn.modules.conv.Conv2d'>
m2 type:  <class 'torch.nn.modules.conv.Conv2d'>
isinstance(m1, nn.Conv2d) and isinstance(m2, (nn.Conv2d, SuperConv2d)):  True
m1 type:  <class 'torch.nn.modules.conv.Conv2d'>
m2 type:  <class 'torch.nn.modules.conv.Conv2d'>
isinstance(m1, nn.Conv2d) and isinstance(m2, (nn.Conv2d, SuperConv2d)):  True
m1 type:  <class 'torch.nn.modules.conv.Conv2d'>
m2 type:  <class 'torch.nn.modules.conv.Conv2d'>
isinstance(m1, nn.Conv2d) and isinstance(m2, (nn.Conv2d, SuperConv2d)):  False
m1 type:  <class 'models.modules.resnet_architecture.resnet_generator.ResnetBlock'>
m2 type:  <class 'models.modules.resnet_architecture.resnet_generator.ResnetBlock'>

@lmxyy
Copy link
Collaborator

lmxyy commented Nov 10, 2022

I see. This is a minor bug in weight_transfer.py because of a typo. I've fixed it in this commit. Could you pull the latest commit and try again?

@lmxyy
Copy link
Collaborator

lmxyy commented Nov 10, 2022

I will close this issue. Let me know if there are some further issues!

@lmxyy lmxyy closed this as completed Nov 10, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants