<a href="https://colab.research.google.com/github/comet-ctrl/Intro-to-PyTorch-mrdbourke/blob/main/04_pytorch_custom_datasets.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn

torch.__version__

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

## 1. Get Data

In [None]:
import requests
import zipfile
from pathlib import Path

data_path = Path("data/")
image_path = data_path / "pizza_steak_sushi"

if image_path.is_dir():
  print(f"{image_path} directory allready exists... skipping downlaod")
else:
  print(f"{image_path} does not exist, creating one...")
  image_path.mkdir(parents = True, exist_ok = True)

with open(data_path / "pizza_steak_sushi.zip", "wb") as f:
  request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/refs/heads/main/data/pizza_steak_sushi.zip")
  print("Downloading pizza, sushi, steak data...")
  f.write(request.content)

with zipfile.ZipFile(data_path / "pizza_steak_sushi.zip", "r") as zip_ref:
  print("Unzipping pizza, steak and sushi data")
  zip_ref.extractall(image_path)

## 2. Becoming one with the data (data preparation and data exploration)

In [None]:
import os
def walk_through_dir(dir_path):
  for dirpath, dirnames, filenames in os.walk(dir_path):
    print(f"There are {len(dirnames)} directories and {len(filenames)} images in {dirpath}")
walk_through_dir(image_path)

In [None]:
train_dir = image_path / "train"
test_dir = image_path / "test"

train_dir, test_dir

### 2.1 Visualizing and image

In [None]:
import random
from PIL import Image

random.seed(42)

image_path_list = list(image_path.glob("*/*/*.jpg"))

random_image_path = random.choice(image_path_list)
print(random_image_path)

image_class = random_image_path.parent.stem
print(image_class)

img = Image.open(random_image_path)

print(f"Random image path: {random_image_path}")
print(f"Image class: {image_class}")
print(f"Image height: {img.height}")
print(f"Image width: {img.width}")
img

In [None]:
import numpy as np
import matplotlib.pyplot as plt

img_as_array = np.asarray(img)

plt.figure(figsize = (10, 7))
plt.imshow(img_as_array)
plt.title(f"Image class: {image_class} | Image shape: {img_as_array.shape} -> [height, width, color_channels]")
plt.axis(False);

## 3. Transforming data

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


## 3.1 Transforming data with `torchvision.transforms`

