In [1]:
## Standard libraries
import os
import numpy as np
from natsort import natsorted
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional

import albumentations as AB 
from PIL import Image
import skimage.io
import skimage.measure
import skimage.segmentation

import networkx as nx

## Imports for plotting
import seaborn as sns
sns.reset_orig()
sns.set()

import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0



## Progress bar
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    ! pip install --quiet pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning import LightningDataModule
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger, MLFlowLogger
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

# Setting the seed
pl.seed_everything(42)

# # Torch geometric
# import torch_geometric.data
# import torch_geometric.utils 
# torch geometric
try:
    import torch_geometric
except ModuleNotFoundError:
    # Installing torch geometric packages with specific CUDA+PyTorch version.
    # See https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html for details
    TORCH = torch.__version__.split('+')[0]
    CUDA = 'cu' + torch.version.cuda.replace('.','')

    ! pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
    ! pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
    ! pip install torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
    ! pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
    ! pip install torch-geometric
    import torch_geometric
import torch_geometric.nn as geom_nn
import torch_geometric.data as geom_data
import torch_geometric.loader as geom_loader
import torch_geometric.transforms as geom_transforms
import torch_geometric.utils as geom_utils
# from torch_geometric.nn import GraphUNet
# from torch_geometric.utils import dropout_adj

from monai.data import CacheDataset, list_data_collate
from monai.config import print_config
from monai.losses import DiceLoss

from monai.utils import set_determinism
from monai.transforms import (
    Compose,
    LoadImaged, AddChanneld, Resized, ScaleIntensityd, Flipd, Rotate90d,
    RandAdjustContrastd, RandHistogramShiftd, RandGaussianNoised, RandGaussianSmoothd, RandGaussianSharpend, 
    RandAffined, RandRotate90d, RandFlipd, RandZoomd, RandSpatialCropd, RandCropByPosNegLabeld, 
    ToTensord
)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)

  set_matplotlib_formats('svg', 'pdf') # For export
Global seed set to 42


cuda:0


In [2]:
class PairedDataModule(LightningDataModule):
    def __init__(self, 
        train_image_dirs: List[str]=['/data/train/images'],
        train_label_dirs: List[str]=['/data/train/labels'], 
        val_image_dirs: List[str]=['/data/val/images'], 
        val_label_dirs: List[str]=['/data/val/labels'],
        test_image_dirs: List[str]=['/data/test/images'],
        test_label_dirs: List[str]=['/data/test/labels'],
    ):
        """[summary]

        Args:
            batch_size (int, optional): [description]. Defaults to 32.
            train_image_dirs (List[str], optional): [description]. Defaults to ['/data/train/images'].
            train_label_dirs (List[str], optional): [description]. Defaults to ['/data/train/labels'].
            val_image_dirs (List[str], optional): [description]. Defaults to ['/data/val/images'].
            val_label_dirs (List[str], optional): [description]. Defaults to ['/data/val/labels'].
            test_image_dirs (List[str], optional): [description]. Defaults to ['/data/test/images'].
            test_label_dirs (List[str], optional): [description]. Defaults to ['/data/test/labels'].
        """
        super().__init__()
        self.train_image_dirs = train_image_dirs
        self.train_label_dirs = train_label_dirs
        self.val_image_dirs = val_image_dirs
        self.val_label_dirs = val_label_dirs
        self.test_image_dirs = test_image_dirs
        self.test_label_dirs = test_label_dirs

    def prepare_data(self):
        # download, split, etc...
        # only called on 1 GPU/TPU in distributed
        pass
    
    def glob_dict(
        self, 
        image_dirs: List[str], 
        label_dirs: List[str],
        ext: str='*.png',
    ) -> Dict[str, List[str]]:
        assert image_dirs is not None and label_dirs is not None
        assert len(image_dirs) == len(label_dirs)
        # Glob all image files in image_dirs
        image_paths = [Path(folder).rglob(ext) for folder in image_dirs]
        image_files = natsorted([str(path) for path_list in image_paths for path in path_list])

        # Glob all label files in label_dirs
        label_paths = [Path(folder).rglob(ext) for folder in label_dirs]
        label_files = natsorted([str(path) for path_list in label_paths for path in path_list])

        # Check that the number of image and label files match
        print(f'Found {len(image_files)} images and {len(label_files)} labels.')
        assert len(image_files) == len(label_files)

        # Create a dictionary of image and label files
        data_dicts = [
            {"image": image_file,  
             "label": label_file} for image_file, label_file in zip(image_files, label_files)
        ]
        return data_dicts

    def setup(self, stage=None):
        # make assignments here (val/train/test split)
        # called on every process in DDP
        self.train_data_dicts = self.glob_dict(self.train_image_dirs, self.train_label_dirs, ext='*.png')
        self.val_data_dicts = self.glob_dict(self.val_image_dirs, self.val_label_dirs, ext='*.png')
        self.test_data_dicts = self.glob_dict(self.test_image_dirs, self.test_label_dirs, ext='*.png')
        # set_determinism(seed=0)

