# Compressive Ratio 8:1 Neural Network Training

This notebook demonstrates the construction and training of a neural network optimized for a compressive ratio of 8:1. The network takes 64 inputs and produces 8 outputs.

# libs

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import trange, tqdm
from mpl_toolkits.axes_grid1 import make_axes_locatable
import os
import cv2 as cv
import glob
import torch
import torch.nn as nn
from torch.nn import functional as F
import torchvision
import torch.optim as optim
import torchvision
import torchvision.transforms.functional as TF
from torchinfo import summary
from NN_module import Upsample, Residual
from utils import (SampleDataset, save_checkpoint, load_checkpoint,
                   get_loaders, check_accuracy, save_predicitons_as_imgs, init_weights)

%matplotlib inline
plt.style.use('default')
plt.rcParams['figure.figsize'] = [8, 8]
plt.rcParams.update({'font.size': 20})
plt.rcParams['figure.dpi'] = 150

# Network architecture

In [2]:
class Decoder(nn.Module):
    def __init__(self, channel, filters=[64, 128, 256, 512]):
        super(Decoder, self).__init__()

        self.input_layer = nn.Sequential(
            nn.Conv2d(channel, filters[0],
                      kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(filters[0]),
            nn.LeakyReLU(),
            nn.Conv2d(filters[0], filters[0],
                      kernel_size=3, padding=1, bias=False),
        )
        self.input_skip = nn.Sequential(
            nn.Conv2d(channel, filters[0],
                      kernel_size=3, padding=1, bias=False)
        )

        self.residual_conv_1 = Residual(
            filters[0], filters[1], use_1x1conv=True)
        self.downsample_1 = nn.Sequential(
            nn.Conv2d(filters[1], filters[1], kernel_size=3,
                      padding=1, stride=2, bias=False),
            nn.BatchNorm2d(filters[1]),
        )

        self.residual_conv_2 = Residual(
            filters[1], filters[2], use_1x1conv=True)
        self.downsample_2 = nn.Sequential(
            nn.Conv2d(filters[2], filters[2], kernel_size=3,
                      padding=1, stride=2, bias=False),
            nn.BatchNorm2d(filters[2]),
        )

        self.bridge = Residual(filters[2], filters[3], use_1x1conv=True)

        self.upsample_1 = Upsample(filters[3], filters[3], 2, 2)
        self.up_residual_conv1 = Residual(
            filters[3] + filters[2], filters[2],
            paddings=1, strides=1, use_1x1conv=True
        )

        self.upsample_2 = Upsample(filters[2], filters[2], 2, 2)
        self.up_residual_conv2 = Residual(
            filters[2] + filters[1], filters[1],
            paddings=1, strides=1, use_1x1conv=True
        )

        self.upsample_3 = Upsample(filters[1], filters[1], 2, 2)
        self.up_residual_conv3 = Residual(
            filters[1], filters[0],
            paddings=1, strides=1, use_1x1conv=True
        )

        self.upsample_4 = Upsample(filters[0], filters[0], 2, 2)
        self.up_residual_conv4 = Residual(
            filters[0], filters[0]//2,
            paddings=1, strides=1, use_1x1conv=True
        )

        self.upsample_5 = Upsample(filters[0]//2, filters[0]//2, 2, 2)
        self.up_residual_conv5 = Residual(
            filters[0]//2, filters[0]//4,
            paddings=1, strides=1, use_1x1conv=True
        )

        self.up_residual_conv6 = Residual(
            filters[0]//4, filters[0]//8,
            paddings=1, strides=1, use_1x1conv=True
        )
        self.up_residual_conv7 = Residual(
            filters[0]//8, filters[0]//16,
            paddings=1, strides=1, use_1x1conv=True
        )
        self.up_residual_conv8 = Residual(
            filters[0]//16, filters[0]//32,
            paddings=1, strides=1, use_1x1conv=True
        )

        self.output_layer = nn.Conv2d(filters[0]//32, 1, 1, 1, bias=True)

    def forward(self, x):
        # Encode
        x1 = self.input_layer(x) + self.input_skip(x)  # 64, 64, 64
        x1 = self.residual_conv_1(x1)  # 128, 64, 64
        x2 = self.downsample_1(x1)  # 128, 32, 32
        x2 = self.residual_conv_2(x2)  # 256, 32, 32
        x3 = self.downsample_2(x2)  # 256, 16, 16

        # Bridge
        x4 = self.bridge(x3)  # 512, 16, 16

        # Decode
        x5 = self.upsample_1(x4)  # 1024, 32, 32
        x5 = torch.cat([x5, x2], dim=1)  # 1536, 32, 32
        x5 = self.up_residual_conv1(x5)  # 512, 32, 32

        x6 = self.upsample_2(x5)  # 512, 64, 64
        x6 = torch.cat([x6, x1], dim=1)  # 768, 64, 64
        x6 = self.up_residual_conv2(x6)  # 256, 64, 64

        x7 = self.upsample_3(x6)  # 256, 128, 128
        x7 = self.up_residual_conv3(x7)  # 128, 128, 128

        x8 = self.upsample_4(x7)  # 128, 256, 256
        x8 = self.up_residual_conv4(x8)  # 64, 256, 256

        x9 = self.upsample_5(x8)  # 64, 512, 512
        x9 = self.up_residual_conv5(x9)  # 32, 512, 512

        # residual blocks
        x10 = self.up_residual_conv6(x9)  # 16, 512, 512
        x11 = self.up_residual_conv7(x10)  # 16, 512, 512
        x12 = self.up_residual_conv8(x11)  # 8, 512, 512

        output = self.output_layer(x12)  # 1, 512, 512

        return output

In [3]:
def test():
    x = torch.randn((1, 8, 64, 64)).to(device='cpu', dtype=torch.float)
    model = Decoder(channel=8)
    model = model.to(device='cpu', dtype=torch.float)
    preds = model(x)
#     torch.save(model, 'test.pt')
    print('input:', x.shape)
    print('output:', preds.shape)
    print()
    print(summary(model, input_size=(1, 8, 64, 64), device='cpu'))

In [4]:
## show network architecture
test()

input: torch.Size([1, 8, 64, 64])
output: torch.Size([1, 1, 512, 512])

Layer (type:depth-idx)                   Output Shape              Param #
Decoder                                  [1, 1, 512, 512]          --
├─Sequential: 1-1                        [1, 64, 64, 64]           --
│    └─Conv2d: 2-1                       [1, 64, 64, 64]           4,608
│    └─BatchNorm2d: 2-2                  [1, 64, 64, 64]           128
│    └─LeakyReLU: 2-3                    [1, 64, 64, 64]           --
│    └─Conv2d: 2-4                       [1, 64, 64, 64]           36,864
├─Sequential: 1-2                        [1, 64, 64, 64]           --
│    └─Conv2d: 2-5                       [1, 64, 64, 64]           4,608
├─Residual: 1-3                          [1, 128, 64, 64]          --
│    └─Conv2d: 2-6                       [1, 128, 64, 64]          73,728
│    └─BatchNorm2d: 2-7                  [1, 128, 64, 64]          256
│    └─Conv2d: 2-8                       [1, 128, 64, 64]          

# Train the network

## parameters

In [7]:
# Hyperparameters 
LEARNING_RATE = 0.001
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 50
NUM_EPOCHS = 800
NUM_WORKERS = 4
PIN_MEMORY = True
LOAD_MODEL = False
valid_loss_list = []
train_loss_list = []

## training, validation, and test dataset
TRAIN_SAMPLE_DIR = "path to your training dataset sample directory"
TRAIN_TARGET_DIR = "path to your training dataset target directory"
VAL_SAMPLE_DIR = "path to your validation dataset sample directory"
VAL_TARGET_DIR = "path to your validation dataset sample directory"
TEST_TSAMPLE_DIR = "path to your test dataset sample directory"
TEST_TARGET_DIR = "path to your test dataset sample directory"

## folder to save predicted images
pred_dirs = 'path to your prediction directory'
if not os.path.exists(pred_dirs):
    os.makedirs(pred_dirs)

## train functions

In [8]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)
    
    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE, dtype=torch.float32)
        targets = targets.float().unsqueeze(1).to(device=DEVICE, dtype=torch.float32)
        
        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)
            train_loss_list.append(np.mean(loss.cpu().detach().numpy()))
            
        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        # update tqdm loop
        loop.set_postfix(loss=loss.item())

