In [1]:
#!/usr/bin/python3

import argparse
import itertools
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch
from tqdm import tqdm
import torch.nn as nn
import time
import matplotlib.pyplot as plt
from models.network import SMCSWT
from models.utils import LambdaLR
from models.utils import weights_init_normal
from models.utils import batch_PSNR
from models.datasets import Art_nosie_Dataset
import warnings
warnings.filterwarnings("ignore", category=UserWarning) 

parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=0, help='starting epoch')
parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs of training')
parser.add_argument('--batchSize', type=int, default=4, help='size of the batches')
parser.add_argument('--dataroot', type=str, default='data/', help='root directory of the dataset')
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate')
parser.add_argument('--decay_epoch', type=int, default=50, help='epoch to start linearly decaying the learning rate to 0')
parser.add_argument('--cuda', action='store_true', default=True, help='use GPU computation')
parser.add_argument('--GPU_id', type=str, default=[0], help='Id of GPUs')
parser.add_argument('--n_cpu', type=int, default=10, help='number of cpu threads to use during batch generation')
parser.add_argument('--net', type=str, default='output/net_smcswt.pth', help='A2B generator checkpoint file')
opt = parser.parse_known_args()[0]
print(opt)

if torch.cuda.is_available() and not opt.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")

Namespace(epoch=0, n_epochs=100, batchSize=16, dataroot='data/', lr=0.0002, decay_epoch=50, cuda=True, GPU_id=[4, 5, 6, 7], n_cpu=10, net='output/net_smcswt.pth')


In [None]:
###### Definition of variables ######
# Networks
Net = SMCSWT(window_size1=[8,8], depth1=[1,1], 
                 window_size2=[16,16,16], depth2=[1,1,1], 
                 window_size3=[32,32,32,32], depth3=[1,1,1,1])


device_index = opt.GPU_id[0]  
device = torch.device(f'cuda:{device_index}')
device_ids = opt.GPU_id
if opt.cuda:
    Net.to(device)
    Net = nn.DataParallel(Net, device_ids=device_ids)
Net.apply(weights_init_normal)


pytorch_total_params = sum(p.numel() for p in Net.parameters() if p.requires_grad)
print('Total Number of Parameters:', pytorch_total_params)


# Lossess
criterion = torch.nn.MSELoss(reduction='sum')


# Optimizers & LR schedulers
optimizer = torch.optim.Adam(itertools.chain(Net.parameters()),lr=opt.lr, betas=(0.85, 0.999))
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

# Dataset loader
dataloader = DataLoader(Art_nosie_Dataset(opt.dataroot), 
                        batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu)
print(len(dataloader))


train_losses = []
test_losses = []

###### Training ######
for epoch in range(opt.epoch, opt.n_epochs):
    loss_sigma = 0
    All_psnr = 0
    batch_sizes = 0
    with tqdm(dataloader, unit='batch', dynamic_ncols=True) as tepoch:
        for i, batch in enumerate(tepoch):
            # Set model input
            clean = Variable(batch['A'].to(device))
            noise = Variable(batch['B'].to(device))
    
            ###### Model A ######
            Net.train()
            Net.zero_grad()
            optimizer.zero_grad()
               
            out = Net(noise)
            loss = criterion(out, clean) / (clean.size()[0] * 2)
            loss.backward()
            optimizer.step()          
            ###### Psnr ########
            batch_sizes += 1         
            loss_sigma += loss.item()
            psnr = batch_PSNR(out, clean, 3)
            All_psnr += psnr
            desc1 = '[%d/%d]' % (epoch+1, opt.n_epochs)
            desc2 = 'Psnr:%.4f' % (All_psnr / batch_sizes)
            desc3 = 'Loss:%.4f' % (loss_sigma / batch_sizes)
            tepoch.set_description(desc1 + ' ' + desc2 + ' ' + desc3)
            time.sleep(0.1)
            
    lr_scheduler.step()

    torch.save(Net.state_dict(), 'output/net_smcswt.pth')

Total Number of Parameters: 1363133
641


  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
[1/100] Psnr:28.0337 Loss:538.8910:   1%|▏               | 6/641 [00:05<07:52,  1.34batch/s]