In [9]:
# Imports
import torch
import torch.nn as nn  # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torch.optim as optim  # For all Optimization algorithms, SGD, Adam, etc.
import torchvision.transforms as transforms  # Transformations we can perform on our dataset
import torchvision
import os
import pandas as pd
from skimage import io
from torch.utils.data import (
    Dataset,
    DataLoader,
)  # Gives easier dataset managment and creates mini batches

In [10]:
class CatsAndDogsDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.annotations = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations)  

    def __getitem__(self, index):
        img_path = os.path.join(self.root_dir, self.annotations.iloc[index, 0])
        image = io.imread(img_path)
        y_label = torch.tensor(int(self.annotations.iloc[index, 1]))

        if self.transform:
            image = self.transform(image)

        return (image, y_label)

In [31]:
dataset = CatsAndDogsDataset(csv_file='dataset\cats_dogs\cats_dogs.csv', \
                             root_dir='dataset\cats_dogs\cats_dogs_resized', \
                             transform=transforms.ToTensor())

train_set, test_set = torch.utils.data.random_split(dataset, [7, 3])
# print(train_set[0])
# print(test_set[0])
train_loader = DataLoader(dataset=train_set, batch_size=2, shuffle=True)
test_loader = DataLoader(dataset=test_set, batch_size=2, shuffle=True)
# print(len(train_loader))
# print(len(test_loader))
next(iter(test_loader))
next(iter(test_loader))

[tensor([[[[0.2510, 0.2510, 0.1961,  ..., 0.5255, 0.5882, 0.5216],
           [0.2275, 0.2392, 0.2000,  ..., 0.5804, 0.6980, 0.5725],
           [0.3137, 0.2902, 0.2627,  ..., 0.5490, 0.5647, 0.5765],
           ...,
           [0.5098, 0.3725, 0.2392,  ..., 0.1529, 0.1804, 0.0980],
           [0.5882, 0.4471, 0.2510,  ..., 0.1843, 0.1490, 0.5490],
           [0.5686, 0.4510, 0.2706,  ..., 0.1216, 0.5020, 0.7686]],
 
          [[0.2118, 0.2118, 0.1569,  ..., 0.4980, 0.5529, 0.4863],
           [0.1765, 0.1882, 0.1529,  ..., 0.5529, 0.6627, 0.5255],
           [0.2588, 0.2353, 0.1961,  ..., 0.5137, 0.5176, 0.5294],
           ...,
           [0.4784, 0.3333, 0.2000,  ..., 0.1137, 0.1412, 0.0627],
           [0.5608, 0.4196, 0.2118,  ..., 0.1373, 0.1059, 0.5059],
           [0.5412, 0.4235, 0.2314,  ..., 0.0745, 0.4588, 0.7294]],
 
          [[0.2078, 0.2078, 0.1529,  ..., 0.4353, 0.4941, 0.4196],
           [0.2039, 0.2157, 0.1686,  ..., 0.4902, 0.5961, 0.4627],
           [0.3176, 0.29