In [1]:
import os
from pathlib import Path
from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

import cv2, time
import skimage
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as tr
from sklearn.model_selection import KFold
import pandas as pd
import torch.nn.functional as F

from model.metric import metrics
from model.utils.utils import *
from model.nets.unet import UNet

from model.losses.loss import WeightMapBortLoss
from model.datasets.dataloader import WeightMapDataset


os.environ["CUDA_VISIBLE_DEVICES"] = "0"

dataset_name = "snemi3d"  # "iron"

model_class  = UNet       
model_name   = "unet"    

loss_names = [
    'skeaw_dilate_step20_iter2_bort',
    ]

metric_names = ['me', 'se', 'vi', 'mAp', 'ari', 'dice', 'betti', 'betti0', 'betti1']


seed_num = 2020
setup_seed(seed_num)
kf_num = 3
kf = KFold(n_splits=kf_num, shuffle=True, random_state=seed_num)
val_rate = 0.1
export_name = 'cv_' + str(kf_num) + '_fold_' + dataset_name + '_' + model_name + '_' + loss_names[0]

if dataset_name == 'iron':
    z_score_norm = tr.Compose([
        tr.ToTensor(),
        tr.Normalize(mean = [0.9410404628082503], 
                     std =  [0.12481161024777744])
    ])
    crop_size = 512
    file_num = 150
    zfill_num = 3
elif dataset_name == 'snemi3d':
    z_score_norm = tr.Compose([
        tr.ToTensor(),
        tr.Normalize(mean = [0.5053152359607174], 
                     std =  [0.16954360899089577])
    ])
    crop_size = 512
    file_num = 100
    zfill_num = 3
file_list = [item for item in range(file_num)]

learning_rate = 1e-4
epochs = 50
num_channels = 1
use_augment = [True,  True,   True]  # In training, rand_rotation, rand_vertical_flip, rand_horizontal_flip
no_augment =  [False, False, False]  # In Validate and test

wmp_criterion = WeightMapBortLoss()

cwd = os.getcwd()
statistic_dir = Path(cwd, 'statistic')

def train(net, epoch, dataloader, optimizer, learning_rate, loss_idx, margin_criterion=None):
    net.train()
    for sample in dataloader:
        if torch.cuda.is_available():
            img, label, weight = sample['img'].cuda(), sample['label'].cuda(), sample['weight'].cuda()
            class_weight = sample['class_weight'].cuda()
            skelen = None
            if 'skelen' in sample:
                skelen = sample['skelen'].cuda()
        output = net.forward(img)

        step = 20
        iter = 2
        if 'step' in loss_names[loss_idx]:
            step = int(loss_names[loss_idx].split('step')[-1].split('_')[0])
        if 'iter' in loss_names[loss_idx]:
            iter = int(loss_names[loss_idx].split('iter')[-1].split('_')[0])

        loss = wmp_criterion(output, label, weight, class_weight, label_skelen=skelen, method=loss_names[loss_idx], step=step, epoch=epoch, d_iter=iter)

        optimizer.zero_grad()
        loss.backward()
        # 梯度裁剪
        nn.utils.clip_grad_norm_(net.parameters(), 5)
        optimizer.step()
    scheduler.step()

def val(net, epoch, dataloader, loss_idx, checkpoint_path,printer, check_first_img=True):
    net.eval()
    vi = 0
    is_first = True
    with torch.no_grad():
        for sample in dataloader:
            if torch.cuda.is_available():
                img = sample['img'].cuda()
            label = sample['label']
            weight = sample['weight']
            output = net.forward(img)
            output = output.max(1)[1].data
            output = output.cpu().squeeze().numpy()
            
            label = sample['label'].squeeze().numpy().astype(np.int64)
            output = output.astype(np.int64)
            temp_vi,_, _ = metrics.vi(output, label)
            vi += temp_vi
            if check_first_img and is_first and epoch % 10 == 0:
                check_result(epoch, img[0, :, :, :].cpu().numpy().transpose((1,2,0)).squeeze(), 
                                label, output, weight[0, :, :, :].squeeze().numpy().transpose((1,2,0)), 
                                checkpoint_path, printer, description="val_")
            is_first = False
        vi /= len(dataloader.dataset)
        return vi

experiment_names = []
experimet_metrics = []
metrics_np = np.zeros((len(metric_names), len(loss_names), kf_num)) # (Metrics, Loss, kf_num) Metric: merge error， split error， vi

