In [1]:
import os
import torch
import cv2
import glob
import numpy as np
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
from segmentation_models_pytorch import metrics as smp_metrics

import albumentations as A
import albumentations.augmentations.functional as F
from albumentations.pytorch import ToTensorV2
import wandb
from pprint import pprint
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from tqdm import tqdm
from shutil import copyfile
from copy import deepcopy
from PIL import Image, ImageFile
from pytorch_lightning import  seed_everything
from sklearn.model_selection import train_test_split
import random
from mcode.copy_paste import CopyPaste
ImageFile.LOAD_TRUNCATED_IMAGES = True


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Setup random seed

def set_seed_everything(seed: int):    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    seed_everything(seed)
    
set_seed_everything(42)

Global seed set to 42


In [3]:


os.environ["WANDB_SILENT"] = "True"
%env JOBLIB_TEMP_FOLDER=/tmp


env: JOBLIB_TEMP_FOLDER=/tmp


In [4]:
class PolypDataset(Dataset):
    """
    dataloader for polyp segmentation tasks
    """
    def __init__(self, image_root, gt_root, trainsize, transform):
        self.trainsize = trainsize
        self.images = image_root
        self.masks = gt_root
        self.images = sorted(self.images)
        self.masks = sorted(self.masks)
        self.filter_files()
        self.size = len(self.images)
        self.transform = transform
        
    def __getitem__(self, index):
        image = self.rgb_loader(self.images[index])
        mask = self.binary_loader(self.masks[index])
        
        if self.transform is not None:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]
            mask = mask / 255
            
        sample = dict(image=image, mask=mask.unsqueeze(0), image_path=self.images[index], mask_path=self.masks[index])
        
        return sample

    def filter_files(self):
        assert len(self.images) == len(self.masks)
        images = []
        masks = []
        for img_path, mask_path in zip(self.images, self.masks):
            img = Image.open(img_path)
            mask = Image.open(mask_path)
            if img.size == mask.size:
                images.append(img_path)
                masks.append(mask_path)
        self.images = images
        self.masks = masks
    
    def rgb_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f).resize((self.trainsize, self.trainsize), Image.Resampling.BILINEAR)
            return np.array(img.convert('RGB'))

    def binary_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f).resize((self.trainsize, self.trainsize), Image.Resampling.NEAREST)
            img = np.array(img.convert('L'))
            return img

    def __len__(self):
        return self.size

In [5]:
class AvgMeter(object):
    def __init__(self, num=40):
        self.num = num
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.losses = []

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        self.losses.append(val)

    def show(self):
        return torch.mean(torch.stack(self.losses[np.maximum(len(self.losses)-self.num, 0):]))

In [6]:
trainsize = 352

In [7]:
# Training labeled with weak augmentation 
train_transform = A.Compose(
    [
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)

# Semi supervised transform for unlabled data with strong augmentation 
semi_transform = A.Compose(
    [
    A.Flip(p=0.5),
    A.RandomCrop(width=352, height=352),
    A.OneOf([A.RandomGamma(), A.RandomBrightness()]),
    A.OneOf([A.Blur(), A.GaussianBlur(), A.GlassBlur(), A.MotionBlur(), A.GaussNoise(), A.MedianBlur()]),
    A.Cutout(p=0.3, max_h_size=25, max_w_size=25, fill_value=255),
    A.ShiftScaleRotate(p=0.3, border_mode=cv2.BORDER_CONSTANT, shift_limit=0.15, scale_limit=0.11),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
    ]
)

train_dataset = PolypDataset(
    image_root=glob.glob('/home/nguyen.van.quan/scatsimclr/newdataset/*/image/*'), 
    gt_root=glob.glob('/home/nguyen.van.quan/scatsimclr/newdataset/*/mask/*'), 
    trainsize=trainsize, 
    transform=semi_transform
)

val_transform = A.Compose(
    [
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 
        ToTensorV2()
    ]
)

t_images = glob.glob('/home/nguyen.van.quan/scatsimclr/TestDataset/*/images/*')
t_masks = glob.glob('/home/nguyen.van.quan/scatsimclr/TestDataset/*/masks/*')

test_dataset = PolypDataset(
    image_root=t_images[:100], 
    gt_root=t_masks[:100], 
    trainsize=trainsize, 
    transform=val_transform
)



In [8]:
print(f"Train size: {len(train_dataset)}")
print(f"Test size: {len(test_dataset)}")

n_cpu = os.cpu_count()

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=8)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=8)

Train size: 1450
Test size: 100


