In [1]:
import os
import pickle
import random
import re
import time
from helpers import *

import cv2
import matplotlib.pyplot as plt
import monai
import numpy as np
import torch
from monai.networks.utils import one_hot
from monai.transforms import *
from monai.losses import DiceFocalLoss, GeneralizedDiceFocalLoss
from torchsummary import summary

In [None]:
_nsre = re.compile("([0-9]+)")


def natural_sort_key(s):
    return [
        int(text) if text.isdigit() else text.lower() for text in re.split(_nsre, s)
    ]


data_path = "data/320063_intact-eye"

dataset_paths = [f.path for f in os.scandir(data_path) if f.is_dir()]

images = []
masks = []

for dataset_path in dataset_paths:
    img_file_names = sorted(
        os.listdir(os.path.join(dataset_path, "img")), key=natural_sort_key
    )
    img_paths = [
        os.path.join(dataset_path, "img", img_file_name)
        for img_file_name in img_file_names
    ]
    mask_paths = [
        os.path.join(dataset_path, "masks_machine", img_file_name[:-4] + ".png")
        for img_file_name in img_file_names
    ]
    images.extend(img_paths)
    masks.extend(mask_paths)

data_dict = [{"img": img, "seg": mask} for img, mask in zip(images, masks)]


# Supervisely downloads masks that are empty as well, so we need to remove them
for data in data_dict:
    seg = data["seg"]
    seg_img = cv2.imread(seg, cv2.IMREAD_GRAYSCALE)
    if not np.any(seg_img):
        data_dict.remove(data)

print(len(data_dict))

In [3]:
def create_datasets(split, train_batch_size, val_batch_size, test_batch_size):
    train_data_list, val_data_list, test_data_list = monai.data.utils.partition_dataset(
        data_dict, ratios=split, shuffle=True, seed=240899
    )

    train_transforms = monai.transforms.Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            EnsureChannelFirstd(keys=["img", "seg"]),
            BorderPadd(keys=["img", "seg"], spatial_border=(12, 0)),
            Rotate90d(keys=["img", "seg"], spatial_axes=[1, 0]),
            Flipd(keys=["img", "seg"], spatial_axis=[1]),
            RandFlipd(keys=["img", "seg"], prob=0.5, spatial_axis=1),
            RandRotated(
                keys=["img", "seg"],
                range_x=0.525,
                prob=0.8,
                mode=("bilinear", "nearest"),
            ),
        ]
    )

    val_transforms = monai.transforms.Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            EnsureChannelFirstd(keys=["img", "seg"]),
            BorderPadd(keys=["img", "seg"], spatial_border=(12, 0)),
            Rotate90d(keys=["img", "seg"], spatial_axes=[1, 0]),
            Flipd(keys=["img", "seg"], spatial_axis=[1]),
        ]
    )

    test_transforms = monai.transforms.Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            EnsureChannelFirstd(keys=["img", "seg"]),
            BorderPadd(keys=["img", "seg"], spatial_border=(12, 0)),
            Rotate90d(keys=["img", "seg"], spatial_axes=[1, 0]),
            Flipd(keys=["img", "seg"], spatial_axis=[1]),
        ]
    )

    train_dataset = monai.data.Dataset(data=train_data_list, transform=train_transforms)
    val_dataset = monai.data.Dataset(data=val_data_list, transform=val_transforms)
    test_dataset = monai.data.Dataset(data=test_data_list, transform=test_transforms)

    train_loader = monai.data.DataLoader(
        train_dataset, batch_size=train_batch_size, num_workers=8
    )
    val_loader = monai.data.DataLoader(
        val_dataset, batch_size=val_batch_size, num_workers=8
    )
    test_loader = monai.data.DataLoader(
        test_dataset, batch_size=test_batch_size, num_workers=8
    )

    print(f"Total dataset size: {len(data_dict)}")
    print(
        f"Num. train images: {len(train_dataset)}\nNum. val images: {len(val_dataset)}\nNum. test images: {len(test_dataset)}"
    )

    # sanity check for shapes
    print(f'Train image size: {monai.utils.first(train_loader)["img"].shape}')
    print(f'Train ground truth size: {monai.utils.first(train_loader)["seg"].shape}')
    print(
        f'Class labels in ground truth: {np.unique(monai.utils.first(train_loader)["seg"])}'
    )
    return train_loader, val_loader, test_loader

