Library for quick training models for image classification.
pip install image-classification-pytorch
import image_classification_pytorch as icp
# add model
# your can add several models for consistent training
tf_efficientnet_b4_ns = {'model_type': 'tf_efficientnet_b4_ns',
'im_size': 380,
'im_size_test': 380,
'batch_size': 8,
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225]}
models = [tf_efficientnet_b4_ns]
# create trainer
trainer = icp.ICPTrainer(models=models, data_dir='my_data')
# start training
trainer.fit_test()
Simple example of training and prediction
Put folders with samples into a folder (data_dir). You can use class labels for folder names.
Example of folders structure
├── inference # data_dir folder
├── dogs # Folder Class 1
├── cats # Folder Class 2
Use the same parameters as for training.
import image_classification_pytorch as icp
icp.ICPInference(data_dir='inference',
img_size=380,
show_accuracy=True,
checkpoint='tb_logs/tf_efficientnet_b4_ns/version_4/checkpoints/tf_efficientnet_b4_ns__epoch=2_val_loss=0.922_val_acc=0.830_val_f1_epoch=0.000.ckpt',
std=[0.229, 0.224, 0.225],
mean=[0.485, 0.456, 0.406],
confidence_threshold=1).predict()
After prediction you can see such folders structure
├── inference # data_dir folder
├── dogs # Initial dogs folder
├── dogs_gt___dogs # In this folder should be dogs pictures (ground truth(gt) dogs) and they predicted as dogs
├── dogs_gt___cats # In this folder should be dogs pictures (ground truth(gt) dogs) but they predicted as cats
├── cats # Initial cats folder
├── cats_gt___cats # In this folder should be cats pictures (ground truth(gt) cats) and they predicted as cats
As you can see all cats predicted as cats and some dogs predicted as cats.
Prepare data for training in the following format
├── animals # Data folder
├── dogs # Folder Class 1
├── cats # Folder Class 2
├── gray_bears # Folder Class 3
├── zebras # Folder Class 4
├── ...
import image_classification_pytorch as icp
# add model
# your can add several models for consistent training
tf_efficientnet_b4_ns = {'model_type': 'tf_efficientnet_b4_ns',
'im_size': 380,
'im_size_test': 380,
'batch_size': 8,
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225]}
ens_adv_inception_resnet_v2 = {'model_type': 'ens_adv_inception_resnet_v2',
'im_size': 256,
'im_size_test': 256,
'batch_size': 8,
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5]}
models = [tf_efficientnet_b4_ns, ens_adv_inception_resnet_v2]
# create trainer
trainer = ICPTrainer(models=models,
data_dir='my_data',
images_ext='jpg',
init_lr=1e-5,
max_epochs=500,
augment_p=0.7,
progress_bar_refresh_rate=10,
early_stop_patience=6,
optimizer=Adam(),
scheduler=ExponentialLR())
# start training
trainer.fit_test()
323 models out of the box. Used models from PyTorch Image Models (timm)
Models
- adv_inception_v3
- cspdarknet53
- cspresnet50
- cspresnext50
- densenet121
- densenet161
- densenet169
- densenet201
- densenetblur121d
- dla102
- dla102x
- dla102x2
- dla169
- dla34
- dla46_c
- dla46x_c
- dla60_res2net
- dla60_res2next
- dla60
- dla60x_c
- dla60x
- dm_nfnet_f0
- dm_nfnet_f1
- dm_nfnet_f2
- dm_nfnet_f3
- dm_nfnet_f4
- dm_nfnet_f5
- dm_nfnet_f6
- dpn107
- dpn131
- dpn68
- dpn68b
- dpn92
- dpn98
- ecaresnet101d_pruned
- ecaresnet101d
- ecaresnet269d
- ecaresnet26t
- ecaresnet50d_pruned
- ecaresnet50d
- ecaresnet50t
- ecaresnetlight
- efficientnet_b0
- efficientnet_b1_pruned
- efficientnet_b1
- efficientnet_b2
- efficientnet_b2a
- efficientnet_b3_pruned
- efficientnet_b3
- efficientnet_b3a
- efficientnet_em
- efficientnet_es
- efficientnet_lite0
- ens_adv_inception_resnet_v2
- ese_vovnet19b_dw
- ese_vovnet39b
- fbnetc_100
- gernet_l
- gernet_m
- gernet_s
- gluon_inception_v3
- gluon_resnet101_v1b
- gluon_resnet101_v1c
- gluon_resnet101_v1d
- gluon_resnet101_v1s
- gluon_resnet152_v1b
- gluon_resnet152_v1c
- gluon_resnet152_v1d
- gluon_resnet152_v1s
- gluon_resnet18_v1b
- gluon_resnet34_v1b
- gluon_resnet50_v1b
- gluon_resnet50_v1c
- gluon_resnet50_v1d
- gluon_resnet50_v1s
- gluon_resnext101_32x4d
- gluon_resnext101_64x4d
- gluon_resnext50_32x4d
- gluon_senet154
- gluon_seresnext101_32x4d
- gluon_seresnext101_64x4d
- gluon_seresnext50_32x4d
- gluon_xception65
- hrnet_w18_small_v2
- hrnet_w18_small
- hrnet_w18
- hrnet_w30
- hrnet_w32
- hrnet_w40
- hrnet_w44
- hrnet_w48
- hrnet_w64
- ig_resnext101_32x16d
- ig_resnext101_32x32d
- ig_resnext101_32x48d
- ig_resnext101_32x8d
- inception_resnet_v2
- inception_v3
- inception_v4
- legacy_senet154
- legacy_seresnet101
- legacy_seresnet152
- legacy_seresnet18
- legacy_seresnet34
- legacy_seresnet50
- legacy_seresnext101_32x4d
- legacy_seresnext26_32x4d
- legacy_seresnext50_32x4d
- mixnet_l
- mixnet_m
- mixnet_s
- mixnet_xl
- mnasnet_100
- mobilenetv2_100
- mobilenetv2_110d
- mobilenetv2_120d
- mobilenetv2_140
- mobilenetv3_large_100
- mobilenetv3_rw
- nasnetalarge
- nf_regnet_b1
- nf_resnet50
- nfnet_l0c
- pnasnet5large
- regnetx_002
- regnetx_004
- regnetx_006
- regnetx_008
- regnetx_016
- regnetx_032
- regnetx_040
- regnetx_064
- regnetx_080
- regnetx_120
- regnetx_160
- regnetx_320
- regnety_002
- regnety_004
- regnety_006
- regnety_008
- regnety_016
- regnety_032
- regnety_040
- regnety_064
- regnety_080
- regnety_120
- regnety_160
- regnety_320
- repvgg_a2
- repvgg_b0
- repvgg_b1
- repvgg_b1g4
- repvgg_b2
- repvgg_b2g4
- repvgg_b3
- repvgg_b3g4
- res2net101_26w_4s
- res2net50_14w_8s
- res2net50_26w_4s
- res2net50_26w_6s
- res2net50_26w_8s
- res2net50_48w_2s
- res2next50
- resnest101e
- resnest14d
- resnest200e
- resnest269e
- resnest26d
- resnest50d_1s4x24d
- resnest50d_4s2x40d
- resnest50d
- resnet101d
- resnet152d
- resnet18
- resnet18d
- resnet200d
- resnet26
- resnet26d
- resnet34
- resnet34d
- resnet50
- resnet50d
- resnetblur50
- resnetv2_101x1_bitm_in21k
- resnetv2_101x1_bitm
- resnetv2_101x3_bitm_in21k
- resnetv2_101x3_bitm
- resnetv2_152x2_bitm_in21k
- resnetv2_152x2_bitm
- resnetv2_152x4_bitm_in21k
- resnetv2_152x4_bitm
- resnetv2_50x1_bitm_in21k
- resnetv2_50x1_bitm
- resnetv2_50x3_bitm_in21k
- resnetv2_50x3_bitm
- resnext101_32x8d
- resnext50_32x4d
- resnext50d_32x4d
- rexnet_100
- rexnet_130
- rexnet_150
- rexnet_200
- selecsls42b
- selecsls60
- selecsls60b
- semnasnet_100
- seresnet152d
- seresnet50
- seresnext26d_32x4d
- seresnext26t_32x4d
- seresnext50_32x4d
- skresnet18
- skresnet34
- skresnext50_32x4d
- spnasnet_100
- ssl_resnet18
- ssl_resnet50
- ssl_resnext101_32x16d
- ssl_resnext101_32x4d
- ssl_resnext101_32x8d
- ssl_resnext50_32x4d
- swsl_resnet18
- swsl_resnet50
- swsl_resnext101_32x16d
- swsl_resnext101_32x4d
- swsl_resnext101_32x8d
- swsl_resnext50_32x4d
- tf_efficientnet_b0_ap
- tf_efficientnet_b0_ns
- tf_efficientnet_b0
- tf_efficientnet_b1_ap
- tf_efficientnet_b1_ns
- tf_efficientnet_b1
- tf_efficientnet_b2_ap
- tf_efficientnet_b2_ns
- tf_efficientnet_b2
- tf_efficientnet_b3_ap
- tf_efficientnet_b3_ns
- tf_efficientnet_b3
- tf_efficientnet_b4_ap
- tf_efficientnet_b4_ns
- tf_efficientnet_b4
- tf_efficientnet_b5_ap
- tf_efficientnet_b5_ns
- tf_efficientnet_b5
- tf_efficientnet_b6_ap
- tf_efficientnet_b6_ns
- tf_efficientnet_b6
- tf_efficientnet_b7_ap
- tf_efficientnet_b7_ns
- tf_efficientnet_b7
- tf_efficientnet_b8_ap
- tf_efficientnet_b8
- tf_efficientnet_cc_b0_4e
- tf_efficientnet_cc_b0_8e
- tf_efficientnet_cc_b1_8e
- tf_efficientnet_el
- tf_efficientnet_em
- tf_efficientnet_es
- tf_efficientnet_l2_ns_475
- tf_efficientnet_l2_ns
- tf_efficientnet_lite0
- tf_efficientnet_lite1
- tf_efficientnet_lite2
- tf_efficientnet_lite3
- tf_efficientnet_lite4
- tf_inception_v3
- tf_mixnet_l
- tf_mixnet_m
- tf_mixnet_s
- tf_mobilenetv3_large_075
- tf_mobilenetv3_large_100
- tf_mobilenetv3_large_minimal_100
- tf_mobilenetv3_small_075
- tf_mobilenetv3_small_100
- tf_mobilenetv3_small_minimal_100
- tresnet_l_448
- tresnet_l
- tresnet_m_448
- tresnet_m
- tresnet_xl_448
- tresnet_xl
- tv_densenet121
- tv_resnet101
- tv_resnet152
- tv_resnet34
- tv_resnet50
- tv_resnext50_32x4d
- vgg11_bn
- vgg11
- vgg13_bn
- vgg13
- vgg16_bn
- vgg16
- vgg19_bn
- vgg19
- vit_base_patch16_224_in21k
- vit_base_patch16_224
- vit_base_patch16_384
- vit_base_patch32_224_in21k
- vit_base_patch32_384
- vit_base_resnet50_224_in21k
- vit_base_resnet50_384
- vit_deit_base_distilled_patch16_224
- vit_deit_base_distilled_patch16_384
- vit_deit_base_patch16_224
- vit_deit_base_patch16_384
- vit_deit_small_distilled_patch16_224
- vit_deit_small_patch16_224
- vit_deit_tiny_distilled_patch16_224
- vit_deit_tiny_patch16_224
- vit_large_patch16_224_in21k
- vit_large_patch16_224
- vit_large_patch16_384
- vit_large_patch32_224_in21k
- vit_large_patch32_384
- vit_small_patch16_224
- wide_resnet101_2
- wide_resnet50_2
- xception
- xception41
- xception65
- xception71
Get List Models
import timm
from pprint import pprint
model_names = timm.list_models(pretrained=True)
pprint(model_names)
>>> ['adv_inception_v3',
'cspdarknet53',
'cspresnext50',
'densenet121',
'densenet161',
'densenet169',
'densenet201',
'densenetblur121d',
'dla34',
'dla46_c',
...
]
Get Model Parameters
import timm
from pprint import pprint
m = timm.create_model('efficientnet_b0', pretrained=True)
pprint(m.default_cfg)
Timm documentation here
Project is distributed under MIT License