Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error finetuning from pretrained checkpoint #30

Closed
yadamonk opened this issue May 6, 2021 · 12 comments
Closed

Error finetuning from pretrained checkpoint #30

yadamonk opened this issue May 6, 2021 · 12 comments

Comments

@yadamonk
Copy link

yadamonk commented May 6, 2021

Hi all, I'm running into an error when trying to fine-tune from one of the pretrained checkpoints.

Code

!mkdir "$output"
!wget -q -O "$output/checkpoint.pth" https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth

!python -m torch.distributed.launch \
  --nproc_per_node=1 ./dino/main_dino.py \
  --arch deit_small \
  --data_path "$input" \
  --output_dir "$output"

Error

| distributed init (rank 0): env://
git:
  sha: 8aa93fdc90eae4b183c4e3c005174a9f634ecfbf, status: clean, branch: main

arch: deit_small
batch_size_per_gpu: 64
...
...
Student and Teacher are built: they are both deit_small network.
Loss, optimizer and schedulers ready.
Found checkpoint at ./drive/MyDrive/DINO/checkpoint.pth
=> failed to load student from checkpoint './drive/MyDrive/DINO/checkpoint.pth'
=> failed to load teacher from checkpoint './drive/MyDrive/DINO/checkpoint.pth'
=> failed to load optimizer from checkpoint './drive/MyDrive/DINO/checkpoint.pth'
=> failed to load fp16_scaler from checkpoint './drive/MyDrive/DINO/checkpoint.pth'
=> failed to load dino_loss from checkpoint './drive/MyDrive/DINO/checkpoint.pth'

Any suggestions would be very much appreciated.

@woctezuma
Copy link

woctezuma commented May 6, 2021

Get the full checkpoint.

Illustration

@yadamonk
Copy link
Author

yadamonk commented May 6, 2021

Thank you so much for the fast reply. With the full checkpoint I see two messages about incompatible keys and the training time is 0:00:00. However, once I specify more epochs than what was used for the pretrained model everything looks good.

@yadamonk yadamonk closed this as completed May 6, 2021
@yadamonk
Copy link
Author

yadamonk commented May 7, 2021

I just finished finetuning for 50 epochs but the self-attention maps look worse than those I get from the pretrained checkpoint. Not sure if that's due to a poor choice of args or related to the messages about incompatible keys.

