# Strt2Ftprnt dataset builder

*from street networks to predicted building footprint layouts.*

## U-Net Architecture for Image Segmentation.

*The dataset is gathered from OSM data using the OSMnx package.*

### Importing packages ...

In [None]:
import torch
import torch.utils.data as data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary
import pytorch_lightning as pl

from sklearn.model_selection import train_test_split

import os
import math
import re
import pandas as pd
import geopandas as gpd
import osmnx as ox
import numpy as np
import glob
import cv2
from skimage.io import imread
import folium
from tqdm import tqdm, trange

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
city = ox.geocode_to_gdf('Ankara, Turkey')
ax = ox.project_gdf(city).plot()
_ = ax.axis('off')

In [None]:
Ank = ox.graph_from_place('Ankara, Turkey', network_type='all')
ox.plot_graph(Ank, bgcolor='#FFFFFF', node_size=0, edge_color='#111111', edge_linewidth=1)

## Test plot

In [None]:
map_settings = dict(
    dist=200,
    edge_color='#000000',
    bgcolor='#FFFFFF',
    dpi = 300,
    point = (39.890591,32.783029),
    default_width=2,
    )

fig, ax = ox.plot_figure_ground(network_type='all', figsize=(8, 8), 
                                **map_settings)
fig.savefig('D:/AI in Urban Design/DL UD dir/Ankara_trainY/imgX_test.png', dpi=map_settings['dpi']) 

gdf_bldings = ox.geometries.geometries_from_point(center_point=map_settings['point'], 
                                                  tags = {'building':True}, dist=map_settings['dist'])

fig, ax = ox.plot.plot_footprints(gdf_bldings, ax=ax, figsize=(8, 8), 
                                  color='#000000', alpha=None, bgcolor='#FFFFFF',
                                  bbox=ox.utils_geo.bbox_from_point(map_settings['point'], dist=map_settings['dist'], project_utm=False, return_crs=False),
                                  save=True, show=True, close=False, 
                                  filepath='D:/AI in Urban Design/DL UD dir/Ankara_trainY/imgY_test.png',
                                  dpi=map_settings['dpi'])

coslat = np.cos(np.cos(map_settings['point'][1] / 180. * np.pi) )
ax.set_aspect(1/coslat)
fig.set_figwidth(10)
fig

## Building the image dataset

### Looping over multiple locations in Ankara

In [None]:
north = 40.000000
south = 39.836170
east = 32.929667
west = 32.630090

step = 0.003

diffv = north-south
rv = int((diffv)/step)
print(diffv, rv)

diffh = east-west
rh = int((diffh)/step)
print(diffh, rh)

count = rv*rh
print('Expected img count: ' + str(count))

vlist = []
for i in range(rv):
    v = north-i*step
    vlist.append(v)
    
hlist = []
for i in range(rh):
    h = east-i*step
    hlist.append(h)
    
print(vlist)
print(hlist)
# print(str(3966/99))

In [None]:
a = 1
for i in vlist:
    for j in hlist:
        try:
            map_settings = dict(
                            dist=200,
                            edge_color='#000000',
                            bgcolor='#FFFFFF',
                            dpi = 300,
                            point = (i,j),
                            default_width=2,
                            )

            fig, ax = ox.plot_figure_ground(network_type='all', figsize=(8, 8), 
                                            **map_settings)
            fig.savefig('D:/AI in Urban Design/DL UD dir/Ankara_trainX/img_X_' + str(a) + '.png', dpi=map_settings['dpi'])

            gdf_bldings = ox.geometries.geometries_from_point(center_point=map_settings['point'], 
                                                              tags = {'building':True}, dist=map_settings['dist'])

            fig, ax = ox.plot.plot_footprints(gdf_bldings, ax=ax, figsize=(8, 8), 
                                              color='#000000', alpha=None, bgcolor='#FFFFFF',
                                              bbox=ox.utils_geo.bbox_from_point(map_settings['point'], dist=map_settings['dist'], project_utm=False, return_crs=False),
                                              save=True, show=True, close=False, 
                                              filepath='D:/AI in Urban Design/DL UD dir/Ankara_trainY/img_Y_' + str(a) + '.png',
                                              dpi=map_settings['dpi'])
            a += 1
        except ValueError as e:
            continue
        

### Get a lot more data from locations all around the world ...