In [9]:
class UnNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
            # The normalize code -> t.sub_(m).div_(s)
        return tensor
    
unorm = UnNormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))

In [10]:
def visualize(image, mask, original_image=None, original_mask=None):
    fontsize = 18
    
    if original_image is None and original_mask is None:
        f, ax = plt.subplots(2, 1, figsize=(8, 8))

        ax[0].imshow(image)
        ax[1].imshow(mask)
    else:
        f, ax = plt.subplots(2, 2, figsize=(8, 8))

        ax[0, 0].imshow(original_image)
        ax[0, 0].set_title('Original image', fontsize=fontsize)
        
        ax[1, 0].imshow(original_mask)
        ax[1, 0].set_title('Original mask', fontsize=fontsize)
        
        ax[0, 1].imshow(image)
        ax[0, 1].set_title('Transformed image', fontsize=fontsize)
        
        ax[1, 1].imshow(mask)
        ax[1, 1].set_title('Transformed mask', fontsize=fontsize)


In [11]:
sample = train_dataset[2]
image = unorm(sample["image"]).permute(1, 2, 0)
mask = sample["mask"].squeeze()

AssertionError: CopyPaste requires ['masks', 'paste_image', 'paste_masks']

In [None]:
aug = A.RandomGamma(gamma_limit=(50, 150), always_apply=True, p=0.6)
# aug = A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5)
augmented = aug(image=image, mask=mask)

image_padded = augmented['image']
mask_padded = augmented['mask']

print(image_padded.shape, mask_padded.shape)

visualize(image_padded, mask_padded, original_image=image, original_mask=mask)

