In [1]:
#%cd /content/drive/My Drive/SegFormer
import argparse
import logging
import os
import random
import sys
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from utils import clip_gradient, AvgMeter

from glob import glob
from skimage.io import imread
import matplotlib.pyplot as plt
import pandas as pd
from collections import OrderedDict
from torch.autograd import Variable
from datetime import datetime
import torch.nn.functional as F
import cv2
from albumentations.augmentations import transforms
from albumentations.core.composition import Compose, OneOf
import math

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import argparse
import copy
import os
import os.path as osp
import time

import mmcv
import torch
from mmcv.runner import init_dist
from mmcv.utils import Config, DictAction, get_git_hash

from mmseg import __version__
from mmseg.apis import set_random_seed, train_segmentor
from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.utils import collect_env, get_root_logger

In [3]:
from mmseg.models.segmentors import CaraSegUPer_ver2 as UNet
from mmseg.models.segmentors import CaraSegUPer_wBiFPN as Net
from mmseg.models.segmentors import CaraSegUPer_wBiRAFPN as Net2

In [4]:
class Dataset(torch.utils.data.Dataset):

    def __init__(self, img_paths, mask_paths, aug=True, transform=None):
        self.img_paths = img_paths
        self.mask_paths = mask_paths
        self.aug = aug
        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        mask_path = self.mask_paths[idx]
        # image = imread(img_path)
        # mask = imread(mask_path)
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, 0)
        # name = self.img_paths[idx].split('/')[-1]

        if self.transform is not None:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
            image = cv2.resize(image, (384, 384))
            mask = cv2.resize(mask, (384, 384)) 
        else:
            image = cv2.resize(image, (384, 384))
            mask = cv2.resize(mask, (384, 384)) 

        image = image.astype('float32') / 255
        image = image.transpose((2, 0, 1))

        mask = mask[:,:,np.newaxis]
        mask = mask.astype('float32')
        mask = mask.transpose((2, 0, 1))

        return np.asarray(image), np.asarray(mask)

In [5]:
from keras import backend as K

def recall_m(y_true, y_pred):
  true_positives = torch.sum(torch.round(torch.clip(y_true * y_pred, 0, 1)))
  possible_positives = torch.sum(torch.round(torch.clip(y_true, 0, 1)))
  recall = true_positives / (possible_positives + K.epsilon())
  return recall

def precision_m(y_true, y_pred):
  true_positives = torch.sum(torch.round(torch.clip(y_true * y_pred, 0, 1)))
  predicted_positives = torch.sum(torch.round(torch.clip(y_pred, 0, 1)))
  precision = true_positives / (predicted_positives + K.epsilon())
  return precision

def dice_m(y_true, y_pred):
  precision = precision_m(y_true, y_pred)
  recall = recall_m(y_true, y_pred)
  return 2*((precision*recall)/(precision+recall+K.epsilon()))

def iou_m(y_true, y_pred):
  precision = precision_m(y_true, y_pred)
  recall = recall_m(y_true, y_pred)
  return recall*precision/(recall+precision-recall*precision +K.epsilon())

In [6]:
class FocalLossV1(nn.Module):

    def __init__(self,
                 alpha=0.25,
                 gamma=2,
                 reduction='mean',):
        super(FocalLossV1, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.crit = nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, logits, label):
        # compute loss
        logits = logits.float() # use fp32 if logits is fp16
        with torch.no_grad():
            alpha = torch.empty_like(logits).fill_(1 - self.alpha)
            alpha[label == 1] = self.alpha

        probs = torch.sigmoid(logits)
        pt = torch.where(label == 1, probs, 1 - probs)
        ce_loss = self.crit(logits, label.float())
        loss = (alpha * torch.pow(1 - pt, self.gamma) * ce_loss)
        if self.reduction == 'mean':
            loss = loss.mean()
        if self.reduction == 'sum':
            loss = loss.sum()
        return loss
class TverskyLoss(nn.Module):
    def __init__(self, alpha=0.7, eps=1):
        super().__init__()

        self.eps = eps
        self.alpha = alpha

    def forward(self, pred, label):
        """
        Forward function
        :param pred: Prediction tensor containing raw network outputs (no logit) (B x C x H x W)
        :param label: Label mask tensor (B x C x H x W)
        """
        probs = torch.sigmoid(pred)
        true_pos = torch.sum(probs * label, dim=[0, 2, 3])
        false_neg = torch.sum(label * (1 - probs), dim=[0, 2, 3])
        false_pos = torch.sum(probs * (1 - label), dim=[0, 2, 3])
        return 1 - torch.mean(
            (true_pos + self.eps)
            / (
                true_pos
                + self.alpha * false_neg
                + (1 - self.alpha) * false_pos
                + self.eps
            )
        )

class FocalTverskyLoss(TverskyLoss):
    def __init__(self, gamma=4 / 3, **kwargs):
        super().__init__(**kwargs)
        self.gamma = gamma

    def forward(self, pred, label):
        probs = torch.sigmoid(pred)
        true_pos = torch.sum(probs * label, dim=[0, 2, 3])
        false_neg = torch.sum(label * (1 - probs), dim=[0, 2, 3])
        false_pos = torch.sum(probs * (1 - label), dim=[0, 2, 3])

        t = (true_pos + self.eps) / (
            true_pos + self.alpha * false_neg + (1 - self.alpha) * false_pos + self.eps
        )

        x = torch.pow(1 - t, 1 / self.gamma)

        return x #torch.sum(x)
def structure_loss_v2(pred, mask):
    weit = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
    wfocal = FocalLossV1()(pred, mask)
    wfocal = (wfocal*weit).sum(dim=(2,3)) / weit.sum(dim=(2, 3))

    
    wiou = FocalTverskyLoss()(pred,mask)
    return (wfocal + wiou).mean()
def structure_loss(pred, mask):
    weit = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
    wfocal = FocalLossV1()(pred, mask)
    wfocal = (wfocal*weit).sum(dim=(2,3)) / weit.sum(dim=(2, 3))

    pred = torch.sigmoid(pred)
    inter = ((pred * mask)*weit).sum(dim=(2, 3))
    union = ((pred + mask)*weit).sum(dim=(2, 3))
    wiou = 1 - (inter + 1)/(union - inter+1)
    return (wfocal + wiou).mean()