def train_main():
    model = Decoder(channel=8).to(device=DEVICE, dtype=torch.float32)
    model.apply(init_weights)
    loss_fn = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    scheduler = optim.lr_scheduler.StepLR(
        optimizer, step_size=500, gamma=0.1, verbose=True)

    train_loader, val_loader = get_loaders(
        TRAIN_SAMPLE_DIR,
        TRAIN_TARGET_DIR,
        VAL_SAMPLE_DIR,
        VAL_TARGET_DIR,
        BATCH_SIZE,
        NUM_WORKERS,
        PIN_MEMORY,
    )

    if LOAD_MODEL:
        try:
            load_checkpoint(torch.load(
                'my_checkpoint_Decoder_64in8out_real.pth.tar'), model)
        except:
            pass

    valid_loss_list.append(check_accuracy(val_loader, model, device=DEVICE))

    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        print('present learning rate: {0}'.format(
            optimizer.param_groups[0]['lr']))
        print(f'Epoch #{epoch}')
        train_fn(train_loader, model, optimizer, loss_fn, scaler)

        # save model
        checkpoint = {
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        save_checkpoint(
            checkpoint, filename='my_checkpoint_Decoder_64in8out_real.pth.tar')
        torch.save(model, 'my_checkpoint_Decoder_64in8out_real.pt')

        # check accuracy
        valid_loss_list.append(check_accuracy(val_loader, model, device=DEVICE))

        # save the best model
        if np.min(valid_loss_list) == valid_loss_list[-1]:
            torch.save(model, 'my_checkpoint_Decoder_64in8out_real_best.pt')
            save_checkpoint(
                checkpoint, filename='my_checkpoint_Decoder_64in8out_real_best.pth.tar')

        # output examples to a folder
        if (epoch+1) % 50 == 0:
            save_predicitons_as_imgs(
                val_loader, model, folder=pred_dirs, device=DEVICE
            )

        scheduler.step()
        print()

In [12]:
## train the network
train_main()

In [11]:
## plot MSE loss of training and validation dataset
plt.figure(figsize=(21, 9))
plt.plot(train_loss_list, '-o', label='train')
plt.plot(valid_loss_list, '-o', label='valid')
plt.ylabel('MSE loss')
plt.xlabel('epoch')
plt.legend(loc='upper right')
plt.show()