In [None]:
class PetModel(pl.LightningModule):

    def __init__(self, arch='', encoder_name='', in_channels=3, out_classes=1, checkpoint_path='', mm_checkpoint_path='', labeled_dataloader=None, 
                 unlabeled_dataloader=None, momentum=0.99, use_momentum=False, is_semi=False, is_online=False, teacher=None, use_soft_label=True, **kwargs):
        super().__init__()
        
        self.model = smp.create_model(
            arch, encoder_name=encoder_name, in_channels=in_channels, classes=out_classes, **kwargs
        )

        # for image segmentation dice loss could be the best first choice
        self.loss_fn_1 = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
        self.loss_fn_2 = active_contour_loss
        
        # max iou of origin model and momentum model  
        self.m_max_iou = 0
        self.mm_max_iou = 0
        self.checkpoint_path = checkpoint_path
        self.mm_checkpoint_path = mm_checkpoint_path
        
        self.labeled_dataloader = labeled_dataloader
        self.unlabeled_dataloader = unlabeled_dataloader
        self.momentum = momentum 
        self.teacher = teacher
        self.is_semi = is_semi
        self.is_online = is_online
        self.use_momentum = use_momentum
        self.use_soft_label = use_soft_label
        # Beta for unlabeled loss 
        self.beta = 0.3
        
        if self.use_momentum:
            self.momentum_model = smp.create_model(
                arch, encoder_name=encoder_name, in_channels=in_channels, classes=out_classes, **kwargs
            )
            for param in self.momentum_model.parameters():
                param.requires_grad = False
        else:
            self.momentum_model = None
        # Non back-propagate into teacher
        if self.is_semi and self.teacher:
            for param in self.teacher.parameters():
                param.requires_grad = False
                
    def forward(self, image):
        return self.model(image)
    def mm_forward(self, image):
        # normalize image here
        mask = self.momentum_model.cuda()(image)
        return mask
    
    def make_pseudo_label(self, image):
        logits = self.teacher.cuda()(image)
        prob_mask = logits.sigmoid()
        # Generate hard label
        pred_mask = (prob_mask > 0.5).float()
        if self.use_soft_label:
            return prob_mask
        return pred_mask

    @torch.no_grad()
    def _update_momentum_network(self):
        """Momentum update of the teacher model with student weight"""
        if self.use_momentum and self.momentum_model:
            for param_model, param_momentum_model in zip(self.model.parameters(), self.momentum_model.parameters()):
                param_momentum_model.data = param_momentum_model.data * self.momentum + param_model.data * (1.0 - self.momentum)
    

    @torch.no_grad()
    def _update_teacher_network(self):
        """Momentum update of the teacher model with student weight"""
        if self.is_semi and self.is_online:
            for param_student, param_teacher in zip(self.model.parameters(), self.teacher.parameters()):
                param_teacher.data = param_teacher.data * self.momentum + param_student.data * (1.0 - self.momentum)
                    
    def compute_loss(self, image, mask):
        assert image.ndim == 4
        assert mask.ndim == 4
        
        h, w = image.shape[2:]
        assert h % 32 == 0 and w % 32 == 0
        # Check that mask values in between 0 and 1, NOT 0 and 255 for binary segmentation
        assert mask.max() <= 1.0 and mask.min() >= 0
        
        logits_mask = self.forward(image)
        # Predicted mask contains logits, and loss_fn param `from_logits` is set to True
        loss_1 = self.loss_fn_1(logits_mask, mask)
        loss_2 = self.loss_fn_2(logits_mask, mask)
        
        loss = 0.5*loss_1 + 0.5*loss_2
        
        y_hat_mask = logits_mask.sigmoid().data
        y_hat_mask = (y_hat_mask - y_hat_mask.min()) / (y_hat_mask.max() - y_hat_mask.min() + 1e-8)
        pred_mask = y_hat_mask.round()
        
        
        tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), mask.long(), mode="binary")
        
        result = {
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "tn": tn,
            "loss": loss
        }
        return result

    def shared_step(self, batch, stage):
        # If not have unlabeled data, the batch is only have labeled data
        batch_labeled = batch
        
        if stage == 'train':
            batch_labeled = batch['labeled']
        
        # Batch of labeled
        l_image = batch_labeled["image"]
        l_mask = batch_labeled["mask"]
        
        l_result = self.compute_loss(l_image, l_mask)
        l_loss = l_result['loss']
        # Predefine for unlabeled loss 
        u_result = l_result
        u_result['loss'] = 0
        u_loss = u_result['loss']
        
        # Batch of unlabeled 
        if self.is_semi and stage == 'train':
            batch_unlabeled = batch['unlabeled']
            u_image = batch_unlabeled["image"]
            # Compute mask from unlabled image with teacher model in online learning
            u_mask = self.make_pseudo_label(batch_unlabeled["image"])
            u_result = self.compute_loss(u_image, u_mask)
            u_loss = u_result['loss']
        
        # Compute total loss 
        loss = l_loss + self.beta * u_loss

        
        # Append loss to result
        result = l_result
        result['loss'] = loss
        
        if self.use_momentum:
            logits_mm_mask = self.mm_forward(l_image)
            
            prob_mm_mask = logits_mm_mask.sigmoid().data
            prob_mm_mask = (prob_mm_mask - prob_mm_mask.min()) / (prob_mm_mask.max() - prob_mm_mask.min() + 1e-8)
            
            pred_mm_mask = prob_mm_mask.round()
            mm_tp, mm_fp, mm_fn, mm_tn = smp.metrics.get_stats(pred_mm_mask.long(), l_mask.long(), mode="binary")

            result['mm_tp'] = mm_tp
            result['mm_fp'] = mm_fp
            result['mm_fn'] = mm_fn
            result['mm_tn'] = mm_tn

        return result
    

    def shared_epoch_end(self, outputs, stage):        
        # aggregate step metics
        tp = torch.cat([x["tp"] for x in outputs])
        fp = torch.cat([x["fp"] for x in outputs])
        fn = torch.cat([x["fn"] for x in outputs])
        tn = torch.cat([x["tn"] for x in outputs])    
        
        per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
        dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
        
        if self.use_momentum:
            # aggregate step metics
            tp = torch.cat([x["mm_tp"] for x in outputs])
            fp = torch.cat([x["mm_fp"] for x in outputs])
            fn = torch.cat([x["mm_fn"] for x in outputs])
            tn = torch.cat([x["mm_tn"] for x in outputs])    

            mm_per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
            mm_dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
            

        if use_wandb:
            wandb.log({'train_dice': dataset_iou})
        
        if stage == 'valid':
            # Save best checkpoint
            # if per_image_iou > self.m_max_iou:
            print('\nSave origin model with checkpoint loss = {}'.format(per_image_iou))
            self.m_max_iou = per_image_iou
            
            if self.use_momentum:
                if mm_per_image_iou > self.mm_max_iou:
                    print('\nSave momentum model with checkpoint loss = {}'.format(mm_per_image_iou))
                    torch.save(self.momentum_model, self.mm_checkpoint_path)
                    self.mm_max_iou = mm_per_image_iou
                
        if self.use_momentum:

            metrics = {
                f"{stage}_per_image_iou": per_image_iou,
                f"{stage}_mm_per_image_iou": mm_per_image_iou,
            }
        else:
            metrics = {
                f"{stage}_per_image_iou": per_image_iou,
            }
        torch.save(self.model, self.checkpoint_path)
        
        self.log_dict(metrics, prog_bar=True)
        
    def train_dataloader(self):
        if self.is_semi:
            loaders = {"labeled": self.labeled_dataloader, "unlabeled": self.unlabeled_dataloader}
        else:
            loaders = {"labeled": self.labeled_dataloader}

        return loaders
    
    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, "train")            

    def training_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "train")

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, "valid")

    def validation_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "valid")

    def test_step(self, batch, batch_idx):
        return self.shared_step(batch, "test")  

    def test_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "test")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0001)
    
    def on_train_batch_end(self, *args, **kwargs):
        """ Check if we should save a checkpoint after every train batch """
        global_step = self.global_step
        
        if global_step % 100 == 0 and self.is_online and self.is_semi:
            # Update teacher network 
            self._update_teacher_network()
        
        if global_step % 25 == 0 and self.use_momentum:
            self._update_momentum_network()

