# Fast.ai : best vision models

*This is a reproduction of Jeremy Howards' notebook: ['The best vision models for fine tuning'](https://www.kaggle.com/code/jhoward/the-best-vision-models-for-fine-tuning/notebook) were I plugged in the results from my own model fine-tuning sweep runs.*

In [3]:
!pip install statsmodels >/dev/null 2>&1
!pip install fastai >/dev/null 2>&1
!pip install plotly >/dev/null 2>&1

## Introduction

In a recent notebook I tried to answer the question "[Which image models are best?](https://www.kaggle.com/code/jhoward/which-image-models-are-best)" This showed which models in Ross Wightman's [PyTorch Image Models](https://timm.fast.ai/) (*timm*) were the fastest and most accurate for training from scratch with Imagenet.

However, this is not what most of us use models for. Most of us fine-tune pretrained models. Therefore, what most of us really want to know is which models are the fastest and most accurate for fine-tuning. However, this analysis has not, to my knowledge, previously existed.

Therefore I teamed up with [Thomas Capelle](https://tcapelle.github.io/about/) of [Weights and Biases](https://wandb.ai/) to answer this question. In this notebook, I present our results.

## The analysis

There are two key dimensions on which datasets can vary when it comes to how well they fine-tune a model:

1. How similar they are to the pre-trained model's dataset
2. How large they are.

Therefore, we decided to test on two datasets that were very different on both of these axes. We tested pre-trained models that were trained on Imagenet, and tested fine-tuning on two different datasets:

1. The [Oxford IIT-Pet Dataset](https://www.robots.ox.ac.uk/~vgg/data/pets/), which is very similar to Imagenet. Imagenet contains many pictures of animals, and each picture is a photo in which the animal is the main subject. IIT-Pet contains nearly 15,000 images, that are also of this type.
2. The [Kaggle Planet](https://www.kaggle.com/c/planet-understanding-the-amazon-from-space/data) sample contains 1,000 satellite images of Earth. There are no images of this kind in Imagenet.

So these two datasets are of very different sizes, and very different in terms of their similarity to Imagenet. Furthermore, they have different types of labels - Planet is a multi-label problem, whereas IIT-Pet is a single label problem.

To test the fine-tuning accuracy of different models, Thomas put together [this script](https://github.com/tcapelle/fastai_timm/blob/main/fine_tune.py). The basic script contains the standard 4 lines of code needed for fastai image recognition models, plus some code to handle various configuration options, such as learning rate and batch size. It was particularly easy to handle in fastai since fastai supports all timm models directly.

Then, to allow us to easily try different configuration options, Thomas created Weights and Biases (*wandb*) YAML files such as [this one](https://github.com/tcapelle/fastai_timm/blob/main/sweep_planets_lr.yaml). This takes advantage of the convenient [wandb "sweeps"](https://wandb.ai/site/sweeps) feature which tries a range of different levels of a model input and tracks the results.

wandb makes it really easy for a group of people to run these kinds of analyses on whatever GPUs they have access to. When you create a sweep using the command-line wandb client, it gives you a command to run to have a computer run experiments for the project. You run that same command on each computer where you want to run experiments. The wandb client automatically ensures that each computer runs different parts of the sweep, and has each on report back its results to the wandb server. You can look at the progress in the wandb web GUI at any time during or after the run. I've got three GPUs in my PC at home, so I ran three copies of the client, with each using a different GPU. Thomas also ran the client on a [Paperspace Gradient](https://gradient.run/notebooks) server.

I liked this approach because I could start and stop the clients any time I wanted, and wandb would automatically handle keeping all the results in sync. When I restarted a client, it would automatically grab from the server whatever the next set of sweep settings were needed. Furthermore, the integration in fastai is really exceptional, thanks particularly to [Boris Dayma](https://github.com/borisdayma), who worked tirelessly to ensure that wandb automatically tracks every aspect of all fastai data processing, model architectures, and optimisation.

## Hyperparameters

We decided to try out all the timm models which had reasonable performance on timm, and which are capable of working with 224x224 px images. We ended up with a list of 86 models and variants to try.

Our first step was to find a good set of hyper-parameters for each model variant and for each dataset. Our experience at fast.ai has been that there's generally not much difference between models and datasets in terms of what hyperparameter settings work well -- and that experience was repeated in this project. Based on some initial sweeps across a smaller number of representative models, on which we found little variation in optimal hyperparameters, in our final sweep we included all combinations of the following options:

- Learning rate (AdamW): 0.008 and 0.02
- Resize method: [Squish](https://docs.fast.ai/vision.augment.html#Resize)
- Pooling type: [Concat](https://docs.fast.ai/layers.html#AdaptiveConcatPool2d) and Average Pooling

For other parameters, we used defaults that we've previously found at fast.ai to be reliable across a range of models and datasets (see the fastai docs for details).

## Analysis

Let's take a look at the data. I've put a CSV of the results into a gist:

In [4]:
from fastai.vision.all import *
import plotly.express as px

url_source = 'https://gist.githubusercontent.com/jph00/959aaf8695e723246b5e21f3cd5deb02/raw/sweep.csv'
url_repro = 'https://gist.githubusercontent.com/eolecvk/5fb35bcd2536e8a78492da6a04686ddf/raw/f8a47c025ff4f5e79bbba09908bf8d3ce5e7114e/fastai_timm_repro.csv'

For each model variant and dataset, for each hyperparameter setting, we did three runs. For the final sweep, we just used the hyperparameter settings listed above.

For each model variant and dataset, I create a group with the minimum error and fit time, and GPU memory use if used. I use the minimum because there might be some reason that a particular run didn't do so well (e.g. maybe there was some resource contention), and I'm mainly interested in knowing what the best case results for a model can be.

I create a "score" which, somewhat arbitrarily combines the accuracy and speed into a single number. I tried a few options until I came up with something that closely matched my own opinions about the tradeoffs between the two. (Feel free of course to fork this notebook and adjust how that's calculated.)

In [54]:
def load_with_score(url):
    df = pd.read_csv(url)
    df['family'] = df.model_name.str.extract('^([a-z]+?(?:v2)?)(?:\d|_|$)')
    df.loc[df.family=='swinv2', 'family'] = 'swin'
    pt_all = df.pivot_table(values=['error_rate','fit_time','GPU_mem'], index=['dataset', 'family', 'model_name'],
                            aggfunc=np.min).reset_index()
    pt_all['score'] = pt_all.error_rate*(pt_all.fit_time+80)
    pt_all = pt_all[pt_all.dataset=='pets'].sort_values('score').reset_index(drop=True)
    return pt_all

pt_all_source = load_with_score(url_source)
pt_all_repro = load_with_score(url_repro)

### IIT Pet

Here's the top 15 models on the IIT Pet dataset, ordered by score:

In [55]:
# Source
pt_all_source.head(15)

Unnamed: 0,dataset,family,model_name,GPU_mem,error_rate,fit_time,score
0,pets,convnext,convnext_tiny_in22k,2.660156,0.044655,94.557838,7.794874
1,pets,swin,swin_s3_tiny_224,3.126953,0.041949,112.2822,8.065961
2,pets,convnext,convnext_tiny,2.660156,0.047361,92.761599,8.182216
3,pets,vit,vit_small_r26_s32_224,3.367188,0.045332,103.240067,8.306554
4,pets,mobilevit,mobilevit_s,2.78125,0.046685,100.770686,8.439222
5,pets,resnetv2,resnetv2_50x1_bit_distilled,3.892578,0.047361,105.952172,8.806939
6,pets,vit,vit_small_patch16_224,2.111328,0.054804,80.739517,8.809135
7,pets,swin,swin_tiny_patch4_window7_224,2.796875,0.048038,105.797015,8.925296
8,pets,swin,swinv2_cr_tiny_ns_224,3.302734,0.042625,129.435368,8.927222
9,pets,resnetrs,resnetrs50,2.419922,0.047361,109.549398,8.977309


As you can see, the [convnext](https://arxiv.org/abs/2201.03545), [swin](https://arxiv.org/abs/2103.14030), and [vit](https://arxiv.org/abs/2010.11929) families are fairly dominent. The excellent showing of `convnext_tiny` matches my view that we should think of this as our default baseline for image recognition today. It's fast, accurate, and not too much of a memory hog. (And according to Ross Wightman, it could be even faster if NVIDIA and PyTorch make some changes to better optimise the operations it relies on!)

`vit_small_patch16` is also a good option -- it's faster and leaner on memory than `convnext_tiny`, although there is some performance cost too.

Interestingly, resnets are still a great option -- especially the [`resnet26d`](https://arxiv.org/abs/1812.01187) variant, which is the fastest in our top 15.

Here's a quick visual representation of the seven model families which look best in the above analysis (the "fit lines" are just there to help visually show where the different families are -- they don't necessarily actually follow a linear fit):

In [56]:
# Repro
pt_all_repro.head(15)

Unnamed: 0,dataset,family,model_name,GPU_mem,error_rate,fit_time,score
0,pets,convnext,convnext_base,5.689453,0.030447,173.295986,7.711984
1,pets,swin,swin_s3_tiny_224,3.148438,0.037889,130.055523,7.958807
2,pets,convnext,convnext_base_in22k,5.697266,0.031123,177.765313,8.022471
3,pets,convnext,convnext_tiny_in22k,2.572266,0.042625,109.049817,8.058284
4,pets,vit,vit_small_patch16_224,2.144531,0.045332,100.08141,8.163367
5,pets,vit,vit_base_patch16_224,4.867188,0.037889,139.352647,8.311066
6,pets,swin,swin_tiny_patch4_window7_224,2.800781,0.040595,124.79394,8.313694
7,pets,vit,vit_base_patch16_224_miil,4.882812,0.038566,136.341941,8.343364
8,pets,convnext,convnext_tiny,2.591797,0.045332,107.641382,8.506072
9,pets,convnext,convnext_tiny_hnf,2.578125,0.044655,111.373022,8.545756


In [57]:
pt_all_repro = pt_all_repro.rename(columns={
    'GPU_mem' : 'GPU_mem_repro',
    'error_rate' : 'error_rate_repro',
    'fit_time' : 'fit_time_repro',
    'score' : 'score_repro'
})

pt_all_repro
#pt_all_repro[['GPU_mem_repro', 'error_rate_repro', 'fit_time_repro', 'score_repro']]

Unnamed: 0,dataset,family,model_name,GPU_mem_repro,error_rate_repro,fit_time_repro,score_repro
0,pets,convnext,convnext_base,5.689453,0.030447,173.295986,7.711984
1,pets,swin,swin_s3_tiny_224,3.148438,0.037889,130.055523,7.958807
2,pets,convnext,convnext_base_in22k,5.697266,0.031123,177.765313,8.022471
3,pets,convnext,convnext_tiny_in22k,2.572266,0.042625,109.049817,8.058284
4,pets,vit,vit_small_patch16_224,2.144531,0.045332,100.081410,8.163367
...,...,...,...,...,...,...,...
80,pets,efficientnetv2,efficientnetv2_rw_t,2.271484,0.067659,197.714413,18.789885
81,pets,levit,levit_128s,0.480469,0.103518,108.035885,19.465145
82,pets,efficientnetv2,efficientnetv2_rw_s,3.136719,0.069689,204.498145,19.826333
83,pets,regnety,regnety_002,0.539062,0.100812,120.410098,20.203723


In [58]:
# Compiled table
pt_all_source = pt_all_source.rename(columns={
    'GPU_mem' : 'GPU_mem_source',
    'error_rate' : 'error_rate_source',
    'fit_time' : 'fit_time_source',
    'score' : 'score_source'
})

pt_all_repro = pt_all_repro.rename(columns={
    'GPU_mem' : 'GPU_mem_repro',
    'error_rate' : 'error_rate_repro',
    'fit_time' : 'fit_time_repro',
    'score' : 'score_repro'
})

In [61]:
pt_all_repro[['model_name','GPU_mem_repro', 'error_rate_repro', 'fit_time_repro', 'score_repro']].set_index('model_name')
pt_all_repro

Unnamed: 0,dataset,family,model_name,GPU_mem_repro,error_rate_repro,fit_time_repro,score_repro
0,pets,convnext,convnext_base,5.689453,0.030447,173.295986,7.711984
1,pets,swin,swin_s3_tiny_224,3.148438,0.037889,130.055523,7.958807
2,pets,convnext,convnext_base_in22k,5.697266,0.031123,177.765313,8.022471
3,pets,convnext,convnext_tiny_in22k,2.572266,0.042625,109.049817,8.058284
4,pets,vit,vit_small_patch16_224,2.144531,0.045332,100.081410,8.163367
...,...,...,...,...,...,...,...
80,pets,efficientnetv2,efficientnetv2_rw_t,2.271484,0.067659,197.714413,18.789885
81,pets,levit,levit_128s,0.480469,0.103518,108.035885,19.465145
82,pets,efficientnetv2,efficientnetv2_rw_s,3.136719,0.069689,204.498145,19.826333
83,pets,regnety,regnety_002,0.539062,0.100812,120.410098,20.203723


In [69]:
pt_all_merged = pt_all_source.join(pt_all_repro[['model_name','GPU_mem_repro', 'error_rate_repro', 'fit_time_repro', 'score_repro']].set_index('model_name'),
                   on='model_name',
                   how='outer'
                  )

def add_delta_column(in_df, col_name_prefix):
    pt_all_merged[f'{col_name_prefix}_%'] = round(
        (pt_all_merged[f'{col_name_prefix}_repro'] - pt_all_merged[f'{col_name_prefix}_source']) / pt_all_merged[f'{col_name_prefix}_source'] * 100, 2)
    return pt_all_merged

res_df = pt_all_merged
for prefix in ['GPU_mem', 'error_rate', 'fit_time', 'score']:
    res_df = add_delta_column(res_df, prefix)

relevant_cols = ['model_name', 'GPU_mem_%', 'error_rate_%', 'fit_time_%', 'score_%']
res_df[relevant_cols].head(15)

Unnamed: 0,model_name,GPU_mem_%,error_rate_%,fit_time_%,score_%
0,convnext_tiny_in22k,-3.3,-4.55,15.33,3.38
1,swin_s3_tiny_224,0.69,-9.68,15.83,-1.33
2,convnext_tiny,-2.57,-4.29,16.04,3.96
3,vit_small_r26_s32_224,-2.78,-7.46,33.93,10.23
4,mobilevit_s,-0.28,27.54,30.39,49.14
5,resnetv2_50x1_bit_distilled,-1.81,-10.0,15.05,-2.28
6,vit_small_patch16_224,1.57,-17.28,23.96,-7.33
7,swin_tiny_patch4_window7_224,0.14,-15.49,17.96,-6.85
8,swinv2_cr_tiny_ns_224,0.65,7.94,14.23,17.43
9,resnetrs50,0.4,5.71,31.3,24.84


Insert comment about performance comparison

In [20]:
# w,h = 900,700
# faves = ['vit','convnext','resnet','levit', 'regnetx', 'swin']
# pt2 = pt[pt.family.isin(faves)]
# px.scatter(pt2, width=w, height=h, x='fit_time', y='error_rate', color='family', hover_name='model_name', trendline="ols",)

This chart shows that there's a big drop-off in performance towards the far left. It seems like there's a big compromise if we want the fastest possible model. It also seems that the best models in terms of accuracy, convnext and swin, aren't able to make great use of the larger capacity of larger models. So an ensemble of smaller models may be effective in some situations.

Note that `vit` doesn't include any larger/slower models, since they only work with larger images. We would recommend trying larger models on your dataset if you have larger images and the resources to handle them.

I particularly like using fast and small models, since I wanted to be able to iterate rapidly to try lots of ideas (see [this notebook](https://www.kaggle.com/code/jhoward/iterate-like-a-grandmaster) for more on this). Here's the top models (based on accuracy) that are smaller and faster than the median model:

In [21]:
# Source
pt_all_source.query("(GPU_mem<2.7) & (fit_time<110)").sort_values("error_rate").head(15).reset_index(drop=True)

Unnamed: 0,dataset,family,model_name,GPU_mem,error_rate,fit_time,score
0,pets,convnext,convnext_tiny_in22k,2.660156,0.044655,94.557838,7.794874
1,pets,convnext,convnext_tiny,2.660156,0.047361,92.761599,8.182216
2,pets,resnetrs,resnetrs50,2.419922,0.047361,109.549398,8.977309
3,pets,regnety,regnety_006,0.914062,0.052097,93.912189,9.06038
4,pets,levit,levit_384,1.699219,0.054127,86.199098,8.995895
5,pets,vit,vit_small_patch16_224,2.111328,0.054804,80.739517,8.809135
6,pets,resnet,resnet50d,2.037109,0.05548,92.989515,9.597521
7,pets,levit,levit_256,1.03125,0.056157,82.68241,9.135755
8,pets,regnetx,regnetx_016,1.369141,0.05954,88.658087,10.041888
9,pets,resnet,resnet26d,1.412109,0.060216,69.395598,8.996078


In [22]:
# Repro
pt_all_repro.query("(GPU_mem<2.7) & (fit_time<110)").sort_values("error_rate").head(15).reset_index(drop=True)

Unnamed: 0,dataset,family,model_name,GPU_mem,error_rate,fit_time,score
0,pets,convnext,convnext_tiny_in22k,2.572266,0.042625,109.049817,8.058284
1,pets,vit,vit_small_patch16_224,2.144531,0.045332,100.08141,8.163367
2,pets,convnext,convnext_tiny,2.591797,0.045332,107.641382,8.506072
3,pets,vit,vit_base_patch32_224,2.291016,0.056834,95.842336,9.993747
4,pets,vit,vit_tiny_patch16_224,1.017578,0.065629,96.447705,11.580129
5,pets,resnet,resnet26,1.279297,0.066306,83.559732,10.844962
6,pets,efficientnet,efficientnet_es,1.474609,0.069012,103.499844,12.663722
7,pets,vit,vit_small_patch32_224,0.763672,0.069689,92.814611,12.043242
8,pets,resnet,resnet34d,1.0625,0.070365,92.829182,12.161189
9,pets,vit,vit_base_patch32_224_sam,2.291016,0.071042,96.760003,12.557374


Insert comment about models performance comparison

...and here's the top 15 models that are the very fastest and most memory efficient:

In [25]:
# Source
pt_all_source.query("(GPU_mem<1.6) & (fit_time<90)").sort_values("error_rate").head(15).reset_index(drop=True)

Unnamed: 0,dataset,family,model_name,GPU_mem,error_rate,fit_time,score
0,pets,levit,levit_256,1.03125,0.056157,82.68241,9.135755
1,pets,regnetx,regnetx_016,1.369141,0.05954,88.658087,10.041888
2,pets,resnet,resnet26d,1.412109,0.060216,69.395598,8.996078
3,pets,levit,levit_192,0.78125,0.060893,82.385787,9.888177
4,pets,vit,vit_tiny_patch16_224,1.074219,0.064276,65.670202,9.363104
5,pets,vit,vit_small_patch32_224,0.775391,0.065629,68.478869,9.744556
6,pets,efficientnet,efficientnet_es_pruned,1.507812,0.066306,69.601242,9.919432
7,pets,efficientnet,efficientnet_es,1.507812,0.066306,69.822634,9.934112
8,pets,resnet,resnet26,1.291016,0.067659,64.398096,9.769834
9,pets,resnet,resnet34,0.951172,0.070365,66.932345,10.338949


[ResNet-RS](https://arxiv.org/abs/2103.07579) performs well here, with lower memory use than convnext but nonetheless high accuracy. A version trained on the larger Imagenet-22k dataset (like `convnext_tiny_in22k` would presumably do even better, and may top the charts!)

[RegNet-y](https://arxiv.org/abs/2101.00590) is impressively miserly in terms of memory use, whilst still achieving high accuracy.

In [31]:
# Repro
pt_all_repro.sort_values("error_rate").reset_index(drop=True)

Unnamed: 0,dataset,family,model_name,GPU_mem,error_rate,fit_time,score
0,pets,convnext,convnext_base,5.689453,0.030447,173.295986,7.711984
1,pets,convnext,convnext_base_in22k,5.697266,0.031123,177.765313,8.022471
2,pets,swin,swin_large_patch4_window7_224,10.515625,0.033830,300.285974,12.864888
3,pets,convnext,convnext_large_in22k,9.615234,0.034506,253.844545,11.519667
4,pets,swin,swin_s3_small_224,6.253906,0.035183,214.181050,10.350070
...,...,...,...,...,...,...,...
80,pets,levit,levit_128,0.664062,0.092693,122.704538,18.789262
81,pets,efficientnet,efficientnet_lite0,1.498047,0.099459,100.963104,17.998365
82,pets,regnety,regnety_002,0.539062,0.100812,120.410098,20.203723
83,pets,levit,levit_128s,0.480469,0.103518,108.035885,19.465145


In [29]:
pt_all_repro.query("(GPU_mem<1.6) & (fit_time<90)").sort_values("error_rate")

Unnamed: 0,dataset,family,model_name,GPU_mem,error_rate,fit_time,score
24,pets,resnet,resnet26,1.279297,0.066306,83.559732,10.844962
39,pets,resnet,resnet26d,1.410156,0.073072,88.670404,12.325036
47,pets,resnet,resnet34,0.970703,0.078484,88.967083,13.261283
41,pets,resnet,resnet18,0.625,0.085927,65.302742,12.485421
50,pets,resnet,resnet18d,0.806641,0.09134,70.762695,13.770611


## Conclusions

It really seems like it's time for a changing of the guard when it comes to computer vision models. There are, as at the time of writing (June 2022) three very clear winners when it comes to fine-tuning pretrained models:

- [convnext](https://arxiv.org/abs/2201.03545)
- [vit](https://arxiv.org/abs/2010.11929)
- [swin](https://arxiv.org/abs/2103.14030) (and [v2](https://arxiv.org/abs/2111.09883)).

[Tanishq Abraham](https://www.kaggle.com/tanlikesmath) studied the top results of a [recent Kaggle computer vision competition](https://www.kaggle.com/c/petfinder-pawpularity-score) and found that the above three approaches did indeed appear to the best approaches. However, there were two other architectures which were also very strong in that competition, but which aren't in our top models above:

- [EfficientNet](https://arxiv.org/abs/1905.11946) and [v2](https://arxiv.org/abs/2104.00298)
- [BEiT](https://arxiv.org/abs/2106.08254).

BEiT isn't there because it's too big to fit on my GPU (even the smallest BEiT model is too big!) This is fixable with gradient accumulation, so perhaps in a future iteration we'll add it in. EfficientNet didn't have any variants that were fast and accurate enough to appear in the top 15 on either dataset. However, it's notoriously fiddly to train, so there might well be some set of hyperparameters that would work for these datasets. Having said that, I'm mainly interested in knowing which architectures can be trained quickly and easily without to much mucking around, so perhaps EfficientNet doesn't really fit here anyway!

Thankfully, it's easy to try lots of different models, especially if you use fastai and timm, because it's literally as easy as changing the model name in one place in your code. Your existing hyperparameters are most likely going to continue to work fine regardless of what model you try. And it's particularly easy if you use [wandb](https://wandb.ai/), since you can start and stop experiments at any time and they'll all be automatically tracked and managed for you.

If you found this notebook useful, please remember to click the little up-arrow at the top to upvote it, since I like to know when people have found my work useful, and it helps others find it too. And if you have any questions or comments, please pop them below -- I read every comment I receive!