Student and Teacher are built: they are both deit_small network.
Loss, optimizer and schedulers ready.
Found checkpoint at ./drive/MyDrive/DINO/checkpoint.pth
=> loaded student from checkpoint './drive/MyDrive/DINO/checkpoint.pth' with msg _IncompatibleKeys(missing_keys=['module.cls_token', 'module.pos_embed', 'module.patch_embed.proj.weight', 'module.patch_embed.proj.bias', 'module.blocks.0.norm1.weight', 'module.blocks.0.norm1.bias', 'module.blocks.0.attn.qkv.weight', 'module.blocks.0.attn.qkv.bias', 'module.blocks.0.attn.proj.weight', 'module.blocks.0.attn.proj.bias', 'module.blocks.0.norm2.weight', 'module.blocks.0.norm2.bias', 'module.blocks.0.mlp.fc1.weight', 'module.blocks.0.mlp.fc1.bias', 'module.blocks.0.mlp.fc2.weight', 'module.blocks.0.mlp.fc2.bias', 'module.blocks.1.norm1.weight', 'module.blocks.1.norm1.bias', 'module.blocks.1.attn.qkv.weight', 'module.blocks.1.attn.qkv.bias', 'module.blocks.1.attn.proj.weight', 'module.blocks.1.attn.proj.bias', 'module.blocks.1.norm2.weight', 'module.blocks.1.norm2.bias', 'module.blocks.1.mlp.fc1.weight', 'module.blocks.1.mlp.fc1.bias', 'module.blocks.1.mlp.fc2.weight', 'module.blocks.1.mlp.fc2.bias', 'module.blocks.2.norm1.weight', 'module.blocks.2.norm1.bias', 'module.blocks.2.attn.qkv.weight', 'module.blocks.2.attn.qkv.bias', 'module.blocks.2.attn.proj.weight', 'module.blocks.2.attn.proj.bias', 'module.blocks.2.norm2.weight', 'module.blocks.2.norm2.bias', 'module.blocks.2.mlp.fc1.weight', 'module.blocks.2.mlp.fc1.bias', 'module.blocks.2.mlp.fc2.weight', 'module.blocks.2.mlp.fc2.bias', 'module.blocks.3.norm1.weight', 'module.blocks.3.norm1.bias', 'module.blocks.3.attn.qkv.weight', 'module.blocks.3.attn.qkv.bias', 'module.blocks.3.attn.proj.weight', 'module.blocks.3.attn.proj.bias', 'module.blocks.3.norm2.weight', 'module.blocks.3.norm2.bias', 'module.blocks.3.mlp.fc1.weight', 'module.blocks.3.mlp.fc1.bias', 'module.blocks.3.mlp.fc2.weight', 'module.blocks.3.mlp.fc2.bias', 'module.blocks.4.norm1.weight', 'module.blocks.4.norm1.bias', 'module.blocks.4.attn.qkv.weight', 'module.blocks.4.attn.qkv.bias', 'module.blocks.4.attn.proj.weight', 'module.blocks.4.attn.proj.bias', 'module.blocks.4.norm2.weight', 'module.blocks.4.norm2.bias', 'module.blocks.4.mlp.fc1.weight', 'module.blocks.4.mlp.fc1.bias', 'module.blocks.4.mlp.fc2.weight', 'module.blocks.4.mlp.fc2.bias', 'module.blocks.5.norm1.weight', 'module.blocks.5.norm1.bias', 'module.blocks.5.attn.qkv.weight', 'module.blocks.5.attn.qkv.bias', 'module.blocks.5.attn.proj.weight', 'module.blocks.5.attn.proj.bias', 'module.blocks.5.norm2.weight', 'module.blocks.5.norm2.bias', 'module.blocks.5.mlp.fc1.weight', 'module.blocks.5.mlp.fc1.bias', 'module.blocks.5.mlp.fc2.weight', 'module.blocks.5.mlp.fc2.bias', 'module.blocks.6.norm1.weight', 'module.blocks.6.norm1.bias', 'module.blocks.6.attn.qkv.weight', 'module.blocks.6.attn.qkv.bias', 'module.blocks.6.attn.proj.weight', 'module.blocks.6.attn.proj.bias', 'module.blocks.6.norm2.weight', 'module.blocks.6.norm2.bias', 'module.blocks.6.mlp.fc1.weight', 'module.blocks.6.mlp.fc1.bias', 'module.blocks.6.mlp.fc2.weight', 'module.blocks.6.mlp.fc2.bias', 'module.blocks.7.norm1.weight', 'module.blocks.7.norm1.bias', 'module.blocks.7.attn.qkv.weight', 'module.blocks.7.attn.qkv.bias', 'module.blocks.7.attn.proj.weight', 'module.blocks.7.attn.proj.bias', 'module.blocks.7.norm2.weight', 'module.blocks.7.norm2.bias', 'module.blocks.7.mlp.fc1.weight', 'module.blocks.7.mlp.fc1.bias', 'module.blocks.7.mlp.fc2.weight', 'module.blocks.7.mlp.fc2.bias', 'module.blocks.8.norm1.weight', 'module.blocks.8.norm1.bias', 'module.blocks.8.attn.qkv.weight', 'module.blocks.8.attn.qkv.bias', 'module.blocks.8.attn.proj.weight', 'module.blocks.8.attn.proj.bias', 'module.blocks.8.norm2.weight', 'module.blocks.8.norm2.bias', 'module.blocks.8.mlp.fc1.weight', 'module.blocks.8.mlp.fc1.bias', 'module.blocks.8.mlp.fc2.weight', 'module.blocks.8.mlp.fc2.bias', 'module.blocks.9.norm1.weight', 'module.blocks.9.norm1.bias', 'module.blocks.9.attn.qkv.weight', 'module.blocks.9.attn.qkv.bias', 'module.blocks.9.attn.proj.weight', 'module.blocks.9.attn.proj.bias', 'module.blocks.9.norm2.weight', 'module.blocks.9.norm2.bias', 'module.blocks.9.mlp.fc1.weight', 'module.blocks.9.mlp.fc1.bias', 'module.blocks.9.mlp.fc2.weight', 'module.blocks.9.mlp.fc2.bias', 'module.blocks.10.norm1.weight', 'module.blocks.10.norm1.bias', 'module.blocks.10.attn.qkv.weight', 'module.blocks.10.attn.qkv.bias', 'module.blocks.10.attn.proj.weight', 'module.blocks.10.attn.proj.bias', 'module.blocks.10.norm2.weight', 'module.blocks.10.norm2.bias', 'module.blocks.10.mlp.fc1.weight', 'module.blocks.10.mlp.fc1.bias', 'module.blocks.10.mlp.fc2.weight', 'module.blocks.10.mlp.fc2.bias', 'module.blocks.11.norm1.weight', 'module.blocks.11.norm1.bias', 'module.blocks.11.attn.qkv.weight', 'module.blocks.11.attn.qkv.bias', 'module.blocks.11.attn.proj.weight', 'module.blocks.11.attn.proj.bias', 'module.blocks.11.norm2.weight', 'module.blocks.11.norm2.bias', 'module.blocks.11.mlp.fc1.weight', 'module.blocks.11.mlp.fc1.bias', 'module.blocks.11.mlp.fc2.weight', 'module.blocks.11.mlp.fc2.bias', 'module.norm.weight', 'module.norm.bias', 'module.head.mlp.0.weight', 'module.head.mlp.0.bias', 'module.head.mlp.2.weight', 'module.head.mlp.2.bias', 'module.head.mlp.4.weight', 'module.head.mlp.4.bias', 'module.head.last_layer.weight_g', 'module.head.last_layer.weight_v'], unexpected_keys=['cls_token', 'pos_embed', 'patch_embed.proj.weight', 'patch_embed.proj.bias', 'blocks.0.norm1.weight', 'blocks.0.norm1.bias', 'blocks.0.attn.qkv.weight', 'blocks.0.attn.qkv.bias', 'blocks.0.attn.proj.weight', 'blocks.0.attn.proj.bias', 'blocks.0.norm2.weight', 'blocks.0.norm2.bias', 'blocks.0.mlp.fc1.weight', 'blocks.0.mlp.fc1.bias', 'blocks.0.mlp.fc2.weight', 'blocks.0.mlp.fc2.bias', 'blocks.1.norm1.weight', 'blocks.1.norm1.bias', 'blocks.1.attn.qkv.weight', 'blocks.1.attn.qkv.bias', 'blocks.1.attn.proj.weight', 'blocks.1.attn.proj.bias', 'blocks.1.norm2.weight', 'blocks.1.norm2.bias', 'blocks.1.mlp.fc1.weight', 'blocks.1.mlp.fc1.bias', 'blocks.1.mlp.fc2.weight', 'blocks.1.mlp.fc2.bias', 'blocks.2.norm1.weight', 'blocks.2.norm1.bias', 'blocks.2.attn.qkv.weight', 'blocks.2.attn.qkv.bias', 'blocks.2.attn.proj.weight', 'blocks.2.attn.proj.bias', 'blocks.2.norm2.weight', 'blocks.2.norm2.bias', 'blocks.2.mlp.fc1.weight', 'blocks.2.mlp.fc1.bias', 'blocks.2.mlp.fc2.weight', 'blocks.2.mlp.fc2.bias', 'blocks.3.norm1.weight', 'blocks.3.norm1.bias', 'blocks.3.attn.qkv.weight', 'blocks.3.attn.qkv.bias', 'blocks.3.attn.proj.weight', 'blocks.3.attn.proj.bias', 'blocks.3.norm2.weight', 'blocks.3.norm2.bias', 'blocks.3.mlp.fc1.weight', 'blocks.3.mlp.fc1.bias', 'blocks.3.mlp.fc2.weight', 'blocks.3.mlp.fc2.bias', 'blocks.4.norm1.weight', 'blocks.4.norm1.bias', 'blocks.4.attn.qkv.weight', 'blocks.4.attn.qkv.bias', 'blocks.4.attn.proj.weight', 'blocks.4.attn.proj.bias', 'blocks.4.norm2.weight', 'blocks.4.norm2.bias', 'blocks.4.mlp.fc1.weight', 'blocks.4.mlp.fc1.bias', 'blocks.4.mlp.fc2.weight', 'blocks.4.mlp.fc2.bias', 'blocks.5.norm1.weight', 'blocks.5.norm1.bias', 'blocks.5.attn.qkv.weight', 'blocks.5.attn.qkv.bias', 'blocks.5.attn.proj.weight', 'blocks.5.attn.proj.bias', 'blocks.5.norm2.weight', 'blocks.5.norm2.bias', 'blocks.5.mlp.fc1.weight', 'blocks.5.mlp.fc1.bias', 'blocks.5.mlp.fc2.weight', 'blocks.5.mlp.fc2.bias', 'blocks.6.norm1.weight', 'blocks.6.norm1.bias', 'blocks.6.attn.qkv.weight', 'blocks.6.attn.qkv.bias', 'blocks.6.attn.proj.weight', 'blocks.6.attn.proj.bias', 'blocks.6.norm2.weight', 'blocks.6.norm2.bias', 'blocks.6.mlp.fc1.weight', 'blocks.6.mlp.fc1.bias', 'blocks.6.mlp.fc2.weight', 'blocks.6.mlp.fc2.bias', 'blocks.7.norm1.weight', 'blocks.7.norm1.bias', 'blocks.7.attn.qkv.weight', 'blocks.7.attn.qkv.bias', 'blocks.7.attn.proj.weight', 'blocks.7.attn.proj.bias', 'blocks.7.norm2.weight', 'blocks.7.norm2.bias', 'blocks.7.mlp.fc1.weight', 'blocks.7.mlp.fc1.bias', 'blocks.7.mlp.fc2.weight', 'blocks.7.mlp.fc2.bias', 'blocks.8.norm1.weight', 'blocks.8.norm1.bias', 'blocks.8.attn.qkv.weight', 'blocks.8.attn.qkv.bias', 'blocks.8.attn.proj.weight', 'blocks.8.attn.proj.bias', 'blocks.8.norm2.weight', 'blocks.8.norm2.bias', 'blocks.8.mlp.fc1.weight', 'blocks.8.mlp.fc1.bias', 'blocks.8.mlp.fc2.weight', 'blocks.8.mlp.fc2.bias', 'blocks.9.norm1.weight', 'blocks.9.norm1.bias', 'blocks.9.attn.qkv.weight', 'blocks.9.attn.qkv.bias', 'blocks.9.attn.proj.weight', 'blocks.9.attn.proj.bias', 'blocks.9.norm2.weight', 'blocks.9.norm2.bias', 'blocks.9.mlp.fc1.weight', 'blocks.9.mlp.fc1.bias', 'blocks.9.mlp.fc2.weight', 'blocks.9.mlp.fc2.bias', 'blocks.10.norm1.weight', 'blocks.10.norm1.bias', 'blocks.10.attn.qkv.weight', 'blocks.10.attn.qkv.bias', 'blocks.10.attn.proj.weight', 'blocks.10.attn.proj.bias', 'blocks.10.norm2.weight', 'blocks.10.norm2.bias', 'blocks.10.mlp.fc1.weight', 'blocks.10.mlp.fc1.bias', 'blocks.10.mlp.fc2.weight', 'blocks.10.mlp.fc2.bias', 'blocks.11.norm1.weight', 'blocks.11.norm1.bias', 'blocks.11.attn.qkv.weight', 'blocks.11.attn.qkv.bias', 'blocks.11.attn.proj.weight', 'blocks.11.attn.proj.bias', 'blocks.11.norm2.weight', 'blocks.11.norm2.bias', 'blocks.11.mlp.fc1.weight', 'blocks.11.mlp.fc1.bias', 'blocks.11.mlp.fc2.weight', 'blocks.11.mlp.fc2.bias', 'norm.weight', 'norm.bias', 'head.projection_head.0.weight', 'head.projection_head.0.bias', 'head.projection_head.2.weight', 'head.projection_head.2.bias', 'head.projection_head.4.weight', 'head.projection_head.4.bias', 'head.prototypes.weight_g', 'head.prototypes.weight_v'])
=> loaded teacher from checkpoint './drive/MyDrive/DINO/checkpoint.pth' with msg _IncompatibleKeys(missing_keys=['head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v'], unexpected_keys=['head.projection_head.0.weight', 'head.projection_head.0.bias', 'head.projection_head.2.weight', 'head.projection_head.2.bias', 'head.projection_head.4.weight', 'head.projection_head.4.bias', 'head.prototypes.weight_g', 'head.prototypes.weight_v'])
=> failed to load optimizer from checkpoint './drive/MyDrive/DINO/checkpoint.pth'
=> failed to load fp16_scaler from checkpoint './drive/MyDrive/DINO/checkpoint.pth'
=> failed to load dino_loss from checkpoint './drive/MyDrive/DINO/checkpoint.pth'

