In [1]:
import time
import os
import sys
import torchvision
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
# import scipy.io
# from utils.utils_metric import batch_PSNR, batch_SSIM

import options 

from datasets.main_dataset import get_dataloaders
from loss import loss_fn
from networks.dnet import get_dnet
from networks.knet import get_knet

In [2]:
print(torch.__version__)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.autograd.set_detect_anomaly(True)

USE_GPU = False
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda:1')
else:
    device = torch.device('cpu')
print('using device:', device)

1.12.1
using device: cpu


In [3]:
#args = set_opts()
args = options.set_opts_jp()
for arg in vars(args):
    print('{:<25s}: {:s}'.format(arg, str(getattr(args, arg))))

batch_size               : 2
patch_size               : 128
epochs                   : 5
pretraining_epochs       : 1
print_freq               : 1
save_model_freq          : 20
lr_C                     : 0.0001
lr_M                     : 0.0001
gamma                    : 0.1
clip_grad_C              : 10000.0
clip_grad_M              : 100000.0
train_data_path          : /data/BasesDeDatos/Camelyon/Camelyon17/training/Toy/
pre_kernel_path          : 
log_dir                  : ./log
model_dir                : ./model
resume                   : 
num_workers              : 8
sigmaRui_h_sq            : 0.001
sigmaRui_e_sq            : 0.001
theta                    : 0.5
pre_kl                   : 100.0
pre_mse                  : 0.01
code_len                 : 30
CNet                     : unet_6
MNet                     : resnet_18_in
max_size                 : 3
dirichlet_para_stretch   : 20000
prekernels               : 
nrow                     : 8
epoch_start_test         : 20
skip_

In [4]:
def adjust_learning_rate(optimizer, epoch, args):
    if epoch <= 60:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 1e-4
    elif epoch <= 80:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 1e-5
    else:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 1e-6

In [5]:
# cnet <--> dnet
# mnet <--> knet
cnet = get_dnet(args.CNet)
mnet = get_knet(args.MNet, kernel_size=3)
cnet = cnet.to(device)
mnet = mnet.to(device)
optimizer_c = optim.Adam(cnet.parameters(), args.lr_C)
optimizer_m = optim.Adam(mnet.parameters(), args.lr_M)
pretraining_epochs = args.pretraining_epochs
#writer.. = SummaryWriter(args.log_dir)


In [6]:
args.epoch_start = 0
step = 0
#step_img = {x: 0 for x in _modes}
os.makedirs(args.log_dir, exist_ok=True)
os.makedirs(args.model_dir, exist_ok=True)

dataloaders = get_dataloaders(args)
train_dataloader = dataloaders['train']
val_dataloader = dataloaders['val']
test_dataloader = dataloaders['test']

num_step_per_epoch = len(train_dataloader)
pre_optimizer_c = optim.Adam(cnet.parameters(), lr=5e-4)
pre_optimizer_m = optim.Adam(mnet.parameters(), lr=5e-4)

Available patches: 6
Available patches: 4


In [9]:
loss_per_epoch = {'loss' : 0, 'loss_mse' : 0, 'loss_kl_h' : 0, 'loss_kl_e' : 0 }
for epoch in range(args.pretraining_epochs):
    tic = time.time()
    cnet.train()
    mnet.train()

    train_loss = 0
    for ii, data in enumerate(train_dataloader):
        y = data[0].to(device) # shape: (batch_size, 3, H, W)
        mR = data[1].to(device)  # shape: (batch_size, 1, 3, 2)
        
        pre_optimizer_m.zero_grad()
        pre_optimizer_c.zero_grad()
        out_MNet_mean, out_Mnet_var = mnet(y) # shape: (batch_size, 3, 2), (batch_size, 1, 2)
        out_CNet = cnet(y) # shape: (batch_size, 2, H, W)

        loss, loss_mse, loss_kl, _, _ = loss_fn(  out_CNet, out_MNet_mean, out_Mnet_var, y, mR,
                                            args.sigmaRui_h_sq, args.sigmaRui_e_sq, 
                                            pretraining=True, pre_mse = args.pre_mse, pre_kl = args.pre_kl
                                            )

        loss.backward()
        pre_optimizer_m.step()
        pre_optimizer_c.step()
        train_loss += loss.item() / num_step_per_epoch
    
    print(f"Pretraining: Epoch: {epoch + 1}, Loss={train_loss:.4e}")

Pretraining: Epoch: 1, Loss=-inf


In [10]:
loss_per_epoch = {'loss' : 0, 'loss_mse' : 0, 'loss_kl_h' : 0, 'loss_kl_e' : 0 }

