Skip to content

Commit

Permalink
[upd] (#38, #14) update viz_reconstruction.ipynb for visualizing your…
Browse files Browse the repository at this point in the history
… own model
  • Loading branch information
keyu-tian committed May 25, 2023
1 parent f02a5a9 commit 1468df8
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 83 deletions.
4 changes: 2 additions & 2 deletions downstream_d2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ Before fine-tuning a ResNet50 pre-trained by SparK, you should first convert our

```shell script
$ cd /path/to/SparK/downstream_d2
$ python3 convert-timm-to-d2.py /some/path/to/timm_resnet50_1kpretrained.pth d2-style.pkl
$ python3 convert-timm-to-d2.py /some/path/to/resnet50_1kpretrained_timm_style.pth d2-style.pkl
```

For a ResNet50, you should see a log reporting `len(state)==318`:
```text
[convert] .pkl is generated! (from `/some/path/to/timm_resnet50_1kpretrained.pth`, to `d2-style.pkl`, len(state)==318)
[convert] .pkl is generated! (from `/some/path/to/resnet50_1kpretrained_timm_style.pth`, to `d2-style.pkl`, len(state)==318)
```

Then run fine-tuning on single machine with 8 gpus:
Expand Down
4 changes: 2 additions & 2 deletions downstream_imagenet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ All the other configurations have their default values, listed in [/downstream_i
You can overwrite any defaults by `--bs=1024` or something like that.


Here is an example to pretrain a ResNet50 on an 8-GPU single machine:
Here is an example to pretrain a ConvNeXt-Small on an 8-GPU single machine:
```shell script
$ cd /path/to/SparK/downstream_imagenet
$ torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port=<some_port> main.py \
--data_path=/path/to/imagenet --exp_name=<your_exp_name> --exp_dir=/path/to/logdir \
--model=resnet50 --resume_from=/some/path/to/timm_resnet50_1kpretrained.pth
--model=convnext_small --resume_from=/some/path/to/convnextS_1kpretrained_official_style.pth
```

For multiple machines, change the `--nnodes` and `--master_addr` to your configurations. E.g.:
Expand Down
2 changes: 1 addition & 1 deletion downstream_imagenet/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def load_checkpoint(resume_from, model_without_ddp, ema_module, optimizer):
# return 0, '[no performance_desc]'
print(f'[try to resume from file `{resume_from}`]')
checkpoint = torch.load(resume_from, map_location='cpu')
assert checkpoint.get('is_pretrain', False) == False, 'Please do not use `*_still_pretraining.pth`, which is ONLY for resuming the pretraining. Use `*_1kpretrained.pth` or `*_1kfinetuned*.pth` instead.'
assert checkpoint.get('is_pretrain', False) == False, 'Please do not use `*_withdecoder_1kpretrained_spark_style.pth`, which is ONLY for resuming the pretraining. Use `*_1kpretrained_timm_style.pth` or `*_1kfinetuned*.pth` instead.'

ep_start, performance_desc = checkpoint.get('epoch', -1) + 1, checkpoint.get('performance_desc', '[no performance_desc]')
missing, unexpected = model_without_ddp.load_state_dict(checkpoint.get('module', checkpoint), strict=False)
Expand Down
6 changes: 3 additions & 3 deletions pretrain/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ $ torchrun --nproc_per_node=8 --nnodes=<your_nnodes> --node_rank=<rank_starts_fr

See files under `--exp_dir` to track your experiment:

- `<model>_still_pretraining.pth`: saves model and optimizer states, current epoch, current reconstruction loss, etc; can be used to resume pretraining
- `<model>_1kpretrained.pth`: can be used for downstream finetuning
- `<model>_withdecoder_1kpretrained_spark_style.pth`: saves model and optimizer states, current epoch, current reconstruction loss, etc.; can be used to resume pretraining; can also be used for visualization in [/pretrain/viz_reconstruction.ipynb](/pretrain/viz_reconstruction.ipynb)
- `<model>_1kpretrained_timm_style.pth`: can be used for downstream finetuning
- `pretrain_log.txt`: records some important information such as:
- `git_commit_id`: git version
- `cmd`: all arguments passed to the script
Expand All @@ -89,7 +89,7 @@ See files under `--exp_dir` to track your experiment:

## Resuming

Add the arg `--resume_from=path/to/<model>_still_pretraining.pth` to resume pretraining. Note this is different from `--init_weight`:
Add the arg `--resume_from=path/to/<model>_withdecoder_1kpretrained_spark_style.pth` to resume pretraining. Note this is different from `--init_weight`:

- `--resume_from` will load three things: model weights, optimizer states, and current epoch, so it's used to resume your interrupted experiment.
- `--init_weight` ONLY loads the model weights, so it's just like a model initialization.
Expand Down
4 changes: 2 additions & 2 deletions pretrain/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def main_pt():
last_loss = stats['last_loss']
min_loss = min(min_loss, last_loss)
performance_desc = f'{min_loss:.4f} {last_loss:.4f}'
misc.save_checkpoint(f'{args.model}_still_pretraining.pth', args, ep, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict())
misc.save_checkpoint_for_finetune(f'{args.model}_1kpretrained.pth', args, model_without_ddp.sparse_encoder.sp_cnn.state_dict())
misc.save_checkpoint(f'{args.model}_withdecoder_1kpretrained_spark_style.pth', args, ep, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict())
misc.save_checkpoint_for_finetune(f'{args.model}_1kpretrained_timm_style.pth', args, model_without_ddp.sparse_encoder.sp_cnn.state_dict())

ep_cost = round(time.time() - ep_start_time, 2) + 1 # +1s: approximate the following logging cost
remain_secs = (args.ep-1 - ep) * ep_cost
Expand Down
14 changes: 4 additions & 10 deletions pretrain/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import sys
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
from timm.models.layers import trunc_normal_

import encoder
Expand Down Expand Up @@ -73,10 +72,9 @@ def __init__(

print(f'[SparK.__init__] dims of mask_tokens={tuple(p.numel() for p in self.mask_tokens)}')

m = torch.tensor(IMAGENET_DEFAULT_MEAN).view(1, 3, 1, 1)
s = torch.tensor(IMAGENET_DEFAULT_STD).view(1, 3, 1, 1)
self.register_buffer('imn_m', m)
self.register_buffer('imn_s', s)
# these are deprecated and would never be used; can be removed.
self.register_buffer('imn_m', torch.empty(1, 3, 1, 1))
self.register_buffer('imn_s', torch.empty(1, 3, 1, 1))
self.register_buffer('norm_black', torch.zeros(1, 3, input_size, input_size))
self.vis_active = self.vis_active_ex = self.vis_inp = self.vis_inp_mask = ...

Expand Down Expand Up @@ -125,7 +123,7 @@ def forward(self, inp_bchw: torch.Tensor, active_b1ff=None, vis=False):
masked_bchw = inp_bchw * active_b1hw
rec_bchw = self.unpatchify(rec * var + mean)
rec_or_inp = torch.where(active_b1hw, inp_bchw, rec_bchw)
return [self.denorm_for_vis(i) for i in (inp_bchw, masked_bchw, rec_or_inp)]
return inp_bchw, masked_bchw, rec_or_inp
else:
return recon_loss

Expand Down Expand Up @@ -186,7 +184,3 @@ def load_state_dict(self, state_dict, strict=True):
else:
print(err, file=sys.stderr)
return incompatible_keys

def denorm_for_vis(self, normalized_im):
normalized_im = (normalized_im * self.imn_s).add_(self.imn_m)
return torch.clamp(normalized_im, 0, 1)
1 change: 1 addition & 0 deletions pretrain/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def save_checkpoint(save_to, args, epoch, performance_desc, model_without_ddp_st
if dist.is_local_master():
to_save = {
'args': str(args),
'input_size': args.input_size,
'arch': args.model,
'epoch': epoch,
'performance_desc': performance_desc,
Expand Down
117 changes: 54 additions & 63 deletions pretrain/viz_reconstruction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
"metadata": {},
"source": [
"# SparK: A Visualization Demo\n",
"A demo using our pre-trained SparK model (ConvNeXt-L with input size 384, or ConvNeXt-S with 224) to reconstruct masked images.\n",
"The mask is whether specified by the user or randomly generated."
"A demo using our pretrained SparK model to reconstruct masked images.\n",
"The mask is whether specified by the user or randomly generated.\n",
"\n",
"**NOTE:** if you want to visualize your own pretrained model, you may need to modify the `IMAGENET_RGB_MEAN` and `IMAGENET_RGB_STD` below to your numbers. In Step 5, use `spark, input_size = build_spark('<model_name>_still_pretraining.pth')` to load your model."
]
},
{
Expand All @@ -16,7 +18,7 @@
"metadata": {},
"source": [
"## 1. Preparation\n",
"Install dependencies, specify the device, and specify the pre-trained model."
"Install dependencies, specify the device, and specify the pretrained model."
]
},
{
Expand All @@ -41,8 +43,7 @@
"\n",
"# specify the device to use\n",
"USING_GPU_IF_AVAILABLE = True\n",
"# specify the CNN to use\n",
"USING_LARGE384_MODEL = True # True for ConvNeXt-L-384, False for ConvNeXt-S-224\n",
"\n",
"import torch\n",
"_ = torch.empty(1)\n",
"if torch.cuda.is_available() and USING_GPU_IF_AVAILABLE:\n",
Expand All @@ -67,18 +68,16 @@
"outputs": [],
"source": [
"from PIL import Image\n",
"import torchvision.transforms as T\n",
"IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)\n",
"IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)\n",
"def load_image(img_file: str):\n",
"import torchvision.transforms.functional as TF\n",
"IMAGENET_RGB_MEAN = torch.tensor((0.485, 0.456, 0.406), device=DEVICE).reshape(1, 3, 1, 1)\n",
"IMAGENET_RGB_STD = torch.tensor((0.229, 0.224, 0.225), device=DEVICE).reshape(1, 3, 1, 1)\n",
"def load_image(size: int, img_file: str):\n",
" img = Image.open(img_file).convert('RGB')\n",
" transform = T.Compose([\n",
" T.Resize((384, 384) if USING_LARGE384_MODEL else (224, 224)),\n",
" T.ToTensor(),\n",
" T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),\n",
" ])\n",
" img = transform(img)\n",
" return img.unsqueeze(0).to(DEVICE)"
" img = TF.center_crop(TF.resize(img, size), [size, size])\n",
" img = TF.to_tensor(img).unsqueeze(0).to(DEVICE).sub(IMAGENET_RGB_MEAN).div_(IMAGENET_RGB_STD)\n",
" return img\n",
"def denormalize(img_bchw):\n",
" return img_bchw.mul(IMAGENET_RGB_STD).add_(IMAGENET_RGB_MEAN).clamp_(0., 1.)"
]
},
{
Expand All @@ -101,31 +100,31 @@
"from encoder import SparseEncoder\n",
"from models import build_sparse_encoder\n",
"from spark import SparK\n",
"def build_spark():\n",
" # download and load the checkpoint\n",
" if USING_LARGE384_MODEL:\n",
" model_name, input_size = 'convnext_large', 384\n",
" ckpt_file = 'cnxL384_withdecoder_1kpretrained_spark_style.pth'\n",
" ckpt_link = 'https://drive.google.com/file/d/1ZI9Jgtb3fKWE_vDFEly29w-1FWZSNwa0/view?usp=share_link'\n",
"def build_spark(your_still_pretraining_file_path: str = ''):\n",
" if len(your_still_pretraining_file_path) > 0 and os.path.exists(your_still_pretraining_file_path):\n",
" all_state = torch.load(your_still_pretraining_file_path, map_location='cpu')\n",
" input_size, model_name = all_state['input_size'], all_state['arch']\n",
" pretrained_state = all_state['module']\n",
" print(f\"[in function `build_spark`] your ckpt `{your_still_pretraining_file_path}` loaded; don't forget to modify IMAGENET_RGB_MEAN and IMAGENET_RGB_MEAN above if needed\")\n",
" else:\n",
" input_size = 224\n",
" model_name, ckpt_file, ckpt_link = {\n",
" 'ResNet50': ('resnet50', 'res50_withdecoder_1kpretrained_spark_style.pth', 'https://drive.google.com/file/d/1STt3w3e5q9eCPZa8VzcJj1zG6p3jLeSF/view?usp=share_link'),\n",
" 'ResNet101': ('resnet101', 'res101_withdecoder_1kpretrained_spark_style.pth', 'https://drive.google.com/file/d/1GjN48LKtlop2YQre6---7ViCWO-3C0yr/view?usp=share_link'),\n",
" 'ResNet152': ('resnet152', 'res152_withdecoder_1kpretrained_spark_style.pth', 'https://drive.google.com/file/d/1U3Cd94j4ZHfYR2dUjWmsEWfjP6Opx4oo/view?usp=share_link'),\n",
" 'ResNet200': ('resnet200', 'res200_withdecoder_1kpretrained_spark_style.pth', 'https://drive.google.com/file/d/13AFSqvIr0v-2hmb4DzVza45t_lhf2CnD/view?usp=share_link'),\n",
" 'ConvNeXt-S': ('convnext_small', 'cnxS224_withdecoder_1kpretrained_spark_style.pth', 'https://drive.google.com/file/d/1bKvrE4sNq1PfzhWlQJXEPrl2kHqHRZM-/view?usp=share_link'),\n",
" }['ConvNeXt-S'] # you can choose any model here\n",
" # download and load the checkpoint\n",
" input_size, model_name, file_path, ckpt_link = {\n",
" 'ResNet50': (224, 'resnet50', 'res50_withdecoder_1kpretrained_spark_style.pth', 'https://drive.google.com/file/d/1STt3w3e5q9eCPZa8VzcJj1zG6p3jLeSF/view?usp=share_link'),\n",
" 'ResNet101': (224, 'resnet101', 'res101_withdecoder_1kpretrained_spark_style.pth', 'https://drive.google.com/file/d/1GjN48LKtlop2YQre6---7ViCWO-3C0yr/view?usp=share_link'),\n",
" 'ResNet152': (224, 'resnet152', 'res152_withdecoder_1kpretrained_spark_style.pth', 'https://drive.google.com/file/d/1U3Cd94j4ZHfYR2dUjWmsEWfjP6Opx4oo/view?usp=share_link'),\n",
" 'ResNet200': (224, 'resnet200', 'res200_withdecoder_1kpretrained_spark_style.pth', 'https://drive.google.com/file/d/13AFSqvIr0v-2hmb4DzVza45t_lhf2CnD/view?usp=share_link'),\n",
" 'ConvNeXt-S': (224, 'convnext_small', 'cnxS224_withdecoder_1kpretrained_spark_style.pth', 'https://drive.google.com/file/d/1bKvrE4sNq1PfzhWlQJXEPrl2kHqHRZM-/view?usp=share_link'),\n",
" 'ConvNeXt-L': (384, 'convnext_large', 'cnxL384_withdecoder_1kpretrained_spark_style.pth', 'https://drive.google.com/file/d/1ZI9Jgtb3fKWE_vDFEly29w-1FWZSNwa0/view?usp=share_link')\n",
" }['ConvNeXt-L'] # you can choose any model here\n",
" assert os.path.exists(file_path), f'please download checkpoint {file_path} from {ckpt_link}'\n",
" pretrained_state = torch.load(file_path, map_location='cpu')\n",
"\n",
" assert os.path.exists(ckpt_file), f'please download checkpoint {ckpt_file} from {ckpt_link}'\n",
" pretrained_state = torch.load(ckpt_file, map_location='cpu')\n",
" using_bn_in_densify = 'densify_norms.0.running_mean' in pretrained_state\n",
" \n",
" # build a SparK model\n",
" config = pretrained_state['config']\n",
" enc: SparseEncoder = build_sparse_encoder(model_name, input_size=input_size)\n",
" spark = SparK(\n",
" sparse_encoder=enc, dense_decoder=LightDecoder(enc.downsample_raito, sbn=False),\n",
" mask_ratio=0.6, densify_norm='bn' if using_bn_in_densify else 'ln', sbn=False,\n",
" mask_ratio=config['mask_ratio'], densify_norm=config['densify_norm_str'], sbn=config['sbn'],\n",
" ).to(DEVICE)\n",
" spark.eval(), [p.requires_grad_(False) for p in spark.parameters()]\n",
" \n",
Expand All @@ -134,7 +133,7 @@
" assert len(missing) == 0, f'load_state_dict missing keys: {missing}'\n",
" assert len(unexpected) == 0, f'load_state_dict unexpected keys: {unexpected}'\n",
" del pretrained_state\n",
" return spark\n"
" return spark, input_size\n"
]
},
{
Expand All @@ -153,15 +152,15 @@
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"def show(spark: SparK, img_file='viz_imgs/recon.png', active_b1ff: torch.BoolTensor = None):\n",
" inp_bchw = load_image(img_file)\n",
"def show(spark: SparK, size: int, img_file='viz_imgs/recon.png', active_b1ff: torch.BoolTensor = None):\n",
" inp_bchw = load_image(size, img_file)\n",
" spark.forward\n",
" inp_bchw, masked_bchw, rec_or_inp = spark(inp_bchw, active_b1ff=active_b1ff, vis=True)\n",
" # plot these three images in a row\n",
" masked_title = 'rand masked' if active_b1ff is None else 'specified masked'\n",
" for col, (title, tensor) in enumerate(zip(['input', masked_title, 'reconstructed'], [inp_bchw, masked_bchw, rec_or_inp])):\n",
" for col, (title, bchw) in enumerate(zip(['input', masked_title, 'reconstructed'], [inp_bchw, masked_bchw, rec_or_inp])):\n",
" plt.subplot2grid((1, 3), (0, col))\n",
" plt.imshow(tensor[0].permute(1, 2, 0).cpu().numpy())\n",
" plt.imshow(denormalize(bchw)[0].permute(1, 2, 0).cpu().numpy())\n",
" plt.title(title)\n",
" plt.axis('off')\n",
" plt.show()"
Expand All @@ -177,26 +176,16 @@
},
{
"cell_type": "code",
"execution_count": 5,
"id": "dbe10ac3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[build_sparse_encoder] model kwargs={'sparse': True, 'drop_path_rate': 0.4, 'pretrained': False, 'num_classes': 0, 'global_pool': ''}\n",
"[SparK.__init__, densify 1/4]: densify_proj(ksz=1, #para=1.18M)\n",
"[SparK.__init__, densify 2/4]: densify_proj(ksz=3, #para=2.65M)\n",
"[SparK.__init__, densify 3/4]: densify_proj(ksz=3, #para=0.66M)\n",
"[SparK.__init__, densify 4/4]: densify_proj(ksz=3, #para=0.17M)\n",
"[SparK.__init__] dims of mask_tokens=(1536, 768, 384, 192)\n"
]
}
],
"execution_count": null,
"outputs": [],
"source": [
"spark = build_spark()"
]
"spark, input_size = build_spark()\n",
"# for visualizing your own pretrained model of '<model_name>_withdecoder_1kpretrained_spark_style.pth', run:\n",
"# spark, input_size = build_spark('<model_name>_withdecoder_1kpretrained_spark_style.pth')"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
Expand All @@ -217,7 +206,7 @@
],
"source": [
"# specify the mask\n",
"if USING_LARGE384_MODEL:\n",
"if input_size == 384:\n",
" active_b1ff = torch.tensor([\n",
" [0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
Expand All @@ -232,7 +221,7 @@
" [1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],\n",
" ], device=DEVICE).bool().reshape(1, 1, 12, 12)\n",
"else:\n",
"elif input_size == 224:\n",
" active_b1ff = torch.tensor([\n",
" [0, 0, 0, 0, 0, 0, 0],\n",
" [1, 0, 0, 0, 0, 0, 1],\n",
Expand All @@ -242,8 +231,10 @@
" [0, 0, 0, 0, 0, 1, 0],\n",
" [0, 0, 0, 1, 0, 0, 0],\n",
" ], device=DEVICE).bool().reshape(1, 1, 7, 7)\n",
"else:\n",
" raise NotImplementedError('define your mask for other input_size')\n",
"\n",
"show(spark, 'viz_imgs/recon.png', active_b1ff=active_b1ff)"
"show(spark, input_size, 'viz_imgs/recon.png', active_b1ff=active_b1ff)"
]
},
{
Expand All @@ -262,7 +253,7 @@
"outputs": [],
"source": [
"# use a random mask (don't specify the mask)\n",
"show(spark, 'viz_imgs/recon.png', active_b1ff=None)"
"show(spark, input_size, 'viz_imgs/recon.png', active_b1ff=None)"
]
}
],
Expand Down

0 comments on commit 1468df8

Please sign in to comment.