In [1]:
##############################
#Training loop for NNs       #
#Maintainer: Christopher Chan#
#Version: 0.0.6              #
#Date: 2022-02-17            #
##############################

import os
import sys
import torch
import pathlib
import time
import re
import numpy as np
import torch.nn as nn
import segmentation_models_pytorch as smp
from torch import optim
from tqdm import tqdm
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, random_split, ConcatDataset
from torch.utils.tensorboard import SummaryWriter
from Networks import Five_UNet
from dataloader import BuildingDataset

device = (torch.device("cuda") if torch.cuda.is_available()
          else torch.device("cpu"))

print(f"Training on device {device}.")

Training on device cpu.


### Train Val Test split

In [2]:
td_KBY = os.path.abspath("/home/chris/Dropbox/HOTOSM/SAMPLE/td_KBY")
td_DZK = os.path.abspath("/home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK")
td_DZKN = os.path.abspath("/home/chris/Dropbox/HOTOSM/SAMPLE/td_DZKN")

#td_KBY = os.path.abspath("/home/mnt/HOTOSM_data/Kakuma/Kalobeyei/td_KBY")
#td_DZK = os.path.abspath("/home/mnt/HOTOSM_data/Dzaleka/td_DZK")
#td_DZKN = os.path.abspath("/home/mnt/HOTOSM_data/Dzaleka_N/td_DZKN")
#
# Below is a set of relatively complex functions which:
# Perform the train, val, test split at a rounded ratio of 62%, 27%, and 10% based on each sets of imagery
# This will be followed by first pseudo changing the name of _LBL_ to _IMG_ to match the split imagery
# Lastly, once the correct LBL files are matched, 

def tvt_split(td):
    
    img_ls = []
    
    for root, dirs, filename in os.walk(os.path.join(td, "IMG")):
        for i in filename:
            if i.endswith(".png"):
                img_ls.append(root + "/" + i)
        
        img_ls = BuildingDataset(img_ls, _)
        
        train_IMG, val_IMG, test_IMG = random_split(img_ls.png_dir, [int(round(0.6 * len(img_ls.png_dir))),
                                                                     int(round(0.3 * len(img_ls.png_dir))),
                                                                     int(round(0.1 * len(img_ls.png_dir)))])
        
        return train_IMG, val_IMG, test_IMG

DZK_train, DZK_val, DZK_test = tvt_split(td_DZK)
KBY_train, KBY_val, KBY_test = tvt_split(td_KBY)
DZKN_train, DZKN_val, DZKN_test = tvt_split(td_DZKN)

##################
# TOO COMPLICATED#
##################

#def match_LBL(td, imgs):
#    
#    lbl_ls = []
#    img_ls = []
#    match_ls = []
#    
#    imgs = list(imgs)
#    
#    for root, dirs, filename in os.walk(os.path.join(td, "LBL")):
#        for j in filename:
#            if j.endswith(".png"):
#                ps_name = j.rsplit("_LBL_")[0] + "_IMG_" + j.rsplit("_LBL_")[1] # Parse the string, PSEUDO-CHANGE _LBL_ to _IMG_
#                lbl_ls.append(ps_name)
#    
#    for k in imgs:
#        names = os.path.basename(k)
#        img_ls.append(names)
#        
#    def common(a, b):
#        a_set = set(a)
#        b_set = set(b)
#        if (a_set & b_set):
#            return (a_set & b_set)
#        else:
#            print("No common elements")
#            
#            
#    match_ls = common(img_ls, lbl_ls)
#        
#    match_ls = [(root + "/" + n.replace("_IMG_", "_LBL_")) for n in match_ls] # Change the _IMG_ back to _LBL_
#    
#    print("For the selected dataset of {0}, There are: {1} images, {2} labels, and {3} matching image/label pairs.".format(os.path.basename(td), len(img_ls), len(lbl_ls), len(match_ls)))
#    
#    return match_ls

#########################################
# Assign matched LBL to new LBL datasets#
#########################################

#DZKLBL_Train = match_LBL(td_DZK, DZK_train)
#DZKLBL_Val = match_LBL(td_DZK, DZK_val)
#DZKLBL_Test = match_LBL(td_DZK, DZK_test)
#DZKNLBL_Train = match_LBL(td_DZKN, DZKN_train)
#DZKNLBL_Val = match_LBL(td_DZKN, DZKN_val)
#DZKNLBL_Test = match_LBL(td_DZKN, DZKN_test)
#KBYLBL_Train = match_LBL(td_KBY, KBY_train)
#KBYLBL_Val = match_LBL(td_KBY, KBY_val)
#KBYLBL_Test = match_LBL(td_KBY, KBY_test)

