In [1]:
import time
import os
import sys
import torchvision
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchmetrics import PeakSignalNoiseRatio

from tqdm import tqdm

#from torch.utils.tensorboard import SummaryWriter
# import scipy.io
# from utils.utils_metric import batch_PSNR, batch_SSIM

import options 

from datasets.main_dataset import get_dataloaders
from loss import loss_fn, loss_BCD
from networks.cnet import get_cnet
from networks.mnet import get_mnet

print(torch.__version__)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.autograd.set_detect_anomaly(True)

USE_GPU = True
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
if USE_GPU and torch.cuda.is_available():
    DEVICE = torch.device('cuda:1')
else:
    DEVICE = torch.device('cpu')
print('using device:', DEVICE)
#args = set_opts()
args = options.set_opts_jp()
for arg in vars(args):
    print('{:<25s}: {:s}'.format(arg, str(getattr(args, arg))))

1.12.1
using device: cpu
batch_size               : 2
patch_size               : 128
epochs                   : 5
pretraining_epochs       : 1
print_freq               : 1
save_model_freq          : 20
val_props                : 0.1
lr_C                     : 0.0001
lr_M                     : 0.0001
gamma                    : 0.1
clip_grad_C              : 10000.0
clip_grad_M              : 100000.0
train_data_path          : /data/BasesDeDatos/Camelyon/Camelyon17/training/Toy/
pre_kernel_path          : 
log_dir                  : ./log
model_dir                : ./model
resume                   : 
num_workers              : 8
sigmaRui_h_sq            : 0.001
sigmaRui_e_sq            : 0.001
theta                    : 0.5
pre_kl                   : 100.0
pre_mse                  : 0.01
code_len                 : 30
CNet                     : unet_6
MNet                     : resnet_18_in
max_size                 : 3
dirichlet_para_stretch   : 20000
prekernels               : 
nrow    

