# Imports

In [None]:
import torch
from torch.utils.data.dataset import Dataset  # For custom data-sets
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import torchvision
import matplotlib.pyplot as plt
import glob
import argparse
from torch.utils.tensorboard import SummaryWriter
import os
from tqdm import tqdm
from tensorflow import summary
import torchvision.transforms.functional as TF
import torch.nn.functional as F
import random
import shutil
import scipy.io
from torch.autograd import Variable

In [None]:
# reproducibility settings

random.seed(1234)
torch.manual_seed(1234)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(1234)

# ThermoDataset Class

In [None]:
class Mapper():
  """
  Define a mapper to map N-channel image where each image is a class in a single
  image where the pixel values defines the class
  """
  def __init__(self, color_map: dict = {(0,0,0) : 0, (128,0,0):1, (0,128,0):2}):
    """
    Initialize the Mapper class

    Args:
      color_map (dict): python dict where vedo se la memoria su colab basta per farlo
                          - key are N-channel values, where each
                            value represent a class
                          - values are the 1-channel corresponding value

    Examples:

      # in this color_map the key (e.g. (0,0,0) are the N-channel class values)
      # while the values (e.g. 0) are the new 1-channel class value

      color_map = {(0,0,0) : 0, (128,0,0):1, (0,128,0):2}
      mapping = Mapper()
    """

    self.color_map = color_map
    self.color_map_inv = {v: k for k, v in color_map.items()}

  def __call__(self, mask):
    """
    Call the class instance as a function (e.g. mapper() ) to perform the
    mapping operation.

    Args:
      mask (numpy.ndarray()): N-channel target image
    """
    
    mask = np.asarray(mask)
    new_mask = np.zeros( (mask.shape[0], mask.shape[1], 1) )

    for m in self.color_map:
      binary_mask = (mask[:,:,0] == m[0]) & (mask[:,:,1] == m[1]) & (mask[:,:,2] == m[2])

      new_mask = new_mask + (binary_mask[..., np.newaxis] * self.color_map[m])

    # plt.imshow(new_mask[:,:,0])
    # plt.show()
    img = Image.fromarray(np.uint8(new_mask[:,:,0])) 
    # plt.imshow(img)
    # plt.show()
    return img


def normalization_param(dataloader):
  """
  Computes mean and std for a given dataloader
  """

  n_pix = 0
  val = 0.0
  std_num = 0.0

  for batch in dataloader:
    # Rearrange batch to be the shape of [B, D,  C, W * H]
    batch = batch.view(batch.size(0), -1)
    # Update total number of images
    n_pix += (batch.size(0)*batch.size(1))
    # Compute mean and std here
    val += batch.sum(1) 

  mean = val / n_pix

  for batch in dataloader:
    # Rearrange batch to be the shape of [B, D,  C, W * H]
    batch = batch.view( -1)
    std_num += torch.pow((batch - mean),2).sum()

  std = torch.sqrt(std_num/n_pix)
  
  return mean.item(), std.item()


