In [2]:
from pathlib import Path
from PIL import Image
from IPython.display import display
import pandas as pd

jpeg_p = Path("jpegs_small")
mask_p = Path("masks_small")

df = pd.DataFrame(columns=['masks', 'label'])

X = []
Y = []

dt = 4 # Past images
dl = 1 # Future images to predict

for folder in mask_p.glob("*/*"):
    imgs = list(folder.glob("*"))
    for i in range(len(imgs) - dt - dl + 1):
        # img = imgs[img]
        # print(img)
        # img_mask = mask_p / Path(*img.parts[1:])
        # print(img_mask)
        # display(Image.open(img))
        # display(Image.open(img_mask))
        x = ",".join(str(img) for img in imgs[i: i+dt])
        y = imgs[i + dt + dl - 1]
        X.append(x)
        Y.append(y)

new_row = pd.DataFrame({'masks': X, 'label': Y})    
df = pd.concat([df, new_row], ignore_index=True)
df.to_csv('training.csv', index=False)

In [6]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
import albumentations as a
from pathlib import Path
from sklearn.model_selection import train_test_split
import numpy as np
import torch

df = pd.read_csv('../training.csv')


trainval, test = train_test_split(df, random_state=42, test_size=0.2)
train, validation = train_test_split(df, random_state=42, test_size=0.25)

class CustomDataset(Dataset):
    def __init__(self, df, transforms):
        self.df = df
        self.transforms = transforms

    def _load_image(self, path):
        return np.asarray(Image.open(path))

    def __len__(self):
        return len(self.df)

    def __getitem__(self, item):
        inputs, label = self.df.iloc[item]
        inputs = [Path("../") / Path(p) for p in inputs.split(",")]
        label = self._load_image(Path("../") / Path(label))
        inputs = [self._load_image(img) for img in inputs]
        # print(np.unique(inputs[0]))
        if self.transforms:
            inputs = [self.transforms(image = img)["image"] / 255 for img in inputs]
            label = self.transforms(image = label)["image"] / 255
        # print(label.squeeze().shape)
        # print(np.unique(label))
        # print(torch.stack(inputs, dim=0).squeeze(1).shape)

        return torch.stack(inputs, dim=0).squeeze(1), label


train_transform = a.Compose([
        a.HorizontalFlip(),
        a.VerticalFlip(),
        a.RandomRotate90(),
        a.Transpose(),
        # a.Normalize(
        #     mean=[0.485, 0.456, 0.406],
        #     std=[0.229, 0.224, 0.225],
        #     max_pixel_value=255
        # ),
        a.ToTensorV2()
    ])

val_transform = a.Compose([
#         a.Normalize(
#             mean=[0.485, 0.456, 0.406],
#             std=[0.229, 0.224, 0.225],
#             max_pixel_value=255
#         ),
        a.ToTensorV2()
    ])


ds = CustomDataset(train, train_transform)
test1, test2 = ds[0]

training_ds = CustomDataset(train, train_transform)
test_ds = CustomDataset(test, train_transform)
val_ds = CustomDataset(validation, train_transform)

print(test1.shape)
print(test2.shape)

print(len(training_ds))

torch.Size([4, 224, 224])
torch.Size([1, 224, 224])
2033


In [13]:
import torch
import torch.nn as nn


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        return self.double_conv(x)

# Sanity check
test_input = torch.rand((1, 1, 224, 224))
dc = DoubleConv(1, 2)
print(dc(test_input).shape)


class Downward(nn.Module):
    def __init__(self):
        super(Downward, self).__init__()
        self.conv1 = DoubleConv(1, 16)
        self.conv2 = DoubleConv(16, 32)
        self.conv3 = DoubleConv(32, 64)
        self.conv4 = DoubleConv(64, 128)
        self.maxpool = nn.MaxPool2d(2)
        # self.fc = nn.Linear(128*28*28, 768)
        self.features = {}

    def forward(self, x):
        # B, 1, 224, 224
        x = self.conv1(x)
        self.features['conv1'] = x
        # B, 16, 224, 224
        x = self.maxpool(x)
        # B, 16, 112, 112
        x = self.conv2(x)
        self.features['conv2'] = x
        # B, 32, 112, 112
        x = self.maxpool(x)
        # B, 32, 56, 56
        x = self.conv3(x)
        self.features['conv3'] = x
        # B, 64, 56, 56
        x = self.maxpool(x)
        # B, 64, 28, 28
        x = self.conv4(x)
        # B, 128, 28, 28
        
        # B, _, _, _ = x.shape
        # return self.fc(x.reshape(B, -1)) # [B, 768]
        return x

test_input = torch.rand((6, 1, 224, 224))
dc = Downward()
y = dc(test_input)
print(y.shape)

# print(dc.features['conv2'].shape)

torch.Size([1, 2, 224, 224])
torch.Size([6, 128, 28, 28])


In [15]:
class Upward(nn.Module):
    def __init__(self):
        super(Upward, self).__init__()
        # self.fc = nn.Linear(768, 128*28*28)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.up2 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.up3 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)

        self.conv1 = DoubleConv(128, 64)
        self.conv2 = DoubleConv(64, 32)
        self.conv3 = DoubleConv(32, 16)

        self.out = nn.Conv2d(16, 1, kernel_size=3, padding=1, bias=False)

        self.sigmoid = nn.Sigmoid()
        

    def forward(self, x, x_res):
        # x_res contains intermediate conv faetures conv1, conv2, conv3
        
        # # B, 768
        # x = self.fc(x)
        # B, _ = x.shape
        # x = x.reshape(B, 128, 28, 28)
        
        # B, 128, 28, 28
        x = self.up1(x)
        # B, 64, 56, 56
        x = torch.cat((x, x_res['conv3']), dim=1)
        # B, 128, 56, 56
        x = self.conv1(x)
        # B, 64, 56, 56
        x = self.up2(x)
        # B, 32, 112, 112
        x = torch.cat((x, x_res['conv2']), dim=1)
        # B, 64, 112, 112
        x = self.conv2(x)
        # B, 32, 112, 112
        x = self.up3(x)
        # B, 16, 224, 224
        x = torch.cat((x, x_res['conv1']), dim=1)
        # B, 32, 224, 224
        x = self.conv3(x)
        # B, 16, 224, 224

        x = self.out(x) # B, 1, 224, 224

        return self.sigmoid(x)

# Sanity check
# test = torch.rand((6, 768))
# net = Upward()
# probs = net(test, dc.features)
# y = torch.round(probs)
# print(y.shape)
# print(torch.unique(y))
# to_pil = transforms.ToPILImage()
# img = to_pil(y[0, :, :, :].squeeze(0))
# display(img)

In [16]:
class TestCustomUNet(nn.Module):
    def __init__(self):
        super(TestCustomUNet, self).__init__()
        self.up = Upward()
        self.down = Downward()

    def forward(self, x):
        x = self.down(x)
        x = self.up(x, self.down.features)
        return x

c = TestCustomUNet()
tst = torch.randn(16, 1, 224, 224)
y  = c(tst)
print(y.shape)

torch.Size([16, 1, 224, 224])


In [11]:
class IoULoss(nn.Module):
    def __init__(self):
        super(IoULoss, self).__init__()

    def forward(self, x, target):
        B, _, _, _ = x.shape
        intersection = (x * target).sum((2, 3))
        union = x.sum((2, 3)) + target.sum((2, 3)) - intersection
        iou = (intersection + 1e-6) / (union + 1e-6)
        return 1 - iou.mean()
        