# 0. Import Libraries + Hyperparams

In [1]:
!pip install segmentation_models_pytorch
!pip install kornia

Collecting segmentation_models_pytorch
  Obtaining dependency information for segmentation_models_pytorch from https://files.pythonhosted.org/packages/cb/70/4aac1b240b399b108ce58029ae54bc14497e1bbc275dfab8fd3c84c1e35d/segmentation_models_pytorch-0.3.3-py3-none-any.whl.metadata
  Downloading segmentation_models_pytorch-0.3.3-py3-none-any.whl.metadata (30 kB)
Collecting pretrainedmodels==0.7.4 (from segmentation_models_pytorch)
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting efficientnet-pytorch==0.7.1 (from segmentation_models_pytorch)
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting timm==0.9.2 (from segmentation_models_pytorch)
  Obtaining dependency information for timm==0.9.2 from https://files.pythonhosted.org/packages/29

In [2]:
!pip install torchmetrics



In [3]:
import torch
import os
import cv2
import logging
import sys
import time
import torchvision.transforms as T
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
from kornia.losses import focal_loss
from torchmetrics import Accuracy, Precision, Recall, FBetaScore, Dice, JaccardIndex



In [4]:
NUM_WORKERS = 0
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#Solver
CLASSES = {0: "Benign", 1: "Malignant", 2: "Normal"}
INPUT_SIZE = (448,448)
BATCH_SIZE = 8
BASE_LR = 0.01
MAX_EPOCHS = 3
SAVE_INTERVAL = 10
PATIENCE = 300
N_FOLDS = 3


#Model
ARCH = "deeplabv3plus" # chọn giữa ['unet', 'unetpp', , 'fpn', 'deeplabv3plus']
ENCODER_NAME = "efficientnet-b4" # chọn giữa các kiến trúc ['resnet50', 'resnext50_32x4d', 'tu-wide_resnet50_2', 'efficientnet-b4']
IN_CHANNELS = 3
SEG_NUM_CLASSES = 2
CLA_NUM_CLASSES = 3
OUTPUT_ACTIVATION = None #None for logits

#Loss coefficient weight
ALPHA = 0.7

#Path
OUTPUT_DIR = r"/kaggle/working/output"
DATASET_DIR = r"/kaggle/input/medical-od-segmentation-homeowork/content/data/train"
CHECKPOINT = None

#Eval
WEIGHT = r""

# 1. Utils

In [5]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

class UnNormalize(object):
    def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
            # The normalize code -> t.sub_(m).div_(s)
        return tensor

In [6]:
def calculate_overlap_metrics(pred, gt,eps=1e-5):
    output = pred.view(-1,)
    target = gt.view(-1,).float()

    tp = torch.sum(output * target)  # TP
    fp = torch.sum(output * (1 - target))  # FP
    fn = torch.sum((1 - output) * target)  # FN
    tn = torch.sum((1 - output) * (1 - target))  # TN

    # pixel_acc = (tp + tn + eps) / (tp + tn + fp + fn + eps)
    dice = (2 * tp + eps) / (2 * tp + fp + fn + eps)
    iou = ( tp + eps) / ( tp + fp + fn + eps)
    precision = (tp + eps) / (tp + fp + eps)
    recall = (tp + eps) / (tp + fn + eps)
#     specificity = (tn + eps) / (tn + fp + eps)

    return iou, dice, precision, recall

In [7]:

def setup_logger(logger_name, output_dir):
    import os
    logger = logging.getLogger(logger_name)
    logger.setLevel(logging.DEBUG)
    # create file handler which logs even debug messages
    fh = logging.FileHandler(os.path.join(output_dir, 'log.log'))
    fh.setLevel(logging.DEBUG)
    # create console handler with a higher log level
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    # create formatter and add it to the handlers
    formatter = logging.Formatter('%(asctime)s %(name)s %(levelname)s: %(message)s')
    fh.setFormatter(formatter)
    ch.setFormatter(formatter)
    # add the handlers to logger
    logger.addHandler(fh)
    logger.addHandler(ch)
    return logger


def logging_hyperparameters(logger):
    logger.info("==========Hyperparameters==========")
    logger.info(f"Device: {DEVICE}")
    logger.info(f"Architecture: {ARCH}")
    logger.info(f"Encoder: {ENCODER_NAME}")
    logger.info(f"Encoder weight: imagenet")
    logger.info(f"Input size: {INPUT_SIZE}")
    logger.info(f"Batch size: {BATCH_SIZE}")
    logger.info(f"Base learning rate: {BASE_LR}")
    logger.info(f"Max epochs: {MAX_EPOCHS}")
    logger.info(f"Weight decay: {1e-5}")
    logger.info("===================================")


