In [None]:
# Import PyTorch
import torch
from torch import nn

# Import torchvision
import torchvision
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor

# Impor matplotlib for visualization
import matplotlib.pyplot as plt

# Check version
print(f"PyTorch version: {torch.__version__}\ntorchvision version: {torchvision.__version__}")

# Device agnostic
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

In [None]:
tr = transforms.Compose([
    transforms.Resize(size=(64, 64)),
    #transforms.RandomHorizontalFlip(p=0.5),
    transforms.TrivialAugmentWide(),
    transforms.ToTensor()
])
ts = transforms.Compose([
    transforms.Resize(size=(64, 64)),
    #transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor()
])

In [None]:
import random
import json
from pathlib import Path
from typing import Any, Callable, Optional, Tuple

import PIL.Image

from torchvision.datasets.utils import download_and_extract_archive, verify_str_arg
from torchvision.datasets import VisionDataset


class Food101(VisionDataset):
    """`The Food-101 Data Set <https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/>`_.

    The Food-101 is a challenging data set of 101 food categories with 101,000 images.
    For each class, 250 manually reviewed test images are provided as well as 750 training images.
    On purpose, the training images were not cleaned, and thus still contain some amount of noise.
    This comes mostly in the form of intense colors and sometimes wrong labels. All images were
    rescaled to have a maximum side length of 512 pixels.


    Args:
        root (string): Root directory of the dataset.
        split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
        transform (callable, optional): A function/transform that  takes in an PIL image and returns a transformed
            version. E.g, ``transforms.RandomCrop``.
        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
        download (bool, optional): If True, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again. Default is False.
    """

    _URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz"
    _MD5 = "85eeb15f3717b99a5da872d97d918f87"

    def __init__(
        self,
        root: str,
        split: str = "train",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
    ) -> None:
        super().__init__(root, transform=transform, target_transform=target_transform)
        self._split = verify_str_arg(split, "split", ("train", "test"))
        self._base_folder = Path(self.root) / "food-101"
        self._meta_folder = self._base_folder / "meta"
        self._images_folder = self._base_folder / "images"

        if download:
            self._download()

        if not self._check_exists():
            raise RuntimeError("Dataset not found. You can use download=True to download it")

        self._labels = []
        self._image_files = []
        with open(self._meta_folder / f"{split}.json") as f:
            metadata = json.loads(f.read())

        self.classes = sorted(metadata.keys())
        random.seed(42)
        self.classes = random.sample(self.classes, 3)
        print(self.classes)
        self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
        for class_label, im_rel_paths in metadata.items():
            if class_label in self.classes:
              self._labels += [self.class_to_idx[class_label]] * len(im_rel_paths)
              self._image_files += [
                  self._images_folder.joinpath(*f"{im_rel_path}.jpg".split("/")) for im_rel_path in im_rel_paths
              ]

    def __len__(self) -> int:
        return len(self._image_files)

    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
        image_file, label = self._image_files[idx], self._labels[idx]
        image = PIL.Image.open(image_file).convert("RGB")

        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            label = self.target_transform(label)

        return image, label

    def extra_repr(self) -> str:
        return f"split={self._split}"

    def _check_exists(self) -> bool:
        return all(folder.exists() and folder.is_dir() for folder in (self._meta_folder, self._images_folder))

    def _download(self) -> None:
        if self._check_exists():
            return
        download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)

In [None]:
train_data = Food101(
    root="data",
    split="train",
    download=True,
    transform=tr,
    target_transform=None
)
test_data = Food101(
    root="data",
    split="test",
    download=True,
    transform=ts
)

In [None]:
print(f"Train data:\n{train_data}\nTest data:\n{test_data}")
print(f"Train data classes: {train_data.classes}")
print(f"Test data classes: {test_data.classes}")

In [None]:
indices = torch.arange(10000)
indices2 = torch.arange(2000)
train_data2 = torch.utils.data.Subset(train_data, indices)
test_data2 = torch.utils.data.Subset(test_data, indices2)