In [None]:
data_transform = transforms.Compose([
    transforms.Resize(size = (64,64)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor()
])
data_transform(img).shape

In [None]:
def plot_transformed_images(image_paths, transforms, n=3, seed = None):
  if seed:
    random.seed(seed)
  random_image_paths = random.sample(image_paths, k = n)
  for image_path in random_image_paths:
    with Image.open(image_path) as f:
      fig, ax = plt.subplots(nrows=1, ncols=2)
      ax[0].imshow(f)
      ax[0].set_title(f"Original\nSize: {f.size}")
      ax[0].axis(False)

      transformed_image = transforms(f).permute(1,2,0)
      ax[1].imshow(transformed_image)
      ax[1].set_title(f"Transformed\nShape: {transformed_image.shape}")
      ax[1].axis("off")

      fig.suptitle(f"Class: {image_path.parent.stem}", fontsize = 16)

plot_transformed_images(
    image_paths = image_path_list,
    transforms = data_transform,
    n = 3,
    seed = 42
)


## 4. Option 1: Loading image data using `ImageFolder`

In [None]:
from torchvision import datasets
train_data = datasets.ImageFolder(
    root = train_dir,
    transform = data_transform,
    target_transform = None
)

test_data = datasets.ImageFolder(
    root = test_dir,
    transform = data_transform
)

train_data, test_data

In [None]:
class_names = train_data.classes
class_names

In [None]:
class_dict = train_data.class_to_idx
class_dict

In [None]:
train_data.samples[0]

In [None]:
img, label = train_data[0][0], train_data[0][1]
print(f"Image tensor:\n {img}")
print(f"Image shape: {img.shape}")
print(f"Image datatype: {img.dtype}")
print(f"Image label: {label}")
print(f"Label dattype: {type(label)}")

In [None]:
img_permute = img.permute(1,2,0)

print(f"Original shape:  {img.shape} -> [color_channels, height, width]")
print(f"Image shape: {img_permute.shape} -> [height, width, color_channels]")

plt.figure(figsize = (10,7))
plt.imshow(img_permute)
plt.axis("off")
plt.title(class_names[label], fontsize = 14)

### 4.1 Turn laoded images inot `DataLoader`'s

In [None]:
import os
os.cpu_count()

In [None]:
from torch.utils.data import DataLoader
BATCH_SIZE = 1
train_dataloader = DataLoader(
    dataset = train_data,
    batch_size = BATCH_SIZE,
    num_workers = 1,
    shuffle = True
)

test_dataloader = DataLoader(
    dataset = test_data,
    batch_size = BATCH_SIZE,
    num_workers = 1,
    shuffle = False
)

train_dataloader, test_dataloader

In [None]:
len(train_dataloader), len(test_dataloader)

In [None]:
img, label = next(iter(train_dataloader))
print(f"Image shape: {img.shape} -> [batch_size, color_channels. height, width]")
print(f"Label shape: {label.shape}")

## 5. Option 2: Loading image data with a Custom `Dataset`

In [None]:
import os
import pathlib
import torch

from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from typing import Tuple, Dict, List


In [None]:
train_data.classes, train_data.class_to_idx

### 5.1 Creating a helper function to get class names

In [None]:
target_directory = train_dir
print(f"Target dir: {target_directory}")

class_names_found = sorted([entry.name for entry in list(os.scandir(target_directory))])
class_names_found

In [None]:
def find_classes(directory: str) -> Tuple[list[str], Dict[str, int]]:
  classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
  if not classes:
    raise FileNotFoundError(f"Couldn't find any classes in {directory}... Please check file structure.")
  class_to_idx = {class_name: i for i, class_name in enumerate(classes)}
  return classes, class_to_idx

In [None]:
find_classes(target_directory)

### 5.2 Create a custom `Dataset` to replicate `ImageFolder`

In [None]:
from torch.utils.data import Dataset

class ImageFolderCustom(Dataset):
  def __init__(self, targ_dir:str, transform = None):
    self.paths = list(pathlib.Path(targ_dir).glob("*/*.jpg"))
    self.transform = transform
    self.classes, self.class_to_idx = find_classes(targ_dir)

  def load_image(self, index: int) -> Image.Image:
    image_path = self.paths[index]
    return Image.open(image_path)

  def __len__(self) -> int:
    return len(self.paths)

  def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
    img = self.load_image(index)
    class_name = self.paths[index].parent.name
    class_idx = self.class_to_idx[class_name]

    if self.transform:
      return self.transform(img), class_idx
    else:
      return img, class_idx

In [None]:
train_transforms = transforms.Compose([
    transforms.Resize(size = (64, 64)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor()
])

test_transforms = transforms.Compose([
    transforms.Resize(size = (64, 64)),
    transforms.ToTensor()
])

In [None]:
train_data_custom = ImageFolderCustom(
    targ_dir = train_dir,
    transform = train_transforms
)

test_data_custom = ImageFolderCustom(
    targ_dir = test_dir,
    transform = test_transforms
)

In [None]:
train_data_custom, test_data_custom

In [None]:
len(train_data), len(train_data_custom)

In [None]:
len(test_data), len(test_data_custom)

In [None]:
train_data_custom.classes

In [None]:
train_data_custom.class_to_idx

In [None]:
print(train_data_custom.classes == train_data.classes)
print(test_data_custom.classes == test_data.classes)

### 5.3 Create a function to displayu random images

In [None]:
def display_random_images(
    dataset: torch.utils.data.Dataset,
    classes: List[str] = None,
    n : int = 10,
    display_shape: bool = True,
    seed: int = None
):
  if n > 10:
    n = 10
    display_shape = False
    print(f"For display purposes, n shouldn't be larger than 10, setting to 10 and removing shape display")
  if seed:
    random.seed(seed)
  random_samples_idx = random.sample(range(len(dataset)), k = n)

  plt.figure(figsize = (16, 8))

  for i, target_sample in enumerate(random_samples_idx):
    targ_image, targ_label = dataset[target_sample][0], dataset[target_sample][1]

    targ_image_adjust = targ_image.permute(1,2,0)

    plt.subplot(1, n, i+1)
    plt.imshow(targ_image_adjust)
    plt.axis("off")

    if classes:
      title = f"Class: {classes[targ_label]}"
      if display_shape:
        title = title + f"\nShape: {targ_image_adjust.shape}"
    plt.title(title)

In [None]:
display_random_images(
    train_data,
    n = 5,
    classes = class_names,
    seed = 42
)

In [None]:
display_random_images(
    train_data_custom,
    n = 20,
    classes = train_data_custom.classes,
    seed  = 42
)

### 5.4 Turn custom loaded images into `DataLoader`'s

In [None]:
from pickle import FALSE
from torch.utils.data import DataLoader
BATCH_SIZE = 32
NUM_WORKERS = os.cpu_count()
train_dataloader_custom = DataLoader(
    dataset = train_data_custom,
    batch_size = BATCH_SIZE,
    num_workers = NUM_WORKERS,
    shuffle = True
)

test_dataloader_custom = DataLoader(
    dataset = test_data_custom,
    batch_size = BATCH_SIZE,
    num_workers = NUM_WORKERS,
    shuffle = False
)
train_dataloader_custom, test_dataloader_custom


In [None]:
img_custom, label_custom = next(iter(train_dataloader_custom))

img_custom.shape, label_custom.shape

## 6. Other forms of transforms (data augmentation)

In [None]:
from torchvision import transforms

train_transform = transforms.Compose([
    transforms.Resize(size = (64, 64)),
    transforms.TrivialAugmentWide(num_magnitude_bins = 31),
    transforms.ToTensor()
])

test_transforms = transforms.Compose([
    transforms.Resize(size = (224, 224)),
    transforms.ToTensor()
])


In [None]:
image_path_list = list(image_path.glob("*/*/*.jpg"))
image_path_list[:10]

In [None]:
plot_transformed_images(
    image_paths = image_path_list,
    transforms = train_transform,
    n = 3,
    seed = 42
)

## 7. Model 0: TinyVGG without data augmentation

### 7.1 Creating transforms and loading data for Model 0

In [None]:
simple_transform = transforms.Compose([
    transforms.Resize(size = (64, 64)),
    transforms.ToTensor()
])

from torchvision import datasets
train_data_simple = datasets.ImageFolder(
    root = train_dir,
    transform = simple_transform
)

test_data_simple = datasets.ImageFolder(
    root = test_dir,
    transform = simple_transform
)

import os
from torch.utils.data import DataLoader

BATCH_SIZE = 32
NUM_WORKERS = os.cpu_count()

train_dataloader_simple = DataLoader(
    dataset = train_data_simple,
    batch_size = BATCH_SIZE,
    shuffle = True,
    num_workers = NUM_WORKERS
)

test_dataloader_simple = DataLoader(
    dataset = test_data_simple,
    batch_size = BATCH_SIZE,
    shuffle = False,
    num_workers = NUM_WORKERS
)


### 7.2 Create TinyVGG model class

In [None]:
class TinyVGG(nn.Module):
  def __init__(
      self,
      input_shape: int,
      hidden_units: int,
      output_shape: int
  ) -> None:
    super().__init__()
    self.conv_block_1 = nn.Sequential(
        nn.Conv2d(
            in_channels = input_shape,
            out_channels = hidden_units,
            kernel_size = 3,
            stride = 1,
            padding = 0
        ),
        nn.ReLU(),
        nn.Conv2d(
            in_channels = hidden_units,
            out_channels = hidden_units,
            kernel_size = 3,
            stride = 1,
            padding = 0
        ),
        nn.ReLU(),
        nn.MaxPool2d(
            kernel_size = 2,
            stride = 2
        )
    )
    self.conv_block_2 = nn.Sequential(
        nn.Conv2d(
            in_channels = hidden_units,
            out_channels = hidden_units,
            kernel_size = 3,
            stride = 1,
            padding = 0
        ),
        nn.ReLU(),
        nn.Conv2d(
            in_channels = hidden_units,
            out_channels = hidden_units,
            kernel_size = 3,
            stride = 1,
            padding = 0
        ),
        nn.ReLU(),
        nn.MaxPool2d(
            kernel_size = 2,
            stride = 2
        )
    )
    self.classifier = nn.Sequential(
        nn.Flatten(),
        nn.Linear(
            in_features = hidden_units * 169,
            out_features = output_shape
        )
    )
  def forward(self, x):
    x = self.conv_block_1(x)
    #print(x.shape)
    x = self.conv_block_2(x)
    #print(x.shape)
    x = self.classifier(x)
    #print(x.shape)
    return x


In [None]:
torch.manual_seed(42)
model_0 = TinyVGG(
    input_shape = 3,
    hidden_units = 10,
    output_shape = len(class_names)
).to(device)
model_0

### 7.3 Try a forward pass on a single image (to test the model)

In [None]:
image_batch, label_batch = next(iter(train_dataloader_simple))
image_batch.shape, label_batch.shape

In [None]:
model_0.forward(image_batch.to(device))

### 7.4 Use `torchinfo` to get an idea of the shapes going through our model

In [None]:
try:
  import torchinfo
except:
  !pip install torchinfo
  import torchinfo

from torchinfo import summary
summary(model_0, input_size = [1,3,64,64])

### 7.5 Create train and test loops functions

In [None]:
def train_step(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    loss_fn: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    device = device
):
  model.train()

  train_loss, train_acc = 0, 0

  for batch, (X, y) in enumerate(dataloader):
    X, y = X.to(device), y.to(device)

    y_pred = model(X)

    loss = loss_fn(y_pred, y)
    train_loss += loss.item()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    y_pred_class = torch.argmax(torch.softmax(y_pred, dim = 1), dim = 1)
    train_acc += (y_pred_class == y).sum().item()/len(y_pred)

  train_loss = train_loss / len(dataloader)
  train_acc = train_acc / len(dataloader)
  return train_loss, train_acc

In [None]:
def test_step(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    loss_fn: torch.nn.Module,
    device = device
):
  model.eval()

  test_loss, test_acc = 0, 0

  with torch.inference_mode():
    for batch, (X, y) in enumerate(dataloader):
      X, y = X.to(device), y.to(device)

      test_pred_logits = model(X)

      loss = loss_fn(test_pred_logits, y)
      test_loss += loss.item()

      test_pred_labels = test_pred_logits.argmax(dim = 1)
      test_acc += (test_pred_labels == y).sum().item()/len(test_pred_labels)

  test_loss = test_loss / len(dataloader)
  test_acc = test_acc / len(dataloader)

  return test_loss, test_acc

### 7.6 Creating a `train()` function to combine `train_step()` and `test_step()`

In [None]:
from tqdm.auto import tqdm

def train(
    model: torch.nn.Module,
    train_dataloader: torch.utils.data.DataLoader,
    test_dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    loss_fn: torch.nn.Module = nn.CrossEntropyLoss(),
    epochs: int = 5,
    device = device
):
  results = {
      "train_loss": [],
      "train_acc": [],
      "test_loss": [],
      "test_acc": []
  }

  for epoch in tqdm(range(epochs)):
    train_loss, train_acc = train_step(
        model = model,
        dataloader = train_dataloader,
        loss_fn = loss_fn,
        optimizer = optimizer,
        device = device
    )

    test_loss, test_acc = test_step(
        model = model,
        dataloader = test_dataloader,
        loss_fn = loss_fn,
        device = device
    )

    print(f"Epoch: {epoch + 1} | Train loss: {train_loss:.4f} | Train acc: {train_acc:.4f} | Test loss: {test_loss:.4f} | Test_acc: {test_acc:.4f}")

    results["train_loss"].append(train_loss)
    results["train_acc"].append(train_acc)
    results["test_loss"].append(test_loss)
    results["test_acc"].append(test_acc)
  return results

### 7.7 Train and evaluate model 0



In [None]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

NUM_EPOCHS = 5

model_0 = TinyVGG(
    input_shape = 3,
    hidden_units = 10,
    output_shape = len(train_data.classes)
).to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    params = model_0.parameters(),
    lr = 0.001
)

from timeit import default_timer as timer
start_time = timer()

model_0_results = train(
    model = model_0,
    train_dataloader = train_dataloader_simple,
    test_dataloader = test_dataloader_simple,
    optimizer = optimizer,
    loss_fn = loss_fn,
    epochs = NUM_EPOCHS,
    device = device
)

end_time = timer()

print(f"Total training time: {end_time - start_time:.3f} seconds")

In [None]:
model_0_results

### 7.8 Plot the loss curves of Model 0

In [None]:
model_0_results.keys()


In [None]:
def plot_loss_curves(results: Dict[str, List[float]]):
  loss = results["train_loss"]
  test_loss = results["test_loss"]

  accuracy = results["train_acc"]
  test_accuracy = results["test_acc"]

  epochs = range(len(results["train_loss"]))

  plt.figure(figsize = (15, 7))
  plt.subplot(1,2,1)
  plt.plot(epochs, loss, label = "train_loss")
  plt.plot(epochs, test_loss, label = "test_loss")
  plt.title("Loss")
  plt.xlabel("Epochs")
  plt.legend()

  plt.subplot(1,2,2)
  plt.plot(epochs, accuracy, label = "train_accuracy")
  plt.plot(epochs, test_accuracy, label = "test_accuracy")
  plt.title("Accuracy")
  plt.xlabel("Epoch")
  plt.legend();

In [None]:
plot_loss_curves(model_0_results)

## 8. What should an ideal loss curve look like

- underfitting
- overfitting

## 9. Model 1: TinyVGG with Data Augmentation

### 9.1 Create transform with data augmentation

In [None]:
from torchvision import transforms
train_transform_trivial = transforms.Compose([
    transforms.Resize(size = (64, 64)),\
    transforms.TrivialAugmentWide(num_magnitude_bins = 31),
    transforms.ToTensor()
])

test_transform_simple = transforms.Compose([
    transforms.Resize(size  =(64, 64)),
    transforms.ToTensor()
])



### 9.2 Create train and test `Dataset`'s and `DataLoader`'s with data augmentation

In [None]:
from torchvision import datasets
train_data_augmented = datasets.ImageFolder(
    root = train_dir,
    transform = train_transform_trivial
)

test_data_simple = datasets.ImageFolder(
    root = test_dir,
    transform = test_transform_simple
)

In [None]:
import os
BATCH_SIZE = 32
NUM_WORKERS = os.cpu_count()

from torch.utils.data import DataLoader
torch.manual_seed(42)
train_dataloader_augmented = DataLoader(
    dataset = train_data_augmented,
    batch_size = BATCH_SIZE,
    shuffle = True,
    num_workers = NUM_WORKERS
)
test_dataloader_simple = DataLoader(
    dataset = test_data_simple,
    batch_size = BATCH_SIZE,
    shuffle = False,
    num_workers = NUM_WORKERS
)



### 9.3 Construct and train model 1

In [None]:
torch.manual_seed(42)
model_1 = TinyVGG(
    input_shape = 3,
    hidden_units = 10,
    output_shape = len(train_data_augmented.classes)
).to(device)
model_1

In [None]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

NUM_EPOCHS = 5
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    params = model_1.parameters(),
    lr = 0.001
)

