# Mask GAN

I want to try to use a GAN to make the output masks better. The generator will be the UNet that outputs the masks, while the discriminator will try to distinguish between real masks and generated masks.

In [1]:
import monai
from monai.networks.nets import UNet, Discriminator
from monai.transforms import (
    Compose,
    LoadNiftid,
    ScaleIntensityd,
    NormalizeIntensityd,
    AddChanneld,
    ToTensord,
    RandSpatialCropd,
    RandCropByPosNegLabeld,
    CropForegroundd,
    Identityd,
)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

import pytorch_lightning as pl
from sklearn.model_selection import train_test_split

In [10]:
class MaskGAN(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        
        self.generator = UNet(
            dimensions=3,
            in_channels=1,
            out_channels=2,
            channels=(64, 128, 258, 512, 1024),
            strides=(2, 2, 2, 2),
            norm=monai.networks.layers.Norm.BATCH,
            dropout=0,
        )
        
        self.discriminator = Discriminator(
            in_shape=hparams.patch_size,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            norm=monai.networks.layers.Norm.BATCH,
        )
        
        self.generated_masks = None
    
    # Data setup
    def setup(self, stage):
        data_dir = 'data/'
        
        # Train imgs/masks
        train_imgs = []
        with open(data_dir + 'train_imgs.txt', 'r') as f:
            train_imgs = [image.rstrip() for image in f.readlines()]

        train_masks = []
        with open(data_dir + 'train_masks.txt', 'r') as f:
            train_masks = [mask.rstrip() for mask in f.readlines()]
        
        train_dicts = [{'image': image, 'mask': mask} for (image, mask) in zip(train_imgs, train_masks)]
        
        train_dicts, val_dicts = train_test_split(train_dicts, test_size=0.2)
        
        # Basic transforms
        data_keys = ["image", "mask"]
        data_transforms = Compose(
            [
                LoadNiftid(keys=data_keys),
                AddChanneld(keys=data_keys),
                NormalizeIntensityd(keys="image"),
                RandCropByPosNegLabeld(
                    keys=data_keys,
                    label_key="mask",
                    spatial_size=self.hparams.patch_size,
                    num_samples=4,
                    image_key="image"
                ),
            ]
        )
        
        self.train_dataset = monai.data.CacheDataset(
            data=train_dicts,
            transform=Compose(
                [
                    data_transforms,
                    ToTensord(keys=data_keys)
                ]
            ),
            cache_rate=1.0
        )
        
        self.val_dataset = monai.data.CacheDataset(
            data=val_dicts,
            transform=Compose(
                [
                    data_transforms,
                    ToTensord(keys=data_keys)
                ]
            ),
            cache_rate=1.0
        )
        
    def train_dataloader(self):
        return monai.data.DataLoader(
            self.train_dataset, batch_size=self.hparams.batch_size, shuffle=True, num_workers=hparams.num_workers
        )

    def val_dataloader(self):
        return monai.data.DataLoader(
            self.val_dataset, batch_size=self.hparams.batch_size, shuffle=True, num_workers=hparams.num_workers
        )
    
    # Training setup
    def forward(self, image):
        return self.generator(image)
    
    def generator_loss(self, y_hat, y):
        dice_loss = monai.losses.DiceLoss(
            to_onehot_y=True,
            softmax=True
        )
        return dice_loss(y_hat, y)
    
    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        inputs, labels = batch['image'], batch['mask']
        batch_size = inputs.size(0)
        # Generator training
        if optimizer_idx == 0:
            self.generated_masks = self(inputs)
            
            # Loss from difference between real and generated masks
            g_loss = generator_loss(
                self.generated_masks,
                labels
            )
            
            # Loss from discriminator
            # The generator wants the discriminator to be wrong,
            # so the wrong labels are used
            fake_labels = torch.ones(batch_size, 1)
            d_loss = self.adversarial_loss(
                self.discriminator(self.generated_masks),
                fake_labels
            )
            
            avg_loss = (g_loss + d_loss) / 2
            
            tensorboard_logs = {
                "g_train/g_loss": g_loss,
                "g_train/d_loss": d_loss,
                "g_train/loss": avg_loss
            }
            return {'loss': avg_loss, 'logs': tensorboard_logs}
            
        # Discriminator trainig
        else:
            # Learning real masks
            real_labels = torch.ones(batch_size, 1)
            real_loss = self.adversarial_loss(
                self.discriminator(labels),
                real_labels
            )
            
            # Learning "fake" masks
            fake_labels = torch.zeros(batch_size, 1)
            fake_loss = self.adversarial_loss(
                self.discriminator(self.generated_masks),
                fake_labels
            )
            
            avg_loss = (real_loss + fake_loss) / 2
            tensorboard_logs = {
                "d_train/r_loss": real_loss,
                "d_train/f_loss": fake_loss,
                "d_train/loss": avg_loss
            }
            return {'loss': avg_loss, 'logs': tensorboard_logs}
    
    def configure_optimizers(self):
        lr = self.hparams.lr
        g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=lr)
        d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr)
    
    def validation_step(self, batch, batch_idx):
        inputs, labels = (
            batch["image"],
            batch["mask"],
        )
        outputs = self(inputs)
        
        # Sample masks
        sample_masks = []
        for i in range(4):
            sample_masks.append(outputs[i, 0, :, :, 8].argmax(1).squeeze(0).detach())
            grid = torchvision.utils.make_grid(sample_masks)
            self.logger.experiment.add_image('generated_masks', grid, 0)
        
        loss = self.generator_loss(outputs, labels)
        return {"val_loss": loss}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        tensorboard_logs = {"val/loss": avg_loss}
        return {"val_loss": avg_loss, "log": tensorboard_logs}

In [11]:
from argparse import Namespace

args = {
    'batch_size': 4,
    'lr': 0.001,
    'patch_size': (256, 256, 16),
    'num_workers': 6,
}

hparams = Namespace(**args)

In [12]:
model = MaskGAN(hparams)

In [13]:
NAME = 'models/7-03-2020_MaskGAN/'
logger = pl.loggers.TensorBoardLogger(NAME + "tb_logs/", name='')

# Callbacks
early_stopping = pl.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=10
)

checkpoint_callback = pl.callbacks.ModelCheckpoint(filepath=NAME + 'checkpoints/')


trainer = pl.Trainer(
    checkpoint_callback=checkpoint_callback,
    early_stop_callback=early_stopping,
    check_val_every_n_epoch=5,
    gpus=1,
    max_epochs=1000,
    logger=logger,
)

trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type          | Params
------------------------------------------------
0 | generator     | UNet          | 31 M  
1 | discriminator | Discriminator | 466 K 




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layoutâ€¦

RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 5.94 GiB total capacity; 3.49 GiB already allocated; 921.94 MiB free; 4.26 GiB reserved in total by PyTorch)