@yadamonk yadamonk reopened this May 7, 2021
@woctezuma
Copy link

@yadamonk
Copy link
Author

yadamonk commented May 7, 2021

Thanks for the suggestion @woctezuma. I'll definitely need to spend some time to find args that work well for my single GPU setup. The main reason I re-opened the issue is to double-check that the messages about incompatible keys can indeed be safely ignored. Maybe @mathildecaron31 could comment on that.

@yadamonk
Copy link
Author

yadamonk commented May 8, 2021

It looks like the checkpoints were trained on a slightly different version of the released code. Luckily it's not difficult to change the names of the affected keys.

!wget -q -O "checkpoint.pth" https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain_full_checkpoint.pth
import gc
import torch
checkpoint = torch.load("checkpoint.pth", map_location="cpu")
student = {}

for key, value in checkpoint['student'].items():

  if "projection_head" in key:
    student['module.' + key.replace("projection_head", "mlp")] = value

  elif "prototypes" in key:
    student['module.' + key.replace("prototypes", "last_layer")] = value
    
  else:
    student['module.' + key] = value
teacher = {}

for key, value in checkpoint['teacher'].items():

  if "projection_head" in key:
    teacher[key.replace("projection_head", "mlp")] = value

  elif "prototypes" in key:
    teacher[key.replace("prototypes", "last_layer")] = value

  else:
    teacher[key] = value