from timeit import default_timer as timer
start_time = timer()

model_1_results = train(
    model = model_1,
    train_dataloader = train_dataloader_augmented,
    test_dataloader = test_dataloader_simple,
    optimizer = optimizer,
    loss_fn = loss_fn,
    epochs = NUM_EPOCHS,
    device = device
)

end_time = timer()

print(f"Total training time for model_1: {end_time - start_time:.3f} seconds")

### 9.4 Plot the loss curves of model 1

In [None]:
plot_loss_curves(model_1_results)

## 10. Compare model results

In [None]:
import pandas as pd
model_0_df = pd.DataFrame(model_0_results)
model_1_df = pd.DataFrame(model_1_results)
model_0_df, model_1_df

In [None]:
plt.figure(figsize = (15, 10))

epochs = range(len(model_0_df))

plt.subplot(2,2,1)
plt.plot(epochs, model_0_df["train_loss"], label = "Model 0")
plt.plot(epochs, model_1_df["train_loss"], label = "Model 1")
plt.title("Train loss")
plt.legend();

plt.subplot(2,2,2)
plt.plot(epochs, model_0_df["test_loss"], label = "Model 0")
plt.plot(epochs, model_1_df["test_loss"], label = "Model 1")
plt.title("Test loss")
plt.legend();

