In [None]:
import sys
sys.path.append('../..')
from utils import (
    show_sbs,
    load_config,
    _print,
)

import os
import ipywidgets as widgets

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchmetrics
from torch.optim import Adam, SGD
from losses import DiceLoss, DiceLossWithLogtis
from torch.nn import BCELoss, CrossEntropyLoss
import ali_utils

# Set the directory path
directory = '../../saved_models/'

# Get the list of files in the directory
file_list = os.listdir(directory)


# Create a dropdown widget
dropdown = widgets.Dropdown(options=sorted(file_list), description='Select a file:')
dropdown


In [None]:
!pwd

In [None]:
run_name=dropdown.value

config=load_config(f'{directory}/{run_name}/config.yaml')

device = 'cpu'#torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Torch device: {device}")



best_model = ali_utils.class_by_name(config['model']['class'])(**config['model']['params'])

torch.cuda.empty_cache()
best_model = best_model.to(device)

fn = "best_model_state_dict.pt"
model_path = f"{config['model']['save_dir']}/{fn}"

best_model.load_state_dict(torch.load(model_path))
print("Loaded best model weights...")

In [None]:
import json
result_file_path = f"{config['model']['save_dir']}/result.json"
with open(result_file_path, 'r') as f:
    results = json.loads(''.join(f.readlines()))
epochs_info = results['epochs_info']

tr_losses = [d['tr_loss'] for d in epochs_info]
vl_losses = [d['vl_loss'] for d in epochs_info]
tr_dice = [d['tr_metrics']['train_metrics/Dice'] for d in epochs_info]
vl_dice = [d['vl_metrics']['valid_metrics/Dice'] for d in epochs_info]
tr_js = [d['tr_metrics']['train_metrics/JaccardIndex'] for d in epochs_info]
vl_js = [d['vl_metrics']['valid_metrics/JaccardIndex'] for d in epochs_info]
tr_acc = [d['tr_metrics']['train_metrics/Accuracy'] for d in epochs_info]
vl_acc = [d['vl_metrics']['valid_metrics/Accuracy'] for d in epochs_info]


_, axs = plt.subplots(1, 4, figsize=[16,3])

axs[0].set_title("Loss")
axs[0].plot(tr_losses, 'r-', label="train loss")
axs[0].plot(vl_losses, 'b-', label="validatiton loss")
axs[0].legend()

axs[1].set_title("Dice score")
axs[1].plot(tr_dice, 'r-', label="train dice")
axs[1].plot(vl_dice, 'b-', label="validation dice")
axs[1].legend()

axs[2].set_title("Jaccard Similarity")
axs[2].plot(tr_js, 'r-', label="train JaccardIndex")
axs[2].plot(vl_js, 'b-', label="validatiton JaccardIndex")
axs[2].legend()

axs[3].set_title("Accuracy")
axs[3].plot(tr_acc, 'r-', label="train Accuracy")
axs[3].plot(vl_acc, 'b-', label="validation Accuracy")
axs[3].legend()

plt.show()

In [None]:
from tqdm import tqdm
import numpy as np
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
img_transform = transforms.Compose([
        transforms.ToTensor()
])
# transform for mask
msk_transform = transforms.Compose([
    transforms.ToTensor()
])