############
# Try again#
############

TrainLBL_ls = []
ValLBL_ls = []
TestLBL_ls = []

TrainIMG_ls = list(DZK_train + KBY_train + DZKN_train)
ValIMG_ls = list(DZK_val + KBY_val + DZKN_val)
TestIMG_ls = list(DZK_test + KBY_test + DZKN_test)

for i in TrainIMG_ls:
    i = re.sub("IMG", "LBL", i, count = 2)
    TrainLBL_ls.append(i)

for i in ValIMG_ls:
    i = re.sub("IMG", "LBL", i, count = 2)
    ValLBL_ls.append(i)

for i in TestIMG_ls:
    i = re.sub("IMG", "LBL", i, count = 2)
    TestLBL_ls.append(i)

Train = BuildingDataset(png_dir = TrainIMG_ls,
                        lbl_dir = TrainLBL_ls)

Val = BuildingDataset(png_dir = ValIMG_ls,
                      lbl_dir = ValLBL_ls)

Test = BuildingDataset(png_dir = TestIMG_ls,
                       lbl_dir = TestLBL_ls)

print(len(Train.png_dir) == len(Train.lbl_dir))
print(type(Train.png_dir), type(Train.lbl_dir))

print("Total images and labels pair in DataLoader: {0}".format(len(Train.png_dir) + len(Val.png_dir) + len(Test.png_dir)))

print("Concatenated TRAINING images and labels pair: {0} :".format(len(Train.png_dir)))
for x, y in zip(Train.png_dir, Train.lbl_dir):    
    print(f"Image: {x}", f"Label: {y}")

print("Concatenated VALIDATION images and labels pair: {0} :".format(len(Val.png_dir)))
for x, y in zip(Val.png_dir, Val.lbl_dir):
    print(f"Image: {x}", f"Label: {y}")

print("Concatenated TESTING images: {0} and labels pair: {0} :".format(len(Test.png_dir)))
for x, y in zip(Test.png_dir, Test.lbl_dir):
    print(f"Image: {x}", f"Label: {y}")


True
<class 'list'> <class 'list'>
Total images and labels pair in DataLoader: 30
Concatenated TRAINING images and labels pair: 18 :
Image: /home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/IMG/DZK_IMG_1364-5456.png Label: /home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/LBL/DZK_LBL_1364-5456.png
Image: /home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/IMG/DZK_IMG_1705-4774.png Label: /home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/LBL/DZK_LBL_1705-4774.png
Image: /home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/IMG/DZK_IMG_1364-5797.png Label: /home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/LBL/DZK_LBL_1364-5797.png
Image: /home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/IMG/DZK_IMG_1364-6138.png Label: /home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/LBL/DZK_LBL_1364-6138.png
Image: /home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/IMG/DZK_IMG_1364-11594.png Label: /home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/LBL/DZK_LBL_1364-11594.png
Image: /home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/IMG/DZK_IMG_1705-4433.png Label: /home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/LBL

import datetime