plt.subplot(2,2,3)
plt.plot(epochs, model_0_df["train_acc"], label = "Model 0")
plt.plot(epochs, model_1_df["train_acc"], label = "Model 1")
plt.title("Train acc")
plt.legend();

plt.subplot(2,2,4)
plt.plot(epochs, model_0_df["test_acc"], label = "Model 0")
plt.plot(epochs, model_1_df["test_acc"], label = "Model 1")
plt.title("Test acc")
plt.legend();

## 11. Making a prediction on a custom image

In [None]:
import requests

custom_image_path = data_path / "04-pizza-dad.jpeg"

if not custom_image_path.is_file():
  with open(custom_image_path, "wb") as f:
    request = requests.get("https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/images/04-pizza-dad.jpeg")
    print(f"Downloading {custom_image_path}...")
    f.write(request.content)
else:
  print(f"{custom_image_path} already exists, skipping downlaod")

### 11.1 Loading in a custom image with PyTorch

In [None]:
import torchvision

custom_image_uint8 = torchvision.io.read_image(custom_image_path)
print(f"Custom image tensor:\n {custom_image_uint8}")
print(f"Custom image shape: {custom_image_uint8.shape}")
print(f"Custom image datatype: {custom_image_uint8.dtype}")

In [None]:
plt.imshow(custom_image_uint8.permute(1,2,0));