def init_path(task):
    #Task == classification
    if task == "classification":
        weight_dir = os.path.join(OUTPUT_DIR, task, ENCODER_NAME)
        os.makedirs(weight_dir, exist_ok=True)
        log_dir = weight_dir
        logger_name = f"{task}_{ENCODER_NAME}"
    elif task == "segmentation":
        weight_dir = os.path.join(OUTPUT_DIR, task, f"{ENCODER_NAME}_{ARCH}")
        os.makedirs(weight_dir, exist_ok=True)
        log_dir = weight_dir
        logger_name = f"{task}_{ENCODER_NAME}_{ARCH}"
    elif task == "multitask":
        weight_dir = os.path.join(OUTPUT_DIR, f"{ENCODER_NAME}_{ARCH}")
        os.makedirs(weight_dir, exist_ok=True)
        log_dir = weight_dir
        logger_name = f"{task}_{ENCODER_NAME}_{ARCH}"
    return weight_dir, log_dir, logger_name

# 2. Setup Data

### 2.1. Download Dataset

In [8]:
!mkdir output

### 2.2. Setup dataloader

In [9]:
from sklearn.model_selection import StratifiedKFold
import numpy as np


def split_dataset(dataset_dir):
    benign, malignant, normal = [], [], []
    benign_images = [os.path.join(dataset_dir, 'benign', file) for file in os.listdir(os.path.join(dataset_dir, 'benign')) if file.endswith('.png')]
    malignant_images = [os.path.join(dataset_dir, 'malignant', file) for file in os.listdir(os.path.join(dataset_dir, 'malignant')) if file.endswith('.png')]
    normal_images = [os.path.join(dataset_dir, 'normal', file) for file in os.listdir(os.path.join(dataset_dir, 'normal')) if file.endswith('.png')]

    for mask in benign_images:
        if "_mask" in mask:
            image = mask.replace('_mask.png', '.png')
            benign.append((0, image, mask))
    for mask in malignant_images:
        if "_mask" in mask:
            image = mask.replace('_mask.png', '.png')
            malignant.append((1, image, mask))
    for mask in normal_images:
        if "_mask" in mask:
            image = mask.replace('_mask.png', '.png')
            normal.append((2, image, mask))

    all_data = benign + malignant + normal
    labels = [item[0] for item in all_data]

    kf = StratifiedKFold(n_splits=3)

    folds = []

    # Splitting data into folds
    for train_index, val_index in kf.split(np.zeros(len(labels)), labels):
        train_set = [all_data[i] for i in train_index]
        val_set = [all_data[i] for i in val_index]
        folds.append((train_set, val_set))

    return folds

