In [1]:
import torch
from torch import nn
from torch.utils.data import random_split

from torchvision import transforms
from torchvision.models import resnet18
from torchvision.datasets import CIFAR10

In [2]:
model = resnet18()
model.conv1 = nn.Conv2d(
	in_channels=3,
	out_channels=64,
	kernel_size=3,
	stride=1,
	padding=1,
	bias=False
)
model.maxpool = nn.Identity()
model.fc = nn.Linear(in_features=512, out_features=10, bias=True)
model.to(torch.device("mps"))

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): Identity()
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), p

In [3]:
try:
	model.load_state_dict(torch.load("../checkpoints/rnd14.pt", weights_only=True))
except:
	print("Not using loaded weights")

In [4]:
loss_fn = nn.CrossEntropyLoss()

In [5]:
trainset = CIFAR10(root='./data', train=True, download=True, transform=transforms.Compose([
	transforms.RandomCrop(32, padding=4),
	transforms.RandomHorizontalFlip(),
	transforms.ToTensor(),
	transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
]))
testset = CIFAR10(root='./data', train=False, download=True, transform=transforms.Compose([
	transforms.ToTensor(),
	transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
]))

Files already downloaded and verified
Files already downloaded and verified


In [6]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=0)
testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=True, num_workers=0)

In [7]:
# current_lr = 0.001
# optimizer = torch.optim.SGD(model.parameters(), lr=current_lr, weight_decay=5e-4, momentum=0.9)
# loss_fn = nn.CrossEntropyLoss().to(torch.device("mps"))

# for local_epoch in range(100):
# 	model.train()
	
# 	print(f"Epoch {local_epoch+1}/100... ")

# 	for i, (inputs, labels) in enumerate(trainloader):
# 		optimizer.zero_grad()
# 		inputs, labels = inputs.to(device=torch.device("mps"), non_blocking=True), labels.to(device=torch.device("mps"), non_blocking=True)
# 		outputs = model(inputs)
# 		minibatch_loss = loss_fn(outputs, labels)
# 		minibatch_loss.backward()
# 		optimizer.step()

# 	model.eval()
# 	with torch.no_grad():
# 		num_batches = len(testloader)
# 		size = len(testset)

# 		test_loss, correct = 0, 0

# 		for inputs, labels in testloader:
# 			inputs, labels = inputs.to(device=torch.device("mps"), non_blocking=True), labels.to(device=torch.device("mps"), non_blocking=True)
# 			outputs = model(inputs)
# 			test_loss += loss_fn(outputs, labels).item()
# 			correct += (outputs.argmax(1) == labels).type(torch.float).sum().item()
		
# 		test_loss /= num_batches
# 		correct /= size

# 		print(f"Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}")
model.eval()
with torch.no_grad():
	num_batches = len(testloader)
	size = len(testset)

	test_loss, correct = 0, 0

	for inputs, labels in testloader:
		inputs, labels = inputs.to(device=torch.device("mps"), non_blocking=True), labels.to(device=torch.device("mps"), non_blocking=True)
		outputs = model(inputs)
		test_loss += loss_fn(outputs, labels).item()
		correct += (outputs.argmax(1) == labels).type(torch.float).sum().item()
	
	test_loss /= num_batches
	correct /= size

	print(f"Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}")

Accuracy: 39.1%, Avg loss: 1.631523