for epoch in range(args.epochs):

    adjust_learning_rate(optimizer_m, epoch, args)
    adjust_learning_rate(optimizer_c, epoch, args)
    grad_norm_C = grad_norm_M = 0
    lr_C = optimizer_c.param_groups[0]['lr']
    lr_M = optimizer_m.param_groups[0]['lr']
    loss_mse_per_epoch = 0
    for ii, data in enumerate(train_dataloader):

        y = data[0].to(device) # shape: (batch_size, 3, H, W)
        mR = data[1].to(device)  # shape: (batch_size, 1, 3, 2)

        optimizer_m.zero_grad()
        optimizer_c.zero_grad()

        out_MNet_mean, out_Mnet_var = mnet(y) # shape: (batch_size, 3, 2), (batch_size, 1, 2)
        out_CNet = cnet(y) # shape: (batch_size, 2, H, W)

        loss, loss_mse, loss_kl, loss_kl_h, loss_kl_e = loss_fn(  out_CNet, out_MNet_mean, out_Mnet_var, y, mR,
                                            args.sigmaRui_h_sq, args.sigmaRui_e_sq
                                            )
        
        loss.backward()
        total_norm_C = nn.utils.clip_grad_norm_(cnet.parameters(), args.clip_grad_C)
        total_norm_M = nn.utils.clip_grad_norm_(mnet.parameters(), args.clip_grad_M)
        grad_norm_C = grad_norm_C + total_norm_C / num_step_per_epoch
        grad_norm_M = grad_norm_M + total_norm_M / num_step_per_epoch
        optimizer_c.step()
        optimizer_m.step()

        loss_per_epoch['loss'] += loss.item() / num_step_per_epoch
        loss_per_epoch['loss_mse'] += loss_mse.item() / num_step_per_epoch
        loss_per_epoch['loss_kl_h'] += loss_kl_h.item() / num_step_per_epoch
        loss_per_epoch['loss_kl_e'] += loss_kl_e.item() / num_step_per_epoch

        r_con_e = torch.clamp(out_CNet[:, :1, :, : ].detach().data, 0.0, 1.0)
        r_con_h = torch.clamp(out_CNet[:, 1:, :, :].detach().data, 0.0, 1.0)


        # mse = F.mse_loss(y[0,:,:], out_CNet[0,0,:,:])
        loss_mse_per_epoch += loss_mse / num_step_per_epoch

        if (ii + 1) % args.print_freq == 0:
            print(f"[Epoch {epoch + 1:0>4d}/{args.epochs:0>4d}], Iter {ii + 1:0>5d}/{num_step_per_epoch:0>5d}, "
                  f"loss={loss.item():.4e}, loss_mse={loss_mse.item():.4e}, loss_kl_h={loss_kl_h.item():.4e}, loss_kl_e={loss_kl_e.item():.4e}, "
                  f"grad_norm_C={args.clip_grad_C:.2e}/{total_norm_C:.2e}, grad_norm_M={args.clip_grad_M:.2e}/{total_norm_M:.2e}"
                  f"lr_C={lr_C:.1e}, lr_M={lr_M:.1e}")

            step += 1

    print(f"[Epoch {epoch + 1:0>4d}/{args.epochs:0>4d}], Iter {ii + 1:0>5d}/{num_step_per_epoch:0>5d}, "
          f"loss={loss_per_epoch['loss']:.4e}, loss_mse={loss_per_epoch['loss_mse']:.4e}, loss_kl_h={loss_per_epoch['loss_kl_h']:.4e}, loss_kl_e={loss_per_epoch['loss_kl_e']:.4e}, "
          )

    toc = time.time()
    print(f'This epoch took time {toc - tic:.2f} s.')
#writer.close()
print('Reached the maximal epochs! Finish training')

[Epoch 0001/0005], Iter 00001/00003, loss=-4.5099e+02, loss_mse=5.6816e+03, loss_kl_h=-1.0563e+00, loss_kl_e=-6.5826e+03, grad_norm_C=1.00e+04/3.79e+05, grad_norm_M=1.00e+05/2.04e+06lr_C=1.0e-04, lr_M=1.0e-04
[Epoch 0001/0005], Iter 00002/00003, loss=-inf, loss_mse=1.0440e+03, loss_kl_h=-inf, loss_kl_e=-6.7920e+03, grad_norm_C=1.00e+04/2.20e+04, grad_norm_M=1.00e+05/1.64e+04lr_C=1.0e-04, lr_M=1.0e-04
[Epoch 0001/0005], Iter 00003/00003, loss=-inf, loss_mse=3.6944e+02, loss_kl_h=-inf, loss_kl_e=-6.5659e+03, grad_norm_C=1.00e+04/1.18e+04, grad_norm_M=1.00e+05/1.64e+04lr_C=1.0e-04, lr_M=1.0e-04
[Epoch 0001/0005], Iter 00003/00003, loss=-inf, loss_mse=2.3650e+03, loss_kl_h=-inf, loss_kl_e=-6.6468e+03, 
This epoch took time 23.46 s.
[Epoch 0002/0005], Iter 00001/00003, loss=-inf, loss_mse=1.0729e+03, loss_kl_h=-inf, loss_kl_e=-6.7697e+03, grad_norm_C=1.00e+04/1.49e+04, grad_norm_M=1.00e+05/1.67e+04lr_C=1.0e-04, lr_M=1.0e-04
[Epoch 0002/0005], Iter 00002/00003, loss=-inf, loss_mse=2.1266e+02