In [10]:
class BUSI(Dataset):
    def __init__(self, dataset_dir, train_set, val_set, input_size=(512,512), transform=None, target_transform=None, is_train=True):
        self.input_size = input_size
        self.dataset_dir = dataset_dir
        self.is_train = is_train
        if not os.path.exists(self.dataset_dir):
            raise ValueError('BUSI dataset not found at {}'.format(self.dataset_dir))

        for _, _, files in os.walk(self.dataset_dir):
            for file in files:
                if "_mask_1" in file:
                    raise Exception("This class requires BUSI dataset with combined mask. It can be done by running the BUSI() function in the process_data.py at utils folder")

        self.transform = transform
        self.target_transform = target_transform
        self.train_set = train_set
        self.val_set = val_set
        if self.is_train:
            self.images = train_set
        else:
            self.images = val_set


    def __len__(self):
        if self.is_train:
            return len(self.train_set)
        else:
            return len(self.val_set)

    def __getitem__(self, idx):
        label, image_path, mask_path = self.images[idx]
        image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)


        image = cv2.resize(image, self.input_size, interpolation=cv2.INTER_NEAREST)
        mask = cv2.resize(mask, self.input_size, interpolation=cv2.INTER_NEAREST)


        #Normalize
        mask = mask/255
        mask = torch.from_numpy(mask).long()
        mask = torch.nn.functional.one_hot(mask, num_classes=2).permute(2,0,1).long()

        if self.transform is not None:
            image = self.transform(image)
        if self.target_transform is not None:
            mask = self.target_transform(mask)

        return image, mask, label

    @property
    def info(self):
        print(f"Dataset: BUSI")
        print(f"Train: {len(self.train_set)} images")
        print("-"*20)
        print(f"Benign: {len([image for image in self.train_set if image[0] == 0])} images")
        print(f"Malignant: {len([image for image in self.train_set if image[0] == 1])} images")
        print(f"Normal: {len([image for image in self.train_set if image[0] == 2])} images")
        print("-"*20)
        print(f"Val: {len(self.val_set)} images")
        print("-"*20)
        print(f"Benign: {len([image for image in self.val_set if image[0] == 0])} images")
        print(f"Malignant: {len([image for image in self.val_set if image[0] == 1])} images")
        print(f"Normal: {len([image for image in self.val_set if image[0] == 2])} images")
        print("-"*20)

    def _get_images(self):
        benign, malignant, normal = [], [], []
        benign_images = [os.path.join(self.dataset_dir, 'benign', file) for file in os.listdir(os.path.join(self.dataset_dir, 'benign')) if file.endswith('.png')]
        malignant_images = [os.path.join(self.dataset_dir, 'malignant', file) for file in os.listdir(os.path.join(self.dataset_dir, 'malignant')) if file.endswith('.png')]
        normal_images = [os.path.join(self.dataset_dir, 'normal', file) for file in os.listdir(os.path.join(self.dataset_dir, 'normal')) if file.endswith('.png')]

        for mask in benign_images:
            if "_mask" in mask:
                image = mask.replace('_mask.png', '.png')
                benign.append((0, image, mask))
        for mask in malignant_images:
            if "_mask" in mask:
                image = mask.replace('_mask.png', '.png')
                malignant.append((1, image, mask))
        for mask in normal_images:
            if "_mask" in mask:
                image = mask.replace('_mask.png', '.png')
                normal.append((2, image, mask))

        self.b_train_set, self.b_val_set = train_test_split(benign, test_size=0.2, random_state=42)
        self.m_train_set, self.m_val_set = train_test_split(malignant, test_size=0.2, random_state=42)
        self.n_train_set, self.n_val_set = train_test_split(normal, test_size=0.2, random_state=42)

        train_set = self.b_train_set + self.m_train_set + self.n_train_set
        val_set = self.b_val_set + self.m_val_set + self.n_val_set
        # # without normal class
        # train_set = b_train_set + m_train_set
        # val_set = b_val_set + m_val_set
        return train_set, val_set

In [11]:
folds = split_dataset(DATASET_DIR)

In [12]:
len(folds)

3

In [13]:
transform = T.Compose([
                    T.ToTensor(),
                    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
              ])

# 3. Single Model

### 3.1 Setup model

In [14]:
PRETRAINED_WEIGHT_URL = {
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'tu-wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'efficientnet-b4': 'https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth',
}

def segmentation_model(aux_param=None):
    assert ARCH in ['unet', 'unetpp', 'deeplabv3plus', 'fpn'], "Invalid architecture, must be ['unet', 'unetpp', 'deeplabv3plus', 'fpn']"
    assert ENCODER_NAME in ['resnet50', 'resnext50_32x4d', 'tu-wide_resnet50_2', 'efficientnet-b4'], "Invalid encoder name, must be ['resnet50', 'resnext50_32x4d', 'tu-wide_resnet50_2', 'efficientnet-b4']"
    #Params

    params = dict(
        encoder_name = ENCODER_NAME,
        encoder_depth = 5,
        encoder_weights = "imagenet",
        in_channels = IN_CHANNELS,
        classes = SEG_NUM_CLASSES,
        activation = OUTPUT_ACTIVATION,
        aux_params = aux_param
    )
    MODELS = {
        'unet':smp.Unet(**params),
        'unetpp': smp.UnetPlusPlus(**params),
        'deeplabv3plus': smp.DeepLabV3Plus(**params),
        'fpn': smp.FPN(**params),

    }
    return MODELS[ARCH]

def classification_model():
    MODELS = {
        'resnet50': torchvision.models.resnet50(weights='DEFAULT'),
        'resnext50_32x4d': torchvision.models.resnext50_32x4d(weights='DEFAULT'),
        'tu-wide_resnet50_2': torchvision.models.wide_resnet50_2(weights='DEFAULT'),
        'efficientnet-b4': torchvision.models.efficientnet_b4(weights=None),
    }
    model = MODELS[ENCODER_NAME]

    # Replace the last layer
    if ENCODER_NAME == "efficientnet-b4":
        state_dict = torch.hub.load_state_dict_from_url(PRETRAINED_WEIGHT_URL[ENCODER_NAME])
        model.load_state_dict(state_dict)
        model.classifier = torch.nn.Linear(1792, CLA_NUM_CLASSES)
    else:
        model.fc = torch.nn.Linear(2048, CLA_NUM_CLASSES)

    return model

class TwoSingleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.seg_model = segmentation_model()
        self.cla_model = classification_model()

    def forward(self, x):
        seg_out = self.seg_model(x)
        cla_out = self.cla_model(x)
        return seg_out, cla_out