In [None]:
# SET LOCATION COORDINATES
north = 53.568016
south = 53.541042
east = 10.115901
west = 9.922825

step = 0.003

diffv = north-south
rv = int((diffv)/step)
print(diffv, rv)

diffh = east-west
rh = int((diffh)/step)
print(diffh, rh)

count = rv*rh
print('Expected img count: ' + str(count))

vlist = []
for i in range(rv):
    v = north-i*step
    vlist.append(v)
    
hlist = []
for i in range(rh):
    h = east-i*step
    hlist.append(h)
    
print(vlist)
print(hlist)

In [None]:
# a = 13647 Extra starts here. 13856. 14096.
# 2191 + 3498
a = 21188
for i in vlist:
    for j in hlist:
        try:
            map_settings = dict(
                            dist=200,
                            edge_color='#000000',
                            bgcolor='#FFFFFF',
                            dpi = 300,
                            point = (i,j),
                            default_width=2,
                            )

            fig, ax = ox.plot_figure_ground(network_type='all', figsize=(8, 8), 
                                            **map_settings)
            fig.savefig('D:/AI in Urban Design/DL UD dir/Extra_X/img_X_' + str(a) + '.png', dpi=map_settings['dpi'])

            gdf_bldings = ox.geometries.geometries_from_point(center_point=map_settings['point'], 
                                                              tags = {'building':True}, dist=map_settings['dist'])

            fig, ax = ox.plot.plot_footprints(gdf_bldings, ax=ax, figsize=(8, 8), 
                                              color='#000000', alpha=None, bgcolor='#FFFFFF',
                                              bbox=ox.utils_geo.bbox_from_point(map_settings['point'], dist=map_settings['dist'], project_utm=False, return_crs=False),
                                              save=True, show=True, close=False, 
                                              filepath='D:/AI in Urban Design/DL UD dir/Extra_Y/img_Y_' + str(a) + '.png',
                                              dpi=map_settings['dpi'])
            a += 1
        except ValueError as e:
            continue

## Data Preprocessing

In [None]:
img_size = 256
# raw data directories
X_img_path = "D:/AI in Urban Design/DL UD dir/strtTOftprnt_X"
Y_img_path = "D:/AI in Urban Design/DL UD dir/strtTOftprnt_Y"

In [None]:
a = 34728
for i in tqdm(sorted(os.listdir(X_img_path)), ncols=100, disable=False):
    a += 1
    path = os.path.join(X_img_path, i)
    im = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    flipped_img = cv2.flip(im, 1)
    cv2.imwrite('D:/AI in Urban Design/DL UD dir/Extra_X/img_X_' 
                + str(a) + '.png', flipped_img)
    
b = 34728
for i in tqdm(sorted(os.listdir(Y_img_path)), ncols=100, disable=False):
    b += 1
    path = os.path.join(Y_img_path, i)
    im = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    flipped_img = cv2.flip(im, 1)
    cv2.imwrite('D:/AI in Urban Design/DL UD dir/Extra_Y/img_Y_' 
                + str(b) + '.png', flipped_img)

In [None]:
l = os.listdir(X_img_path)
lsorted = l.sort()
print(l)
print(lsorted)

In [None]:
a = 0
for i in tqdm(sorted(os.listdir(X_img_path)), ncols=100, disable=False):
    a += 1
    path = os.path.join(X_img_path, i)
    im = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    y,x = im[:,:,3].nonzero() # get the nonzero alpha coordinates
    minx = np.min(x)
    miny = np.min(y)
    maxx = np.max(x)
    maxy = np.max(y)
    cropImg = im[miny:maxy, minx:maxx]
    cv2.imwrite('D:/AI in Urban Design/DL UD dir/Extra_Xcrop/img_X_' 
                + str(int(re.search(r'\d+', i).group())) + '.png', cropImg)

