In [1]:
from IPython.core.display import display, HTML

display(HTML("<style>.container { width:90% !important; }</style>"))

  from IPython.core.display import display, HTML


In [None]:
import datetime
import os
import shutil
import time
from copy import deepcopy
from glob import glob
from pathlib import Path

import albumentations as A
import habana_frameworks.torch.core as htcore
import numpy as np
import opendatasets as od
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
from albumentations.pytorch.transforms import ToTensorV2
from PIL import Image, ImageChops
from torch.utils.data import DataLoader, random_split
from tqdm.auto import tqdm

from unet import UNET
from utils import get_data, load_hpu_library, set_env_params

In [None]:
set_env_params(run_lazy_mode=True, hpus_per_node=1)
load_hpu_library()

In [None]:
get_data()

In [None]:
png = ".png"


class Dataset:
    def __init__(self, cxr_dir, mask_dir, transform=None):
        self.cxr_images = glob(os.path.join(cxr_dir, "*{}".format(png)))
        self.mask_images = glob(os.path.join(mask_dir, "*{}".format(png)))
        self.transform = transform

    def __len__(self):
        return len(self.cxr_images)

    def __getitem__(self, idx):
        cxr_png_path = Path(self.cxr_images[idx])
        mask_png_path = Path(self.mask_images[idx])
        img = np.array(Image.open(cxr_png_path).convert("RGB"))
        mask = np.array(Image.open(mask_png_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 1.0

        if self.transform:
            augs = self.transform(image=img, mask=mask)
            img = augs["image"]
            mask = augs["mask"]

        return img, mask

In [None]:
dim = 256
transforms = A.Compose(
    [
        A.Resize(height=dim, width=dim, always_apply=True),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

In [None]:
cxr_dir = "data/proc_seg/cxr_pngs/"
mask_dir = "data/proc_seg/mask_pngs/"
bs = 16

In [None]:
dataset = Dataset(cxr_dir=cxr_dir, mask_dir=mask_dir, transform=transforms)
train_samples = int(len(dataset) * 0.8)
train_data, val_data = random_split(
    dataset, [train_samples, len(dataset) - train_samples]
)

In [None]:
train_loader = DataLoader(
    train_data, batch_size=bs, shuffle=True, pin_memory=True, num_workers=os.cpu_count()
)
val_loader = DataLoader(
    val_data, batch_size=bs, shuffle=False, pin_memory=True, num_workers=os.cpu_count()
)

In [None]:
model = UNET(in_channels=3, out_channels=1)

In [None]:
# model

In [None]:
device = torch.device("hpu")
device

In [None]:
# permute the params from filters first (KCRS) to filters last(RSCK) or vice versa.
# and permute from RSCK to KCRS is used for checkpoint saving
def permute_params(model, to_filters_last, lazy_mode):
    with torch.no_grad():
        for name, param in model.named_parameters():
            if param.ndim == 4:
                if to_filters_last:
                    param.data = param.data.permute((2, 3, 1, 0))
                else:
                    param.data = param.data.permute(
                        (3, 2, 0, 1)
                    )  # permute RSCK to KCRS

    if lazy_mode:
        import habana_frameworks.torch.core as htcore

        htcore.mark_step()

In [None]:
def permute_momentum(optimizer, to_filters_last, lazy_mode):
    # Permute the momentum buffer before using for checkpoint
    for group in optimizer.param_groups:
        for p in group["params"]:
            param_state = optimizer.state[p]
            if "momentum_buffer" in param_state:
                buf = param_state["momentum_buffer"]
                if buf.ndim == 4:
                    if to_filters_last:
                        buf = buf.permute((2, 3, 1, 0))
                    else:
                        buf = buf.permute((3, 2, 0, 1))
                    param_state["momentum_buffer"] = buf

    if lazy_mode:
        import habana_frameworks.torch.core as htcore

        htcore.mark_step()

In [None]:
criterion = nn.BCEWithLogitsLoss()
model = model.to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
permute_params(model, True, True)
permute_momentum(optimizer, True, True)

In [None]:
def save_checkpoint(state, is_best, filename):
    torch.save(state["state_dict"], filename)
    if is_best:
        shutil.copyfile(filename, "model_best" + str(state["epoch"]) + ".pth.tar")

In [None]:
def train(train_loader, epoch):
    for i, (images, target) in enumerate(pbar := tqdm(train_loader)):
        pbar.set_description(f"Training")
        images, target = images.to(device, non_blocking=True), target.to(
            device, non_blocking=True
        ).unsqueeze(1)
        images = images.contiguous(memory_format=torch.channels_last)
        htcore.mark_step()
        # compute output
        output = model(images)
        loss = criterion(output, target)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        htcore.mark_step()
        optimizer.step()
        htcore.mark_step()
        pbar.set_postfix(
            {
                "Train Epoch": epoch,
                "Train Loss": loss.item(),
            }
        )

In [None]:
def validate(val_loader, model, criterion, device):
    dice_score = 0
    with torch.no_grad():
        data_end = time.time()
        for i, (images, target) in enumerate(pbar := tqdm(val_loader)):
            pbar.set_description(f"Validating")
            images, target = images.to(device, non_blocking=True), target.to(
                device, non_blocking=True
            ).unsqueeze(1)
            images = images.contiguous(memory_format=torch.channels_last)
            htcore.mark_step()
            # compute output
            output = model(images)
            loss = criterion(output, target)
            preds = (torch.sigmoid(output) > 0.5).float()
            dice_score += (2 * (preds * target).sum()) / ((preds + target).sum() + 1e-7)
            pbar.set_postfix(
                {
                    "Validation Epoch": epoch,
                    "Validation Loss": loss.item(),
                    "Dice Score": (dice_score / i).item(),
                }
            )
    return dice_score / len(val_loader)

In [None]:
start_time = time.time()
e_time = start_time
best_acc1 = 0

for epoch in range(30):
    model.train()
    end = time.time()
    train(train_loader, epoch)
    # evaluate on validation set
    # switch to evaluate mode
    model.eval()
    model_for_eval = model
    acc1 = validate(val_loader, model_for_eval, criterion, device)

    # remember best acc@1 and save checkpoint
    is_best = acc1 > best_acc1
    if is_best:
        print(
            f"Dice score inreased from {best_acc1} --> {acc1} --> Saving checkpoint epoch {epoch} "
        )
        # Permute model parameters from RSCK to KCRS
        model_without_ddp = model
        permute_params(model_without_ddp, False, True)
        # Use this model only to copy the state_dict of the actual model
        copy_model = UNET(in_channels=3, out_channels=1)
        state_dict = model_without_ddp.state_dict()
        for k, v in state_dict.items():
            if "num_batches_tracked" in k and v.dim() == 1:
                state_dict[k] = v.squeeze(0)

        copy_model.load_state_dict(state_dict)
        # Permute the weight momentum buffer before saving in checkpoint
        permute_momentum(optimizer, False, True)

        # Bring all model parameters and optimizer parameters to CPU
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to("cpu")

        # Save model parameters in checkpoint
        dir_ = "checkpoints/"
        filename = dir_ + "checkpoint_" + str(epoch) + "_" + "hpu" + ".pth"
        save_checkpoint(
            {
                "epoch": epoch,
                "arch": model,
                "state_dict": copy_model.state_dict(),
                "best_acc1": best_acc1,
                "optimizer": optimizer.state_dict(),
            },
            is_best,
            filename,
        )

        # Take back model parameters and optimizer parameters to HPU
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to("hpu")
        # Permute back from KCRS to RSCK
        permute_params(model, True, True)
        permute_momentum(optimizer, True, True)
        best_acc1 = max(acc1, best_acc1)

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Training time {}".format(total_time_str))

In [None]:
torch.save(state["state_dict"], "filename")