In [27]:
import torch

from PIL import Image
import torchvision.transforms as transforms

In [28]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

# device = "cpu" # uncomment if you want to use "cpu", currently cpu is faster than cuda (maybe because the NN is very little)
print(f"Using {device} device")

Using cuda device


# Loading an image

In [33]:
# Functions for image handling

def show_image_by_path(_image_path: str) -> None:
    image = Image.open(_image_path)
    image.show()

def image_to_tensor(_image_path: str) -> torch.Tensor:
    image = Image.open(_image_path)

    transform = transforms.Compose([transforms.Resize((224, 224)),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.RandomRotation(15),
                                    transforms.ToTensor()])
    _image_tensor = transform(image)  # image_tensor now has a shape of torch.Size([3, 224, 224])

    # RandomHorizontalFlip
    # --> randomly mirror the image from the left to right
    # RandomRotation
    # --> rotate the image by a random angle within a given range, in this case from -15 to +15 degrees

    # we add a batch dimension since most neural network frameworks expect input in the form of batches
    # the batch dimension helps in parallel processing and is essential for training the model with
    # multiple samples
    _image_tensor = _image_tensor.unsqueeze(0)  # image_tensor now has a shape of torch.Size([1, 3, 224, 224])

    # image_tensor now has these dimensions: [batch_size, channels, height, width]
    # the channel dimension refers to the different color layers that make up an image. Usually, we have 3 channels: RGB
    # by using transforms.ToTensor(), we automatically normalize the pixel values to a range between 0 and 1 (instead of 0 to 255).
    # it is important to understand each value in the multidimensional array is between 0 and 1 now

    return _image_tensor

def show_image_by_tensor(_image_tensor: torch.Tensor) -> None:
    _image_tensor = _image_tensor.squeeze(0)  # remove the batch dimension
    transform = transforms.Compose([transforms.ToPILImage()])

    # convert tensor to PIL image
    image_pil = transform(_image_tensor)

    # display the image
    image_pil.show()

In [30]:
image_path = "data/simpsons_dataset/abraham_grampa_simpson/pic_0000.jpg"
show_image_by_path(image_path)

In [31]:
image_tensor = image_to_tensor(image_path)
image_tensor.shape

torch.Size([1, 3, 224, 224])

In [34]:
show_image_by_tensor(image_tensor)