te_dataset = ali_utils.class_by_name(config['dataset']['class'])(mode="te", one_hot=True,**config['dataset'],img_transform=img_transform,msk_transform=msk_transform)
te_dataloader = DataLoader(te_dataset, **config['data_loader']['test'])
from PIL import Image
import cv2
def skin_plot(img, gt, pred):
    img = np.array(img)
    gt = np.array(gt)
    pred = np.array(pred)
    edged_test = cv2.Canny(pred, 100, 255)
    contours_test, _ = cv2.findContours(edged_test, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    edged_gt = cv2.Canny(gt, 100, 255)
    contours_gt, _ = cv2.findContours(edged_gt, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    for cnt_test in contours_test:
        cv2.drawContours(img, [cnt_test], -1, (255, 0, 0), 1)
    for cnt_gt in contours_gt:
        cv2.drawContours(img, [cnt_gt], -1, (0,255,0), 1)
    return img

#---------------------------------------------------------------------------------------------
import random

save_imgs_dir = f"{config['model']['save_dir']}/visualized"

if not os.path.isdir(save_imgs_dir):
    os.mkdir(save_imgs_dir)

with torch.no_grad():
    
    for batch in tqdm(te_dataloader):
        imgs = batch['image']
        msks = batch['mask']
        ids = batch['id']
        
        preds = best_model(imgs.to(device))
        
        txm = imgs.cpu().numpy()
        tbm = torch.argmax(msks, 1).cpu().numpy()
        tpm = torch.argmax(preds, 1).cpu().numpy()
        tid = ids
        debug_imgs=[]
        for idx in range(len(tbm)):
            img = np.moveaxis(txm[idx, :3], 0, -1)*255.
            img = np.ascontiguousarray(img, dtype=np.uint8)
            gt = np.uint8(tbm[idx]*255.)
            pred = np.where(tpm[idx]>0.5, 255, 0)
            pred = np.ascontiguousarray(pred, dtype=np.uint8)
            
            res_img = skin_plot(img, gt, pred)
            
            fid = tid[idx]
            # Image.fromarray(img).save(f"{save_imgs_dir}/{fid}_img.png")
            # Image.fromarray(res_img).save(f"{save_imgs_dir}/{fid}_img_gt_pred.png")
            # if(random.random()<0.05):
            debug_imgs.append(Image.fromarray(res_img))
            
            if idx>4:break
        import ipyplot
        ipyplot.plot_images(debug_imgs, max_images=4, img_width=150)

In [None]:
from tqdm import tqdm
import numpy as np
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
import numpy as np
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
import matplotlib.pyplot as plt

img_transform = transforms.Compose([
        transforms.ToTensor()
])
# transform for mask
msk_transform = transforms.Compose([
    transforms.ToTensor()
])

te_dataset = ali_utils.class_by_name(config['dataset']['class'])(mode="te", one_hot=True,**config['dataset'],img_transform=img_transform,msk_transform=msk_transform)
te_dataloader = DataLoader(te_dataset, **config['data_loader']['test'])
from PIL import Image
import cv2
def skin_plot(img, gt, pred):
    img = np.array(img)
    gt = np.array(gt)
    pred = np.array(pred)
    edged_test = cv2.Canny(pred, 100, 255)
    contours_test, _ = cv2.findContours(edged_test, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    edged_gt = cv2.Canny(gt, 100, 255)
    contours_gt, _ = cv2.findContours(edged_gt, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    for cnt_test in contours_test:
        cv2.drawContours(img, [cnt_test], -1, (0, 0, 255), 1)
    for cnt_gt in contours_gt:
        cv2.drawContours(img, [cnt_gt], -1, (0,255,0), 1)
    return img

#---------------------------------------------------------------------------------------------
import random

save_imgs_dir = f"{config['model']['save_dir']}/visualized"

if not os.path.isdir(save_imgs_dir):
    os.mkdir(save_imgs_dir)

with torch.no_grad():
    
    for batch in tqdm(te_dataloader):
        imgs = batch['image']
        msks = batch['mask']
        ids = batch['id']
        
        preds = best_model.DCFD.bases_net(imgs.to(device))
        # preds = best_model.DCFD(imgs.to(device))
        
        txm = imgs.cpu().numpy()
        tpm2=preds.cpu().numpy()
        tbm = torch.argmax(msks, 1).cpu().numpy()
        tpm = torch.argmax(preds, 1).cpu().numpy()
        tid = ids
        debug_imgs=[]
        for idx in range(len(tbm)):
            img = np.moveaxis(txm[idx, :3], 0, -1)*255.
            img = np.ascontiguousarray(img, dtype=np.uint8)
            
            img2 = np.moveaxis(tpm2[idx, :], 0, -1)*255.
            img2 = np.ascontiguousarray(img2, dtype=np.uint8)
            
            gt = np.uint8(tbm[idx]*255.)
            res_img = skin_plot(img, gt, gt)
            debug_imgs.append(Image.fromarray(res_img))
            # color_map = plt.get_cmap('seismic')
            color_map = plt.get_cmap('turbo')
            # color_map = plt.get_cmap('Blues')
            # color_map = plt.get_cmap('ocean')
            print(img2.shape[2])
            for channel in range(img2.shape[2]):
                img2tmp=np.zeros_like(img2)
                img2tmp = (color_map(img2[:,:,channel])[:, :, :3] * 255).astype(np.uint8)
                # img2tmp[:,:,:]=img2[:,:,channel]
                # img2tmp[:,:,1]=img2[:,:,channel]
                # img2tmp[:,:,2]=img2[:,:,channel]
                
                res_img2 = img2tmp#skin_plot(img2tmp, gt, gt)

                # fid = tid[idx]
                # Image.fromarray(img).save(f"{save_imgs_dir}/{fid}_img.png")
                # Image.fromarray(res_img).save(f"{save_imgs_dir}/{fid}_img_gt_pred.png")
                # if(random.random()<0.05):
                
                debug_imgs.append(Image.fromarray(res_img2))
            
            if idx>4:break
        import ipyplot
        ipyplot.plot_images(debug_imgs, img_width=150)

In [None]:
pip install ipyplot