In [1]:
import torch
from torch import nn
from datetime import datetime

In [2]:
import sys
import os
sys.path.append('../')
import model
from utils import evaluate_psnr

Using CPU ❌ 😭


In [3]:
if torch.cuda.is_available():
  device = torch.device('cuda')
  print('Using GPU. ✅')
else:
  device = torch.device('cpu')
  print('Using CPU ❌ 😭')

Using CPU ❌ 😭


Make a folder to save the model in:

In [4]:
if not(os.path.exists('../models')):
    os.mkdir('../models')

# Load data

In [5]:
# Load train data
TRAIN_PATH = '../../miniproject_dataset/train_data.pkl'
train_noisy_imgs_input, train_noisy_imgs_target = torch.load(TRAIN_PATH, map_location=device)
training_set_size,num_channels,y_size,x_size = train_noisy_imgs_input.shape
print('Number of loaded images: ', train_noisy_imgs_input.shape[0])

Number of loaded images:  50000


In [6]:
# Load validation data
VALIDATION_PATH = '../../miniproject_dataset/val_data.pkl'
validation_noisy_imgs, validation_clean_imgs = torch.load(VALIDATION_PATH, map_location=device)
print('Number of validation images: ', validation_clean_imgs.shape[0])

Number of validation images:  1000


# Specify which model to train

In [7]:
def set_model(model_name):
    net = model.Model()
    if (model_name=='BaseNet'):
        net.model = model.BaseNet()
        net.criterion =  torch.nn.MSELoss()
    elif (model_name=='UNet'):
        net.model = model.UNet()
        net.criterion =  torch.nn.MSELoss()
    elif (model_name=='TNet'):
        net.model = model.TNet()
        net.criterion =  torch.nn.MSELoss()
    elif (model_name=='TNet_batchnorm'):
        net.model = model.TNet(batchnorm=True)
        net.criterion =  torch.nn.MSELoss()
    elif (model_name=='TNet_L1'):
        net.model = model.TNet()
        net.criterion =  torch.nn.L1Loss()
    else:
        raise ValueError('Invalid model name!')
    net.model.to(device)
    net.criterion.to(device)
    return net

In [8]:
#model_name = 'BaseNet'          
#model_name = 'UNet'             
model_name = 'TNet'           
#model_name = 'TNet_batchnorm'
#model_name = 'TNet_L1'        

In [9]:
accuracies = []
nb_epochs = 50

for i in range(5):
  print('Training round ',i+1,' out of 5:')
  train_start_time_ms = datetime.now()
  net = set_model(model_name)
  net.model.train()
  net.train(train_noisy_imgs_input, train_noisy_imgs_target, nb_epochs)
  train_end_time_ms = datetime.now()
  training_time = train_end_time_ms - train_start_time_ms
  print('Training took: ', training_time)

  # Evaluate model
  net.model.eval()
  with torch.no_grad():
    psnr = evaluate_psnr(net, validation_noisy_imgs, validation_clean_imgs)
  print('Peak signal-to-noise = ',psnr.item(),' dB')
  accuracies.append(psnr)

  # Save model 
  checkpoint = {'model': net.model.state_dict(),
                'optimizer' : net.optimizer.state_dict(),
                'lr' : net.lr,
                'batch_size' : net.batch_size,
                'losses' : net.losses
                }
  torch.save(checkpoint, '../models/'+model_name+'_'+str(i)+'.pth')

with open('../models/'+model_name+'_accuracies.txt', 'w') as f:
    f.write('accuracies = ' + str(accuracies))

Training round  1  out of 5:


KeyboardInterrupt: 