<a href="https://colab.research.google.com/github/e-mny/drive_retinal_seg/blob/main/rv_gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RV-GAN: Segmenting Retinal Vascular Structure in Fundus Photographs using a Novel Multi-scale Generative Adversarial Network

Original paper was written in TensorFlow, so I will be practise using PyTorch.

I will also use this chance to try PyTorch Lightning to speed up the workflow.

In [None]:
!pip install lightning
import torch
import torch.nn as nn
import pytorch_lightning as pl
from lightning.pytorch.callbacks import ModelSummary
from torch.utils.data import DataLoader, random_split, Dataset
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from sklearn.model_selection import KFold
import os
import shutil
import random
from PIL import Image
import numpy as np
import cv2



| Dataset Name | Total Available Images | Training Images Used | Testing Images Used
| -------- | -------- | -------- | ------- |
| DRIVE  | 40 | 20 | 20
| CHASE  | 28 | 20 | 8
| STARE  | 400  | 20 | 4


## STARE Dataset

In [None]:
# Getting STARE Dataset

!rm -rf STARE
!mkdir STARE
!mkdir STARE/images
!mkdir STARE/labels
!wget https://cecas.clemson.edu/~ahoover/stare/probing/stare-images.tar
!wget https://cecas.clemson.edu/~ahoover/stare/probing/labels-ah.tar
!tar xf stare-images.tar -C STARE/images
!tar xf labels-ah.tar -C STARE/labels

# Unzip the ppm.gz zipped files in STARE folder
!find STARE/images/ -type f -name "*.gz" -exec gzip -d {} +
!find STARE/labels/ -type f -name "*.gz" -exec gzip -d {} +

mkdir: cannot create directory ‘STARE/train/images’: No such file or directory
mkdir: cannot create directory ‘STARE/train/labels’: No such file or directory
mkdir: cannot create directory ‘STARE/test/images’: No such file or directory
mkdir: cannot create directory ‘STARE/test/labels’: No such file or directory
--2024-01-31 09:36:55--  https://cecas.clemson.edu/~ahoover/stare/probing/stare-images.tar
Resolving cecas.clemson.edu (cecas.clemson.edu)... 130.127.200.74
Connecting to cecas.clemson.edu (cecas.clemson.edu)|130.127.200.74|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 18674176 (18M) [application/x-tar]
Saving to: ‘stare-images.tar’


2024-01-31 09:36:55 (48.8 MB/s) - ‘stare-images.tar’ saved [18674176/18674176]

--2024-01-31 09:36:55--  https://cecas.clemson.edu/~ahoover/stare/probing/labels-ah.tar
Resolving cecas.clemson.edu (cecas.clemson.edu)... 130.127.200.74
Connecting to cecas.clemson.edu (cecas.clemson.edu)|130.127.200.74|:443... connected.
H

In [None]:
def split_train_test_stare(source_folder, train_folder, test_folder, num_test_images):
    # Create images and labels folder in train and test folders if they don't exist
    os.makedirs(os.path.join(train_folder, "images"), exist_ok=True)
    os.makedirs(os.path.join(test_folder, "images"), exist_ok=True)
    os.makedirs(os.path.join(train_folder, "labels"), exist_ok=True)
    os.makedirs(os.path.join(test_folder, "labels"), exist_ok=True)
    image_folder = os.path.join(source_folder, "images")
    label_folder = os.path.join(source_folder, "labels")

    # Get a list of all image files in the image folder
    image_files = [f for f in os.listdir(image_folder) if f.lower().endswith('.ppm')]

    # Get a list of all label files in the label folder
    label_files = [f for f in os.listdir(label_folder) if f.lower().endswith('.ppm')]

    # Randomly choose images for the test set
    test_images = random.sample(image_files, num_test_images)
    test_file_name_list = []

    # Move images to the appropriate folders
    for image_file in image_files:
        source_path = os.path.join(image_folder, image_file)
        if image_file in test_images:
            test_file_name_list.append(os.path.splitext(image_file)[0])
            destination_path = os.path.join(test_folder, "images", image_file)
        else:
            destination_path = os.path.join(train_folder, "images", image_file)
        shutil.move(source_path, destination_path)

    for label_file in label_files:
        source_path = os.path.join(label_folder, label_file)
        if label_file.split(".")[0] in test_file_name_list:
            destination_path = os.path.join(test_folder, "labels", label_file)
        else:
            destination_path = os.path.join(train_folder, "labels", label_file)
        shutil.move(source_path, destination_path)


    os.rmdir(os.path.join(source_folder, "images"))
    os.rmdir(os.path.join(source_folder, "labels"))

