<a href="https://colab.research.google.com/github/cudaMBI/cudaMBI-Land-Cover-Change-Detection/blob/main/CDD_V1_2022-12-08/moh_FC_EF_test01.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [7]:
import os
os.chdir("/content/drive/MyDrive/CDD_V1_2022-12-08")

print(os.getcwd())

/content/drive/MyDrive/CDD_V1_2022-12-08


In [8]:
!ls

log  metadata.json  models  __pycache__  train.py  utils


In [7]:
#!git clone https://github.com/cudaMBI/cudaMBI-Land-Cover-Change-Detection
#!rm -rf "/content/cudaMBI-Land-Cover-Change-Detection" 

In [21]:
!pip install tensorboardX torchinfo

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


# train file content

In [14]:
import datetime
import logging
import json
from tensorboardX import SummaryWriter
import torch
import torch.nn as nn
from tqdm import tqdm
import random
import numpy as np
import os
from utils.helper import (get_loaders, load_model, get_criterion, initialize_metrics,
                   set_metrics, get_mean_metrics)
from sklearn.metrics import precision_recall_fscore_support as prfs

"""
Initialize Parser and define arguments

import parse
from parse import get_parser_with_args

parser, metadata = get_parser_with_args()
opt = parser.parse_args()
"""
opt = {"augmentation": False,
        "num_gpus": 1,
        "num_workers": 0,
        "in_channel": 6,
        "out_channel":1,
        'loss_function':'IoULoss',
        "epochs": 2,
        "batch_size": 16,
        "learning_rate": 1e-3,
        "dataset_dir": "/content/drive/MyDrive/CDD_Data_2000imgs/",
        "log_dir": "./log/"
      }

"""
Initialize experiments log
"""
writer = SummaryWriter(opt['log_dir'] + f'/FC_EF_{datetime.datetime.now().strftime("%Y-%m-%d")}/')

"""
Set up environment: define paths, download data, and set device
"""
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('GPU AVAILABLE? ' + str(torch.cuda.is_available()))

def seed_torch(seed=0):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

seed_torch(seed=777)

train_loader, val_loader = get_loaders(opt)

"""
Load Model then define other aspects of the model
"""
print('LOADING Model')
model = load_model(opt, dev)

criterion = get_criterion(opt)
optimizer = torch.optim.AdamW(model.parameters(), lr=opt['learning_rate']) # Be careful when you adjust learning rate, you can refer to the linear scaling rule
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.1)

"""
 Set starting values
"""
best_metrics = {'cd_f1scores': -1, 'cd_recalls': -1, 'cd_precisions': -1}
logging.info('STARTING training')
total_step = -1


GPU AVAILABLE? True
LOADING Model


In [15]:
batch_img1, batch_img2, label = next(iter(train_loader))
print(f"batch image 1 size: {batch_img1.size()}")
print(f"batch image 2 size: {batch_img2.size()}")
print(f"label size: {label.size()}")

# we need to add extra dimension in label 
label = torch.unsqueeze(label, dim=1)
print(f"label size: {label.size()}")

batch image 1 size: torch.Size([16, 3, 256, 256])
batch image 2 size: torch.Size([16, 3, 256, 256])
label size: torch.Size([16, 256, 256])
label size: torch.Size([16, 1, 256, 256])


### model functions

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary

# A placeholder identity operator that is argument-insensitive.
Identity = nn.Identity


def get_norm_layer():
    # TODO: select appropriate norm layer
    return nn.BatchNorm2d

def get_act_layer():
    # TODO: select appropriate activation layer
    return nn.ReLU


def make_norm(*args, **kwargs):
    norm_layer = get_norm_layer()
    return norm_layer(*args, **kwargs)

def make_act(*args, **kwargs):
    act_layer = get_act_layer()
    return act_layer(*args, **kwargs)

def make_dropout(use_dropout=False):
        if use_dropout:
            return nn.Dropout2d(p=0.2)
        else:
            return Identity()

