Skip to content

Commit

Permalink
[20230525 v0.5.0] Support datasets: Flowers102 and Stanford Cars; Sup…
Browse files Browse the repository at this point in the history
…port models: DeiT, PVT, TNT and Twins; Add some practical module wrappers; Update demo config; etc
  • Loading branch information
horrible-dong committed May 25, 2023
1 parent eca94b2 commit 44151e9
Show file tree
Hide file tree
Showing 39 changed files with 1,864 additions and 341 deletions.
32 changes: 23 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ QTClassification

**A lightweight and extensible toolbox for image classification**

[![version](https://img.shields.io/badge/Version-0.4.0-brightgreen)](https://github.com/horrible-dong/QTClassification)
[![version](https://img.shields.io/badge/Version-0.5.0-brightgreen)](https://github.com/horrible-dong/QTClassification)
 [![docs](https://img.shields.io/badge/Docs-Latest-orange)](https://github.com/horrible-dong/QTClassification/blob/main/README.md)
 [![license](https://img.shields.io/badge/License-Apache--2.0-blue)](https://github.com/horrible-dong/QTClassification/blob/main/LICENSE)

Expand Down Expand Up @@ -134,7 +134,7 @@ arguments*.
|:------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:----------------:|
| `--data_root` | Directory where your datasets is stored. | `./data` |
| `--dataset`<br />`-d` | Dataset name defined in [qtcls/datasets/\_\_init\_\_.py](qtcls/datasets/__init__.py), such as `cifar10` and `imagenet1k`. | / |
| `--model_lib` | Model library where models come from. Our basic model library is extended from `torchvision` (default), and also supports `timm`. | `torchvision-ex` |
| `--model_lib` | Model library where models come from. Our basic (default) model library is extended from `torchvision` and `timm`, and also supports the original `timm`. | `default` |
| `--model`<br />`-m` | Model name defined in [qtcls/models/\_\_init\_\_.py](qtcls/models/__init__.py), such as `resnet50` and `vit_b_16`. Currently supported model names are listed in <a href="#model_zoo">Model Zoo</a>. | / |
| `--criterion` | Criterion name defined in [qtcls/criterions/\_\_init\_\_.py](qtcls/criterions/__init__.py), such as `ce`. | `default` |
| `--optimizer` | Optimizer name defined in [qtcls/optimizers/\_\_init\_\_.py](qtcls/optimizers/__init__.py), such as `sgd` and `adam`. | `adamw` |
Expand Down Expand Up @@ -168,8 +168,9 @@ Arguments in the config file merge or override command line arguments `args`. Fo

**How to put your dataset**

Currently, `mnist`, `cifar10`, `cifar100`, `stl10`, `svhn` and `pets` datasets will be automatically downloaded to
the `--data_root` directory. For other datasets, please refer to ["How to put your dataset"](data/README.md).
Currently, `mnist`, `cifar10`, `cifar100`, `stl10`, `svhn`, `pets`, `flowers` and `cars` datasets will be automatically
downloaded to the `--data_root` directory. For other datasets, please refer
to ["How to put your dataset"](data/README.md).

### How to customize

Expand All @@ -192,17 +193,18 @@ Our toolbox is flexible enough to be extended. Please follow the instructions be
## <span id="dataset_zoo">Dataset Zoo</span>

Currently supported argument `--dataset` / `-d`:
`mnist`, `cifar10`, `cifar100`, `stl10`, `svhn`, `pets`, `imagenet1k`, `imagenet21k (also called imagenet22k)`,
`mnist`, `cifar10`, `cifar100`, `stl10`, `svhn`, `pets`, `flowers`, `cars`, `imagenet1k`,
`imagenet21k (also called imagenet22k)`,
and all datasets in `folder` format (consistent with `imagenet` storage format, that is, images of each category are
stored in a folder/directory, and the folder/directory name is the category name).

## <span id="model_zoo">Model Zoo</span>

Our basic model library is extended from `torchvision` (default), and also supports `timm`.
Our basic (default) model library is extended from `torchvision` and `timm`, and also supports the original `timm`.

### torchvision (extended)
### default

Set the argument `--model_lib` to `torchvision-ex`.
Set the argument `--model_lib` to `default`.

Currently supported argument `--model` / `-m`:

Expand All @@ -215,6 +217,9 @@ Currently supported argument `--model` / `-m`:
**ConvNeXt**
`convnext_tiny`, `convnext_small`, `convnext_base`, `convnext_large`

**DeiT**
`deit_tiny_patch16_224`, `deit_small_patch16_224`, `deit_base_patch16_224`, `deit_base_patch16_384`, `deit_tiny_distilled_patch16_224`, `deit_small_distilled_patch16_224`, `deit_base_distilled_patch16_224`, `deit_base_distilled_patch16_384`, `deit3_small_patch16_224`, `deit3_small_patch16_384`, `deit3_medium_patch16_224`, `deit3_base_patch16_224`, `deit3_base_patch16_384`, `deit3_large_patch16_224`, `deit3_large_patch16_384`, `deit3_huge_patch14_224`, `deit3_small_patch16_224_in21ft1k`, `deit3_small_patch16_384_in21ft1k`, `deit3_medium_patch16_224_in21ft1k`, `deit3_base_patch16_224_in21ft1k`, `deit3_base_patch16_384_in21ft1k`, `deit3_large_patch16_224_in21ft1k`, `deit3_large_patch16_384_in21ft1k`, `deit3_huge_patch14_224_in21ft1k`

**DenseNet**
`densenet121`, `densenet169`, `densenet201`, `densenet161`

Expand Down Expand Up @@ -242,6 +247,9 @@ Currently supported argument `--model` / `-m`:
**PoolFormer**
`poolformer_s12`, `poolformer_s24`, `poolformer_s36`, `poolformer_m36`, `poolformer_m48`

**PVT**
`pvt_tiny`, `pvt_small`, `pvt_medium`, `pvt_large`, `pvt_huge_v2`

**RegNet**
`regnet_y_400mf`, `regnet_y_800mf`, `regnet_y_1_6gf`, `regnet_y_3_2gf`, `regnet_y_8gf`, `regnet_y_16gf`, `regnet_y_32gf`, `regnet_y_128gf`, `regnet_x_400mf`, `regnet_x_800mf`, `regnet_x_1_6gf`, `regnet_x_3_2gf`, `regnet_x_8gf`, `regnet_x_16gf`, `regnet_x_32gf`

Expand All @@ -260,6 +268,12 @@ Currently supported argument `--model` / `-m`:
**Swin Transformer V2**
`swinv2_tiny_window8_256`, `swinv2_tiny_window16_256`, `swinv2_small_window8_256`, `swinv2_small_window16_256`, `swinv2_base_window8_256`, `swinv2_base_window16_256`, `swinv2_base_window12_192_22k`, `swinv2_base_window12to16_192to256_22kft1k`, `swinv2_base_window12to24_192to384_22kft1k`, `swinv2_large_window12_192_22k`, `swinv2_large_window12to16_192to256_22kft1k`, `swinv2_large_window12to24_192to384_22kft1k`

**TNT**
`tnt_s_patch4_32`, `tnt_s_patch16_224`, `tnt_b_patch16_224`

**Twins**
`twins_pcpvt_small`, `twins_pcpvt_base`, `twins_pcpvt_large`, `twins_svt_small`, `twins_svt_base`, `twins_svt_large`

**VGG**
`vgg11`, `vgg11_bn`, `vgg13`, `vgg13_bn`, `vgg16`, `vgg16_bn`, `vgg19`, `vgg19_bn`

Expand Down Expand Up @@ -296,7 +310,7 @@ If you find QTClassification Toolbox useful in your research, please consider ci
```bibtex
@misc{2023QTClassification,
title={QTClassification},
author={QTClassification Contributors},
author={Qiu, Tian},
howpublished = {\url{https://github.com/horrible-dong/QTClassification}},
year={2023}
}
Expand Down
28 changes: 20 additions & 8 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ QTClassification

**轻量可扩展的图像分类工具箱**

[![version](https://img.shields.io/badge/Version-0.4.0-brightgreen)](https://github.com/horrible-dong/QTClassification)
[![version](https://img.shields.io/badge/Version-0.5.0-brightgreen)](https://github.com/horrible-dong/QTClassification)
&emsp;[![docs](https://img.shields.io/badge/Docs-Latest-orange)](https://github.com/horrible-dong/QTClassification/blob/main/README_zh-CN.md)
&emsp;[![license](https://img.shields.io/badge/License-Apache--2.0-blue)](https://github.com/horrible-dong/QTClassification/blob/main/LICENSE)

Expand Down Expand Up @@ -131,7 +131,7 @@ python main.py \
|:------------------------:|:-----------------------------------------------------------------------------------------------------------------------------------:|:----------------:|
| `--data_root` | 你的数据集存放的路径。 | `./data` |
| `--dataset`<br />`-d` | 数据集名称,在 [qtcls/datasets/\_\_init\_\_.py](qtcls/datasets/__init__.py) 里定义,如 `cifar10``imagenet1k`| / |
| `--model_lib` | 模型库,模型都取自模型库。我们的基础模型库由 `torchvision` 扩展而来 (我们的默认模型库),同时我们也支持 `timm` 模型库。 | `torchvision-ex` |
| `--model_lib` | 模型库,模型都取自模型库。我们的基础(默认)模型库由 `torchvision` `timm` 扩展而来,同时我们也支持原生 `timm` 模型库。 | `default` |
| `--model`<br />`-m` | 模型名称,在 [qtcls/models/\_\_init\_\_.py ](qtcls/models/__init__.py) 里定义,如 `resnet50``vit_b_16`。目前支持的模型名称在<a href="#模型库">模型库</a>中列出。 | / |
| `--criterion` | 损失函数名称,在 [qtcls/criterions/\_\_init\_\_.py](qtcls/criterions/__init__.py) 里定义,如 `ce`| `default` |
| `--optimizer` | 优化器名称,在 [qtcls/optimizers/\_\_init\_\_.py](qtcls/optimizers/__init__.py),如 `sgd``adam`| `adamw` |
Expand Down Expand Up @@ -163,7 +163,7 @@ python main.py -c configs/_demo_.py

**如何放置你的数据集**

目前,`mnist``cifar10``cifar100``stl10``svhn``pets` 数据集会自动下载到 `--data_root`
目前,`mnist``cifar10``cifar100``stl10``svhn``pets``flowers``cars` 数据集会自动下载到 `--data_root`
目录下。其余数据集请参考 [“如何放置你的数据集”](data/README_zh-CN.md)

### 如何自定义
Expand All @@ -187,16 +187,16 @@ python main.py -c configs/_demo_.py
## <span id="数据集">数据集</span>

目前支持的 `--dataset` 参数:
`mnist``cifar10``cifar100``stl10``svhn``pets``imagenet1k``imagenet21k (也叫做 imagenet22k)`
`mnist``cifar10``cifar100``stl10``svhn`, `pets`, `flowers`, `cars`, `imagenet1k``imagenet21k (也叫做 imagenet22k)`
以及所有 `folder` 格式的数据集(与 `imagenet` 存储格式一致,即每个类别的图片存放在一个文件夹内,文件夹名称是类别名称)。

## <span id="模型库">模型库</span>

我们的基础模型库由 `torchvision` 扩展而来(我们的默认模型库),同时我们也支持 `timm` 模型库。
我们的基础(默认)模型库由 `torchvision` `timm` 扩展而来,同时我们也支持原生 `timm` 模型库。

### torchvision(经过我们扩展的)
### 默认模型库

`--model_lib` 赋值为 `torchvision-ex`
`--model_lib` 赋值为 `default`

目前支持的 `--model` 参数:

Expand All @@ -209,6 +209,9 @@ python main.py -c configs/_demo_.py
**ConvNeXt**
`convnext_tiny`, `convnext_small`, `convnext_base`, `convnext_large`

**DeiT**
`deit_tiny_patch16_224`, `deit_small_patch16_224`, `deit_base_patch16_224`, `deit_base_patch16_384`, `deit_tiny_distilled_patch16_224`, `deit_small_distilled_patch16_224`, `deit_base_distilled_patch16_224`, `deit_base_distilled_patch16_384`, `deit3_small_patch16_224`, `deit3_small_patch16_384`, `deit3_medium_patch16_224`, `deit3_base_patch16_224`, `deit3_base_patch16_384`, `deit3_large_patch16_224`, `deit3_large_patch16_384`, `deit3_huge_patch14_224`, `deit3_small_patch16_224_in21ft1k`, `deit3_small_patch16_384_in21ft1k`, `deit3_medium_patch16_224_in21ft1k`, `deit3_base_patch16_224_in21ft1k`, `deit3_base_patch16_384_in21ft1k`, `deit3_large_patch16_224_in21ft1k`, `deit3_large_patch16_384_in21ft1k`, `deit3_huge_patch14_224_in21ft1k`

**DenseNet**
`densenet121`, `densenet169`, `densenet201`, `densenet161`

Expand Down Expand Up @@ -236,6 +239,9 @@ python main.py -c configs/_demo_.py
**PoolFormer**
`poolformer_s12`, `poolformer_s24`, `poolformer_s36`, `poolformer_m36`, `poolformer_m48`

**PVT**
`pvt_tiny`, `pvt_small`, `pvt_medium`, `pvt_large`, `pvt_huge_v2`

**RegNet**
`regnet_y_400mf`, `regnet_y_800mf`, `regnet_y_1_6gf`, `regnet_y_3_2gf`, `regnet_y_8gf`, `regnet_y_16gf`, `regnet_y_32gf`, `regnet_y_128gf`, `regnet_x_400mf`, `regnet_x_800mf`, `regnet_x_1_6gf`, `regnet_x_3_2gf`, `regnet_x_8gf`, `regnet_x_16gf`, `regnet_x_32gf`

Expand All @@ -254,6 +260,12 @@ python main.py -c configs/_demo_.py
**Swin Transformer V2**
`swinv2_tiny_window8_256`, `swinv2_tiny_window16_256`, `swinv2_small_window8_256`, `swinv2_small_window16_256`, `swinv2_base_window8_256`, `swinv2_base_window16_256`, `swinv2_base_window12_192_22k`, `swinv2_base_window12to16_192to256_22kft1k`, `swinv2_base_window12to24_192to384_22kft1k`, `swinv2_large_window12_192_22k`, `swinv2_large_window12to16_192to256_22kft1k`, `swinv2_large_window12to24_192to384_22kft1k`

**TNT**
`tnt_s_patch4_32`, `tnt_s_patch16_224`, `tnt_b_patch16_224`

**Twins**
`twins_pcpvt_small`, `twins_pcpvt_base`, `twins_pcpvt_large`, `twins_svt_small`, `twins_svt_base`, `twins_svt_large`

**VGG**
`vgg11`, `vgg11_bn`, `vgg13`, `vgg13_bn`, `vgg16`, `vgg16_bn`, `vgg19`, `vgg19_bn`

Expand Down Expand Up @@ -285,7 +297,7 @@ QTClassification 基于 Apache 2.0 开源许可证. 具体请看[开源许可证
```bibtex
@misc{2023QTClassification,
title={QTClassification},
author={QTClassification Contributors},
author={Qiu, Tian},
howpublished = {\url{https://github.com/horrible-dong/QTClassification}},
year={2023}
}
Expand Down
4 changes: 2 additions & 2 deletions configs/README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ python main.py -c configs/_demo_.py
值得注意的是,`--config` / `-c` 可以支持任意文件系统路径,比如
`configs/_demo_.py`, `D:\\QTClassification\\configs\\_demo_.py`, `../../other_project/cfg.py`

然后,这些参数会和 [`main.py`](../main.py) 中的命令行参数 `args` 进行合并。**如果参数名相同,配置文件参数会覆盖命令行参数。
**
然后,这些参数会和 [`main.py`](../main.py) 中的命令行参数 `args` 进行合并。
**如果参数名相同,配置文件参数会覆盖命令行参数。**
50 changes: 28 additions & 22 deletions configs/_demo_.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,36 @@
from torch import nn

device = 'cuda'
seed = 42
batch_size = 256
epochs = 12
epochs = 300
eval_interval = 1
num_workers = 2
num_workers = None # auto
pin_memory = True
sync_bn = True
data_root = './data'
dataset = 'mnist'
no_pretrain = True
model_lib = 'torchvision-ex'
model = 'resnet18'
optimizer = 'sgd'
lr = 2e-3
momentum = 0.9
find_unused_params = False
dist_url = 'env://'
need_targets = False
model_lib = 'default'
criterion = 'ce'
optimizer = 'adamw'
weight_decay = 5e-2
scheduler = 'cosine'
save_interval = 1

output_dir = './runs/resnet18_baseline-cifar10'
note = "using the demo config in configs/_demo_.py"

print_freq = 10
clip_max_norm = 5.0

warmup_epochs = 2
min_lr = 1e-6
warmup_epochs = 20
warmup_lr = 1e-06
min_lr = 1e-05
evaluator = 'default'
no_pretrain = True
save_interval = 5
clip_max_norm = 1.0
amp = True

model_kwargs = dict(in_chans=1, groups=1, width_per_group=64) # Do NOT set 'num_classes' in 'model_kwargs'.
image_size = 32
batch_size = 256
lr = 0.0005 * (batch_size / 512)
data_root = './data'
dataset = 'cifar10'
model = 'vit_tiny_patch4_32'
model_kwargs = dict(in_chans=3, act_layer=nn.GELU, drop_path_rate=0.1) # Do NOT set 'num_classes' in 'model_kwargs'.
output_dir = f'./runs/{model}-{dataset}'
note = f"Using the demo config in 'configs/_demo_.py'. | dataset: {dataset} | model: {model} | output_dir: {output_dir}"
print_freq = 20
3 changes: 3 additions & 0 deletions data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ data/
├── cifar10/
├── cifar100/
├── imagenet1k/
├── imagenet21k/
├── stl10/
├── svhn/
├── pets/
├── flowers/
├── cars/
└── your_dataset/
```

Expand Down
3 changes: 3 additions & 0 deletions data/README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ data/
├── cifar10/
├── cifar100/
├── imagenet1k/
├── imagenet21k/
├── stl10/
├── svhn/
├── pets/
├── flowers/
├── cars/
└── your_dataset/
```

Expand Down
7 changes: 3 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_args_parser():
parser.add_argument('--print_freq', type=int, default=50)
parser.add_argument('--need_targets', action='store_true', help='need targets for training')
parser.add_argument('--drop_lr_now', action='store_true')
parser.add_argument('--drop_last', type=bool, default=True)
parser.add_argument('--drop_last', action='store_true')
parser.add_argument('--amp', action='store_true', help='automatic mixed precision training')
parser.add_argument('--no_dist', action='store_true', help='forcibly disable distributed mode')

Expand All @@ -61,8 +61,7 @@ def get_args_parser():
parser.add_argument('--label_smoothing', type=float, default=0.0, help='for LabelSmoothingCrossEntropy')

# model
parser.add_argument('--model_lib', default='torchvision-ex', type=str, choices=['torchvision-ex', 'timm'],
help='model library')
parser.add_argument('--model_lib', default='default', type=str, choices=['default', 'timm'], help='model library')
parser.add_argument('--model', '-m', default='resnet50', type=str, help='model name')
parser.add_argument('--model_kwargs', default=dict(), help='model specific kwargs')

Expand Down Expand Up @@ -171,7 +170,7 @@ def main(args):
data_loader_train = Data.DataLoader(dataset=dataset_train,
sampler=sampler_train,
batch_size=args.batch_size,
drop_last=args.drop_last,
drop_last=bool(args.drop_last or len(dataset_train) % 2 or args.batch_size % 2),
pin_memory=args.pin_memory,
num_workers=args.num_workers,
collate_fn=dataset_train.collate_fn)
Expand Down
2 changes: 1 addition & 1 deletion qtcls/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# ********************************************
# Copyright (c) QIU, Tian. All rights reserved.

__version__ = "0.4.0"
__version__ = "0.5.0"

from .criterions import build_criterion
from .datasets import build_dataset
Expand Down
2 changes: 1 addition & 1 deletion qtcls/datasets/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def build_dataset(args, split, download=True):
return YourDataset(root=dataset_path,
split=split,
transform=transform, # Can also be written explicitly as 'transform=transform[split]'.
batch_transform=batch_transform) # Can also be written explicitly as 'batch_transform[split]'.
batch_transform=batch_transform) # Can also be written explicitly as 'batch_transform=batch_transform[split]'.
...
```

Expand Down
4 changes: 2 additions & 2 deletions qtcls/datasets/README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def build_dataset(args, split, download=True):

return YourDataset(root=dataset_path,
split=split,
transform=transform, # 也可以显式写成 transform=transform[split]
batch_transform=batch_transform) # 也可以显式写成 batch_transform[split]
transform=transform, # 也可以显式地写成 transform=transform[split]
batch_transform=batch_transform) # 也可以显式地写成 batch_transform=batch_transform[split]
...
```

Expand Down
Loading

0 comments on commit 44151e9

Please sign in to comment.