In [1]:
import os
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim

from dataset import CarvanaDataset, prepare_datasets
from model import UNET
from train import train_fn
import utils

In [2]:
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
DATA_DIR = './dataset/'
LR = 1e-4
BATCH_SIZE = 32
NUM_EPOCHS = 2
NUM_WORKERS = 16
IMAGE_HEIGHT = 160
IMAGE_WIDTH = 240
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_DIR = os.path.join(DATA_DIR, 'train')
TRAIN_MASKDIR = os.path.join(DATA_DIR, 'train_masks')
VAL_DIR = os.path.join(DATA_DIR, 'val')
VAL_MASKDIR = os.path.join(DATA_DIR, 'val_masks')

In [3]:
# prepare_datasets(DATA_DIR)

In [4]:
train_transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0),
        ToTensorV2()
    ])

val_transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0),
        ToTensorV2()
    ])

In [5]:
net = UNET(in_channels=3, out_channels=1).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(net.parameters(), lr=LR)

In [6]:
x = torch.randn((32, 3, 160, 240)).to(DEVICE)
pred = net(x)
print(pred.shape)
print(x.shape)

torch.Size([32, 1, 160, 240])
torch.Size([32, 3, 160, 240])


In [7]:
train_loader, val_loader = utils.get_loaders(
    TRAIN_DIR,
    TRAIN_MASKDIR,
    VAL_DIR,
    VAL_MASKDIR,
    BATCH_SIZE,
    train_transform,
    val_transform,
    NUM_WORKERS,
    PIN_MEMORY
)

In [8]:
scaler = torch.cuda.amp.GradScaler()
for epoch in range(NUM_EPOCHS):
    train_fn(train_loader, net, optimizer, loss_fn, scaler, device=DEVICE)
    
#     torch.save({
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             }, './model_weights/model_weights.pth')
    
    utils.check_dice(val_loader, net, device=DEVICE)
    
#     utils.save_preds_img(val_loader, net)

100%|██████████| 128/128 [01:57<00:00,  1.09it/s, loss=0.134]


Validation Dice score: 0.9762299656867981


100%|██████████| 128/128 [02:06<00:00,  1.01it/s, loss=0.105]


Validation Dice score: 0.9790133237838745
