Skip to content

denred0/image_classification_pytorch

Repository files navigation

Image Classification PyTorch

Library for quick training models for image classification.

Table of content


Installation

pip install image-classification-pytorch

Quick Start

import image_classification_pytorch as icp

# add model
# your can add several models for consistent training
tf_efficientnet_b4_ns = {'model_type': 'tf_efficientnet_b4_ns', 
                         'im_size': 380, 
                         'im_size_test': 380, 
                         'batch_size': 8, 
                         'mean': [0.485, 0.456, 0.406], 
                         'std': [0.229, 0.224, 0.225]}
                         
models = [tf_efficientnet_b4_ns]

# create trainer
trainer = icp.ICPTrainer(models=models, data_dir='my_data')
# start training
trainer.fit_test()

Example

Simple example of training and prediction Open In Colab


Prediction

Put folders with samples into a folder (data_dir). You can use class labels for folder names.

Example of folders structure

├── inference                    # data_dir folder
    ├── dogs                     # Folder Class 1
    ├── cats                     # Folder Class 2

Use the same parameters as for training.

import image_classification_pytorch as icp

icp.ICPInference(data_dir='inference',
                 img_size=380,
                 show_accuracy=True,
                 checkpoint='tb_logs/tf_efficientnet_b4_ns/version_4/checkpoints/tf_efficientnet_b4_ns__epoch=2_val_loss=0.922_val_acc=0.830_val_f1_epoch=0.000.ckpt',
                 std=[0.229, 0.224, 0.225],
                 mean=[0.485, 0.456, 0.406],
                 confidence_threshold=1).predict()

After prediction you can see such folders structure

├── inference                    # data_dir folder
    ├── dogs                     # Initial dogs folder 
    ├── dogs_gt___dogs           # In this folder should be dogs pictures (ground truth(gt) dogs) and they predicted as dogs
    ├── dogs_gt___cats           # In this folder should be dogs pictures (ground truth(gt) dogs) but they predicted as cats
    ├── cats                     # Initial cats folder
    ├── cats_gt___cats           # In this folder should be cats pictures (ground truth(gt) cats) and they predicted as cats

As you can see all cats predicted as cats and some dogs predicted as cats.


Data Preparation

Prepare data for training in the following format

├── animals                      # Data folder
    ├── dogs                     # Folder Class 1
    ├── cats                     # Folder Class 2
    ├── gray_bears               # Folder Class 3
    ├── zebras                   # Folder Class 4
    ├── ...

Detailed Quick Start

import image_classification_pytorch as icp

# add model
# your can add several models for consistent training
tf_efficientnet_b4_ns = {'model_type': 'tf_efficientnet_b4_ns', 
                         'im_size': 380, 
                         'im_size_test': 380, 
                         'batch_size': 8, 
                         'mean': [0.485, 0.456, 0.406], 
                         'std': [0.229, 0.224, 0.225]}
                         
ens_adv_inception_resnet_v2 = {'model_type': 'ens_adv_inception_resnet_v2',
                               'im_size': 256,
                               'im_size_test': 256,
                               'batch_size': 8,
                               'mean': [0.5, 0.5, 0.5],
                               'std': [0.5, 0.5, 0.5]}
                         
models = [tf_efficientnet_b4_ns, ens_adv_inception_resnet_v2]

# create trainer
trainer = ICPTrainer(models=models, 
                     data_dir='my_data',
                     images_ext='jpg',
                     init_lr=1e-5,
                     max_epochs=500,
                     augment_p=0.7,
                     progress_bar_refresh_rate=10,
                     early_stop_patience=6,
                     optimizer=Adam(),
                     scheduler=ExponentialLR())

# start training
trainer.fit_test()

Models

323 models out of the box. Used models from PyTorch Image Models (timm)

