In [60]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader 
from torch.autograd import Variable
from torch.backends import cudnn


import matplotlib.pyplot as plt
import os
from PIL import Image
import cv2
import numpy as np


# Augmenting library 

import torchvision
import torchvision.transforms.functional as TF
import albumentations as A
import torchvision.transforms as T

# Control Randomness
import random
random_seed = 7
torch.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)

import warnings
warnings.filterwarnings("ignore")
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
print(torch.cuda.device_count())


# logging
import datetime
from tensorboardX import SummaryWriter
from tqdm import tqdm
import time 

1


In [61]:
import joint_transforms
from config import cod_training_root
from config import backbone_path
from datasets import ImageFolder
from misc import AvgMeter, check_mkdir
from PFNet import PFNet
from helper import *
import loss

# Config

In [62]:
ckpt_path = './ckpt'
exp_name = 'PFNet'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args = {
    'epoch_num': 200,
    'train_batch_size': 32,
    'last_epoch': 0,
    'lr': 1e-4, 
    'lr_decay': 0.9,
    'weight_decay': 5e-4,
    'momentum': 0.9,
    'snapshot': '',
    'scale': 416, 
    'save_point': [],
    'poly_train': False,
    'optimizer': 'Adam',
    'amp' : False
}

In [63]:
# loss function
structure_loss = loss.structure_loss().to(device)
bce_loss = nn.BCEWithLogitsLoss().to(device)
iou_loss = loss.IOU().to(device)

In [64]:
def bce_iou_loss(pred, target):
    bce_out = bce_loss(pred, target)
    iou_out = iou_loss(pred, target)
    loss = bce_out + iou_out
    return loss



# Net

In [65]:
net = PFNet(backbone_path)

From ./backbone/resnet/resnet50-19c8e357.pth Load resnet50 Weights Succeed!


In [66]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

if args['optimizer'] == 'SGD':
    print('SGD opt')
    optimizer = torch.optim.SGD([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
        'lr': 2 * args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
        'lr': 1 * args['lr'], 'weight_decay': args['weight_decay']}
    ], momentum=args['momentum'])

else:
    print('Adam opt')
    optimizer = torch.optim.Adam([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
            'lr': 2 * args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
            'lr': 1 * args['lr'], 'weight_decay': args['weight_decay']}
    ])

# scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, \
#                                 threshold=0.05, threshold_mode='rel', cooldown=5, min_lr = 1e-4)

# scheduler = CosineAnnealingWarmupRestarts(optimizer, first_cycle_steps=200, cycle_mult=1.0, max_lr=0.1, min_lr=0.001, warmup_steps=50, gamma=0.5)

Adam opt


## making data index list

In [67]:
# mask_files = os.walk("/home/sklab2/workspace/datashared/SS-OCT/vessel_segmentation/exp/masked")
# mask_idx = []
# for (root, dirs, files) in mask_files:
#     if len(files) > 0 :
#         mask_idx.append(files)

# mask_idxs = [element for array in mask_idx for element in array]
# len(mask_idxs)

# # 1~ 11 / 12, 13, 14  , 40, 41, 43, 44, 46, 49,  50, 53, 54, 55 
# train_indexs = []
# test_indexs = []
# for idx, data in enumerate(mask_idxs):
#     tmp = mask_idxs[idx].split('_')
#     test_indexs.append([tmp[0], tmp[1].split('.')[0]])

In [68]:
mask_files = os.walk("/home/sklab2/workspace/datashared/SS-OCT/vessel_segmentation/masked")
mask_idx = []
for (root, dirs, files) in mask_files:
    if len(files) > 0 :
        mask_idx.append(files)

mask_idxs = [element for array in mask_idx for element in array]

# 1~ 11 / 12, 13, 14  , 40, 41, 43, 44, 46, 49,  50, 53, 54, 55 
train_indexs = []
test_indexs = []
for idx, data in enumerate(mask_idxs):
    tmp = mask_idxs[idx].split('_')
    if len(tmp) < 3:
        if int(tmp[0]) < 45:
            train_indexs.append([ tmp[0], tmp[1].split('.')[0]])
        else:
            test_indexs.append([tmp[0], tmp[1].split('.')[0]])

In [69]:
test_indexs[5]

['46', '21']

# Dataloader

In [70]:
# import albumentations.augmentations.functional as AF

# PATH = '/home/sklab2/workspace/datashared/SS-OCT/vessel_segmentation/exp/'
# class VesselDataset(Dataset):
#     def __init__(self, index, transforms):
#         self.index = index
#         self.transforms = transforms
        
#     def __len__(self):
#         return len(self.index)
        
#     def __getitem__(self, idx):
#         s_1 = self.index[idx][0]
#         s_2 = self.index[idx][1]