def training_loop(n_epochs, optimizer, model, loss_fn, train_loader):
    for epoch in range(1, n_epochs + 1):
        loss_train = 0.0
        for imgs, labels in train_loader:
            imgs = imgs.to(device = device)
            labels = labels.to(device = device)
            outputs = model(imgs)
            loss = loss_fn(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_train += loss.item()

        if epoch == 1 or epoch % 10 == 0:
            print("{} Epoch {}, Training loss {}".format(
                datetime.datetime.now(), epoch,
                loss_train/len(train_loader)))

train_loader = DataLoader(Train, batch_size = 1, shuffle = False)
val_loader = DataLoader(Val, batch_size = 1, shuffle = False)

model = Five_UNet
model.to(device = device)
optimizer = optim.Adam(model.parameters(), lr = 1e-3)
loss_fn = nn.CrossEntropyLoss()

def validate(model, train_loader, val_loader):
    for name, loader in [("train", train_loader), ("val", val_loader)]:
        correct = 0
        total = 0

        with torch.no_grad():
            for imgs, labels in loader:
                outputs = model(imgs)
                _, predicted = torch.max(outputs, dim = 1)
                total += labels.shape[0]
                correct += int((predicted == labels).sum())

        print("Accuracy {}: {:.2f}".format(name, correct/total))

In [3]:
def training_loop(n_epochs, optimizer, model, xp_name, loss_fn, 
                  train_loader, val_loader, checkpoint_freq, val_freq):
    for epoch in tqdm(range(1, n_epochs + 1)):

        model = model
        
        log_dir = os.path.abspath("/home/chris/Dropbox/HOTOSM/log")
        checkpointdir = os.path.abspath("/home/chris/Dropbox/HOTOSM/checkpoints")
        writer = SummaryWriter(os.path.join(log_dir, xp_name))

        for i, batch in tqdm(enumerate(train_loader), total = len(Train.png_dir)):
            img = batch["png_dir"].float32()
            img = img.to(device = device)
            lbl = batch["lbl_dir"].float32()
            lbl = lbl.to(device = device)

            optimizer.zero_grad()
        
            loss = loss_fn(prediction.squeeze(), target)

            loss.backward()
            optimizer.step()
            global_step = epoch * len(train_loader) + i

            if global_step % 10 == 0:
                writer.add_scalar('train/loss', loss.item(), global_step=global_step)

        val_loss = 0

        if epoch % val_freq == 0:
            with torch.no_grad():
                for i, batch in tqdm(enumerate(val_loader), total = len(Val.png_dir)):

                    img = batch['png_dir'].float32()
                    img = img.to(device = device)
                    lbl = batch['lbl_dir'].float32()
                    lbl = lbl.to(device = device)

                    prediction = model.squeeze()
                    val_loss += loss_fn(prediction, target)

                    if i == 0:
                        writer.add_images('val/samples_img', img[:, 1:4, :, :] / img.max(), global_step = global_step)
                        target = target.cpu().detach().numpy()
                        writer.add_images('val/samples_lbl', np.rollaxis(target.astype(np.uint8), 3, 1), global_step = global_step)

                        writer.add_images('val/samples_pred_conf', torch.sigmoid(prediction).unsqueeze(1), global_step = global_step)
                        prediction = (torch.sigmoid(prediction).cpu().detach().numpy() > 0.5).astype(np.uint8)
                        writer.add_images('val/samples_pred', np.rollaxis(prediction, 3, 1), global_step = global_step)

                # print('val loss after epoch', epoch, '=', val_loss.item()/len(val_set))
                writer.add_scalar('val/loss', val_loss.item(), global_step = global_step)

        if epoch % checkpoint_freq == 0:
            os.makedirs(os.path.join(checkpointdir, xp_name), exist_ok = True)
            checkpoint_file = os.path.join(checkpointdir, xp_name, xp_name + '_iter_' + str(global_step).zfill(6) + '.pth')
            model_states = {}
            for model_id in model.keys():
                model_states[model_id] = model[model_id].state_dict()
            state = {'models': model_states, 'epoch': epoch, 'step': global_step}
            torch.save(state, checkpoint_file)

In [4]:
Net = Five_UNet()

n_params = [p.numel() for p in Net.parameters() if p.requires_grad == True]

print(Net)
print('Trainable parameters in current model:', n_params)

loss_fn = nn.CrossEntropyLoss()
train_loader = DataLoader(Train, batch_size = 3, shuffle = False) 
val_loader = DataLoader(Val, batch_size = 3, shuffle = False)

Five_UNet(
  (encoder1): Sequential(
    (enc1conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc1norm1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc1relu1): ReLU(inplace=True)
    (enc1conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc1norm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc1relu2): ReLU(inplace=True)
  )
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (encoder2): Sequential(
    (enc2conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc2norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc2relu1): ReLU(inplace=True)
    (enc2conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc2norm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True

In [5]:
    print(Train.png_dir[3])

/home/chris/Dropbox/HOTOSM/SAMPLE/td_DZK/IMG/DZK_IMG_1364-6138.png


In [6]:
training_loop(n_epochs = 500,
              optimizer = torch.optim.Adam(Net.parameters(), lr = 1e-3, weight_decay = 1e-3),
              model = Net,
              xp_name = "18:9_Adam1e-3_wd1e-3_b1_ep100",
              loss_fn = loss_fn,
              train_loader = train_loader,
              val_loader = val_loader,
              checkpoint_freq = 10,
              val_freq = 10)

  0%|          | 0/18 [00:00<?, ?it/s]]
  0%|          | 0/500 [00:00<?, ?it/s]


TypeError: 'tuple' object cannot be interpreted as an integer