Install Dependencies & Download Data

In [13]:
!pip install gdown unzip

# Download from Google Drive
!gdown 1fuFurVV8rcrVTAFPjhQvzGLNdnTi1jWZ

# Unzip without verbose output
!unzip -q CATS_DOGS.zip

Downloading...
From (original): https://drive.google.com/uc?id=1fuFurVV8rcrVTAFPjhQvzGLNdnTi1jWZ
From (redirected): https://drive.google.com/uc?id=1fuFurVV8rcrVTAFPjhQvzGLNdnTi1jWZ&confirm=t&uuid=308a6acd-c1ae-4b47-a49e-566503a0dff6
To: /content/CATS_DOGS.zip
100% 812M/812M [00:16<00:00, 50.6MB/s]
replace CATS_DOGS/test/CAT/10000.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

Imports

In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms

Collect All Image Paths

In [None]:
train_dir = 'CATS_DOGS/train'
img_paths = []
for root, dirs, files in os.walk(train_dir):
    for fname in files:
        if fname.lower().endswith(('jpg', 'jpeg', 'png', 'bmp', 'gif')):
            img_paths.append(os.path.join(root, fname))

print(f"Found {len(img_paths)} images in {train_dir}.")

Found 18743 images in CATS_DOGS/train.


Define CatDogDataset

In [None]:
class CatDogDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        """
        image_paths: list of full image file paths
        transform: torchvision transforms (default: ToTensor)
        """
        self.image_paths = image_paths
        self.transform = transform if transform is not None else transforms.ToTensor()

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image and convert to RGB
        path = self.image_paths[idx]
        image = Image.open(path).convert('RGB')

        # Apply transform to get a tensor
        image_tensor = self.transform(image)

        # Label = 0 if folder contains 'cat', else 1
        label = 0 if 'cat' in os.path.basename(os.path.dirname(path)).lower() else 1

        return image_tensor, label

Instantiate & Test


In [20]:
dataset = CatDogDataset(img_paths)
dataset[0] # e.g., (tensor([...]), 0 or 1)

(tensor([[[0.9961, 0.9804, 0.9569,  ..., 0.4706, 0.4275, 0.3373],
          [0.9647, 0.9686, 0.9529,  ..., 0.4431, 0.4588, 0.4314],
          [0.9608, 0.9882, 0.9804,  ..., 0.3608, 0.4353, 0.4824],
          ...,
          [0.4902, 0.4941, 0.5059,  ..., 0.4275, 0.4235, 0.4235],
          [0.4863, 0.5020, 0.5373,  ..., 0.4314, 0.4275, 0.4235],
          [0.4863, 0.4980, 0.5333,  ..., 0.4314, 0.4275, 0.4275]],
 
         [[1.0000, 1.0000, 0.9882,  ..., 0.4392, 0.3961, 0.3059],
          [0.9804, 0.9961, 0.9843,  ..., 0.4196, 0.4353, 0.4078],
          [0.9804, 1.0000, 1.0000,  ..., 0.3647, 0.4392, 0.4863],
          ...,
          [0.4863, 0.4902, 0.5020,  ..., 0.4667, 0.4627, 0.4627],
          [0.4824, 0.4980, 0.5373,  ..., 0.4706, 0.4667, 0.4627],
          [0.4745, 0.4863, 0.5216,  ..., 0.4706, 0.4667, 0.4667]],
 
         [[0.9451, 0.9451, 0.9373,  ..., 0.4902, 0.4392, 0.3490],
          [0.9137, 0.9255, 0.9333,  ..., 0.4588, 0.4745, 0.4471],
          [0.9020, 0.9373, 0.9529,  ...,

In [19]:
dataset[-1]  # e.g., (tensor([...]), 0 or 1)

(tensor([[[0.2627, 0.2627, 0.2627,  ..., 0.3686, 0.3725, 0.3608],
          [0.2627, 0.2627, 0.2627,  ..., 0.3647, 0.3569, 0.3412],
          [0.2627, 0.2627, 0.2627,  ..., 0.3529, 0.3333, 0.3059],
          ...,
          [0.0392, 0.0392, 0.0431,  ..., 0.4745, 0.4667, 0.4627],
          [0.0431, 0.0431, 0.0471,  ..., 0.4706, 0.4627, 0.4588],
          [0.0510, 0.0510, 0.0510,  ..., 0.4706, 0.4588, 0.4510]],
 
         [[0.1294, 0.1294, 0.1294,  ..., 0.2510, 0.2549, 0.2431],
          [0.1294, 0.1294, 0.1294,  ..., 0.2471, 0.2392, 0.2235],
          [0.1294, 0.1294, 0.1294,  ..., 0.2431, 0.2235, 0.1961],
          ...,
          [0.0431, 0.0431, 0.0471,  ..., 0.4118, 0.4039, 0.4000],
          [0.0471, 0.0471, 0.0510,  ..., 0.4078, 0.4000, 0.3961],
          [0.0549, 0.0549, 0.0549,  ..., 0.4000, 0.3882, 0.3804]],
 
         [[0.0314, 0.0314, 0.0314,  ..., 0.1490, 0.1529, 0.1412],
          [0.0314, 0.0314, 0.0314,  ..., 0.1451, 0.1373, 0.1216],
          [0.0314, 0.0314, 0.0314,  ...,