Models
  1. adv_inception_v3
  2. cspdarknet53
  3. cspresnet50
  4. cspresnext50
  5. densenet121
  6. densenet161
  7. densenet169
  8. densenet201
  9. densenetblur121d
  10. dla102
  11. dla102x
  12. dla102x2
  13. dla169
  14. dla34
  15. dla46_c
  16. dla46x_c
  17. dla60_res2net
  18. dla60_res2next
  19. dla60
  20. dla60x_c
  21. dla60x
  22. dm_nfnet_f0
  23. dm_nfnet_f1
  24. dm_nfnet_f2
  25. dm_nfnet_f3
  26. dm_nfnet_f4
  27. dm_nfnet_f5
  28. dm_nfnet_f6
  29. dpn107
  30. dpn131
  31. dpn68
  32. dpn68b
  33. dpn92
  34. dpn98
  35. ecaresnet101d_pruned
  36. ecaresnet101d
  37. ecaresnet269d
  38. ecaresnet26t
  39. ecaresnet50d_pruned
  40. ecaresnet50d
  41. ecaresnet50t
  42. ecaresnetlight
  43. efficientnet_b0
  44. efficientnet_b1_pruned
  45. efficientnet_b1
  46. efficientnet_b2
  47. efficientnet_b2a
  48. efficientnet_b3_pruned
  49. efficientnet_b3
  50. efficientnet_b3a
  51. efficientnet_em
  52. efficientnet_es
  53. efficientnet_lite0
  54. ens_adv_inception_resnet_v2
  55. ese_vovnet19b_dw
  56. ese_vovnet39b
  57. fbnetc_100
  58. gernet_l
  59. gernet_m
  60. gernet_s
  61. gluon_inception_v3
  62. gluon_resnet101_v1b
  63. gluon_resnet101_v1c
  64. gluon_resnet101_v1d
  65. gluon_resnet101_v1s
  66. gluon_resnet152_v1b
  67. gluon_resnet152_v1c
  68. gluon_resnet152_v1d
  69. gluon_resnet152_v1s
  70. gluon_resnet18_v1b
  71. gluon_resnet34_v1b
  72. gluon_resnet50_v1b
  73. gluon_resnet50_v1c
  74. gluon_resnet50_v1d
  75. gluon_resnet50_v1s
  76. gluon_resnext101_32x4d
  77. gluon_resnext101_64x4d
  78. gluon_resnext50_32x4d
  79. gluon_senet154
  80. gluon_seresnext101_32x4d
  81. gluon_seresnext101_64x4d
  82. gluon_seresnext50_32x4d
  83. gluon_xception65
  84. hrnet_w18_small_v2
  85. hrnet_w18_small
  86. hrnet_w18
  87. hrnet_w30
  88. hrnet_w32
  89. hrnet_w40
  90. hrnet_w44
  91. hrnet_w48
  92. hrnet_w64
  93. ig_resnext101_32x16d
  94. ig_resnext101_32x32d
  95. ig_resnext101_32x48d
  96. ig_resnext101_32x8d
  97. inception_resnet_v2
  98. inception_v3
  99. inception_v4
  100. legacy_senet154
  101. legacy_seresnet101
  102. legacy_seresnet152
  103. legacy_seresnet18
  104. legacy_seresnet34
  105. legacy_seresnet50
  106. legacy_seresnext101_32x4d
  107. legacy_seresnext26_32x4d
  108. legacy_seresnext50_32x4d
  109. mixnet_l
  110. mixnet_m
  111. mixnet_s
  112. mixnet_xl
  113. mnasnet_100
  114. mobilenetv2_100
  115. mobilenetv2_110d
  116. mobilenetv2_120d
  117. mobilenetv2_140
  118. mobilenetv3_large_100
  119. mobilenetv3_rw
  120. nasnetalarge
  121. nf_regnet_b1
  122. nf_resnet50
  123. nfnet_l0c
  124. pnasnet5large
  125. regnetx_002
  126. regnetx_004
  127. regnetx_006
  128. regnetx_008
  129. regnetx_016
  130. regnetx_032
  131. regnetx_040
  132. regnetx_064
  133. regnetx_080
  134. regnetx_120
  135. regnetx_160
  136. regnetx_320
  137. regnety_002
  138. regnety_004
  139. regnety_006
  140. regnety_008
  141. regnety_016
  142. regnety_032
  143. regnety_040
  144. regnety_064
  145. regnety_080
  146. regnety_120
  147. regnety_160
  148. regnety_320
  149. repvgg_a2
  150. repvgg_b0
  151. repvgg_b1
  152. repvgg_b1g4
  153. repvgg_b2
  154. repvgg_b2g4
  155. repvgg_b3
  156. repvgg_b3g4
  157. res2net101_26w_4s
  158. res2net50_14w_8s
  159. res2net50_26w_4s
  160. res2net50_26w_6s
  161. res2net50_26w_8s
  162. res2net50_48w_2s
  163. res2next50
  164. resnest101e
  165. resnest14d
  166. resnest200e
  167. resnest269e
  168. resnest26d
  169. resnest50d_1s4x24d
  170. resnest50d_4s2x40d
  171. resnest50d
  172. resnet101d
  173. resnet152d
  174. resnet18
  175. resnet18d
  176. resnet200d
  177. resnet26
  178. resnet26d
  179. resnet34
  180. resnet34d
  181. resnet50
  182. resnet50d
  183. resnetblur50
  184. resnetv2_101x1_bitm_in21k
  185. resnetv2_101x1_bitm
  186. resnetv2_101x3_bitm_in21k
  187. resnetv2_101x3_bitm
  188. resnetv2_152x2_bitm_in21k
  189. resnetv2_152x2_bitm
  190. resnetv2_152x4_bitm_in21k
  191. resnetv2_152x4_bitm
  192. resnetv2_50x1_bitm_in21k
  193. resnetv2_50x1_bitm
  194. resnetv2_50x3_bitm_in21k
  195. resnetv2_50x3_bitm
  196. resnext101_32x8d
  197. resnext50_32x4d
  198. resnext50d_32x4d
  199. rexnet_100
  200. rexnet_130
  201. rexnet_150
  202. rexnet_200
  203. selecsls42b
  204. selecsls60
  205. selecsls60b
  206. semnasnet_100
  207. seresnet152d
  208. seresnet50
  209. seresnext26d_32x4d
  210. seresnext26t_32x4d
  211. seresnext50_32x4d
  212. skresnet18
  213. skresnet34
  214. skresnext50_32x4d
  215. spnasnet_100
  216. ssl_resnet18
  217. ssl_resnet50
  218. ssl_resnext101_32x16d
  219. ssl_resnext101_32x4d
  220. ssl_resnext101_32x8d
  221. ssl_resnext50_32x4d
  222. swsl_resnet18
  223. swsl_resnet50
  224. swsl_resnext101_32x16d
  225. swsl_resnext101_32x4d
  226. swsl_resnext101_32x8d
  227. swsl_resnext50_32x4d
  228. tf_efficientnet_b0_ap
  229. tf_efficientnet_b0_ns
  230. tf_efficientnet_b0
  231. tf_efficientnet_b1_ap
  232. tf_efficientnet_b1_ns
  233. tf_efficientnet_b1
  234. tf_efficientnet_b2_ap
  235. tf_efficientnet_b2_ns
  236. tf_efficientnet_b2
  237. tf_efficientnet_b3_ap
  238. tf_efficientnet_b3_ns
  239. tf_efficientnet_b3
  240. tf_efficientnet_b4_ap
  241. tf_efficientnet_b4_ns
  242. tf_efficientnet_b4
  243. tf_efficientnet_b5_ap
  244. tf_efficientnet_b5_ns
  245. tf_efficientnet_b5
  246. tf_efficientnet_b6_ap
  247. tf_efficientnet_b6_ns
  248. tf_efficientnet_b6
  249. tf_efficientnet_b7_ap
  250. tf_efficientnet_b7_ns
  251. tf_efficientnet_b7
  252. tf_efficientnet_b8_ap
  253. tf_efficientnet_b8
  254. tf_efficientnet_cc_b0_4e
  255. tf_efficientnet_cc_b0_8e
  256. tf_efficientnet_cc_b1_8e
  257. tf_efficientnet_el
  258. tf_efficientnet_em
  259. tf_efficientnet_es
  260. tf_efficientnet_l2_ns_475
  261. tf_efficientnet_l2_ns
  262. tf_efficientnet_lite0
  263. tf_efficientnet_lite1
  264. tf_efficientnet_lite2
  265. tf_efficientnet_lite3
  266. tf_efficientnet_lite4
  267. tf_inception_v3
  268. tf_mixnet_l
  269. tf_mixnet_m
  270. tf_mixnet_s
  271. tf_mobilenetv3_large_075
  272. tf_mobilenetv3_large_100
  273. tf_mobilenetv3_large_minimal_100
  274. tf_mobilenetv3_small_075
  275. tf_mobilenetv3_small_100
  276. tf_mobilenetv3_small_minimal_100
  277. tresnet_l_448
  278. tresnet_l
  279. tresnet_m_448
  280. tresnet_m
  281. tresnet_xl_448
  282. tresnet_xl
  283. tv_densenet121
  284. tv_resnet101
  285. tv_resnet152
  286. tv_resnet34
  287. tv_resnet50
  288. tv_resnext50_32x4d
  289. vgg11_bn
  290. vgg11
  291. vgg13_bn
  292. vgg13
  293. vgg16_bn
  294. vgg16
  295. vgg19_bn
  296. vgg19
  297. vit_base_patch16_224_in21k
  298. vit_base_patch16_224
  299. vit_base_patch16_384
  300. vit_base_patch32_224_in21k
  301. vit_base_patch32_384
  302. vit_base_resnet50_224_in21k
  303. vit_base_resnet50_384
  304. vit_deit_base_distilled_patch16_224
  305. vit_deit_base_distilled_patch16_384
  306. vit_deit_base_patch16_224
  307. vit_deit_base_patch16_384
  308. vit_deit_small_distilled_patch16_224
  309. vit_deit_small_patch16_224
  310. vit_deit_tiny_distilled_patch16_224
  311. vit_deit_tiny_patch16_224
  312. vit_large_patch16_224_in21k
  313. vit_large_patch16_224
  314. vit_large_patch16_384
  315. vit_large_patch32_224_in21k
  316. vit_large_patch32_384
  317. vit_small_patch16_224
  318. wide_resnet101_2
  319. wide_resnet50_2
  320. xception
  321. xception41
  322. xception65
  323. xception71

Get List Models

import timm
from pprint import pprint
model_names = timm.list_models(pretrained=True)
pprint(model_names)
>>> ['adv_inception_v3',
 'cspdarknet53',
 'cspresnext50',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'densenetblur121d',
 'dla34',
 'dla46_c',
...
]

Get Model Parameters

import timm
from pprint import pprint
m = timm.create_model('efficientnet_b0', pretrained=True)
pprint(m.default_cfg)

Timm documentation here


License

Project is distributed under MIT License

About

Easy train models for image classification

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages