In [5]:
import torch
import numpy as np
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from model import UNET
from utils import (
    load_checkpoint,
    save_checkpoint,
    get_loaders,
    check_accuracy,
    save_predictions_as_imgs,
)

# Hyperparameters etc.
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 10
NUM_WORKERS = 2
IMAGE_HEIGHT = 384  # 1280 originally
IMAGE_WIDTH = 512  # 1918 originally
PIN_MEMORY = True
LOAD_MODEL = False


ModuleNotFoundError: No module named 'torchvision'

In [None]:
from sklearn.model_selection import train_test_split
import shutil

if os.path.exists('saved_images/') == False:
  os.mkdir("saved_images/")

if os.path.exists('/content/data/train_images') == False: 
  
  images = np.array(os.listdir('/content/data/images'))
  masks = np.array(os.listdir('/content/data/masks'))
  train_images, val_images = train_test_split(images, test_size=0.3, random_state=25)
  train_masks, val_masks = train_test_split(masks, test_size=0.3, random_state=25)
  train_images_dir = np.char.add(np.array (train_images.shape[0]*['/content/data/images/']),train_images)
  val_images_dir = np.char.add(np.array (val_images.shape[0]*['/content/data/images/']),val_images)
  if os.path.exists("data/train_images/") == False:
    os.mkdir("data/train_images/")
  if os.path.exists('data/train_masks/') == False:
    os.mkdir("data/train_masks/")
  if os.path.exists('data/val_images/') == False:
    os.mkdir("data/val_images/")
  if os.path.exists('data/val_masks/') == False:
    os.mkdir("data/val_masks/")



  [shutil.copy(file, "data/train_images/") for file in train_images_dir]
  [shutil.copy(file, "data/train_masks/") for file in train_images_dir]

  [shutil.copy(file, "data/val_images/") for file in val_images_dir]
  [shutil.copy(file, "data/val_masks/") for file in val_images_dir]

  if os.path.exists('data/images/') == True:
    shutil.rmtree('data/images/')

  if os.path.exists('data/masks/') == True:
    shutil.rmtree('data/masks/')
TRAIN_IMG_DIR = "data/train_images/"
TRAIN_MASK_DIR = "data/train_masks/"
VAL_IMG_DIR = "data/val_images/"
VAL_MASK_DIR = "data/val_masks/"



In [None]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())


def main():
    train_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    val_transforms = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        val_transforms,
        NUM_WORKERS,
        PIN_MEMORY,
    )

    if LOAD_MODEL:
        load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)


    check_accuracy(val_loader, model, device=DEVICE)
    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler)

        # save model
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer":optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)

        # check accuracy
        check_accuracy(val_loader, model, device=DEVICE)

        # print some examples to a folder
        if epoch % 5 == 0:
          save_predictions_as_imgs(
              val_loader, model, folder="saved_images", device=DEVICE
          )


if __name__ == "__main__":
    main()

Got 2939207/70385664 with acc 4.18
Dice score: 1.9712238311767578


100%|██████████| 53/53 [01:00<00:00,  1.15s/it, loss=-173]


=> Saving checkpoint
Got 2938156/70385664 with acc 4.17
Dice score: 1.714235782623291


100%|██████████| 53/53 [01:03<00:00,  1.19s/it, loss=-190]


=> Saving checkpoint
Got 493834/70385664 with acc 0.70
Dice score: 1.8519206047058105


100%|██████████| 53/53 [01:03<00:00,  1.20s/it, loss=-225]


=> Saving checkpoint
Got 2936641/70385664 with acc 4.17
Dice score: 1.9633547067642212


100%|██████████| 53/53 [01:03<00:00,  1.20s/it, loss=-246]


=> Saving checkpoint
Got 609777/70385664 with acc 0.87
Dice score: 1.9586812257766724


100%|██████████| 53/53 [01:03<00:00,  1.20s/it, loss=-282]


=> Saving checkpoint
Got 2938769/70385664 with acc 4.18
Dice score: 1.9561361074447632


100%|██████████| 53/53 [01:03<00:00,  1.20s/it, loss=-265]


=> Saving checkpoint
Got 2270868/70385664 with acc 3.23
Dice score: 1.9585925340652466


100%|██████████| 53/53 [01:02<00:00,  1.18s/it, loss=-282]


=> Saving checkpoint
Got 1905391/70385664 with acc 2.71
Dice score: 1.9710873365402222


100%|██████████| 53/53 [01:02<00:00,  1.18s/it, loss=-326]


=> Saving checkpoint
Got 2939207/70385664 with acc 4.18
Dice score: 1.9712238311767578


100%|██████████| 53/53 [01:03<00:00,  1.19s/it, loss=-322]


=> Saving checkpoint
Got 2109569/70385664 with acc 3.00
Dice score: 1.9712047576904297


100%|██████████| 53/53 [01:03<00:00,  1.19s/it, loss=-367]


=> Saving checkpoint
Got 2939207/70385664 with acc 4.18
Dice score: 1.9712238311767578
