In [13]:
# change path
code_path = 'Semantic-Aware-Attention-Based-Deep-Object-Co-segmentation/'
label_save_path = 'Data_ham/HAM_LABELS.npy'

img_path1 = 'Data_ham/HAM10000_images_part_1/'
img_path2 = 'Data_ham/HAM10000_images_part_2/'
gt_path = 'Data_ham/HAM_MASK/'
el_path = 'Data_ham/HAM_ellipse/'
grab_path = 'Data_ham/HAM_grabcut/'

a1_path='checkpoints/ham_a1.pt'
a1_class_path='checkpoints/ham_a1class.pt'

num_class=45

In [2]:
import torch, math, time
import numpy as np
import cv2
import PIL
from matplotlib import pyplot as plt
from matplotlib import image
%matplotlib inline
import skimage, skimage.transform
import sys
sys.path.append(code_path)
from torchvision import models
import scipy.io as sio
from scipy.io import loadmat
import torch.nn.functional as F
# CUDA flag. Speed-up due to CUDA is mostly noticable for large batches.
cuda = True
from PIL import Image,ImageDraw
from torchvision.utils import save_image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, Scale, ToPILImage
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
#from skimage import io, transform
from model import *
import segmentation_models_pytorch as smp
import albumentations as albu
from tqdm.auto import tqdm
import torch.nn as nn
import torchvision.models.resnet as resnet_util
from segmentation_models_pytorch.encoders import get_preprocessing_fn
from torch.optim import Adam
import os
from copy import deepcopy
import random
from random import shuffle
from util import metrics

## Dataloader

In [3]:
im_dir = np.array(sorted(os.listdir(img_path1)+os.listdir(img_path2)))
gt_dir = np.array(sorted(os.listdir(gt_path)))
el_dir = np.array(sorted(os.listdir(el_path)))
grab_dir = np.array(sorted(os.listdir(grab_path)))

wrong=[1736, 3925, 5513]
im_dir=np.delete(im_dir,wrong)
gt_dir=np.delete(gt_dir,wrong)

labels=np.load(label_save_path)

In [4]:
random.seed(1)
ind = np.arange(len(im_dir))
shuffle(ind)
tr = int(len(ind)*0.7)
val = tr+int(len(ind)*0.2)

im_train_dir = im_dir[ind[:tr]]
im_val_dir = im_dir[ind[tr:val]]
im_test_dir = im_dir[ind[val:]]

In [5]:
random.seed(1)
ind = np.arange(len(im_dir))
shuffle(ind)
tr = int(len(ind)*0.7)
val = tr+int(len(ind)*0.2)
im_train_dir = im_dir[ind[:tr]]
gt_train_dir = gt_dir[ind[:tr]]
el_train_dir = el_dir[ind[:tr]]
grab_train_dir = grab_dir[ind[:tr]]
train_label=labels[ind[:tr]]

im_val_dir = im_dir[ind[tr:val]]
gt_val_dir = gt_dir[ind[tr:val]]
el_val_dir = el_dir[ind[tr:val]]
grab_val_dir = grab_dir[ind[tr:val]]
val_label=labels[ind[tr:val]]

im_test_dir = im_dir[ind[val:]]
gt_test_dir = gt_dir[ind[val:]]
el_test_dir = el_dir[ind[val:]]
grab_test_dir = grab_dir[ind[val:]]
test_label=labels[ind[val:]]

In [6]:
def get_training_augmentation():
    train_transform = [albu.Resize(120,120),albu.PadIfNeeded(384, 480)]
    return albu.Compose(train_transform)

def to_tensor(x, **kwargs):
    return torch.from_numpy(x.transpose(2, 0, 1).astype('float32'))

def get_preprocessing(preprocessing_fn):
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

In [7]:
class Dataset(Dataset):
    def __init__(
            self, 
            images_dir, 
            masks_dir,
            masks_path,
            label,
            augmentation=None, 
            preprocessing=None,
    ):
        self.images_fps = []
        for image_id in images_dir:
            if os.path.isfile(os.path.join(img_path1, image_id)):
                self.images_fps.append(os.path.join(img_path1, image_id))
            else:
                self.images_fps.append(os.path.join(img_path2, image_id))
        self.masks_fps = [os.path.join(masks_path, mask_id) for mask_id in masks_dir]
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.label=label
    
    def __getitem__(self, i):
        image = cv2.imread(self.images_fps[i])
        mask = cv2.imread(self.masks_fps[i],0)/255.0
        mask[mask>=0.5]=1
        mask[mask<0.5]=0
        mask = np.expand_dims(mask,2)
        sample = self.augmentation(image=image, mask=mask)
        image, mask = sample['image'], sample['mask'].reshape(384,480,1)
        sample = self.preprocessing(image=image, mask=mask)
        image, mask = sample['image'], sample['mask']
        return image, mask, torch.tensor(self.label[i])
        
    def __len__(self):
        return len(self.masks_fps)