# K-fold Cross Validation
for kf_idx, (train_index, test_index) in enumerate(kf.split(np.array(file_list))):
    # Dataset split
    # we use k-fold cv to produce train set and test set, and then sample some data (20%) from train set as val set (sampling without replacement)
    random.seed(seed_num)
    val_index = random.sample(list(train_index), int(len(train_index) * val_rate) + 1 )
    train_index = list(set(train_index).difference(set(val_index)))
    train_index.sort(); val_index.sort()
    
    # convert index to file name (zfill): 10->010
    train_names = file_name_convert(train_index, zfill_num)
    val_names = file_name_convert(val_index, zfill_num)
    test_names = file_name_convert(test_index, zfill_num)
    
    assert len(train_names) + len(val_names) + len(test_names) == file_num, 'The sum of three dataset should equal to the total number of dataset'
    assert is_interact(train_names, val_names) is False and is_interact(test_names, val_names) is False and is_interact(train_names, test_names) is False, 'The three dataset should not interact'
    
    for loss_idx in range(len(loss_names)):
        experiment_name = 'cv_' + str(kf_idx + 1) + '_' + dataset_name + '_minValVI_' + model_name + '_' + loss_names[loss_idx]
        experiment_names.append(experiment_name)
        imgs_dir = Path(cwd, 'data', dataset_name)
        test_inference_names = [item + '.png' for item in test_names]
        checkpoint_path = Path(cwd, 'parameters', experiment_name)
        if not os.path.exists(str(checkpoint_path)):
            os.mkdir(str(checkpoint_path))
        printer = Printer(True, str(Path(checkpoint_path, "loss.txt")))
        output_save_path = os.path.join(str(checkpoint_path), 'results')
        if not os.path.exists(output_save_path):
            os.mkdir(output_save_path)
        printer.print_and_log("This experiments is: {}'s".format(experiment_name))

        weight_dir = Path(imgs_dir, 'skeaw')
        skelen_dir = Path(imgs_dir, 'skelen')
        print('weight_dir is ', weight_dir, ' skelen_dir is ', skelen_dir)

        # train setting
        setup_seed(seed_num)
        batch_size = 10
        print('batch_size is ', batch_size, ' lr: ', learning_rate)
        
        train_dataset = WeightMapDataset(imgs_dir, train_names, weight_dir, skelen_dir=skelen_dir, use_augment = use_augment, crop_size = crop_size, norm_transform=z_score_norm, dataset_name=dataset_name)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        train_one_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)
        val_dataset = WeightMapDataset(imgs_dir, val_names, weight_dir, skelen_dir=skelen_dir, use_augment = no_augment, crop_size = crop_size,  norm_transform=z_score_norm, dataset_name=dataset_name)
        val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
        print('train_dataset ', len(train_dataset), ' val_dataset ', len(val_dataset))

        num_classes = 2
        print('num_class is ', num_classes)
        model = model_class(num_channels=num_channels, num_classes=num_classes)
        if torch.cuda.is_available():
            model = nn.DataParallel(model).cuda()

        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
        
        train_vi_list = []
        val_vi_list = []
        val_baseline = 10000
        val_best_epoch = 0

        # Training
        st_total = time.time()
        printer.print_and_log("Training:")
        for i in range(1, epochs + 1):
            print("Experiment name: " + experiment_name)
            st = time.time()
            train(model, i, train_loader, optimizer, learning_rate, loss_idx)
            train_vi = val(model, i, train_one_loader, loss_idx, checkpoint_path, printer, check_first_img=False)
            val_vi   = val(model, i, val_loader, loss_idx, checkpoint_path, printer, check_first_img=True)

            train_vi_list.append(train_vi)
            val_vi_list.append(val_vi)

            printer.print_and_log("Epoch {}: train_vi {:.4f}; val_vi {:.4f} \n".format(i, train_vi_list[-1], val_vi_list[-1]))
            plot(i, train_vi_list, val_vi_list, checkpoint_path, curve_name='vi')

            step = 10
            if 'step' in loss_names[loss_idx]:
                step = int(loss_names[loss_idx].split('step')[-1].split('_')[0])
            if val_vi < val_baseline and i > step:
                val_baseline = val_vi
                val_best_epoch = i
                torch.save(model.state_dict(), str(Path(checkpoint_path, "best_model_state.pth")))
            ed = time.time()
            printer.print_and_log("Epoch Duration: {}'s".format(ed - st))
        ed_total = time.time()
        printer.print_and_log("Total duration is: {}'s".format(ed_total - st_total))
        printer.print_and_log("The best epoch is at: {} th epoch. batch_size: {}, lr: {}".format(val_best_epoch, batch_size, learning_rate))
        printer.print_and_log("Train VI list is: {}".format(train_vi_list))
        printer.print_and_log("Val VI list is: {}".format(val_vi_list))

        # Testing
        vi, me, se = 0.0, 0.0, 0.0
        mAp, ari = 0.0, 0.0
        dice, betti, betti0, betti1 = 0.0, 0.0, 0.0, 0.0
        printer.print_and_log("Testing:")
        model.load_state_dict(torch.load(str(Path(checkpoint_path, 'best_model_state.pth'))))
        model.eval()
        with torch.no_grad():
            for img_idx, img_name in enumerate(test_inference_names):
                printer.print_and_log(img_name)
                img = load_img(Path(imgs_dir, 'images', img_name))
                img_origin = img.copy()
                origin_h, origin_w = img.shape[:2]
                no_remainder = True
                if origin_h % 32 != 0 or origin_w % 32 != 0:
                    no_remainder = False
                    resize_h = origin_h + (32 - origin_h % 32)
                    resize_w = origin_w + (32 - origin_w % 32)
                    img = skimage.transform.resize(img, (resize_h, resize_w), preserve_range=True).astype(np.float32)
                img = z_score_norm(img).unsqueeze(0)
                label = load_img(Path(imgs_dir, 'labels', img_name))
                if torch.cuda.is_available():
                    img = img.cuda()
                output = model.forward(img)
                output = output.max(1)[1].data
                output = output.cpu().squeeze().numpy()
                if no_remainder is False:
                    output = skimage.transform.resize(output, (origin_h, origin_w), order=0, preserve_range=True)
                output = output.astype(np.int64)
                label = label.astype(np.int64)
                
                cv2.imwrite(os.path.join(output_save_path, img_name), output*255)
                # visulaization
                plt.figure(figsize=(20, 20))
                plt.subplot(1, 3, 1), plt.imshow(img_origin, cmap="gray"), plt.title('img'), plt.axis("off")
                plt.subplot(1, 3, 2), plt.imshow(label, cmap="gray"), plt.title('label'), plt.axis("off")
                plt.subplot(1, 3, 3), plt.imshow(output, cmap="gray"), plt.title('output'), plt.axis("off")
                plt.show()
                # evaluation
                output = output.astype(np.int64)
                label = label.astype(np.int64) 
                output = metrics.post_process_output(output)
                label = metrics.post_process_label(label)

                label = label[5:1019, 5:1019]
                output = output[5:1019, 5:1019]
                temp_vi, temp_me, temp_se = metrics.vi(output, label)
                printer.print_and_log("temp_me:{:.4f} temp_se:{:.4f} temp_vi:{:.4f}".format(temp_me, temp_se, temp_vi))
                vi += temp_vi; me += temp_me; se += temp_se;

                mAp += metrics.map_2018kdsb(output, label)
                ari += metrics.ari(output, label)

                output = metrics.prun(output, 4)
                label = metrics.prun(label, 4)
                dice += metrics.mdice(output, label)
                temp_betti, temp_betti0, temp_brtti1 = metrics.compute_bettis_own(output, label, filter_small_holes=True)
                betti += temp_betti
                betti0 += temp_betti0
                betti1 += temp_brtti1
                torch.cuda.empty_cache()

        vi /= len(test_inference_names); me /= len(test_inference_names); se /= len(test_inference_names); 
        mAp /= len(test_inference_names); ari /= len(test_inference_names);
        dice /= len(test_inference_names); betti /= len(test_inference_names); betti0 /= len(test_inference_names); betti1 /= len(test_inference_names);
        experimet_metrics.append([me, se, vi, mAp, ari, dice, betti, betti0, betti1])
        metrics_np[0, loss_idx, kf_idx] = me; metrics_np[1, loss_idx, kf_idx] = se; metrics_np[2, loss_idx, kf_idx] = vi; 
        metrics_np[3, loss_idx, kf_idx] = mAp; metrics_np[4, loss_idx, kf_idx] = ari;
        metrics_np[5, loss_idx, kf_idx] = dice; metrics_np[6, loss_idx, kf_idx] = betti; metrics_np[7, loss_idx, kf_idx] = betti0; metrics_np[8, loss_idx, kf_idx] = betti1;
        np.save(str(Path(statistic_dir, export_name + '.npy')), metrics_np)
        
        printer.print_and_log("Experiment_name: " + experiment_name)
        printer.print_and_log("Total Evaluation:")
        printer.print_and_log("me:{:.4f} se:{:.4f} vi:{:.4f} mAp:{:.4f} ari:{:.4f} dice:{:.4f} betti:{:.4f} betti0:{:.4f} betti1:{:.4f}".format(me, se, vi, mAp, ari, dice, betti, betti0, betti1))