class BasicConv(nn.Module):
    def __init__(
            self, in_ch, out_ch,
            kernel_size, pad_mode='Zero',
            bias='auto', norm=False, act=False, **kwargs
        ):
        super().__init__()
        seq = []
        if kernel_size >= 2:
            seq.append( getattr(nn, pad_mode.capitalize()+'Pad2d')(kernel_size//2) )
        seq.append(
            nn.Conv2d(
                in_ch, out_ch, kernel_size, stride=1, padding=0,
                bias=(False if norm else True) if bias=='auto' else bias,
                **kwargs
            )    
        )
        if norm:
            if norm is True:
                norm = make_norm(out_ch)
            seq.append(norm)
        if act:
            if act is True:
                act = make_act()
            seq.append(act)
        self.seq = nn.Sequential(*seq)
        
    def forward(self, x):
        return self.seq(x)
 
       
class Conv3x3(BasicConv):
    def __init__(self, in_ch, out_ch, pad_mode='Zero', bias='auto', norm=False, act=False, **kwargs):
        super().__init__(in_ch, out_ch, kernel_size=3, pad_mode=pad_mode, bias=bias, norm=norm, act=act, **kwargs)

class MaxPool2x2(nn.MaxPool2d):
    def __init__(self, **kwargs):
        super().__init__(kernel_size=2, stride=(2,2), padding=(0,0), **kwargs)

class MaxUnPool2x2(nn.MaxUnpool2d):
    def __init__(self, **kwargs):
        super().__init__(kernel_size=2, stride=(2,2), padding=(0,0), **kwargs)
        

class ConvTransposed3x3(nn.Module):
    def __init__(self, in_ch, out_ch, bias='auto', norm=False, act=False, **kwargs):
        super().__init__()
        seq = []
        seq.append(
                nn.ConvTranspose2d(in_ch, out_ch, kernel_size=3, stride=2, padding=1,
                                   bias=(False if norm else True) if bias=='auto' else bias,
                                   **kwargs
                        )
            )        
        if norm:
            if norm is True:
                norm = make_norm(out_ch)
            seq.append(norm)
        if act:
            if act is True:
                act = make_act()
            seq.append(act)
        self.seq = nn.Sequential(*seq) 
        
    def forward(self, x):
        return self.seq(x)

### model

In [57]:
class UNet(nn.Module):
    def __init__(self, in_ch, out_ch, use_dropout=False):
        super().__init__()
        
        self.use_dropout = use_dropout
        
        self.conv11 = Conv3x3(in_ch, out_ch=16, norm=True, act=True)
        self.do11 = self.make_dropout()
        self.conv12 = Conv3x3(16, 16, norm=True, act=True)
        self.do12 = self.make_dropout()
        self.pool1 = MaxPool2x2()
        
        self.conv21 = Conv3x3(16, 32, norm=True, act=True)
        self.do21 = self.make_dropout()
        self.conv22 = Conv3x3(32, 32, norm=True, act=True)
        self.do22 = self.make_dropout()
        self.pool2 = MaxPool2x2()

        self.conv31 = Conv3x3(32, 64, norm=True, act=True)
        self.do31 = self.make_dropout()
        self.conv32 = Conv3x3(64, 64, norm=True, act=True)
        self.do32 = self.make_dropout()
        self.conv33 = Conv3x3(64, 64, norm=True, act=True)
        self.do33 = self.make_dropout()
        self.pool3 = MaxPool2x2()

        self.conv41 = Conv3x3(64, 128, norm=True, act=True)
        self.do41 = self.make_dropout()
        self.conv42 = Conv3x3(128, 128, norm=True, act=True)
        self.do42 = self.make_dropout()
        self.conv43 = Conv3x3(128, 128, norm=True, act=True)
        self.do43 = self.make_dropout()
        self.pool4 = MaxPool2x2()

        self.upconv4 = ConvTransposed3x3(128, 128, output_padding=1)

        self.conv43d = Conv3x3(256, 128, norm=True, act=True)
        self.do43d = self.make_dropout()
        self.conv42d = Conv3x3(128, 128, norm=True, act=True)
        self.do42d = self.make_dropout()
        self.conv41d = Conv3x3(128, 64, norm=True, act=True)
        self.do41d = self.make_dropout()

        self.upconv3 = ConvTransposed3x3(64, 64, output_padding=1)

        self.conv33d = Conv3x3(128, 64, norm=True, act=True)
        self.do33d = self.make_dropout()
        self.conv32d = Conv3x3(64, 64, norm=True, act=True)
        self.do32d = self.make_dropout()
        self.conv31d = Conv3x3(64, 32, norm=True, act=True)
        self.do31d = self.make_dropout()

        self.upconv2 = ConvTransposed3x3(in_ch=32, out_ch=32, output_padding=1)

        self.conv22d = Conv3x3(64, 32, norm=True, act=True)
        self.do22d = self.make_dropout()
        self.conv21d = Conv3x3(32, 16, norm=True, act=True)
        self.do21d = self.make_dropout()

        self.upconv1 = ConvTransposed3x3(in_ch=16, out_ch=16, output_padding=1)

        self.conv12d = Conv3x3(32, 16, norm=True, act=True)
        self.do12d = self.make_dropout()
        self.conv11d = Conv3x3(16, out_ch)

    def forward(self, x):
        
        #x = torch.cat([t1, t2], dim=1)
        
        # Stage 1
        x11 = self.do11(self.conv11(x))
        x12 = self.do12(self.conv12(x11))
        x1p = self.pool1(x12)

        # Stage 2
        x21 = self.do21(self.conv21(x1p))
        x22 = self.do22(self.conv22(x21))
        x2p = self.pool2(x22)

        # Stage 3
        x31 = self.do31(self.conv31(x2p))
        x32 = self.do32(self.conv32(x31))
        x33 = self.do33(self.conv33(x32))
        x3p = self.pool3(x33)

        # Stage 4
        x41 = self.do41(self.conv41(x3p))
        x42 = self.do42(self.conv42(x41))
        x43 = self.do43(self.conv43(x42))
        x4p = self.pool4(x43)

        # Stage 4d
        x4d = self.upconv4(x4p)
        pad4 = (0, x43.shape[3]-x4d.shape[3], 0, x43.shape[2]-x4d.shape[2])
        x4d = torch.cat([F.pad(x4d, pad=pad4, mode='replicate'), x43], 1)
        x43d = self.do43d(self.conv43d(x4d))
        x42d = self.do42d(self.conv42d(x43d))
        x41d = self.do41d(self.conv41d(x42d))

        # Stage 3d
        x3d = self.upconv3(x41d)
        pad3 = (0, x33.shape[3]-x3d.shape[3], 0, x33.shape[2]-x3d.shape[2])
        x3d = torch.cat([F.pad(x3d, pad=pad3, mode='replicate'), x33], 1)
        x33d = self.do33d(self.conv33d(x3d))
        x32d = self.do32d(self.conv32d(x33d))
        x31d = self.do31d(self.conv31d(x32d))

        # Stage 2d
        x2d = self.upconv2(x31d)
        pad2 = (0, x22.shape[3]-x2d.shape[3], 0, x22.shape[2]-x2d.shape[2])
        x2d = torch.cat([F.pad(x2d, pad=pad2, mode='replicate'), x22], 1)
        x22d = self.do22d(self.conv22d(x2d))
        x21d = self.do21d(self.conv21d(x22d))

        # Stage 1d
        x1d = self.upconv1(x21d)
        pad1 = (0, x12.shape[3]-x1d.shape[3], 0, x12.shape[2]-x1d.shape[2])
        x1d = torch.cat([F.pad(x1d, pad=pad1, mode='replicate'), x12], 1)
        x12d = self.do12d(self.conv12d(x1d))
        x11d = self.conv11d(x12d)

        return x11d
        
        
    def make_dropout(self):
        if self.use_dropout:
            return nn.Dropout2d(p=0.2)
        else:
            return Identity()
            
unet_model = UNet(in_ch=6, out_ch=2, use_dropout=False)
print(summary(unet_model, input_size=((64, 6, 256, 256))))            

Layer (type:depth-idx)                   Output Shape              Param #
UNet                                     [64, 2, 256, 256]         --
├─Conv3x3: 1-1                           [64, 16, 256, 256]        --
│    └─Sequential: 2-1                   [64, 16, 256, 256]        --
│    │    └─ZeroPad2d: 3-1               [64, 6, 258, 258]         --
│    │    └─Conv2d: 3-2                  [64, 16, 256, 256]        864
│    │    └─BatchNorm2d: 3-3             [64, 16, 256, 256]        32
│    │    └─ReLU: 3-4                    [64, 16, 256, 256]        --
├─Identity: 1-2                          [64, 16, 256, 256]        --
├─Conv3x3: 1-3                           [64, 16, 256, 256]        --
│    └─Sequential: 2-2                   [64, 16, 256, 256]        --
│    │    └─ZeroPad2d: 3-5               [64, 16, 258, 258]        --
│    │    └─Conv2d: 3-6                  [64, 16, 256, 256]        2,304
│    │    └─BatchNorm2d: 3-7             [64, 16, 256, 256]        32
│    │    └

In [55]:
x = torch.cat([batch_img1, batch_img1], dim=1)
x.size()

torch.Size([16, 6, 256, 256])

### training loop

In [22]:
batch_img1, batch_img2, labels = next(iter(train_loader))
print(f"batch image 1 size: {batch_img1.size()}")
print(f"batch image 2 size: {batch_img2.size()}")
print(f"label size: {labels.size()}") # we need to add extra dimension in label 

labels = torch.unsqueeze(labels, dim=1)
labels = labels.long().to(dev)
print(f"label size: {labels.size()}")

inputs = torch.cat([batch_img1, batch_img1], dim=1).to(dev)
print(f"concate batch 1 and batch 2 images: {inputs.size()}")

model = model.to(dev)

cd_preds = model(inputs)
print(f"output size: {cd_preds.size()}")

print(cd_preds[-1].size())

batch image 1 size: torch.Size([16, 3, 256, 256])
batch image 2 size: torch.Size([16, 3, 256, 256])
label size: torch.Size([16, 256, 256])
label size: torch.Size([16, 1, 256, 256])
concate batch 1 and batch 2 images: torch.Size([16, 6, 256, 256])
output size: torch.Size([16, 1, 256, 256])
torch.Size([1, 256, 256])


In [17]:
print(labels[-1])

tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1]]], device='cuda:0')


