In [1]:
import torch
from torch import nn
from torchvision import transforms
from torchvision.models import resnet18
from torchvision.datasets import CIFAR10

In [2]:
loss_fn = nn.CrossEntropyLoss()

In [3]:
model = resnet18()

In [4]:
model.fc = nn.Linear(in_features=512, out_features=10, bias=True)

In [5]:
model.to(torch.device("mps"))

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [8]:
model.eval()

with torch.no_grad():
	testset = CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())

	testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=True, num_workers=0)

	num_batches = len(testloader)
	size = len(testset)

	test_loss, correct = 0, 0

	for inputs, labels in testloader:
		inputs, labels = inputs.to(device=torch.device("mps"), non_blocking=True), labels.to(device=torch.device("mps"), non_blocking=True)
		outputs = model(inputs)
		test_loss += loss_fn(outputs, labels).item()
		correct += (outputs.argmax(1) == labels).type(torch.float).sum().item()
	
	test_loss /= num_batches
	correct /= size

	print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}")

Files already downloaded and verified
Test Error: 
 Accuracy: 39.9%, Avg loss: 0.463059


In [7]:
model.load_state_dict(torch.load('model.pt', weights_only=True))

<All keys matched successfully>