In [None]:
import torch
from torch.utils.data import Dataset
import glob
import os
from PIL import Image
from torch.utils.data import DataLoader
import numpy as np
import wandb

from pl_bolts.models.vision import UNet
import pytorch_lightning as pl
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch import nn
import matplotlib.pyplot as plt

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [None]:
# Extract the datasets to a path called
# /content/bird/piv, /content/bird/mask
# /content/fish/piv, /content/fish/mask
# ...

data = ['bird']

piv = []
for d in data:
    piv += glob.glob("/content/" + d + "/piv/*")

mask = []
for d in data:
    mask += glob.glob("/content/" + d + "/mask/*")

piv = sorted(piv)
mask = sorted(mask)

print(piv[:5])
print(mask[:5])

print(len(piv))

In [None]:
class DatasetSegmentation(Dataset):
    def __init__(self, image_path, mask_path, transform=None, target_transform=None):
        super(DatasetSegmentation, self).__init__()

        if isinstance(image_path, list):
            self.img_files = image_path
        else:
            self.img_files = glob.glob(os.path.join(image_path, '*.png'))

        if isinstance(mask_path, list):
            self.mask_files = mask_path
        else:
            self.mask_files = glob.glob(os.path.join( mask_path, '*.png'))

        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        img_path  = self.img_files[index]
        mask_path = self.mask_files[index]
        data  = np.asarray(Image.open(img_path).convert("L"))
        label = np.asarray(Image.open(mask_path).convert("L"))

        if self.transform:
            data = self.transform(data)
        if self.target_transform:
            label = self.target_transform(label)

        return data, label

    def __len__(self):
        return len(self.img_files)

In [None]:
from torch.utils.data import random_split
from sklearn.model_selection import train_test_split

class SegmentationDatamodule(pl.LightningDataModule):
    def __init__(self, image_path, mask_path, batch_size: int = 32, **kwargs):
        super().__init__()
        if isinstance(image_path, list):
            self.img_files = image_path
        else:
            self.img_files  = glob.glob( os.path.join(image_path, '*.png') )

        if isinstance(mask_path, list):
            self.mask_files = mask_path
        else:
            self.mask_files = glob.glob( os.path.join(mask_path, '*.png') )

        self.batch_size = batch_size
        self.kwargs = kwargs

        self._has_setup_train_only = True

    def setup(self, stage = None):

        SEED = 0
        indices = np.arange(0, len(self.img_files))

        # RANDOM SPLIT 60/20/20
        train_indices, val_test_indices = train_test_split(
            indices, test_size=.4, random_state=SEED,
        )

        val_indices, test_indices = train_test_split(
            val_test_indices, test_size=.5, random_state=SEED,
        )

        self.train_data = [ self.img_files[i] for i in train_indices]
        self.train_mask = [self.mask_files[i] for i in train_indices]

        self.test_data = [ self.img_files[i] for i in test_indices]
        self.test_mask = [self.mask_files[i] for i in test_indices]

        self.val_data = [ self.img_files[i] for i in val_indices]
        self.val_mask = [self.mask_files[i] for i in val_indices]

    def train_dataloader(self):
        dataset = DatasetSegmentation(self.train_data, self.train_mask, **self.kwargs)
        return DataLoader(dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        dataset = DatasetSegmentation(self.test_data, self.test_mask, **self.kwargs)
        return DataLoader(dataset, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        dataset = DatasetSegmentation(self.val_data, self.val_mask, **self.kwargs)
        return DataLoader(dataset, batch_size=self.batch_size, shuffle=False)

In [None]:
t = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((256, 256)),
])

tt = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((256, 256)),
    lambda x:x.long(),
    lambda x:F.one_hot(x, 2),
    lambda x:x.permute(0, 3, 1, 2),
    lambda x:x.squeeze(0)
])

dm = SegmentationDatamodule(
    piv, mask,
    transform=t, target_transform=tt, batch_size = 16
)

dm.setup()

train_dl = dm.train_dataloader()
x, y = next(iter(train_dl))
print(np.shape(x))
print(np.shape(y))

print(torch.max(x))
print(torch.min(x))

print(torch.max(y))
print(torch.min(y))

print(y.dtype)
print(x.device)

fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 20))

i = 15
im = x[i,0,:,:].detach().cpu().numpy()
im = np.uint8(im*255)
im = Image.fromarray(im, mode='L')
axs[0].imshow(im, cmap="gray")

im = y[i,:,:,:].detach().cpu().numpy()
im = np.argmax(im, axis=0)
im = np.uint8(im*255)
im = Image.fromarray(im, mode='L')
axs[1].imshow(im, cmap="gray")
plt.show()

In [None]:
# https://github.com/jvanvugt/pytorch-unet
# https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/models/vision/unet.py

class MyUnet(pl.LightningModule):
    def __init__(self, unet_args = {}, lr=0.001):
        super().__init__()

        self.unet = UNet(**unet_args)
        self.learning_rate = lr
        self.loss = nn.CrossEntropyLoss()
        self.save_hyperparameters('lr')

    def forward(self, x):
        return self.unet(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        
        y = torch.argmax(y, axis=1).long()

        loss = F.cross_entropy(y_pred, y)
        self.log("train/loss", loss)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        
        y = torch.argmax(y, axis=1).long()
        loss = F.cross_entropy(y_pred, y)

        self.log("val/loss", loss)

        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        
        y = torch.argmax(y, axis=1).long()

        loss = F.cross_entropy(y_pred, y)
        self.log("test/loss", loss)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [None]:
from pytorch_lightning.loggers import WandbLogger

unet_args = {
    'num_classes' : 2,
    'input_channels' : 1,
    'num_layers' : 3,
    'features_start' : 4
}
epoch = 250

model = MyUnet(unet_args = unet_args)

print(model, file=open("unet_layers.txt", "a"))
trainer = pl.Trainer(max_epochs=epoch, gpus=1, log_every_n_steps=1)
trainer.fit(model, dm)
trainer.test(model, dm)

In [None]:
model.eval()

x = None
y = None
out = None

with torch.no_grad():
    for xx, yy in dm.test_dataloader():
        if out is None:
            out = model(xx)
            x = xx
            y = yy
        else:
            out = torch.cat((out, model(xx)), dim=0)
            x = torch.cat((x, xx), dim=0)
            y = torch.cat((y, yy), dim=0)

x = x.cpu().numpy()
y = y.cpu().numpy()
out = out.cpu().numpy()

print(np.shape(x))
print(np.shape(out))

In [None]:
def iou(actual, predicted):
    actual    = np.asarray(actual, dtype=bool)
    predicted = np.asarray(predicted, dtype=bool)

    overlap = actual * predicted
    union   = actual + predicted
    iou = overlap.sum() / float(union.sum())

    return iou

iou_res = iou(
    np.argmax(y, axis=1),
    np.argmax(out, axis=1)
)

print(iou_res)