# STARE (4 images used for testing)
stare_source_folder = '/content/STARE'
stare_train_folder = '/content/STARE/training'
stare_test_folder = '/content/STARE/test'
split_train_test_stare(stare_source_folder, stare_train_folder, stare_test_folder, num_test_images = 4)

## CHASE Dataset

In [None]:
# Getting CHASE Dataset

!rm -rf CHASE
!mkdir CHASE
!wget https://staffnet.kingston.ac.uk/~ku15565/CHASE_DB1/assets/CHASEDB1.zip
!unzip CHASEDB1.zip -d CHASE

In [None]:
def split_image_labels(source_folder):
  image_folder = os.path.join(source_folder, "images")
  labels_folder = os.path.join(source_folder, "labels")
  os.makedirs(image_folder, exist_ok=True)
  os.makedirs(labels_folder, exist_ok=True)
  for filename in os.listdir(source_folder):

    # Images end with .jpg file
    if filename.lower().endswith(".jpg"):
      curr_file_path = os.path.join(source_folder, filename)
      new_file_path = os.path.join(image_folder, filename)
      shutil.move(curr_file_path, new_file_path)
    # Labels end with .png file
    elif filename.lower().endswith(".png"):
      curr_file_path = os.path.join(source_folder, filename)
      new_file_path = os.path.join(labels_folder, filename)
      shutil.move(curr_file_path, new_file_path)

  print("Files split successfully.")


def rename_files(source_folder):
  # Create a mapping of old names to new unique IDs
  id_mapping = {}
  unique_id = 1

  for file_name in sorted(os.listdir(source_folder)):
      # Extract the patient number from the folder name
      patient_number = file_name[6:9]

      # Create a unique ID if not already assigned
      if patient_number not in id_mapping:
          id_mapping[patient_number] = unique_id
          unique_id += 1

      # File extension
      # old_file_name, file_extension = os.path.splitext(file_name)

      # Construct the new folder name
      new_file_name = f"{id_mapping[patient_number]:02d}{file_name[9:]}"

      # Rename the folder
      old_path = os.path.join(source_folder, file_name)
      new_path = os.path.join(source_folder, new_file_name)
      os.rename(old_path, new_path)

  print("Folders renamed successfully.")


def split_train_test_chase(source_folder, train_folder, test_folder, num_test_images):
    # Create images and labels folder in train and test folders if they don't exist
    os.makedirs(os.path.join(train_folder, "images"), exist_ok=True)
    os.makedirs(os.path.join(test_folder, "images"), exist_ok=True)
    os.makedirs(os.path.join(train_folder, "labels"), exist_ok=True)
    os.makedirs(os.path.join(test_folder, "labels"), exist_ok=True)

    image_folder = os.path.join(source_folder, "images")
    label_folder = os.path.join(source_folder, "labels")

    # Get a list of all image files in the image folder
    image_files = [f for f in os.listdir(image_folder) if f.lower().endswith('.jpg')]

    # Get a list of all label files in the label folder
    label_files = [f for f in os.listdir(label_folder) if f.lower().endswith('.png')]

    # Randomly choose images for the test set
    test_images = random.sample(image_files, num_test_images)

    test_file_name_list = []

    # Move images to the appropriate folders
    for image_file in image_files:
        source_path = os.path.join(image_folder, image_file)
        if image_file in test_images:
            test_file_name_list.append(os.path.splitext(image_file)[0])
            destination_path = os.path.join(test_folder, "images", image_file)
        else:
            destination_path = os.path.join(train_folder, "images", image_file)
        shutil.move(source_path, destination_path)

    for label_file in label_files:
        source_path = os.path.join(label_folder, label_file)
        if label_file.split("_")[0] in test_file_name_list:
            destination_path = os.path.join(test_folder, "labels", label_file)
        else:
            destination_path = os.path.join(train_folder, "labels", label_file)
        shutil.move(source_path, destination_path)


    os.rmdir(os.path.join(source_folder, "images"))
    os.rmdir(os.path.join(source_folder, "labels"))

# CHASE (8 images used for testing)
chase_source_folder = '/content/CHASE'
chase_train_folder = '/content/CHASE/training'
chase_test_folder = '/content/CHASE/test'
rename_files(chase_source_folder)
split_image_labels(chase_source_folder)
split_train_test_chase(chase_source_folder, chase_train_folder, chase_test_folder, num_test_images = 8)

Folders renamed successfully.
Files split successfully.


## DRIVE Dataset

Remember to upload your Kaggle.json file when prompted

In [None]:
# Loading DRIVE dataset

from google.colab import files
files.upload()
!rm -r ~/.kaggle
!mkdir ~/.kaggle
!mv ./kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download -d andrewmvd/drive-digital-retinal-images-for-vessel-extraction
!mkdir DRIVE_patches
!unzip drive-digital-retinal-images-for-vessel-extraction.zip
!mv DRIVE/training/1st_manual DRIVE/training/labels

