In [None]:
import itertools
from pathlib import Path
import copy

import cv2
import torch
import torchvision
import torchmetrics
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import seaborn as sns

sns.set()

In [None]:
data_dir = Path("/nfs/home/rafman23/jupyter/FlyingObjectDataset_10K")
training_dir = data_dir.joinpath("training")
validation_dir = data_dir.joinpath("validation")
testing_dir = data_dir.joinpath("testing")

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_detect_anomaly(True)
writer = torch.utils.tensorboard.SummaryWriter("runs/seg", )

In [None]:
class SegmentationFlyingObjectsDataset(torch.utils.data.Dataset):

    def __init__(self, root, transform):
        super().__init__()
        self.root = Path(root)
        self.transform = transform
        self.image_paths = sorted(self.root.joinpath("image").glob("*"))
        self.seg_paths = [
            self.root.joinpath("gt_image", p.relative_to(self.root.joinpath("image"))).with_name(f"gt_{p.name}")
            for p in self.image_paths
        ]

    def _read_image(self, path):
        return np.array(Image.open(path))
    
    def __getitem__(self, index):
        image = self._read_image(self.image_paths[index])
        seg = self._read_image(self.seg_paths[index])

        transform = self.transform(image=image)

        image_tr = transform["image"]
        seg_tr = self.transform.replay(transform["replay"], image=seg)["image"]
        return image_tr, seg_tr

    def __len__(self):
        return len(self.image_paths)

In [None]:
train_transform = A.ReplayCompose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.RandomResizedCrop(128, 128, scale=(0.95, 1.05), ratio=(1.0, 1.0), interpolation=cv2.INTER_LANCZOS4),
    A.Normalize(mean=0.0, std=1.0),
    ToTensorV2(),
])
test_transform = A.ReplayCompose([
    A.Normalize(mean=0.0, std=1.0),
    ToTensorV2(),
])

In [None]:
train_dataset = SegmentationFlyingObjectsDataset(
    training_dir,
    train_transform
)
valid_dataset = SegmentationFlyingObjectsDataset(
    validation_dir,
    test_transform
)
test_dataset = SegmentationFlyingObjectsDataset(
    testing_dir,
    test_transform
)

In [None]:
sample_image, sample_seg = zip(*[train_dataset[i] for i in range(100)])
sample_image = torch.stack(sample_image)
sample_seg = torch.stack(sample_seg)

In [None]:
plt.figure(figsize=(20, 10))
plt.imshow(torchvision.utils.make_grid(sample_image, nrow=20).permute(1, 2, 0))
plt.show()

In [None]:
plt.figure(figsize=(20, 10))
plt.imshow(torchvision.utils.make_grid(sample_seg, nrow=20).permute(1, 2, 0))
plt.show()

In [None]:
sample_image_seg = torch.stack([sample_image, sample_seg], dim=1).view(-1, 3, 128, 128)

In [None]:
plt.figure(figsize=(10, 50))
plt.axis("off")
idx = np.random.choice(np.arange(len(x)), size=25, replace=False)
plt.imshow(torchvision.utils.make_grid(sample_image_seg[:100], nrow=10, pad_value=0.5, padding=10).permute(1, 2, 0))
plt.show()

In [None]:
batch_size = 32
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True
)
valid_loader = torch.utils.data.DataLoader(
    dataset=valid_dataset,
    batch_size=batch_size
)
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=batch_size
)

