In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
from torch_unet.dataset import BasicDataset
from torch_unet.unet import UNet
import numpy as np

from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, random_split
import torch
import logging

import torch.nn as nn
from torch import optim
from tqdm import tqdm
from torch_unet.evaluation import eval_net
from torchsummary import summary
from torch_unet.losses import DiceCoeff
from torch_unet.unet.components import Down

logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

In [2]:
DATADIR = "../Datasets/training/"
IMAGE_DIR = DATADIR + "images/"
MASK_DIR = DATADIR + "groundtruth/"
MASK_THRESHOLD = 0.25

val_percent = 0.2
batch_size=1
lr = 0.001
img_scale = 1
epochs = 5

In [3]:
dataset = BasicDataset(IMAGE_DIR, MASK_DIR, mask_treshold=MASK_THRESHOLD)

INFO: Creating dataset with 100 examples


In [4]:
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
net = UNet(n_channels=3, n_classes=1, bilinear=False)
net.to(device=device);

INFO: Using device cpu


In [7]:
writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
global_step = 0

In [8]:
optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-8)
criterion = nn.BCEWithLogitsLoss()

In [None]:
for epoch in range(epochs):
    net.train()   # Sets module in training mode
    epoch_loss = 0
    with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
        for batch in train_loader:
            imgs = batch['image']
            true_masks = batch['mask']

            imgs = imgs.to(device=device, dtype=torch.float32)
            true_masks = true_masks.to(device=device, dtype=torch.float32)
            
            masks_pred = net(imgs)  # Make predictions
            loss = criterion(masks_pred, true_masks)  # Evaluate loss
            
            epoch_loss += loss.item()    # Add loss to epoch
            writer.add_scalar('Loss/train', loss.item(), global_step)  

            pbar.set_postfix(**{'loss (batch)': loss.item()})

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.update(imgs.shape[0])
            global_step += 1
            if global_step % (len(dataset) // (10 * batch_size)) == 0:
                val_score = eval_net(net, val_loader, device, n_val)
                logging.info('Validation Dice Coeff: {}'.format(val_score))
                writer.add_scalar('Dice/test', val_score, global_step)

                writer.add_images('images', imgs, global_step)
                if net.n_classes == 1:
                    writer.add_images('masks/true', true_masks, global_step)
                    writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
    if save_cp:
        try:
            os.mkdir(dir_checkpoint)
            logging.info('Created checkpoint directory')
        except OSError:
            pass
        torch.save(net.state_dict(),
                   dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
        logging.info(f'Checkpoint {epoch + 1} saved !')

writer.close()


Epoch 1/5:  12%|█▎        | 10/80 [01:01<07:04,  6.07s/img, loss (batch)=0.562]
Validation round:   0%|          | 0/20 [00:00<?, ?img/s][A
Validation round:   5%|▌         | 1/20 [00:02<00:40,  2.12s/img][A
Validation round:  10%|█         | 2/20 [00:03<00:35,  2.00s/img][A
Validation round:  15%|█▌        | 3/20 [00:05<00:31,  1.87s/img][A
Validation round:  20%|██        | 4/20 [00:07<00:28,  1.79s/img][A
Validation round:  25%|██▌       | 5/20 [00:08<00:26,  1.76s/img][A
Validation round:  30%|███       | 6/20 [00:10<00:24,  1.73s/img][A
Validation round:  35%|███▌      | 7/20 [00:12<00:23,  1.78s/img][A
Validation round:  40%|████      | 8/20 [00:13<00:20,  1.74s/img][A
Validation round:  45%|████▌     | 9/20 [00:15<00:18,  1.71s/img][A
Validation round:  50%|█████     | 10/20 [00:17<00:16,  1.69s/img][A
Validation round:  55%|█████▌    | 11/20 [00:18<00:15,  1.68s/img][A
Validation round:  60%|██████    | 12/20 [00:20<00:14,  1.78s/img][A
Validation round:  65%|█████

In [10]:
np.max(masks_pred)

torch.Size([4, 1, 400, 400])

In [None]:
import gc
gc.collect()