In [None]:
def create_image_patches(folder_path, output_path, patch_size=128, stride=32):
  for root, _, files in os.walk(folder_path):
    for filename in files:
      if filename.endswith((".tif", ".jpg", ".ppm")):
        curr_file_path = os.path.join(root, filename)
        # Open the image
        img = Image.open(curr_file_path)

        # Convert image to numpy array
        img_array = np.array(img)

        # Get image shape
        height, width, channels = img_array.shape
        # print(f"Original Image Shape: {height, width}")


        # Create output directory if it doesn't exist
        os.makedirs(output_path, exist_ok=True)

        # Iterate over the image with the specified stride
        for y in range(0, height - patch_size + 1, stride):
            for x in range(0, width - patch_size + 1, stride):
                # Extract the patch
                patch = img_array[y:y+patch_size, x:x+patch_size, :]


                # Save the patch as an image
                patch_img = Image.fromarray(patch.astype('uint8'))
                new_filename = filename.split(".")[0]
                patch_filename = f"{new_filename}{y}_{x}.png"
                patch_img.save(os.path.join(output_path, patch_filename))


def generate_FOV_mask(source_dir):
  image_folder = os.path.join(source_dir, "images")
  mask_folder = os.path.join(source_dir, "mask")
  os.makedirs(mask_folder, exist_ok=True)

  for filename in os.listdir(image_folder):
    image_path = os.path.join(image_folder, filename)
    output_path = os.path.join(mask_folder, filename)

    # Read the image
    image = cv2.imread(image_path)

    # Convert the image to grayscale
    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    # Apply a simple threshold to identify pixels close to black
    _, thresholded_image = cv2.threshold(gray_image, 30, 255, cv2.THRESH_BINARY)

    # Save the thresholded image
    cv2.imwrite(output_path, thresholded_image)
  print(f"Generated FOV mask for {source_dir}")


def rename_masks(source_dir):
  for filename in os.listdir(source_dir):
    # Check if the filename contains "_mask"
    if "_mask" in filename:
        # Create the new filename without "_mask"
        new_filename = filename.replace("_mask", "")

        # Construct the full paths
        old_path = os.path.join(source_dir, filename)
        new_path = os.path.join(source_dir, new_filename)

        # Rename the file
        os.rename(old_path, new_path)


## Preparing Training Dataset

I will prepare the image patches here and also generate FOV masks for CHASE and STARE as they were not included in the dataset.

I will also rename the mask files in DRIVE to be the same as the original image

```
"xx_training_mask.gif" -> "xx_training.gif"
```

where xx is the index of the image

In [None]:
# Training data preparation
PATCH_SIZE = 128
TRAIN_STRIDE = 32

# DRIVE
drive_original_train_dir = '/content/DRIVE/training/images/'
drive_patches_train_dir = '/content/DRIVE_patches/training'
create_image_patches(drive_original_train_dir, drive_patches_train_dir, PATCH_SIZE, TRAIN_STRIDE)
rename_masks('/content/DRIVE/training/mask/')


print(f"DRIVE dataset has {len(os.listdir(drive_patches_train_dir))} patches")

# CHASE
chase_original_train_dir = '/content/CHASE/training/images/'
chase_patches_train_dir = '/content/CHASE_patches/training'
create_image_patches(chase_original_train_dir, chase_patches_train_dir, PATCH_SIZE, TRAIN_STRIDE)
generate_FOV_mask("/content/CHASE/training/")
print(f"CHASE dataset has {len(os.listdir(chase_patches_train_dir))} patches")

# STARE
stare_original_train_dir = '/content/STARE/training/images/'
stare_patches_train_dir = '/content/STARE_patches/training'
create_image_patches(stare_original_train_dir, stare_patches_train_dir, PATCH_SIZE, TRAIN_STRIDE)
generate_FOV_mask("/content/STARE/training/")
print(f"STARE dataset has {len(os.listdir(stare_patches_train_dir))} patches")

DRIVE dataset has 4200 patches
CHASE dataset has 15120 patches
STARE dataset has 4320 patches


## Preparing Test Dataset

Generating image patches for test dataset takes very long because of the short stride

In [None]:
TEST_STRIDE = 3

# DRIVE
drive_original_test_dir = '/content/DRIVE/test/images/'
drive_patches_test_dir = '/content/DRIVE_patches/test'
create_image_patches(drive_original_test_dir, drive_patches_test_dir, PATCH_SIZE, TEST_STRIDE)

print(f"DRIVE dataset has {len(os.listdir(drive_patches_test_dir))} patches")