In [None]:
image, label = train_data[0]
print(f"Image shape: {image.shape}, Label shape: {label}")
class_names = train_data.classes
print(f"Class names: {class_names}")
plt.imshow(image.permute(1,2,0))
plt.title(class_names[label])

In [None]:
# Plot more images
torch.manual_seed(42)
fig = plt.figure(figsize=(9,9))
rows, cols = 4, 4
for i in range(1, rows * cols + 1):
  random_idx = torch.randint(0, len(train_data), size=[1]).item()
  img, label = train_data[random_idx]
  fig.add_subplot(rows, cols, i)
  plt.imshow(img.permute(1,2,0))
  plt.title(class_names[label])
  plt.axis(False);

In [None]:
import os
from torch.utils.data import DataLoader

# Setup the batch size
BATCH_SIZE = 32

# Turn datasets into iterables (batches)
train_dataloader = DataLoader(
    train_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=os.cpu_count(),
    pin_memory=True
)
test_dataloader = DataLoader(
    test_data,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=os.cpu_count(),
    pin_memory=True
)

# Check out dataloader lengths
print(f"Dataloader: {train_dataloader, test_dataloader}")
print(f"Length of train dataloader: {len(train_dataloader)} batches of {BATCH_SIZE}")
print(f"Length of test dataloader: {len(test_dataloader)} batches of {BATCH_SIZE}")

# Check out what's inside the training dataloader
train_features_batch, train_labels_batch = next(iter(train_dataloader))
print(f"Features batch shape: {train_features_batch.shape}, Labels batch shape: {train_labels_batch.shape}")

# Show a sample
torch.manual_seed(42)
random_idx = torch.randint(0, len(train_features_batch), size=[1]).item()
img, label = train_features_batch[random_idx], train_labels_batch[random_idx]
plt.imshow(img.permute(1,2,0))
plt.title(class_names[label])
plt.axis(False);
print(f"Image size: {img.shape}")
print(f"Label: {label}, label size: {label.shape}")

In [None]:
class VGG(nn.Module):
  def __init__(self, in_shape: int, hidden_units: int, out_shape: int):
    super().__init__()
    self.block1 = nn.Sequential(
        nn.Conv2d(
            in_channels=in_shape,
            out_channels=hidden_units,
            kernel_size=3,
            stride=1,
            padding=1
        ),
        nn.ReLU(),
        nn.Conv2d(
            in_channels=hidden_units,
            out_channels=hidden_units*4,
            kernel_size=3,
            stride=1,
            padding=1
        ),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2)
    )
    self.block2 = nn.Sequential(
        nn.Conv2d(
            in_channels=hidden_units*4,
            out_channels=hidden_units*8,
            kernel_size=3,
            stride=1,
            padding=1
        ),
        nn.ReLU(),
        nn.Conv2d(
            in_channels=hidden_units*8,
            out_channels=hidden_units*8,
            kernel_size=3,
            stride=1,
            padding=1
        ),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2)
    )
    self.block3 = nn.Sequential(
        nn.Conv2d(
            in_channels=hidden_units*8,
            out_channels=hidden_units*4,
            kernel_size=3,
            stride=1,
            padding=1
        ),
        nn.ReLU(),
        nn.Conv2d(
            in_channels=hidden_units*4,
            out_channels=hidden_units,
            kernel_size=3,
            stride=1,
            padding=1
        ),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2)
    )
    self.classifier = nn.Sequential(
        nn.Flatten(),
        nn.Linear(in_features=hidden_units*8*8, out_features=out_shape)
    )

  def forward(self, x: torch.Tensor):
    return self.classifier(self.block3(self.block2(self.block1(x))))
    #return self.classifier(self.block2(self.block1(x)))

In [None]:
model_0 = VGG(
    in_shape=3,
    hidden_units=8,
    out_shape=len(class_names)
)
print(model_0.to(device))
print(f"Model on Device: {next(model_0.parameters()).device}")

In [None]:
x = torch.rand(size=(1,3,64,64)).to(device)
x = model_0(x)

