In [1]:
import pytorch_lightning as pl
import time
import os
import sys
import torchvision
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchmetrics import PeakSignalNoiseRatio

from tqdm import tqdm

#from torch.utils.tensorboard import SummaryWriter
# import scipy.io
# from utils.utils_metric import batch_PSNR, batch_SSIM

import options 

from datasets.main_dataset import get_dataloaders
from loss import loss_fn
from networks.dnet import get_dnet
from networks.knet import get_knet

print(torch.__version__)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.autograd.set_detect_anomaly(True)

USE_GPU = False
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda:1')
else:
    device = torch.device('cpu')
print('using device:', device)
#args = set_opts()
args = options.set_opts_jp()
for arg in vars(args):
    print('{:<25s}: {:s}'.format(arg, str(getattr(args, arg))))

1.12.1
using device: cpu
batch_size               : 2
patch_size               : 128
epochs                   : 5
pretraining_epochs       : 1
print_freq               : 1
save_model_freq          : 20
val_props                : 0.1
lr_C                     : 0.0001
lr_M                     : 0.0001
gamma                    : 0.1
clip_grad_C              : 10000.0
clip_grad_M              : 100000.0
train_data_path          : /data/BasesDeDatos/Camelyon/Camelyon17/training/Toy/
pre_kernel_path          : 
log_dir                  : ./log
model_dir                : ./model
resume                   : 
num_workers              : 8
sigmaRui_h_sq            : 0.001
sigmaRui_e_sq            : 0.001
theta                    : 0.5
pre_kl                   : 100.0
pre_mse                  : 0.01
code_len                 : 30
CNet                     : unet_6
MNet                     : resnet_18_in
max_size                 : 3
dirichlet_para_stretch   : 20000
prekernels               : 
nrow    

In [2]:
class DVBCDModel(pl.LightningModule):
    def __init__(self, args):
        super(DVBCDModel, self).__init__()
        self.args = args
        self.cnet = get_dnet(args.CNet)
        self.mnet = get_knet(args.MNet, kernel_size=3)
        self.loss_fn = loss_fn
        self.automatic_optimization = False
        #self.example_input_array = torch.rand(1, 1, 64, 64)
        #self.save_hyperparameters()
        
    def forward(self, y):
        out_MNet_mean, out_Mnet_var = self.mnet(y) # shape: (batch_size, 3, 2), (batch_size, 1, 2)
        out_CNet = self.cnet(y) # shape: (batch_size, 2, H, W)
        return out_MNet_mean, out_Mnet_var, out_CNet
    
    def configure_optimizers(self):
        #pre_optimizer_CNet = optim.Adam(self.cnet.parameters(), lr=5e-4)
        #pre_optimizer_MNet = optim.Adam(self.mnet.parameters(), lr=5e-4)
        CNet_opt = torch.optim.Adam(self.cnet.parameters(), lr=5e-4)
        MNet_opt = torch.optim.Adam(self.mnet.parameters(), lr=5e-4)
        CNet_sch = torch.optim.lr_scheduler.StepLR(CNet_opt, step_size=20, gamma=0.1)
        MNet_sch = torch.optim.lr_scheduler.StepLR(MNet_opt, step_size=20, gamma=0.1)
        return [CNet_opt, MNet_opt], [CNet_sch, MNet_sch]
    
    def training_step(self, batch, batch_idx):

        CNet_opt, MNet_opt = self.optimizers()

        y, mR = batch
        #y = y.to(device)
        #mR = mR.to(device)

        CNet_opt.zero_grad()
        MNet_opt.zero_grad()

        out_MNet_mean, out_Mnet_var = self.mnet(y) # shape: (batch_size, 3, 2), (batch_size, 1, 2)
        out_CNet = self.cnet(y) # shape: (batch_size, 2, H, W)

        loss, loss_mse, loss_kl, loss_kl_h, loss_kl_e = self.loss_fn(  out_CNet, out_MNet_mean, out_Mnet_var, y, mR,
                                            self.args.sigmaRui_h_sq, self.args.sigmaRui_e_sq, 
                                            pretraining=True, pre_mse = self.args.pre_mse, pre_kl = self.args.pre_kl
                                            )
        
        self.manual_backward(loss)
        CNet_opt.step()
        MNet_opt.step()

        CNet_sch, MNet_sch = self.lr_schedulers()
        CNet_sch.step()
        MNet_sch.step()

        self.log_dict({'loss' : loss, 'loss_mse' : loss_mse, 'loss_kl' : loss_kl}, logger=True, prog_bar=True)
        return {'loss' : loss, 'loss_mse' : loss_mse, 'loss_kl' : loss_kl}
    
    def validation_step(self, batch, batch_idx):
        self._shared_eval(batch, batch_idx, 'val')
    
    def test_step(self, batch, batch_idx):
        self._shared_eval(batch, batch_idx, 'test')
    
    def _shared_eval(self, batch, batch_idx, prefix):

        y, mR = batch
        #y = y.to(device)
        #mR = mR.to(device)

        out_MNet_mean, out_Mnet_var = self.mnet(y) # shape: (batch_size, 3, 2), (batch_size, 1, 2)
        out_CNet = self.cnet(y) # shape: (batch_size, 2, H, W)

        loss, loss_mse, loss_kl, loss_kl_h, loss_kl_e = self.loss_fn(
                                                        out_CNet, out_MNet_mean, out_Mnet_var, y, mR,
                                                        self.args.sigmaRui_h_sq, self.args.sigmaRui_e_sq
                                                        )

        #psnr = PeakSignalNoiseRatio()
        if prefix == 'val':
            to_logger = True
            to_progbar = True
        else:
            to_logger = False
            to_progbar = False
        self.log_dict({f'{prefix}_loss' : loss, f'{prefix}_loss_mse' : loss_mse, f'{prefix}_loss_kl' : loss_kl}, logger=to_logger, prog_bar=to_progbar)    

In [3]:
dataloaders = get_dataloaders(args, val_prop=0.3)
train_dataloader = dataloaders['train']
val_dataloader = dataloaders['val']
print(len(train_dataloader), len(val_dataloader))
test_dataloader = dataloaders['test']

Available patches: 6
Available patches: 4
3 1


In [4]:
from pytorch_lightning.callbacks import EarlyStopping, TQDMProgressBar

In [6]:
callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=5), TQDMProgressBar(refresh_rate=1)]
trainer = pl.Trainer(accelerator="cpu", callbacks=callbacks, max_epochs=10, enable_progress_bar=True, log_every_n_steps=1)
model = DVBCDModel(args)
trainer.fit(model, train_dataloader, val_dataloader)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name | Type       | Params
------------------------------------
0 | cnet | UNet       | 5.0 M 
1 | mnet | ResNet18IN | 11.2 M
------------------------------------
16.2 M    Trainable params
0         Non-trainable params
16.2 M    Total params
64.741    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]