In [23]:
cd_preds = model(inputs)
print(cd_preds[-1])

_, cd_preds = torch.max(cd_preds, 1)
print(cd_preds)
print(cd_preds.size())

tensor([[[0.1570, 0.1698, 0.1766,  ..., 0.1441, 0.0974, 0.0494],
         [0.3115, 0.3727, 0.4215,  ..., 0.3042, 0.2431, 0.0597],
         [0.3702, 0.4333, 0.5130,  ..., 0.3504, 0.2834, 0.0868],
         ...,
         [0.2603, 0.3804, 0.4135,  ..., 0.3281, 0.2463, 0.0799],
         [0.2627, 0.4113, 0.4421,  ..., 0.3560, 0.2892, 0.1354],
         [0.1902, 0.3297, 0.3475,  ..., 0.2474, 0.2135, 0.1074]]],
       device='cuda:0', grad_fn=<SelectBackward0>)
tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  .

In [20]:
for epoch in range(opt['epochs']):  # loop over the dataset multiple times
    
    train_metrics = initialize_metrics()
    val_metrics = initialize_metrics()

    """
    Begin Training
    """
    model.train()
    print('SET model mode to train!')
    batch_iter = 0
    #tbar = tqdm(train_loader)

    for i, data in enumerate(train_loader, 0):
        # get the inputs; data is a list of [batch_img1, batch_img2, labels]
        batch_img1, batch_img2, labels = data
        
        # Set variables for training
        batch_img1 = batch_img1.float()
        batch_img2 = batch_img2.float()
        
        inputs = torch.cat([batch_img1, batch_img2], dim=1).to(dev)

        #labels = torch.unsqueeze(labels, dim=1)
        labels = labels.long().to(dev)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        # Get model predictions, calculate loss, backprop
        cd_preds = model(inputs)

        cd_loss = criterion(cd_preds, labels)
        loss = cd_loss
        loss.backward()
        optimizer.step()
        
        cd_preds = cd_preds[-1]
        _, cd_preds = torch.max(cd_preds, 1)
        

        # clear batch variables from memory
        del batch_img1, batch_img2, inputs, labels
        
        
        # print statistics
        print(f"[{epoch + 1}, {i + 1:5d}] loss: {cd_loss:.5f}")

    scheduler.step()
    logging.info("EPOCH {} TRAIN METRICS".format(epoch) + str(cd_loss))