In [4]:
def create_model_loss_optimizer_scheduler(
    device,
    lr,
    num_channels,
    num_res_units,
    scheduler_factor,
    scheduler_patience,
    scheduler_min_lr,
):
    model = monai.networks.nets.UNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=4,
        channels=tuple([2**x for x in range(num_channels[0], num_channels[1])]),
        strides=tuple([2] * (num_channels[1] - num_channels[0] - 1)),
        kernel_size=3,
        num_res_units=num_res_units,
    ).to(device)

    loss = monai.losses.DiceLoss(include_background=True, softmax=True)
    # loss = DiceFocalLoss(include_background=True, softmax=True, gamma=1, weight=torch.tensor([0.0043, 0.9990, 0.9982, 0.9983]))
    # loss = GeneralizedDiceFocalLoss(include_background=True, softmax=True, gamma=0.6, weight=torch.tensor([0.005, 0.33, 0.33, 0.33]), lambda_gdl=0.5, lambda_focal=0.5)
    
    optimizer = torch.optim.Adam(model.parameters(), lr)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="min",
        factor=scheduler_factor,
        patience=scheduler_patience,
        min_lr=scheduler_min_lr,
    )
    return model, loss, optimizer, lr_scheduler

In [5]:
def train(
    device,
    val_interval,
    max_epochs,
    train_loader,
    val_loader,
    model,
    optimizer,
    loss_function,
    lr_scheduler,
    hparams_dict,
    run_name=time.strftime("%Y%m%d-%H%M%S"),
):

    run_path = os.path.join("runs", run_name)
    os.makedirs(run_path, exist_ok=False)

    with open(os.path.join(run_path, "hparams.pickle"), "wb") as pickle_file:
        pickle.dump(hparams_dict, pickle_file)

    best_val_loss = np.inf
    best_val_epoch = -1
    epoch_loss_values = list()
    val_loss_values = list()

    for epoch in range(max_epochs):

        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_epochs}")

        model.train()
        epoch_loss = 0
        step = 0

        for batch_data in train_loader:
            optimizer.zero_grad()
            step += 1

            inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device)
            labels = one_hot(labels, num_classes=4)

            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            epoch_len = (
                train_loader.batch_size * len(train_loader)
            ) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}", end="\r")
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:

            model.eval()
            val_loss = 0
            val_steps = 0
            save_val_imgs = True

            with torch.no_grad():
                for val_data in val_loader:
                    val_steps += 1
                    val_images, val_labels = val_data["img"].to(device), val_data[
                        "seg"
                    ].to(device)
                    val_labels = one_hot(val_labels, num_classes=4)

                    val_output = model(val_images)
                    val_loss += loss_function(val_output, val_labels).item()
                    if save_val_imgs:
                        val_output = torch.softmax(val_output, dim=1)
                        save_val_results(val_labels.cpu().numpy(), val_output.cpu().numpy(), os.path.join(run_path, f"val_results_{epoch+1}.png"))
                        save_val_imgs = False # we only save once per epoch
                        

                val_loss /= val_steps
                val_loss_values.append(val_loss)
                lr_scheduler.step(val_loss)
                print(f"epoch {epoch + 1} average validation loss: {val_loss:.4f}")

                # save best model
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    best_val_epoch = epoch + 1
                    save_path = os.path.join(
                        run_path, f"best.pth"
                    )
                    torch.save(model.state_dict(), save_path)
                    print(
                        f"current epoch: {epoch + 1} current val loss: {val_loss:.4f} best val loss: {best_val_loss:.4f} at epoch {best_val_epoch}"
                    )
                else:
                    # early stopping
                    if epoch - best_val_epoch > 20:
                        print(
                            f"best val loss not updated for 20 epochs, stopping training"
                        )
                        break

        if (epoch + 1) % 50 == 0:
            save_path = os.path.join(
                run_path, f"checkpoint_{epoch+1}_val_loss_{val_loss:.4f}.pth"
            )
            torch.save(model.state_dict(), save_path)

    print(
        f"train completed, best val loss: {best_val_loss:.4f} at epoch: {best_val_epoch}"
    )
    save_path = os.path.join(run_path, f"final_{epoch+1}_val_loss{val_loss:.4f}.pth")
    torch.save(model.state_dict(), save_path)

    return epoch_loss_values, val_loss_values