In [8]:
propro=get_preprocessing(get_preprocessing_fn('resnet101', pretrained='imagenet'))
train_dataset = Dataset(
    im_train_dir, 
    el_train_dir, #grab_train_dir, 
    el_path, #grab_path
    train_label,
    augmentation=get_training_augmentation(),
    preprocessing=propro
)
val_dataset = Dataset(
    im_val_dir, 
    el_val_dir, #grab_val_dir, 
    el_path, #grab_path,
    val_label,
    augmentation=get_training_augmentation(),
    preprocessing=propro
)
test_dataset = Dataset(
    im_test_dir, 
    gt_test_dir, 
    gt_path,
    test_label,
    augmentation=get_training_augmentation(),
    preprocessing=propro
)

In [9]:
bs = 10
train_loader = torch.utils.data.DataLoader(train_dataset,
                                             batch_size=bs, shuffle=True,
                                             num_workers=1)
val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=bs, shuffle=True,
                                             num_workers=1)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                             batch_size=bs, shuffle=False,
                                             num_workers=1)

## A1 Train

In [11]:
def train(model, opt, criterion, device, checkpoint_save_path, max_patient=5):
    min_val_loss = float('inf')
    m=0
    for epoch in tqdm(range(50)):
        # train
        model.train()
        train_loss=[]
        for batch_idx, sam in enumerate(train_loader):
            # send to device
            data, target = sam[0].to(device=device,dtype=torch.float), sam[1].to(device=device,dtype=torch.float)
            
            opt.zero_grad()
            
            out=model(data)
        
            loss = criterion(out, target) * 100.0
            loss.backward()
            opt.step()
            
            train_loss.append(loss.item())
            
            if batch_idx % 500 == 0:
                avg_loss = sum(train_loss)/len(train_loss)
                print('Step {} avg train loss = {:.{prec}f}'.format(batch_idx, avg_loss, prec=4))
                train_loss = []
                
        # valid
        valid_loss=[]
        model.eval()
        with torch.no_grad():
            for batch_idx, sam in enumerate(val_loader):
                data, target = sam[0].to(device=device,dtype=torch.float), sam[1].to(device=device,dtype=torch.float)
                
                out=model(data)
                
                loss = criterion(out,target) * 100.0
                valid_loss.append(loss.item())
            avg_val_loss = sum(valid_loss) / len(valid_loss)
            print('Validation loss after {} epoch = {:.{prec}f}'.format(epoch, avg_val_loss, prec=4))
        
        if avg_val_loss < min_val_loss:
            min_val_loss = avg_val_loss
            torch.save({'model_dict': model.state_dict()},checkpoint_save_path)
            print('model saved')
            m=0
        else:
            m+=1
            if m >= max_patient:
                return
                
    return

In [15]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = smp.DeepLabV3Plus(encoder_name='resnet101')
model.load_state_dict(checkpoint['model_dict'])

<All keys matched successfully>

In [16]:
lr=0.001
model = model.to(device)

opt = Adam(model.parameters(),lr=lr)
bce_loss = nn.BCEWithLogitsLoss()

train(model,opt,bce_loss,device, a1_path)

HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

Step 0 avg train loss = 217.0010
Step 500 avg train loss = 18.4547
Validation loss after 0 epoch = 13.1834
model saved
Step 0 avg train loss = 13.5415
Step 500 avg train loss = 12.8940
Validation loss after 1 epoch = 12.1508
model saved
Step 0 avg train loss = 8.3327
Step 500 avg train loss = 11.9451
Validation loss after 2 epoch = 12.1470
model saved
Step 0 avg train loss = 20.2168
Step 500 avg train loss = 10.8984
Validation loss after 3 epoch = 11.0995
model saved
Step 0 avg train loss = 13.0823
Step 500 avg train loss = 10.4339
Validation loss after 4 epoch = 10.9708
model saved
Step 0 avg train loss = 9.3801
Step 500 avg train loss = 9.9595
Validation loss after 5 epoch = 10.6551
model saved
Step 0 avg train loss = 10.7799
Step 500 avg train loss = 9.3843
Validation loss after 6 epoch = 11.1890
Step 0 avg train loss = 6.0398
Step 500 avg train loss = 8.6909
Validation loss after 7 epoch = 10.8962
Step 0 avg train loss = 12.0898
Step 500 avg train loss = 8.3317
Validation loss afte