In [None]:
class UnetBlock(torch.nn.Module):

    def __init__(self, input_channels, output_channels, kernel_size=3, predown=False, postup=False):
        super().__init__()
        padding = (kernel_size - 1) // 2
        prelayers = [torch.nn.MaxPool2d(2, 2)] if predown else []
        postlayers = [torch.nn.ConvTranspose2d(output_channels, output_channels // 2, 2, 2)] if postup else []
            
        self.main = torch.nn.Sequential(
            *prelayers,
            torch.nn.Conv2d(input_channels, output_channels, kernel_size, 1, padding),
            torch.nn.LeakyReLU(0.2),
            torch.nn.BatchNorm2d(output_channels),
            torch.nn.Conv2d(output_channels, output_channels, kernel_size, 1, padding),
            torch.nn.LeakyReLU(0.2),
            *postlayers
        )
    
    def forward(self, input):
        return self.main(input)
            

In [None]:
class Unet(torch.nn.Module):

    def __init__(self, n_features=64, n_outputs=3):
        super().__init__()
        self.downblock1 = UnetBlock(3, n_features)
        self.downblock2 = UnetBlock(n_features, n_features * 2, predown=True)
        self.downblock3 = UnetBlock(n_features * 2, n_features * 4, predown=True) 
        self.downblock4 = UnetBlock(n_features * 4, n_features * 8, predown=True) 
        self.downblock5 = UnetBlock(n_features * 8, n_features * 16, predown=True) 
        
        self.bottom = UnetBlock(n_features * 16, n_features * 32, kernel_size=1, predown=True, postup=True)   

        self.upblock5 = UnetBlock(n_features * 32, n_features * 16, postup=True)
        self.upblock4 = UnetBlock(n_features * 16, n_features * 8, postup=True)
        self.upblock3 = UnetBlock(n_features * 8, n_features * 4, postup=True)
        self.upblock2 = UnetBlock(n_features * 4, n_features * 2, postup=True)
        self.upblock1 = UnetBlock(n_features * 2, n_outputs)

    def forward(self, input):
        downblock1 = self.downblock1(input)
        downblock2 = self.downblock2(downblock1)
        downblock3 = self.downblock3(downblock2)
        downblock4 = self.downblock4(downblock3)
        downblock5 = self.downblock5(downblock4)
        
        bottom = self.bottom(downblock5)

        upblock5 = self.upblock5(torch.cat([downblock5, bottom], dim=1))
        upblock4 = self.upblock4(torch.cat([downblock4, upblock5], dim=1))
        upblock3 = self.upblock3(torch.cat([downblock3, upblock4], dim=1))
        upblock2 = self.upblock2(torch.cat([downblock2, upblock3], dim=1))
        upblock1 = self.upblock1(torch.cat([downblock1, upblock2], dim=1))
        return torch.nn.functional.sigmoid(upblock1)

In [None]:
def train_epoch(
    epoch,
    optimizer:torch.optim.Optimizer, 
    loss_fn: torch.nn.Module, 
    model: torch.nn.Module, 
    train_loader: torch.utils.data.DataLoader,
    writer,
):
    total_loss = 0
    total_items = 0
    model.train(True)

    for idx, (inputs, labels) in enumerate(tqdm(train_loader)):
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)

        loss = loss_fn(outputs, labels)
        loss.backward()

        optimizer.step()
        
        n_items = len(inputs)
        total_loss += loss.item() * n_items
        total_items += n_items

        iteration_number = idx + epoch * len(train_loader)
        writer.add_scalar("training_loss_step", loss.item(), iteration_number)

    return total_loss / total_items

def validate_epoch(
    loss_fn: torch.nn.Module, 
    model: torch.nn.Module, 
    val_loader: torch.utils.data.DataLoader
):
    total_loss = 0
    total_items = 0
    model.eval()

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)

            n_items = len(inputs)
            total_loss += loss_fn(outputs, labels).item() * n_items
            total_items += n_items

    return total_loss / total_items

def training_loop(num_epoch, writer, model, optimizer, loss_fn, train_loader, val_loader, model_path):
    best_val_loss = np.inf
    best_model = None

    train_losses = list()
    val_losses = list()

    for epoch in range(num_epoch):
        train_loss = train_epoch(epoch, optimizer, loss_fn, model, train_loader, writer)
        train_losses.append(train_loss)
        writer.add_scalar("training_loss_epoch", train_loss, epoch)
        
        val_loss = validate_epoch(loss_fn, model, val_loader)
        val_losses.append(val_loss)
        writer.add_scalar("validation_loss_epoch", val_loss, epoch)

        if val_loss < best_val_loss:
            torch.save(model, model_path)
            best_model = copy.deepcopy(model)
            best_val_loss = val_loss
        print(f"epoch {epoch + 1}: loss: {train_loss:0.4f} val loss: {val_loss:0.4f}")

    return best_model, train_losses, val_losses

