# Codes for Training the U-Net

In [None]:
import torch
torch.cuda.current_device()
torch.cuda.device(0)
torch.cuda.device_count()
torch.cuda.get_device_name(0)
torch.cuda.is_available()

import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torchvision.models as models
from torch.autograd import Variable
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader, TensorDataset
from torchsummary import summary
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import numpy as np
from numpy import inf
import scipy.io as sio
import cv2
import h5py
import time


In [None]:
# Load the training data and print its shape

mat = h5py.File('ALL_TRAIN_1000.mat', 'r')
data = np.array(mat['all_multi_all'])
target = np.array(mat['all_envdb_norm_all'])
target = np.expand_dims(target, axis=1)

print('RAW data: ', data.shape)
print('Label images: ', target.shape)


In [None]:
# Hyper parameters

EPOCH = 200
BATCH_SIZE = 16
LR = 0.0001        # learning rate

torch_data = torch.from_numpy(data)
torch_delayed = torch.from_numpy(target)
train_dataset = TensorDataset(torch_data, torch_delayed)

# DataLoader for easy mini-batch return in training.
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)


In [None]:
# U-Net implementation in PyTorch

class UNet(nn.Module):
    def contracting_block(self, in_channels, out_channels, kernel_size=3):
        block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels, padding=1),
                    torch.nn.BatchNorm2d(out_channels),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels, padding=1),
                    torch.nn.BatchNorm2d(out_channels),
                    torch.nn.ReLU(),
                )
        return block
    
    def expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
            block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel, padding=1),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel, padding=1),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.ConvTranspose2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
                    )
            return  block
    
    def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
            block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel, padding=1),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel, padding=1),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1),
                    # torch.nn.BatchNorm2d(out_channels),
                    torch.nn.Sigmoid(),
                    )
            return  block
    
    def __init__(self, in_channel, out_channel):
        super(UNet, self).__init__()
        #Encode
        self.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=64)
        self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode2 = self.contracting_block(64, 128)
        self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode3 = self.contracting_block(128, 256)
        self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=2)
        # Bottleneck
        self.bottleneck = torch.nn.Sequential(
                            torch.nn.Conv2d(kernel_size=3, in_channels=256, out_channels=512, padding=1),
                            torch.nn.BatchNorm2d(512),
                            torch.nn.ReLU(),
                            torch.nn.Conv2d(kernel_size=3, in_channels=512, out_channels=512, padding=1),
                            torch.nn.BatchNorm2d(512),
                            torch.nn.ReLU(),
                            torch.nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1)
                            )
        # Decode
        self.conv_decode3 = self.expansive_block(512, 256, 128)
        self.conv_decode2 = self.expansive_block(256, 128, 64)
        self.final_layer = self.final_block(128, 64, out_channel)
        
    def crop_and_concat(self, upsampled, bypass, crop=False):
        if crop:
            c = (bypass.size()[2] - upsampled.size()[2]) // 2
            bypass = F.pad(bypass, (-c, -c, -c, -c))
        return torch.cat((upsampled, bypass), 1)
    
    def forward(self, x):
        # Encode
        encode_block1 = self.conv_encode1(x)
        encode_pool1 = self.conv_maxpool1(encode_block1)
        encode_block2 = self.conv_encode2(encode_pool1)
        encode_pool2 = self.conv_maxpool2(encode_block2)
        encode_block3 = self.conv_encode3(encode_pool2)
        encode_pool3 = self.conv_maxpool3(encode_block3)
        # Bottleneck
        bottleneck1 = self.bottleneck(encode_pool3)
        # Decode
        decode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=True)
        cat_layer2 = self.conv_decode3(decode_block3)
        decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=True)
        cat_layer1 = self.conv_decode2(decode_block2)
        decode_block1 = self.crop_and_concat(cat_layer1, encode_block1, crop=True)
        final_layer = self.final_layer(decode_block1)
        return  final_layer
    
model = UNet(in_channel=128,out_channel=1).cuda()
criterion = nn.L1Loss().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.3)


In [None]:
# Train the U-Net model

num = 1;
loss_count = np.zeros((EPOCH, num))
time_count = np.zeros((1, num))

for i in range(num):
    
    time_start = time.time()

    for epoch in range(EPOCH):
        for step, (a, b) in enumerate(train_loader):
            model.zero_grad()               # clear gradients for this training step
            decoded = model(a.float().cuda())
            loss = criterion(decoded, b.float().cuda())
            loss.backward()                 # backpropagation, compute gradients
            optimizer.step()                # apply gradients
            
        scheduler.step()
        loss_count[epoch, i] = loss.data
        print('Epoch: ', epoch, '| train loss: %.4f' % loss.data)
            
    time_end = time.time()
    time_count[0, i] = time_end-time_start
    print('time cost',time_count[0, i],'s')
    
    plt.figure(figsize=(6,3))
    plt.plot(loss_count[:, i])
    plt.grid()
    plt.title('Loss During Training')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.show()


In [None]:
# Save the well-trained model's parameter dictionary

torch.save(model.state_dict(), 'UNet_dict.pth')