class ImageGrid(object):
    def __init__(self, array=None, diff_edge=False, seed=32):
        """[summary]

        Args:
            array ([numpy array], optional): [H W C=1,3 tensor]. Defaults to None.
            diff_edge (bool, optional): [description]. Defaults to False.
        """
        self.reset()
        # self.set_array(array, diff_edge=diff_edge)
        self.array = array.astype(np.float32) if array.ndim == 3 else np.expand_dims(array, axis=2)
        self.height, self.width = array.shape[:2]
        self.seed = seed
        # Create the graph
        self.graph = nx.grid_2d_graph(self.height, self.width)
        
        self.set_nodes(weight=None)
        self.set_edges(weight=None, diff_edge=diff_edge)
        # for n, node in enumerate(self.graph.nodes):
        #     print(self.graph.nodes[node])
        self.number_of_nodes = self.graph.number_of_nodes()
        self.number_of_edges = self.graph.number_of_edges()

    def reset(self):
        self.height = 0
        self.width = 0
        self.array = None
        self.graph = None

    def set_edges(self, weight=None, diff_edge=True, cc=8):
        # for e, edge in enumerate(self.graph.edges):
        #     self.graph.edges[edge]['weight'] = 0.5

        if cc==8:
            k=1
            self.graph.add_edges_from([
                ((x+0, y+0), (x+0, y+k))
                for x in range(self.width-k)
                for y in range(self.height-k)
            ] + [
                ((x+0, y+0), (x+k, y+0))
                for x in range(self.width-k)
                for y in range(self.height-k)
            ] + [
                ((x+k, y+0), (x+k, y+k))
                for x in range(self.width-k)
                for y in range(self.height-k)
            ] + [
                ((x+0, y+k), (x+k, y+k))
                for x in range(self.width-k)
                for y in range(self.height-k)
            ], weight=0.5)

            # diagonal edges
            self.graph.add_edges_from([
                ((x+0, y+0), (x+k, y+k))
                for x in range(self.width-k)
                for y in range(self.height-k)
            ] + [
                ((x+k, y+0), (x+0, y+k))
                for x in range(self.width-k)
                for y in range(self.height-k)
            ], weight=0.5)
            
            np.random.seed(self.seed)
            # k = np.random.choice([2, 4, 8, 16])
            # straight edges
            x = np.random.choice(self.width, size=(self.height*self.width//4), replace=True)
            y = np.random.choice(self.height,  size=(self.height*self.width//4), replace=True)
            u = np.random.choice(self.width, size=(self.height*self.width//4), replace=True)
            v = np.random.choice(self.height,  size=(self.height*self.width//4), replace=True)
            self.graph.add_edges_from([
                ((x, y), (u, v)) for x, y, u, v in zip(x, y, u, v)
            ], weight=0.5)

            # # diagonal edges
            # self.graph.add_edges_from([
            #     ((x+0, y+0), (x+k, y+k))
            #     for x in range(self.width-k)
            #     for y in range(self.height-k)
            # ] + [
            #     ((x+k, y+0), (x+0, y+k))
            #     for x in range(self.width-k)
            #     for y in range(self.height-k)
            # ], weight=0.5)
            
        elif cc==4: 
            # connected component is equal to 4 already in the grid
            pass

        # Add edge to master node 
        self.graph.add_edges_from([
            ((x, y), (-1, 0))
            for x in range(self.width)
            for y in range(self.height)
        ], weight=0.5)

        self.graph.add_edges_from([
            ((x, y), (0, -1))
            for x in range(self.width)
            for y in range(self.height)
        ], weight=0.5)


        if diff_edge=="mean":
            for e, edge in enumerate(self.graph.edges):
                # Compute the mean of the affinity
                self.graph.edges[edge]['weight'] = (self.array[edge[0]] - self.array[edge[1]]) * 0.5
        elif diff_edge=="diff":
            for e, edge in enumerate(self.graph.edges):
                # Compute the affinity, 1 - difference
                self.graph.edges[edge]['weight'] = ( 1 - np.abs(self.array[edge[0]] - self.array[edge[1]]) )
                # if self.array[edge[0]] == self.array[edge[1]] == 0: # We dont care the connection of background
                #     self.graph.edges[edge]['weight'] = np.abs(self.array[edge[0]] - self.array[edge[1]]) # 0
                # else:
                #     self.graph.edges[edge]['weight'] = ( 1 - np.abs(self.array[edge[0]] - self.array[edge[1]]) )

    def set_nodes(self, weight=None):
        for n, node in enumerate(self.graph.nodes):
            y = n // self.width
            x = n % self.width
            self.graph.nodes[node]['weight'] = self.array[y,x,:] \
                if weight is None else np.array([weight], dtype=np.float32)
        # Set master node 
        self.graph.add_node((-1, 0), weight=np.array([1.0], dtype=np.float32) if weight is None else np.array([weight], dtype=np.float32)) 
        self.graph.add_node((0, -1), weight=np.array([0.0], dtype=np.float32) if weight is None else np.array([weight], dtype=np.float32)) 


class GraphBasedDataset(Dataset):
    def __init__(self, data_dicts, transforms=None):
        self.data_dicts = data_dicts
        self.transforms = transforms
        
    def __getitem__(self, index):
        # If use AB
        image_file = self.data_dicts[index]['image']
        label_file = self.data_dicts[index]['label']
        image_array = skimage.io.imread(image_file).astype(np.uint8) 
        label_array = skimage.io.imread(label_file).astype(np.uint8) 

        if self.transforms:
            transformed = self.transforms(image=image_array, mask=label_array)
            image = transformed['image'].astype(np.float32) / 255.0
            label = transformed['mask'].astype(np.float32) / 255.0
        else:
            image = image_array.astype(np.float32) / 255.0
            label = label_array.astype(np.float32) / 255.0
        
        # If use Monai
        # data_dict = self.data_dicts[index]
        # if self.transforms:
        #     # print(data_dict)
        #     data_dict = self.transforms(data_dict)
        #     image = data_dict['image'].squeeze()
        #     label = data_dict['label'].squeeze()
        #     # print(image.shape, label.shape)
        # else:
        #     image, label = data_dict["image"], data_dict["label"]
        seed = index
        image_grid = ImageGrid(image, diff_edge="diff", seed=seed)
        label_grid = ImageGrid(label, diff_edge="diff", seed=seed)

        image_graph = torch_geometric.utils.from_networkx(image_grid.graph, 
                                                          group_node_attrs=['weight'],
                                                          group_edge_attrs=['weight'])
        label_graph = torch_geometric.utils.from_networkx(label_grid.graph, 
                                                          group_node_attrs=['weight'],
                                                          group_edge_attrs=['weight'])
        # Normalize coordinate
        # image_graph.x = geom_transforms.NormalizeScale()(image_graph.x)
        # label_graph.x = geom_transforms.NormalizeScale()(label_graph.x)
        
        # if self.transforms:
        #     pass
        # print(image_graph.x.shape, image_graph.edge_index.shape, image_graph.edge_attr.shape)
        # print(label_graph.x.shape, label_graph.edge_index.shape, label_graph.edge_attr.shape)
        # return (image_graph.x, image_graph.edge_index), (label_graph.x, label_graph.edge_index)
        return {"image": image_graph, "label": label_graph}
    
    def __len__(self):
        return len(self.data_dicts)

class GraphBasedDataModule(PairedDataModule):
    def __init__(self, 
        batch_size: int=32,
        train_image_dirs: List[str]=['/data/train/images'],
        train_label_dirs: List[str]=['/data/train/labels'], 
        val_image_dirs: List[str]=['/data/val/images'], 
        val_label_dirs: List[str]=['/data/val/labels'],
        test_image_dirs: List[str]=['/data/test/images'],
        test_label_dirs: List[str]=['/data/test/labels'],
    ):
        super().__init__()
        self.batch_size = batch_size
        self.train_image_dirs = train_image_dirs
        self.train_label_dirs = train_label_dirs
        self.val_image_dirs = val_image_dirs
        self.val_label_dirs = val_label_dirs
        self.test_image_dirs = test_image_dirs
        self.test_label_dirs = test_label_dirs

    def _shared_dataloader(self, data_dicts, transforms=None, shuffle=True, drop_last=False, num_workers=8):
        dataset = GraphBasedDataset(data_dicts, transforms=transforms)
        dataloader = geom_loader.DataLoader(
            dataset=dataset, 
            batch_size=self.batch_size, 
            num_workers=num_workers, 
            # collate_fn=list_data_collate,
            shuffle=shuffle,
        )
        return dataloader

    def train_dataloader(self):
        train_transforms = AB.Compose([
            AB.OneOf([
                AB.RandomSizedCrop(min_max_height=(160, 250), height=256, width=256, p=0.5),
                # AB.PadIfNeeded(min_height=256, min_width=256, p=0.5)
            ], p=1),    
            AB.VerticalFlip(p=0.5),              
            AB.RandomRotate90(p=0.5),
            # AB.OneOf([
            #     AB.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03, p=0.5),
            #     AB.GridDistortion(p=0.5),
            #     AB.OpticalDistortion(distort_limit=2, shift_limit=0.5, p=1)                  
            #     ], p=0.8),
            AB.CLAHE(p=0.8),
            AB.RandomBrightnessContrast(p=0.8),    
            AB.RandomGamma(p=0.8), 
            AB.ToFloat(max_value=1.0, p=1.0),
            AB.Resize(height=128, width=128, p=1),
        ])
        return self._shared_dataloader(self.train_data_dicts, 
            transforms=train_transforms, 
            shuffle=True,
            drop_last=False,
            num_workers=12
        )
    
    def val_dataloader(self):
        val_transforms = AB.Compose([
            AB.ToFloat(max_value=1.0, p=1.0),
            AB.Resize(height=128, width=128, p=1),
        ])
        return self._shared_dataloader(self.val_data_dicts, 
            transforms=val_transforms, 
            shuffle=False,
            drop_last=False,
            num_workers=4
        )

    def test_dataloader(self):
        test_transforms = AB.Compose([
            AB.ToFloat(max_value=1.0, p=1.0),
            AB.Resize(height=128, width=128, p=1),
        ])
        return self._shared_dataloader(self.test_data_dicts, 
            transforms=test_transforms, 
            shuffle=False,
            drop_last=False,
            num_workers=4
        )

class GNNModel(nn.Module):
    def __init__(self, c_in, c_hidden, c_out, num_layers=3, dp_rate=0.5, **kwargs):
        """
        Inputs:
            c_in - Dimension of input features
            c_hidden - Dimension of hidden features
            c_out - Dimension of the output features. Usually number of classes in classification
            num_layers - Number of "hidden" graph layers
            dp_rate - Dropout rate to apply throughout the network
            kwargs - Additional arguments for the graph layer (e.g. number of heads for GAT)
        """
        super().__init__()
        # gnn_layer = gnn_layer_by_name[layer_name]
        node_layer = geom_nn.GATConv
        # node_layer = geom_nn.GATv2Conv
        # node_layer = geom_nn.ResGatedGraphConv

        node_model = []
        in_channels, out_channels = c_in, c_hidden
        for _ in range(num_layers-1):
            node_model += [
                node_layer(in_channels=in_channels,
                          out_channels=out_channels,
                          #dropout=dp_rate,
                          **kwargs),
                nn.ReLU(inplace=True),
                nn.Dropout(dp_rate)
            ]
            in_channels = c_hidden
        node_model += [node_layer(in_channels=in_channels,
                             out_channels=c_out,
                             #dropout=dp_rate,
                             **kwargs)]
        self.node_model = nn.ModuleList(node_model)

    def forward(self, x, edge_index, edge_attr=None):
        """
        Inputs:
            x - Input features per node
            edge_index - List of vertex index pairs representing the edges in the graph (PyTorch geometric notation)
        """
        for node_layer in self.node_model:
            # For graph layers, we need to add the "edge_index" tensor as additional input
            # All PyTorch Geometric graph layer inherit the class "MessagePassing", hence
            # we can simply check the class type.
            if isinstance(node_layer, geom_nn.MessagePassing):
                x, (edge_index, edge_attr) = node_layer(x, edge_index, edge_attr, return_attention_weights=True)
                # edge_attr = alpha*( (1 - torch.abs(x[edge_index[0]] - x[edge_index[1]])) )
            else:
                x = node_layer(x)
        return x, edge_attr

class GridGNN(pl.LightningModule):
    def __init__(self, **model_kwargs):
        super().__init__()
        # Saving hyperparameters
        self.save_hyperparameters()

        self.model = GNNModel(**model_kwargs)
        self.loss_function = DiceLoss(to_onehot_y=False, 
                                      sigmoid=False, 
                                      squared_pred=False)

    def forward(self, data, mode="train"):
        output, edge_attr = self.model(data.x, data.edge_index, data.edge_attr)
        return torch.sigmoid(output), torch.sigmoid(edge_attr)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=0.0001)
        return optimizer

    def _shared_step(self, batch, batch_idx, stage: Optional[str]='_shared'):
        images, labels = batch["image"], batch["label"]
        output, attrs = self.forward(images)
        loss = 0.5*(self.loss_function(output, labels.x) \
                  + self.loss_function(attrs, labels.edge_attr))
        return {"loss": loss}

    def training_step(self, batch, batch_idx, stage: Optional[str]='train'):
        images, labels = batch["image"], batch["label"]
        output, attrs = self.forward(images)
        loss = 0.5*(self.loss_function(output, labels.x) \
                  + self.loss_function(attrs, labels.edge_attr))
        if batch_idx==0:
            viz = torch.cat([images.x[:16384].reshape([128, 128]), 
                             labels.x[:16384].reshape([128, 128]), 
                             output[:16384].reshape([128, 128]), 
                            ], dim=-1)#[:8]
            grid = torchvision.utils.make_grid(viz, nrow=5, padding=0)
            tensorboard = self.logger[0].experiment
            tensorboard.add_image(f'{stage}_samples', grid, self.current_epoch)

        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        return self._shared_step(batch, batch_idx, stage='val')

    def test_step(self, batch, batch_idx):
        return self._shared_step(batch, batch_idx, stage='test')
    
    def _shared_epoch_end(self, outputs, stage: Optional[str]='_shared'):
        loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log(f'{stage}_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)

    def validation_epoch_end(self, outputs):
        return self._shared_epoch_end(outputs, stage='val')
    
    def test_epoch_end(self, outputs):
        return self._shared_epoch_end(outputs, stage='test')


log_dir = "logs"
tsb_logger = TensorBoardLogger(save_dir=os.path.join(log_dir, 'tsb'))
batch_size=5
lr=1e-4
max_epochs=201

datamodule = GraphBasedDataModule(
    batch_size=batch_size,
    train_image_dirs=['data/train/images/',],
    train_label_dirs=['data/train/labels/',],
    val_image_dirs=['data/test/images/',],
    val_label_dirs=['data/test/labels/',],
    test_image_dirs=['data/test/images/',],
    test_label_dirs=['data/test/labels/',],
)
datamodule.prepare_data()
datamodule.setup()

model = GridGNN(c_in=1, c_hidden=100, c_out=1, num_layers=8, edge_dim=1, add_self_loops=False, dp_rate=0.5)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath=log_dir,
    filename='{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    mode='min',
)

trainer = Trainer(
    gpus=-1,
    # tpu_cores=8,
    max_epochs=max_epochs,
    logger=[tsb_logger],
    callbacks=[checkpoint_callback],
    num_sanity_val_steps=1,
    # profiler="advanced",
)

trainer.fit(model, datamodule)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
DataModule.setup has already been called, so it will not be called again. In v1.6 this behavior will change to always call DataModule.setup.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Found 3000 images and 3000 labels.
Found 20 images and 20 labels.
Found 20 images and 20 labels.



  | Name          | Type     | Params
-------------------------------------------
0 | model         | GNNModel | 63.7 K
1 | loss_function | DiceLoss | 0     
-------------------------------------------
63.7 K    Trainable params
0         Non-trainable params
63.7 K    Total params
0.255     Total estimated model params size (MB)
Checkpoint directory /home/qtran/graph/logs exists and is not empty.


                                                                      

Global seed set to 42


Epoch 200: 100%|██████████| 604/604 [21:24<00:00,  2.13s/it, loss=0.518, v_num=0, val_loss=0.492]
