Skip to content

Loading weights into timm models #353

Answered by rwightman
AlaaKhaddaj asked this question in Q&A
Discussion options

You must be logged in to vote

@AlaaKhaddaj if you used a timm- model, it's very easy, you just need to modify the state dict, and then re-save...

ckpt = torch.load('open_clip.pt')
sd_timm = {k.replace('module.visual.trunk.', ''): v for k, v in ckpt['state_dict'].items() if k.startswith('module.visual.trunk')}

The open-clip ViT model and ResNet model need remapping though, I don't have a remap for open_clip ResNet -> timm, but do have code for ViT. NOTE, the ViT model is different than the default timm vit, so needs to be one of the vit_xxx_patchxx_clip_xx models not the ones w/o the 'clip' in them, there is an extra norm near the beginning of the network.

https://github.com/rwightman/pytorch-image-models/blob/94a9159…

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
0 replies
Answer selected by AlaaKhaddaj
Comment options

You must be logged in to vote
2 replies
@rwightman
Comment options

@vturrisi
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants
Converted from issue

This discussion was converted from issue #343 on January 12, 2023 05:20.