In [None]:
import time
def train_time(start: float, end: float, device: torch.device):
  total_time = end - start
  print(f"Train time on {device}: {total_time:.3f} seconds")
  return total_time

In [None]:
### train and test functions
def train_step(model: torch.nn.Module,
               dataloader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               optimizer: torch.optim.Optimizer,
               device: torch.device):
  # Put model in train mode
  model.train()

  # Setup train loss and accuracy values
  train_loss, train_acc = 0, 0

  # Loop through DataLoader batches
  for batch, (img, label) in enumerate(dataloader):
    # Send data to target device
    img, label = img.to(device), label.to(device)

    # 1. Forward pass, make predictions
    pred_logits = model(img)

    # 2. Calculate and accumulate loss across all batches
    loss = loss_fn(pred_logits, label)
    train_loss += loss.item()

    # 3. Optimizer zero grad
    optimizer.zero_grad()

    # 4. Loss backward
    loss.backward()

    # 5. Optimizer step
    optimizer.step()

    # Calculate and accumulate accuracy across all batches
    pred = pred_logits.softmax(dim=1).argmax(dim=1)
    train_acc += pred.eq(label).sum().item()/len(label)

    # Print status updates
    if batch % 33 == 0:
      print(f"Looked at {batch * len(img)}/{len(train_dataloader.dataset)} samples")

  # Adjust metrics to get average loss and accuracy per batch
  train_loss /= len(dataloader)
  train_acc /= len(dataloader)

  #print(f"Train loss: {train_loss:.4f} | Train accuracy: {train_acc:.2f}%")
  return train_loss, train_acc

def test_step(model: torch.nn.Module,
              dataloader: torch.utils.data.DataLoader,
              loss_fn: torch.nn.Module,
              device: torch.device):
  # Put model in eval mode
  model.eval()
  # Setup test loss and test accuracy values
  test_loss, test_acc = 0, 0

  # Turn on inference context manager
  with torch.inference_mode():
    # Loop through DataLoader batches
    for batch, (img, label) in enumerate(dataloader):
      # Send data to target device
      img, label = img.to(device), label.to(device)

      # 1. Forward pass, make predictions
      pred_logits = model(img)

      # 2. Calculate and accumulate loss across all batches
      test_loss += loss_fn(pred_logits, label).item()

      # Calculate and accumulate accuracy across all batches
      pred = pred_logits.softmax(dim=1).argmax(dim=1)
      test_acc += pred.eq(label).sum().item()/len(label)

    # Adjust metrics to get average loss and accuracy per batch
    test_loss /= len(dataloader)
    test_acc /= len(dataloader)

  #print(f"Test loss: {test_loss:.4f} | Test accuracy: {test_acc:.2f}%")
  return test_loss, test_acc

In [None]:
from tqdm.auto import tqdm

def train(model: torch.nn.Module,
          train_dataloader: torch.utils.data.DataLoader,
          test_dataloader: torch.utils.data.DataLoader,
          loss_fn: torch.nn.Module,
          optimizer: torch.optim.Optimizer,
          epochs: int,
          device: torch.device):
  # Create empty results dictionary
  results = {"train_loss": [],
             "train_acc": [],
             "test_loss": [],
             "test_acc": []
             }


  # Send model to target device
  model.to(device)

  # Loop through training and testing steps for a number of epochs
  for epoch in tqdm(range(epochs)):
    print(f"Epoch: {epoch}\n------")
    train_loss, train_acc = train_step(model=model,
                                      dataloader=train_dataloader,
                                      loss_fn=loss_fn,
                                      optimizer=optimizer,
                                      device=device)
    test_loss, test_acc = test_step(model=model,
                                    dataloader=test_dataloader,
                                    loss_fn=loss_fn,
                                    device=device)

    # Print out what's happening
    print(
        f"Epoch: {epoch+1} | "
        f"train_loss: {train_loss:.4f} | "
        f"train_acc: {train_acc:.3f} | "
        f"test_loss: {test_loss:.4f} | "
        f"test_acc: {test_acc:.3f}"
    )

    # Update results dictionary
    results["train_loss"].append(train_loss)
    results["train_acc"].append(train_acc)
    results["test_loss"].append(test_loss)
    results["test_acc"].append(test_acc)

  # Return the filled results at the end of the epochs
  return results