#         # '1_L_0.jpg', 
#         image = Image.open(PATH+'origin/' + s_1+'_'+s_2+'.jpg').resize((416, 416),Image.Resampling.BILINEAR)
#         #'10_L_112_L.png', 
#         mask = Image.open(PATH+'masked/' +  s_1+'_'+s_2+'.png').resize((416, 416),Image.Resampling.BILINEAR)
        
#         image = np.array(image, dtype=np.uint8) #RGB
#         mask = np.array(mask, dtype=np.uint8)   # HWC
#         mask_o = mask / 255        # CHW


#         lower_red = np.array([-10, 100, 100]) 
#         upper_red = np.array([10, 255, 255]) 

#         mask_hsv = cv2.cvtColor(mask, cv2.COLOR_RGB2HSV)
#         mask = cv2.inRange(mask_hsv, lower_red, upper_red)

#         aft_mask = mask / 255
        
#         # aft_mask = cv2.resize(aft_mask, (416, 416), interpolation=cv2.INTER_NEAREST)
#         masks = [aft_mask, mask_o]  # target, original

#         # for num in range(3): #### 3번 이터레이션이 왜들어갔지?
#         if self.transforms:
#             transformed = self.transforms(image=image, masks=masks)
#             image, masks = transformed['image'], transformed['masks']
#         # urls.append(s_1+'_'+s_2)
#         assert sum(masks[0]==0).sum() + sum(masks[0]==1).sum() == 416*416   # mask가 0 또는 1이 아닐경우 스탑
                    
#         return image, masks, aft_mask, s_1+'_'+s_2

In [71]:
import albumentations.augmentations.functional as AF

PATH = '/home/sklab2/workspace/datashared/SS-OCT/vessel_segmentation/'
class VesselDataset(Dataset):
    def __init__(self, index, transforms):
        self.index = index
        self.transforms = transforms
        
    def __len__(self):
        return len(self.index)
        
    def __getitem__(self, idx):
        s_1 = self.index[idx][0]
        s_2 = self.index[idx][1]

        # '1_L_0.jpg', 
        image = Image.open(PATH+'origin/' + s_1+'_L_'+s_2+'.jpg').resize((416, 416),Image.Resampling.BILINEAR)
        #'10_L_112_L.png', 
        mask = Image.open(PATH+'masked/' +  s_1+'_'+s_2+'.png').resize((416, 416),Image.Resampling.BILINEAR)
        
        image = np.array(image, dtype=np.uint8) #RGB
        mask = np.array(mask, dtype=np.uint8)   # HWC
        mask_o = mask / 255        # CHW


        lower_red = np.array([-10, 100, 100]) 
        upper_red = np.array([10, 255, 255]) 

        mask_hsv = cv2.cvtColor(mask, cv2.COLOR_RGB2HSV)
        mask = cv2.inRange(mask_hsv, lower_red, upper_red)

        aft_mask = mask / 255
        
        # aft_mask = cv2.resize(aft_mask, (416, 416), interpolation=cv2.INTER_NEAREST)
        masks = [aft_mask, mask_o]  # target, original

        # for num in range(3): #### 3번 이터레이션이 왜들어갔지?
        if self.transforms:
            transformed = self.transforms(image=image, masks=masks)
            image, masks = transformed['image'], transformed['masks']
        # urls.append(s_1+'_'+s_2)
        assert sum(masks[0]==0).sum() + sum(masks[0]==1).sum() == 416*416   # mask가 0 또는 1이 아닐경우 스탑
                    
        return image, masks, aft_mask, s_1+'_'+s_2

: 

In [59]:
import albumentations.pytorch as AP

train_transform = A.Compose([
    A.RandomRotate90(p=0.25),
    A.RandomResizedCrop(416, 416, scale=(0.5, 1.0), ratio=(0.8, 1.2), interpolation=cv2.INTER_AREA, p=0.25),
    A.OneOf([
        A.OpticalDistortion(p=1, interpolation=cv2.INTER_AREA),
        A.GridDistortion(p=1, interpolation=cv2.INTER_AREA),
        A.ElasticTransform(p=1, alpha=100, sigma=100 * 0.05, alpha_affine=100 * 0.03, interpolation=cv2.INTER_AREA)
        ], p = 0.5),  # 밝기 및 조도 변화
    # A.Normalize(mean=(126.71482973095203, 126.6879562017254, 126.85466873988524), std = (32.9434, 33.0122, 32.9186)),
    A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.25),
    AP.ToTensorV2(transpose_mask=False),
    
])


test_transform = A.Compose([
    AP.ToTensorV2(always_apply=True)
])

