In [2]:
import os
import time
from glob import glob

import torch
from torch.utils.data import DataLoader
import torch.nn as nn

from data import DriveDataset
from model import build_unet
from loss import DiceLoss, DiceBCELoss
from utils import seeding, create_dir, epoch_time

In [3]:
def load_data():
    train_x = sorted(glob("../new_data/train/images/*"))
    train_y = sorted(glob("../new_data/train/masks/*"))

    valid_x = sorted(glob("../new_data/test/images/*"))
    valid_y = sorted(glob("../new_data/test/masks/*"))

    return train_x,train_y,valid_x,valid_y

In [4]:
def train(model,loader,optimizer,loss_fn,device):
    epoch_loss = 0.0
    
    model.train()

    for x,y in loader:
        x = x.to(device,dtype=torch.float32)
        y = y.to(device,dtype=torch.float32)

        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred,y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    
    epoch_loss /= len(loader)
    return epoch_loss

In [5]:
def evaluate(model,loader,loss_fn,device):

    epoch_loss = 0.0

    model.eval()
    
    with torch.no_grad():
        for x,y in loader:
            x = x.to(device,dtype=torch.float32)
            y = y.to(device,dtype=torch.float32)

            y_pred = model(x)
            loss = loss_fn(y_pred,y)
            epoch_loss += loss.item()
    
    epoch_loss /= len(loader)
    return epoch_loss

In [6]:
if __name__ == "__main__":
    "Seeding"
    seeding(42)

    "Directories"
    create_dir("files")

    "load_dataset"
    train_x,train_y,valid_x,valid_y = load_data()
    print("Train: ", len(train_x), len(valid_y))

    "Hyperparameters"
    H = 512
    W = 512
    size = (H, W)
    batch_size = 2
    num_epochs = 50
    lr = 1e-4
    checkpoint_path = "files/checkpoint.pth"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    "Dataset and DataLoader"
    train_dataset = DriveDataset(train_x,train_y)
    valid_dataset = DriveDataset(valid_x,valid_y)

    train_loader = DataLoader(
        dataset = train_dataset,
        batch_size = batch_size,
        shuffle = True,
        num_workers = 2 # this represents the number of CPU cores used for data loading
    )

    valid_loader = DataLoader(
        dataset = valid_dataset,
        batch_size = batch_size,
        shuffle = False,
        num_workers = 2 
    )

    device = device
    print("Device: ", device)
    model = build_unet()
    model = model.to(device)


    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=5, factor=0.3, verbose=True)
    loss_fn = DiceBCELoss()

    "Training"
    
    best_valid_loss = float("inf")

    for epoch in range(num_epochs):
        start_time = time.time()

        train_loss = train(model,train_loader,optimizer,loss_fn,device)
        valid_loss = evaluate(model,valid_loader,loss_fn,device)

        "save the best model"
        if valid_loss < best_valid_loss:
            data_str = f"Valid loss improved from {best_valid_loss:2.4f} to {valid_loss:2.4f}. Saving checkpoint:{checkpoint_path}"
            print(data_str)
            best_valid_loss = valid_loss
            torch.save(model.state_dict(),checkpoint_path)
        
        end_time = time.time()
        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        data_str = f"Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s"
        data_str += f"\tTrain Loss: {train_loss:2.4f}"
        data_str += f"\tVal. Loss: {valid_loss:2.4f}"
        print(data_str)


Train:  80 20
Device:  cuda
Valid loss improved from inf to 1.3384. Saving checkpoint:files/checkpoint.pth.tar
Epoch: 01 | Epoch Time: 1m 51s	Train Loss: 1.0638	Val. Loss: 1.3384
Valid loss improved from 1.3384 to 1.2847. Saving checkpoint:files/checkpoint.pth.tar
Epoch: 02 | Epoch Time: 1m 37s	Train Loss: 0.8926	Val. Loss: 1.2847
Valid loss improved from 1.2847 to 1.2555. Saving checkpoint:files/checkpoint.pth.tar
Epoch: 03 | Epoch Time: 1m 36s	Train Loss: 0.8067	Val. Loss: 1.2555
Valid loss improved from 1.2555 to 1.2413. Saving checkpoint:files/checkpoint.pth.tar
Epoch: 04 | Epoch Time: 1m 35s	Train Loss: 0.7376	Val. Loss: 1.2413
Valid loss improved from 1.2413 to 1.2384. Saving checkpoint:files/checkpoint.pth.tar
Epoch: 05 | Epoch Time: 1m 36s	Train Loss: 0.6807	Val. Loss: 1.2384
Valid loss improved from 1.2384 to 1.2161. Saving checkpoint:files/checkpoint.pth.tar
Epoch: 06 | Epoch Time: 1m 37s	Train Loss: 0.6272	Val. Loss: 1.2161
Epoch: 07 | Epoch Time: 1m 37s	Train Loss: 0.5815	V