torch.save({
            'student': student,
            'teacher': teacher,
            'epoch': checkpoint['epoch'],
            'optimizer': checkpoint['optimizer']
            }, "checkpoint.pth")
del checkpoint, student, teacher
gc.collect();

Now training starts at a much smaller loss and I see the following message.

Found checkpoint at ./checkpoint.pth
=> loaded student from checkpoint './checkpoint.pth' with msg <All keys matched successfully>
=> loaded teacher from checkpoint './checkpoint.pth' with msg <All keys matched successfully>

@yadamonk yadamonk closed this as completed May 8, 2021
@mathildecaron31
Copy link
Contributor

Hi @yadamonk

Thanks for flagging this. I indeed need to change the keys from the released checkpoints to match this codebase. I changed the keys from the projection head during refactoring to prepare the code release.

@yadamonk
Copy link
Author

Hi @ymathildecaron31

Thank you so much for your wonderful work and all the time you're putting into helping others build on it.

@mathildecaron31
Copy link
Contributor

mathildecaron31 commented May 18, 2021

Thanks a lot for your kind words @yadamonk :) :).

@tejassp2002
Copy link

@mathildecaron31 Any Updates on this? I tried the method suggested by @yadamonk but it doesn't seem to work for me.
Thanks!

@yadamonk
Copy link
Author

yadamonk commented Jun 6, 2021

Hi @tejassp2002

The checkpoint seems to have been updated since I posted my workaround. If you replace the corresponding sections with the snippets below things should work again. Hope that helps.

student = {}

for key, value in checkpoint['student'].items():

  if "projection_head" in key:
    student['module.' + key.replace("projection_head", "mlp")] = value

  elif "prototypes" in key:
    student['module.' + key.replace("prototypes", "last_layer")] = value
    
  else:
    student['module.backbone.' + key] = value
teacher = {}

for key, value in checkpoint['teacher'].items():

  if "projection_head" in key:
    teacher[key.replace("projection_head", "mlp")] = value

  elif "prototypes" in key:
    teacher[key.replace("prototypes", "last_layer")] = value
    
  else:
    teacher['backbone.' + key] = value

@tejassp2002
Copy link

Hey @yadamonk, that really helped!
But the checkpoint loading of optimizer, fp16_scaler and dino_loss still fails, I think I can safely ignore that.
Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants