In [1]:
##############################
#Training loop for NNs       #
#Maintainer: Christopher Chan#
#Version: 0.1.0              #
#Date: 2022-02-23            #
##############################

import os
import sys
import torch
import pathlib
import time
import re
import PIL
import numpy as np
import torch.nn as nn
import segmentation_models_pytorch as smp
from torch import optim
from tqdm import tqdm
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, random_split, ConcatDataset
from torch.utils.tensorboard import SummaryWriter
from Networks import Five_UNet
from dataloader import BuildingDataset

device = (torch.device("cuda") if torch.cuda.is_available()
          else torch.device("cpu"))

print(f"Training on device {device}.")

td_KBY = os.path.abspath("/home/chris/Dropbox/HOTOSM/SAMPLE/td_KBY")
td_DZK = os.path.abspath("/home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK")
td_DZKN = os.path.abspath("/home/chris/Dropbox/HOTOSM/SAMPLE/td_DZKN")

#td_KBY = os.path.abspath("/home/mnt/HOTOSM_data/Kakuma/Kalobeyei/td_KBY")
#td_DZK = os.path.abspath("/home/mnt/HOTOSM_data/Dzaleka/td_DZK")
#td_DZKN = os.path.abspath("/home/mnt/HOTOSM_data/Dzaleka_N/td_DZKN")


Training on device cpu.


In [2]:
# Remove non square (512 x 512 imgs)

for root, dirname, filename in os.walk(os.path.join(td_KBY, "LBL")):
    for i in filename:
        if i.endswith("png"):
            with PIL.Image.open(root + "/" + i) as img:
                print(img.size)
                
                if img.size != (512, 512):
                    os.remove(root + "/" + i)


(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)


### Train Val Test split

In [3]:

# Below is a set of relatively complex functions which:
# Perform the train, val, test split at a rounded ratio of 62%, 27%, and 10% based on each sets of imagery
# This will be followed by first pseudo changing the name of _LBL_ to _IMG_ to match the split imagery
# Lastly, once the correct LBL files are matched, 

def tvt_split(td):
    
    img_ls = []
    
    for root, dirs, filename in os.walk(os.path.join(td, "IMG")):
        for i in filename:
            if i.endswith(".png"):
                img_ls.append(root + "/" + i)
        
        img_ls = BuildingDataset(img_ls, _)
        
        train_IMG, val_IMG, test_IMG = random_split(img_ls.png_dir, [int(round(0.6 * len(img_ls.png_dir))),
                                                                     int(round(0.3 * len(img_ls.png_dir))),
                                                                     int(round(0.1 * len(img_ls.png_dir)))])
        
        return train_IMG, val_IMG, test_IMG

DZK_train, DZK_val, DZK_test = tvt_split(td_DZK)
KBY_train, KBY_val, KBY_test = tvt_split(td_KBY)
DZKN_train, DZKN_val, DZKN_test = tvt_split(td_DZKN)

##################
# TOO COMPLICATED#
##################

#def match_LBL(td, imgs):
#    
#    lbl_ls = []
#    img_ls = []
#    match_ls = []
#    
#    imgs = list(imgs)
#    
#    for root, dirs, filename in os.walk(os.path.join(td, "LBL")):
#        for j in filename:
#            if j.endswith(".png"):
#                ps_name = j.rsplit("_LBL_")[0] + "_IMG_" + j.rsplit("_LBL_")[1] # Parse the string, PSEUDO-CHANGE _LBL_ to _IMG_
#                lbl_ls.append(ps_name)
#    
#    for k in imgs:
#        names = os.path.basename(k)
#        img_ls.append(names)
#        
#    def common(a, b):
#        a_set = set(a)
#        b_set = set(b)
#        if (a_set & b_set):
#            return (a_set & b_set)
#        else:
#            print("No common elements")
#            
#            
#    match_ls = common(img_ls, lbl_ls)
#        
#    match_ls = [(root + "/" + n.replace("_IMG_", "_LBL_")) for n in match_ls] # Change the _IMG_ back to _LBL_
#    
#    print("For the selected dataset of {0}, There are: {1} images, {2} labels, and {3} matching image/label pairs.".format(os.path.basename(td), len(img_ls), len(lbl_ls), len(match_ls)))
#    
#    return match_ls

#########################################
# Assign matched LBL to new LBL datasets#
#########################################