In [None]:
def get_testdataset(dataset):
    return PolypDataset(
        image_root=glob.glob('{}/{}/images/*'.format('TestDataset', dataset)), 
        gt_root=glob.glob('{}/{}/masks/*'.format('TestDataset', dataset)), 
        trainsize=trainsize, 
        transform=val_transform
    )

In [None]:
from src.evaluation.metric import get_scores
from tabulate import tabulate
def full_val(model, device = 'cuda:1'):
    use_wandb = False
    print("#" * 20)
    model.eval()
    
    dataset_names = os.listdir('TestDataset/')
    table = []
    headers = ['Dataset', 'IoU', 'Dice']
    ious, dices = AvgMeter(), AvgMeter()

    for dataset_name in dataset_names:
        tmp_dataset = get_testdataset(dataset_name)
        test_loader = DataLoader(tmp_dataset, batch_size=1, shuffle=False, num_workers=n_cpu)   

        # print('Dataset_name:', dataset_name)
        gts = []
        prs = []
        for i, pack in enumerate(test_loader, start=1):
            image, gt = pack["image"], pack["mask"]
            gt = gt[0][0]
            gt = np.asarray(gt, np.float32)
            image = image.to(device)

            res = model(image)[0]
            # res = F.interpolate(res, size=gt.shape, mode='bilinear', align_corners=False)
            res = res.sigmoid().data.cpu().numpy().squeeze()
            res = (res - res.min()) / (res.max() - res.min() + 1e-8)
            pr = res.round()
            gts.append(gt)
            prs.append(pr)
        mean_iou, mean_dice, _, _ = get_scores(gts, prs)
        ious.update(mean_iou)
        dices.update(mean_dice)
        if use_wandb:
            wandb.log({f'{dataset_name}_dice': mean_dice})
            wandb.log({f'{dataset_name}_iou': mean_iou})
        table.append([dataset_name, mean_iou, mean_dice])
    table.append(['Total', ious.avg, dices.avg])

    print(tabulate(table, headers=headers, tablefmt="fancy_grid"))
    return ious.avg, dices.avg

In [None]:
# Uncomment when retrain new teacher 
model = PetModel("FPN", "densenet169", in_channels=3, out_classes=1, momentum=0.95,
                 labeled_dataloader=train_dataloader, unlabeled_dataloader=None, 
                 checkpoint_path='runs/checkpoints/fpn_densenet169_full.pth')

trainer = pl.Trainer(
    accelerator="gpu", devices=[0],
    max_epochs=100
)

trainer.fit(
    model, 
    val_dataloaders=None
) 







In [None]:
model = PetModel("FPN", "densenet169", in_channels=3, out_classes=1, checkpoint_path='')
model.model = torch.load('runs/checkpoints/fpn_densenet169_full.pth', map_location='cuda:1')

In [None]:
full_val(model.model.to('cuda:0'), device='cuda:0')

In [None]:
from PIL import Image
from matplotlib import pyplot as plt
from src.data.augmentor import augmentations

image = Image.open('TrainDataset/0/image/415.png').convert('L')
img_aug = augmentations(image)


In [None]:
from src.data.polyp_dataset import ActiveDataset
from src.data.reconstruct_dataset import SemanticGenesis_Dataset
import glob
from src.data.augmentor import augmentations
train_dataset = ActiveDataset(
    image_paths=glob.glob('TrainDataset/*/image/*'), 
    gt_paths=glob.glob('TrainDataset/*/mask/*'), 
    trainsize=256, 
    transform=augmentations
)

dataset = SemanticGenesis_Dataset(train_dataset, transform=True)

In [None]:
x, x_ori, y_trans = dataset[100]