In [10]:
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as utils
import torch.utils.data as data
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR

import matplotlib.pyplot as plt
import glob
import re

In [3]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

import os
os.chdir('gdrive/My Drive/Colab Notebooks')

Mounted at /content/gdrive


In [4]:
%run Network.ipynb
%run Datasets.ipynb
%run SSIM.ipynb

In [21]:
# 看上次訓練到第幾個 epoch
def findLastCheckPoint():
  last_pth_file = glob.glob('model*.pth')

  for file in last_pth_file:
    result = re.findall("model(.*).pth", file)
    
  return int(result[0])

In [6]:
# 設定程式 argument
parser = argparse.ArgumentParser(description="PReNet_train")

parser.add_argument("--batch_size", type=int, default=1, help='Training Batch Size')
parser.add_argument("--test_batch_size", type=int, default=1, help="Testing Batch Size")
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
parser.add_argument("--milestone", type=int, default=[30, 50, 80], help="When to decay learning rate")
parser.add_argument("--lr", type=float, default=1e-3, help="Learning Rate")
parser.add_argument("--recurrent_iteration", type=int, default=6, help="Number of recursive stages")
parser.add_argument("--log-interval", type=int, default=1, help="How many batches to wait before logging training status")
parser.add_argument("-f")

opt = parser.parse_args()

In [7]:
# 讀取 datasets
def load_data(train_batch_size, test_batch_size):

  kwargs = {}
  '''
  datasets = PReNet_train_datasets(transform=transforms.Compose([
                                                                     #transforms.Resize([300, 300]),
                                                                     transforms.ToTensor(),
                                                                     transforms.Normalize((0.1307,), (0.3081,))]))
  '''

  datasets = PReNet_train_datasets()

  # 用 dataloader 讀取 datasets
  train_loader = DataLoader(dataset=datasets, batch_size=opt.batch_size, shuffle=True, **kwargs, pin_memory=True, num_workers=4)

  return (train_loader)

In [8]:
def train(model, optimizer, epoch, train_loader, log_interval, use_gpu=False):

  # switch to training mode
  model.train()

  # 開始 iterate
  for batch_idx, (data, target) in enumerate(train_loader):

    #data, target = data.permute(0, 2, 3, 1), target.permute(0, 2, 3, 1)
    data, target = Variable(data), Variable(target)

    if use_gpu:
      data = data.cuda()
      target = target.cuda()

    # 清空 gradient
    optimizer.zero_grad()

    # Forward Propagation
    output = model(data)

    # Calculate loss
    pixel_metric = criterion(target, output)
    loss = -pixel_metric
    
    # Backward Propagation
    loss.backward()

    # Update the gradient
    optimizer.step()

    # 輸出訓練過程
    if batch_idx % log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data.item()))

In [None]:
if __name__ == "__main__":

  # check cuda
  if torch.cuda.is_available():
    print('cuda is available')
    use_gpu = True

  # model
  model = PReNet(recurrent=opt.recurrent_iteration, use_GPU=use_gpu)
  #print(model)

  # Loss function
  criterion = SSIM()

  # 把 model, criterion 移至 gpu 
  if use_gpu:
    model = model.cuda() 
    criterion.cuda()

  # Optimizer
  optimizer = optim.Adam(model.parameters(), lr=opt.lr)
  scheduler = MultiStepLR(optimizer, milestones=opt.milestone, gamma=0.2)

  # loading data
  train_loader = load_data(opt.batch_size, opt.test_batch_size)

  # 從上次的進度開始
  initial_epoch = findLastCheckPoint()

  if initial_epoch >= 0:
    print('start at epoch %d' %(initial_epoch))

    file_name = 'model' + str(initial_epoch) + '.pth'
    print('load file: ', file_name)
    model.load_state_dict(torch.load(file_name))
  else:
    initial_epoch = 0

  # training
  step = 0
  for epoch in range(initial_epoch+1, opt.epochs):
  
    # output learning rate
    for param_group in optimizer.param_groups:
      print('learning rate %f' %(param_group['lr']))

    # epoch start training
    train(model, optimizer, epoch, train_loader, log_interval=opt.log_interval, use_gpu=use_gpu)

    # update learning rate
    scheduler.step(epoch)

    # 每過一個 epoch 就存 model parameters
    os.remove('model%d.pth' %(epoch-1))
    print('remove model%d.pth' %(epoch-1))

    torch.save(model.state_dict(), 'model%d.pth' %(epoch))
    print('save model%d.pth...' %(epoch))

cuda is available
start at epoch 14
load file:  model14.pth
learning rate 0.001000
network cuda is available
iteration:  0
iteration:  1
iteration:  2
iteration:  3
iteration:  4
iteration:  5
network cuda is available
iteration:  0
iteration:  1
iteration:  2
iteration:  3
iteration:  4
iteration:  5
network cuda is available
iteration:  0
iteration:  1
iteration:  2
iteration:  3
iteration:  4
iteration:  5
network cuda is available
iteration:  0
iteration:  1
iteration:  2
iteration:  3
iteration:  4
iteration:  5
network cuda is available
iteration:  0
iteration:  1
iteration:  2
iteration:  3
iteration:  4
iteration:  5
network cuda is available
iteration:  0
iteration:  1
iteration:  2
iteration:  3
iteration:  4
iteration:  5
network cuda is available
iteration:  0
iteration:  1
iteration:  2
iteration:  3
iteration:  4
iteration:  5
network cuda is available
iteration:  0
iteration:  1
iteration:  2
iteration:  3
iteration:  4
iteration:  5
network cuda is available
iteration: 



[1;30;43m串流輸出內容已截斷至最後 5000 行。[0m
iteration:  1
iteration:  2
iteration:  3
iteration:  4
iteration:  5
network cuda is available
iteration:  0
iteration:  1
iteration:  2
iteration:  3
iteration:  4
iteration:  5
network cuda is available
iteration:  0
iteration:  1
iteration:  2
iteration:  3
iteration:  4
iteration:  5
network cuda is available
iteration:  0
iteration:  1
iteration:  2
iteration:  3
iteration:  4
iteration:  5
network cuda is available
iteration:  0
iteration:  1
iteration:  2
iteration:  3
iteration:  4
iteration:  5
network cuda is available
iteration:  0
iteration:  1
iteration:  2
iteration:  3
iteration:  4
iteration:  5
network cuda is available
iteration:  0
iteration:  1
iteration:  2
iteration:  3
iteration:  4
iteration:  5
network cuda is available
iteration:  0
iteration:  1
iteration:  2
iteration:  3
iteration:  4
iteration:  5
network cuda is available
iteration:  0
iteration:  1
iteration:  2
iteration:  3
iteration:  4
iteration:  5
network cuda is