In [1]:
import numpy as np
import os
import cv2 as cv
from typing import Iterable
import matplotlib.pyplot as plt
import glob
import time

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
import torchvision.transforms as T
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils import data
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
from models import MixtureOfExpertsEncoder, Traditional2dSegmenter, SegmentationDecoder
from models.utils.util_classes import SplitTensor, AddInQuadrature, DepthSum, ConvWH, ConvDW, ConvDH

dev = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(dev)
print(device)

cpu


In [2]:
class FftLayer(nn.Module):
    def __init__(self, kernel_size, stride, padding, num_orders=4, inverse=False):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.num_orders = num_orders

In [3]:
root_directory = '/media/shiva/BigBertha/Datasets/ComputerVision/ImageSegmentationFloodZones/'
image_directory = os.path.join(root_directory, 'Image')
mask_directory = os.path.join(root_directory, 'Mask')
image_paths = glob.glob(image_directory + '/*.jpg')
mask_paths = glob.glob(mask_directory + '/*.png')
print(image_paths)
print(mask_paths)

['/media/shiva/BigBertha/Datasets/ComputerVision/ImageSegmentationFloodZones/Image/0.jpg', '/media/shiva/BigBertha/Datasets/ComputerVision/ImageSegmentationFloodZones/Image/1.jpg', '/media/shiva/BigBertha/Datasets/ComputerVision/ImageSegmentationFloodZones/Image/10.jpg', '/media/shiva/BigBertha/Datasets/ComputerVision/ImageSegmentationFloodZones/Image/1000.jpg', '/media/shiva/BigBertha/Datasets/ComputerVision/ImageSegmentationFloodZones/Image/1001.jpg', '/media/shiva/BigBertha/Datasets/ComputerVision/ImageSegmentationFloodZones/Image/1002.jpg', '/media/shiva/BigBertha/Datasets/ComputerVision/ImageSegmentationFloodZones/Image/1003.jpg', '/media/shiva/BigBertha/Datasets/ComputerVision/ImageSegmentationFloodZones/Image/1004.jpg', '/media/shiva/BigBertha/Datasets/ComputerVision/ImageSegmentationFloodZones/Image/1005.jpg', '/media/shiva/BigBertha/Datasets/ComputerVision/ImageSegmentationFloodZones/Image/1006.jpg', '/media/shiva/BigBertha/Datasets/ComputerVision/ImageSegmentationFloodZones/I

In [15]:
im_transform = T.Compose([
    T.ToTensor(),
    T.Resize((256, 256)),
])

def to_image(inp):
    if inp.ndim == 3:
        out = inp.numpy()
        out = np.transpose(out, (1, 2, 0))
    else:
        out = inp.numpy()
        out = np.transpose(out, (0, 2, 3, 1))
    return out

def calculate_output_size(input_size, kernel_size, stride, padding):
    return (input_size - 1) * stride - 2 * padding + (kernel_size - 1) + 1

In [16]:
class FloodZoneDataset(data.Dataset):
    def __init__(self, root_directory, im_transform=im_transform):
        self.image_directory = os.path.join(root_directory, 'Image')
        self.mask_directory = os.path.join(root_directory, 'Mask')
        self.im_transform = im_transform
        self.image_paths = sorted(glob.glob(os.path.join(self.image_directory, '*.jpg')))
        self.mask_paths = sorted(glob.glob(os.path.join(self.mask_directory, '*.png')))
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]
        image = cv.imread(image_path)
        image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
        mask = cv.imread(mask_path, cv.IMREAD_GRAYSCALE)
        image = self.im_transform(image)
        mask = self.im_transform(mask)
        return image, 1.0 - mask
    
    def __len__(self):
        return len(self.image_paths)

In [17]:
dataset = FloodZoneDataset(root_directory)
dataloader = data.DataLoader(dataset, batch_size=4, shuffle=True, drop_last=True)

X, y = next(iter(dataloader))
print(X.shape, y.shape)

torch.Size([4, 3, 256, 256]) torch.Size([4, 1, 256, 256])


In [18]:
class TestNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = MixtureOfExpertsEncoder.MixtureOfExpertsSegmentationEncoder()
        self.decoder = SegmentationDecoder.Decoder()
    
    def forward(self, X):
        out = self.encoder(X)
        out = self.decoder(out)
        return out
    
class ControlNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Traditional2dSegmenter.ControlSegmentationModel()
        self.decoder = SegmentationDecoder.Decoder()
    
    def forward(self, X):
        out = self.encoder(X)
        out = self.decoder(out)
        return out

In [19]:
test_network = TestNetwork()
ctrl_network = ControlNetwork()

start = time.time()
out = ctrl_network(X)
print(time.time() - start)
print(out.shape)
print(out.min(), out.max())

start = time.time()
out = test_network(X)
print(time.time() - start)
print(out.shape)
print(out.min(), out.max())

5.014561653137207
torch.Size([4, 1, 256, 256])
tensor(0., grad_fn=<MinBackward1>) tensor(1., grad_fn=<MaxBackward1>)
4.338047981262207
torch.Size([4, 1, 256, 256])
tensor(0., grad_fn=<MinBackward1>) tensor(1., grad_fn=<MaxBackward1>)


In [20]:
test_loss_fn = nn.MSELoss()
test_optimizer = optim.Adam(test_network.parameters(), lr=1e-5)
ctrl_loss_fn = nn.MSELoss()
ctrl_optimizer = optim.Adam(ctrl_network.parameters(), lr=1e-5)

In [21]:
num_epochs = 10
test_losses, ctrl_losses = [], []
test_network = test_network.to(device)
ctrl_network = ctrl_network.to(device)

for epoch in range(num_epochs):
    test_running_loss, ctrl_running_loss = 0.0, 0.0
    for i, (X, y) in enumerate(dataloader):
        X = X.to(device)
        y = y.to(device)
        test_optimizer.zero_grad()
        ctrl_optimizer.zero_grad()

        test_out = test_network(X)
        ctrl_out = ctrl_network(X)

        print(test_out.shape, y.shape)
        print(ctrl_out.shape, y.shape)

        test_loss = test_loss_fn(test_out, y)
        ctrl_loss = ctrl_loss_fn(ctrl_out, y)

        test_optimizer.zero_grad()
        ctrl_optimizer.zero_grad()
        test_loss.backward()
        ctrl_loss.backward()
        test_optimizer.step()
        ctrl_optimizer.step()

        test_running_loss += test_loss.item()
        ctrl_running_loss += ctrl_loss.item()

        if i % 100 == 99:
            print(f'Epoch {epoch + 1}, Batch {i + 1}, Test Loss: {test_running_loss / 100:.3f}, Control Loss: {ctrl_running_loss / 100:.3f}')
            test_losses.append(test_running_loss / 100)
            ctrl_losses.append(ctrl_running_loss / 100)
            test_running_loss, ctrl_running_loss = 0.0, 0.0

torch.Size([4, 1, 256, 256]) torch.Size([4, 1, 256, 256])
torch.Size([4, 1, 256, 256]) torch.Size([4, 1, 256, 256])
torch.Size([4, 1, 256, 256]) torch.Size([4, 1, 256, 256])
torch.Size([4, 1, 256, 256]) torch.Size([4, 1, 256, 256])


KeyboardInterrupt: 