In [1]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from Deep_Learning.UNets.image_segmentation.model import UNet
from Deep_Learning.UNets.image_segmentation.dataset import CarvanaDatasetLoaded, CarvanaDataset
from Deep_Learning.UNets.image_segmentation.train import train
from Deep_Learning.UNets.image_segmentation.utils import (
    load_checkpoint,
    save_checkpoint,
    get_loaders,
    check_accuracy,
    save_predictions_as_imgs,
)

In [6]:
# Hyperparameters etc.
LEARNING_RATE = 3e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64
NUM_EPOCHS = 10
NUM_WORKERS = 2
IMAGE_HEIGHT = 160  # 1280 originally
IMAGE_WIDTH = 240  # 1918 originally
PIN_MEMORY = True
LOAD_MODEL = False
VAL_RATIO = 0.01
IMG_DIR = "../Datasets/carvana/train/"
MASK_DIR = "../Datasets/carvana/train_masks/"

In [7]:
train_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    A.Rotate(limit=35, p=1.0),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.1),
    A.ColorJitter(),
    A.GaussNoise(),
    A.Normalize(
        mean=[0.0, 0.0, 0.0],
        std=[1.0, 1.0, 1.0],
        max_pixel_value=255.0,
    ),
    ToTensorV2(),
])
val_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    A.Normalize(
        mean=[0.0, 0.0, 0.0],
        std=[1.0, 1.0, 1.0],
        max_pixel_value=255.0,
    ),
    ToTensorV2(),
])

In [8]:
model = UNet(in_channels=3, out_channels=1).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
train_loader, val_loader = get_loaders(
    IMG_DIR,
    MASK_DIR,
    VAL_RATIO,
    BATCH_SIZE,
    train_transform,
    val_transform,
    NUM_WORKERS,
    PIN_MEMORY,
)
if LOAD_MODEL:
    load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)
scaler = torch.cuda.amp.GradScaler()

In [9]:
train(NUM_EPOCHS, train_loader, val_loader, model, optimizer, loss_fn, scaler, DEVICE)

Epoch [0/10]:  23%|██▎       | 18/79 [00:13<00:46,  1.30it/s, loss=0.32] 


KeyboardInterrupt: 