In [None]:
!git clone https://github.com/doermindset/face-segmentation.git
%cd face-segmentation
!git checkout develop
%mkdir checkpoints

In [None]:
!pip install wandb -Uq

In [None]:
import numpy as np
import torch.cuda
import torch.nn as nn
import torch.optim as optim
import wandb
from data.lfw_dataset import LFWDataset
from torch.utils.data import DataLoader
from metrics.segmentation_metrics import compute_metrics
from models.uNet import UNet
from tqdm import tqdm
from model_checkpoint import ModelCheckpoint

wandb.login()

In [None]:
step = 0

In [None]:
def test(model, test_loader, device):
    model.eval()
    mean_accuracy, mean_iou, mean_fw_iou = [], [], []

    with torch.no_grad():
        for batch_idx, data in tqdm(enumerate(test_loader), desc="evaluate_unet"):
            imgs = data["image"]
            segs = data["seg"]
            imgs, segs = imgs.to(device), segs.to(device)

            segs_pred = model(imgs)

            mpa, m_iou, m_fw_iou = compute_metrics(segs, segs_pred)
            mean_accuracy.append(mpa)
            mean_iou.append(m_iou)
            mean_fw_iou.append(m_fw_iou)

    wandb.log({"Test Mean Pixel Acc": np.mean(mean_accuracy),
               "Test Mean IoU": np.mean(mean_iou),
               "Test Frequency Weighted IoU": np.mean(mean_fw_iou)}, step=step)

In [None]:
def val(model, val_loader, criterion, config, device, epoch, model_ckpt):
    global step
    running_loss = 0.0
    mean_accuracy, mean_iou, mean_fw_iou = [], [], []
    table = wandb.Table(columns=["id", "image", "pred", "gt"])

    model.eval()

    pbar = tqdm(enumerate(val_loader, 0),
                unit=' images',
                unit_scale=config.batch_size,
                total=len(val_loader),
                smoothing=0,
                disable=False)

    with torch.no_grad():
        for (batch_idx, data) in pbar:
            imgs = data["image"]
            segs = data["seg"]
            imgs, segs = imgs.to(device), segs.to(device)

            segs_pred = model(imgs)
            loss = criterion(segs_pred, segs)

            if batch_idx < 5:
                table.add_data(
                    *[f'{step}_{batch_idx}', wandb.Image(imgs[0]), wandb.Image(segs_pred[0]), wandb.Image(segs[0])])

            running_loss += float(loss)
            val_loss = float(running_loss) / (batch_idx + 1)

            pbar.set_description(f'Validation [ E {epoch}, L {loss}, L_Avg {val_loss}')

            mpa, m_iou, m_fw_iou = compute_metrics(segs, segs_pred)
            mean_accuracy.append(mpa)
            mean_iou.append(m_iou)
            mean_fw_iou.append(m_fw_iou)

        val_loss = float(running_loss) / len(val_loader)

        wandb.log({"Validation Loss": val_loss,
                   "Validation Mean Pixel Acc": np.mean(mean_accuracy),
                   "Validation Mean IoU": np.mean(mean_iou),
                   "Validation Frequency Weighted IoU": np.mean(mean_fw_iou)}, step=step)

        wandb.log({"Images Data": table})

        model_ckpt(model, epoch, np.mean(mean_iou))

In [None]:
def train(model, train_loader, criterion, optimizer, config, device, epoch):
    running_loss = 0.0
    global step

    model.train()
    pbar = tqdm(enumerate(train_loader, 0),
                unit=' images',
                unit_scale=config.batch_size,
                total=len(train_loader),
                smoothing=0,
                disable=False)

    for (batch_idx, data) in pbar:

        imgs = data["image"]
        segs = data["seg"]
        imgs, segs = imgs.to(device), segs.to(device)

        optimizer.zero_grad()
        segs_pred = model(imgs)
        loss = criterion(segs_pred, segs)

        loss.backward()
        optimizer.step()

        running_loss += float(loss)
        step += len(imgs)
        train_loss = float(running_loss) / (batch_idx + 1)
        pbar.set_description(f'Training [ E {epoch}, L {loss}, L_Avg {train_loss}')

        batch_idx += 1
        if batch_idx % config.log_freq == 0:
            wandb.log({"Training Loss": train_loss}, step=step)

In [None]:
def model_pipeline(hyperparameters=None):
    with wandb.init(config=hyperparameters):
        global step
        step = 0
        config = wandb.config

        device = "cuda" if torch.cuda.is_available() else "cpu"

        train_dataset = LFWDataset(download=True, base_folder='data/lfw_dataset', split_name="train")
        val_dataset = LFWDataset(download=False, base_folder='lfw_dataset', split_name="val")
        test_dataset = LFWDataset(download=False, base_folder='lfw_dataset', split_name="test")

        train_loader = DataLoader(train_dataset,
                                  batch_size=config.batch_size,
                                  pin_memory=True,
                                  shuffle=False,
                                  sampler=None,
                                  num_workers=0)

        val_loader = DataLoader(val_dataset,
                                batch_size=config.batch_size,
                                pin_memory=True,
                                shuffle=False,
                                sampler=None,
                                num_workers=0)

        test_loader = DataLoader(test_dataset,
                                 batch_size=config.batch_size,
                                 pin_memory=True,
                                 shuffle=False,
                                 sampler=None,
                                 num_workers=0)

        model = UNet(n_channels=3, n_classes=3, bilinear=config.bilinear)
        model = model.to(device)

        criterion = nn.CrossEntropyLoss()
        
        if config.optimizer == "adam":
            optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
        elif config.optimizer == "sgd":
            optimizer = optim.SGD(model.parameters(), lr=config.learning_rate)
            
        model_ckpt = ModelCheckpoint(0.0, True, 5, "mean_iou")
        for epoch in range(config.epochs):
            val(model, val_loader, criterion, config, device, epoch, model_ckpt)
            train(model, train_loader, criterion, optimizer, config, device, epoch)

        test(model, test_loader, device)
    wandb.finish()

In [None]:
sweep_config = {
    'method': 'random',
    'metric': {'name': 'Validation Loss', 'goal': 'minimize'},
    'parameters': {
        'learning_rate': {
            'distribution': 'uniform',
            'min': 0,
            'max': 0.1
        },
        'batch_size': {
            'distribution': 'q_log_uniform_values',
            'q': 8,
            'min': 8,
            'max': 32,
        },
        'epochs': {'value': 30},
        'classes': {'value': 3},
        'log_freq': {'value': 10}

    }
}
sweep_id = wandb.sweep(sweep_config, project="face-segmentation-sweeps-grid")

In [None]:
sweep_config = {
    'method': 'grid',
    'metric': {'name': 'Validation Loss', 'goal': 'minimize'},
    'parameters': {
        'learning_rate': {
            'values': [0.01, 0.001, 0.0005]
        },
        'batch_size': {
            'values': [16, 32]
        },
        'epochs': {'value': 30},
        'classes': {'value': 3},
        'log_freq': {'value': 10},
        'optimizer': {
            'values': ['adam', 'sgd']
        }

    }
}
sweep_id = wandb.sweep(sweep_config, project="face-segmentation-sweeps-grid")

In [None]:
wandb.agent(sweep_id, model_pipeline, count=5)

In [None]:
config = dict(
        epochs=50,
        classes=3,
        batch_size=32,
        learning_rate=0.0005,
        dataset="LFW",
        architecture="UNet",
        log_freq=10,
        optimizer="adam")

model_pipeline(config)