Skip to content

Commit

Permalink
[20230424 v0.4.0] Add strong data augmentations; Support the argument…
Browse files Browse the repository at this point in the history
… 'in_chans' for all models; Support imagenet21k (22k)
  • Loading branch information
horrible-dong committed Apr 23, 2023
1 parent 85a5143 commit 5b17f6a
Show file tree
Hide file tree
Showing 45 changed files with 323 additions and 220 deletions.
18 changes: 9 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ QTClassification

**A lightweight and extensible toolbox for image classification**

[![version](https://img.shields.io/badge/Version-0.3.0-brightgreen)](https://github.com/horrible-dong/QTClassification)
[![version](https://img.shields.io/badge/Version-0.4.0-brightgreen)](https://github.com/horrible-dong/QTClassification)
 [![docs](https://img.shields.io/badge/Docs-Latest-orange)](https://github.com/horrible-dong/QTClassification/blob/main/README.md)
 [![license](https://img.shields.io/badge/License-Apache--2.0-blue)](https://github.com/horrible-dong/QTClassification/blob/main/LICENSE)

Expand Down Expand Up @@ -136,7 +136,7 @@ arguments*.
| `--dataset`<br />`-d` | Dataset name defined in [qtcls/datasets/\_\_init\_\_.py](qtcls/datasets/__init__.py), such as `cifar10` and `imagenet1k`. | / |
| `--model_lib` | Model library where models come from. Our basic model library is extended from `torchvision` (default), and also supports `timm`. | `torchvision-ex` |
| `--model`<br />`-m` | Model name defined in [qtcls/models/\_\_init\_\_.py](qtcls/models/__init__.py), such as `resnet50` and `vit_b_16`. Currently supported model names are listed in <a href="#model_zoo">Model Zoo</a>. | / |
| `--criterion` | Criterion name defined in [qtcls/criterions/\_\_init\_\_.py](qtcls/criterions/__init__.py). The `default` criterion computes the cross entropy loss. | `default` |
| `--criterion` | Criterion name defined in [qtcls/criterions/\_\_init\_\_.py](qtcls/criterions/__init__.py), such as `ce`. | `default` |
| `--optimizer` | Optimizer name defined in [qtcls/optimizers/\_\_init\_\_.py](qtcls/optimizers/__init__.py), such as `sgd` and `adam`. | `adamw` |
| `--scheduler` | Scheduler name defined in [qtcls/schedulers/\_\_init\_\_.py](qtcls/schedulers/__init__.py), such as `cosine`. | `cosine` |
| `--evaluator` | Evaluator name defined in [qtcls/evaluators/\_\_init\_\_.py](qtcls/evaluators/__init__.py). The `default` evaluator computes the accuracy, recall, precision, and f1_score. | `default` |
Expand Down Expand Up @@ -191,10 +191,10 @@ Our toolbox is flexible enough to be extended. Please follow the instructions be

## <span id="dataset_zoo">Dataset Zoo</span>

Currently supported argument `--dataset`:
`mnist`, `cifar10`, `cifar100`, `stl10`, `svhn`, `pets`, `imagenet1k`, and all datasets in `folder` format (consistent
with `imagenet` storage format, that is, images of each category are stored in a folder/directory, and the
folder/directory name is the category name).
Currently supported argument `--dataset` / `-d`:
`mnist`, `cifar10`, `cifar100`, `stl10`, `svhn`, `pets`, `imagenet1k`, `imagenet21k (also called imagenet22k)`,
and all datasets in `folder` format (consistent with `imagenet` storage format, that is, images of each category are
stored in a folder/directory, and the folder/directory name is the category name).

## <span id="model_zoo">Model Zoo</span>

Expand All @@ -204,15 +204,15 @@ Our basic model library is extended from `torchvision` (default), and also suppo

Set the argument `--model_lib` to `torchvision-ex`.

Currently supported argument `--model`:
Currently supported argument `--model` / `-m`:

**AlexNet**
`alexnet`

**CaiT**
`cait_xxs24_224`, `cait_xxs24_384`, `cait_xxs36_224`, `cait_xxs36_384`, `cait_xs24_384`, `cait_s24_224`, `cait_s24_384`, `cait_s36_384`, `cait_m36_384`, `cait_m48_448`

**ConvNext**
**ConvNeXt**
`convnext_tiny`, `convnext_small`, `convnext_base`, `convnext_large`

**DenseNet**
Expand Down Expand Up @@ -273,7 +273,7 @@ Currently supported argument `--model`:

Set the argument `--model_lib` to `timm`.

Currently supported argument `--model`:
Currently supported argument `--model` / `-m`:
All supported. Please refer to `timm` for the specific model name.

## LICENSE
Expand Down
10 changes: 5 additions & 5 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ QTClassification

**轻量可扩展的图像分类工具箱**

[![version](https://img.shields.io/badge/Version-0.3.0-brightgreen)](https://github.com/horrible-dong/QTClassification)
[![version](https://img.shields.io/badge/Version-0.4.0-brightgreen)](https://github.com/horrible-dong/QTClassification)
&emsp;[![docs](https://img.shields.io/badge/Docs-Latest-orange)](https://github.com/horrible-dong/QTClassification/blob/main/README_zh-CN.md)
&emsp;[![license](https://img.shields.io/badge/License-Apache--2.0-blue)](https://github.com/horrible-dong/QTClassification/blob/main/LICENSE)

Expand Down Expand Up @@ -133,7 +133,7 @@ python main.py \
| `--dataset`<br />`-d` | 数据集名称,在 [qtcls/datasets/\_\_init\_\_.py](qtcls/datasets/__init__.py) 里定义,如 `cifar10``imagenet1k`| / |
| `--model_lib` | 模型库,模型都取自模型库。我们的基础模型库由 `torchvision` 扩展而来 (我们的默认模型库),同时我们也支持 `timm` 模型库。 | `torchvision-ex` |
| `--model`<br />`-m` | 模型名称,在 [qtcls/models/\_\_init\_\_.py ](qtcls/models/__init__.py) 里定义,如 `resnet50``vit_b_16`。目前支持的模型名称在<a href="#模型库">模型库</a>中列出。 | / |
| `--criterion` | 损失函数名称,在 [qtcls/criterions/\_\_init\_\_.py](qtcls/criterions/__init__.py) 里定义。默认的损失函数会计算交叉熵损失。 | `default` |
| `--criterion` | 损失函数名称,在 [qtcls/criterions/\_\_init\_\_.py](qtcls/criterions/__init__.py) 里定义,如 `ce` | `default` |
| `--optimizer` | 优化器名称,在 [qtcls/optimizers/\_\_init\_\_.py](qtcls/optimizers/__init__.py),如 `sgd``adam`| `adamw` |
| `--scheduler` | 学习率调整策略名称,在 [qtcls/schedulers/\_\_init\_\_.py](qtcls/schedulers/__init__.py) 中定义,如 `cosine`| `cosine` |
| `--evaluator` | 验证器名称,在 [qtcls/evaluators/\_\_init\_\_.py](qtcls/evaluators/__init__.py) 中定义。默认的验证器会计算准确率、召回率、精确率和f1分数。 | `default` |
Expand Down Expand Up @@ -187,8 +187,8 @@ python main.py -c configs/_demo_.py
## <span id="数据集">数据集</span>

目前支持的 `--dataset` 参数:
`mnist``cifar10``cifar100``stl10``svhn``pets``imagenet1k` 以及所有 `folder` 格式的数据集(与 `imagenet`
存储格式一致,即每个类别的图片存放在一个文件夹内,文件夹名称是类别名称)。
`mnist``cifar10``cifar100``stl10``svhn``pets``imagenet1k``imagenet21k (也叫做 imagenet22k)`
以及所有 `folder` 格式的数据集(与 `imagenet` 存储格式一致,即每个类别的图片存放在一个文件夹内,文件夹名称是类别名称)。

## <span id="模型库">模型库</span>

Expand All @@ -206,7 +206,7 @@ python main.py -c configs/_demo_.py
**CaiT**
`cait_xxs24_224`, `cait_xxs24_384`, `cait_xxs36_224`, `cait_xxs36_384`, `cait_xs24_384`, `cait_s24_224`, `cait_s24_384`, `cait_s36_384`, `cait_m36_384`, `cait_m48_448`

**ConvNext**
**ConvNeXt**
`convnext_tiny`, `convnext_small`, `convnext_base`, `convnext_large`

**DenseNet**
Expand Down
4 changes: 2 additions & 2 deletions configs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ English | [简体中文](README_zh-CN.md)

If needed, refer to [`_demo_.py`](_demo_.py) and write your arguments in a config file (.py).

Since v0.2.0, when using your config file, set `--config` to your **_config file path_**,
Since v0.2.0, when using your config file, set `--config` / `-c` to your **_config file path_**,
such as `configs/_demo_.py`.

For example,
Expand All @@ -19,7 +19,7 @@ or
python main.py -c configs/_demo_.py
```

Note that `--config` supports any file system path, such as `configs/_demo_.py`,
Note that `--config` / `-c` supports any file system path, such as `configs/_demo_.py`,
`D:\\QTClassification\\configs\\_demo_.py`, `../../other_project/cfg.py`.

Then, the config arguments will be merged with the command line arguments `args` in [`main.py`](../main.py).
Expand Down
6 changes: 3 additions & 3 deletions configs/README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

如果需要的话,仿照 [`_demo_.py`](_demo_.py) 编写你的参数。

从 v0.2.0 开始,当使用你的配置文件时,把 `--config` 赋值为你的 **_配置文件路径_**,如 `configs/_demo_.py`
从 v0.2.0 开始,当使用你的配置文件时,把 `--config` / `-c` 赋值为你的 **_配置文件路径_**,如 `configs/_demo_.py`

样例:

Expand All @@ -18,8 +18,8 @@ python main.py --config configs/_demo_.py
python main.py -c configs/_demo_.py
```

值得注意的是,`--config` 可以支持任意文件系统路径,比如 `configs/_demo_.py`, `D:\\QTClassification\\configs\\_demo_.py`,
`../../other_project/cfg.py`
值得注意的是,`--config` / `-c` 可以支持任意文件系统路径,比如
`configs/_demo_.py`, `D:\\QTClassification\\configs\\_demo_.py`, `../../other_project/cfg.py`

然后,这些参数会和 [`main.py`](../main.py) 中的命令行参数 `args` 进行合并。**如果参数名相同,配置文件参数会覆盖命令行参数。
**
8 changes: 5 additions & 3 deletions configs/_demo_.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
device = 'cuda'
seed = 42
batch_size = 128
batch_size = 256
epochs = 12
eval_interval = 1
num_workers = 2
pin_memory = True
sync_bn = True
data_root = './data'
dataset = 'cifar10'
dataset = 'mnist'
no_pretrain = True
model_lib = 'torchvision-ex'
model = 'resnet18'
optimizer = 'sgd'
Expand All @@ -23,6 +24,7 @@
clip_max_norm = 5.0

warmup_epochs = 2
min_lr = 1e-6
amp = True

model_kwargs = dict(groups=1, width_per_group=64) # do NOT set 'num_classes' in 'model_kwargs'
model_kwargs = dict(in_chans=1, groups=1, width_per_group=64) # Do NOT set 'num_classes' in 'model_kwargs'.
10 changes: 6 additions & 4 deletions engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def train_one_epoch(model, criterion, data_loader, optimizer, lr_scheduler, devi
model.train()
criterion.train()
n_steps = len(data_loader)

metric_logger = MetricLogger(delimiter=' ')
metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('class_error', SmoothedValue(window_size=1, fmt='{value:.2f}'))
Expand All @@ -29,7 +30,7 @@ def train_one_epoch(model, criterion, data_loader, optimizer, lr_scheduler, devi
else:
outputs = model(samples)

loss_dict = criterion(outputs, targets)
loss_dict = criterion(outputs, targets, training=True)
weight_dict = criterion.weight_dict
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)

Expand All @@ -51,15 +52,16 @@ def train_one_epoch(model, criterion, data_loader, optimizer, lr_scheduler, devi
lr_scheduler.step_update(epoch * n_steps + batch_idx)

metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
metric_logger.update(class_error=loss_dict_reduced['class_error'])
metric_logger.update(lr=optimizer.param_groups[0]['lr'])
if 'class_error' in loss_dict_reduced.keys():
metric_logger.update(class_error=loss_dict_reduced['class_error'])

lr_scheduler.step(epoch)

metric_logger.synchronize_between_processes()
print('Averaged stats:', metric_logger)

stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items() if meter.count > 0}

return stats

Expand All @@ -85,7 +87,7 @@ def evaluate(model, data_loader, criterion, device, args, print_freq=10, need_ta
else:
outputs = model(samples)

loss_dict = criterion(outputs, targets)
loss_dict = criterion(outputs, targets, training=False)

weight_dict = criterion.weight_dict
loss_dict_reduced = reduce_dict(loss_dict)
Expand Down
7 changes: 7 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,14 @@ def get_args_parser():
# dataset
parser.add_argument('--data_root', type=str, default='./data')
parser.add_argument('--dataset', '-d', type=str, default='cifar10')

# data augmentation
parser.add_argument('--image_size', type=int)
parser.add_argument('--train_aug_kwargs', default=dict())
parser.add_argument('--eval_aug_kwargs', default=dict())
parser.add_argument('--train_batch_aug_kwargs', default=dict())
parser.add_argument('--eval_batch_aug_kwargs', default=dict())
parser.add_argument('--label_smoothing', type=float, default=0.0, help='for LabelSmoothingCrossEntropy')

# model
parser.add_argument('--model_lib', default='torchvision-ex', type=str, choices=['torchvision-ex', 'timm'],
Expand Down
2 changes: 1 addition & 1 deletion qtcls/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# ********************************************
# Copyright (c) QIU, Tian. All rights reserved.

__version__ = "0.3.0"
__version__ = "0.4.0"

from .criterions import build_criterion
from .datasets import build_dataset
Expand Down
16 changes: 13 additions & 3 deletions qtcls/criterions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
# Copyright (c) QIU, Tian. All rights reserved.

from .default import DefaultCriterion
from .cross_entropy import CrossEntropy, LabelSmoothingCrossEntropy, SoftTargetCrossEntropy


def build_criterion(args):
criterion_name = args.criterion.lower()

if criterion_name in ['ce', 'default']:
if criterion_name == 'ce':
losses = ['labels']
weight_dict = {'loss_ce': 1}
return DefaultCriterion(losses=losses, weight_dict=weight_dict)
return CrossEntropy(losses=losses, weight_dict=weight_dict)

if criterion_name == 'label_smoothing_ce':
losses = ['labels']
weight_dict = {'loss_ce': 1}
return LabelSmoothingCrossEntropy(losses=losses, weight_dict=weight_dict, smoothing=args.label_smoothing)

if criterion_name in ['soft_target_ce', 'default']:
losses = ['labels']
weight_dict = {'loss_ce': 1}
return SoftTargetCrossEntropy(losses=losses, weight_dict=weight_dict)

raise ValueError(f"Criterion '{criterion_name}' is not found.")
4 changes: 2 additions & 2 deletions qtcls/criterions/_base_.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def __init__(self, losses: list, weight_dict: dict):
self.losses = losses
self.weight_dict = weight_dict

def forward(self, outputs, targets):
def forward(self, outputs, targets, **kwargs):
losses = {}
for loss in self.losses:
losses.update(getattr(self, f'loss_{loss}')(outputs, targets))
losses.update(getattr(self, f'loss_{loss}')(outputs, targets, **kwargs))
return losses
74 changes: 74 additions & 0 deletions qtcls/criterions/cross_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) QIU, Tian. All rights reserved.

import torch
import torch.nn.functional as F

from ._base_ import BaseCriterion
from ..utils.misc import accuracy

__all__ = ['CrossEntropy', 'LabelSmoothingCrossEntropy', 'SoftTargetCrossEntropy']


class CrossEntropy(BaseCriterion):
def __init__(self, losses: list, weight_dict: dict):
super().__init__(losses, weight_dict)

def loss_labels(self, outputs, targets, **kwargs):
if isinstance(outputs, dict):
assert 'logits' in outputs.keys(), \
f"When using 'loss_labels(self, outputs, targets, **kwargs)' in '{self.__class__.__name__}', " \
f"if 'outputs' is a dict, 'logits' MUST be the key."
outputs = outputs["logits"]

loss_ce = F.cross_entropy(outputs, targets, reduction='mean')
losses = {'loss_ce': loss_ce, 'class_error': 100 - accuracy(outputs, targets)[0]}

return losses


class LabelSmoothingCrossEntropy(BaseCriterion):
def __init__(self, losses: list, weight_dict: dict, smoothing: float = 0.1):
super().__init__(losses, weight_dict)
self.smoothing = smoothing
self.confidence = 1. - smoothing

def loss_labels(self, outputs, targets, training, **kwargs):
if isinstance(outputs, dict):
assert 'logits' in outputs.keys(), \
f"When using 'loss_labels(self, outputs, targets, **kwargs)' in '{self.__class__.__name__}', " \
f"if 'outputs' is a dict, 'logits' MUST be the key."
outputs = outputs["logits"]

if training:
logprobs = F.log_softmax(outputs, dim=-1)
nll_loss = -logprobs.gather(dim=-1, index=targets.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -logprobs.mean(dim=-1)
loss_ce = (self.confidence * nll_loss + self.smoothing * smooth_loss).mean()
else:
loss_ce = F.cross_entropy(outputs, targets, reduction='mean')

losses = {'loss_ce': loss_ce, 'class_error': 100 - accuracy(outputs, targets)[0]}

return losses


class SoftTargetCrossEntropy(BaseCriterion): # Compatible with 'CrossEntropy'
def __init__(self, losses: list, weight_dict: dict):
super().__init__(losses, weight_dict)

def loss_labels(self, outputs, targets, **kwargs):
if isinstance(outputs, dict):
assert 'logits' in outputs.keys(), \
f"When using 'loss_labels(self, outputs, targets, **kwargs)' in '{self.__class__.__name__}', " \
f"if 'outputs' is a dict, 'logits' MUST be the key."
outputs = outputs["logits"]

if targets.dim() == 1:
loss_ce = F.cross_entropy(outputs, targets, reduction='mean')
losses = {'loss_ce': loss_ce, 'class_error': 100 - accuracy(outputs, targets)[0]}
else:
loss_ce = torch.sum(-targets * F.log_softmax(outputs, dim=-1), dim=-1).mean()
losses = {'loss_ce': loss_ce}

return losses
22 changes: 0 additions & 22 deletions qtcls/criterions/default.py

This file was deleted.

Loading

0 comments on commit 5b17f6a

Please sign in to comment.