# CHASE
chase_original_test_dir = '/content/CHASE/test/images/'
chase_patches_test_dir = '/content/CHASE_patches/test'
create_image_patches(chase_original_test_dir, chase_patches_test_dir, PATCH_SIZE, TEST_STRIDE)
print(f"CHASE dataset has {len(os.listdir(chase_patches_test_dir))} patches")

# STARE
stare_original_test_dir = '/content/STARE/test/images/'
stare_patches_test_dir = '/content/STARE_patches/test'
create_image_patches(stare_original_test_dir, stare_patches_test_dir, PATCH_SIZE, TEST_STRIDE)
print(f"STARE dataset has {len(os.listdir(stare_patches_test_dir))} patches")

DRIVE dataset has 446760 patches
CHASE dataset has 647184 patches
STARE dataset has 122240 patches


In [None]:
class WholeImageDataset(Dataset):
    """Whole Image dataset."""

    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_folder = "images"
        self.mask_folder = "mask"
        self.file_names = os.listdir(os.path.join(root_dir, self.image_folder))

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        img_name = self.file_names[idx]

        # Load image
        img_path = os.path.join(self.root_dir, self.image_folder, img_name)
        image = Image.open(img_path).convert("RGB")

        # Load mask
        mask_path = os.path.join(self.root_dir, self.mask_folder, img_name)
        mask = Image.open(mask_path).convert("L")

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return {"image": image, "mask": mask}


In [None]:
class PatchDataset(Dataset):
    """Patch dataset."""

    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.file_names = os.listdir(self.root_dir)

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        img_name = self.file_names[idx]

        # Load image
        img_path = os.path.join(self.root_dir, img_name)
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image


In [None]:
class DRIVEWholeModule(pl.LightningDataModule):
  def __init__(
      self,
      k: int = 1,  # fold number
      split_seed: int = 42,  # split needs to be always the same for correct cross validation
      num_splits: int = 5,
      batch_size: int = 24,
      num_workers: int = 0,
      pin_memory: bool = True
    ):
    super().__init__()
    # this line allows to access init params with 'self.hparams' attribute
    self.save_hyperparameters(logger=False)
    self.dir = '/content/DRIVE/'
    self.train_dir = os.path.join(self.dir, "training")
    self.test_dir = os.path.join(self.dir, "test")
    self.transform = transforms.Compose([
      transforms.ToTensor()
    ])


  def setup(self, stage: str):
      # Assign train/val datasets for use in dataloaders
      if stage == "fit":
          drive_full = WholeImageDataset(self.train_dir, transform=self.transform)

          # choose fold to train on
          kf = KFold(n_splits=self.hparams.num_splits, shuffle=True, random_state=self.hparams.split_seed)
          all_splits = [k for k in kf.split(drive_full)]
          train_indexes, val_indexes = all_splits[self.hparams.k]
          train_indexes, val_indexes = train_indexes.tolist(), val_indexes.tolist()

          self.train_data, self.val_data = drive_full[train_indexes], drive_full[val_indexes]

      # Assign test dataset for use in dataloader(s)
      if stage == "test":
          self.test_data = WholeImageDataset(self.test_dir, transform=self.transform)

  def train_dataloader(self):
    return DataLoader(self.train_data, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers,
                          pin_memory=self.hparams.pin_memory, shuffle=True)

  def val_dataloader(self):
    return DataLoader(self.val_data, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers,
                          pin_memory=self.hparams.pin_memory)

  def test_dataloader(self):
    return DataLoader(self.test_data, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers,
                          pin_memory=self.hparams.pin_memory)

In [None]:
class DRIVEPatchModule(pl.LightningDataModule):
  def __init__(
      self,
      k: int = 1,  # fold number
      split_seed: int = 42,  # split needs to be always the same for correct cross validation
      num_splits: int = 5,
      batch_size: int = 24,
      num_workers: int = 0,
      pin_memory: bool = True
    ):
    super().__init__()
    self.dir = '/content/DRIVE_patches/'
    self.train_dir = os.path.join(self.dir, "training")
    self.test_dir = os.path.join(self.dir, "test")
    self.transform = transforms.Compose([
      transforms.ToTensor()
    ])


  def setup(self, stage: str):
      # Assign train/val datasets for use in dataloaders
      if stage == "fit":
          drive_full = PatchDataset(self.train_dir, transform=self.transform)

          # choose fold to train on
          kf = KFold(n_splits=self.hparams.num_splits, shuffle=True, random_state=self.hparams.split_seed)
          all_splits = [k for k in kf.split(drive_full)]
          train_indexes, val_indexes = all_splits[self.hparams.k]
          train_indexes, val_indexes = train_indexes.tolist(), val_indexes.tolist()

          self.train_data, self.val_data = drive_full[train_indexes], drive_full[val_indexes]

      # Assign test dataset for use in dataloader(s)
      if stage == "test":
          drive_full = PatchDataset(self.test_dir, transform=self.transform)

  def train_dataloader(self):
    return DataLoader(self.train_data, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers,
                          pin_memory=self.hparams.pin_memory, shuffle=True)

  def val_dataloader(self):
    return DataLoader(self.val_data, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers,
                          pin_memory=self.hparams.pin_memory)

  def test_dataloader(self):
    return DataLoader(self.test_data, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers,
                          pin_memory=self.hparams.pin_memory)