In [2]:
class DVBCDModel():
    def __init__(self, args, device="cpu"):
        self.args = args
        self.cnet = get_cnet(args.CNet)
        self.mnet = get_mnet(args.MNet, kernel_size=3)
        self.loss_fn = loss_BCD
        self.device = device

        self.sigmaRui_sq = torch.tensor([args.sigmaRui_h_sq, args.sigmaRui_e_sq])
        self.theta = 0.5
        self.pre_mse = args.pre_mse
        self.pre_kl = args.pre_kl

        self.optim_initiated = False
        self.optimizers = None
        self.lr_schedulers = None
        
    def forward(self, y):
        out_MNet_mean, out_Mnet_var = self.mnet(y) # shape: (batch_size, 3, 2), (batch_size, 1, 2)
        out_CNet = self.cnet(y) # shape: (batch_size, 2, H, W)
        return out_MNet_mean, out_Mnet_var, out_CNet

    def train(self):
        self.cnet.train()
        self.mnet.train()
    
    def eval(self):
        self.cnet.eval()
        self.mnet.eval()
    
    def to(self, device):
        self.cnet.to(device)
        self.mnet.to(device)

    def init_optimizers(self):
        #pre_optimizer_CNet = optim.Adam(self.cnet.parameters(), lr=5e-4)
        #pre_optimizer_MNet = optim.Adam(self.mnet.parameters(), lr=5e-4)
        Cnet_opt = torch.optim.Adam(self.cnet.parameters(), lr=5e-4)
        Mnet_opt = torch.optim.Adam(self.mnet.parameters(), lr=5e-4)
        Cnet_sch = torch.optim.lr_scheduler.StepLR(Cnet_opt, step_size=20, gamma=0.1)
        Mnet_sch = torch.optim.lr_scheduler.StepLR(Mnet_opt, step_size=20, gamma=0.1)

        self.optimizers = (Cnet_opt, Mnet_opt)
        self.lr_schedulers = (Cnet_sch, Mnet_sch)
        self.optim_initiated = True

    def training_step(self, batch, batch_idx, pretraining=False):

        Cnet_opt, Mnet_opt = self.optimizers

        y, mR = batch
        y = y.to(self.device)
        mR = mR.to(self.device)

        Cnet_opt.zero_grad()
        Mnet_opt.zero_grad()

        out_Mnet_mean, out_Mnet_var = self.mnet(y) # shape: (batch_size, 3, 2), (batch_size, 1, 2)
        out_Cnet = self.cnet(y) # shape: (batch_size, 2, H, W)

        #loss, loss_mse, loss_kl, loss_kl_h, loss_kl_e = self.loss_fn(  out_CNet, out_MNet_mean, out_Mnet_var, y, mR,
        #                                    self.args.sigmaRui_h_sq, self.args.sigmaRui_e_sq, 
        #                                    pretraining=True, pre_mse = self.args.pre_mse, pre_kl = self.args.pre_kl
        #                                    )
        
        loss, loss_kl, loss_mse = loss_BCD(out_Cnet, out_Mnet_mean, out_Mnet_var, y, self.sigmaRui_sq, mR, self.theta, pretraining=pretraining, pre_mse = self.pre_mse, pre_kl = self.pre_kl)

        loss.backward()
        Cnet_opt.step()
        Mnet_opt.step()

        CNet_sch, MNet_sch = self.lr_schedulers
        CNet_sch.step()
        MNet_sch.step()

        return {'train_loss' : loss.item(), 'train_loss_mse' : loss_mse.item(), 'train_loss_kl' : loss_kl.item()}
    
    def validation_step(self, batch, batch_idx):
        return self._shared_eval_step(batch, batch_idx, 'val')
    
    def test_step(self, batch, batch_idx):
        return self._shared_eval_step(batch, batch_idx, 'test')
    
    def _shared_eval_step(self, batch, batch_idx, prefix):

        y, mR = batch
        y = y.to(self.device)
        mR = mR.to(self.device)

        out_Mnet_mean, out_Mnet_var = self.mnet(y) # shape: (batch_size, 3, 2), (batch_size, 1, 2)
        out_Cnet = self.cnet(y) # shape: (batch_size, 2, H, W)

        #loss, loss_mse, loss_kl, loss_kl_h, loss_kl_e = self.loss_fn(  out_CNet, out_MNet_mean, out_Mnet_var, y, mR,
        #                                    self.args.sigmaRui_h_sq, self.args.sigmaRui_e_sq, 
        #                                    pretraining=True, pre_mse = self.args.pre_mse, pre_kl = self.args.pre_kl
        #                                    )
        
        loss, loss_kl, loss_mse = loss_BCD(out_Cnet, out_Mnet_mean, out_Mnet_var, y, self.sigmaRui_sq, mR, self.theta)

        #psnr = PeakSignalNoiseRatio()
        return {f'{prefix}_loss' : loss.item(), f'{prefix}_loss_mse' : loss_mse.item(), f'{prefix}_loss_kl' : loss_kl.item()}
    
    def fit(self, max_epochs, train_dataloader, val_dataloader=None, pretraining=False):
        if val_dataloader is None:
            val_dataloader = train_dataloader
        
        if not self.optim_initiated:
            self.init_optimizers()
        
        self.sigmaRui_sq.to(self.device)
        self.to(self.device)
        for epoch in range(1, max_epochs + 1):
            print(f"Epoch {epoch}")
            
            # Trainining loop:
            self.train()
            pbar = tqdm(enumerate(train_dataloader), total=len(train_dataloader))
            pbar.set_description(f"Epoch {epoch} - Training")
            for batch_idx, batch in pbar:
                m_dic = self.training_step(batch, batch_idx, pretraining=pretraining)
                pbar.set_postfix(m_dic)

            # Eval loop
            self.eval()
            pbar = tqdm(enumerate(val_dataloader), total=len(val_dataloader))
            pbar.set_description(f"Epoch {epoch} - Validation")
            for batch_idx, batch in pbar:
                m_dic = self.validation_step(batch, batch_idx)
                pbar.set_postfix(m_dic)

    def evaluate(self, test_dataloader):
        self.eval()
        pbar = tqdm(enumerate(test_dataloader), total=len(test_dataloader))
        pbar.set_description(f"Testing")
        for batch_idx, batch in pbar:
            m_dic = self.test_step(batch, batch_idx)
            pbar.set_postfix(m_dic)
        return m_dic

In [3]:
#args.train_data_path = '/data/BasesDeDatos/Camelyon/Camelyon17/training/Toy/'
args.train_data_path = '/data/BasesDeDatos/Camelyon/Camelyon17/training/patches_224/'

dataloaders = get_dataloaders(args, val_prop=0.3)
train_dataloader = dataloaders['train']
val_dataloader = dataloaders['val']
print(len(train_dataloader), len(val_dataloader))
test_dataloader = dataloaders['test']

Available patches: 1247945
Available patches: 1813651
70000 30000


In [4]:
model = DVBCDModel(args, DEVICE)
model.fit(5, train_dataloader, val_dataloader)

Epoch 1


Epoch 1 - Training:   0%|          | 42/70000 [00:30<14:04:39,  1.38it/s, train_loss=184, train_loss_mse=66.8, train_loss_kl=302]          