In [6]:
def save_loss_plots(epoch_loss_values, val_loss_values, run_name, val_interval):
    # plot train and val losses
    plt.figure("train")
    plt.title("Epoch Average Loss")
    x = [i + 1 for i in range(len(epoch_loss_values))]
    x_val = [i for i in range(val_interval, len(x) + 1, val_interval)]
    print(x_val)
    train_loss = epoch_loss_values
    val_loss = val_loss_values
    plt.xlabel("epoch")
    plt.plot(x, train_loss, color="red")
    plt.plot(x_val, val_loss, color="green")
    plt.legend(["train", "val"])
    plt.savefig(f"runs/{run_name}/loss_plot.png")
    plt.clf()
    
    losses = np.vstack([epoch_loss_values, val_loss_values]).T
    np.savetxt(f"runs/{run_name}/losses.csv", losses, delimiter=",", fmt="%f")

In [7]:
def test_and_save_results(device, model, test_loader, run_name, print_results=False):
    times = []
    class_wise_dice = []
    class_wise_iou = []
    class_wise_precision = []
    class_wise_recall = []
    class_wise_f1 = []
    class_wise_accuracy = []

    with torch.no_grad():
        for test_data in test_loader:
            start_time = time.perf_counter()
            test_images, test_labels = test_data["img"].to(device), test_data["seg"].to(
                device
            )

            output = model(test_images)
            predicted_mask = torch.argmax(torch.softmax(output, dim=1), dim=1)
            elapsed_time = time.perf_counter() - start_time
            # print(f"elapsed time: {elapsed_time:.5f}")

            times.append(elapsed_time)

            dice_scores = dice_coefficient_per_class(
                predicted_mask, test_labels.squeeze(), 4
            )
            class_wise_dice.append(dice_scores)

            iou_scores = iou_per_class(predicted_mask, test_labels.squeeze(), 4)
            class_wise_iou.append(iou_scores)

            precision_scores = precision_per_class(
                predicted_mask, test_labels.squeeze(), 4
            )
            class_wise_precision.append(precision_scores)

            recall_scores = recall_per_class(predicted_mask, test_labels.squeeze(), 4)
            class_wise_recall.append(recall_scores)

            f1_scores = f1_score_per_class(predicted_mask, test_labels.squeeze(), 4)
            class_wise_f1.append(f1_scores)

            accuracy_scores = accuracy_per_class(
                predicted_mask, test_labels.squeeze(), 4
            )
            class_wise_accuracy.append(accuracy_scores)

            # cv2.imshow('test', apply_color_map(predicted_mask[3].detach().cpu().numpy()))
            # cv2.waitKey(0)
            # cv2.destroyAllWindows()

            # for i in range(test_data["img"].shape[0]):
            #     visualize_segmentation_results(
            #         original_image=test_data["img"][i][0].detach().cpu(),
            #         ground_truth_mask=test_data["seg"][i][0].detach().cpu(),
            #         predicted_mask=predicted_mask[i].detach().cpu()
            #     )

    average_time = np.mean(times)
    average_dice_scores = np.mean(class_wise_dice, axis=0)
    average_iou_scores = np.mean(class_wise_iou, axis=0)
    average_precision_scores = np.mean(class_wise_precision, axis=0)
    average_recall_scores = np.mean(class_wise_recall, axis=0)
    average_f1_scores = np.mean(class_wise_f1, axis=0)
    average_accuracy_scores = np.mean(class_wise_accuracy, axis=0)

    data = np.vstack(
        [
            average_dice_scores,
            average_iou_scores,
            average_precision_scores,
            average_recall_scores,
            average_f1_scores,
            average_accuracy_scores,
        ]
    ).T

    time_column = np.full((data.shape[0], 1), average_time)
    data = np.hstack([data, time_column])

    with open(f"runs/{run_name}/results.csv", "w") as f:
        np.savetxt(
            f"runs/{run_name}/results.csv",
            data,
            delimiter=",",
            header="dice,iou,precision,recall,f1,accuracy,time",
            fmt="%f",
        )

    if print_results:
        print(f"average elapsed time: {np.mean(times):.5f}")
        for cls in range(4):
            print(f"Average Dice Score for Class {cls}: {average_dice_scores[cls]:.5f}")

        for cls in range(4):
            print(f"Average IoU Score for Class {cls}: {average_iou_scores[cls]:.5f}")

        for cls in range(4):
            print(
                f"Average Precision Score for Class {cls}: {average_precision_scores[cls]:.5f}"
            )

        for cls in range(4):
            print(
                f"Average Recall Score for Class {cls}: {average_recall_scores[cls]:.5f}"
            )

        for cls in range(4):
            print(f"Average F1 Score for Class {cls}: {average_f1_scores[cls]:.5f}")

        for cls in range(4):
            print(
                f"Average Accuracy Score for Class {cls}: {average_accuracy_scores[cls]:.5f}"
            )

