In [1]:
import os
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim

from dataset import CarvanaDataset, prepare_datasets
from model import UNET
from train import train_fn
import utils

In [2]:
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
DATA_DIR = './dataset/'
LR = 1e-4
BATCH_SIZE = 32
NUM_EPOCHS = 20
NUM_WORKERS = 16
IMAGE_HEIGHT = 320
IMAGE_WIDTH = 320
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_DIR = os.path.join(DATA_DIR, 'train')
TRAIN_MASKDIR = os.path.join(DATA_DIR, 'train_masks')
VAL_DIR = os.path.join(DATA_DIR, 'val')
VAL_MASKDIR = os.path.join(DATA_DIR, 'val_masks')

In [3]:
# prepare_datasets(DATA_DIR)

In [4]:
train_transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
        A.RandomCrop(height=280, width=280),
        A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0),
        ToTensorV2()
    ])

val_transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0),
        ToTensorV2()
    ])

In [5]:
net = UNET(in_channels=3, out_channels=1)
net = nn.DataParallel(net, device_ids=['cuda:0', 'cuda:1', 'cuda:2', 'cuda:3'])
net = net.to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(net.parameters(), lr=LR)

In [6]:
x = torch.randn((32, 3, 160, 240)).to(DEVICE)
pred = net(x)
print(pred.shape)
print(x.shape)

torch.Size([32, 1, 160, 240])
torch.Size([32, 3, 160, 240])


In [7]:
train_loader, val_loader = utils.get_loaders(
    TRAIN_DIR,
    TRAIN_MASKDIR,
    VAL_DIR,
    VAL_MASKDIR,
    BATCH_SIZE,
    train_transform,
    val_transform,
    NUM_WORKERS,
    PIN_MEMORY
)

In [8]:
scaler = torch.cuda.amp.GradScaler()

for epoch in range(NUM_EPOCHS):
    print(f'Epoch [{epoch + 1}]')
    train_fn(train_loader, net, optimizer, loss_fn, scaler, device=DEVICE)
    
#     torch.save({
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             }, './model_weights/model_weights.pth')
    
    utils.check_dice(val_loader, net, device=DEVICE)
    
#     utils.save_preds_img(val_loader, net)

Epoch [1]


100%|██████████| 128/128 [01:17<00:00,  1.66it/s, loss=0.172]


Validation Dice score: 0.9779
Epoch [2]


100%|██████████| 128/128 [01:18<00:00,  1.64it/s, loss=0.123]


Validation Dice score: 0.9848
Epoch [3]


100%|██████████| 128/128 [01:19<00:00,  1.61it/s, loss=0.0868]


Validation Dice score: 0.9867
Epoch [4]


100%|██████████| 128/128 [01:20<00:00,  1.60it/s, loss=0.0942]


Validation Dice score: 0.9859
Epoch [5]


100%|██████████| 128/128 [01:20<00:00,  1.59it/s, loss=0.0611]


Validation Dice score: 0.9883
Epoch [6]


100%|██████████| 128/128 [01:20<00:00,  1.59it/s, loss=0.053] 


Validation Dice score: 0.9896
Epoch [7]


100%|██████████| 128/128 [01:20<00:00,  1.58it/s, loss=0.0479]


Validation Dice score: 0.9905
Epoch [8]


100%|██████████| 128/128 [01:20<00:00,  1.59it/s, loss=0.0416]


Validation Dice score: 0.9912
Epoch [9]


100%|██████████| 128/128 [01:20<00:00,  1.59it/s, loss=0.0336]


Validation Dice score: 0.9910
Epoch [10]


100%|██████████| 128/128 [01:20<00:00,  1.58it/s, loss=0.0399]


Validation Dice score: 0.9898
Epoch [11]


100%|██████████| 128/128 [01:20<00:00,  1.59it/s, loss=0.0354]


Validation Dice score: 0.9887
Epoch [12]


100%|██████████| 128/128 [01:20<00:00,  1.59it/s, loss=0.0275]


Validation Dice score: 0.9921
Epoch [13]


100%|██████████| 128/128 [01:21<00:00,  1.58it/s, loss=0.024] 


Validation Dice score: 0.9913
Epoch [14]


100%|██████████| 128/128 [01:20<00:00,  1.58it/s, loss=0.0237]


Validation Dice score: 0.9921
Epoch [15]


100%|██████████| 128/128 [01:20<00:00,  1.58it/s, loss=0.0368]


Validation Dice score: 0.9926
Epoch [16]


100%|██████████| 128/128 [01:20<00:00,  1.59it/s, loss=0.0191]


Validation Dice score: 0.9923
Epoch [17]


100%|██████████| 128/128 [01:20<00:00,  1.59it/s, loss=0.0179]


Validation Dice score: 0.9931
Epoch [18]


100%|██████████| 128/128 [01:20<00:00,  1.59it/s, loss=0.0242]


Validation Dice score: 0.8881
Epoch [19]


100%|██████████| 128/128 [01:20<00:00,  1.58it/s, loss=0.0155]


Validation Dice score: 0.9926
Epoch [20]


100%|██████████| 128/128 [01:21<00:00,  1.57it/s, loss=0.0133]


Validation Dice score: 0.9928
