In [1]:
import torch
import torchvision
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from torchvision import models

from time import time
from sklearn.metrics import accuracy_score


## Check CUDA

In [2]:
if torch.cuda.is_available():
    cudnn.benchmark = True
    device = "cuda"
    print(torch.cuda.get_device_name())
else:
    device = "cpu"
    print("Use CPU")


Quadro RTX 3000 with Max-Q Design


## Load data

In [3]:
transform_valid = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

In [4]:
BATCH = 5
valid_set = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform_valid, download=True)
valid_set = Subset(valid_set, list(range(70, 100)))
valid_loader = DataLoader(valid_set, batch_size=BATCH, shuffle=False)
print(f'## Validation set has {len(valid_set)} instances.')

Files already downloaded and verified
## Validation set has 30 instances.


## Load model

In [13]:
PATH = r"my_weights\Resnet18_e25_b5_t70_v30(global L1 sp=0.2).pth"
model = models.resnet18().to(device)
model.load_state_dict(torch.load(PATH))

  model.load_state_dict(torch.load(PATH))


<All keys matched successfully>

## Inference

In [14]:
model.eval()

# warm up 
print(f'## start warm up')
dummy_data = torch.randn(5, 3, 32, 32).to(device)
for _ in range(500):
    _ = model(dummy_data)
print(f'## finished warm up')

# calculate accuracy
with torch.no_grad():
    total_vcorrect = 0.0
    total_vsamples = 0.0
    for i, vdata in enumerate(valid_loader):
        vinputs, vlabels = vdata[0].to(device), vdata[1].to(device)
        voutputs = model(vinputs)
        total_vcorrect += (voutputs.argmax(dim=1) == vlabels).sum().item()
        total_vsamples += vlabels.size(0)
avg_vacc = total_vcorrect / total_vsamples

# calculate time
with torch.no_grad():
    start_time = time()
    for i, vdata in enumerate(valid_loader):
        vinputs, vlabels = vdata[0].to(device), vdata[1].to(device)
        model(vinputs)
    end_time = time()

print(f"## Accuracy: {avg_vacc:.4f}")
print(f"## Inference {len(valid_set)} data with {end_time-start_time} sec.")

## start warm up
## finished warm up
## Accuracy: 0.3000
## Inference 30 data with 0.026927947998046875 sec.