#DZKLBL_Train = match_LBL(td_DZK, DZK_train)
#DZKLBL_Val = match_LBL(td_DZK, DZK_val)
#DZKLBL_Test = match_LBL(td_DZK, DZK_test)
#DZKNLBL_Train = match_LBL(td_DZKN, DZKN_train)
#DZKNLBL_Val = match_LBL(td_DZKN, DZKN_val)
#DZKNLBL_Test = match_LBL(td_DZKN, DZKN_test)
#KBYLBL_Train = match_LBL(td_KBY, KBY_train)
#KBYLBL_Val = match_LBL(td_KBY, KBY_val)
#KBYLBL_Test = match_LBL(td_KBY, KBY_test)

############
# Try again#
############

TrainLBL_ls = []
ValLBL_ls = []
TestLBL_ls = []

TrainIMG_ls = list(DZK_train + KBY_train + DZKN_train)
ValIMG_ls = list(DZK_val + KBY_val + DZKN_val)
TestIMG_ls = list(DZK_test + KBY_test + DZKN_test)

for i in TrainIMG_ls:
    i = re.sub("IMG", "LBL", i, count = 2)
    TrainLBL_ls.append(i)

for i in ValIMG_ls:
    i = re.sub("IMG", "LBL", i, count = 2)
    ValLBL_ls.append(i)

for i in TestIMG_ls:
    i = re.sub("IMG", "LBL", i, count = 2)
    TestLBL_ls.append(i)

Train = BuildingDataset(png_dir = TrainIMG_ls,
                        lbl_dir = TrainLBL_ls)

Val = BuildingDataset(png_dir = ValIMG_ls,
                      lbl_dir = ValLBL_ls)

Test = BuildingDataset(png_dir = TestIMG_ls,
                       lbl_dir = TestLBL_ls)

print(len(Train.png_dir) == len(Train.lbl_dir))
print(type(Train.png_dir), type(Train.lbl_dir))

print("Total images and labels pair in DataLoader: {0}".format(len(Train.png_dir) + len(Val.png_dir) + len(Test.png_dir)))

print("Concatenated TRAINING images and labels pair: {0} :".format(len(Train.png_dir)))
for x, y in zip(Train.png_dir, Train.lbl_dir):    
    print(f"Image: {x}", f"Label: {y}")

print("Concatenated VALIDATION images and labels pair: {0} :".format(len(Val.png_dir)))
for x, y in zip(Val.png_dir, Val.lbl_dir):
    print(f"Image: {x}", f"Label: {y}")

print("Concatenated TESTING images: {0} and labels pair: {0} :".format(len(Test.png_dir)))
for x, y in zip(Test.png_dir, Test.lbl_dir):
    print(f"Image: {x}", f"Label: {y}")


True
<class 'list'> <class 'list'>
Total images and labels pair in DataLoader: 29
Concatenated TRAINING images and labels pair: 17 :
Image: /home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/IMG/DZK_IMG_1705-6138.png Label: /home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/LBL/DZK_LBL_1705-6138.png
Image: /home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/IMG/DZK_IMG_1705-5797.png Label: /home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/LBL/DZK_LBL_1705-5797.png
Image: /home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/IMG/DZK_IMG_1705-5456.png Label: /home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/LBL/DZK_LBL_1705-5456.png
Image: /home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/IMG/DZK_IMG_1364-5797.png Label: /home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/LBL/DZK_LBL_1364-5797.png
Image: /home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/IMG/DZK_IMG_1705-4774.png Label: /home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/LBL/DZK_LBL_1705-4774.png
Image: /home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/IMG/DZK_IMG_1364-5456.png Label: /home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/LBL/D