In [None]:
class CHASEPatchModule(pl.LightningDataModule):
  def __init__(
      self,
      k: int = 1,  # fold number
      split_seed: int = 42,  # split needs to be always the same for correct cross validation
      num_splits: int = 5,
      batch_size: int = 24,
      num_workers: int = 0,
      pin_memory: bool = True
    ):
    super().__init__()
    self.dir = '/content/CHASE_patches/'
    self.train_dir = os.path.join(self.dir, "training")
    self.test_dir = os.path.join(self.dir, "test")
    self.transform = transforms.Compose([
      transforms.ToTensor()
    ])


  def setup(self, stage: str):
      # Assign train/val datasets for use in dataloaders
      if stage == "fit":
          chase_full = PatchDataset(self.train_dir, transform=self.transform)

          # choose fold to train on
          kf = KFold(n_splits=self.hparams.num_splits, shuffle=True, random_state=self.hparams.split_seed)
          all_splits = [k for k in kf.split(chase_full)]
          train_indexes, val_indexes = all_splits[self.hparams.k]
          train_indexes, val_indexes = train_indexes.tolist(), val_indexes.tolist()

          self.train_data, self.val_data = chase_full[train_indexes], chase_full[val_indexes]

      # Assign test dataset for use in dataloader(s)
      if stage == "test":
          chase_full = PatchDataset(self.test_dir, transform=self.transform)

  def train_dataloader(self):
    return DataLoader(self.train_data, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers,
                          pin_memory=self.hparams.pin_memory, shuffle=True)

  def val_dataloader(self):
    return DataLoader(self.val_data, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers,
                          pin_memory=self.hparams.pin_memory)

  def test_dataloader(self):
    return DataLoader(self.test_data, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers,
                          pin_memory=self.hparams.pin_memory)

In [None]:
class CHASEWholeModule(pl.LightningDataModule):
  def __init__(
      self,
      k: int = 1,  # fold number
      split_seed: int = 42,  # split needs to be always the same for correct cross validation
      num_splits: int = 5,
      batch_size: int = 24,
      num_workers: int = 0,
      pin_memory: bool = True
    ):
    super().__init__()
    self.dir = '/content/CHASE/'
    self.train_dir = os.path.join(self.dir, "training")
    self.test_dir = os.path.join(self.dir, "test")
    self.transform = transforms.Compose([
      transforms.ToTensor()
    ])


  def setup(self, stage: str):
      # Assign train/val datasets for use in dataloaders
      if stage == "fit":
          chase_full = WholeImageDataset(self.train_dir, transform=self.transform)

          # choose fold to train on
          kf = KFold(n_splits=self.hparams.num_splits, shuffle=True, random_state=self.hparams.split_seed)
          all_splits = [k for k in kf.split(chase_full)]
          train_indexes, val_indexes = all_splits[self.hparams.k]
          train_indexes, val_indexes = train_indexes.tolist(), val_indexes.tolist()

          self.train_data, self.val_data = chase_full[train_indexes], chase_full[val_indexes]

      # Assign test dataset for use in dataloader(s)
      if stage == "test":
          chase_full = WholeImageDataset(self.test_dir, transform=self.transform)

  def train_dataloader(self):
    return DataLoader(self.train_data, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers,
                          pin_memory=self.hparams.pin_memory, shuffle=True)

  def val_dataloader(self):
    return DataLoader(self.val_data, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers,
                          pin_memory=self.hparams.pin_memory)

  def test_dataloader(self):
    return DataLoader(self.test_data, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers,
                          pin_memory=self.hparams.pin_memory)