In [None]:
def make_data(X_img_path, Y_img_path, img_size):
    
    X_data = []
    Y_data = []

    X_img_count = 0
    Y_img_count = 0
    
    for i in tqdm(os.listdir(X_img_path), ncols=100, disable=False):
        path = os.path.join(X_img_path, i)
        X = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        X = cv2.resize(X, (img_size, img_size), interpolation = cv2.INTER_AREA)
        X = np.array(X)
        X = X.reshape((1, img_size, img_size))
        X = X / 255
        X = 1 - X
        X_data.append(X)
        X_img_count += 1

    for i in tqdm(os.listdir(Y_img_path), ncols=100, disable=False):
        path = os.path.join(Y_img_path, i)
        Y = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        
        Y = cv2.resize(Y, (img_size, img_size), interpolation = cv2.INTER_AREA)
        Y = np.array(Y)
        Y = Y.reshape((1, img_size, img_size))
        Y = Y / 255
        Y = 1 - Y
        Y_data.append(Y)
        Y_img_count += 1    
            
    print('X Image_count:' + str(X_img_count))
    print('Y Image_count:' + str(Y_img_count))
        
    # train, val, test : 0.6, 0.2, 0.2 Split
    X_train, X_test, Y_train, Y_test = train_test_split(X_data, Y_data, test_size=0.2, random_state=1)
    X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.25, random_state=1) # 0.25 x 0.8 = 0.2
    
    print('X_train Image_count:' + str(len(X_train))) 
    print('Y_train Image_count:' + str(len(Y_train)))
    print('X_val Image_count:' + str(len(X_val))) 
    print('Y_val Image_count:' + str(len(Y_val)))
    print('X_test Image_count:' + str(len(X_test))) 
    print('Y_test Image_count:' + str(len(Y_test)))
    
    return X_train, Y_train, X_val, Y_val, X_test, Y_test

In [None]:
X_train, Y_train, X_val, Y_val, X_test, Y_test = make_data(X_img_path, Y_img_path, img_size)

In [None]:
# Visualizing images
index = 1345
fig = plt.figure()

ax = fig.add_subplot(1, 2, 1)
imgplot = plt.imshow(X_train[index].reshape((img_size, img_size)), cmap='gray')
ax.set_title('Input Sample')

ax = fig.add_subplot(1, 2, 2)
imgplot = plt.imshow(Y_train[index].reshape((img_size, img_size)), cmap='gray')
ax.set_title('Ground Truth')

In [None]:
index = 1345
fig = plt.figure()

ax = fig.add_subplot(1, 2, 1)
imgplot = plt.imshow((1 - X_train[index]).reshape((img_size, img_size)), cmap='gray')
ax.set_title('Input Sample')

ax = fig.add_subplot(1, 2, 2)
imgplot = plt.imshow((1 - Y_train[index]).reshape((img_size, img_size)), cmap='gray')
ax.set_title('Ground Truth')

In [None]:
# Get the shape of the data
print('Training data shape is: ' + str(X_train[0].shape))
print('Validation data shape is: ' + str(X_val[0].shape))
print('Test data shape is: ' + str(X_test[0].shape))

In [None]:
class SegmentationDataSet(data.Dataset):
    def __init__(self,
                 inputs: list,
                 targets: list,
                 transform=None
                 ):
        self.inputs = inputs
        self.targets = targets
        self.transform = transform
        self.inputs_dtype = torch.float32
        self.targets_dtype = torch.float32

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, index: int):
        # Select the sample
        input_ID = self.inputs[index]
        target_ID = self.targets[index]

        # Load input and target
        x, y = input_ID, target_ID

        # Preprocessing
        if self.transform is not None:
            x, y = self.transform(x, y)

        # Typecasting
        x, y = torch.from_numpy(x).type(self.inputs_dtype), torch.from_numpy(y).type(self.targets_dtype)
        # y = torch.squeeze(y)
        
        return x, y

### Create Dataloaders

In [None]:
training_dataset = SegmentationDataSet(inputs = X_train, targets = Y_train) 

training_dataloader = data.DataLoader(dataset=training_dataset, batch_size = 2, shuffle=True)
x, y = next(iter(training_dataloader))

print(f'x = shape: {x.shape}; type: {x.dtype}')
print(f'x = min: {x.min()}; max: {x.max()}')
print(f'y = shape: {y.shape}; type: {y.dtype}')
print(f'y = min: {y.min()}; max: {y.max()}')

In [None]:
val_dataset = SegmentationDataSet(inputs = X_val, targets = Y_val)

val_dataloader = data.DataLoader(dataset=val_dataset, batch_size = 2, shuffle=False)
x, y = next(iter(val_dataloader))

print(f'x = shape: {x.shape}; type: {x.dtype}')
print(f'x = min: {x.min()}; max: {x.max()}')
print(f'y = shape: {y.shape}; type: {y.dtype}')
print(f'y = min: {y.min()}; max: {y.max()}')

