# Inference with pre-trained models



## Installation

```
pip install openmim, mmengine
mim install mmcv>=2.0.0rc mmcls>=1.0.0rc
```

## Config and Checkpoint File

**Config file**

All items required to define a training *experiment*, including

- model
- dataset and data augmentation pipeline
- training algorithms and learning rate policies
- runtime config of the program

**Checkpoint file**

A pytorch `pth` file, containing the `state_dict` of a model as well as some meta information.
A checkpoint file is usually trained from a config file, using the toolbox.

We can navigate to [home page](https://github.com/open-mmlab/mmclassification) to select a model
and download corresponding config & checkpoint file using `mim`.

```
mim download mmcls --config mobilenet-v2_8xb32_in1k --dest . 
```

## Inference using high-level API

In [1]:
# Necessary to register all modules
from mmcls.utils import register_all_modules
register_all_modules()

In [2]:
from mmcls.apis import init_model, inference_model

In [3]:
model = init_model('mobilenet-v2_8xb32_in1k.py', 
                   'mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth', 
                   device='cuda:0')

local loads checkpoint from path: mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth




In [4]:
result = inference_model(model, 'banana.png')

In [5]:
result

{'pred_label': 954, 'pred_score': 0.9999284744262695, 'pred_class': 'banana'}

## PyTorch codes under the hood

Let write some raw PyTorch codes to do the same thing. 

These are actual codes wrapped in high-level APIs.

### construct an `ImageClassifier`

Note: current implementation only allow configs of backbone, neck and classification head instead of Python objects. 

If you want to play with components you can construct them indivisually or fetch them as members of the constructed `ImageClassifier` object. 

In [7]:
from mmcls.models import ImageClassifier

classifier = ImageClassifier(
    backbone=dict(type='MobileNetV2', widen_factor=1.0),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='LinearClsHead',
        num_classes=1000,
        in_channels=1280)
)

### Load trained parameters

In [8]:
import torch

ckpt = torch.load('mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth')
classifier.load_state_dict(ckpt['state_dict'])

<All keys matched successfully>

### Construct data preprocessing pipeline

**Important**: A models work only if image preprocessing pipelines is correct.

In [13]:
from mmcv.transforms import Compose

test_pipeline = Compose([
    dict(type='LoadImageFromFile'),
    dict(type='ResizeEdge', scale=256, edge='short', backend='pillow'),
    dict(type='CenterCrop', crop_size=224),
    dict(type='PackClsInputs')
])

In [15]:
data = dict(img_path='banana.png')
data = test_pipeline(data)

#### equivalent in `torchvision` 

In [17]:
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, Normalize, ToTensor

tv_transform = Compose([Resize(256), 
                        CenterCrop(224), 
                        ToTensor(),
                        Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                        ])

image = Image.open('banana.png').convert('RGB')
tv_data = tv_transform(image)

### Forward through the model

In [23]:
## IMPORTANT: set the classifier to eval mode
classifier.eval()

imgs = data['inputs'].unsqueeze(0)
imgs = tv_data.unsqueeze(0)

with torch.no_grad():
    # class probabilities
    prob = classifier.forward(imgs, mode='tensor')[0]
    # class labels
    pred = classifier.predict(imgs)
    # features
    feat = classifier.extract_feat(imgs, stage='neck')[0]
    
print(len(prob))
print(pred)
print(feat.shape)

1000
[<ClsDataSample(

    META INFORMATION

    DATA FIELDS
    pred_label: <LabelData(
        
            META INFORMATION
            num_classes: 1000
        
            DATA FIELDS
            label: tensor([954])
            score: tensor([8.7886e-11, 1.4862e-10, 7.9490e-10, 7.5087e-10, 2.3911e-08, 1.4624e-08,
                        1.6622e-09, 8.6278e-09, 2.2922e-08, 2.8606e-11, 1.8818e-09, 8.1963e-09,
                        4.3835e-09, 1.3013e-09, 6.0647e-10, 1.4825e-09, 9.7490e-09, 2.1946e-09,
                        1.4078e-08, 2.6998e-10, 9.2191e-11, 4.0907e-10, 2.5204e-09, 1.3026e-08,
                        5.2355e-11, 1.9043e-09, 7.6891e-09, 2.0430e-09, 6.6810e-10, 9.5962e-09,
                        1.0180e-09, 1.2636e-08, 2.4030e-09, 1.6970e-09, 1.1531e-09, 8.5339e-10,
                        1.4433e-08, 7.8068e-10, 3.7663e-10, 8.6339e-10, 1.3663e-08, 3.3271e-10,
                        2.1759e-10, 6.4700e-11, 2.0980e-10, 8.3660e-10, 5.2944e-09, 8.7994e-11,
      