# tensor([127.5388, 127.5482, 127.6733])
# tensor([57.4250, 57.6999, 57.5387])
train_dataset = VesselDataset(index=train_indexs, transforms=train_transform)
test_dataset = VesselDataset(index=test_indexs, transforms=test_transform)
 
#train_dataset, _, test_dataset = torch.utils.data.random_split(dataset, [train, 0, test])

train_loader = DataLoader(dataset=train_dataset, batch_size=args['train_batch_size'], shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)

image, masks, aft_mask, _= next(iter(train_loader))
print(image.shape, masks[0].shape, masks[1].shape, aft_mask.shape)

ValueError: too many values to unpack (expected 3)

In [None]:
TF.to_pil_image(torchvision.utils.make_grid(image[:5]))

In [None]:
TF.to_pil_image(torchvision.utils.make_grid(mask_o[:5]))

In [40]:
from sklearn.metrics import jaccard_score, precision_score, recall_score
def calc_metric(labels, preds):
    accuracy = np.mean(np.equal(labels,preds))
    right = np.sum(labels *preds == 1)
    precision = right / np.sum(preds)
    recall = right / np.sum(labels)
    f1 = 2 * precision*recall/(precision+recall)

    
    y_pred = preds
    y_true = labels
    """ Ground truth """
    #y_true = y_true.cpu().numpy()
    y_true = y_true > 0.5
    y_true = y_true.astype(np.uint8)
    y_true = y_true.reshape(-1)

    """ Prediction """
    #y_pred = y_pred.cpu().numpy()
    y_pred = y_pred > 0.5
    y_pred = y_pred.astype(np.uint8)
    y_pred = y_pred.reshape(-1)

    score_jaccard = jaccard_score(y_true, y_pred)

    print('jaccard, f1, recall, precision, acc')
    print(score_jaccard, f1, recall, precision, accuracy)
    return score_jaccard, f1, recall, precision, accuracy

# evaluation

In [None]:
print("Testing...")
net = PFNet(backbone_path)
load_from = '/home/sklab2/workspace/code_only/junsu/model/vessel_PFNet_base_b32_e200_220930_22:41.pt'
# load_from = '/home/sklab2/workspace/code_only/junsu/model/vessel_PFNet_aug50+focal+coslr_b32_e300_220928_01:55.pt'
if torch.cuda.device_count() > 1:
    print(f'Using {torch.cuda.device_count()} GPUs.')
net = nn.DataParallel(net)
net.load_state_dict(torch.load(load_from))
net.to(device)

images=[]
preds=[]
labels=[]
label_os = []
urls_list = []
net.eval()
# tqdm_loader = tqdm(test_loader)
with torch.no_grad():
    for idx, dd in enumerate(tqdm(test_loader)):

        image, masks, mask_o, urls = dd 
        
        image = image.float().to(device)
        label = masks[0].float()
        label_o = masks[1].float()
        _, _, _, pred = net(image)    

        images.append(image.cpu().detach().numpy())
        labels.append(label.numpy())
        label_os.append(label_o)
        preds.append(pred.cpu().detach().numpy())
        urls_list.append(urls)

    images= np.array(images).squeeze(1)
    preds = np.array(preds).squeeze(1)
    labels = np.array(labels)
    label_os = np.array(label_os)
    preds = np.where(preds > 0.5 , 1 , 0)
    labels = np.where(labels > 0.5 , 1 , 0)
    
    # score_jaccard, score_f1, score_recall, score_precision, score_acc = calc_metric(labels=labels, preds=preds)

In [None]:
# randnum = np.random.randint(0, len(test_dataset)-10)
randnum=218

fig, axes = plt.subplots(10, 4, figsize = (12,36))
[c_ax.axis('off') for c_ax in axes.flatten()]

for idx, (img_ax, pred_ax, target_ax ,mask_o_ax) in zip(range(randnum, randnum+10), axes):
    
# inputs[:10] , preds[:10], targets[:10], urls_list[:10]) :
    
    image = images[idx].astype(int).transpose(1, 2, 0) # astype(int)
    img_pred = preds[idx].transpose(1, 2, 0) 
    img_mask = labels[idx].transpose(1, 2, 0)
    img_mask_o = label_os[idx][0]
 
    img_ax.imshow(np.clip(image, 0, 255))
    mask_o_ax.imshow(img_mask_o)
    target_ax.imshow(img_mask)
    pred_ax.imshow(img_pred)


    img_ax.set_title(f'Test num: {idx}')
    mask_o_ax.set_title(f'Annotation: {idx}')
    target_ax.set_title(f'Ground Truth: {idx}')
    pred_ax.set_title(f'Predicted: {idx}')
plt.tight_layout()
plt.show()