In [8]:
from importnb import Notebook

with Notebook():
    from CustomDataClass import CarsDataSet

In [9]:
import torch
import torchvision
from torch.utils.data import DataLoader

In [10]:
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

In [11]:
def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

In [12]:
def get_loaders(
    train_dir,
    train_maskdir,
    batch_size,
    transform,
):
    dataset = CarsDataSet(train_dir,train_maskdir,transform)
    
    train_size = int(dataset.__len__()*0.8)
    val_size = dataset.__len__()-train_size
    
    train_set, val_set = torch.utils.data.random_split(dataset,[train_size,val_size])
    
    train_loader = DataLoader(
        dataset=train_set,
        batch_size=batch_size,
        shuffle=True)
    
    val_loader = DataLoader(
        dataset=val_set,
        batch_size=batch_size,
        shuffle=True)
    
    return train_loader, val_loader

In [15]:
def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )
            
    acc = num_correct/num_pixels*100
    dice_scores= dice_score/len(loader)

    print(
        f"Got {num_correct}/{num_pixels} with acc {acc:.2f}"
    )
    print(f"Dice score: {dice_scores}")
    model.train()
    
    return acc, dice_scores
    
    

In [14]:
def save_predictions_as_imgs(loader, model, folder="saved_images/", device="cuda"):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

    model.train()