-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
Description
Describe the bug
I managed to finetune 'vit_deit_base_patch16_384' on bigger image sizes (512,600) just calling timm.create_model with a diff img_size.
Works well. When I try to do this with 'vit_deit_base_distilled_patch16_384' I get an error during init:
in resize_pos_embed (see below)
From a first look, I think this code should be used in distilled and not distilled versions of vit - but I cannot see the difference - hmm
Desktop (please complete the following information):
kaggle tpu env
Additional context
in init(self, model_arch, n_class, input_size, pretrained)
3 super().init()
4 self.img_size=input_size
----> 5 self.model = timm.create_model(model_arch, img_size=self.img_size, pretrained=pretrained)
6 #print(self.model)
7 n_features = self.model.head.in_features
/kaggle/input/pytorch-image-models-034/pytorch-image-models-master/timm/models/factory.py in create_model(model_name, pretrained, checkpoint_path, scriptable, exportable, no_jit, **kwargs)
51 if is_model(model_name):
52 create_fn = model_entrypoint(model_name)
---> 53 model = create_fn(**model_args, **kwargs)
54 else:
55 raise RuntimeError('Unknown model (%s)' % model_name)
/kaggle/input/pytorch-image-models-034/pytorch-image-models-master/timm/models/vision_transformer.py in vit_deit_base_distilled_patch16_384(pretrained, **kwargs)
784 model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
785 model = _create_vision_transformer(
--> 786 'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
787 return model
/kaggle/input/pytorch-image-models-034/pytorch-image-models-master/timm/models/vision_transformer.py in _create_vision_transformer(variant, pretrained, distilled, **kwargs)
483 load_pretrained(
484 model, num_classes=num_classes, in_chans=kwargs.get('in_chans', 3),
--> 485 filter_fn=partial(checkpoint_filter_fn, model=model))
486 return model
487
/kaggle/input/pytorch-image-models-034/pytorch-image-models-master/timm/models/helpers.py in load_pretrained(model, cfg, num_classes, in_chans, filter_fn, strict, progress)
158 state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu')
159 if filter_fn is not None:
--> 160 state_dict = filter_fn(state_dict)
161
162 if in_chans == 1:
/kaggle/input/pytorch-image-models-034/pytorch-image-models-master/timm/models/vision_transformer.py in checkpoint_filter_fn(state_dict, model)
457 elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
458 # To resize pos embedding when using model at different size from pretrained weights
--> 459 v = resize_pos_embed(v, model.pos_embed)
460 out_dict[k] = v
461 return out_dict
/kaggle/input/pytorch-image-models-034/pytorch-image-models-master/timm/models/vision_transformer.py in resize_pos_embed(posemb, posemb_new)
437 gs_new = int(math.sqrt(ntok_new))
438 _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new)
--> 439 posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
440 posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear')
441 posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1)
RuntimeError: shape '[1, 24, 24, -1]' is invalid for input of size 443136