In [None]:
import torch
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

In [None]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

training_path = './data/Training'
testing_path = './data/Testing'
batch_size = 64
image_size = (224, 224)

In [None]:
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(image_size)
])

testval_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(image_size)
])

In [None]:
trainset = datasets.ImageFolder(training_path, transform=train_transform)
testset = datasets.ImageFolder(testing_path, transform=testval_transform)

split_ratio = 0.15
trainset_len = len(trainset)
valset_len = int(split_ratio*trainset_len)

trainset, validationset = torch.utils.data.random_split(
    trainset, 
    [trainset_len - valset_len, valset_len],
    generator=torch.Generator().manual_seed(42)
)

In [None]:
train_dl = DataLoader(trainset, batch_size, shuffle=True, num_workers=3)
test_dl = DataLoader(testset, batch_size, shuffle=True, num_workers=3)
validation_dl = DataLoader(validationset, batch_size, shuffle=True, num_workers=3) 

In [None]:
examples = iter(train_dl)
imgs, labels = next(examples)
# access labels 
# class_names = trainset.dataset.classes

In [None]:
def imshow(img):
    fig, ax = plt.subplots(figsize=(20, 20))
    plt.imshow(img.permute(1, 2, 0))
    ax.set_xticks([])
    ax.set_yticks([])
    plt.show()

imshow(torchvision.utils.make_grid(imgs))