# Imports

In [175]:
# For this notebook to run with updated APIs, we need torch 1.12+ and torchvision 0.13+
try:
    import torch
    import torchvision
    from torch import nn

    assert int(torch.__version__.split(".")[1]) >= 12 or int(
            torch.__version__.split(".")[0]
            ) == 2, "torch version should be 1.12+"
    assert int(torchvision.__version__.split(".")[1]) >= 13, "torchvision version should be 0.13+"
    print(f"torch version: {torch.__version__}")
    print(f"torchvision version: {torchvision.__version__}")
except:
    print(f"[INFO] torch/torchvision versions not as required, installing nightly versions.")
    !pip3 install -U torch torchvision torchaudio --index-url https: // download.pytorch.org/whl/cu118
    import torch
    import torchvision

    print(f"torch version: {torch.__version__}")
    print(f"torchvision version: {torchvision.__version__}")

# Continue with regular imports
import matplotlib.pyplot as plt
import torch
import torchvision

from torch import nn
from torchvision import transforms

# Try to get torchinfo, install it if it doesn't work
try:
    from torchinfo import summary
except:
    print("[INFO] Couldn't find torchinfo... installing it.")
    !pip install -q torchinfo
    from torchinfo import summary

# Try to import the going_modular directory, download it from GitHub if it doesn't work
try:
    from going_modular.going_modular import data_setup, engine
    from helper_functions import download_data, set_seeds, plot_loss_curves
except:
    # Get the going_modular scripts
    print("[INFO] Couldn't find going_modular or helper_functions scripts... downloading them from GitHub.")
    !git clone https: // github.com/mrdbourke/pytorch-deep-learning
    !mv pytorch-deep-learning/going_modular.
    !mv pytorch-deep-learning/helper_functions.py.  # get the helper_functions.py script
    !rm -rf pytorch-deep-learning
    from going_modular.going_modular import data_setup, engine
    from helper_functions import download_data, set_seeds, plot_loss_curves

from pathlib import Path

torch version: 2.1.2
torchvision version: 0.16.2


# Device choser

In [176]:
def device_chooser(prefer_device: str = "cpu") -> str:
    devices = {}
    if torch.cuda.is_available():
        devices["cuda"] = "cuda"
    elif torch.backends.mps.is_available():
        devices["mps"] = "mps"
    else:
        devices["cpu"] = "cpu"

    if prefer_device in devices:
        return devices[prefer_device]
    else:
        return "cpu"


device = device_chooser(prefer_device="mps")

# Get data

In [177]:
image_path = download_data(
        source="https://github.com/mrdbourke/pytorch-deep-learning/raw/main/images/pizza_steak_sushi.zip",
        destination=Path("pizza_steak_sushi")
        )
image_path

[INFO] data/pizza_steak_sushi directory exists, skipping download.


PosixPath('data/pizza_steak_sushi')

In [178]:
train_path = image_path / "train"
test_path = image_path / "test"

## Prepare dataset | dataloader

In [179]:
RANDOM_SEED = 42
HEIGHT, WIDTH = 224, 224
IMG_SIZE = (HEIGHT, WIDTH)

BATCH_SIZE = 32

manual_transforms = transforms.Compose(
        [
            transforms.Resize(IMG_SIZE),
            transforms.ToTensor(),
            ]
        )

train_dataset = torchvision.datasets.ImageFolder(root=train_path, transform=manual_transforms)
test_dataset = torchvision.datasets.ImageFolder(root=test_path, transform=manual_transforms)

train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE)
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE)

In [180]:
PATCH_SIZE = 16

conv2d = nn.Conv2d(
        in_channels=3,
        out_channels=768,
        kernel_size=PATCH_SIZE,
        stride=PATCH_SIZE,
        padding=0
        )

flatten = nn.Flatten(start_dim=1, end_dim=2)

In [129]:
im_batch: torch.Tensor = next(iter(train_dataloader))[0]
im: torch.Tensor = im_batch[0]

print(im.shape)
image_embeddings: torch.Tensor = conv2d(im)
flattened_image_embeddings: torch.Tensor = flatten(image_embeddings)
print(image_embeddings.shape, flattened_image_embeddings.shape, sep="\n")

torch.Size([3, 224, 224])
torch.Size([768, 14, 14])
torch.Size([768, 196])


In [148]:
im_batch, _ = next(iter(train_dataloader))
im: torch.Tensor = im_batch[0].unsqueeze(0)
im.shape

torch.Size([1, 3, 224, 224])

In [157]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int, patch_size: int, embed_dim: int):
        super().__init__()

        self.patch_size = patch_size
        self.patcher = nn.Sequential(
                nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=embed_dim,
                        stride=patch_size,
                        kernel_size=patch_size,
                        padding=0
                        ),
                nn.Flatten(start_dim=-2, end_dim=-1)
                )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        image_resolution = x.shape[-1]
        assert image_resolution % self.patch_size == 0, f"[ERROR] - Image resolution is not divisible by patch size. Image size {x.shape} | Patch size {self.patch_size}"

        return self.patcher(x).permute(0, 2, 1)


set_seeds()
patchify = PatchEmbedding(in_channels=3, patch_size=PATCH_SIZE, embed_dim=768)
patch_embedded_image = patchify(im)

print(
        f"{im.shape}",
        f"{patchify(im).shape}",
        sep="\n"
        )

torch.Size([1, 3, 224, 224])
torch.Size([1, 196, 768])


In [174]:
embedding_dimension = 768
class_token = nn.Parameter(
        torch.randn(1, 1, embedding_dimension), requires_grad=True
        )

print(class_token.shape, patch_embedded_image.shape, sep="\n")

patch_embedded_image_with_class_embedding = torch.cat(
        (class_token, patch_embedded_image), dim=1
        )
patch_position_embeddings = nn.Parameter(
        torch.randn(1, 197, embedding_dimension),
        requires_grad=True
        )

position_and_class_embeddings = (
        patch_embedded_image_with_class_embedding + patch_position_embeddings)

patch_embedded_image_with_class_embedding.shape, position_and_class_embeddings.shape

torch.Size([1, 1, 768])
torch.Size([1, 196, 768])


(torch.Size([1, 197, 768]), torch.Size([1, 197, 768]))