for idx in range(len(experiment_names)):
    print(experiment_names[idx])
    print("me:{:.4f} se:{:.4f} vi:{:.4f} mAp {:.4f} ari {:.4f} dice:{:.4f} betti:{:.4f} betti0:{:.4f} betti1:{:.4f}".format(experimet_metrics[idx][0], experimet_metrics[idx][1], experimet_metrics[idx][2], 
                                                                                                                            experimet_metrics[idx][3], experimet_metrics[idx][4], experimet_metrics[idx][5],
                                                                                                                            experimet_metrics[idx][6], experimet_metrics[idx][7], experimet_metrics[idx][8]))
# statistic
mean_np = np.mean(metrics_np, axis=2)
std_np  = np.std(metrics_np, axis=2)
        
metrics_df_brief = pd.DataFrame(index = metric_names, columns = loss_names)
for loss_idx in range(len(loss_names)):
    for metric_idx in range(len(metric_names)):
        metrics_df_brief.iloc[metric_idx, loss_idx] = '{:.4f} ± {:.4f}'.format(mean_np[metric_idx, loss_idx], std_np[metric_idx, loss_idx])

# export excel result
print(export_name)
print(metrics_df_brief)
statistic_path = Path(statistic_dir, export_name + '.xlsx')
metrics_df_brief.to_excel(str(statistic_path))

KeyboardInterrupt: 