def predict(model: torch.nn.Module, test_loader: torch.utils.data.DataLoader):
    with torch.no_grad():
        true = []
        pred = []
        for inputs, labels in tqdm(test_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)

            true.append(labels.detach())
            pred.append(outputs.detach())

    return torch.cat(true).cpu(), torch.cat(pred).cpu()

In [None]:
unet = Unet().to(device)
optimizer = torch.optim.Adam(unet.parameters(), lr=0.01)

def dice_loss(pred, true, smooth=1.0):
    intersection = (pred * true).sum(dim=[2,3])
    union = pred.sum(dim=[2,3]) + true.sum(dim=[2,3])
    coef = ((2.0 * intersection + smooth) / (union + smooth)).mean(dim=1)
    loss = 1 - coef

    return loss.mean()

def calculate_loss(pred, true, bce_weight=0.5):
    dice = dice_loss(pred, true)
    bce = torch.nn.functional.binary_cross_entropy(pred, true)
    return dice * (1 - bce_weight) + bce * bce_weight
    return torch.nn.functional.binary_cross_entropy(pred, true)
    
torch.cuda.empty_cache()
writer.add_graph(unet, sample_image.to(device))

In [None]:
best_model, training_loss, validation_loss = training_loop(200, writer, unet, optimizer, calculate_loss, train_loader, valid_loader, "best_model.pth")

In [None]:
plt.figure(figsize=(15, 7.5))
plt.plot(training_loss, label="training")
plt.plot(validation_loss, label="validation")
plt.title("0.5 Dice + 0.5 BCE Loss")
plt.legend()
plt.show()

In [None]:
test_loss = validate_epoch(calculate_loss, best_model, test_loader)
test_loss

In [None]:
def evaluate(
    metrics, 
    model: torch.nn.Module, 
    loader: torch.utils.data.DataLoader
):
    total_loss = dict(zip(metrics.keys(), [0]*len(metrics)))
    total_items = 0
    model.eval()

    with torch.no_grad():
        for inputs, labels in tqdm(loader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)

            n_items = len(inputs)
            for m, fn in metrics.items():
                total_loss[m] += fn(outputs, labels).item() * n_items 
            total_items += n_items

    for m, fn in metrics.items():
        total_loss[m] /= total_items

    return total_loss

In [None]:
metrics = {
    "dice": dice_loss,
    "bce": torch.nn.functional.binary_cross_entropy,
    "mse": torch.nn.functional.mse_loss,
    "acc": lambda p, t: ((p > 0.5) == t).float().mean()
}

In [None]:
train_eval_loader = torch.utils.data.DataLoader(
    dataset=SegmentationFlyingObjectsDataset(
        training_dir,
        test_transform
    ),
    batch_size=batch_size
)
evaluate(metrics, best_model, train_eval_loader)

In [None]:
evaluate(metrics, best_model, valid_loader)

In [None]:
evaluate(metrics, best_model, test_loader)

In [None]:
test_true, test_pred = predict(best_model, test_loader)
padding = torch.ones(*test_true.shape[:-1], 2)
test_true_pred = torch.cat([test_pred, padding, test_true], dim=3)

In [None]:
error = ((test_true - test_pred) ** 2).sum(dim=[1,2,3])
error_values, error_indices = error.topk(25)
good_values, good_indices = (-error).topk(100)

In [None]:
plt.figure(figsize=(15, 50))
plt.title("Bad Results")
plt.axis("off")
idx = np.random.choice(np.arange(len(x)), size=25, replace=False)
plt.imshow(torchvision.utils.make_grid(x[error_indices], nrow=5, pad_value=1.0, padding=10).permute(1, 2, 0))
plt.show()

In [None]:
plt.figure(figsize=(15, 50))
plt.title("Good Results")
plt.axis("off")
idx = np.random.choice(np.arange(len(x)), size=25, replace=False)
plt.imshow(torchvision.utils.make_grid(x[good_indices[-25:]], nrow=5, pad_value=1.0, padding=10).permute(1, 2, 0))
plt.show()