In [None]:
class STAREWholeModule(pl.LightningDataModule):
  def __init__(
      self,
      k: int = 1,  # fold number
      split_seed: int = 42,  # split needs to be always the same for correct cross validation
      num_splits: int = 5,
      batch_size: int = 24,
      num_workers: int = 0,
      pin_memory: bool = True
    ):
    super().__init__()
    self.dir = '/content/STARE/'
    self.train_dir = os.path.join(self.dir, "training")
    self.test_dir = os.path.join(self.dir, "test")
    self.transform = transforms.Compose([
      transforms.ToTensor()
    ])


  def setup(self, stage: str):
      # Assign train/val datasets for use in dataloaders
      if stage == "fit":
          stare_full = WholeImageDataset(self.train_dir, transform=self.transform)

          # choose fold to train on
          kf = KFold(n_splits=self.hparams.num_splits, shuffle=True, random_state=self.hparams.split_seed)
          all_splits = [k for k in kf.split(stare_full)]
          train_indexes, val_indexes = all_splits[self.hparams.k]
          train_indexes, val_indexes = train_indexes.tolist(), val_indexes.tolist()

          self.train_data, self.val_data = stare_full[train_indexes], stare_full[val_indexes]

      # Assign test dataset for use in dataloader(s)
      if stage == "test":
          stare_full = WholeImageDataset(self.test_dir, transform=self.transform)

  def train_dataloader(self):
    return DataLoader(self.train_data, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers,
                          pin_memory=self.hparams.pin_memory, shuffle=True)

  def val_dataloader(self):
    return DataLoader(self.val_data, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers,
                          pin_memory=self.hparams.pin_memory)

  def test_dataloader(self):
    return DataLoader(self.test_data, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers,
                          pin_memory=self.hparams.pin_memory)

In [None]:
class STAREPatchModule(pl.LightningDataModule):
  def __init__(
      self,
      k: int = 1,  # fold number
      split_seed: int = 42,  # split needs to be always the same for correct cross validation
      num_splits: int = 5,
      batch_size: int = 24,
      num_workers: int = 0,
      pin_memory: bool = True
    ):
    super().__init__()
    self.dir = '/content/STARE_patches/'
    self.train_dir = os.path.join(self.dir, "training")
    self.test_dir = os.path.join(self.dir, "test")
    self.transform = transforms.Compose([
      transforms.ToTensor()
    ])


  def setup(self, stage: str):
      # Assign train/val datasets for use in dataloaders
      if stage == "fit":
          stare_full = PatchDataset(self.train_dir, transform=self.transform)

          # choose fold to train on
          kf = KFold(n_splits=self.hparams.num_splits, shuffle=True, random_state=self.hparams.split_seed)
          all_splits = [k for k in kf.split(stare_full)]
          train_indexes, val_indexes = all_splits[self.hparams.k]
          train_indexes, val_indexes = train_indexes.tolist(), val_indexes.tolist()

          self.train_data, self.val_data = stare_full[train_indexes], stare_full[val_indexes]

      # Assign test dataset for use in dataloader(s)
      if stage == "test":
          stare_full = PatchDataset(self.test_dir, transform=self.transform)

  def train_dataloader(self):
    return DataLoader(self.train_data, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers,
                          pin_memory=self.hparams.pin_memory, shuffle=True)

  def val_dataloader(self):
    return DataLoader(self.val_data, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers,
                          pin_memory=self.hparams.pin_memory)

  def test_dataloader(self):
    return DataLoader(self.test_data, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers,
                          pin_memory=self.hparams.pin_memory)

## Building Blocks

### Downsample Block

In [None]:
# Downsample
class DownsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
      super(DownsampleBlock, self).__init__()

      self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(4, 4), stride=2, dilation=1)
      self.bn = nn.BatchNorm2d(out_channels)
      self.leaky = nn.LeakyReLU()

    def forward(self, x):
      out = self.conv(x)
      out = self.bn(x)
      out = self.leaky(x)
      return out



### Upsample Block

In [None]:
# Upsample
class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
      super(UpsampleBlock, self).__init__()

      self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=(4, 4), stride=2, dilation=1)
      self.bn = nn.BatchNorm2d(out_channels)
      self.leaky = nn.LeakyReLU()

    def forward(self, x):
      out = self.conv(x)
      out = self.bn(out)
      out = self.leaky(out)
      return out

### SFA Block

In [None]:
# SFA Block

# Had to implement this because the function exists in TF but not in PyTorch
# See https://stackoverflow.com/questions/65154182/implement-separableconv2d-in-pytorch
class SeparableConv2d(nn.Module):

  def __init__(self, in_channels, out_channels, kernel_size, stride, dilation, bias=False):
      super(SeparableConv2d, self).__init__()
      self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size,
                                groups = in_channels, bias = bias, stride = stride, dilation = dilation)
      self.pointwise = nn.Conv2d(in_channels, out_channels,
                                kernel_size = 1, bias = bias)

  def forward(self, x):
      out = self.depthwise(x)
      out = self.pointwise(out)
      return out