In [14]:
hparams = {
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "lr": 1e-4,
    "epochs": 1,
    "train_batch_size": 16,
    "val_batch_size": 16,
    "test_batch_size": 16,
    "train_val_test_split": [0.8, 0.1, 0.1],
    "val_interval": 1,
    "num_channels": [4, 9],
    "num_res_units": 0,
    "scheduler_factor": 0.1,
    "scheduler_patience": 6,
    "scheduler_min_lr": 1e-9,
    "run_name": "test_run",
}

In [12]:
def create_random_hparams():
    lr = random.uniform(1e-6, 1e-2)
    scheduler_min_lr = lr / 100000
    batch_size = random.choice([4, 8, 16])
    channel_start = random.randint(2, 4)
    channel_len = random.randint(3, 6)
    num_channels = [channel_start, channel_start + channel_len]
    num_res_units = random.randint(0, 4)

    hparams = {
        "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        "lr": lr,
        "epochs": 250,
        "train_batch_size": batch_size,
        "val_batch_size": batch_size,
        "test_batch_size": batch_size,
        "train_val_test_split": [0.8, 0.1, 0.1],
        "val_interval": 1,
        "num_channels": num_channels,
        "num_res_units": num_res_units,
        "scheduler_factor": 0.1,
        "scheduler_patience": 10,
        "scheduler_min_lr": scheduler_min_lr,
        "run_name": time.strftime("%Y%m%d-%H%M%S"),
    }
    return hparams