In [None]:
x, y = next(iter(training_dataloader))
print(x[0, :, :, :])
print(y[0, :, :, :])

## Defining the U-Net architecture

In [None]:
def double_conv(in_c, out_c, seperable=True):
    if seperable:
        conv = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=1, bias=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1, groups=out_c, bias=True),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=1, bias=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1, groups=out_c, bias=True),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True)
        )
        return conv
    else:
        conv = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True)
        )
        return conv

In [None]:
def crop_img(tensor, target_tensor):
    target_size = target_tensor.size()[2]
    tensor_size = tensor.size()[2]
    delta = tensor_size - target_size
    delta = delta // 2
    return tensor[:, :, delta:tensor_size - delta, delta:tensor_size - delta]

In [None]:
class Unet(nn.Module):
    def __init__(self):
        super(Unet, self).__init__()
        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.down_conv1 = double_conv(1, 8, seperable=False)
        self.down_conv2 = double_conv(8, 16, seperable=True)
        self.down_conv3 = double_conv(16, 32, seperable=True)
        self.down_conv4 = double_conv(32, 64, seperable=True)
        self.down_conv5 = double_conv(64, 128, seperable=True)
        self.down_conv6 = double_conv(128, 256, seperable=True)
        self.down_conv7 = double_conv(256, 512, seperable=True)
        self.down_conv8 = double_conv(512, 1024, seperable=True)
        self.down_conv9 = double_conv(1024, 2048, seperable=True)

        self.up_trans1 = nn.ConvTranspose2d(2048, 1024, kernel_size=2, stride=2, bias=False)
        self.up_conv1 = double_conv(2048, 1024, seperable=True)
        self.up_trans2 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2, bias=False)
        self.up_conv2 = double_conv(1024, 512, seperable=True)
        self.up_trans3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2, bias=False)
        self.up_conv3 = double_conv(512, 256, seperable=True)
        self.up_trans4 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2, bias=False)
        self.up_conv4 = double_conv(256, 128, seperable=False)
        self.up_trans5 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2, bias=False)
        self.up_conv5 = double_conv(128, 64, seperable=False)
        self.up_trans6 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2, bias=False)
        self.up_conv6 = double_conv(64, 32, seperable=False)
        self.up_trans7 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2, bias=False)
        self.up_conv7 = double_conv(32, 16, seperable=False)
        self.up_trans8 = nn.ConvTranspose2d(16, 8, kernel_size=2, stride=2, bias=False)
        self.up_conv8 = double_conv(16, 8, seperable=False)
        
        self.num_classes = 1
        self.out = nn.Conv2d(8, self.num_classes, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, image):
        # Down
        x1 = self.down_conv1(image)
        x2 = self.max_pool_2x2(x1)
        x3 = self.down_conv2(x2)
        x4 = self.max_pool_2x2(x3)
        x5 = self.down_conv3(x4)
        x6 = self.max_pool_2x2(x5)
        x7 = self.down_conv4(x6)
        x8 = self.max_pool_2x2(x7)
        x9 = self.down_conv5(x8)
        x10 = self.max_pool_2x2(x9)
        x11 = self.down_conv6(x10)
        x12 = self.max_pool_2x2(x11)
        x13 = self.down_conv7(x12)
        x14 = self.max_pool_2x2(x13)
        x15 = self.down_conv8(x14)
        x16 = self.max_pool_2x2(x15)
        x17 = self.down_conv9(x16)
        
        # Up
        x_U = self.up_trans1(x17)
        y = crop_img(x15, x_U)
        x_U = self.up_conv1(torch.cat([x_U, y], 1))
        x_U = self.up_trans2(x_U)
        y = crop_img(x13, x_U)
        x_U = self.up_conv2(torch.cat([x_U, y], 1))
        x_U = self.up_trans3(x_U)
        y = crop_img(x11, x_U)
        x_U = self.up_conv3(torch.cat([x_U, y], 1))
        x_U = self.up_trans4(x_U)
        y = crop_img(x9, x_U)
        x_U = self.up_conv4(torch.cat([x_U, y], 1))
        x_U = self.up_trans5(x_U)
        y = crop_img(x7, x_U)
        x_U = self.up_conv5(torch.cat([x_U, y], 1))
        x_U = self.up_trans6(x_U)
        y = crop_img(x5, x_U)
        x_U = self.up_conv6(torch.cat([x_U, y], 1))
        x_U = self.up_trans7(x_U)
        y = crop_img(x3, x_U)
        x_U = self.up_conv7(torch.cat([x_U, y], 1))
        x_U = self.up_trans8(x_U)
        y = crop_img(x1, x_U)
        x_U = self.up_conv8(torch.cat([x_U, y], 1))
        
        # U-Net Output
        x_U = self.out(x_U)

        # Sigmoid layer
        x_sig = self.sigmoid(x_U)
        return x_sig

