In [1]:
# hyperPara config
batch_size: int = 16
epochs: int = 5

# optimizer config
learning_rate: float = 1.0e-4

# dataset config
is_conv = False
model_save_dir: str = './models'
model_name: str = 'unet_v1'
num_workers: int = 12

In [2]:
import warnings
warnings.filterwarnings(action='ignore')

from datetime import datetime
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Variable
from torch import nn
from tqdm import tqdm
import cv2
import numpy as np
import os

import codes.utils as util

conv_image = None
if is_conv:
    def conv_image(src):
        return conv.binary(src, val=255/2)

torch.manual_seed(0)
totensor = torchvision.transforms.ToTensor()

start_time = datetime.now()
time_str = start_time.strftime("%b%d_%H-%M-%S")
log_dir = os.path.join("models", f"{time_str}_{model_name}")
writer = SummaryWriter(log_dir=log_dir)

assert(torch.cuda.is_available() == True)

print(f"Using CUDA device")
print(f"log at: {log_dir}")

Using CUDA device
log at: models/Dec03_20-17-01_unet_v1


In [3]:
# 1. load model
import codes.networks as network

model = network.unet_v1()
model = model.cuda()

print("Model Loaded...")

Model Loaded...


In [4]:
# 2. load dataset
from torch.utils.data import DataLoader
import codes.datas as data
import codes.convs as conv

train_dataset = data.nyu_v2_kaggle(dir='data/nyu2_train.csv', y_res=(60, 80), subset=conv_image)
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True, shuffle=True)

val_dataset = data.nyu_v2_kaggle(dir='data/nyu2_test.csv', y_res=(60, 80), subset=conv_image)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True, shuffle=True)

print("Dataset Loaded...")

Dataset Loaded...


In [5]:
# 3. load loss func
import codes.loss as loss

loss_fn = loss.RMSELoss().cuda()

print("Loss Function Loaded...")

Loss Function Loaded...


In [6]:
# 4. load optimizer
import torch.optim as optim

optimizer = optim.AdamW(model.parameters(), lr=learning_rate, eps=1.0e-8)

print("Optimizer Loaded...")

Optimizer Loaded...


In [7]:
import codes.utils as util

x1 = util.read_image("data/nyu2_test/00001_colors.png", 1, c=conv_image)
x2 = util.read_image("data/nyu2_test/00974_colors.png", 1, c=conv_image)
x1_n = cv2.resize(cv2.imread("data/nyu2_test/00001_colors.png", 1), (160, 120))
x2_n = cv2.resize(cv2.imread("data/nyu2_test/00974_colors.png", 1), (160, 120))

In [8]:
# 5. training
model.train()
print(f"epochs: {epochs}")
for epoch in range(1, epochs + 1):
    # model train
    model.train()
    loss_sum = 0.0
    batch_len = len(train_loader)
    batch_runner = tqdm(train_loader, desc=f"Epoch {epoch}", unit="batch")
    for batch, data in enumerate(batch_runner, start=1):
        t = torch.cuda.FloatTensor
        x, y = data
        x, y = Variable(x.type(t)), Variable(y.type(t))

        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_sum += loss.item()
        batch_runner.set_postfix(loss=f"{loss_sum / batch:.04f}")

    loss_avg = loss_sum / batch_len
    writer.add_scalar('(Train) Loss/epoch', loss_avg, epoch)

    # model val
    model.eval()
    error_sum = 0.0
    batch_len = len(val_loader)
    batch_runner = tqdm(val_loader, desc=f"Validation {epoch}", unit="batch")
    with torch.no_grad():
        for batch, data in enumerate(batch_runner, start=1):
            t = torch.cuda.FloatTensor
            x, y = data
            x, y = Variable(x.type(t)), Variable(y.type(t))

            y_pred = model(x)
            error = loss_fn(y_pred, y)

            error_sum += error.item()
            batch_runner.set_postfix(Error=f"{error_sum / batch:.04f}")

        loss_avg = loss_sum / batch_len
        error_avg = error_sum / batch_len
        writer.add_scalar('(Val) Error/epoch', error_avg, epoch)

    y1, y2 = model(x1.cuda()).cpu(), model(x2.cuda()).cpu()
    y1, y2 = y1.detach().numpy().squeeze(), y2.detach().numpy().squeeze()
    fig = util.make_plot(x1_n, y1, x2_n, y2)
    fig = util.plot_to_img(fig)
    writer.add_image("result(1, 974)", fig, epoch)
    
    writer.flush()

    # model save
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)
        
    dir_to_save = os.path.join(log_dir, f'epoch_{epoch}.pth')
    torch.save(model.state_dict(), dir_to_save)


writer.close()

epochs: 5


Epoch 1: 100%|██████████| 3168/3168 [02:50<00:00, 18.53batch/s, loss=0.7378]
Validation 1: 100%|██████████| 40/40 [00:02<00:00, 19.97batch/s, Error=0.4450]
Epoch 2: 100%|██████████| 3168/3168 [02:51<00:00, 18.52batch/s, loss=0.3825]
Validation 2: 100%|██████████| 40/40 [00:02<00:00, 19.88batch/s, Error=0.2648]
Epoch 3: 100%|██████████| 3168/3168 [02:50<00:00, 18.54batch/s, loss=0.1029]
Validation 3: 100%|██████████| 40/40 [00:02<00:00, 19.70batch/s, Error=0.2347]
Epoch 4: 100%|██████████| 3168/3168 [02:50<00:00, 18.56batch/s, loss=0.0457]
Validation 4: 100%|██████████| 40/40 [00:02<00:00, 18.40batch/s, Error=0.2512]
Epoch 5: 100%|██████████| 3168/3168 [02:50<00:00, 18.55batch/s, loss=0.0431]
Validation 5: 100%|██████████| 40/40 [00:02<00:00, 19.37batch/s, Error=0.2492]