In [13]:
def train_from_hparams(hparams, run_name=time.strftime("%Y%m%d-%H%M%S")):
    # hparams = create_random_hparams()
    hparams["run_name"] = run_name

    print("-" * 20)
    print(hparams)

    train_loader, val_loader, test_loader = create_datasets(
        hparams["train_val_test_split"],
        hparams["train_batch_size"],
        hparams["val_batch_size"],
        hparams["test_batch_size"],
    )
    model, loss_function, optimizer, lr_scheduler = (
        create_model_loss_optimizer_scheduler(
            hparams["device"],
            hparams["lr"],
            hparams["num_channels"],
            hparams["num_res_units"],
            hparams["scheduler_factor"],
            hparams["scheduler_patience"],
            hparams["scheduler_min_lr"],
        )
    )

    summary(model, input_size=(1, 1024, 1024))

    epoch_loss_values, val_loss_values = train(
        hparams["device"],
        hparams["val_interval"],
        hparams["epochs"],
        train_loader,
        val_loader,
        model,
        optimizer,
        loss_function,
        lr_scheduler,
        hparams,
        hparams["run_name"],
    )

    save_loss_plots(
        epoch_loss_values, val_loss_values, hparams['run_name'], hparams["val_interval"]
    )

    model.load_state_dict(torch.load(f"runs/{hparams['run_name']}/best.pth"))

    test_and_save_results(
        hparams["device"], model, test_loader, hparams['run_name'], print_results=False
    )

    del model
    torch.cuda.empty_cache()

In [None]:
train_from_hparams(hparams, run_name='test'+time.strftime("%Y%m%d-%H%M%S"))

# for i in range(15):
#     print(f"Training model {i+1}")
#     hparams = create_random_hparams()
#     train_from_hparams(hparams)
#     print(f"Finished training {i+1} models")

In [16]:
def test_run(run_path):
    hparams_path = os.path.join(run_path, "hparams.pickle")
    with open(hparams_path, "rb") as pickle_file:
        hparams = pickle.load(pickle_file)

    _, _, test_loader = create_datasets(hparams["train_val_test_split"],
        hparams["train_batch_size"],
        hparams["val_batch_size"],
        hparams["test_batch_size"],)
    
    model, _, _, _ = create_model_loss_optimizer_scheduler(hparams['device'],
                                                            hparams['lr'],
                                                            hparams['num_channels'],
                                                            hparams['num_res_units'],
                                                            hparams['scheduler_factor'],
                                                            hparams['scheduler_patience'],
                                                            hparams['scheduler_min_lr'],)
    
    model.load_state_dict(
        torch.load(
            os.path.join(run_path, f"best.pth")
        )
    )

    for image in test_loader:
        img = image["img"].to(hparams["device"])
        seg = image["seg"].to(hparams["device"])
        output = model(img)
        predicted_mask = torch.argmax(torch.softmax(output, dim=1), dim=1)
        for i in range(img.shape[0]):
            visualize_segmentation_results(
                original_image=img[i][0].detach().cpu(),
                ground_truth_mask=seg[i][0].detach().cpu(),
                predicted_mask=predicted_mask[i].detach().cpu()
            )

### Intact eye fine tuning

In [None]:
model, loss, optimizer, scheduler = create_model_loss_optimizer_scheduler("cuda", 3e-5, [4, 10], 0, 0.1, 10, 1e-9)
summary(model, input_size=(1, 1024, 1024))

# model = monai.networks.nets.UNet(
#                 spatial_dims=2,
#                 in_channels=1,
#                 out_channels=4,
#                 channels=(16, 32, 64, 128, 256, 512),
#                 strides=(2, 2, 2, 2, 2),
#                 kernel_size=3,
#             ).to("cuda")

# summary(model, input_size=(1, 1024, 1024))

model.load_state_dict(torch.load("best_150_val_loss_0.4428_in_retina.pth"))

# freeze encoder layers
# for idx, param in enumerate(model.parameters()):
#     param.requires_grad = False
#     if idx == 17:
#         break

train_loader, val_loader, test_loader = create_datasets([0.8, 0.1, 0.1], 8, 8, 8)

In [None]:
run_name = "intact_retrain_large_lr_3"
val_interval = 1
epoch_loss_values, val_loss_values = train("cuda", val_interval=val_interval, max_epochs=100, train_loader=train_loader, val_loader=val_loader, model=model, optimizer=optimizer, loss_function=loss, lr_scheduler=scheduler, hparams_dict=None, run_name=run_name)
save_loss_plots(epoch_loss_values, val_loss_values, run_name, val_interval)