class ThermoDatasetPfs(Dataset):
  def __init__(self, image_paths, target_paths, mean=None, std=None, train=True):

    self.image_paths =  glob.glob(image_paths + "*.mat")
    self.image_paths.sort()

    self.target_paths =  target_paths + "mask_final.tiff"
    self.mask = Image.open(self.target_paths)
    
    self.mapping = Mapper({(0,0,0) : 0, (128,0,0):1})
    self.mean = mean
    self.std = std

    self.mask = self.mapping(self.mask)
    self.mask = np.asarray(self.mask)
    self.mask = torch.from_numpy(self.mask).long()

    self.videos = []

    for path in self.image_paths:
      v = scipy.io.loadmat(path)['A'][np.newaxis, ...]
      v = (v - self.mean) / (self.std)
      npad = ((0,0), (0, 0), (1, 1), (1, 1))
      v = np.pad(v, pad_width=npad, mode='constant', constant_values=0)
      self.videos.append(v)

    self.T, _, self.H, self.W = self.videos[0].shape
    self.H -= 2
    self.W -= 2

  def __getitem__(self, index):

    n_vid = index // (self.H*self.W)
    new_idx = index % (self.H*self.W)

    n_row = (new_idx // self.W)
    n_col = (new_idx % self.W)

    # (1, 856, 320, 450)

    pixels = self.videos[n_vid][:, :, n_row:n_row+3, n_col:n_col+3]
    # image = undersampling(image)

    mask_px = self.mask[n_row, n_col][np.newaxis, ...]

    return pixels, mask_px

  def __len__(self):  # return count of sample we have
    l = len(self.image_paths)* self.H * self.W
    return l

  def get_size(self):
    return self.H, self.W


# Model

## Losses

## CODE

In [None]:
class BlockTemporal(nn.Sequential):

  def __init__(self,
                kernel,
                in_planes,
                out_planes,
                stride=1,
                padding=1):
    super(BlockTemporal, self).__init__(
          nn.Conv1d(in_planes, out_planes, kernel_size=kernel,
                    stride=stride, padding=padding,
                    bias=False))

  @staticmethod
  def get_downsample_stride(stride):
      return stride


class BasicBlockTemporal(nn.Module):

  def __init__(self, inplanes, planes, conv_builder, kernel=3, stride=1, padding=1, downsample=None):

    super(BasicBlockTemporal, self).__init__()

    self.conv0 = self.conv1 = nn.Sequential(
        nn.Conv1d(inplanes, inplanes, kernel_size=kernel,
                    stride=1, padding=padding,
                    bias=False),
        nn.BatchNorm1d(inplanes),
        nn.ReLU()
    )
    self.conv1 = nn.Sequential(
        conv_builder(kernel, inplanes, planes, stride, padding),
        nn.BatchNorm1d(planes),
        nn.ReLU()
    )
    self.relu = nn.ReLU()
    self.downsample = downsample
    self.stride = stride

  def forward(self, x):
    residual = x

    out = self.conv0(x)

    out = self.conv1(out)
    if self.downsample is not None:
        residual = self.downsample(x)

    out += residual
    out = self.relu(out)

    return out


class StemTemporal4(nn.Sequential):
    """R(2+1)D stem is different than the default one as it uses separated 3D convolution
    """
    def __init__(self):
        super(StemTemporal4, self).__init__(
            nn.Conv3d(1, 128, kernel_size=(13, 3, 3),
                      stride=(1), padding=(13//2, 0, 0),
                      bias=False),
            nn.BatchNorm3d(128),
            nn.ReLU(inplace=True))


class TimeResNetTemporal(nn.Module):

    def __init__(self, block, conv_makers,
                 stem, num_classes=2,
                 zero_init_residual=False, hdim=128):
        
        super(TimeResNetTemporal, self).__init__()

        self.stem = stem()

        self.layer1 = self._make_layer(block, conv_makers[0], 128, 128, stride=2, padding=1)
        self.layer2 = self._make_layer(block, conv_makers[1], 128, 128, stride=2, padding=1)
        self.layer3 = self._make_layer(block, conv_makers[2], 128, 128, stride=2, padding=1)
        self.layer4 = self._make_layer(block, conv_makers[3], 128, 128, stride=2, padding=1)
        self.layer5 = self._make_layer(block, conv_makers[4], 128, 256, stride=2, padding=1)


        self.last_conv = nn.Conv1d(256, 2, kernel_size=27,
                                   stride=1, padding=0,
                                   bias=False)
        
        # init weights
        self._initialize_weights()

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)

    def forward(self, x):

        x = self.stem(x).squeeze()

        x = self.layer1(x)
        
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.last_conv(x)

        return x

    def _make_layer(self, block, conv_builder, inplanes, planes, stride, padding, kernel=3):
        downsample = None

        if stride != 1 or inplanes != planes:
            ds_stride = conv_builder.get_downsample_stride(stride)
            # print("Downsampled!")
            downsample = nn.Sequential(
                nn.Conv1d(inplanes, planes,
                          kernel_size=1, stride=ds_stride, bias=False),
                nn.BatchNorm1d(planes)
            )
        layers = []
        layers.append(block(inplanes, planes, conv_builder, kernel, stride, 
                            padding, downsample))

        return nn.Sequential(*layers)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out',
                                        nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


def resnetTemporal34(pretrained=False, progress=True):

    return TimeResNetTemporal(block=BasicBlockTemporal,
                      conv_makers=[BlockTemporal] * 5,
                      stem=StemTemporal4, num_classes=2)


# Metrics

In [None]:
class Iou():
  """
  This Class allows to compute IOU both for single classes or in a global 
  fashion)
  """

  def __init__(self, n_classes: int):
    """ 
    Initialize class values

    Args:
      n_classes (int): number of semantic segmentation classes
    """
    self.SMOOTH = 1e-10
    self.num_ex = 0
    self.n_classes = n_classes
    self.ious = np.zeros((n_classes), dtype=np.float32)
    

  def update(self, out, mask):
    """
    Update class values

    Args:
      out (torch.Tensor): network output
      mask (torch.Tensor): network target
    """

    if len(out.shape) == 4 and out.shape[1] != 1:
      out = torch.argmax(out, dim=1)
    if len(out.shape) == 4 and mask.shape[1] != 1:
      mask = torch.argmax(mask, dim=1)
    
    # print("out: ", torch.unique(out))
    # print("mask: ", torch.unique(mask))

    # updating number of samples
    self.num_ex += out.shape[0]

    # print("intersection: ", np.unique(intersection), intersection.shape)
    # print("union: ", np.unique(union), union.shape)

    # intersection and union shape [BATCH, H, W]

    out = out.cpu().detach().numpy() 
    mask = mask.cpu().detach().numpy() 

    for c in range(self.n_classes):

      o_mask = np.where(out==c, 1, 0)
      m_mask = np.where(mask==c, 1, 0) 

      intersection = (o_mask & m_mask)
      union = (o_mask | m_mask)
      # print("c", c)

      # print("inter_mask: ", inter_mask.shape)
      i = intersection.sum((1, 2)) + self.SMOOTH

      # print("union_mask: ", union_mask.shape)
      u = union.sum((1, 2)) + self.SMOOTH

      # sum batches
      self.ious[c] += (i / u).sum((0))

      # print("ious: ", self.ious.shape)
      


  def get_iou_classes(self):
    """
    Return classes iou
    
    Returns:
      iou for each class
    """
    return self.ious / self.num_ex


  def get_iou_global(self):
    """
    Return global iou
    
    Returns:
      iou 
    """
    iou = (self.ious).sum((0)) /  (self.num_ex * self.n_classes)
    return iou


class ClassAccuracy:
  """
  This Class allows to compute IOU both for single classes or in a global 
  fashion)
  """

  def __init__(self, n_classes: int):
    """ 
    Initialize class values

    Args:
      n_classes (int): number of semantic segmentation classes
    """
    self.num_ex = 0
    self.n_classes = n_classes
    self.t = np.zeros((n_classes), dtype=np.float32)
    self.f = np.zeros((n_classes), dtype=np.float32)


  def update(self, out, mask):
    """
    Update class values

    Args:
      out (torch.Tensor): network output
      mask (torch.Tensor): network target
    """
    assert( len(out.shape) == 3)
    if len(out.shape) == 4 and out.shape[1] != 1:
      out = torch.argmax(out, dim=1)
    if len(out.shape) == 4 and mask.shape[1] != 1:
      mask = torch.argmax(mask, dim=1)
    
    # print("out: ", torch.unique(out))
    # print("mask: ", torch.unique(mask))

    # updating number of samples
    self.num_ex += out.shape[0]*out.shape[1]*out.shape[2]

    # print("intersection: ", np.unique(intersection), intersection.shape)
    # print("union: ", np.unique(union), union.shape)

    # intersection and union shape [BATCH, H, W]

    out = out.cpu().detach().numpy() 
    mask = mask.cpu().detach().numpy() 

    for c in range(self.n_classes):

      o_mask = np.where(out==c, 1, 0)
      m_mask = np.where(mask==c, 1, 0) 

      self.t[c] = (o_mask == m_mask).sum()
      self.f[c] = (o_mask != m_mask).sum()

    self.tp = np.sum(np.logical_and(out == 1, mask == 1))
    self.tn = np.sum(np.logical_and(out == 0, mask == 0))
    self.fp = np.sum(np.logical_and(out == 1, mask == 0))
    self.fn = np.sum(np.logical_and(out == 0, mask == 1))
      

  def get_acc_classes(self):
    """
    Return classes iou
    
    Returns:
      iou for each class
    """
    
    return self.t / (self.t + self.f)

  def get_metrics(self):
    return self.tp, self.tn, self.fp, self.fn


In [None]:
class Ioupf():
  """
  This Class allows to compute IOU both for single classes or in a global 
  fashion)
  """

  def __init__(self, n_classes: int):
    """ 
    Initialize class values

    Args:
      n_classes (int): number of semantic segmentation classes
    """
    self.SMOOTH = 1e-10
    self.num_ex = 0
    self.n_classes = n_classes
    self.ious = np.zeros((n_classes), dtype=np.float32)
    

  def update(self, out, mask):
    """
    Update class values

    Args:
      out (torch.Tensor): network output
      mask (torch.Tensor): network target
    """

    if len(out.shape) == 3 and out.shape[1] != 1:
      out = torch.argmax(out, dim=1)
    if len(out.shape) == 3 and mask.shape[1] != 1:
      mask = torch.argmax(mask, dim=1)
    
    # print("out: ", torch.unique(out))
    # print("mask: ", torch.unique(mask))

    # updating number of samples
    self.num_ex += out.shape[0]

    # print("intersection: ", np.unique(intersection), intersection.shape)
    # print("union: ", np.unique(union), union.shape)

    # intersection and union shape [BATCH, H, W]

    out = out.cpu().detach().numpy() 
    mask = mask.cpu().detach().numpy() 

    for c in range(self.n_classes):

      o_mask = np.where(out==c, 1, 0)
      m_mask = np.where(mask==c, 1, 0) 

      intersection = (o_mask & m_mask)
      union = (o_mask | m_mask)
      # print("c", c)

      # print("inter_mask: ", inter_mask.shape)
      i = intersection.sum((1, 2)) + self.SMOOTH

      # print("union_mask: ", union_mask.shape)
      u = union.sum((1, 2)) + self.SMOOTH

      # sum batches
      self.ious[c] += (i / u).sum((0))

      # print("ious: ", self.ious.shape)
      


  def get_iou_classes(self):
    """
    Return classes iou
    
    Returns:
      iou for each class
    """
    return self.ious / self.num_ex


  def get_iou_global(self):
    """
    Return global iou
    
    Returns:
      iou 
    """
    iou = (self.ious).sum((0)) /  (self.num_ex * self.n_classes)
    return iou


# Train/Validation functions

In [None]:
def train_cnn(dataloader_train, model, criterion, optim, epoch, logger):
    model.train()

    for it, (image, mask) in enumerate(tqdm(dataloader_train)):

      global_it = (epoch) * len(dataloader_train) + it

      image = image.cuda()
      mask = mask.cuda()

      image = image.float()

      # (batch, D, H*W) 
      out = model(image)
      # print("OUT shape: ", out.shape)
      # print("mask shape: ", mask.shape)
      loss = criterion(out, mask)
      optim.zero_grad()
      loss.backward()
      optim.step()
      logger.add_scalar('Loss/train', loss.item(), global_it)


def validatepfs_cnn_traintest(dataloader_val, model, criterion, epoch, logger):

  model.eval()
  iou_meter = Iou(n_classes=2)
  acc_meter = ClassAccuracy(n_classes=2)

  H, W = dataloader_val.dataset.get_size()

  om = []
  msk = []
  i=1
  
  with torch.no_grad():
    for it, (image, mask) in enumerate(tqdm(dataloader_val)):
      global_it = (epoch) * len(dataloader_val) + it

      # print(image.shape)

      videopx = image.cuda().float()
      maskpx = mask.cuda()

      outpx = model(videopx)

      loss = criterion(outpx, maskpx)

      ompx = torch.argmax(outpx, dim=1)

      om.append(ompx)
      msk.append(maskpx)
      logger.add_scalar('Loss/test', loss.item(), global_it)
      
      if ( ( (it+1) * image.shape[0] ) // (H*W) ) == i:
        print("Image: ", i)

        om1 = torch.stack(om).view(H*W, 1).permute(1, 0).reshape(1, H, W)
        msk1 = torch.stack(msk).view(H*W, 1).permute(1, 0).reshape(1, H, W)

        om_rgb = decode_segmap(om1)
        mask_rgb = decode_segmap(msk1)

        iou_meter.update(om1, msk1)
        acc_meter.update(om1, msk1)

        logger.add_image("image/out_rgb", om_rgb[0], global_it)

        logger.add_scalar('Iou/global', iou_meter.get_iou_global() , epoch)
        logger.add_scalar('Iou/background', iou_meter.get_iou_classes()[0] , epoch)
        logger.add_scalar('Iou/inclusion', iou_meter.get_iou_classes()[1] , epoch)
        logger.add_scalar('Acc/background', acc_meter.get_acc_classes()[0] , epoch)
        logger.add_scalar('Acc/inclusion', acc_meter.get_acc_classes()[1] , epoch)

        print("IOU global: ", iou_meter.get_iou_global())
        print("IOU background: ", iou_meter.get_iou_classes()[0])
        print("IOU inclusion: ", iou_meter.get_iou_classes()[1])
        print("Acc background: ", acc_meter.get_acc_classes()[0])
        print("Acc inclusion: ", acc_meter.get_acc_classes()[1])
        print("TP: {}, TN: {}, FP: {}, FN: {} ".format(acc_meter.get_metrics()[0],acc_meter.get_metrics()[1], acc_meter.get_metrics()[2], acc_meter.get_metrics()[3]))
        om = []
        msk = []

        i = i+1

  return (iou_meter.get_iou_global(), acc_meter.get_acc_classes()[0], acc_meter.get_acc_classes()[1])


def save_examples(image, mask, out):
  path = master_folder + "examples/"
  os.makedirs(path, exist_ok=True)
  for im in zip(image, mask, out):
    fig, axs = plt.subplots(1, 3, figsize=(15,15))

    axs[0].set_title("Input Image")
    axs[0].imshow(im[0].transpose(1,2,0).squeeze(2))

    axs[1].set_title("Input Mask")
    axs[1].imshow(im[1].transpose(1,2,0))


    axs[2].set_title("Output Mask")
    axs[2].imshow(im[2].transpose(1,2,0))

    plt.show()


def decode_segmap(image, nc=2):
  
  label_colors = np.array([(0, 0, 0),   # 0 = background
                          (128, 0, 0),  # 1 = inclusion
                          (0, 128, 0),  # 2 = deformation
                          ])

  r = torch.zeros_like(image, dtype=torch.uint8)
  g = torch.zeros_like(image, dtype=torch.uint8)
  b = torch.zeros_like(image, dtype=torch.uint8)

  for l in range(0, nc):
    idx = image == l
    r[idx] = label_colors[l, 0]
    g[idx] = label_colors[l, 1]
    b[idx] = label_colors[l, 2]
    
  rgb = torch.stack([r, g, b], axis=1)
  return rgb

def save_checkpoint(state, is_best, check_folder, project_name):
  torch.save(state, check_folder+project_name+".pth")
  if is_best:
    shutil.copyfile(check_folder+project_name+".pth", check_folder+project_name+"_model_best.pth")


def load_checkpoint(filename):
  return torch.load(filename, map_location=torch.device("cpu"))


# install logger
# !pip install -q tf-nightly-2.0-preview
%load_ext tensorboard

# Mounting Drive

In [None]:
from google.colab import drive
drive.mount('/gdrive')
%cd /gdrive

# Settings

In [None]:
#@title Compute std and mean
compute = True #@param {type:"boolean"}
all_dataset = "dataset_folder" #@param {type:"string"}

# dataset_pf_median
mean = 46.406394958496094
std = 31.089099884033203

if compute:
  stooge_dataset = ThermoDataset_base(all_dataset)
  stooge_loader = torch.utils.data.DataLoader(stooge_dataset, batch_size=1, shuffle=False)
  mean, std = normalization_param(stooge_loader)

  print("mean :", mean)
  print("std: ", std)


In [None]:
#@title Data, Checkpoint and Annotation folders
master_folder = "master_folder" #@param {type:"string"}

image_folder = "image_folder" #@param {type:"string"}
annotation_folder = "annotation_folder" #@param {type:"string"}
logger_folder = "logs_folder" #@param {type:"string"}
project_name = "project_name" #@param {type:"string"}
check_folder = "checkpoint_folder" #@param {type:"string"}
test = False #@param {type:"boolean"}


# relative to absolute path
image_folder = master_folder+image_folder
annotation_folder = master_folder+annotation_folder
logger_folder = master_folder+logger_folder + "/"
check_folder = master_folder+check_folder + "/"

possible_check = glob.glob(check_folder + "*.pth")

import ipywidgets as widgets
from IPython.display import display

possible_check.append('No check available')

print("Select yout checkpoint here: ")
print()

check_file = widgets.Dropdown(
    options=possible_check,
    value='No check available',
    description='Checkpoint Available:',
    disabled=False,
    layout=widgets.Layout(width='100%')
)
display(check_file)
print()

train_image_paths = image_folder + "train/data/"
test_image_paths =  image_folder + "test/data/"

train_mask_paths =  annotation_folder + "train/annotations/"
test_mask_paths =  annotation_folder + "test/annotations/"

train_dataset = ThermoDatasetPfs(train_image_paths, train_mask_paths, mean=mean, std=std)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1200, shuffle=True)

test_dataset = ThermoDatasetPfs(test_image_paths, test_mask_paths, mean=mean, std=std)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1200, shuffle=False)

