In [1]:
import numpy as np
import os
import cv2 as cv
from typing import Iterable
import matplotlib.pyplot as plt
import glob
import time

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
import torchvision.transforms as T
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils import data
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
from models import MixtureOfExpertsEncoder, Traditional2dSegmenter, SegmentationDecoder
from models.utils.util_classes import SplitTensor, AddInQuadrature, DepthSum, ConvWH, ConvDW, ConvDH

dev = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(dev)
print(device)

cuda


In [2]:
class FftLayer(nn.Module):
    def __init__(self, kernel_size, stride, padding, num_orders=4, inverse=False):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.num_orders = num_orders

In [3]:
root_directory = '/media/polarbear/BigBertha/Datasets/ComputerVision/ImageSegmentationFloodZones/'
image_directory = os.path.join(root_directory, 'Image')
mask_directory = os.path.join(root_directory, 'Mask')
image_paths = glob.glob(image_directory + '/*.jpg')
mask_paths = glob.glob(mask_directory + '/*.png')
print(image_paths)
print(mask_paths)

['/media/polarbear/BigBertha/Datasets/ComputerVision/ImageSegmentationFloodZones/Image/0.jpg', '/media/polarbear/BigBertha/Datasets/ComputerVision/ImageSegmentationFloodZones/Image/1.jpg', '/media/polarbear/BigBertha/Datasets/ComputerVision/ImageSegmentationFloodZones/Image/10.jpg', '/media/polarbear/BigBertha/Datasets/ComputerVision/ImageSegmentationFloodZones/Image/1000.jpg', '/media/polarbear/BigBertha/Datasets/ComputerVision/ImageSegmentationFloodZones/Image/1001.jpg', '/media/polarbear/BigBertha/Datasets/ComputerVision/ImageSegmentationFloodZones/Image/1002.jpg', '/media/polarbear/BigBertha/Datasets/ComputerVision/ImageSegmentationFloodZones/Image/1003.jpg', '/media/polarbear/BigBertha/Datasets/ComputerVision/ImageSegmentationFloodZones/Image/1004.jpg', '/media/polarbear/BigBertha/Datasets/ComputerVision/ImageSegmentationFloodZones/Image/1005.jpg', '/media/polarbear/BigBertha/Datasets/ComputerVision/ImageSegmentationFloodZones/Image/1006.jpg', '/media/polarbear/BigBertha/Datasets/

In [4]:
im_transform = T.Compose([
    T.ToTensor(),
    T.Resize((256, 256), antialias=False),
])

def to_image(inp):
    if inp.ndim == 3:
        out = inp.numpy()
        out = np.transpose(out, (1, 2, 0))
    else:
        out = inp.numpy()
        out = np.transpose(out, (0, 2, 3, 1))
    return out

def calculate_output_size(input_size, kernel_size, stride, padding):
    return (input_size - 1) * stride - 2 * padding + (kernel_size - 1) + 1

In [5]:
class FloodZoneDataset(data.Dataset):
    def __init__(self, root_directory, im_transform=im_transform):
        self.image_directory = os.path.join(root_directory, 'Image')
        self.mask_directory = os.path.join(root_directory, 'Mask')
        self.im_transform = im_transform
        self.image_paths = sorted(glob.glob(os.path.join(self.image_directory, '*.jpg')))
        self.mask_paths = sorted(glob.glob(os.path.join(self.mask_directory, '*.png')))
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]
        image = cv.imread(image_path)
        if image is not None:
            image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
        else:
            image = np.zeros((256, 256, 3), dtype=np.uint8)
        mask = cv.imread(mask_path, cv.IMREAD_GRAYSCALE)
        image = self.im_transform(image)
        mask = self.im_transform(mask)
        return image, 1.0 - mask
    
    def __len__(self):
        return len(self.image_paths)

In [6]:
dataset = FloodZoneDataset(root_directory)
dataloader = data.DataLoader(dataset, batch_size=4, shuffle=True, drop_last=True)

X, y = next(iter(dataloader))
print(X.shape, y.shape)

torch.Size([4, 3, 256, 256]) torch.Size([4, 1, 256, 256])


In [7]:
class TestNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = MixtureOfExpertsEncoder.MixtureOfExpertsSegmentationEncoder()
        self.decoder = SegmentationDecoder.Decoder()
    
    def forward(self, X):
        out = self.encoder(X)
        out = self.decoder(out)
        return out
    
class ControlNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Traditional2dSegmenter.ControlSegmentationModel()
        self.decoder = SegmentationDecoder.Decoder()
    
    def forward(self, X):
        out = self.encoder(X)
        out = self.decoder(out)
        return out

In [8]:
test_network = TestNetwork()
ctrl_network = ControlNetwork()

start = time.time()
out = ctrl_network(X)
print(time.time() - start)
print(out.shape)
print(out.min(), out.max())

start = time.time()
out = test_network(X)
print(time.time() - start)
print(out.shape)
print(out.min(), out.max())

2.1029884815216064
torch.Size([4, 1, 256, 256])
tensor(0., grad_fn=<MinBackward1>) tensor(1., grad_fn=<MaxBackward1>)
2.1359446048736572
torch.Size([4, 1, 256, 256])
tensor(0., grad_fn=<MinBackward1>) tensor(1., grad_fn=<MaxBackward1>)


