# some prep with needed modules

In [112]:
from torch.utils.data import Dataset, DataLoader
import torchvision
from PIL import Image
import ttach as tta

import pandas as pd
import urllib.request

import albumentations as A
import albumentations.pytorch

In [113]:
# setting device to cuda if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# loading the model

In [114]:
def load_checkpoint(filepath):
    
    checkpoint = torch.load(filepath, map_location=device)
    model = checkpoint['model']
    model.load_state_dict(checkpoint['state_dict'])
    for parameter in model.parameters():
        parameter.requires_grad = False
    
    model.eval()
    
    return model

model = load_checkpoint('../input/checkpoint-for-george/checkpoint.pth')

# dataset

In [115]:
class dataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df=df
        self.transforms=transforms
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index):
        image = Image.open(self.df.fnames[index]).convert('RGB')
        image = np.array(image)
        
        label = torch.tensor(self.df.label[index]).long()
        
        if self.transforms:
            augmentations = self.transforms(image=image)
            image = augmentations['image']
            
        return image, label

# transforms for dataset

In [116]:
test_transforms = A.Compose([
    A.SmallestMaxSize(256),
    A.CenterCrop(256, 256),
    A.Normalize(),
    albumentations.pytorch.ToTensorV2()
    ])

# df builder

In [117]:
# helper function to build a df for dataset later
def test_df_builder(url, filename='test.jpg'):

    urllib.request.urlretrieve(url, filename)
    test_df = pd.DataFrame({'fnames': [filename], 'label':[1]})
    
    return test_df

# predict function

In [118]:
# this will build a pipeline and predict on an image from url
def predict(url=url, model=model, tta_switch=True):
    
    test_df = test_df_builder(url)
    test_dataset = dataset(test_df, transforms=test_transforms)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1)
    
    if tta:
        
        model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform(int(256*0.9), int(256*0.9)))
    
    model.eval()
    
    with torch.no_grad():
        
        for batch, (images, labels) in enumerate(test_loader):
            
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            
    return ('This is George' if preds.item()==1 else 'Not George')

# try to predict (add your urls)

In [119]:
urls = [
    'https://i.pinimg.com/originals/f8/50/d5/f850d5b9cea781d6a67c599b338af8c4.jpg',
    'https://ichef.bbci.co.uk/news/976/cpsprodpb/41CF/production/_109474861_angrycat-index-getty3-3.jpg',
    'https://scene7.zumiez.com/is/image/zumiez/product_main_medium_2x/Primitive-x-Rick-and-Morty-Pickle-Rick-Sticker-_309256-front-US.jpg'
]

In [120]:
for url in urls:
    print(predict(url))

This is George
Not George
Not George


In [121]:
for url in urls:
    print(predict(url, tta_switch=False))

This is George
Not George
Not George