In [4]:
# Trimmed down version
def training_loop1(n_epochs, optimizer, model, xp_name,
                   loss_fn, in_channels, out_channels, train_loader,
                   val_loader, checkpoint_freq, val_freq):

    model = model.train()

    for epoch in tqdm(range(1, n_epochs + 1)):
        
        log_dir = os.path.abspath("/home/chris/Dropbox/HOTOSM/log")
        checkpointdir = os.path.abspath("/home/chris/Dropbox/HOTOSM/checkpoints")
        writer = SummaryWriter(os.path.join(log_dir, xp_name))

        loss_train = 0.0

        for i, (imgs, labels) in tqdm(enumerate(train_loader), total = len(train_loader)):
            imgs = imgs.to(device = device, dtype = torch.float32)
            labels = labels.to(device = device, dtype = torch.float32)

            optimizer.zero_grad()
            prediction = model(imgs)
            loss = loss_fn(prediction.squeeze(), labels)

            loss.backward()
            optimizer.step()
            loss_train += loss.item()
            global_step = epoch * len(train_loader) + i

            if global_step % 10 == 0:
                writer.add_scalar("Train/Loss", loss.item(), global_step = global_step)

        # Validation

        if epoch % val_freq == 0:
            
            model = model.eval()
            val_loss = 0.0

            with torch.no_grad():
                for i, (imgs, labels) in tqdm(enumerate(val_loader), total = len(val_loader)):
                    imgs = imgs.to(device = device, dtype = torch.float32)
                    labels = labels.to(device = device, dtype = torch.float32)

                    val_outIMG = model(imgs)
                    prediction = torch.argmax(val_outIMG)
                    val_loss += loss_fn(prediction, labels)
                    assert val_loss.requires_grad == False

                    if i == 0:
                        labels = labels.cpu().detach().numpy()
                        writer.add_images("Val/Sample_LBL", np.rollaxis(labels.astype(np.uint8), 3, 1),
                                          global_step = global_step)

                        if out_channels > 1:
                            writer.add_images("Val/Sample_conf_1", prediction[:, 0, :, :].unsqueeze(1), 
                                              global_step = global_step)
                            writer.add_images("Val/Sample_conf_2", prediction[:, 1, :, :].unsqueeze(1), 
                                              global_step = global_step)

                            confidence = prediction[:, 0, :, :] - prediction[:, 1, :, :]
                            writer.add_images("Val/Sample_conf", confidence.unsqueeze(1), 
                                              global_step = global_step)

                            prediction = torch.argmax(prediction, 1).cpu().detach().numpy()
                        else:
                            writer.add_images("Val/Sample_conf", confidence.unsqueeze(1), 
                                              global_step = global_step)

                            prediction = (torch.sigmoid(prediction).cpu().detach().numpy() > 0.5).astype(np.uint8)

                        writer.add_images("Val/Sample_pred", np.rollaxis(prediction, 3, 1), 
                                          global_step = global_step)

                if global_step % 10 == 0:
                    writer.add_scalar("Val/Loss", val_loss.item(), global_step = global_step)


            if epoch % checkpoint_freq == 0:
                os.makedirs(os.path.join(checkpointdir, xp_name), exist_ok = True)
                checkpoint_file = os.path.join(checkpointdir, xp_name, xp_name + "_iter_" + str(global_step).zfill(6) + ".pth")
                model_states = {}

                model_state = model.state_dict()

                state = {"Model:": model_state, "Epoch:": epoch, "Steps:": global_step}
                torch.save(state, checkpoint_file)

In [5]:
train_loader = DataLoader(Train, batch_size = 3, shuffle = False) 
val_loader = DataLoader(Val, batch_size = 3, shuffle = False)

Net = Five_UNet()

n_params = [p.numel() for p in Net.parameters() if p.requires_grad == True]

print(Net)
print('Trainable parameters in current model:', n_params)

Five_UNet(
  (encoder1): Sequential(
    (enc1conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc1norm1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc1relu1): ReLU(inplace=True)
    (enc1conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc1norm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc1relu2): ReLU(inplace=True)
  )
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (encoder2): Sequential(
    (enc2conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc2norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc2relu1): ReLU(inplace=True)
    (enc2conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc2norm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True

In [6]:
training_loop1(n_epochs = 100,
              optimizer = torch.optim.Adam(Net.parameters(), lr = 1e-3, weight_decay = 1e-3),
              model = Net,
              in_channels = 3,
              out_channels = 2,
              xp_name = "FIRST17:9_Adam1e-3_wd1e-3_b1_ep100_KLDivLoss",
              loss_fn = nn.KLDivLoss(reduction = 'batchmean'),
              train_loader = train_loader,
              val_loader = val_loader,
              checkpoint_freq = 10,
              val_freq = 10)

100%|██████████| 6/6 [01:23<00:00, 13.83s/it]
100%|██████████| 6/6 [01:26<00:00, 14.41s/it]/it]
  2%|▏         | 2/100 [02:49<2:18:53, 85.04s/it]