### 3.2. Training Classification

In [21]:
def train(folds):
    
    for fold in range(N_FOLDS):
    
            #TASK
        TASK = "classification"

        #Path
        weight_dir, log_dir, logger_name = init_path(TASK)

        #Model
        model = classification_model().to(DEVICE)

        #Loss & Optimizer
        model = model.to(DEVICE)
        optimizer = optim.Adam(model.parameters(), lr=BASE_LR, weight_decay=1e-5)


        #Meters
        train_loss_meter = AverageMeter()
        val_loss_meter = AverageMeter()
        acc_meter = AverageMeter()
        precision_meter = AverageMeter()
        recall_meter = AverageMeter()
        f1_score_meter = AverageMeter()
        
        train_images, val_images = folds[fold]
        
        train_set = BUSI(DATASET_DIR, train_images, val_images, input_size=INPUT_SIZE,transform=transform, target_transform=None, is_train=True)
        val_set = BUSI(DATASET_DIR, train_images, val_images, input_size=INPUT_SIZE,transform=transform, target_transform=None, is_train=False)

        train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
        val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

        logger = setup_logger(logger_name, log_dir)
        best_f1 = 0
        stale = 0
        start_epoch = 1

        if CHECKPOINT is not None:
            if os.path.exists(CHECKPOINT):
                checkpoint = torch.load(CHECKPOINT)
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                start_epoch = checkpoint['epoch']
                best_f1 = checkpoint['best_f1']
                logger.info(f"Resume training from epoch {start_epoch}")
            else:
                logger.info(f"Checkpoint not found, start training from epoch 1")

        #Logging hyperparameters
        logging_hyperparameters(logger)


        for epoch in range(start_epoch, 1+MAX_EPOCHS):
            #Start time
            start_time = time.time()
            #Train
            model.train()
            #Reset meters
            train_loss_meter.reset()
            precision_meter.reset()
            recall_meter.reset()
            f1_score_meter.reset()
            acc_meter.reset()

            logger.info("Start training")
            for batch_idx, (image, _, label) in enumerate(train_loader):
                n = image.shape[0]
                optimizer.zero_grad()
                image = image.to(DEVICE)
                label = label.to(DEVICE)

                output = model(image) #Logits (batch_size,num_classes)
                #Cal loss
                train_loss = focal_loss(output, label, alpha=0.25, gamma=2, reduction='mean')
                train_loss.backward()
                optimizer.step()

                train_loss_meter.update(train_loss.item(),n)

                if batch_idx % 10 == 0:
                    logger.info(f"Epoch[{epoch}] - Fold[{fold}] - Iteration[{batch_idx}/{len(train_loader)}] Loss: {train_loss:.3f}")
            end_time = time.time()
            logger.info(f"Training Result: Epoch {epoch}/{MAX_EPOCHS} Fold {fold}/{N_FOLDS}, Loss: {train_loss_meter.avg:.3f}, Time epoch: {end_time-start_time:.3f}s")

            #Valid
            model.eval()
            with torch.no_grad():
                for batch_idx, (image, _, label) in enumerate(val_loader):
                    n = image.shape[0]
                    image = image.to(DEVICE)
                    label = label.to(DEVICE)

                    output = model(image)
                    val_loss = focal_loss(output, label, alpha=0.25, gamma=2,reduction='mean')

                    #Calculate metrics
                    #P, R and F1
                    label = label.detach().cpu().numpy()
                    output = output.argmax(1).detach().cpu().numpy()

                    p_score = precision_score(label, output, average='macro', zero_division=0)
                    r_score = recall_score(label, output, average='macro', zero_division=0)
                    _f1_score = f1_score(label, output, average='macro')
                    acc = accuracy_score(label, output)

                    #Update meters
                    val_loss_meter.update(val_loss.item(), n)
                    acc_meter.update(acc.item(),n)
                    precision_meter.update(p_score.item(), n)
                    recall_meter.update(r_score.item(), n)
                    f1_score_meter.update(_f1_score.item(), n)

            logger.info(f"Validation Result: Loss: {val_loss_meter.avg:.3f}, Accuracy: {acc_meter.avg:.3f} F1-Score: {f1_score_meter.avg:.3f}, Precision: {precision_meter.avg:.3f}, Recall: {recall_meter.avg:.3f}")

            #Save best model
            to_save = {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'best_f1': best_f1,
                }
            if f1_score_meter.avg > best_f1: # best base on IoU score
                logger.info(f"Best model found at epoch {epoch}, saving model")
                torch.save(to_save, os.path.join(weight_dir,f"best_epoch{epoch}_fold{fold}_{INPUT_SIZE[0]}_BS={BATCH_SIZE}_f1={f1_score_meter.avg:.3f}.pth")) # only save best to prevent output memory exceed error
                best_f1 = f1_score_meter.avg
                stale = 0
            else:
                stale += 1
                if stale > 300:
                    logger.info(f"No improvement {300} consecutive epochs, early stopping")
                    break
            if epoch % SAVE_INTERVAL == 0 or epoch == MAX_EPOCHS:
                logger.info(f"Save model at epoch {epoch}, saving model")
                torch.save(to_save, os.path.join(weight_dir,f"epoch_{epoch}_{fold}.pth"))