In [9]:
num_epochs = 100
test_losses, ctrl_losses = [], []
test_loss_fn = nn.MSELoss()
test_optimizer = optim.Adam(test_network.parameters(), lr=1e-5)
ctrl_loss_fn = nn.MSELoss()
ctrl_optimizer = optim.Adam(ctrl_network.parameters(), lr=1e-5)
test_network = test_network.to(device)
ctrl_network = ctrl_network.to(device)

for epoch in range(num_epochs):
    print(epoch)
    test_running_loss = 0.0
    for i, (X, y) in enumerate(dataloader):
        X = X.to(device)
        y = y.to(device)
        test_optimizer.zero_grad()
        test_out = test_network(X)

        test_loss = test_loss_fn(test_out, y)
        test_loss.backward()
        test_optimizer.step()

        test_running_loss += test_loss.item()

        if i % 100 == 0:
            print(
                f'Epoch {epoch + 1}, Batch {i + 1}, Test Loss: {test_running_loss / 100:.3f}'
            )
            test_losses.append(test_running_loss / 100)
            test_running_loss = 0.0

0
Epoch 1, Batch 1, Test Loss: 0.005
1
Epoch 2, Batch 1, Test Loss: 0.004
2
Epoch 3, Batch 1, Test Loss: 0.004
3
Epoch 4, Batch 1, Test Loss: 0.004
4
Epoch 5, Batch 1, Test Loss: 0.004
5
Epoch 6, Batch 1, Test Loss: 0.004
6
Epoch 7, Batch 1, Test Loss: 0.004
7
Epoch 8, Batch 1, Test Loss: 0.004
8
Epoch 9, Batch 1, Test Loss: 0.003
9
Epoch 10, Batch 1, Test Loss: 0.002
10
Epoch 11, Batch 1, Test Loss: 0.003
11
Epoch 12, Batch 1, Test Loss: 0.002
12
Epoch 13, Batch 1, Test Loss: 0.002
13
Epoch 14, Batch 1, Test Loss: 0.002
14
Epoch 15, Batch 1, Test Loss: 0.002
15
Epoch 16, Batch 1, Test Loss: 0.001
16
Epoch 17, Batch 1, Test Loss: 0.002
17
Epoch 18, Batch 1, Test Loss: 0.002
18
Epoch 19, Batch 1, Test Loss: 0.001
19
Epoch 20, Batch 1, Test Loss: 0.002
20
Epoch 21, Batch 1, Test Loss: 0.002
21
Epoch 22, Batch 1, Test Loss: 0.002
22
Epoch 23, Batch 1, Test Loss: 0.001
23
Epoch 24, Batch 1, Test Loss: 0.002
24
Epoch 25, Batch 1, Test Loss: 0.001
25
Epoch 26, Batch 1, Test Loss: 0.001
26
Ep

In [10]:
for epoch in range(num_epochs):
    print(epoch)
    ctrl_running_loss = 0.0
    for i, (X, y) in enumerate(dataloader):
        X = X.to(device)
        y = y.to(device)
        ctrl_optimizer.zero_grad()
        ctrl_out = ctrl_network(X)

        ctrl_loss = ctrl_loss_fn(ctrl_out, y)
        ctrl_loss.backward()
        ctrl_optimizer.step()

        ctrl_running_loss += ctrl_loss.item()

        if i % 100 == 0:
            print(
                f'Epoch {epoch + 1}, Batch {i + 1}, Control Loss: {ctrl_running_loss / 100:.3f}'
            )
            ctrl_losses.append(ctrl_running_loss / 100)
            ctrl_running_loss = 0.0

0
Epoch 1, Batch 1, Control Loss: 0.008


1
Epoch 2, Batch 1, Control Loss: 0.005
2
Epoch 3, Batch 1, Control Loss: 0.004
3
Epoch 4, Batch 1, Control Loss: 0.005
4
Epoch 5, Batch 1, Control Loss: 0.003
5
Epoch 6, Batch 1, Control Loss: 0.004
6
Epoch 7, Batch 1, Control Loss: 0.003
7
Epoch 8, Batch 1, Control Loss: 0.003
8
Epoch 9, Batch 1, Control Loss: 0.003
9
Epoch 10, Batch 1, Control Loss: 0.003
10
Epoch 11, Batch 1, Control Loss: 0.003
11
Epoch 12, Batch 1, Control Loss: 0.003
12
Epoch 13, Batch 1, Control Loss: 0.003
13
Epoch 14, Batch 1, Control Loss: 0.003
14
Epoch 15, Batch 1, Control Loss: 0.002
15
Epoch 16, Batch 1, Control Loss: 0.003
16
Epoch 17, Batch 1, Control Loss: 0.002
17
Epoch 18, Batch 1, Control Loss: 0.002
18
Epoch 19, Batch 1, Control Loss: 0.002
19
Epoch 20, Batch 1, Control Loss: 0.002
20
Epoch 21, Batch 1, Control Loss: 0.002
21
Epoch 22, Batch 1, Control Loss: 0.002
22
Epoch 23, Batch 1, Control Loss: 0.002
23
Epoch 24, Batch 1, Control Loss: 0.002
24
Epoch 25, Batch 1, Control Loss: 0.002
25
Epoch 

In [12]:
print(f'Test: {np.max(test_losses)-np.min(test_losses)}, Control: {np.max(ctrl_losses)-np.min(ctrl_losses)}')

Test: 0.00461292177438736, Control: 0.0069376127421855935


In [13]:
print(np.max(test_losses), np.min(test_losses))
print(np.max(ctrl_losses), np.min(ctrl_losses))

0.005260730981826782 0.0006478092074394226
0.007513035535812378 0.0005754227936267853