## A1 Evaluate

In [22]:
def test(model,loader):
    test_loss=[]
    #imgs = torch.tensor([])
    mask,pred = [],[]
    with torch.no_grad():
        for batch_idx, sam in enumerate(loader):
            data, target = sam[0].to(device=device,dtype=torch.float), sam[1].to(device=device,dtype=torch.float)
            out = model(data)
            loss = bce_loss(out,target) * 100.0
            test_loss.append(loss.item())
            mask.append(target.cpu().detach().numpy())
            pred.append(out.cpu().detach().numpy())
        avg_test_loss = sum(test_loss) / len(test_loss)
        print('Test loss = {:.{prec}f}'.format(avg_test_loss, prec=4))
    return mask,pred

In [23]:
#checkpoint_save_path1='/misc/vlgscratch4/LakeGroup/shared_data/wsol_checkpoint_/ham_ell.pt'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#optimal_model = smp.DeepLabV3Plus(encoder_name='resnet101').to(device)
optimal_model = smp.Unet(encoder_name='resnet101').to(device)
checkpoint = torch.load(a1_path)
optimal_model.load_state_dict(checkpoint['model_dict'])
optimal_model.eval()
bce_loss = nn.BCEWithLogitsLoss()

mask,pred = test(optimal_model,test_loader)
gt_map = np.concatenate(mask, axis=0)
pred_map = np.concatenate(pred, axis=0)

Test loss = 14.8359


In [31]:
pred_map_sig = nn.Sigmoid()(torch.tensor(pred_map))
threshold = 0.5
predict = deepcopy(pred_map_sig)
predict[pred_map_sig>=threshold]=1
predict[pred_map_sig<threshold]=0

cut_pred1 = []
cut_gt = []
for i in range(predict.shape[0]):
    cut_pred1.append(np.uint8(predict[i,0,132:252,180:300]))
    cut_gt.append(np.uint8(gt_map[i,0,132:252,180:300]))

metrics(cut_gt,cut_pred1)

Best test IOU is 0.825251476297906
Best test DICE is 0.8992303889649864
Test center error is 2.6270569717730394
Test circumstance error is 8.862275449101796
Test AVD is  10.387944143022278
Test VS is  0.918178131644553


## A1+class Train

In [11]:
def train(model, opt, criterion_grab, criterion_class, device, checkpoint_save_path, max_patient=5):
    min_val_loss = float('inf')
    m=0
    for epoch in tqdm(range(50)):
        # train
        model.train()
        train_loss={'Total':[],'Mask':[],'Label':[]}
        for batch_idx, sam in enumerate(train_loader):
            # send to device
            data, label, grabcut = sam[0].to(device=device,dtype=torch.float), sam[2].to(device=device,dtype=torch.long), sam[1].to(device=device,dtype=torch.float)
            
            opt.zero_grad()
            
            out_grab, out_class = model(data)
        
            loss_grab = criterion_grab(out_grab, grabcut) * 50.0
            loss_class = criterion_class(out_class, label)
            loss = loss_grab+loss_class
            loss.backward()
            opt.step()
            
            train_loss['Total'].append(loss.item())
            train_loss['Mask'].append(loss_grab.item())
            train_loss['Label'].append(loss_class.item())
            
            if batch_idx % 500 == 0:
                total_loss = sum(train_loss['Total'])/len(train_loss['Total'])
                mask_loss = sum(train_loss['Mask'])/len(train_loss['Mask'])
                label_loss = sum(train_loss['Label'])/len(train_loss['Label'])
                print('Step %s avg train total loss = %.4f, avg train Mask loss = %.4f, avg train label loss = %.4f,'%(batch_idx, total_loss, mask_loss, label_loss))
                train_loss = {'Total':[],'Mask':[],'Label':[]}
                
        # valid
        valid_loss={'Total':[],'Mask':[],'Label':[]}
        model.eval()
        with torch.no_grad():
            for batch_idx, sam in enumerate(val_loader):
                data, label, grabcut = sam[0].to(device=device,dtype=torch.float), sam[2].to(device=device,dtype=torch.long), sam[1].to(device=device,dtype=torch.float)
                
                out_grab,out_class=model(data)
                
                loss_grab = criterion_grab(out_grab,grabcut) * 50.0
                loss_class = criterion_class(out_class, label)
                loss = loss_grab+loss_class
                
                valid_loss['Total'].append(loss.item())
                valid_loss['Mask'].append(loss_grab.item())
                valid_loss['Label'].append(loss_class.item())
                
            total_loss = sum(valid_loss['Total'])/len(valid_loss['Total'])
            mask_loss = sum(valid_loss['Mask'])/len(valid_loss['Mask'])
            label_loss = sum(valid_loss['Label'])/len(valid_loss['Label'])
            print('At Epoch %s avg validation total loss = %.4f, Mask loss = %.4f, label loss = %.4f,'%(epoch, total_loss, mask_loss, label_loss))
        
        if mask_loss < min_val_loss:
            min_val_loss = mask_loss
            torch.save({'model_dict': model.state_dict()},checkpoint_save_path)
            print('model saved')
            m=0
        else:
            m+=1
            if m >= max_patient:
                torch.save({'model_dict': model.state_dict()},checkpoint_save_path)
                return
                
    return