In [22]:
train(folds)

2024-01-11 05:07:40,486 classification_efficientnet-b4 INFO: Device: cuda
2024-01-11 05:07:40,486 classification_efficientnet-b4 INFO: Device: cuda
2024-01-11 05:07:40,486 classification_efficientnet-b4 INFO: Device: cuda
2024-01-11 05:07:40,488 classification_efficientnet-b4 INFO: Architecture: deeplabv3plus
2024-01-11 05:07:40,488 classification_efficientnet-b4 INFO: Architecture: deeplabv3plus
2024-01-11 05:07:40,488 classification_efficientnet-b4 INFO: Architecture: deeplabv3plus
2024-01-11 05:07:40,490 classification_efficientnet-b4 INFO: Encoder: efficientnet-b4
2024-01-11 05:07:40,490 classification_efficientnet-b4 INFO: Encoder: efficientnet-b4
2024-01-11 05:07:40,490 classification_efficientnet-b4 INFO: Encoder: efficientnet-b4
2024-01-11 05:07:40,493 classification_efficientnet-b4 INFO: Encoder weight: imagenet
2024-01-11 05:07:40,493 classification_efficientnet-b4 INFO: Encoder weight: imagenet
2024-01-11 05:07:40,493 classification_efficientnet-b4 INFO: Encoder weight: imag

### 3.4. Train Segmentation

In [20]:

