In [2]:
import torch
from torch import nn
import torchvision
import torchvision.transforms as T
from torchvision.models import ResNet, resnet18, resnet50
from PIL import Image


In [70]:
class ResnetPredictor(nn.Module):

    def __init__(self, model: ResNet):
        super().__init__()
        self.model = model
        self.transforms = nn.Sequential(
            T.Resize([256, ]),  # We use single int value inside a list due to torchscript type restrictions
            T.CenterCrop(224),
            T.ConvertImageDtype(torch.float),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            x = self.transforms(x)
            _ = self.model(x)
            y_pred: torch.Tensor = nn.Softmax(dim=1)(_).topk(5)
            clazz = y_pred.values.clone().detach().cpu().tolist()
            confs = y_pred.indices.clone().detach().cpu().tolist()
            # topk = y_pred.topk(5)
            return clazz, confs

In [37]:
r18 = resnet18(pretrained=True, progress=True).eval().cuda()
r50 = resnet50(pretrained=True, progress=True).eval().cuda()
model_version_mapping = {
    50: r50,
    18: r18
}


In [34]:
pil_image = Image.open("./imgs/dogs.jpg")
if pil_image.mode == 'RGBA':
    pil_image = pil_image.convert('RGB')
img = T.ToTensor()(pil_image).unsqueeze_(0).cuda()


In [71]:
predictor = ResnetPredictor(r50).to(0)
res = predictor(img)

In [72]:
res

([[0.9793123006820679,
   0.013081071898341179,
   0.0036804289557039738,
   0.001444189460016787,
   0.0007551226881332695]],
 [[208, 207, 852, 222, 162]])

In [76]:
clazz_zh = [[c] for c in res[0][0]]
clazz_zh

[[0.9793123006820679],
 [0.013081071898341179],
 [0.0036804289557039738],
 [0.001444189460016787],
 [0.0007551226881332695]]