In [12]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
lr=0.001 
#num_class=45

model = smp.DeepLabV3Plus(encoder_name='resnet101',aux_params={'classes':num_class})
model = model.to(device)

opt = Adam(model.parameters(),lr=lr)

bce_loss = nn.BCEWithLogitsLoss()
ce_loss = nn.CrossEntropyLoss()

train(model,opt,bce_loss,ce_loss,device, a1_class_path, max_patient=5)

HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

Step 0 avg train total loss = 42.7998, avg train Mask loss = 38.7969, avg train label loss = 4.0029,
Step 500 avg train total loss = 14.1047, avg train Mask loss = 10.5223, avg train label loss = 3.5824,
At Epoch 0 avg validation total loss = 11.4158, Mask loss = 8.4680, label loss = 2.9478,
model saved
Step 0 avg train total loss = 7.8734, avg train Mask loss = 5.4330, avg train label loss = 2.4404,
Step 500 avg train total loss = 9.9644, avg train Mask loss = 7.2222, avg train label loss = 2.7422,
At Epoch 1 avg validation total loss = 9.2249, Mask loss = 6.7104, label loss = 2.5145,
model saved
Step 0 avg train total loss = 9.6563, avg train Mask loss = 6.7434, avg train label loss = 2.9128,
Step 500 avg train total loss = 8.9716, avg train Mask loss = 6.4940, avg train label loss = 2.4776,
At Epoch 2 avg validation total loss = 9.4882, Mask loss = 6.9155, label loss = 2.5727,
Step 0 avg train total loss = 5.9018, avg train Mask loss = 4.0936, avg train label loss = 1.8083,
Step 500

## A1+class Evaluate

In [14]:
def test(model,loader):
    test_loss=[]
    imgs = torch.tensor([])
    mask,pred = [],[]
    with torch.no_grad():
        for batch_idx, sam in enumerate(loader):
            data, target, label = sam[0].to(device=device,dtype=torch.float), sam[1].to(device=device,dtype=torch.float), sam[2].to(device=device,dtype=torch.long)
            out, c = model(data)
            loss = bce_loss(out,target) * 100.0
            test_loss.append(loss.item())
            
            mask.append(target.cpu().detach().numpy())
            pred.append(out.cpu().detach().numpy())
        avg_test_loss = sum(test_loss) / len(test_loss)
        print('Test loss = {:.{prec}f}'.format(avg_test_loss, prec=4))
    return mask,pred

In [15]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
optimal_model = smp.DeepLabV3Plus(encoder_name='resnet101',aux_params={'classes':num_class}).to(device)
checkpoint = torch.load(a1_class_path)
optimal_model.load_state_dict(checkpoint['model_dict'])
optimal_model.eval()
bce_loss = nn.BCEWithLogitsLoss()
mask,pred = test(optimal_model,test_loader)
gt_map = np.concatenate(mask, axis=0)
pred_map = np.concatenate(pred, axis=0)

Test loss = 13.5269


In [16]:
pred_map_sig = nn.Sigmoid()(torch.tensor(pred_map))
threshold = 0.5
predict = deepcopy(pred_map_sig)
predict[pred_map_sig>=threshold]=1
predict[pred_map_sig<threshold]=0

cut_pred = []
cut_gt = []
for i in range(predict.shape[0]):
    cut_pred.append(np.uint8(predict[i,0,132:252,180:300]))
    cut_gt.append(np.uint8(gt_map[i,0,132:252,180:300]))
    
metrics(cut_gt,cut_pred)

Best test IOU is 0.8387841875213979
Best test DICE is 0.9084863046783714
Test center error is 2.495279991214853
Test circumstance error is 9.37624750499002
Test AVD is  9.73292352546333
Test VS is  0.9299433506190851