SET model mode to train!
[1,     1] loss: 1.01762
[1,     2] loss: 0.86002
[1,     3] loss: 0.78050
[1,     4] loss: 0.85326
[1,     5] loss: -0.24153
[1,     6] loss: 0.40890
[1,     7] loss: -0.35814
[1,     8] loss: -1.73874
[1,     9] loss: 0.00880
[1,    10] loss: 1.95581
[1,    11] loss: -0.76622
[1,    12] loss: 1.29365
[1,    13] loss: 0.98158
[1,    14] loss: 1.27434
[1,    15] loss: 1.00376
[1,    16] loss: 1.14439
[1,    17] loss: 0.92594
[1,    18] loss: 1.04532
[1,    19] loss: 1.13380
[1,    20] loss: 1.00727
[1,    21] loss: 0.96440
[1,    22] loss: 0.94631
[1,    23] loss: 0.90711
[1,    24] loss: 0.87282
[1,    25] loss: 0.94549
[1,    26] loss: 0.96340
[1,    27] loss: 0.92988
[1,    28] loss: 0.92060
[1,    29] loss: 0.91616
[1,    30] loss: 0.89790
[1,    31] loss: 0.88083
[1,    32] loss: 0.88495
[1,    33] loss: 0.85797
[1,    34] loss: 0.90238
[1,    35] loss: 0.92786
[1,    36] loss: 0.81383
[1,    37] loss: 0.57348
[1,    38] loss: 0.59839
[1,    39] loss: 0.79