train_len = len(train_dataset)
test_len =  len(test_dataset)
len_data =  train_len + test_len
print("Dataset lenght: ", len_data)
print("dataset data - train: ", train_len, "  test: ", test_len)
os.makedirs(logger_folder + project_name, exist_ok=True)
os.makedirs(check_folder, exist_ok=True)

In [None]:
#@title Network parameters


learning_rate = 0.000001 #@param {type:"number"}
start_epoch = 0 #@param {type:"integer"}
max_epoch =  2046#@param {type:"integer"}




# TensorBoard

In [None]:
%reload_ext tensorboard
%tensorboard --logdir "<logdir>"

# Training/Test execution

In [None]:
logger = SummaryWriter(logger_folder+project_name,flush_secs=20)
torch.cuda.set_device(0)

print("=> creating model")
model = resnetTemporal34()
model = model.cuda()

weights=[1.0, 100.0]
weights = torch.tensor(weights).cuda()
criterion = torch.nn.CrossEntropyLoss(weights).cuda()

print("=> selecting optimizer")
optim = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False)


best_iou=0

if check_file.value != "No check available" :
  checkpoint = torch.load(check_file.value, map_location='cpu')
  model.load_state_dict(checkpoint['model'])
  optim.load_state_dict(checkpoint['optimizer'])
  start_epoch = checkpoint['epoch'] + 1
  best_iou = checkpoint['best_iou']
  # lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])

if test:
  print("Testing")
  print("Epoch: ", start_epoch)
  validatepfs_cnn_traintest(test_loader, model, criterion, start_epoch, logger)
else :
  for e in range(start_epoch, max_epoch):
    is_best = False
    print("Epoch: ", e)
    # !nvidia-smi
    train_cnn(train_loader, model, criterion, optim, e, logger)
    iou, accb, acci = validatepfs_cnn_traintest(test_loader, model, criterion, e, logger)

    checkpoint = {
              'model': model.state_dict(),
              'optimizer': optim.state_dict(),
              'best_iou': iou,
              'epoch': e,
              }
    is_best =  iou > best_iou

    save_checkpoint(checkpoint, is_best, check_folder=check_folder, project_name=project_name)

    if is_best:
      best_iou = iou

logger.close()