def train(folds):
    for fold in range(N_FOLDS):
        #TASK
        TASK = "segmentation"

        #Path
        weight_dir, log_dir, logger_name = init_path(TASK)


        #Model
        model = segmentation_model().to(DEVICE)

        #Loss & Optimizer
        model = model.to(DEVICE)
        dice_loss = smp.losses.DiceLoss(mode='binary', from_logits=True)
        optimizer = optim.Adam(model.parameters(), lr=BASE_LR, weight_decay=1e-5)


        #Meters
        overall_meter = AverageMeter()
        iou_meter = AverageMeter()
        dice_meter = AverageMeter()
        train_loss_meter = AverageMeter()
        val_loss_meter = AverageMeter()
        precision_meter = AverageMeter()
        recall_meter = AverageMeter()
        f1_score_meter = AverageMeter()
        
        train_images, val_images = folds[fold]
        train_set = BUSI(DATASET_DIR, train_images, val_images, input_size=INPUT_SIZE,transform=transform, target_transform=None, is_train=True)
        val_set = BUSI(DATASET_DIR, train_images, val_images, input_size=INPUT_SIZE,transform=transform, target_transform=None, is_train=False)

        train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
        val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

        logger = setup_logger(logger_name, log_dir)
        stale = 0
        best_overall = 0
        start_epoch = 1

        if CHECKPOINT is not None:
            if os.path.exists(CHECKPOINT):
                checkpoint = torch.load(CHECKPOINT)
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                start_epoch = checkpoint['epoch']
                best_overall = checkpoint['best_overall']
                logger.info(f"Resume training from epoch {start_epoch}")
            else:
                logger.info(f"Checkpoint not found, start training from epoch 1")
        #Logging hyperparameters
        logging_hyperparameters(logger)

        for epoch in range(start_epoch, 1+MAX_EPOCHS):
            start_time = time.time()
            #Train
            model.train()
            #Reset meters
            overall_meter.reset()
            train_loss_meter.reset()
            val_loss_meter.reset()

            iou_meter.reset()
            dice_meter.reset()
            precision_meter.reset()
            recall_meter.reset()
            f1_score_meter.reset()

            logger.info("Start training")
            for batch_idx, (image, mask, _) in enumerate(train_loader):
                n = image.shape[0]
                optimizer.zero_grad()
                image = image.to(DEVICE)
                mask = mask.to(DEVICE)

                output = model(image) #Logits
                #Cal loss
                train_loss = dice_loss(output, mask)
                train_loss.backward()
                optimizer.step()

                train_loss_meter.update(train_loss.item(),n)

                if batch_idx % 10 == 0:
                    logger.info(f"Epoch[{epoch}] - Fold[{fold}] - Iteration[{batch_idx}/{len(train_loader)}] Loss: {train_loss:.3f}")
            end_time = time.time()
            logger.info(f"Training Result: Epoch {epoch}/{MAX_EPOCHS} - Fold {fold}/{N_FOLDS}, Loss: {train_loss_meter.avg:.3f}, Time epoch: {end_time-start_time:.3f}s")

            #Valid
            model.eval()
            with torch.no_grad():
                for batch_idx, (image, mask, _) in enumerate(val_loader):
                    n = image.shape[0]
                    image = image.to(DEVICE)
                    mask = mask.to(DEVICE)

                    output = model(image)
                    val_loss = dice_loss(output, mask)

                    # #Calculate metrics
                    mask = F.sigmoid(mask).round().long()
                    tp, fp, fn, tn = smp.metrics.get_stats(output, mask, mode='binary', threshold=0.5)


                    iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="macro")

                    dice_score = torch.mean((2*tp.sum(0)/(2*tp.sum(0) + fp.sum(0) + fn.sum(0) + 1e-5)))
                    precision_score = smp.metrics.precision(tp, fp, fn, tn, reduction="macro")
                    recall_score = smp.metrics.recall(tp, fp, fn, tn, reduction="macro")
                    f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="macro")


                    #Update meters
                    val_loss_meter.update(val_loss.item(), n)

                    iou_meter.update(iou_score.item(), n)
                    dice_meter.update(dice_score.item(), n)
                    precision_meter.update(precision_score.item(), n)
                    recall_meter.update(recall_score.item(), n)
                    f1_score_meter.update(f1_score.item(), n)

                    #Overall score
                    overall_score = (iou_score + dice_score + f1_score)/3
                    overall_meter.update(overall_score.item(), n)

            logger.info(f"Validation Result: Dice Loss: {val_loss_meter.avg:.3f}, IoU: {iou_meter.avg:.3f}, Dice Score: {dice_meter.avg:.3f}, F1-Score: {f1_score_meter.avg:.3f}, Average Score: {overall_meter.avg:.3f}")

            #Save best model
            to_save = {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'best_overall': best_overall,
                }
            if overall_meter.avg > best_overall: # best base on IoU score
                logger.info(f"Best model found at epoch {epoch}, saving model")

                torch.save(to_save, os.path.join(weight_dir,f"best_epoch{epoch}_fold{fold}_{INPUT_SIZE[0]}_BS={BATCH_SIZE}_average={overall_meter.avg:.3f}.pth"))
                best_overall = overall_meter.avg
                stale = 0
            else:
                stale += 1
                if stale > 300:
                    logger.info(f"No improvement {300} consecutive epochs, early stopping")
                    break
            if epoch % SAVE_INTERVAL == 0 or epoch == MAX_EPOCHS:
                logger.info(f"Save model at epoch {epoch}, saving model")
                torch.save(to_save, os.path.join(weight_dir,f"epoch{epoch}_fold{fold}.pth"))


In [None]:
train(folds)

2024-01-11 04:53:18,919 segmentation_efficientnet-b4_deeplabv3plus INFO: Device: cuda
2024-01-11 04:53:18,920 segmentation_efficientnet-b4_deeplabv3plus INFO: Architecture: deeplabv3plus
2024-01-11 04:53:18,921 segmentation_efficientnet-b4_deeplabv3plus INFO: Encoder: efficientnet-b4
2024-01-11 04:53:18,921 segmentation_efficientnet-b4_deeplabv3plus INFO: Encoder weight: imagenet
2024-01-11 04:53:18,922 segmentation_efficientnet-b4_deeplabv3plus INFO: Input size: (448, 448)
2024-01-11 04:53:18,923 segmentation_efficientnet-b4_deeplabv3plus INFO: Batch size: 8
2024-01-11 04:53:18,925 segmentation_efficientnet-b4_deeplabv3plus INFO: Base learning rate: 0.01
2024-01-11 04:53:18,925 segmentation_efficientnet-b4_deeplabv3plus INFO: Max epochs: 10
2024-01-11 04:53:18,926 segmentation_efficientnet-b4_deeplabv3plus INFO: Weight decay: 1e-05
2024-01-11 04:53:18,932 segmentation_efficientnet-b4_deeplabv3plus INFO: Start training
2024-01-11 04:53:19,705 segmentation_efficientnet-b4_deeplabv3plus 

# 4. Multitask Model

In [None]:
RESNET50_ENCODER_WEIGHTS_URL = "https://download.pytorch.org/models/resnet50-19c8e357.pth"