### 11.2 Making a prediction on a custom image with a trained PyTorch model

In [None]:
model_1.eval()
with torch.inference_mode():
  model_1(custom_image_uint8.to(device))

In [None]:
custom_image = torchvision.io.read_image(custom_image_path).type(torch.float32) / 255
custom_image

In [None]:
model_1.eval()
with torch.inference_mode():
  model_1(custom_image.to(device))

In [None]:
from torchvision import transforms
custom_image_transform = transforms.Compose([
    transforms.Resize(size = (64, 64))
])
custom_image_transformed = custom_image_transform(custom_image)
print(f"Original shape: {custom_image.shape}")
print(f"Transformed shape: {custom_image_transformed.shape}")

In [None]:
plt.imshow(custom_image_transformed.permute(1,2,0))

In [None]:
model_1.eval()
with torch.inference_mode():
  custom_image_pred = model_1(custom_image_transformed.unsqueeze(0).to(device))
custom_image_pred

In [None]:
custom_image_pred_probs = torch.softmax(custom_image_pred, dim = 1)
custom_image_pred_probs

In [None]:
custom_image_pred_labels = torch.argmax(custom_image_pred_probs, dim = 1).cpu()
custom_image_pred_labels

In [None]:
class_names[custom_image_pred_labels]

### 11.3 Putting custom image prediction together: building a function

