In [1]:
import time

import torch
import torchvision
from numpy import mean
from itertools import product
from torchvision.transforms import transforms
from copy import deepcopy


In [2]:
sizes = [
    512,
    1024,
    2048
]

In [3]:
networks = [
    ('vgg16', torch.nn.Sequential(*list(torchvision.models.vgg16_bn(pretrained=True).children())[:-2])),
    ('resnet50', torch.nn.Sequential(*list(torchvision.models.resnet50(pretrained=True).children())[:-2])),
    ('mobilenet_v2', torch.nn.Sequential(*list(torchvision.models.mobilenet_v2(pretrained=True).children())[:-1]))
]

In [4]:
tensor_core_usage = [
    False,
    True
]

In [5]:
for size, network_tuple, use_tensor_cores in product(sizes, networks, tensor_core_usage):
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    dataset = torchvision.datasets.CelebA(root='./data', download=True,
                                          transform=transform)
    loader = torch.utils.data.DataLoader(dataset, batch_size=1,
                                         shuffle=False, num_workers=1)
    network_name = network_tuple[0]
    network = deepcopy(network_tuple[1]).eval()
    if use_tensor_cores:
        network = network.half()
    else:
        network = network.float()
    network = network.cuda()

    iterator = iter(loader)
    with torch.no_grad():
        for i in range(10):
            img = next(iterator)[0]
            if use_tensor_cores:
                img = img.half()
            network(img.cuda())

        times = []
        for i in range(100):
            img = next(iterator)[0]
            if use_tensor_cores:
                img = img.half()
            start = time.perf_counter()
            network(img.cuda())
            end = time.perf_counter()
            times.append(end - start)
        print(f'{size} {network_name} {use_tensor_cores} {1 / (mean(times)):.2f} FPS')

    del network
    torch.cuda.empty_cache()

Files already downloaded and verified
512 vgg16 False 31.78 FPS
Files already downloaded and verified
512 vgg16 True 72.78 FPS
Files already downloaded and verified
512 resnet50 False 57.06 FPS
Files already downloaded and verified
512 resnet50 True 111.61 FPS
Files already downloaded and verified
512 mobilenet_v2 False 121.92 FPS
Files already downloaded and verified
512 mobilenet_v2 True 126.16 FPS
Files already downloaded and verified
1024 vgg16 False 7.80 FPS
Files already downloaded and verified
1024 vgg16 True 18.69 FPS
Files already downloaded and verified
1024 resnet50 False 16.16 FPS
Files already downloaded and verified
1024 resnet50 True 86.26 FPS
Files already downloaded and verified
1024 mobilenet_v2 False 72.45 FPS
Files already downloaded and verified
1024 mobilenet_v2 True 102.41 FPS
Files already downloaded and verified
2048 vgg16 False 1.89 FPS
Files already downloaded and verified
2048 vgg16 True 4.58 FPS
Files already downloaded and verified
2048 resnet50 False 4.16