In [None]:
import os
import glob
import time
import random

import torch
import torchvision

import numpy as np
import cv2
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch2trt

In [None]:
LABELS = ['Cat', 'Dog']

if torch.cuda.is_available():
    TARGET = 'cuda'
else:
    TARGET = 'cpu'

In [None]:
class CatDog():
    def __init__(self, path):
        self.img_size = 100
    
        self.train_data = []
        self.cat_path = os.path.normpath(path) + '/Cat'
        self.dog_path = os.path.normpath(path) + '/Dog'
        
        self.labels = {self.cat_path: 0, self.dog_path: 1}
    
    def get_train_data(self):
        for label in self.labels:
            images_path = glob.glob(label + '/*.jpg')
            images_path.sort()
            
            for image_path in tqdm(images_path):
                try:
                    img = cv2.imread(image_path)
                    img = cv2.resize(img, (self.img_size, self.img_size))
                    
                    img = Image.fromarray(img)
                    img = torchvision.transforms.ToTensor()(img)
                    
                    ans = np.eye(2)[self.labels[label]]
                    ans = torch.tensor(ans, dtype=torch.float32)
                    
                    self.train_data.append([img, ans])
                except:
                    pass
        
        return self.train_data

class CatDogDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.__data = data
        
    def __getitem__(self, index):
        return (self.__data[index][0], self.__data[index][1])
    
    def __len__(self):
        return len(self.__data)

In [None]:
cat_dog = CatDog(path='PetImages/')
cat_dog_dataset = cat_dog.get_train_data()

np.random.shuffle(cat_dog_dataset)
print('Dataset length', len(cat_dog_dataset))

cat_dog_test = cat_dog_dataset[int(len(cat_dog_dataset) * 0.8):]

print('Test length', len(cat_dog_test))

test = CatDogDataset(cat_dog_test)

In [None]:
net = torchvision.models.resnet18(pretrained=False, num_classes=2).to(TARGET)
net.load_state_dict(torch.load('resnet18_adv_loss.pt'))

In [None]:
i = random.randint(0, len(test) - 1)
print('Item', i)

start = time.time()
net_result = net(test[i][0].unsqueeze(0).to(TARGET))
print('Compute time:', time.time() - start)

real = test[i][1]

p = np.argmax(net_result.cpu().detach().numpy())
r = np.argmax(real.detach().numpy())

print('Predict:', LABELS[p])
print('Real:', LABELS[r])

img = np.array(test[i][0].permute(1, 2, 0))
img = img[:, :, ::-1].copy()
plt.imshow(img)
plt.show()

In [None]:
x = torch.zeros((1, 3, 100, 100)).to(TARGET)

trt_net = torch2trt.torch2trt(net, [x])

In [None]:
i = random.randint(0, len(test) - 1)
print('Item', i)

start = time.time()
net_result = trt_net(test[i][0].unsqueeze(0).to(TARGET))
print('Compute time:', time.time() - start)

real = test[i][1]

p = np.argmax(net_result.cpu().detach().numpy())
r = np.argmax(real.detach().numpy())

print('Predict:', LABELS[p])
print('Real:', LABELS[r])

img = np.array(test[i][0].permute(1, 2, 0))
img = img[:, :, ::-1].copy()
plt.imshow(img)
plt.show()

In [None]:
torch.save(trt_net.state_dict(), 'resnet18_trt.pt')