In [1]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from pclib.nn.models import UNet
from pclib.nn.layers import ConvTranspose2d
from examples.carvana.dataset import CarvanaDatasetLoaded, CarvanaDataset
from examples.carvana.train import train
from examples.carvana.utils import (
    load_checkpoint,
    save_checkpoint,
    get_loaders,
    check_accuracy,
    save_predictions_as_imgs,
)

In [2]:
# Hyperparameters etc.
LEARNING_RATE = 3e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 12
NUM_EPOCHS = 10
NUM_WORKERS = 2
IMAGE_HEIGHT = 160  # 1280 originally
IMAGE_WIDTH = 256  # 1918 originally
PIN_MEMORY = True
LOAD_MODEL = False
VAL_RATIO = 0.01
IMG_DIR = "../Datasets/carvana/train/"
MASK_DIR = "../Datasets/carvana/train_masks/"

In [3]:
train_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    A.Rotate(limit=35, p=1.0),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.1),
    A.ColorJitter(),
    A.GaussNoise(),
    A.Normalize(
        mean=[0.0, 0.0, 0.0],
        std=[1.0, 1.0, 1.0],
        max_pixel_value=255.0,
    ),
    ToTensorV2(),
])
val_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    A.Normalize(
        mean=[0.0, 0.0, 0.0],
        std=[1.0, 1.0, 1.0],
        max_pixel_value=255.0,
    ),
    ToTensorV2(),
])

In [4]:
model = UNet(steps=60, device=DEVICE)
train_loader, val_loader = get_loaders(
    IMG_DIR,
    MASK_DIR,
    VAL_RATIO,
    BATCH_SIZE,
    train_transform,
    val_transform,
    NUM_WORKERS,
    PIN_MEMORY,
)
step = 0
stats = None
if LOAD_MODEL:
    load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)


In [7]:
for images, masks in train_loader:
    images = images.to(DEVICE)
    masks = masks.unsqueeze(1).to(DEVICE)
    out, state = model(images, y=masks)
    break
out.shape, model.vfe(state)

(torch.Size([12, 1, 160, 256]),
 tensor(3093.7871, device='cuda:0', grad_fn=<MeanBackward0>))

In [6]:
train(
    model,
    train_loader, 
    val_loader, 
    NUM_EPOCHS, 
    lr=LEARNING_RATE,
    batch_size=BATCH_SIZE,
    reg_coeff=0.01,
    step=step,
    stats=stats,
    device=DEVICE,
    optim='AdamW',
    save_every=None,
)

  5%|▍         | 19/420 [02:17<48:27,  7.25s/it]


KeyboardInterrupt: 