In [None]:
torch.manual_seed(42)
loss_fn = nn.CrossEntropyLoss()
#optimizer = torch.optim.SGD(params=model_0.parameters(), lr=0.1)
optimizer = torch.optim.Adam(params=model_0.parameters(), lr=0.001)
epochs = 10
t_start = time.time()
res = train(model=model_0,
            train_dataloader=train_dataloader,
            test_dataloader=test_dataloader,
            loss_fn=loss_fn,
            optimizer=optimizer,
            epochs=epochs,
            device=device)
t_model_0 = train_time(start=t_start, end=time.time(), device=device)

In [None]:
# Plot loss curves of a model
def plot_loss_curves(results):
  loss = results["train_loss"]
  test_loss = results["test_loss"]

  acc = results["train_acc"]
  test_acc = results["test_acc"]

  epochs = range(len(results["train_loss"]))

  plt.figure(figsize=(15,7))

  # Plot loss
  plt.subplot(1, 2, 1)
  plt.plot(epochs, loss, label="train_loss")
  plt.plot(epochs, test_loss, label="test_loss")
  plt.title("Loss")
  plt.xlabel("Epochs")
  plt.legend()

  # Plot accuracy
  plt.subplot(1, 2, 2)
  plt.plot(epochs, acc, label="train_acc")
  plt.plot(epochs, test_acc, label="test_acc")
  plt.title("Accuracy")
  plt.xlabel("Epochs")
  plt.legend()

In [None]:
plot_loss_curves(res)

In [None]:
# Make predictions
def make_predictions(model: torch.nn.Module,
                     data: list,
                     device: torch.device):
  pred_probs = []
  model.eval()
  with torch.inference_mode():
    for sample in data:
      # Prepare sample
      sample = sample.unsqueeze(dim=0).to(device) # Add an extra dimension and send sample to device

      # Forward pass (model outputs raw logits)
      pred_logit = model(sample)

      # Get prediction probability (logit -> prediction probability)
      pred_prob = pred_logit.squeeze().softmax(dim=0) # perform softmax on the "logits" dim, not "batch" dim

      # Send data to cpu
      pred_probs.append(pred_prob.cpu())

  # Stack the pred_probs to turn list into a tensor
  return torch.stack(pred_probs)

In [None]:
import random
random.seed(42)
test_samples = []
test_labels = []
for sample, label in random.sample(list(test_data), k=9):
  test_samples.append(sample)
  test_labels.append(label)

print(len(test_samples))

# View the first test sample shape and label
print(f"Test sample image shape: {test_samples[0].shape}\nTest sample label: {test_labels[0]} ({class_names[test_labels[0]]})")

# Make predictions on test samples with model_0
pred_probs = make_predictions(model=model_0,
                              data=test_samples,
                              device=device)

# View first two prediction probabilities list
print(pred_probs[:2])

# Turn the prediction probabilities into prediction labels by taking the argmax()
pred_classes = pred_probs.argmax(dim=1)
#print(pred_classes)

# Are predictions in the same form as test labels?
print(f"Test labels: {test_labels}\nPredicted labels: {pred_classes}")

In [None]:
# Plot predictions
plt.figure(figsize=(9,9))
nrows, ncols = 3, 3
for i, sample in enumerate(test_samples):
  # Create a subplot
  plt.subplot(nrows, ncols, i+1)

  # Plot the target image
  plt.imshow(sample.permute(1,2,0))

  # Find the prediction label
  pred_label = class_names[pred_classes[i]]

  # Get the truth label
  truth_label = class_names[test_labels[i]]

  # Create the title text of the plot
  title_text = f"Pred: {pred_label} | Truth: {truth_label}"

  # Check for equality and change title color accordingly
  if pred_label == truth_label:
    plt.title(title_text, fontsize=10, c='g')
  else:
    plt.title(title_text, fontsize=10, c='r')
  plt.axis(False);