In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import time

In [3]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")
torch.set_default_device('cpu')

Using mps device


In [4]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 128

# trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
#                                         download=True, transform=transform)
# trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
#                                           shuffle=True, num_workers=2, generator=torch.Generator(device='cuda'))

# testset = torchvision.datasets.CIFAR10(root='./data', train=False,
#                                        download=True, transform=transform)
# testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
#                                          shuffle=False, num_workers=2, generator=torch.Generator(device='cuda'))

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# i = 0
# for X,y in trainloader:
#     print('x shape',X.shape)
#     print('y shape', y)
#     i += 1
#     if i == 10:
#         break


Files already downloaded and verified
Files already downloaded and verified


In [5]:

class Residual(nn.Module):
  def __init__(self, num_channels_in, num_channels_out, use_1x1conv=False, strides=1):
    super().__init__()
    self.conv1 = nn.Conv2d(num_channels_in, num_channels_in, kernel_size=3, stride=strides, padding=1)
    self.conv2 = nn.Conv2d(num_channels_in, num_channels_out, kernel_size=3, padding=1)
    self.conv3 = nn.Conv2d(num_channels_in, num_channels_out, kernel_size=1, stride=strides)
    if use_1x1conv:
      self.use_1x1conv = True
    else:
      self.use_1x1conv = False
    self.bn1 = nn.BatchNorm2d(num_channels_in)
    self.bn2 = nn.BatchNorm2d(num_channels_out)

  def forward(self, x):
    Y = F.relu(self.bn1(self.conv1(x)))
    Y = self.bn2(self.conv2(Y))
    if self.use_1x1conv:
      x = self.conv3(x)
    return F.relu(Y + x)

In [6]:
def resnet18(num_classes):
  net = nn.Sequential(
      nn.Conv2d(3,128,kernel_size=3, stride=1, padding=1),
      nn.BatchNorm2d(128),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
  def resnet_block(num_channels_in, num_channels_out, num_residuals, first_block=False):
    blk = nn.Sequential()
    for i in range(num_residuals):
      if i == 0 and not first_block:
        blk.add_module("res 1x1conv",Residual(num_channels_in, num_channels_out, use_1x1conv=True, strides=2))
      else:
        blk.add_module("res", Residual(num_channels_out, num_channels_out))
    return blk
  
  # net.add_module("res",resnet_block(64,2,first_block=True))
  # net.add_module("res",resnet_block(128,2))
  # net.add_module("res",resnet_block(256,2))
  # net.add_module("res",resnet_block(512,2))
  net.add_module("res",nn.Sequential(
     resnet_block(64,128,2,first_block=True),
     resnet_block(128,256,2),
     resnet_block(256,512,2),
     resnet_block(512,512,2)
  ))
  net.add_module("avg pool",nn.AvgPool2d(kernel_size=2))
  net.add_module('flatten', nn.Flatten())
  net.add_module("dense",nn.LazyLinear(num_classes))
  return net

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_uniform_(m.weight.data)
        if m.bias is not None:
            nn.init.zeros_(m.bias.data)


def get_net():
  num_classes = 10
  net = resnet18(num_classes)
  net.apply(weights_init)
  return net


loss = nn.CrossEntropyLoss()


In [7]:
def evaluate_accuracy(net, valid_iter):
  train_acc_sum = 0.0
  n = 0
  for X,y in valid_iter:
    y_hat = net(X)
    train_acc_sum += (y_hat.argmax(axis=1) == y).sum().item()
    n += 1
  return train_acc_sum / n

In [18]:

def train(net, train_iter, valid_iter, num_epochs, lr, wd, lr_period, lr_decay):
  loss = nn.CrossEntropyLoss()
  trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=wd)
  for epoch in range(num_epochs):
    train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()
    if epoch > 0 and epoch % lr_period == 0:
      trainer.set_learning_rate(trainer.learning_rate * lr_decay)
      prev_X = None
    for X, y in train_iter:
      # y = y.astype('float32').as_in_context(ctx)
      # y_hat = net(X.as_in_context(ctx))
      X = X.to(device)
      y = y.to(device)
      y_hat = net(X)
      l = loss(y_hat, y).sum()
      trainer.zero_grad()
      with torch.no_grad():
        l.backward()
        trainer.step()
      train_l_sum += l.item()
      train_acc_sum += (y_hat.argmax(axis=1) == y).sum().item()
      n += len(y)
      prev_X = X
    time_s = "time %.2f sec" % (time.time() - start)
    if valid_iter is not None:
      # valid_acc = evaluate_accuracy(net, valid_iter)
      test_acc_sum = 0.0
      n2 = 0
      for X2,y2 in valid_iter:
        X2 = X2.to(device)
        y2 = y2.to(device)
        y_hat2 = net(X2)
        test_acc_sum += (y_hat2.argmax(axis=1) == y2).sum().item()
        n2 += len(y2)
      epoch_s = ("epoch %d, loss %f, train acc %f, valid acc %f, " % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc_sum / n2))
    else:
      epoch_s = ("epoch %d, loss %f, train acc %f, " % (epoch + 1, train_l_sum / n, train_acc_sum / n))
    print(epoch_s + time_s)

In [9]:
net = get_net()
x = torch.zeros(size=(1,3,32,32))

print(net)
yhat = net(x)
yhat.shape
# for layer in net:
#     print(layer.shape)

Sequential(
  (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (res): Sequential(
    (0): Sequential(
      (res): Residual(
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv3): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Sequential(
      (res 1x1conv): Residual(
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))



torch.Size([1, 10])

In [19]:
num_epochs = 1
lr = 0.1
wd = 5e-4
net = get_net()
print('net:', net)
train(net, trainloader, testloader, num_epochs, lr, wd, 80, 0.1)

net: Sequential(
  (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (res): Sequential(
    (0): Sequential(
      (res): Residual(
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv3): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Sequential(
      (res 1x1conv): Residual(
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1