In [None]:
!pip install matplotlib

In [1]:
from torch import nn

In [2]:
### CLEANING UP THE DATASET BY REMOVING RETINAL FUNDUS IMAGES FEATURING DISEASES
import os
import pandas as pd
import tqdm

damaged_images = []

print("Scanning through the labels file...")
labels = pd.read_excel("ODIR/labels.xlsx")
for i in tqdm.tqdm(range(len(labels["ID"]))):
    if "normal fundus" not in labels["Left-Diagnostic Keywords"][i]:
        damaged_images.append(labels["Left-Fundus"][i])
    if "normal fundus" not in labels["Right-Diagnostic Keywords"][i]:
        damaged_images.append(labels["Right-Fundus"][i])

print(f"\nLength of the corrected dataset: {len(labels["ID"]+1)*2 - len(damaged_images)}")

print(f"\nDeleting the corrupted images...")
dir = os.listdir("ODIR/images")
i = 0
for file in tqdm.tqdm(dir):
    if file in damaged_images:
        os.remove(f"ODIR/images/{file}")


Scanning through the labels file...


100%|██████████| 3500/3500 [00:00<00:00, 131559.47it/s]



Length of the corrected dataset: 3098

Deleting the corrupted images...


100%|██████████| 3099/3099 [00:00<00:00, 34859.73it/s]


In [3]:
### RENAMING FILES TO FEATURE LABEL
dir = os.listdir("ODIR/images/")
os.chdir("ODIR/images/")

for file in tqdm.tqdm(dir):
    matched_rows = labels[(labels["Left-Fundus"] == file) | (labels["Right-Fundus"] == file)]
    if not matched_rows.empty:
        for index, row in matched_rows.iterrows():
            sex = row["Patient Sex"]
            os.rename(file, f'{sex}_{file}')

  0%|          | 0/3099 [00:00<?, ?it/s]

100%|██████████| 3099/3099 [00:01<00:00, 2907.19it/s]


In [4]:
### CREATING A DATALOADER
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from PIL import Image

class ODIRDataset(Dataset):
    def __init__(self,
                 targ_dir: str,
                 transform=None):
        self.paths = list(os.listdir())
        self.transforms = transform
        self.classes = ['Male', 'Female']
        self.class_to_idx = {'Male': 0, 'Female': 1}
    
    def load_image(self, index: int) -> Image.Image:
        image_path = self.paths[index]
        return Image.open(image_path)
    
    def __len__(self) -> int:
        return len(self.paths)
    
    def __getitem__(self, index):
        img = self.load_image(index)
        class_name = self.paths[index].split('_')[0]
        class_idx = self.class_to_idx[class_name]

        if self.transforms:
            return self.transforms(img), class_idx
        else:
            return img, class_idx

In [5]:
### TRANSFORMS
import torch
from torchvision.transforms import v2

train_transforms = v2.Compose([
    v2.Resize(size=(300, 300)),
    v2.ToTensor()
    # v2.Normalize(mean=[0.2, 0.2, 0.2], std=[0.229, 0.224, 0.225]),
])

test_transforms = v2.Compose([
    v2.Resize(size=(300, 300)),
    v2.ToTensor()
])



In [6]:
### SPLITTING THE 'IMAGES' FOLDER INTO A TRAINING SET AND A TESTING SET
train_size = int(0.8 * len(dir))
test_size = len(dir) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dir, [train_size, test_size])

In [7]:
train_data = ODIRDataset(targ_dir=train_dataset,
                         transform=train_transforms)
test_data = ODIRDataset(targ_dir=test_dataset,
                        transform=test_transforms)

In [15]:
for path in train_data.paths:
    if "jpg" not in path:
        print(path)

.DS_Store


In [8]:
from typing import List
import matplotlib.pyplot as plt
import random

#1. Create a function to take in a dataset
def display_random_images(dataset: Dataset,
                          classes: List[str] = None,
                          n: int = 10,
                          display_shape: bool = True,
                          seed: int = None):
  #2. Adjust display if n is too high
  if n > 10:
    n = 10
    display_shape = False
    print("For display purposes, n shouldn't be larger than 10, setting it to 10 and removing shape display")

  #3. Set the seed
  if seed:
    random.seed(seed)

  #4. Get random sample indexes
  random_samples_idx = random.sample(range(len(dataset)), k=n)

  #5. Setup plot
  plt.figure(figsize=(16, 8))

  #6. Loop through and plot random indexes
  for i, targ_sample in enumerate(random_samples_idx):
    targ_image, targ_label = dataset[targ_sample][0], dataset[targ_sample][1]

    #7. Adjust tensor dimensions
    targ_image_adjust = targ_image.permute(1, 2, 0)

    #8. Plot adjusted samples
    plt.subplot(1, n, i+1)
    plt.imshow(targ_image_adjust)
    plt.axis(False)
    if classes:
      title = f"Class: {classes[targ_label]}"
      if display_shape:
        title = title + f"\nshape: {targ_image_adjust.shape}"
    plt.title(title)

In [None]:
display_random_images(train_data,
                      classes=train_data.classes,
                      n=5)

In [160]:
## LOADING DATA INTO BATCHES
BATCH_SIZE = 32
NUM_WORKERS = os.cpu_count()

train_dataloader = DataLoader(dataset=train_data,
                                     batch_size=BATCH_SIZE,
                                     num_workers=NUM_WORKERS,
                                     shuffle=True)
test_dataloader  = DataLoader(dataset=test_data,
                                    batch_size=BATCH_SIZE,
                                    num_workers=NUM_WORKERS,
                                    shuffle=False)

(3104, 3099)

In [162]:
# Get image and label from custom dataloader
img_custom, label_custom = next(iter(train_dataloader))
img_custom.shape, label_custom.shape

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/spawn.py", line 132, in _main
    self = reduction.pickle.load(from_parent)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: Can't get attribute 'ODIRDataset' on <module '__main__' (<class '_frozen_importlib.BuiltinImporter'>)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/spawn.py", li

RuntimeError: DataLoader worker (pid(s) 61647) exited unexpectedly