## train.py

In [1]:
import argparse
import logging
import os
import os.path as osp
import sys

import numpy as np
import torch
import torch.nn as nn
import torchvision as tv
from torch import optim
from torch.cuda import amp
from torch.nn.modules import activation
from torch.nn.modules.activation import Threshold
from tqdm import tqdm

from eval import eval_net

from torch.utils.data import DataLoader, random_split
from torch.utils.data.distributed import DistributedSampler
import utils
import models
from utils.dataset import BasicDataset

  warn(f"Failed to load image Python extension: {e}")


In [2]:
logger = logging.getLogger(__name__)

# dir_img = osp.join("..", "unet_dataset", "images", "trainval")
# dir_mask = osp.join("..", "unet_dataset", "labels", "trainval")

dir_img = osp.join(".", "dataset0", "train")
dir_mask = osp.join(".", "dataset0", "train_GT")

In [3]:
def is_parallel(model):
    return type(model) in (nn.parallel.DataParallel,
                           nn.parallel.DistributedDataParallel)

In [4]:
def get_args():
    parser = argparse.ArgumentParser(
        description='Train the UNet on images and target masks',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-e',
                        '--epochs',
                        metavar='E',
                        type=int,
                        default=5,
                        help='Number of epochs',
                        dest='epochs')
    parser.add_argument('-b',
                        '--batch_size',
                        metavar='B',
                        type=int,
                        nargs='?',
                        default=1,
                        help='Batch size',
                        dest='batchsize')
    parser.add_argument('-l',
                        '--learning_rate',
                        metavar='LR',
                        type=float,
                        nargs='?',
                        default=0.0001,
                        help='Learning rate',
                        dest='lr')
    parser.add_argument('-f',
                        '--load',
                        dest='load',
                        default=False,
                        action='store_true',
                        help='Load model from a .pth file')
    parser.add_argument('-s',
                        '--scale',
                        dest='scale',
                        type=float,
                        default=0.5,
                        help='Downscaling factor of the images')
    parser.add_argument('-v',
                        '--validation',
                        dest='val',
                        type=float,
                        default=0.1,
                        help='Percent of the data \
                              that is used as validation (0-100)')
    parser.add_argument('-d',
                        '--device',
                        default='cpu',
                        help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    parser.add_argument('--local_rank',
                        type=int,
                        default=-1,
                        help='DDP parameter, do not modify')
    parser.add_argument('--model_type',
                        type=str,
                        default='utrans',
                        help="Model which choosed.")
    parser.add_argument('--split_seed', type=int, default=None, help='')
    return parser.parse_args(args=[])


In [5]:

def select_device(device='', batch_size=None):
    # device = 'cpu' or '0' or '0,1,2,3'
    s = f'UNetHX torch {torch.__version__} '
    cpu = device.lower() == 'cpu'
    if cpu:
        os.environ[
            'CUDA_VISIBLE_DEVICES'] = '-1'  # force torch.cuda.is_available() = False
    elif device:  # non-cpu device requested
        os.environ['CUDA_VISIBLE_DEVICES'] = device  # set environment variable
        assert torch.cuda.is_available(
        ), f'CUDA unavailable, invalid device {device} requested'  # check availability

    cuda = not cpu and torch.cuda.is_available()
    if cuda:
        n = torch.cuda.device_count()
        if n > 1 and batch_size:  # check that batch_size is compatible with device_count
            assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
        space = ' ' * len(s)
        for i, d in enumerate(device.split(',') if device else range(n)):
            p = torch.cuda.get_device_properties(i)
            s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n"  # bytes to MB
    else:
        s += 'CPU\n'

    logger.info(s)  # skip a line
    return torch.device('cuda:0' if cuda else 'cpu')

In [6]:
def train_net(model,
              device,
              epochs=5,
              batch_size=1,
              lr=0.001,
              val_percent=0.1,
              save_all_cp=True,
              dir_checkpoint='runs',
              split_seed=None):
    transform_valid = tv.transforms.Compose([tv.transforms.ToTensor(),
                                             tv.transforms.RandomCrop((400, 400))
                                             ])
    dataset = BasicDataset(dir_img, dir_mask, transform=transform_valid)
    n_val = int(len(dataset) *
                val_percent) if val_percent < 1 else int(val_percent)
    n_train = len(dataset) - n_val
    if split_seed:
        train, val = random_split(
            dataset, [n_train, n_val],
            generator=torch.Generator().manual_seed(split_seed))
    else:
        train, val = random_split(dataset, [n_train, n_val])
    if type(model) == nn.parallel.DistributedDataParallel:
        train_loader = DataLoader(train,
                                  batch_size=batch_size,
                                  shuffle=False,
                                  num_workers=0,
                                  pin_memory=True,
                                  sampler=DistributedSampler(train))
        val_loader = DataLoader(val,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=0,
                                pin_memory=True,
                                drop_last=True,
                                sampler=DistributedSampler(val))
    else:
        train_loader = DataLoader(train,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=0,
                                  pin_memory=True)
        val_loader = DataLoader(val,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=0,
                                pin_memory=True,
                                drop_last=True)

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_all_cp}
        Device:          {device.type}
    ''')

    # loss = nn.BCEWithLogitsLoss()
    # loss.__name__ = 'BCEWithLogitLoss'
    # loss = nn.BCELoss()
    # loss.__name__ = 'BCELoss'
    loss = utils.losses.NoiseRobustDiceLoss(eps=1e-7, activation='sigmoid')
    metrics = [
        utils.metrics.Dice(threshold=0.5, activation='sigmoid'),
        utils.metrics.Fscore(threshold=None, activation='sigmoid')
    ]
    optimizer = torch.optim.Adam([
        dict(params=model.parameters(), lr=lr),
    ])

    train_epoch = utils.train.TrainEpoch(
        model,
        loss=loss,
        metrics=metrics,
        optimizer=optimizer,
        device=device,
        verbose=True,
    )
    valid_epoch = utils.train.ValidEpoch(
        model,
        loss=loss,
        metrics=metrics,
        device=device,
        verbose=True,
    )

    max_score = 0
    os.makedirs(dir_checkpoint, exist_ok=True)
    for i in range(0, epochs):
        print('\nEpoch: {}'.format(i + 1))
        train_logs = train_epoch.run(train_loader)
        valid_logs = valid_epoch.run(val_loader)

        # do something (save model, change lr, etc.)
        if max_score < valid_logs['dice_score']:
            max_score = valid_logs['dice_score']
            torch.save(model, osp.join(dir_checkpoint, 'best_model.pt'))
            torch.save(model.state_dict(),
                       osp.join(dir_checkpoint, 'best_model_dict.pth'))
            print('Model saved!')

        if save_all_cp:
            torch.save(model.state_dict(),
                       osp.join(dir_checkpoint, f'CP_epoch{i + 1}.pth'))

## if __name__ == '__main__':


In [7]:
logging.basicConfig(level=logging.INFO,
                    format='%(levelname)s: %(message)s')


In [8]:
args = get_args()
args

Namespace(batchsize=1, device='cpu', epochs=5, load=False, local_rank=-1, lr=0.0001, model_type='utrans', scale=0.5, split_seed=None, val=0.1)

In [9]:
device = select_device(args.device, batch_size=args.batchsize)
device

INFO: UNetHX torch 1.12.0 CPU



device(type='cpu')

In [10]:
logging.info(f'Using device {device}')

INFO: Using device cpu


In [11]:
import socket
from datetime import datetime


In [12]:
current_time = datetime.now().strftime('%b%d_%H-%M-%S')
current_time

'Aug16_01-30-05'

In [13]:
comment = f'MT_{args.model_type}_SS_{args.split_seed}_LR_{args.lr}_BS_{args.batchsize}'
comment


'MT_utrans_SS_None_LR_0.0001_BS_1'

In [14]:
dir_checkpoint = osp.join(
    ".", "checkpoints",
    f"{current_time}_{socket.gethostname()}_" + comment)
dir_checkpoint

'.\\checkpoints\\Aug16_01-30-05_DESKTOP-V4LBDUN_MT_utrans_SS_None_LR_0.0001_BS_1'

In [15]:
nets = {
    # "unet": models.UNet,
    # "inunet": InUNet,
    # "attunet": AttU_Net,
    # "inattunet": InAttU_Net,
    # "att2uneta": Att2U_NetA,
    # "att2unetb": Att2U_NetB,
    # "att2unetc": Att2U_NetC,
    # "ecaunet": ECAU_Net,
    # "gsaunet": GsAUNet,
    # "utnet": U_Transformer,
    #"ddrnet": models.DualResNet,
    "utrans": models.U_Transformer
}


In [16]:
net_type = nets[args.model_type.lower()]
net_type

models.utransformer.U_Transformer.U_Transformer

In [17]:
net = net_type(in_channels=3, classes=1)

In [18]:
net.to(device=device)

U_Transformer(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(128, eps=1e-0

In [19]:
cuda = device.type != 'cpu'
cuda

False

In [20]:
net = net.to(device=device)
net = net.module if is_parallel(net) else net
net = net.to(device=device)

In [21]:
if args.load:
    net.load_state_dict(torch.load(args.load, map_location=device))
    logging.info(f'Model loaded from {args.load}')

In [22]:
train_net(model=net,
          epochs=args.epochs,
          batch_size=args.batchsize,
          lr=args.lr,
          device=device,
          val_percent=args.val,
          dir_checkpoint=dir_checkpoint,
          split_seed=args.split_seed,
          save_all_cp=True)

INFO: Creating dataset with 2074 examples
INFO: Starting training:
        Epochs:          5
        Batch size:      1
        Learning rate:   0.0001
        Training size:   1867
        Validation size: 207
        Checkpoints:     True
        Device:          cpu
    



Epoch: 1
train:   0%|          | 2/1867 [25:25<394:49:37, 762.13s/it, noise_robust_dice_loss - 0.75, dice_score - 0.4646, f_score - 0.3646]  

## train_net

In [137]:
model=net
epochs=args.epochs
batch_size=args.batchsize
lr=args.lr
device=device
val_percent=args.val
dir_checkpoint=dir_checkpoint
split_seed=args.split_seed
save_all_cp=True

In [138]:
transform_valid = tv.transforms.Compose([tv.transforms.ToTensor(),
                                         tv.transforms.RandomCrop((400, 400))
                                         ])

In [139]:
dataset = BasicDataset(dir_img, dir_mask, transform=transform_valid)
dataset

INFO: Creating dataset with 2074 examples


<utils.dataset.BasicDataset at 0x2223bc43f48>

In [140]:
n_val = int(len(dataset) *
            val_percent) if val_percent < 1 else int(val_percent)
n_train = len(dataset) - n_val

In [141]:
if split_seed:
    train, val = random_split(
        dataset, [n_train, n_val],
        generator=torch.Generator().manual_seed(split_seed))
else:
    train, val = random_split(dataset, [n_train, n_val])
print(type(train))

<class 'torch.utils.data.dataset.Subset'>


In [143]:
if type(model) == nn.parallel.DistributedDataParallel:
    train_loader = DataLoader(train,
                              batch_size=batch_size,
                              shuffle=False,
                              num_workers=0,
                              pin_memory=True,
                              sampler=DistributedSampler(train))
    val_loader = DataLoader(val,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=0,
                            pin_memory=True,
                            drop_last=True,
                            sampler=DistributedSampler(val))
else:
    train_loader = DataLoader(train,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=1,
                              pin_memory=True)
    val_loader = DataLoader(val,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=1,
                            pin_memory=True,
                            drop_last=True)

In [144]:
for i, (train, GT) in enumerate(train_loader):
    print(train.size)
    break

<built-in method size of Tensor object at 0x000002223BC3FE08>


## DataSet



In [211]:
import os
import os.path as osp
import numpy as np
from glob import glob
import torch
from torch.utils.data import Dataset
import logging
from PIL import Image

In [212]:
imgs_dir = dir_img
masks_dir = dir_mask
mask_suffix = '_segmentation'
transform = transform_valid
ids = [
            osp.splitext(file)[0] for file in os.listdir(imgs_dir)
            if not file.startswith('.')
        ]
logging.info(f'Creating dataset with {len(ids)} examples')

INFO: Creating dataset with 2074 examples


In [226]:
def preprocess(img_nd):
    if len(img_nd.shape) == 2:
        img_nd = np.expand_dims(img_nd, axis=2)
        # img_nd = np.repeat(img_nd, 3, 2) # make 1 channel pic to 3 channels pic

    # HWC to CHW
    # img_trans = img_nd.transpose((2, 0, 1))
    img_trans = img_nd

    if 255 >= img_trans.max() > 1 and img_trans.min() > 0:
        # Normally UINT8 pic
        img_trans = img_trans / 255.0
    elif 0 < img_trans.all() <= 1:
        # Normally FLOAT pic
        pass
    else:
        # DICOM pic
        pass

    return img_trans

In [214]:
idx = ids[0]
mask_file = glob(
    osp.join(masks_dir, idx + mask_suffix + '.*'))
img_file = glob(osp.join(imgs_dir, idx + '.*'))

In [251]:
mask = np.array(Image.open(mask_file[0]))
img = np.array(Image.open(img_file[0]))

In [252]:
print(img.shape)

(767, 1022, 3)


In [253]:
seed = torch.random.seed()

if transform:
    torch.random.manual_seed(seed)
    img = transform(img)
    torch.random.manual_seed(seed)
    mask = transform(mask)

In [254]:
img.shape

torch.Size([3, 400, 400])

In [224]:
trans = tv.transforms.ToPILImage()

In [55]:
image = trans(img)
label = trans(mask)
image.show()
label.show()

In [228]:
img = preprocess(img)
mask = preprocess(mask)