def multitask_model():
    aux_param=dict(
                    pooling='avg',             # one of 'avg', 'max'
                    dropout=0.5,               # dropout ratio, default is None
                    # activation='sigmoid',      # activation function, default is None
                    classes=CLA_NUM_CLASSES,      # define number of output labels
                )
    model = segmentation_model(aux_param=aux_param)
    return model

In [None]:
def train(folds):
    for fold in range(N_FOLDS):
        #Task
        TASK = "multitask"

        #Path
        weight_dir, log_dir, logger_name = init_path(TASK)


        #Model
        model = multitask_model().to(DEVICE)

        #Loss & Optimizer
        dice_loss = smp.losses.DiceLoss(mode='binary', from_logits=True)
        # CE_loss = torch.nn.CrossEntropyLoss()

        optimizer = optim.Adam(model.parameters(), lr=BASE_LR, weight_decay=1e-5)



        #Common meter
        overall_meter = AverageMeter()
        train_loss_meter = AverageMeter()
        val_loss_meter = AverageMeter()

        #Meters segmentation
        seg_train_loss_meter = AverageMeter()
        seg_val_loss_meter = AverageMeter()
        seg_iou_meter = AverageMeter()
        seg_dice_meter = AverageMeter()
        seg_precision_meter = AverageMeter()
        seg_recall_meter = AverageMeter()
        seg_f1_score_meter = AverageMeter()

        #Meters classification
        cla_train_loss_meter = AverageMeter()
        cla_val_loss_meter = AverageMeter()
        cla_acc_meter = AverageMeter()
        cla_precision_meter = AverageMeter()
        cla_recall_meter = AverageMeter()
        cla_f1_score_meter = AverageMeter()

        train_images, val_images = folds[fold]
        train_set = BUSI(DATASET_DIR, train_images, val_images, input_size=INPUT_SIZE,transform=transform, target_transform=None, is_train=True)
        val_set = BUSI(DATASET_DIR, train_images, val_images, input_size=INPUT_SIZE,transform=transform, target_transform=None, is_train=False)

        train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
        val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

                #Setup logging
        logger = setup_logger(logger_name, log_dir)

        start_epoch=1
        best_overall = 0
        stale = 0

        if CHECKPOINT is not None:
            if os.path.exists(CHECKPOINT):
                checkpoint = torch.load(CHECKPOINT)
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                start_epoch = checkpoint['epoch']
                best_overall = checkpoint['best_overall']
                print(f"Resume training from epoch {start_epoch}")
            else:
                print(f"Checkpoint not found, start training from epoch 1")

        #Logging hyperparameters
        logging_hyperparameters(logger)

        for epoch in range(start_epoch, 1+MAX_EPOCHS):
            start_time = time.time()
            #Train
            model.train()

            #Reset meters
            #Common meter
            train_loss_meter.reset()
            val_loss_meter.reset()
            overall_meter.reset()

            #Meters segmentation
            seg_iou_meter.reset()
            seg_dice_meter.reset()
            seg_precision_meter.reset()
            seg_recall_meter.reset()
            seg_f1_score_meter.reset()

            #Meters classification
            cla_acc_meter.reset()
            cla_precision_meter.reset()
            cla_recall_meter.reset()
            cla_f1_score_meter.reset()

            logger.info("Start training")
            for batch_idx, (image, mask, label) in enumerate(train_loader):
                n = image.shape[0]
                optimizer.zero_grad()
                image = image.to(DEVICE)
                mask = mask.to(DEVICE)
                label = label.to(DEVICE)

                #Forward
                output_mask, output_classification = model(image)

                #Cal loss
                loss_segmentation = dice_loss(output_mask, mask)
                loss_classification = focal_loss(output_classification, label, alpha=0.25, gamma=2,reduction='mean')
                train_loss = ALPHA*loss_segmentation + (1 - ALPHA)*loss_classification

                train_loss.backward()
                optimizer.step()
                train_loss_meter.update(train_loss.item(), n)
                seg_train_loss_meter.update(loss_segmentation.item(), n)
                cla_train_loss_meter.update(loss_classification.item(), n)
                if batch_idx % 10 == 0:
                    logger.info(f"Epoch[{epoch}] Iteration[{batch_idx}/{len(train_loader)}] Loss: {train_loss:.3f}")

            end_time = time.time()
            logger.info(f"Training Result: Epoch {epoch}/{MAX_EPOCHS}, Loss: {train_loss_meter.avg:.3f}  Segmentation loss: {seg_train_loss_meter.avg:.3f} Classification loss: {cla_train_loss_meter.avg:.3f} Time epoch: {end_time-start_time:.3f}s")

            #Valid
            model.eval()
            with torch.no_grad():
                for batch_idx, (image, mask, label) in enumerate(val_loader):
                    n = image.shape[0]
                    image = image.to(DEVICE)
                    mask = mask.to(DEVICE)
                    label = label.to(DEVICE)

                    #Forward
                    output_mask, output_classification = model(image)

                    #Cal loss
                    loss_segmentation = dice_loss(output_mask, mask)
                    loss_classification = focal_loss(output_classification, label, alpha=0.25, gamma=2,reduction='mean')
                    val_loss = ALPHA*loss_segmentation + (1 - ALPHA)*loss_classification


                    #Calculate metrics
                    #Segmentation: iou, dice, p, r, f1

                    mask = F.sigmoid(mask).round().long()
                    tp, fp, fn, tn = smp.metrics.get_stats(output_mask, mask, mode='binary', threshold=0.5)


                    seg_iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="macro")

                    seg_dice_score = torch.mean((2*tp.sum(0)/(2*tp.sum(0) + fp.sum(0) + fn.sum(0) + 1e-5)))
                    seg_precision_score = smp.metrics.precision(tp, fp, fn, tn, reduction="macro")
                    seg_recall_score = smp.metrics.recall(tp, fp, fn, tn, reduction="macro")
                    seg_f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="macro")


                    #Classification: acc, p, r, f1
                    label = label.detach().cpu().numpy()
                    output_classification = output_classification.argmax(1).detach().cpu().numpy()

                    cla_acc = accuracy_score(label, output_classification)
                    cla_precision_score = precision_score(label, output_classification, average='macro', zero_division=0)
                    cla_recall_score = recall_score(label, output_classification, average='macro', zero_division=0)
                    cla_f1_score = f1_score(label, output_classification, average='macro')


                    #Update meters
                    val_loss_meter.update(val_loss.item(), n)

                    #Segmentation
                    seg_val_loss_meter.update(loss_segmentation.item(), n)
                    seg_iou_meter.update(seg_iou_score.item(), n)
                    seg_dice_meter.update(seg_dice_score.item(), n)
                    seg_precision_meter.update(seg_precision_score.item(), n)
                    seg_recall_meter.update(seg_recall_score.item(), n)
                    seg_f1_score_meter.update(seg_f1_score.item(), n)

                    #Classification
                    cla_val_loss_meter.update(loss_classification.item(), n)
                    cla_acc_meter.update(cla_acc.item(),n)
                    cla_precision_meter.update(cla_precision_score.item(), n)
                    cla_recall_meter.update(cla_recall_score.item(), n)
                    cla_f1_score_meter.update(cla_f1_score.item(), n)

                    #Common
                    overall_score = ((seg_iou_score + seg_dice_score + seg_f1_score)/3 + cla_f1_score)/2
                    overall_meter.update(overall_score.item(), n)

            logger.info(f"Validation Result: Loss: {val_loss_meter.avg:.3f}, Segmentation loss: {seg_val_loss_meter.avg:.3f} Classification loss: {cla_val_loss_meter.avg:.3f} Overal Score: {overall_meter.avg:.3f}")
            logger.info(f"Classification: Accuracy: {cla_acc_meter.avg:.3f}, F1-Score: {cla_f1_score_meter.avg:.3f}, Precision: {cla_precision_meter.avg:.3f}, Recall: {cla_recall_meter.avg:.3f}")
            logger.info(f"Segmentation: IoU: {seg_iou_meter.avg:.3f} Dice: {seg_dice_meter.avg:.3f}, F1-score: {seg_f1_score_meter.avg:.3f}, Precision: {seg_precision_meter.avg:.3f}, Recall: {seg_recall_meter.avg:.3f}")
            #Save best model
            to_save = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_overall': best_overall
            }
            if overall_meter.avg > best_overall: # best base on IoU score
                logger.info(f"Best model found at epoch {epoch}, saving model")
                torch.save(to_save, os.path.join(weight_dir,f"best_{epoch}_{INPUT_SIZE[0]}_BS={BATCH_SIZE}_overal={overall_meter.avg:.3f}.pth"))
                best_overall = overall_meter.avg
                stale = 0
            else:
                stale += 1
                if stale > 300:
                    logger.info(f"No improvement {300} consecutive epochs, early stopping")
                    break
            if epoch % SAVE_INTERVAL == 0 or epoch == MAX_EPOCHS:
                logger.info(f"Save model at epoch {epoch}, saving model")

                torch.save(to_save, os.path.join(weight_dir,f"epoch_{epoch}.pth"))


In [None]:
train()