class SFABlock(nn.Module):
    def __init__(self, in_channels, out_channels):
      super(SFABlock, self).__init__()

      self.sepconv = SeparableConv2d(in_channels, out_channels, kernel_size=(3, 3), stride=1, dilation=1)
      self.bn = nn.BatchNorm2d(out_channels)
      self.leaky = nn.LeakyReLU()
      self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=1, dilation=1)

    def forward(self, x):
      first = self.sepconv(x)
      first = self.bn(first)
      first = self.leaky(first)
      res_first = torch.add(x, first)

      second = self.conv(res_first)
      second = self.bn(second)
      second = self.leaky(second)
      res_second = torch.add(res_first, second)

      out = torch.add(res_second, x)

      return out

### Generator Residual Block

In [None]:
# Generator Residual
class GenResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
      super(GenResidualBlock, self).__init__()

      self.reflect = nn.ReflectionPad2d(padding = 2)
      self.sepconv_d1 = SeparableConv2d(in_channels, out_channels, kernel_size=(3, 3), stride=1, dilation=1)
      self.sepconv_d2 = SeparableConv2d(in_channels, out_channels, kernel_size=(3, 3), stride=1, dilation=2)
      self.bn = nn.BatchNorm2d(out_channels)
      self.leaky = nn.LeakyReLU()

    def forward(self, x):
      first = self.reflect(x)
      first = self.sepconv_d1(first)
      first = self.bn(first)
      first = self.leaky(first)

      split_first = self.reflect(first)
      split_first = self.sepconv_d1(split_first)
      split_first = self.bn(split_first)
      split_first = self.leaky(split_first)

      split_second = self.reflect(first)
      split_second = self.sepconv_d2(split_second)
      split_second = self.bn(split_second)
      split_second = self.leaky(split_second)

      out = torch.add(x, split_first, split_second)

      return out

### Discriminator Residual Block

In [None]:
# Discriminator Residual
class DisResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
      super(DisResidualBlock, self).__init__()

      self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(2, 2), stride=1, dilation=2)
      self.sepconv = SeparableConv2d(in_channels, out_channels, kernel_size=(2, 2), stride=1, dilation=2)
      self.bn = nn.BatchNorm2d(out_channels)
      self.leaky = nn.LeakyReLU()

    def forward(self, x):
      split_first = self.conv(x)
      split_first = self.bn(split_first)
      split_first = self.leaky(split_first)

      split_second = self.sepconv(x)
      split_second = self.bn(split_second)
      split_second = self.leaky(split_second)

      out = torch.add(split_first, split_second)

      return out

### Generator G<sub>f<sub>

In [None]:
# Generator Gf

class GeneratorFine(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GeneratorFine, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(2, 2), stride=1, dilation=2)
        self.downsample = DownsampleBlock(in_channels, out_channels)
        self.upsample = UpsampleBlock(in_channels, out_channels)
        self.sfa = SFABlock(in_channels, out_channels)
        self.g_res = GenResidualBlock(in_channels, out_channels)

    def forward(self, x, mid):
        x = self.conv(x)
        first_out = self.downsample(x)

        top = self.sfa(first_out)

        bot = torch.add(x, mid)
        bot = self.g_res(bot)
        bot = self.g_res(bot)
        bot = self.g_res(bot)

        out = torch.cat((top, bot), axis = 0)
        out = self.upsample(out)
        out = self.conv(out)

        return out

### Generator G<sub>c<sub>

In [None]:
# Generator Gc

class GeneratorCoarse(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GeneratorCoarse, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(2, 2), stride=1, dilation=2)
        self.downsample = DownsampleBlock(in_channels, out_channels)
        self.upsample = UpsampleBlock(in_channels, out_channels)
        self.sfa = SFABlock(in_channels, out_channels)
        self.g_res = GenResidualBlock(in_channels, out_channels)

    def forward(self, x):
        x = self.conv(x)
        first_out = self.downsample(x)

        top = self.sfa(first_out)

        mid = self.downsample(x)
        mid = self.sfa(mid)

        bot = self.g_res(mid)
        bot = self.g_res(bot)
        bot = self.g_res(bot)
        bot = self.g_res(bot)
        bot = self.g_res(bot)
        bot = self.g_res(bot)
        bot = self.g_res(bot)
        bot = self.g_res(bot)
        bot = self.g_res(bot)

        out = torch.cat((mid, bot), axis = 0)
        out = self.upsample(out)


        out = torch.cat((top, out), axis = 0)
        out = self.upsample(out)
        return out # To return value for Gf

### Discriminator D<sub>f<sub>

In [None]:
# Discriminator Df