### Sigmoid U-net model architecture

In [None]:
def initial_block(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, stride=1, bias=True),
        nn.ReLU(inplace=True),
        nn.BatchNorm2d(out_c),
        nn.Conv2d(out_c, out_c, kernel_size=3, padding=1, stride=1, bias=True),
        nn.ReLU(inplace=True),
        nn.BatchNorm2d(out_c),
    )
    return conv

In [None]:
def enc_block(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, stride=1, bias=True),
        nn.ReLU(inplace=True),
        nn.BatchNorm2d(out_c),
        nn.Conv2d(out_c, out_c, kernel_size=3, padding=1, stride=1, bias=True),
        nn.ReLU(inplace=True),
        nn.BatchNorm2d(out_c)
    )
    return conv

In [None]:
def mid_block(out_c):
    conv = nn.Sequential(
        nn.Conv2d(out_c, out_c, kernel_size=7, padding=3, stride=1, bias=True),
        nn.ReLU(inplace=True),
        nn.BatchNorm2d(out_c),
        nn.Conv2d(out_c, out_c, kernel_size=7, padding=3, stride=1, bias=True),
        nn.ReLU(inplace=True),
        nn.BatchNorm2d(out_c)
    )
    return conv

In [None]:
def dec_block(out_c):
    conv = nn.Sequential(
        nn.BatchNorm2d(out_c),
        nn.Conv2d(out_c, out_c, kernel_size=7, padding=3, stride=1, bias=True),
        nn.ReLU(inplace=True),
        nn.BatchNorm2d(out_c)
    )
    return conv

In [None]:
def end_block(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=7, padding=3, stride=1, bias=True),
        nn.ReLU(inplace=True),
        nn.BatchNorm2d(out_c),
        nn.Conv2d(out_c, out_c, kernel_size=7, padding=3, stride=1, bias=True),
        nn.ReLU(inplace=True),
        nn.BatchNorm2d(out_c),
    )
    return conv
def sigmoid_block(out_c):
    conv = nn.Sequential(
        nn.Conv2d(out_c, 1, kernel_size=1, padding=0, stride=1, bias=True),
        nn.Sigmoid()
    )
    return conv
def conv_res(in_c, out_c):
    resconv = nn.Conv2d(in_c, out_c, kernel_size=1, padding=0, stride=1, bias=True)
    return resconv

In [None]:
class Sigmoid_Unet(nn.Module):
    def __init__(self):
        super(Sigmoid_Unet, self).__init__()
        self.initial = initial_block(1, 32)
        self.encblock1 = enc_block(32, 64)
        self.encblock2 = enc_block(64, 128)
        self.encblock3 = enc_block(128, 256)
        self.mid = mid_block(256)
        self.decblock1 = dec_block(128)
        self.decblock2 = dec_block(64)
        self.decblock3 = dec_block(32)
        self.end = end_block(16, 16)
        self.sigmoid = sigmoid_block(16)

        self.transpose1 = nn.ConvTranspose2d(512, 128, kernel_size=2, stride=2, bias=False)
        self.transpose2 = nn.ConvTranspose2d(256, 64, kernel_size=2, stride=2, bias=False)
        self.transpose3 = nn.ConvTranspose2d(128, 32, kernel_size=2, stride=2, bias=False)
        self.transpose4 = nn.ConvTranspose2d(64, 16, kernel_size=2, stride=2, bias=False)

        self.res1 = conv_res(1, 32)
        self.res2 = conv_res(32, 64)
        self.res3 = conv_res(64, 128)
        self.res4 = conv_res(128, 256)
        self.res5 = conv_res(256, 256)
        self.res6 = conv_res(128, 128)
        self.res7 = conv_res(64, 64)
        self.res8 = conv_res(32, 32)
        self.res9 = conv_res(16, 16)
        
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        
    def forward(self, image):
        x1 = self.initial(image)
        x1res = self.relu(x1 + self.res1(image))
        x2 = self.maxpool(x1res)

        x3 = self.encblock1(x2)
        x2res = self.relu(x3 + self.res2(x2))
        x4 = self.maxpool(x2res)
        
        x5 = self.encblock2(x4)
        x3res = self.relu(x5 + self.res3(x4))
        x6 = self.maxpool(x3res)
        
        x7 = self.encblock3(x6)
        x4res = self.relu(x7 + self.res4(x6))
        x8 = self.maxpool(x4res)

        x9 = self.mid(x8)
        x5res = self.relu(x9 + self.res5(x8))
        x10 = self.transpose1(torch.cat([x5res, x8], 1))
        
        x11 = self.decblock1(x10)
        x6res = self.relu(x11 + self.res6(x10))
        x12 = self.transpose2(torch.cat([x6res, x6], 1))
        
        x13 = self.decblock2(x12)
        x7res = self.relu(x13 + self.res7(x12))
        x14 = self.transpose3(torch.cat([x7res, x4], 1))
        
        x15 = self.decblock3(x14)
        x8res = self.relu(x15 + self.res8(x14))
        x16 = self.transpose4(torch.cat([x8res, x2], 1))
        
        x17 = self.end(x16)
        x9res = self.relu(x17 + self.res9(x16))
        out = self.sigmoid(x9res)
        return out