In [32]:
model = UNet(in_ch=6, out_ch=1, use_dropout=False).to(dev)

for epoch in range(opt['epochs']):  # loop over the dataset multiple times
    
    train_metrics = initialize_metrics()
    val_metrics = initialize_metrics()

    """
    Begin Training
    """
    model.train()
    logging.info('SET model mode to train!')
    batch_iter = 0
    #tbar = tqdm(train_loader)

    for i, data in enumerate(train_loader, 0):
        # get the inputs; data is a list of [batch_img1, batch_img2, labels]
        batch_img1, batch_img2, labels = data
        
        # Set variables for training
        batch_img1 = batch_img1.float()
        batch_img2 = batch_img2.float()
        
        inputs = torch.cat([batch_img1, batch_img2], dim=1).to(dev)

        labels = torch.unsqueeze(labels, dim=1)
        labels = labels.long().to(dev)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        # Get model predictions, calculate loss, backprop
        cd_preds = model(inputs)

        cd_loss = criterion(cd_preds, labels)
        loss = cd_loss
        loss.backward()
        optimizer.step()
        
        cd_preds = cd_preds[-1]
        _, cd_preds = torch.max(cd_preds, 1)
        
        # Calculate and log other batch metrics
        cd_corrects = (100 *
                       (cd_preds.squeeze().byte() == labels.squeeze().byte()).sum() /
                       (labels.size()[0] * (opt['batch_size']**2)))
        
        cd_train_report = prfs(labels.data.cpu().numpy().flatten(),
                               cd_preds.data.cpu().numpy().flatten(),
                               average='binary',
                               zero_division=0,
                               pos_label=1)
        
        train_metrics = set_metrics(train_metrics,
                                    cd_loss,
                                    cd_corrects,
                                    cd_train_report,
                                    scheduler.get_last_lr())
        
        # log the batch mean metrics
        mean_train_metrics = get_mean_metrics(train_metrics)
        
        for k, v in mean_train_metrics.items():
            writer.add_scalars(str(k), {'train': v}, total_step)

        # clear batch variables from memory
        del batch_img1, batch_img2, labels
        
        
        # print statistics
        if i % 1000 == 0:    # print every 2000 mini-batches
            print(f"[{opt['epoch'] + 1}, {i + 1:5d}] loss: {mean_train_metrics['cd_losses']:.5f}")

    scheduler.step()
    logging.info("EPOCH {} TRAIN METRICS".format(epoch) + str(mean_train_metrics))
    
    break 

    """
    Begin Validation
    """
    model.eval()
    with torch.no_grad():
        for batch_img1, batch_img2, labels in val_loader:
            # Set variables for training
            batch_img1 = batch_img1.float()
            batch_img2 = batch_img2.float()
            
            inputs = torch.cat([batch_img1, batch_img2], dim=1).to(dev)
            labels = labels.long().to(dev)

            # Get predictions and calculate loss
            cd_preds = model(inputs)

            cd_loss = criterion(cd_preds, labels)

            cd_preds = cd_preds[-1]
            _, cd_preds = torch.max(cd_preds, 1)

            # Calculate and log other batch metrics
            cd_corrects = (100 *
                           (cd_preds.squeeze().byte() == labels.squeeze().byte()).sum() /
                           (labels.size()[0] * (opt.patch_size**2)))

            cd_val_report = prfs(labels.data.cpu().numpy().flatten(),
                                 cd_preds.data.cpu().numpy().flatten(),
                                 average='binary',
                                 zero_division=0,
                                 pos_label=1)

            val_metrics = set_metrics(val_metrics,
                                      cd_loss,
                                      cd_corrects,
                                      cd_val_report,
                                      scheduler.get_last_lr())

            # log the batch mean metrics
            mean_val_metrics = get_mean_metrics(val_metrics)

            for k, v in mean_train_metrics.items():
                writer.add_scalars(str(k), {'val': v}, total_step)

            # clear batch variables from memory
            del batch_img1, batch_img2, labels    

        
        """
        Store the weights of good epochs based on validation results
        """
        if ((mean_val_metrics['cd_precisions'] > best_metrics['cd_precisions'])
                or
                (mean_val_metrics['cd_recalls'] > best_metrics['cd_recalls'])
                or
                (mean_val_metrics['cd_f1scores'] > best_metrics['cd_f1scores'])):

            # Insert training and epoch information to metadata dictionary
            logging.info('updata the model')
            opt['validation_metrics'] = mean_val_metrics

            # Save model and log
            if not os.path.exists('./tmp'):
                os.mkdir('./tmp')
            with open('./tmp/metadata_epoch_' + str(epoch) + '.json', 'w') as fout:
                json.dump(opt, fout)

            torch.save(model, './tmp/checkpoint_epoch_'+str(epoch)+'.pt')

            # comet.log_asset(upload_metadata_file_path)
            best_metrics = mean_val_metrics   
            
            
        print('An epoch finished.')
        
writer.close()  # close tensor board
logging.info('Done!')

RuntimeError: ignored