class DiscriminatorFine(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DiscriminatorFine, self).__init__()
        self.downsample = DownsampleBlock(in_channels, out_channels)
        self.upsample = UpsampleBlock(in_channels, out_channels)
        self.d_res = GenResidualBlock(in_channels, out_channels)

    def forward(self, x):
        x = self.conv(x)
        first_out = self.downsample(x)

        top = self.sfa(first_out)

        bot = torch.add(x, mid)
        bot = self.g_res(bot)
        bot = self.g_res(bot)
        bot = self.g_res(bot)

        out = torch.cat((top, bot), axis = 0)
        out = self.upsample(out)
        out = self.conv(out)

        return out

### Discriminator D<sub>c<sub>

In [None]:
# Discriminator Dc

class DiscriminatorCoarse(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DiscriminatorCoarse, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(2, 2), stride=1, dilation=2)
        self.downsample = DownsampleBlock(in_channels, out_channels)
        self.upsample = UpsampleBlock(in_channels, out_channels)
        self.sfa = SFABlock(in_channels, out_channels)
        self.g_res = GenResidualBlock(in_channels, out_channels)

    def forward(self, x):
        x = self.conv(x)
        first_out = self.downsample(x)

        top = self.sfa(first_out)

        mid = self.downsample(x)
        mid = self.sfa(mid)

        bot = self.g_res(mid)
        bot = self.g_res(bot)
        bot = self.g_res(bot)
        bot = self.g_res(bot)
        bot = self.g_res(bot)
        bot = self.g_res(bot)
        bot = self.g_res(bot)
        bot = self.g_res(bot)
        bot = self.g_res(bot)

        out = torch.cat((mid, bot), axis = 0)
        out = self.upsample(out)


        out = torch.cat((top, out), axis = 0)
        out = self.upsample(out)
        return out # To return value for Gf

In [None]:
class GAN(L.LightningModule):
    def __init__(
        self,
        channels,
        width,
        height,
        latent_dim: int = 100,
        lr: float = 0.0002,
        b1: float = 0.5,
        b2: float = 0.999,
        batch_size: int = 24,
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.automatic_optimization = False

        # networks
        data_shape = (channels, width, height)
        self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape)
        self.discriminator = Discriminator(img_shape=data_shape)

        self.validation_z = torch.randn(8, self.hparams.latent_dim)

        self.example_input_array = torch.zeros(2, self.hparams.latent_dim)

    def forward(self, z):
        return self.generator(z)

    def adversarial_loss(self, y_hat, y): # TODO: Change to paper's novel loss
        return F.binary_cross_entropy(y_hat, y)

    def training_step(self, batch):
        imgs, _ = batch

        optimizer_g, optimizer_d = self.optimizers()

        # sample noise
        z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z = z.type_as(imgs)

        # train generator
        # generate images
        self.toggle_optimizer(optimizer_g)
        self.generated_imgs = self(z)

        # log sampled images
        sample_imgs = self.generated_imgs[:6]
        grid = make_grid(sample_imgs)
        self.logger.experiment.add_image("generated_images", grid, 0)

        # ground truth result (ie: all fake)
        # put on GPU because we created this tensor inside training_loop
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs)

        # adversarial loss is binary cross-entropy
        g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
        self.log("g_loss", g_loss, prog_bar=True)
        self.manual_backward(g_loss)
        optimizer_g.step()
        optimizer_g.zero_grad()
        self.untoggle_optimizer(optimizer_g)

        # train discriminator
        # Measure discriminator's ability to classify real from generated samples
        self.toggle_optimizer(optimizer_d)

        # how well can it label as real?
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs)

        real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

        # how well can it label as fake?
        fake = torch.zeros(imgs.size(0), 1)
        fake = fake.type_as(imgs)

        fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake)

        # discriminator loss is the average of these
        d_loss = (real_loss + fake_loss) / 2
        self.log("d_loss", d_loss, prog_bar=True)
        self.manual_backward(d_loss)
        optimizer_d.step()
        optimizer_d.zero_grad()
        self.untoggle_optimizer(optimizer_d)

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []

    def on_validation_epoch_end(self):
        z = self.validation_z.type_as(self.generator.model[0].weight)

        # log sampled images
        sample_imgs = self(z)
        grid = make_grid(sample_imgs)
        self.logger.experiment.add_image("generated_images", grid, self.current_epoch)

In [None]:
# Load data variables
drive_whole = DRIVEWholeModule()
drive_patch = DRIVEPatchModule()
chase_whole = CHASEWholeModule()
chase_patch = CHASEPatchModule()
stare_whole = STAREWholeModule()
stare_patch = STAREPatchModule()


# train
model = GAN()


## Time to train the model!

In [None]:
trainer = pl.Trainer(
    callbacks=[ModelSummary(max_depth=1)],
    accelerator="auto",
    max_epochs=10
    )

trainer.fit(model, drive_data)

In [None]:
# %load_ext tensorboard
# %tensorboard --logdir lightning_logs/