In [None]:
import torch
import torchvision
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models import resnet34, ResNet34_Weights
import torch.nn as nn
import torch.nn.functional as F

import time

In [None]:
weights = ResNet50_Weights.DEFAULT

In [None]:
batch_size = 32

train_set = torchvision.datasets.Food101(root='./data', split="train",
                                        download=True, transform=weights.transforms())
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                          shuffle=True, generator=torch.Generator(device='cuda'))


test_set = torchvision.datasets.Food101(root='./data', split="test",
                                        download=True, transform=weights.transforms())
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,
                                          shuffle=True, generator=torch.Generator(device='cuda'))

Downloading https://data.vision.ee.ethz.ch/cvl/food-101.tar.gz to ./data/food-101.tar.gz


100%|██████████| 4996278331/4996278331 [00:46<00:00, 107754065.83it/s]


Extracting ./data/food-101.tar.gz to ./data


In [None]:
print(len(train_set))
print(len(test_set))

75750
25250


In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")
torch.set_default_device(device)

Using cuda device


In [None]:
def init_xavier(module):
  if type(module) == nn.Linear:
        nn.init.xavier_uniform_(module.weight)

def get_net():
  weights = ResNet34_Weights.DEFAULT
  model = resnet34(weights=weights)
  model_without_last_layer = nn.Sequential(*list(model.children())[:-1])
  new_output = nn.Sequential(
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 101)
  )
  net = nn.Sequential(
    model_without_last_layer,
    nn.Flatten(),
    new_output
  )
  net[-1][-3].apply(init_xavier)
  net[-1][-1].apply(init_xavier)

  return net, new_output



In [None]:
net, _ = get_net()
print(net[-1][-3])

Linear(in_features=512, out_features=256, bias=True)


In [None]:
import pdb

def train(train_iter, valid_iter, num_epochs, lr, wd, lr_period, lr_decay):
  net, new_output = get_net()
  loss = nn.CrossEntropyLoss()
  trainer = torch.optim.SGD(new_output.parameters(), lr=lr, momentum=0.9, weight_decay=wd)
  for epoch in range(num_epochs):
    train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()
    if epoch > 0 and epoch % lr_period == 0:
      trainer.set_learning_rate(trainer.learning_rate * lr_decay)
      prev_X = None
    for X, y in train_iter:
      X = X.to(device)
      y = y.to(device)
      y_hat = net(X)
      l = loss(y_hat, y).sum()
      trainer.zero_grad()
      with torch.no_grad():
        l.backward()
        trainer.step()
      train_l_sum += l.item()
      train_acc_sum += (y_hat.argmax(axis=1) == y).sum().item()
      n += len(y)
      prev_X = X
    time_s = "time %.2f sec" % (time.time() - start)
    if valid_iter is not None:
      test_acc_sum = 0.0
      n2 = 0
      for X2,y2 in valid_iter:
        X2 = X2.to(device)
        y2 = y2.to(device)
        y_hat2 = net(X2)
        test_acc_sum += (y_hat2.argmax(axis=1) == y2).sum().item()
        n2 += len(y2)
      epoch_s = ("epoch %d, loss %f, train acc %f, valid acc %f, " % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc_sum / n2))
    else:
      epoch_s = ("epoch %d, loss %f, train acc %f, " % (epoch + 1, train_l_sum / n, train_acc_sum / n))
    print(epoch_s + time_s)


In [None]:
num_epochs, lr, wd = 10, 0.01, 1e-4
lr_period, lr_decay= 10, 0.1
train(train_loader, test_loader, num_epochs, lr, wd, lr_period, lr_decay)

epoch 1, loss 0.082622, train acc 0.359868, valid acc 0.449030, time 741.67 sec
epoch 2, loss 0.068790, train acc 0.449096, valid acc 0.494297, time 733.46 sec
epoch 3, loss 0.065332, train acc 0.474271, valid acc 0.492000, time 717.72 sec
epoch 4, loss 0.063173, train acc 0.487828, valid acc 0.503723, time 705.96 sec
epoch 5, loss 0.061627, train acc 0.499393, valid acc 0.503644, time 683.40 sec
epoch 6, loss 0.060161, train acc 0.507921, valid acc 0.491168, time 686.60 sec
epoch 7, loss 0.058798, train acc 0.518284, valid acc 0.523248, time 686.37 sec


KeyboardInterrupt: ignored