In [None]:
model = Unet()
summary = summary(model, (1, 256, 256))

### Hyperparameters ...

In [None]:
num_epochs = 2
learning_rate = 1e-3
lr_decay = 0.1
lr_decay_iter = 2

### Finding the learning rate 

In [None]:
from pytorch_lr_finder import LearningRateFinder 
lrf = LearningRateFinder(model, criterion, optimizer)
lrf.fit(training_dataloader)
lrf.plot()

In [None]:
class LearningRateFinder:
    """
    Train a model using different learning rates within a range to find the optimal learning rate.
    """

    def __init__(self,
                 model: nn.Module,
                 criterion,
                 optimizer,
                 device
                 ):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.loss_history = {}
        self._model_init = model.state_dict()
        self._opt_init = optimizer.state_dict()
        self.device = device

    def fit(self,
            data_loader: torch.utils.data.DataLoader,
            steps=100,
            min_lr=1e-7,
            max_lr=1,
            constant_increment=False
            ):
        """
        Trains the model for number of steps using varied learning rate and store the statistics
        """
        self.loss_history = {}
        self.model.train()
        current_lr = min_lr
        steps_counter = 0
        epochs = math.ceil(steps / len(data_loader))

        progressbar = trange(epochs, desc='Progress')
        for epoch in progressbar:
            batch_iter = tqdm(enumerate(data_loader), 'Training', total=len(data_loader),
                              leave=False)

            for i, (x, y) in batch_iter:
                x, y = x.to(self.device), y.to(self.device)
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = current_lr
                self.optimizer.zero_grad()
                out = self.model(x)
                loss = self.criterion(out, y)
                loss.backward()
                self.optimizer.step()
                self.loss_history[current_lr] = loss.item()

                steps_counter += 1
                if steps_counter > steps:
                    break

                if constant_increment:
                    current_lr += (max_lr - min_lr) / steps
                else:
                    current_lr = current_lr * (max_lr / min_lr) ** (1 / steps)

    def plot(self,
             smoothing=True,
             clipping=True,
             smoothing_factor=0.1
             ):
        """
        Shows loss vs learning rate(log scale) in a matplotlib plot
        """
        loss_data = pd.Series(list(self.loss_history.values()))
        lr_list = list(self.loss_history.keys())
        if smoothing:
            loss_data = loss_data.ewm(alpha=smoothing_factor).mean()
            loss_data = loss_data.divide(pd.Series(
                [1 - (1.0 - smoothing_factor) ** i for i in range(1, loss_data.shape[0] + 1)]))  # bias correction
        if clipping:
            loss_data = loss_data[10:-5]
            lr_list = lr_list[10:-5]
        plt.plot(lr_list, loss_data)
        plt.xscale('log')
        plt.title('Loss vs Learning rate')
        plt.xlabel('Learning rate (log scale)')
        plt.ylabel('Loss (exponential moving average)')
        plt.show()


In [None]:
lrf = LearningRateFinder(model, criterion, optimizer, device)
lrf.fit(training_dataloader)
lrf.plot()