In [None]:
def pred_and_plot_image(
    model: torch.nn.Module,
    image_path: str,
    class_names: List[str] = None,
    transform = None,
    device = device
):
  """
  Makes a prediction on a custom image and plots the image with the prediction
  """
  target_image = torchvision.io.read_image(str(image_path)).type(torch.float32)
  target_image = target_image / 255

  if transform:
    target_image = transform(target_image)
  model.to(device)

  model.eval()
  with torch.inference_mode():
    target_image = target_image.unsqueeze(0) # this is the batch dimension

    target_image_pred = model(target_image.to(device))

  target_image_pred_probs = torch.softmax(target_image_pred, dim = 1)
  target_image_pred_label = torch.argmax(target_image_pred_probs, dim = 1)

  plt.imshow(target_image.squeeze().permute(1,2,0))
  if class_names:
    title = f"Pred: {class_names[target_image_pred_label.cpu()]} | Prob: {target_image_pred_probs.max().cpu():.3f}"
  else:
    title = f"Pred: {target_image_pred_label} | Prob: {target_image_pred_probs.max().cpu():.3f}"
  plt.title(title)
  plt.axis(False)

In [None]:
pred_and_plot_image(
    model = model_1,
    image_path = custom_image_path,
    class_names = class_names,
    transform = custom_image_transform,
    device = device
)