In [49]:
import torch
import torchvision
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from torchvision import models

from time import time
from sklearn.metrics import accuracy_score


## Check CUDA

In [44]:
if torch.cuda.is_available():
    cudnn.benchmark = True
    device = "cuda"
    print(torch.cuda.get_device_name())
else:
    device = "cpu"
    print("Use CPU")


Quadro RTX 3000 with Max-Q Design


## Load data

In [45]:
transform_valid = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

In [None]:
valid_set = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform_valid, download=True)
valid_set = Subset(valid_set, list(range(200, 300)))
valid_loader = DataLoader(valid_set, batch_size=5, shuffle=True)
print(f'## Validation set has {len(valid_set)} instances.')

Files already downloaded and verified
## Validation set has 10000 instances.


## Load model

In [47]:
PATH = r"my_weights\Resnet18_e50_b5_t70_v30.pth"
model = models.resnet18().to(device)
model.load_state_dict(torch.load(PATH))

<All keys matched successfully>

## Inference

In [50]:
# warm up 
print(f'## start warm up')
dummy_data = torch.randn(5, 3, 32, 32).to(device)
for _ in range(500):
    _ = model(dummy_data)
print(f'## finished warm up')

# calculate accuracy
with torch.no_grad():
    sum_vacc = 0.0
    for i, vdata in enumerate(valid_loader):
        vinputs, vlabels = vdata[0].to(device), vdata[1].to(device)
        model.eval()
        voutputs = model(vinputs)
        vacc = accuracy_score(vlabels.cpu(), voutputs.argmax(dim=1).cpu())
        sum_vacc += vacc
avg_vacc = sum_vacc / (i + 1)

# calculate time
start_time = time()
with torch.no_grad():
    for i, vdata in enumerate(valid_loader):
        vinputs, vlabels = vdata[0].to(device), vdata[1].to(device)
        model.eval()
        model(vinputs)

end_time = time()

print(f"## Accuracy: {avg_vacc:.4f}")
print(f"## Inference {len(valid_set)} data with {end_time-start_time} sec.")

## start warm up
## finished warm up
## Accuracy: 0.0969
## Inference 10000 data with 9.37735915184021 sec.
