<a href="https://colab.research.google.com/github/namoshi/colab/blob/master/convenxt_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ConvNeXt Inference Tutorial

This tutorial will teach how to run ConvNeXt inference using a pretrained model from `timm`

In [None]:
URL = "https://raw.githubusercontent.com/SharanSMenon/swin-transformer-hub/main/imagenet_labels.json" # Imagenet labels
!wget https://www.allaboutbirds.org/guide/assets/photo/306327661-480px.jpg -O house_finch.jpg

--2022-05-07 00:57:19--  https://www.allaboutbirds.org/guide/assets/photo/306327661-480px.jpg
Resolving www.allaboutbirds.org (www.allaboutbirds.org)... 104.26.0.144, 172.67.69.67, 104.26.1.144, ...
Connecting to www.allaboutbirds.org (www.allaboutbirds.org)|104.26.0.144|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 20694 (20K) [image/jpeg]
Saving to: ‘house_finch.jpg’


2022-05-07 00:57:19 (76.7 MB/s) - ‘house_finch.jpg’ saved [20694/20694]



In [None]:
# pick your own image. Notebook will be posted in description for easy copy/paste
import torch
from torchvision import transforms as T
!pip install timm
import timm
from PIL import Image

Collecting timm
  Downloading timm-0.5.4-py3-none-any.whl (431 kB)
[?25l[K     |▊                               | 10 kB 24.2 MB/s eta 0:00:01[K     |█▌                              | 20 kB 29.7 MB/s eta 0:00:01[K     |██▎                             | 30 kB 20.2 MB/s eta 0:00:01[K     |███                             | 40 kB 21.9 MB/s eta 0:00:01[K     |███▉                            | 51 kB 9.6 MB/s eta 0:00:01[K     |████▋                           | 61 kB 11.1 MB/s eta 0:00:01[K     |█████▎                          | 71 kB 10.1 MB/s eta 0:00:01[K     |██████                          | 81 kB 11.0 MB/s eta 0:00:01[K     |██████▉                         | 92 kB 12.1 MB/s eta 0:00:01[K     |███████▋                        | 102 kB 10.2 MB/s eta 0:00:01[K     |████████▍                       | 112 kB 10.2 MB/s eta 0:00:01[K     |█████████▏                      | 122 kB 10.2 MB/s eta 0:00:01[K     |█████████▉                      | 133 kB 10.2 MB/s eta 0:00:01

In [None]:
timm.list_models("convnext*")

['convnext_base',
 'convnext_base_384_in22ft1k',
 'convnext_base_in22ft1k',
 'convnext_base_in22k',
 'convnext_large',
 'convnext_large_384_in22ft1k',
 'convnext_large_in22ft1k',
 'convnext_large_in22k',
 'convnext_small',
 'convnext_tiny',
 'convnext_tiny_hnf',
 'convnext_xlarge_384_in22ft1k',
 'convnext_xlarge_in22ft1k',
 'convnext_xlarge_in22k']

In [None]:
model = timm.create_model("convnext_base_384_in22ft1k", pretrained=True) # Will take a moment

Downloading: "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth" to /root/.cache/torch/hub/checkpoints/convnext_base_22k_1k_384.pth


In [None]:
model

ConvNeXt(
  (stem): Sequential(
    (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
    (1): LayerNorm2d((128,), eps=1e-06, elementwise_affine=True)
  )
  (stages): Sequential(
    (0): ConvNeXtStage(
      (downsample): Identity()
      (blocks): Sequential(
        (0): ConvNeXtBlock(
          (conv_dw): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
          (norm): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=128, out_features=512, bias=True)
            (act): GELU()
            (drop1): Dropout(p=0.0, inplace=False)
            (fc2): Linear(in_features=512, out_features=128, bias=True)
            (drop2): Dropout(p=0.0, inplace=False)
          )
          (drop_path): Identity()
        )
        (1): ConvNeXtBlock(
          (conv_dw): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
          (norm): LayerNorm((128,), eps=1e-06, elemen

In [None]:
from timm import data

In [None]:
# should have 1000 outputs
model.eval() # set to inference mode
trans_ = T.Compose([
                    T.Resize(256),
                    T.CenterCrop(224),
                    T.ToTensor(),
                    T.Normalize(timm.data.IMAGENET_DEFAULT_MEAN, timm.data.IMAGENET_DEFAULT_STD)
])

In [None]:
image = Image.open("house_finch.jpg")

In [None]:
transformed = trans_(image)
batched = transformed.unsqueeze(0)

In [None]:
with torch.no_grad():
  out = model(batched)

In [None]:
pred = out.argmax(dim=1)
pred

tensor([12])

In [None]:
# the class for house finch is 12 in the imagenet labels file. We will now load the JSON file.
import json
from urllib.request import urlopen

In [None]:
res = urlopen(URL)
classes = json.loads(res.read())
len(classes)

1000

In [None]:
classes[pred.item()] # As you can see, convnext correctly classified our house finch.

'house finch'