In [None]:
# this NB will run inference

In [18]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
from PIL import Image
import ttach as tta

import numpy as np

import pandas as pd
import urllib.request

import albumentations as A
import albumentations.pytorch

In [19]:
# setting device to cuda if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [20]:
presize = 224
crop = 224

num_workers = 0

In [21]:
# classes the model was trained on
classes = ['HYUNDAI_Sonata_Silver', 'MERCEDES-BENZ_E 350_Black',
       'TOYOTA_Camry_Black', 'TOYOTA_Camry_White', 'TOYOTA_Prius_White']

# loading the model

In [22]:
def load_checkpoint(filepath):
    
    checkpoint = torch.load(filepath, map_location=device)
    model = checkpoint['model']
    model.load_state_dict(checkpoint['state_dict'])
    for parameter in model.parameters():
        parameter.requires_grad = False
    
    model.eval()
    
    return model

model = load_checkpoint(r'checkpoint.pth')

# dataset

In [23]:
class dataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df=df
        self.transforms=transforms
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index):
        image = Image.open(self.df.fname[index]).convert('RGB')
        image = np.array(image)
        
        label = torch.tensor(self.df.label[index]).long()
        
        if self.transforms:
            augmentations = self.transforms(image=image)
            image = augmentations['image']
            
        return image, label

# transforms

In [24]:
test_transforms = A.Compose([
    A.SmallestMaxSize(presize),
    A.CenterCrop(crop, crop),
    A.Normalize(),
    albumentations.pytorch.ToTensorV2()
    ])

# df builder

In [25]:
# we will put every sample through our usuall pipeline we used when we built out model
# this cell will build df
def test_df_builder(url, filename='test.jpg'):

    urllib.request.urlretrieve(url, filename)
    test_df = pd.DataFrame({'fname': [filename], 'label':[1]})
    
    return test_df

# predict function

In [26]:
def predict(url, model=model, tta_switch=True, classes=classes):
    
    test_df = test_df_builder(url)
    test_dataset = dataset(test_df, transforms=test_transforms)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=num_workers)
    
    if tta:
        
        model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform(int(crop*0.9), int(crop*0.9)))
    
    model.eval()
    
    with torch.no_grad():
        
        for batch, (images, labels) in enumerate(test_loader):
            
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            
    return (classes[preds.item()])

# try to predict (add your urls)

In [27]:
# the results are accurate because the photos provided are representative of what was used for training (same distribution)
# only about 400+ photos in total were used to train the model

urls = [
    'https://i.ytimg.com/vi/qyh8ZRMhLFk/maxresdefault.jpg',
    'https://cdn-ds.com/blogs-media/sites/231/2019/01/20084856/How-Safe-is-the-2019-Toyota-Camry-A_O.png',
    'https://content.homenetiol.com/2000292/2175740/0x0/193d3d8f09404aa7b397354b031c33ad.jpg',
    'https://smgmedia.blob.core.windows.net/images/105075/1024/toyota-unlisted-hatchback-0f33fc2b8a52.jpg',
    'https://www.batfa.com/photo-used-car-toyota-prius-2015-model-pearl-color.files/Prius2015pearl-rear.jpg',
    'https://static.cargurus.com/images/site/2011/02/02/08/39/2011_hyundai_sonata_gls-pic-2044408969940823943-1600x1200.jpeg',
    'https://m.media-amazon.com/images/I/71ZoAfzH22L._UY560_.jpg',
    'https://autompv.ru/wp-content/uploads/2020/08/novaya-toyota-camry-ws-black-edition-2020%E2%80%942021.jpg'
    ]

# THESE ARE THE LABELS FOR URLS
# - toyota camry white
# - toyota camry white
# - black merc
# - TOYOTA_Prius_White
# - TOYOTA_Prius_White
# - HYUNDAI_Sonata_Silver
# - HYUNDAI_Sonata_Silver
# - TOYOTA_Camry_Black


In [28]:
# with TTA
for idx, url in enumerate(urls):
    print(f'photo number {idx+1} is {predict(url)}')

photo number 1 is TOYOTA_Camry_White
photo number 2 is TOYOTA_Camry_White
photo number 3 is MERCEDES-BENZ_E 350_Black
photo number 4 is TOYOTA_Prius_White
photo number 5 is TOYOTA_Prius_White
photo number 6 is HYUNDAI_Sonata_Silver
photo number 7 is HYUNDAI_Sonata_Silver
photo number 8 is TOYOTA_Camry_Black


In [29]:
# w/o the TTA
for idx, url in enumerate(urls):
    print(f'photo number {idx+1} is {predict(url, tta_switch=False)}')

photo number 1 is TOYOTA_Camry_White
photo number 2 is TOYOTA_Camry_White
photo number 3 is MERCEDES-BENZ_E 350_Black
photo number 4 is TOYOTA_Prius_White
photo number 5 is TOYOTA_Prius_White
photo number 6 is HYUNDAI_Sonata_Silver
photo number 7 is HYUNDAI_Sonata_Silver
photo number 8 is TOYOTA_Camry_Black