In [7]:
def train(train_loader, model, optimizer, epoch, lr_scheduler, deep=False):
    model.train()
    # ---- multi-scale training ----
    size_rates = [256/384, 1, 512/384]
    loss_record = AvgMeter()
    dice, iou = AvgMeter(), AvgMeter()
    for i, pack in enumerate(train_loader, start=1):
        if epoch <= 1:
                optimizer.param_groups[0]["lr"] = (epoch * i) / (1.0 * total_step) * init_lr
        else:
            lr_scheduler.step()

        for rate in size_rates: 
            optimizer.zero_grad()
            # ---- data prepare ----
            images, gts = pack
            images = Variable(images).cuda()
            gts = Variable(gts).cuda()
            # ---- rescale ----
            trainsize = int(round(trainsize_init*rate/32)*32)
            images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
            gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
            # ---- forward ----
            map4, map3, map2, map1 = model(images)
           # print(map4.shape, map3.shape, map2.shape, map1.shape)
            map1 = F.upsample(map1, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
            map2 = F.upsample(map2, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
            map3 = F.upsample(map3, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
            map4 = F.upsample(map4, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
            loss = structure_loss(map1, gts) + structure_loss(map2, gts) + structure_loss(map3, gts) + structure_loss(map4, gts)
            # ---- metrics ----
            dice_score = dice_m(map2, gts)
            iou_score = iou_m(map2, gts)
            # ---- backward ----
            loss.backward()
            clip_gradient(optimizer, clip)
            optimizer.step()
            # ---- recording loss ----
            if rate == 1:
                loss_record.update(loss.data, batchsize)
                dice.update(dice_score.data, batchsize)
                iou.update(iou_score.data, batchsize)

        # ---- train visualization ----
        if i == total_step:
            print('{} Training Epoch [{:03d}/{:03d}], '
                  '[loss: {:0.4f}, dice: {:0.4f}, iou: {:0.4f}]'.
                  format(datetime.now(), epoch, num_epochs,\
                         loss_record.show(), dice.show(), iou.show()))

    ckpt_path = save_path + 'last.pth'
    print('[Saving Checkpoint:]', ckpt_path)
    checkpoint = {
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': lr_scheduler.state_dict()
    }
    torch.save(checkpoint, ckpt_path)

    log = OrderedDict([
        ('loss', loss_record.show()), ('dice', dice.show()), ('iou', iou.show()),
    ])

    return log

In [8]:
def recall_np(y_true, y_pred):
    true_positives = np.sum(np.round(np.clip(y_true * y_pred, 0, 1)))
    possible_positives = np.sum(np.round(np.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall

def precision_np(y_true, y_pred):
    true_positives = np.sum(np.round(np.clip(y_true * y_pred, 0, 1)))
    predicted_positives = np.sum(np.round(np.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision

def dice_np(y_true, y_pred):
    precision = precision_np(y_true, y_pred)
    recall = recall_np(y_true, y_pred)
    return 2*((precision*recall)/(precision+recall+K.epsilon()))

def iou_np(y_true, y_pred):
    intersection = np.sum(np.round(np.clip(y_true * y_pred, 0, 1)))
    union = np.sum(y_true)+np.sum(y_pred)-intersection
    return intersection/(union+K.epsilon())

def get_scores(gts, prs):
    mean_precision = 0
    mean_recall = 0
    mean_iou = 0
    mean_dice = 0
    for gt, pr in zip(gts, prs):
        mean_precision += precision_np(gt, pr)
        mean_recall += recall_np(gt, pr)
        mean_iou += iou_np(gt, pr)
        mean_dice += dice_np(gt, pr)

    mean_precision /= len(gts)
    mean_recall /= len(gts)
    mean_iou /= len(gts)
    mean_dice /= len(gts)        
    
    print(f"scores: dice={mean_dice}, miou={mean_iou}, precision={mean_precision}, recall={mean_recall}")

    return (mean_iou, mean_dice, mean_precision, mean_recall)



def inference(model,writer=None,epoch=None,test_dataset=None,dataset_name="test"):
    print("#"*20)
    model.eval()
    if True:
#     dataset_names = ['Kvasir', 'CVC-ClinicDB', 'CVC-ColonDB', 'CVC-300', 'ETIS-LaribPolypDB']
#     for dataset_name in dataset_names:
#         data_path = f'../dataset/scenario_4/all_datasets//TestDataset/{dataset_name}'
#         print(data_path)
#         X_test = glob('{}/images/*'.format(data_path))
#         X_test.sort()
#         y_test = glob('{}/masks/*'.format(data_path))
#         y_test.sort()

#         test_dataset = Dataset(X_test, y_test)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=1,
            shuffle=False,
            pin_memory=True,
            drop_last=False)

        print('Dataset_name:', dataset_name)
        tp_all = 0
        fp_all = 0
        fn_all = 0
        mean_iou = 0
        gts = []
        prs = []
        losses = []
        for i, pack in enumerate(test_loader, start=1):
            image, gt_ = pack
            # name = name[0]
            gt = gt_[0][0]
            gt = np.asarray(gt, np.float32)
            image = image.cuda()

            res, res2, res3, res4 = model(image)
            res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
            loss = structure_loss_v2(res.cpu(), gt_)
            losses.append(loss.cpu().detach().numpy().squeeze())
            res = res.sigmoid().data.cpu().numpy().squeeze()
            res = (res - res.min()) / (res.max() - res.min() + 1e-8)
            pr = res.round()
            # cv2.imwrite(os.path.join(save_path, dataset_name, name), res)
            gts.append(gt)
            prs.append(pr)
        mean_iou, mean_dice, mean_precision, mean_recall = get_scores(gts, prs)
        writer.add_scalar(dataset_name + ' val loss',
                            np.mean(losses),
                            epoch )
        writer.add_scalar(dataset_name + ' val dice',
                            mean_dice,
                            epoch )
        writer.add_scalar(dataset_name + ' val iou',
                            mean_iou,
                            epoch )
    print("#"*20)

In [9]:
#from torchvision.transforms.functional import *
#import torchvision.transforms as transforms
from albumentations.augmentations.geometric import  resize,rotate
import albumentations.augmentations.crops.transforms as crop
import albumentations.augmentations.transforms as transforms
train_transform = Compose([
            rotate.RandomRotate90(),
            transforms.Flip(),
            transforms.HueSaturationValue(),
            transforms.RandomBrightnessContrast(),
            transforms.GaussianBlur(),
            transforms.Transpose(),
            OneOf([
                crop.RandomCrop(224, 224, p=1),
                crop.CenterCrop(224, 224, p=1)
            ], p=0.2),
            resize.Resize(384, 384)
        ], p=0.5)

In [10]:
scenario = 6
import random
if scenario == 1:
    train_img_paths = glob('../dataset/CVC_Colon/images/*')+glob('../dataset/ETIS/images/*')
    train_mask_paths = glob('../dataset/CVC_Colon/masks/*')+glob('../dataset/ETIS/masks/*')
    train_img_paths.sort()
    train_mask_paths.sort()
    print(len(train_img_paths),len(train_mask_paths))
    logswriter = 'logsBiFPN_scenario1_augment/'
    trainsave = 'Scenario1-augment-BiFPN'
    X_test = glob('../dataset/CVC_Clinic/images/*')
    y_test = glob('../dataset/CVC_Clinic/masks/*')
    X_test.sort()
    y_test.sort()
    print(len(X_test),len(y_test))
    n=random.randint(0,500)
    print(train_img_paths[n],train_mask_paths[n])
    print(X_test[n],y_test[n])
    test_dataset = Dataset(X_test, y_test)
    #bifpn và RA deu 30 epoch
elif scenario ==2:
    train_img_paths = glob('../dataset/CVC_Colon/images/*')#+glob('../dataset/ETIS/images/*')
    train_mask_paths = glob('../dataset/CVC_Colon/masks/*')#+glob('../dataset/ETIS/masks/*')
    train_img_paths.sort()
    train_mask_paths.sort()
    print(len(train_img_paths),len(train_mask_paths))
    logswriter = 'logsRA_scenario2_augment/'
    trainsave = 'Scenario2-augment-RA'
    X_test = glob('../dataset/CVC_Clinic/images/*')
    y_test = glob('../dataset/CVC_Clinic/masks/*')
    X_test.sort()
    y_test.sort()
    print(len(X_test),len(y_test))
    n=random.randint(0,500)
    print(train_img_paths[n],train_mask_paths[n])
    print(X_test[n],y_test[n])
    test_dataset = Dataset(X_test, y_test)
    #biFPN 30 epoch
elif scenario ==3:
    train_img_paths = glob('../dataset/CVC_Clinic/images/*')#glob('../dataset/CVC_Colon/images/*')#+
    train_mask_paths = glob('../dataset/CVC_Clinic/masks/*')#glob('../dataset/CVC_Colon/masks/*')#+glob('../dataset/ETIS/masks/*')
    train_img_paths.sort()
    train_mask_paths.sort()
    print(len(train_img_paths),len(train_mask_paths))
    logswriter = 'logsBiFPN_scenario3_augment/'
    trainsave = 'Scenario3-augment-BiFPN'
    X_test = glob('../dataset/ETIS/images/*')
    y_test = glob('../dataset/ETIS/masks/*')
    X_test.sort()
    y_test.sort()
    print(len(X_test),len(y_test))
    n=random.randint(0,100)
    print(train_img_paths[n],train_mask_paths[n])
    print(X_test[n],y_test[n])
    test_dataset = Dataset(X_test, y_test)
    #biFPN 30 epoch
elif scenario ==5:
    fold = 4
    li = [0,1,2,3,4]
    train_img_paths = []
    train_mask_paths =[]
    for i in li:
        if fold == i: continue
        train_img_paths += glob(f"../dataset/Clinic_fold_new/fold_{i}/images/*")
        train_mask_paths += glob(f"../dataset/Clinic_fold_new/fold_{i}/masks/*")
    X_test = glob(f"../dataset/Clinic_fold_new/fold_{fold}/images/*") 
    y_test = glob(f"../dataset/Clinic_fold_new/fold_{fold}/masks/*")
    print(len(train_img_paths),len(train_mask_paths))
    logswriter = f"logsRA_scenario5_augment/fold{fold}_"
    trainsave = f"Scenario5-augment-RA-fold{fold}"
    X_test.sort()
    y_test.sort()
    print(len(X_test),len(y_test))
    n=random.randint(0,100)
    print(train_img_paths[n],train_mask_paths[n])
    print(X_test[n],y_test[n])
    test_dataset = Dataset(X_test, y_test)
elif scenario ==6:
    fold = 4
    li = [0,1,2,3,4]
    train_img_paths = []
    train_mask_paths =[]
    for i in li:
        if fold == i: continue
        train_img_paths += glob(f"../dataset/Kvasir_fold/fold_{i}/images/*")
        train_mask_paths += glob(f"../dataset/Kvasir_fold/fold_{i}/masks/*")
    X_test = glob(f"../dataset/Kvasir_fold/fold_{fold}/images/*") 
    y_test = glob(f"../dataset/Kvasir_fold/fold_{fold}/masks/*")
    print(len(train_img_paths),len(train_mask_paths))
    logswriter = f"logsRA_scenario6_augment/fold{fold}_"
    trainsave = f"Scenario6-augment-RA-fold{fold}"
    X_test.sort()
    y_test.sort()
    print(len(X_test),len(y_test))
    n=random.randint(0,100)
    print(train_img_paths[n],train_mask_paths[n])
    print(X_test[n],y_test[n])
    test_dataset = Dataset(X_test, y_test)

800 800
200 200
../dataset/Kvasir_fold/fold_0/images/cju87xn2snfmv0987sc3d9xnq.png ../dataset/Kvasir_fold/fold_0/masks/cju87xn2snfmv0987sc3d9xnq.png
../dataset/Kvasir_fold/fold_4/images/cju3y9difj6th0801kd1rqm3w.png ../dataset/Kvasir_fold/fold_4/masks/cju3y9difj6th0801kd1rqm3w.png


In [11]:
scenario = 6
import random
if scenario == 1:
    train_img_paths = glob('../dataset/CVC_Colon/images/*')+glob('../dataset/ETIS/images/*')
    train_mask_paths = glob('../dataset/CVC_Colon/masks/*')+glob('../dataset/ETIS/masks/*')
    train_img_paths.sort()
    train_mask_paths.sort()
    print(len(train_img_paths),len(train_mask_paths))
    logswriter = 'logsRA_scenario1_augment_bottleneck/'
    trainsave = 'Scenario1-augment-RA-bottleneck'
    X_test = glob('../dataset/CVC_Clinic/images/*')
    y_test = glob('../dataset/CVC_Clinic/masks/*')
    X_test.sort()
    y_test.sort()
    print(len(X_test),len(y_test))
    n=random.randint(0,500)
    print(train_img_paths[n],train_mask_paths[n])
    print(X_test[n],y_test[n])
    test_dataset = Dataset(X_test, y_test)
    #bifpn và RA deu 30 epoch
elif scenario ==2:
    train_img_paths = glob('../dataset/CVC_Colon/images/*')#+glob('../dataset/ETIS/images/*')
    train_mask_paths = glob('../dataset/CVC_Colon/masks/*')#+glob('../dataset/ETIS/masks/*')
    train_img_paths.sort()
    train_mask_paths.sort()
    print(len(train_img_paths),len(train_mask_paths))
    logswriter = 'logsRA_scenario2_augment_bottleneck/'
    trainsave = 'Scenario2-augment-RA-bottleneck'
    X_test = glob('../dataset/CVC_Clinic/images/*')
    y_test = glob('../dataset/CVC_Clinic/masks/*')
    X_test.sort()
    y_test.sort()
    print(len(X_test),len(y_test))
    n=random.randint(0,100)
    print(train_img_paths[n],train_mask_paths[n])
    print(X_test[n],y_test[n])
    test_dataset = Dataset(X_test, y_test)
    #biFPN 30 epoch
elif scenario ==3:
    train_img_paths = glob('../dataset/CVC_Clinic/images/*')#glob('../dataset/CVC_Colon/images/*')#+
    train_mask_paths = glob('../dataset/CVC_Clinic/masks/*')#glob('../dataset/CVC_Colon/masks/*')#+glob('../dataset/ETIS/masks/*')
    train_img_paths.sort()
    train_mask_paths.sort()
    print(len(train_img_paths),len(train_mask_paths))
    logswriter = 'logsRA_scenario3_augment_bottleneck/'
    trainsave = 'Scenario3-augment-RA-bottleneck'
    X_test = glob('../dataset/ETIS/images/*')
    y_test = glob('../dataset/ETIS/masks/*')
    X_test.sort()
    y_test.sort()
    print(len(X_test),len(y_test))
    n=random.randint(0,100)
    print(train_img_paths[n],train_mask_paths[n])
    print(X_test[n],y_test[n])
    test_dataset = Dataset(X_test, y_test)
    #biFPN 30 epoch
elif scenario ==5:
    fold = 1
    li = [0,1,2,3,4]
    train_img_paths = []
    train_mask_paths =[]
    for i in li:
        if fold == i: continue
        train_img_paths += glob(f"../dataset/Clinic_fold_new/fold_{i}/images/*")
        train_mask_paths += glob(f"../dataset/Clinic_fold_new/fold_{i}/masks/*")
    X_test = glob(f"../dataset/Clinic_fold_new/fold_{fold}/images/*") 
    y_test = glob(f"../dataset/Clinic_fold_new/fold_{fold}/masks/*")
    print(len(train_img_paths),len(train_mask_paths))
    logswriter = f"logsRA_scenario5_augment_bottleneck/fold{fold}_"
    trainsave = f"Scenario5-augment-RA-bottleneck-fold{fold}"
    X_test.sort()
    y_test.sort()
    print(len(X_test),len(y_test))
    n=random.randint(0,100)
    print(train_img_paths[n],train_mask_paths[n])
    print(X_test[n],y_test[n])
    test_dataset = Dataset(X_test, y_test)
elif scenario ==6:
    fold = 4
    li = [0,1,2,3,4]
    train_img_paths = []
    train_mask_paths =[]
    for i in li:
        if fold == i: continue
        train_img_paths += glob(f"../dataset/Kvasir_fold/fold_{i}/images/*")
        train_mask_paths += glob(f"../dataset/Kvasir_fold/fold_{i}/masks/*")
    X_test = glob(f"../dataset/Kvasir_fold/fold_{fold}/images/*") 
    y_test = glob(f"../dataset/Kvasir_fold/fold_{fold}/masks/*")
    print(len(train_img_paths),len(train_mask_paths))
    logswriter = f"logsRA_scenario6_augment_bottleneck/fold{fold}_"
    trainsave = f"Scenario6-augment-RA-bottleneck-fold{fold}"
    X_test.sort()
    y_test.sort()
    print(len(X_test),len(y_test))
    n=random.randint(0,100)
    print(train_img_paths[n],train_mask_paths[n])
    print(X_test[n],y_test[n])
    test_dataset = Dataset(X_test, y_test)

800 800
200 200
../dataset/Kvasir_fold/fold_0/images/cju40jl7skiuo0817p0smlgg8.png ../dataset/Kvasir_fold/fold_0/masks/cju40jl7skiuo0817p0smlgg8.png
../dataset/Kvasir_fold/fold_4/images/cju15jr8jz8sb0855ukmkswkz.png ../dataset/Kvasir_fold/fold_4/masks/cju15jr8jz8sb0855ukmkswkz.png


In [12]:
from torch.utils.tensorboard import SummaryWriter

# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter(logswriter+"lan1")

init_lr = 1e-4
batchsize = 8
trainsize_init = 384
clip = 0.5
num_epochs= 30
train_save = trainsave+"_lan1"

save_path = './snapshots/{}/'.format(train_save)
if not os.path.exists(save_path):
    os.makedirs(save_path, exist_ok=True)
else:
    print("Save path existed")
#     sys.exit(1)


log = pd.DataFrame(index=[], columns=[
    'epoch', 'lr', 'loss', 'dice', 'iou', 'val_loss', 'val_dice', 'val_iou'
])
# train_img_paths = []
# train_mask_paths = []

# train_img_paths.sort()
# train_mask_paths.sort()

train_dataset = Dataset(train_img_paths, train_mask_paths,transform =train_transform)#train_transform
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batchsize,
    shuffle=True,
    pin_memory=True,
    drop_last=True
)

total_step = len(train_loader)
model = Net2( backbone=dict(
                  type='mit_b3',
                  style='pytorch'), 
              decode_head=dict(
                  type='UPerHead',
                  in_channels=[64, 128, 320, 448],
                  in_index=[0, 1, 2, 3],
                  channels=128,
                  dropout_ratio=0.1,
                  num_classes=1,
                  norm_cfg=dict(type='BN', requires_grad=True),
                   compound_coef=4,
                  align_corners=False,
                  decoder_params=dict(embed_dim=768),
                  loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
              neck=None,
              auxiliary_head=None,
              train_cfg=dict(),
              test_cfg=dict(mode='whole'),
              pretrained='pretrained/mit_b3.pth').cuda()
model.load_state_dict(torch.load(save_path+"last.pth")['state_dict'])
inference(model,None,0,test_dataset)
# ---- flops and params ----
params = model.parameters()
optimizer = torch.optim.Adam(params, init_lr)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 
                                    T_max=len(train_loader)*num_epochs,
                                    eta_min=init_lr/1000)



start_epoch = 1

ckpt_path = ''
if ckpt_path != '':
    log = pd.read_csv(ckpt_path.replace('last.pth', 'log.csv'))
    checkpoint = torch.load(ckpt_path)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    lr_scheduler.load_state_dict(checkpoint['scheduler'])
    optimizer.load_state_dict(checkpoint['optimizer'])

print("#"*20, f"Start Training", "#"*20)
for epoch in range(start_epoch, num_epochs+1):
    #inference(model,writer,epoch)
    train_log = train(train_loader, model, optimizer, epoch, lr_scheduler)

    log_tmp = pd.Series([epoch, optimizer.param_groups[0]["lr"], 
            train_log['loss'].item(), train_log['dice'].item(), train_log['iou'].item(),  
    ], index=['epoch', 'lr', 'loss', 'dice', 'iou'])
    log = log.append(log_tmp, ignore_index=True)
    log.to_csv(f'./snapshots/{train_save}/log.csv', index=False)
    writer.add_scalar('training loss',
                            train_log['loss'].item(),
                            epoch )
    writer.add_scalar('training dice',
                            train_log['dice'].item(),
                            epoch )
    writer.add_scalar('training iou',
                            train_log['iou'].item(),
                            epoch )
    
    if epoch >= num_epochs-15:
        inference(model,writer,epoch,test_dataset)

Save path existed


2022-07-09 15:04:10,949 - mmseg - INFO - Use load_from_local loader

unexpected key in source state_dict: head.weight, head.bias



RuntimeError: Error(s) in loading state_dict for CaraSegUPer_wBiRAFPN:
	Unexpected key(s) in state_dict: "bifpn.5.p6_w1", "bifpn.5.p5_w1", "bifpn.5.p4_w1", "bifpn.5.p3_w1", "bifpn.5.p4_w2", "bifpn.5.p5_w2", "bifpn.5.p6_w2", "bifpn.5.p7_w2", "bifpn.5.conv6_up.depthwise_conv.conv.weight", "bifpn.5.conv6_up.pointwise_conv.conv.weight", "bifpn.5.conv6_up.pointwise_conv.conv.bias", "bifpn.5.conv6_up.bn.weight", "bifpn.5.conv6_up.bn.bias", "bifpn.5.conv6_up.bn.running_mean", "bifpn.5.conv6_up.bn.running_var", "bifpn.5.conv6_up.bn.num_batches_tracked", "bifpn.5.conv5_up.depthwise_conv.conv.weight", "bifpn.5.conv5_up.pointwise_conv.conv.weight", "bifpn.5.conv5_up.pointwise_conv.conv.bias", "bifpn.5.conv5_up.bn.weight", "bifpn.5.conv5_up.bn.bias", "bifpn.5.conv5_up.bn.running_mean", "bifpn.5.conv5_up.bn.running_var", "bifpn.5.conv5_up.bn.num_batches_tracked", "bifpn.5.conv4_up.depthwise_conv.conv.weight", "bifpn.5.conv4_up.pointwise_conv.conv.weight", "bifpn.5.conv4_up.pointwise_conv.conv.bias", "bifpn.5.conv4_up.bn.weight", "bifpn.5.conv4_up.bn.bias", "bifpn.5.conv4_up.bn.running_mean", "bifpn.5.conv4_up.bn.running_var", "bifpn.5.conv4_up.bn.num_batches_tracked", "bifpn.5.conv3_up.depthwise_conv.conv.weight", "bifpn.5.conv3_up.pointwise_conv.conv.weight", "bifpn.5.conv3_up.pointwise_conv.conv.bias", "bifpn.5.conv3_up.bn.weight", "bifpn.5.conv3_up.bn.bias", "bifpn.5.conv3_up.bn.running_mean", "bifpn.5.conv3_up.bn.running_var", "bifpn.5.conv3_up.bn.num_batches_tracked", "bifpn.5.conv4_down.depthwise_conv.conv.weight", "bifpn.5.conv4_down.pointwise_conv.conv.weight", "bifpn.5.conv4_down.pointwise_conv.conv.bias", "bifpn.5.conv4_down.bn.weight", "bifpn.5.conv4_down.bn.bias", "bifpn.5.conv4_down.bn.running_mean", "bifpn.5.conv4_down.bn.running_var", "bifpn.5.conv4_down.bn.num_batches_tracked", "bifpn.5.conv5_down.depthwise_conv.conv.weight", "bifpn.5.conv5_down.pointwise_conv.conv.weight", "bifpn.5.conv5_down.pointwise_conv.conv.bias", "bifpn.5.conv5_down.bn.weight", "bifpn.5.conv5_down.bn.bias", "bifpn.5.conv5_down.bn.running_mean", "bifpn.5.conv5_down.bn.running_var", "bifpn.5.conv5_down.bn.num_batches_tracked", "bifpn.5.conv6_down.depthwise_conv.conv.weight", "bifpn.5.conv6_down.pointwise_conv.conv.weight", "bifpn.5.conv6_down.pointwise_conv.conv.bias", "bifpn.5.conv6_down.bn.weight", "bifpn.5.conv6_down.bn.bias", "bifpn.5.conv6_down.bn.running_mean", "bifpn.5.conv6_down.bn.running_var", "bifpn.5.conv6_down.bn.num_batches_tracked", "bifpn.5.conv7_down.depthwise_conv.conv.weight", "bifpn.5.conv7_down.pointwise_conv.conv.weight", "bifpn.5.conv7_down.pointwise_conv.conv.bias", "bifpn.5.conv7_down.bn.weight", "bifpn.5.conv7_down.bn.bias", "bifpn.5.conv7_down.bn.running_mean", "bifpn.5.conv7_down.bn.running_var", "bifpn.5.conv7_down.bn.num_batches_tracked", "bifpn.5.RA_p7_6.ra_conv1.conv.weight", "bifpn.5.RA_p7_6.ra_conv1.bn.weight", "bifpn.5.RA_p7_6.ra_conv1.bn.bias", "bifpn.5.RA_p7_6.ra_conv1.bn.running_mean", "bifpn.5.RA_p7_6.ra_conv1.bn.running_var", "bifpn.5.RA_p7_6.ra_conv1.bn.num_batches_tracked", "bifpn.5.RA_p7_6.ra_conv2.conv.weight", "bifpn.5.RA_p7_6.ra_conv2.bn.weight", "bifpn.5.RA_p7_6.ra_conv2.bn.bias", "bifpn.5.RA_p7_6.ra_conv2.bn.running_mean", "bifpn.5.RA_p7_6.ra_conv2.bn.running_var", "bifpn.5.RA_p7_6.ra_conv2.bn.num_batches_tracked", "bifpn.5.RA_p7_6.latenconv.conv.weight", "bifpn.5.RA_p7_6.latenconv.bn.weight", "bifpn.5.RA_p7_6.latenconv.bn.bias", "bifpn.5.RA_p7_6.latenconv.bn.running_mean", "bifpn.5.RA_p7_6.latenconv.bn.running_var", "bifpn.5.RA_p7_6.latenconv.bn.num_batches_tracked", "bifpn.5.RA_p7_6.ra_conv3.conv.weight", "bifpn.5.RA_p7_6.ra_conv3.bn.weight", "bifpn.5.RA_p7_6.ra_conv3.bn.bias", "bifpn.5.RA_p7_6.ra_conv3.bn.running_mean", "bifpn.5.RA_p7_6.ra_conv3.bn.running_var", "bifpn.5.RA_p7_6.ra_conv3.bn.num_batches_tracked", "bifpn.5.RA_p6_5.ra_conv1.conv.weight", "bifpn.5.RA_p6_5.ra_conv1.bn.weight", "bifpn.5.RA_p6_5.ra_conv1.bn.bias", "bifpn.5.RA_p6_5.ra_conv1.bn.running_mean", "bifpn.5.RA_p6_5.ra_conv1.bn.running_var", "bifpn.5.RA_p6_5.ra_conv1.bn.num_batches_tracked", "bifpn.5.RA_p6_5.ra_conv2.conv.weight", "bifpn.5.RA_p6_5.ra_conv2.bn.weight", "bifpn.5.RA_p6_5.ra_conv2.bn.bias", "bifpn.5.RA_p6_5.ra_conv2.bn.running_mean", "bifpn.5.RA_p6_5.ra_conv2.bn.running_var", "bifpn.5.RA_p6_5.ra_conv2.bn.num_batches_tracked", "bifpn.5.RA_p6_5.latenconv.conv.weight", "bifpn.5.RA_p6_5.latenconv.bn.weight", "bifpn.5.RA_p6_5.latenconv.bn.bias", "bifpn.5.RA_p6_5.latenconv.bn.running_mean", "bifpn.5.RA_p6_5.latenconv.bn.running_var", "bifpn.5.RA_p6_5.latenconv.bn.num_batches_tracked", "bifpn.5.RA_p6_5.ra_conv3.conv.weight", "bifpn.5.RA_p6_5.ra_conv3.bn.weight", "bifpn.5.RA_p6_5.ra_conv3.bn.bias", "bifpn.5.RA_p6_5.ra_conv3.bn.running_mean", "bifpn.5.RA_p6_5.ra_conv3.bn.running_var", "bifpn.5.RA_p6_5.ra_conv3.bn.num_batches_tracked", "bifpn.5.RA_p5_4.ra_conv1.conv.weight", "bifpn.5.RA_p5_4.ra_conv1.bn.weight", "bifpn.5.RA_p5_4.ra_conv1.bn.bias", "bifpn.5.RA_p5_4.ra_conv1.bn.running_mean", "bifpn.5.RA_p5_4.ra_conv1.bn.running_var", "bifpn.5.RA_p5_4.ra_conv1.bn.num_batches_tracked", "bifpn.5.RA_p5_4.ra_conv2.conv.weight", "bifpn.5.RA_p5_4.ra_conv2.bn.weight", "bifpn.5.RA_p5_4.ra_conv2.bn.bias", "bifpn.5.RA_p5_4.ra_conv2.bn.running_mean", "bifpn.5.RA_p5_4.ra_conv2.bn.running_var", "bifpn.5.RA_p5_4.ra_conv2.bn.num_batches_tracked", "bifpn.5.RA_p5_4.latenconv.conv.weight", "bifpn.5.RA_p5_4.latenconv.bn.weight", "bifpn.5.RA_p5_4.latenconv.bn.bias", "bifpn.5.RA_p5_4.latenconv.bn.running_mean", "bifpn.5.RA_p5_4.latenconv.bn.running_var", "bifpn.5.RA_p5_4.latenconv.bn.num_batches_tracked", "bifpn.5.RA_p5_4.ra_conv3.conv.weight", "bifpn.5.RA_p5_4.ra_conv3.bn.weight", "bifpn.5.RA_p5_4.ra_conv3.bn.bias", "bifpn.5.RA_p5_4.ra_conv3.bn.running_mean", "bifpn.5.RA_p5_4.ra_conv3.bn.running_var", "bifpn.5.RA_p5_4.ra_conv3.bn.num_batches_tracked", "bifpn.5.RA_p4_3.ra_conv1.conv.weight", "bifpn.5.RA_p4_3.ra_conv1.bn.weight", "bifpn.5.RA_p4_3.ra_conv1.bn.bias", "bifpn.5.RA_p4_3.ra_conv1.bn.running_mean", "bifpn.5.RA_p4_3.ra_conv1.bn.running_var", "bifpn.5.RA_p4_3.ra_conv1.bn.num_batches_tracked", "bifpn.5.RA_p4_3.ra_conv2.conv.weight", "bifpn.5.RA_p4_3.ra_conv2.bn.weight", "bifpn.5.RA_p4_3.ra_conv2.bn.bias", "bifpn.5.RA_p4_3.ra_conv2.bn.running_mean", "bifpn.5.RA_p4_3.ra_conv2.bn.running_var", "bifpn.5.RA_p4_3.ra_conv2.bn.num_batches_tracked", "bifpn.5.RA_p4_3.latenconv.conv.weight", "bifpn.5.RA_p4_3.latenconv.bn.weight", "bifpn.5.RA_p4_3.latenconv.bn.bias", "bifpn.5.RA_p4_3.latenconv.bn.running_mean", "bifpn.5.RA_p4_3.latenconv.bn.running_var", "bifpn.5.RA_p4_3.latenconv.bn.num_batches_tracked", "bifpn.5.RA_p4_3.ra_conv3.conv.weight", "bifpn.5.RA_p4_3.ra_conv3.bn.weight", "bifpn.5.RA_p4_3.ra_conv3.bn.bias", "bifpn.5.RA_p4_3.ra_conv3.bn.running_mean", "bifpn.5.RA_p4_3.ra_conv3.bn.running_var", "bifpn.5.RA_p4_3.ra_conv3.bn.num_batches_tracked", "bifpn.6.p6_w1", "bifpn.6.p5_w1", "bifpn.6.p4_w1", "bifpn.6.p3_w1", "bifpn.6.p4_w2", "bifpn.6.p5_w2", "bifpn.6.p6_w2", "bifpn.6.p7_w2", "bifpn.6.conv6_up.depthwise_conv.conv.weight", "bifpn.6.conv6_up.pointwise_conv.conv.weight", "bifpn.6.conv6_up.pointwise_conv.conv.bias", "bifpn.6.conv6_up.bn.weight", "bifpn.6.conv6_up.bn.bias", "bifpn.6.conv6_up.bn.running_mean", "bifpn.6.conv6_up.bn.running_var", "bifpn.6.conv6_up.bn.num_batches_tracked", "bifpn.6.conv5_up.depthwise_conv.conv.weight", "bifpn.6.conv5_up.pointwise_conv.conv.weight", "bifpn.6.conv5_up.pointwise_conv.conv.bias", "bifpn.6.conv5_up.bn.weight", "bifpn.6.conv5_up.bn.bias", "bifpn.6.conv5_up.bn.running_mean", "bifpn.6.conv5_up.bn.running_var", "bifpn.6.conv5_up.bn.num_batches_tracked", "bifpn.6.conv4_up.depthwise_conv.conv.weight", "bifpn.6.conv4_up.pointwise_conv.conv.weight", "bifpn.6.conv4_up.pointwise_conv.conv.bias", "bifpn.6.conv4_up.bn.weight", "bifpn.6.conv4_up.bn.bias", "bifpn.6.conv4_up.bn.running_mean", "bifpn.6.conv4_up.bn.running_var", "bifpn.6.conv4_up.bn.num_batches_tracked", "bifpn.6.conv3_up.depthwise_conv.conv.weight", "bifpn.6.conv3_up.pointwise_conv.conv.weight", "bifpn.6.conv3_up.pointwise_conv.conv.bias", "bifpn.6.conv3_up.bn.weight", "bifpn.6.conv3_up.bn.bias", "bifpn.6.conv3_up.bn.running_mean", "bifpn.6.conv3_up.bn.running_var", "bifpn.6.conv3_up.bn.num_batches_tracked", "bifpn.6.conv4_down.depthwise_conv.conv.weight", "bifpn.6.conv4_down.pointwise_conv.conv.weight", "bifpn.6.conv4_down.pointwise_conv.conv.bias", "bifpn.6.conv4_down.bn.weight", "bifpn.6.conv4_down.bn.bias", "bifpn.6.conv4_down.bn.running_mean", "bifpn.6.conv4_down.bn.running_var", "bifpn.6.conv4_down.bn.num_batches_tracked", "bifpn.6.conv5_down.depthwise_conv.conv.weight", "bifpn.6.conv5_down.pointwise_conv.conv.weight", "bifpn.6.conv5_down.pointwise_conv.conv.bias", "bifpn.6.conv5_down.bn.weight", "bifpn.6.conv5_down.bn.bias", "bifpn.6.conv5_down.bn.running_mean", "bifpn.6.conv5_down.bn.running_var", "bifpn.6.conv5_down.bn.num_batches_tracked", "bifpn.6.conv6_down.depthwise_conv.conv.weight", "bifpn.6.conv6_down.pointwise_conv.conv.weight", "bifpn.6.conv6_down.pointwise_conv.conv.bias", "bifpn.6.conv6_down.bn.weight", "bifpn.6.conv6_down.bn.bias", "bifpn.6.conv6_down.bn.running_mean", "bifpn.6.conv6_down.bn.running_var", "bifpn.6.conv6_down.bn.num_batches_tracked", "bifpn.6.conv7_down.depthwise_conv.conv.weight", "bifpn.6.conv7_down.pointwise_conv.conv.weight", "bifpn.6.conv7_down.pointwise_conv.conv.bias", "bifpn.6.conv7_down.bn.weight", "bifpn.6.conv7_down.bn.bias", "bifpn.6.conv7_down.bn.running_mean", "bifpn.6.conv7_down.bn.running_var", "bifpn.6.conv7_down.bn.num_batches_tracked", "bifpn.6.RA_p7_6.ra_conv1.conv.weight", "bifpn.6.RA_p7_6.ra_conv1.bn.weight", "bifpn.6.RA_p7_6.ra_conv1.bn.bias", "bifpn.6.RA_p7_6.ra_conv1.bn.running_mean", "bifpn.6.RA_p7_6.ra_conv1.bn.running_var", "bifpn.6.RA_p7_6.ra_conv1.bn.num_batches_tracked", "bifpn.6.RA_p7_6.ra_conv2.conv.weight", "bifpn.6.RA_p7_6.ra_conv2.bn.weight", "bifpn.6.RA_p7_6.ra_conv2.bn.bias", "bifpn.6.RA_p7_6.ra_conv2.bn.running_mean", "bifpn.6.RA_p7_6.ra_conv2.bn.running_var", "bifpn.6.RA_p7_6.ra_conv2.bn.num_batches_tracked", "bifpn.6.RA_p7_6.latenconv.conv.weight", "bifpn.6.RA_p7_6.latenconv.bn.weight", "bifpn.6.RA_p7_6.latenconv.bn.bias", "bifpn.6.RA_p7_6.latenconv.bn.running_mean", "bifpn.6.RA_p7_6.latenconv.bn.running_var", "bifpn.6.RA_p7_6.latenconv.bn.num_batches_tracked", "bifpn.6.RA_p7_6.ra_conv3.conv.weight", "bifpn.6.RA_p7_6.ra_conv3.bn.weight", "bifpn.6.RA_p7_6.ra_conv3.bn.bias", "bifpn.6.RA_p7_6.ra_conv3.bn.running_mean", "bifpn.6.RA_p7_6.ra_conv3.bn.running_var", "bifpn.6.RA_p7_6.ra_conv3.bn.num_batches_tracked", "bifpn.6.RA_p6_5.ra_conv1.conv.weight", "bifpn.6.RA_p6_5.ra_conv1.bn.weight", "bifpn.6.RA_p6_5.ra_conv1.bn.bias", "bifpn.6.RA_p6_5.ra_conv1.bn.running_mean", "bifpn.6.RA_p6_5.ra_conv1.bn.running_var", "bifpn.6.RA_p6_5.ra_conv1.bn.num_batches_tracked", "bifpn.6.RA_p6_5.ra_conv2.conv.weight", "bifpn.6.RA_p6_5.ra_conv2.bn.weight", "bifpn.6.RA_p6_5.ra_conv2.bn.bias", "bifpn.6.RA_p6_5.ra_conv2.bn.running_mean", "bifpn.6.RA_p6_5.ra_conv2.bn.running_var", "bifpn.6.RA_p6_5.ra_conv2.bn.num_batches_tracked", "bifpn.6.RA_p6_5.latenconv.conv.weight", "bifpn.6.RA_p6_5.latenconv.bn.weight", "bifpn.6.RA_p6_5.latenconv.bn.bias", "bifpn.6.RA_p6_5.latenconv.bn.running_mean", "bifpn.6.RA_p6_5.latenconv.bn.running_var", "bifpn.6.RA_p6_5.latenconv.bn.num_batches_tracked", "bifpn.6.RA_p6_5.ra_conv3.conv.weight", "bifpn.6.RA_p6_5.ra_conv3.bn.weight", "bifpn.6.RA_p6_5.ra_conv3.bn.bias", "bifpn.6.RA_p6_5.ra_conv3.bn.running_mean", "bifpn.6.RA_p6_5.ra_conv3.bn.running_var", "bifpn.6.RA_p6_5.ra_conv3.bn.num_batches_tracked", "bifpn.6.RA_p5_4.ra_conv1.conv.weight", "bifpn.6.RA_p5_4.ra_conv1.bn.weight", "bifpn.6.RA_p5_4.ra_conv1.bn.bias", "bifpn.6.RA_p5_4.ra_conv1.bn.running_mean", "bifpn.6.RA_p5_4.ra_conv1.bn.running_var", "bifpn.6.RA_p5_4.ra_conv1.bn.num_batches_tracked", "bifpn.6.RA_p5_4.ra_conv2.conv.weight", "bifpn.6.RA_p5_4.ra_conv2.bn.weight", "bifpn.6.RA_p5_4.ra_conv2.bn.bias", "bifpn.6.RA_p5_4.ra_conv2.bn.running_mean", "bifpn.6.RA_p5_4.ra_conv2.bn.running_var", "bifpn.6.RA_p5_4.ra_conv2.bn.num_batches_tracked", "bifpn.6.RA_p5_4.latenconv.conv.weight", "bifpn.6.RA_p5_4.latenconv.bn.weight", "bifpn.6.RA_p5_4.latenconv.bn.bias", "bifpn.6.RA_p5_4.latenconv.bn.running_mean", "bifpn.6.RA_p5_4.latenconv.bn.running_var", "bifpn.6.RA_p5_4.latenconv.bn.num_batches_tracked", "bifpn.6.RA_p5_4.ra_conv3.conv.weight", "bifpn.6.RA_p5_4.ra_conv3.bn.weight", "bifpn.6.RA_p5_4.ra_conv3.bn.bias", "bifpn.6.RA_p5_4.ra_conv3.bn.running_mean", "bifpn.6.RA_p5_4.ra_conv3.bn.running_var", "bifpn.6.RA_p5_4.ra_conv3.bn.num_batches_tracked", "bifpn.6.RA_p4_3.ra_conv1.conv.weight", "bifpn.6.RA_p4_3.ra_conv1.bn.weight", "bifpn.6.RA_p4_3.ra_conv1.bn.bias", "bifpn.6.RA_p4_3.ra_conv1.bn.running_mean", "bifpn.6.RA_p4_3.ra_conv1.bn.running_var", "bifpn.6.RA_p4_3.ra_conv1.bn.num_batches_tracked", "bifpn.6.RA_p4_3.ra_conv2.conv.weight", "bifpn.6.RA_p4_3.ra_conv2.bn.weight", "bifpn.6.RA_p4_3.ra_conv2.bn.bias", "bifpn.6.RA_p4_3.ra_conv2.bn.running_mean", "bifpn.6.RA_p4_3.ra_conv2.bn.running_var", "bifpn.6.RA_p4_3.ra_conv2.bn.num_batches_tracked", "bifpn.6.RA_p4_3.latenconv.conv.weight", "bifpn.6.RA_p4_3.latenconv.bn.weight", "bifpn.6.RA_p4_3.latenconv.bn.bias", "bifpn.6.RA_p4_3.latenconv.bn.running_mean", "bifpn.6.RA_p4_3.latenconv.bn.running_var", "bifpn.6.RA_p4_3.latenconv.bn.num_batches_tracked", "bifpn.6.RA_p4_3.ra_conv3.conv.weight", "bifpn.6.RA_p4_3.ra_conv3.bn.weight", "bifpn.6.RA_p4_3.ra_conv3.bn.bias", "bifpn.6.RA_p4_3.ra_conv3.bn.running_mean", "bifpn.6.RA_p4_3.ra_conv3.bn.running_var", "bifpn.6.RA_p4_3.ra_conv3.bn.num_batches_tracked". 

In [22]:
from torch.utils.tensorboard import SummaryWriter
del model
# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter(logswriter+"lan2")

init_lr = 1e-4
batchsize = 8
trainsize_init = 384
clip = 0.5
num_epochs= 30
train_save = trainsave+"_lan2"

save_path = './snapshots/{}/'.format(train_save)
if not os.path.exists(save_path):
    os.makedirs(save_path, exist_ok=True)
else:
    print("Save path existed")
#     sys.exit(1)


log = pd.DataFrame(index=[], columns=[
    'epoch', 'lr', 'loss', 'dice', 'iou', 'val_loss', 'val_dice', 'val_iou'
])
# train_img_paths = []
# train_mask_paths = []

# train_img_paths.sort()
# train_mask_paths.sort()

train_dataset = Dataset(train_img_paths, train_mask_paths,transform =train_transform)#train_transform
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batchsize,
    shuffle=True,
    pin_memory=True,
    drop_last=True
)

total_step = len(train_loader)
model = Net2( backbone=dict(
                  type='mit_b3',
                  style='pytorch'), 
              decode_head=dict(
                  type='UPerHead',
                  in_channels=[64, 128, 320, 448],
                  in_index=[0, 1, 2, 3],
                  channels=128,
                  dropout_ratio=0.1,
                  num_classes=1,
                  norm_cfg=dict(type='BN', requires_grad=True),
                   compound_coef=4,
                  align_corners=False,
                  decoder_params=dict(embed_dim=768),
                  loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
              neck=None,
              auxiliary_head=None,
              train_cfg=dict(),
              test_cfg=dict(mode='whole'),
              pretrained='pretrained/mit_b3.pth').cuda()

# ---- flops and params ----
params = model.parameters()
optimizer = torch.optim.Adam(params, init_lr)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 
                                    T_max=len(train_loader)*num_epochs,
                                    eta_min=init_lr/1000)



start_epoch = 1

ckpt_path = ''
if ckpt_path != '':
    log = pd.read_csv(ckpt_path.replace('last.pth', 'log.csv'))
    checkpoint = torch.load(ckpt_path)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    lr_scheduler.load_state_dict(checkpoint['scheduler'])
    optimizer.load_state_dict(checkpoint['optimizer'])

print("#"*20, f"Start Training", "#"*20)
for epoch in range(start_epoch, num_epochs+1):
    #inference(model,writer,epoch)
    train_log = train(train_loader, model, optimizer, epoch, lr_scheduler)

    log_tmp = pd.Series([epoch, optimizer.param_groups[0]["lr"], 
            train_log['loss'].item(), train_log['dice'].item(), train_log['iou'].item(),  
    ], index=['epoch', 'lr', 'loss', 'dice', 'iou'])
    log = log.append(log_tmp, ignore_index=True)
    log.to_csv(f'./snapshots/{train_save}/log.csv', index=False)
    writer.add_scalar('training loss',
                            train_log['loss'].item(),
                            epoch )
    writer.add_scalar('training dice',
                            train_log['dice'].item(),
                            epoch )
    writer.add_scalar('training iou',
                            train_log['iou'].item(),
                            epoch )
    
    if epoch >= num_epochs-20:
        inference(model,writer,epoch,test_dataset)

2022-06-21 00:35:05,954 - mmseg - INFO - Use load_from_local loader

unexpected key in source state_dict: head.weight, head.bias



#################### Start Training ####################
2022-06-21 00:36:30.038457 Training Epoch [001/030], [loss: 3.6158, dice: 0.5440, iou: 0.3914]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan2/last.pth
2022-06-21 00:37:54.720315 Training Epoch [002/030], [loss: 2.5368, dice: 0.7900, iou: 0.6570]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan2/last.pth
2022-06-21 00:39:19.704028 Training Epoch [003/030], [loss: 1.8802, dice: 0.8369, iou: 0.7229]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan2/last.pth
2022-06-21 00:40:44.584577 Training Epoch [004/030], [loss: 1.4939, dice: 0.8539, iou: 0.7484]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan2/last.pth
2022-06-21 00:42:09.448239 Training Epoch [005/030], [loss: 1.2085, dice: 0.8912, iou: 0.8062]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan2/last.pth
2022-06-21 00:43:34.371034 Training Epoch [006/030], [loss: 1.2150, dice:

2022-06-21 01:21:34.696569 Training Epoch [029/030], [loss: 0.5777, dice: 0.9561, iou: 0.9168]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan2/last.pth
####################
Dataset_name: test
scores: dice=0.8228964661015761, miou=0.7367480601279751, precision=0.7818681446684611, recall=0.9250618677228164
####################
2022-06-21 01:23:16.790150 Training Epoch [030/030], [loss: 0.5309, dice: 0.9632, iou: 0.9292]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan2/last.pth
####################
Dataset_name: test
scores: dice=0.8223592247302339, miou=0.7356498385381296, precision=0.7821599559036893, recall=0.9227995965935102
####################


In [23]:
from torch.utils.tensorboard import SummaryWriter

# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter(logswriter+"lan3")

init_lr = 1e-4
batchsize = 8
trainsize_init = 384
clip = 0.5
num_epochs= 30
train_save = trainsave+"_lan3"

save_path = './snapshots/{}/'.format(train_save)
if not os.path.exists(save_path):
    os.makedirs(save_path, exist_ok=True)
else:
    print("Save path existed")
#     sys.exit(1)


log = pd.DataFrame(index=[], columns=[
    'epoch', 'lr', 'loss', 'dice', 'iou', 'val_loss', 'val_dice', 'val_iou'
])
# train_img_paths = []
# train_mask_paths = []

# train_img_paths.sort()
# train_mask_paths.sort()

train_dataset = Dataset(train_img_paths, train_mask_paths,transform =train_transform)#train_transform
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batchsize,
    shuffle=True,
    pin_memory=True,
    drop_last=True
)

total_step = len(train_loader)
model = Net2( backbone=dict(
                  type='mit_b3',
                  style='pytorch'), 
              decode_head=dict(
                  type='UPerHead',
                  in_channels=[64, 128, 320, 448],
                  in_index=[0, 1, 2, 3],
                  channels=128,
                  dropout_ratio=0.1,
                  num_classes=1,
                  norm_cfg=dict(type='BN', requires_grad=True),
                   compound_coef=4,
                  align_corners=False,
                  decoder_params=dict(embed_dim=768),
                  loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
              neck=None,
              auxiliary_head=None,
              train_cfg=dict(),
              test_cfg=dict(mode='whole'),
              pretrained='pretrained/mit_b3.pth').cuda()

# ---- flops and params ----
params = model.parameters()
optimizer = torch.optim.Adam(params, init_lr)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 
                                    T_max=len(train_loader)*num_epochs,
                                    eta_min=init_lr/1000)



start_epoch = 1

ckpt_path = ''
if ckpt_path != '':
    log = pd.read_csv(ckpt_path.replace('last.pth', 'log.csv'))
    checkpoint = torch.load(ckpt_path)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    lr_scheduler.load_state_dict(checkpoint['scheduler'])
    optimizer.load_state_dict(checkpoint['optimizer'])

print("#"*20, f"Start Training", "#"*20)
for epoch in range(start_epoch, num_epochs+1):
    #inference(model,writer,epoch)
    train_log = train(train_loader, model, optimizer, epoch, lr_scheduler)

    log_tmp = pd.Series([epoch, optimizer.param_groups[0]["lr"], 
            train_log['loss'].item(), train_log['dice'].item(), train_log['iou'].item(),  
    ], index=['epoch', 'lr', 'loss', 'dice', 'iou'])
    log = log.append(log_tmp, ignore_index=True)
    log.to_csv(f'./snapshots/{train_save}/log.csv', index=False)
    writer.add_scalar('training loss',
                            train_log['loss'].item(),
                            epoch )
    writer.add_scalar('training dice',
                            train_log['dice'].item(),
                            epoch )
    writer.add_scalar('training iou',
                            train_log['iou'].item(),
                            epoch )
    
    if epoch >= num_epochs-20:
        inference(model,writer,epoch,test_dataset)

2022-06-21 01:23:35,767 - mmseg - INFO - Use load_from_local loader

unexpected key in source state_dict: head.weight, head.bias



#################### Start Training ####################
2022-06-21 01:24:59.923030 Training Epoch [001/030], [loss: 3.7020, dice: 0.5534, iou: 0.3921]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan3/last.pth
2022-06-21 01:26:24.556482 Training Epoch [002/030], [loss: 2.5820, dice: 0.7869, iou: 0.6529]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan3/last.pth
2022-06-21 01:27:49.498570 Training Epoch [003/030], [loss: 1.9716, dice: 0.8177, iou: 0.6971]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan3/last.pth
2022-06-21 01:29:14.231111 Training Epoch [004/030], [loss: 1.5670, dice: 0.8517, iou: 0.7461]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan3/last.pth
2022-06-21 01:30:39.123750 Training Epoch [005/030], [loss: 1.2639, dice: 0.8812, iou: 0.7903]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan3/last.pth
2022-06-21 01:32:04.050906 Training Epoch [006/030], [loss: 1.1102, dice:

2022-06-21 02:10:05.732088 Training Epoch [029/030], [loss: 0.5835, dice: 0.9570, iou: 0.9177]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan3/last.pth
####################
Dataset_name: test
scores: dice=0.8201866639651678, miou=0.7411482242939988, precision=0.802573955959974, recall=0.8888733965361144
####################
2022-06-21 02:11:47.872948 Training Epoch [030/030], [loss: 0.5636, dice: 0.9588, iou: 0.9210]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan3/last.pth
####################
Dataset_name: test
scores: dice=0.8189757631844345, miou=0.7393397663142814, precision=0.7952464902247905, recall=0.8994997082535378
####################


In [24]:
from torch.utils.tensorboard import SummaryWriter

# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter(logswriter+"lan4")

init_lr = 1e-4
batchsize = 8
trainsize_init = 384
clip = 0.5
num_epochs= 30
train_save = trainsave+"_lan4"

save_path = './snapshots/{}/'.format(train_save)
if not os.path.exists(save_path):
    os.makedirs(save_path, exist_ok=True)
else:
    print("Save path existed")
#     sys.exit(1)


log = pd.DataFrame(index=[], columns=[
    'epoch', 'lr', 'loss', 'dice', 'iou', 'val_loss', 'val_dice', 'val_iou'
])
# train_img_paths = []
# train_mask_paths = []

# train_img_paths.sort()
# train_mask_paths.sort()

train_dataset = Dataset(train_img_paths, train_mask_paths,transform =train_transform)#train_transform
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batchsize,
    shuffle=True,
    pin_memory=True,
    drop_last=True
)

total_step = len(train_loader)
model = Net2( backbone=dict(
                  type='mit_b3',
                  style='pytorch'), 
              decode_head=dict(
                  type='UPerHead',
                  in_channels=[64, 128, 320, 448],
                  in_index=[0, 1, 2, 3],
                  channels=128,
                  dropout_ratio=0.1,
                  num_classes=1,
                  norm_cfg=dict(type='BN', requires_grad=True),
                   compound_coef=4,
                  align_corners=False,
                  decoder_params=dict(embed_dim=768),
                  loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
              neck=None,
              auxiliary_head=None,
              train_cfg=dict(),
              test_cfg=dict(mode='whole'),
              pretrained='pretrained/mit_b3.pth').cuda()

# ---- flops and params ----
params = model.parameters()
optimizer = torch.optim.Adam(params, init_lr)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 
                                    T_max=len(train_loader)*num_epochs,
                                    eta_min=init_lr/1000)



start_epoch = 1

ckpt_path = ''
if ckpt_path != '':
    log = pd.read_csv(ckpt_path.replace('last.pth', 'log.csv'))
    checkpoint = torch.load(ckpt_path)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    lr_scheduler.load_state_dict(checkpoint['scheduler'])
    optimizer.load_state_dict(checkpoint['optimizer'])

print("#"*20, f"Start Training", "#"*20)
for epoch in range(start_epoch, num_epochs+1):
    #inference(model,writer,epoch)
    train_log = train(train_loader, model, optimizer, epoch, lr_scheduler)

    log_tmp = pd.Series([epoch, optimizer.param_groups[0]["lr"], 
            train_log['loss'].item(), train_log['dice'].item(), train_log['iou'].item(),  
    ], index=['epoch', 'lr', 'loss', 'dice', 'iou'])
    log = log.append(log_tmp, ignore_index=True)
    log.to_csv(f'./snapshots/{train_save}/log.csv', index=False)
    writer.add_scalar('training loss',
                            train_log['loss'].item(),
                            epoch )
    writer.add_scalar('training dice',
                            train_log['dice'].item(),
                            epoch )
    writer.add_scalar('training iou',
                            train_log['iou'].item(),
                            epoch )
    
    if epoch >= num_epochs-20:
        inference(model,writer,epoch,test_dataset)

2022-06-21 02:12:06,691 - mmseg - INFO - Use load_from_local loader

unexpected key in source state_dict: head.weight, head.bias



#################### Start Training ####################
2022-06-21 02:13:30.669431 Training Epoch [001/030], [loss: 3.5322, dice: 0.6361, iou: 0.4757]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan4/last.pth
2022-06-21 02:14:55.356607 Training Epoch [002/030], [loss: 2.3163, dice: 0.8028, iou: 0.6760]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan4/last.pth
2022-06-21 02:16:20.198027 Training Epoch [003/030], [loss: 1.8201, dice: 0.8147, iou: 0.6951]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan4/last.pth
2022-06-21 02:17:45.354240 Training Epoch [004/030], [loss: 1.4475, dice: 0.8589, iou: 0.7595]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan4/last.pth
2022-06-21 02:19:10.202778 Training Epoch [005/030], [loss: 1.2510, dice: 0.8819, iou: 0.7924]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan4/last.pth
2022-06-21 02:20:35.005263 Training Epoch [006/030], [loss: 1.1224, dice:

2022-06-21 02:58:34.481499 Training Epoch [029/030], [loss: 0.5220, dice: 0.9618, iou: 0.9267]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan4/last.pth
####################
Dataset_name: test
scores: dice=0.7989318730387039, miou=0.7169741887450334, precision=0.7523802312247824, recall=0.9269903945086456
####################
2022-06-21 03:00:16.558143 Training Epoch [030/030], [loss: 0.5260, dice: 0.9628, iou: 0.9284]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan4/last.pth
####################
Dataset_name: test
scores: dice=0.7963891466623289, miou=0.7144073012196092, precision=0.7486533495201166, recall=0.9283074857889155
####################


In [25]:
from torch.utils.tensorboard import SummaryWriter

# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter(logswriter+"lan5")

init_lr = 1e-4
batchsize = 8
trainsize_init = 384
clip = 0.5
num_epochs= 30
train_save = trainsave+"_lan5"

save_path = './snapshots/{}/'.format(train_save)
if not os.path.exists(save_path):
    os.makedirs(save_path, exist_ok=True)
else:
    print("Save path existed")
#     sys.exit(1)


log = pd.DataFrame(index=[], columns=[
    'epoch', 'lr', 'loss', 'dice', 'iou', 'val_loss', 'val_dice', 'val_iou'
])
# train_img_paths = []
# train_mask_paths = []

# train_img_paths.sort()
# train_mask_paths.sort()

train_dataset = Dataset(train_img_paths, train_mask_paths,transform =train_transform)#train_transform
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batchsize,
    shuffle=True,
    pin_memory=True,
    drop_last=True
)

total_step = len(train_loader)
model = Net2( backbone=dict(
                  type='mit_b3',
                  style='pytorch'), 
              decode_head=dict(
                  type='UPerHead',
                  in_channels=[64, 128, 320, 448],
                  in_index=[0, 1, 2, 3],
                  channels=128,
                  dropout_ratio=0.1,
                  num_classes=1,
                  norm_cfg=dict(type='BN', requires_grad=True),
                   compound_coef=4,
                  align_corners=False,
                  decoder_params=dict(embed_dim=768),
                  loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
              neck=None,
              auxiliary_head=None,
              train_cfg=dict(),
              test_cfg=dict(mode='whole'),
              pretrained='pretrained/mit_b3.pth').cuda()

# ---- flops and params ----
params = model.parameters()
optimizer = torch.optim.Adam(params, init_lr)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 
                                    T_max=len(train_loader)*num_epochs,
                                    eta_min=init_lr/1000)



start_epoch = 1

ckpt_path = ''
if ckpt_path != '':
    log = pd.read_csv(ckpt_path.replace('last.pth', 'log.csv'))
    checkpoint = torch.load(ckpt_path)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    lr_scheduler.load_state_dict(checkpoint['scheduler'])
    optimizer.load_state_dict(checkpoint['optimizer'])

print("#"*20, f"Start Training", "#"*20)
for epoch in range(start_epoch, num_epochs+1):
    #inference(model,writer,epoch)
    train_log = train(train_loader, model, optimizer, epoch, lr_scheduler)

    log_tmp = pd.Series([epoch, optimizer.param_groups[0]["lr"], 
            train_log['loss'].item(), train_log['dice'].item(), train_log['iou'].item(),  
    ], index=['epoch', 'lr', 'loss', 'dice', 'iou'])
    log = log.append(log_tmp, ignore_index=True)
    log.to_csv(f'./snapshots/{train_save}/log.csv', index=False)
    writer.add_scalar('training loss',
                            train_log['loss'].item(),
                            epoch )
    writer.add_scalar('training dice',
                            train_log['dice'].item(),
                            epoch )
    writer.add_scalar('training iou',
                            train_log['iou'].item(),
                            epoch )
    
    if epoch >= num_epochs-20:
        inference(model,writer,epoch,test_dataset)

2022-06-21 03:00:35,296 - mmseg - INFO - Use load_from_local loader

unexpected key in source state_dict: head.weight, head.bias



#################### Start Training ####################
2022-06-21 03:01:59.466071 Training Epoch [001/030], [loss: 3.5863, dice: 0.5781, iou: 0.4186]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan5/last.pth
2022-06-21 03:03:24.030854 Training Epoch [002/030], [loss: 2.5338, dice: 0.7749, iou: 0.6392]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan5/last.pth
2022-06-21 03:04:49.033317 Training Epoch [003/030], [loss: 1.7036, dice: 0.8565, iou: 0.7512]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan5/last.pth
2022-06-21 03:06:14.026657 Training Epoch [004/030], [loss: 1.6429, dice: 0.8273, iou: 0.7126]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan5/last.pth
2022-06-21 03:07:38.906961 Training Epoch [005/030], [loss: 1.1725, dice: 0.8973, iou: 0.8155]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan5/last.pth
2022-06-21 03:09:03.816865 Training Epoch [006/030], [loss: 1.1695, dice:

2022-06-21 03:47:05.338383 Training Epoch [029/030], [loss: 0.5939, dice: 0.9578, iou: 0.9192]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan5/last.pth
####################
Dataset_name: test
scores: dice=0.8150001143610538, miou=0.7317286362604741, precision=0.7861968431417551, recall=0.9114107482728385
####################
2022-06-21 03:48:47.602660 Training Epoch [030/030], [loss: 0.5461, dice: 0.9624, iou: 0.9278]
[Saving Checkpoint:] ./snapshots/Scenario3-augment-RA-bottleneck_lan5/last.pth
####################
Dataset_name: test
scores: dice=0.8117646976033528, miou=0.7286407001681064, precision=0.7804056016810214, recall=0.9137684838387311
####################


In [18]:
from torchsummary import summary
summary(model, (3, 384, 384))

Layer (type:depth-idx)                                  Output Shape              Param #
├─mit_b3: 1-1                                           [-1, 64, 96, 96]          --
|    └─OverlapPatchEmbed: 2-1                           [-1, 9216, 64]            --
|    |    └─Conv2d: 3-1                                 [-1, 64, 96, 96]          9,472
|    |    └─LayerNorm: 3-2                              [-1, 9216, 64]            128
|    └─ModuleList: 2                                    []                        --
|    |    └─Block: 3-3                                  [-1, 9216, 64]            314,880
|    |    └─Block: 3-4                                  [-1, 9216, 64]            314,880
|    |    └─Block: 3-5                                  [-1, 9216, 64]            314,880
|    └─LayerNorm: 2-2                                   [-1, 9216, 64]            128
|    └─OverlapPatchEmbed: 2-3                           [-1, 2304, 128]           --
|    |    └─Conv2d: 3-6                 

Layer (type:depth-idx)                                  Output Shape              Param #
├─mit_b3: 1-1                                           [-1, 64, 96, 96]          --
|    └─OverlapPatchEmbed: 2-1                           [-1, 9216, 64]            --
|    |    └─Conv2d: 3-1                                 [-1, 64, 96, 96]          9,472
|    |    └─LayerNorm: 3-2                              [-1, 9216, 64]            128
|    └─ModuleList: 2                                    []                        --
|    |    └─Block: 3-3                                  [-1, 9216, 64]            314,880
|    |    └─Block: 3-4                                  [-1, 9216, 64]            314,880
|    |    └─Block: 3-5                                  [-1, 9216, 64]            314,880
|    └─LayerNorm: 2-2                                   [-1, 9216, 64]            128
|    └─OverlapPatchEmbed: 2-3                           [-1, 2304, 128]           --
|    |    └─Conv2d: 3-6                 