In [2]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image
import glob
from collections import namedtuple
from typing import Optional, Union, List
import copy
from tqdm.auto import tqdm
import kornia
import time

import segmentation_models_pytorch as smp


# The kinds of splits we can do
SPLIT_TYPES = ['train', 'test']

# Splitting images by video source
VIDEO_GLOBS_PUBLIC = \
      [f"cholec80_video{i:02d}_*" for i in range(1,81)] \
    + [f"M2CCAI2016_video{i}_*" for i in range(81,122)]

VIDEO_GLOBS_PRIVATE = \
      [f"AdnanSet_LC_{i}_*" for i in range(1,165)] \
    + [f"AminSet_LC_{i}_*" for i in range(1,11)] \
    + ["HokkaidoSet_LC_1_*", "HokkaidoSet_LC_2_*"] \
    + [f"UTSWSet_Case_{i}_*" for i in range(1,13)] \
    + [f"WashUSet_LC_01_*"]

VIDEO_GLOBS_DICT = {
    'public': VIDEO_GLOBS_PUBLIC,
    'private': VIDEO_GLOBS_PRIVATE
}
#
class CholecGonogoDataset(Dataset):
    
    gonogo_names: str = [
        "Background",
        "Go",
        "Nogo"
    ]

    organ_names: str = [
        "Background",
        "Liver",
        "Gallbladder",
        "Hepatocystic Triangle"
    ]
    
    def __init__(
        self,
         data_dir: str,
         images_dir: str = "images",
         gonogo_labels_dir: str = "gonogo_labels",
         organ_labels_dir: str = "organ_labels",
         split: str = "train",
         split_filepath: str = None,
         video_globs: Union[str, list] = 'public',
         train_ratio: float = 0.8,
         image_height: float = 384,
         image_width: float = 640,
         train_angle_max: float = 60.0,
         image_transforms = None,
         label_transforms = None,
         download: bool = False,
         gen_seed: int = 1234,
         augmentations: bool = False,
         pretransform: bool = False
    ):
        if download:
            raise ValueError("download not implemented")
            
        # assert split in ['train', 'test']

        self.data_dir = data_dir
        self.images_dir = os.path.join(data_dir, images_dir)
        self.gonogo_labels_dir = os.path.join(data_dir, gonogo_labels_dir)
        self.organ_labels_dir = os.path.join(data_dir, organ_labels_dir)

        assert os.path.isdir(self.images_dir)
        assert os.path.isdir(self.gonogo_labels_dir)
        assert os.path.isdir(self.organ_labels_dir)
        assert split in SPLIT_TYPES
        self.split = split
        self.augmentations = augmentations

        gen = torch.Generator()
        gen.manual_seed(gen_seed)
        
        # Use existing video globs or new ones
        if isinstance(video_globs, str):
            if video_globs in VIDEO_GLOBS_DICT:
                VIDEO_GLOBS = VIDEO_GLOBS_DICT[video_globs]
            else:
                raise ValueError(f'video_globs {video_globs} does not exist, please pass in one of {VIDEO_GLOBS_DICT.keys()} or a list.')
        elif isinstance(video_globs, list):
            VIDEO_GLOBS = video_globs
        else:
            raise NotImplementedError()
        self.video_globs = VIDEO_GLOBS
        
        # Split by the video source
        if split_filepath is not None and os.path.isfile(split_filepath):
            with open(split_filepath, 'rt') as input_file:
                self.image_files = [line.strip() for line in input_file.readlines()]
            
        elif split == "train" or split == "test":
            num_all, num_train = len(VIDEO_GLOBS), int(len(VIDEO_GLOBS) * train_ratio)
            perm = torch.randperm(num_all, generator=gen)
            idxs = perm[:num_train] if "train" in split else perm[num_train:]

            image_files = []
            for i in idxs:
                # import pdb; pdb.set_trace()
                image_files += [os.path.basename(path) for path in \
                    glob.glob(os.path.join(self.images_dir, VIDEO_GLOBS[i]))]
            self.image_files = sorted(image_files)

        else:
            raise NotImplementedError()

        gen.manual_seed(gen_seed) # rotation seed shouldn't be affected by the split seed
        # Random rotation angles used for the training data
        self.train_angle_max = train_angle_max
        self.random_angles = train_angle_max * (torch.rand(len(self.image_files)) * 2 - 1)

        self.image_height = image_height
        self.image_width = image_width

        # Image transforms
        if image_transforms is None:
            self.image_transforms = transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize((image_height, image_width), antialias=True),
            ])
        else:
            assert callable(image_transforms)
            self.image_transforms = image_transforms

        if label_transforms is None:
            self.label_transforms = transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize((image_height, image_width), antialias=True)
            ])
        else:
            assert callable(label_transforms)
            self.label_transforms = label_transforms

        self.pretransform = pretransform
        if pretransform:
            self.pretransform_data()

    # import torch
    # from tqdm import tqdm
    # from torch.utils.data import DataLoader, TensorDataset

    def pretransform_data(self, batch_size=8):
        print('pretransforming data')
        images = []
        organs = []
        gonogos = []

        images_aug = []
        organs_aug = []
        gonogos_aug = []

        # Iterate through dataset in batches
        for batch_start in tqdm(range(0, len(self.image_files), batch_size)):
            batch_end = min(batch_start + batch_size, len(self.image_files))
            batch_images = []
            batch_organs = []
            batch_gonogos = []

            # Load and preprocess each image in the batch
            for idx in range(batch_start, batch_end):
                image_file = os.path.join(self.images_dir, self.image_files[idx])
                organ_label_file = os.path.join(self.organ_labels_dir, self.image_files[idx])
                gonogo_label_file = os.path.join(self.gonogo_labels_dir, self.image_files[idx])

                image = Image.open(image_file).convert("RGB")
                organ_label = Image.open(organ_label_file).convert("L")
                gonogo_label = Image.open(gonogo_label_file).convert("L")

                # Apply transforms and move to device
                image = self.image_transforms(image)
                organ_label = self.label_transforms(organ_label)
                gonogo_label = self.label_transforms(gonogo_label)

                batch_images.append(image)
                batch_organs.append(organ_label)
                batch_gonogos.append(gonogo_label)

            # Stack images and labels into batch tensors
            batch_images = torch.stack(batch_images)
            batch_organs = torch.stack(batch_organs)
            batch_gonogos = torch.stack(batch_gonogos)

            # Collect batches
            images.append(batch_images)
            organs.append(batch_organs)
            gonogos.append(batch_gonogos)

            # Apply augmentations if enabled
            if self.augmentations:
                # Create a tensor of angles for the entire batch
                angles = torch.tensor(self.random_angles[batch_start:batch_end])
                
                # Rotate the batch of images and labels
                batch_images = kornia.geometry.transform.rotate(batch_images, angles)
                batch_organs = kornia.geometry.transform.rotate(batch_organs, angles)
                batch_gonogos = kornia.geometry.transform.rotate(batch_gonogos, angles)

            images_aug.append(batch_images)
            organs_aug.append(batch_organs)
            gonogos_aug.append(batch_gonogos)
            
        # Concatenate all batches into a single tensor
        self.images_pretransformed = torch.cat(images)
        self.organs_pretransformed = torch.cat(organs)
        self.gonogos_pretransformed = torch.cat(gonogos)

        self.images_aug = torch.cat(images_aug)
        self.organs_aug = torch.cat(organs_aug)
        self.gonogos_aug = torch.cat(gonogos_aug)
        print('pretransforming data done')

    def augment_data(self, batch_size=8):
        images = []
        organs = []
        gonogos = []
        for batch_start in tqdm(range(0, len(self.image_files), batch_size)):
            batch_end = min(batch_start + batch_size, len(self.image_files))
            batch_images = self.images_pretransformed[batch_start:batch_end]
            batch_organs = self.organs_pretransformed[batch_start:batch_end]
            batch_gonogos = self.gonogos_pretransformed[batch_start:batch_end]

            angles = torch.tensor(self.random_angles[batch_start:batch_end])

            batch_images = kornia.geometry.transform.rotate(batch_images, angles)
            batch_organs = kornia.geometry.transform.rotate(batch_organs, angles)
            batch_gonogos = kornia.geometry.transform.rotate(batch_gonogos, angles)

            images.append(batch_images)
            organs.append(batch_organs)
            gonogos.append(batch_gonogos)
        
        self.images_aug = torch.cat(images)
        self.organs_aug = torch.cat(organs)
        self.gonogos_aug = torch.cat(gonogos)
        
    def reset_random_angles(self, seed=None):
        if seed is not None:
            torch.manual_seed(seed)
        self.random_angles = self.train_angle_max * (torch.rand(len(self.image_files)) * 2 - 1)
        print('angles reset')

    def __len__(self):
        return len(self.image_files)


    def __getitem__(self, idx):
        verbose = False
        if self.pretransform:
            image = self.images_pretransformed[idx]
            organ_label = self.organs_pretransformed[idx]
            gonogo_label = self.gonogos_pretransformed[idx]
        else:
            start = time.time()
            image_file = os.path.join(self.images_dir, self.image_files[idx])
            organ_label_file = os.path.join(self.organ_labels_dir, self.image_files[idx])
            gonogo_label_file = os.path.join(self.gonogo_labels_dir, self.image_files[idx])
            if verbose:
                print('time in data1', time.time() - start)
            start = time.time()
            # Read image and label
            image = Image.open(image_file).convert("RGB")
            organ_label = Image.open(organ_label_file).convert("L") # L is grayscale
            gonogo_label = Image.open(gonogo_label_file).convert("L")
            if verbose:
                print('time in data2', time.time() - start)
            start = time.time()

            image = self.image_transforms(image)
            organ_label = self.label_transforms(organ_label)
            gonogo_label = self.label_transforms(gonogo_label)
            if verbose:
                print('time in data3', time.time() - start)
            start = time.time()
                
            if self.augmentations: #self.split.startswith("train"): we should allow choice for using augmentation or not for both train and test
                # Apply the random rotation
                angle = self.random_angles[idx].item()
                if idx == 0:
                    print('angle', angle)
                image = transforms.functional.rotate(image, angle)
                organ_label = transforms.functional.rotate(organ_label, angle)
                gonogo_label = transforms.functional.rotate(gonogo_label, angle)

        if verbose:
            print('time in data4', time.time() - start)
        start = time.time()
        organ_label = (organ_label * 255).round().long()
        gonogo_label = (gonogo_label * 255).round().long()
        if verbose:
            print('time in data', time.time() - start)
        return {
            'image': image, 
            'organs': organ_label,
            'gonogo': gonogo_label,
            'idx': idx
        }
    
    def get_cv_splits(self, cv_fold=5, gen_seed=0): # need to test
        dataset = self
        assert cv_fold > 1
        assert len(dataset.image_files) > cv_fold

        # get a subset of train_ratio split by videos according to a random seed
        num_all, num_train = len(dataset.image_files), int(len(dataset.image_files) / cv_fold)
        gen = torch.Generator()
        gen.manual_seed(gen_seed)
        perm = torch.randperm(num_all, generator=gen)
        # idxs should be all the folds
        idxs = perm[:num_train * cv_fold].reshape(cv_fold, num_train)

        # set the new dataset's image_files and video_globs
        datasets = []
        for i in range(cv_fold):
            image_files_new = [dataset.image_files[j] for j in idxs[i]]
            new_dataset = copy.deepcopy(dataset)
            new_dataset.image_files = image_files_new
            datasets.append(new_dataset)
        
        return datasets

    # get a subset
    def get_subset(self, split='train', train_ratio=0.9, gen_seed=0):
        dataset = self
        assert split in ['train', 'val']
        video_globs_used = []
        used_image_files_all = {}
        for i in range(len(dataset.video_globs)):
            image_files_curr = [os.path.basename(path) for path in glob.glob(os.path.join(dataset.images_dir, dataset.video_globs[i]))]
            used_image_files = sorted(set(image_files_curr) & set(dataset.image_files))
            if len(used_image_files) > 0:
                used_image_files_all[dataset.video_globs[i]] = used_image_files
                video_globs_used.append(dataset.video_globs[i])

        # get a subset of train_ratio split by videos according to a random seed
        num_all, num_train = len(video_globs_used), int(len(video_globs_used) * train_ratio)
        gen = torch.Generator()
        gen.manual_seed(gen_seed)
        perm = torch.randperm(num_all, generator=gen)
        idxs = perm[:num_train] if "train" in split else perm[num_train:]

        # set the new dataset's image_files and video_globs
        image_files_new = []
        video_globs_new = []
        for i in idxs:
            image_files_new += [os.path.basename(path) for path in \
                glob.glob(os.path.join(dataset.images_dir, video_globs_used[i]))]
            video_globs_new.append(video_globs_used[i])
        image_files_new = sorted(image_files_new)
        image_files_idxs = [dataset.image_files.index(image_file) for image_file in image_files_new]
        new_dataset = copy.deepcopy(dataset)
        new_dataset.image_files = image_files_new
        new_dataset.video_globs = video_globs_new

        return new_dataset
    
    def concat_set(self, other):
        dataset = copy.deepcopy(self)
        dataset.image_files = [filename for filename in self.image_files] + \
            [filename for filename in other.image_files]
        dataset.video_globs = sorted(set([filename for filename in self.video_globs] + \
            [filename for filename in other.video_globs]))
        dataset.random_angles = torch.cat([self.random_angles, other.random_angles])
        return dataset

def get_cholecgonogo_datasets(mode, train_test_seed=0, train_val_seed=0, augmentations=False,
                            data_dir='/shared_data0/weiqiuy/real_drs/data/abdomen_exlib',
                            public_split_filepath_train='/shared_data0/weiqiuy/real_drs/data/splits/public_train0.txt',
                            public_split_filepath_test='/shared_data0/weiqiuy/real_drs/data/splits/public_test0.txt',
                            private_split_filepath_train=None,
                            private_split_filepath_test=None):
    assert mode in ['public', 'private', 'all']

    # train sets
    train_public_dataset = CholecGonogoDataset(data_dir=data_dir, split='train', video_globs='public', 
        split_filepath=public_split_filepath_train, gen_seed=train_test_seed)
    train_private_dataset = CholecGonogoDataset(data_dir=data_dir, split='train', video_globs='private', 
        split_filepath=private_split_filepath_train, gen_seed=train_test_seed)

    # test sets
    test_public_dataset = CholecGonogoDataset(data_dir=data_dir, split='test', video_globs='public', 
        split_filepath=public_split_filepath_test, gen_seed=train_test_seed)
    test_private_dataset = CholecGonogoDataset(data_dir=data_dir, split='test', video_globs='private', 
        split_filepath=private_split_filepath_test, gen_seed=train_test_seed)

    # Split train into train and val
    train_public_dataset_train = train_public_dataset.get_subset(split='train', gen_seed=train_val_seed)
    train_public_dataset_val = train_public_dataset.get_subset(split='val', gen_seed=train_val_seed)
    train_private_dataset_train = train_private_dataset.get_subset(split='train', gen_seed=train_val_seed)
    train_private_dataset_val = train_private_dataset.get_subset(split='val', gen_seed=train_val_seed)

    if mode == 'public':
        train_dataset = train_public_dataset_train
        val_dataset = train_public_dataset_val
        test_dataset = test_public_dataset
    elif mode == 'private':
        train_dataset = train_private_dataset_train
        val_dataset = train_private_dataset_val
        test_dataset = test_private_dataset
    elif mode == 'all':
        train_dataset = train_public_dataset_train.concat_set(train_private_dataset_train)
        val_dataset = train_public_dataset_val.concat_set(train_private_dataset_val)
        test_dataset = test_public_dataset.concat_set(test_private_dataset)
    else:
        raise NotImplementedError()

    return {
        'train_dataset': train_dataset,
        'val_dataset': val_dataset,
        'test_dataset': test_dataset
    }

# todo: cross validation split # need to test
# train/val is cross validation, test is the same
def get_cholecgonogo_cv_datasets(mode, cv_fold=5, gen_seed=0, train_test_seed=0, train_val_seed=0, augmentations=False,
                            data_dir='/shared_data0/weiqiuy/real_drs/data/abdomen_exlib',
                            public_split_filepath_train='/shared_data0/weiqiuy/real_drs/data/splits/public_train0.txt',
                            public_split_filepath_test='/shared_data0/weiqiuy/real_drs/data/splits/public_test0.txt',
                            private_split_filepath_train=None,
                            private_split_filepath_test=None):
    assert mode in ['public', 'private', 'all']
    # train sets
    train_public_dataset = CholecGonogoDataset(data_dir=data_dir, split='train', video_globs='public', 
        split_filepath=public_split_filepath_train, gen_seed=train_test_seed)
    train_private_dataset = CholecGonogoDataset(data_dir=data_dir, split='train', video_globs='private', 
        split_filepath=private_split_filepath_train, gen_seed=train_test_seed)

    # test sets
    test_public_dataset = CholecGonogoDataset(data_dir=data_dir, split='test', video_globs='public', 
        split_filepath=public_split_filepath_test, gen_seed=train_test_seed)
    test_private_dataset = CholecGonogoDataset(data_dir=data_dir, split='test', video_globs='private', 
        split_filepath=private_split_filepath_test, gen_seed=train_test_seed)

    # Split train into train and val
    train_public_datasets = train_public_dataset.get_cv_splits(cv_fold=cv_fold, gen_seed=gen_seed)
    train_private_datasets = train_private_dataset.get_cv_splits(cv_fold=cv_fold, gen_seed=gen_seed)

    datasets = []
    for i in range(cv_fold):
        if mode == 'public':
            train_dataset = train_public_datasets[i]
            val_dataset = train_public_datasets[(i+1)%cv_fold]
            test_dataset = test_public_dataset
        elif mode == 'private':
            train_dataset = train_private_datasets[i]
            val_dataset = train_private_datasets[(i+1)%cv_fold]
            test_dataset = test_private_dataset
        elif mode == 'all':
            train_dataset = train_public_datasets[i].concat_set(train_private_datasets[i])
            val_dataset = train_public_datasets[(i+1)%cv_fold].concat_set(train_private_datasets[(i+1)%cv_fold])
            test_dataset = test_public_dataset.concat_set(test_private_dataset)
        datasets.append({
            'train_dataset': train_dataset,
            'val_dataset': val_dataset,
            'test_dataset': test_dataset
        })

    return datasets


In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image
import glob
from collections import namedtuple
from typing import Optional, Union, List

import segmentation_models_pytorch as smp

ModelOutput = namedtuple("ModelOutput", ["logits", "pooler_output"])


class Unet(smp.Unet):
    def __init__(
        self,
        encoder_name: str = "resnet50",
        encoder_depth: int = 5,
        encoder_weights: Optional[str] = "imagenet",
        decoder_use_batchnorm: bool = True,
        decoder_channels: List[int] = (256, 128, 64, 32, 16),
        decoder_attention_type: Optional[str] = None,
        in_channels: int = 3,
        classes: int = 1,
        activation: Optional[Union[str, callable]] = None,
        aux_params: Optional[dict] = None,
    ):
        super().__init__(encoder_name=encoder_name,
                         encoder_depth=encoder_depth,
                         encoder_weights=encoder_weights,
                         decoder_use_batchnorm=decoder_use_batchnorm,
                         decoder_channels=decoder_channels,
                         decoder_attention_type=decoder_attention_type,
                         in_channels=in_channels,
                         classes=classes,
                         activation=activation,
                         aux_params=aux_params)

    def forward(self, x, return_tuple=False):
        """Sequentially pass `x` trough model`s encoder, decoder and heads"""

        self.check_input_shape(x)

        features = self.encoder(x)
        decoder_output = self.decoder(*features)

        masks = self.segmentation_head(decoder_output)

        if self.classification_head is not None:
            labels = self.classification_head(features[-1])
            return masks, labels

        if return_tuple:
            return ModelOutput(logits=masks,
                               pooler_output=decoder_output)
        else:
            return masks


# Basic classification model
class AbdomenClsModel(nn.Module):
    def __init__(self, num_classes, in_channels=3):
        super().__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes

        self.layers = nn.Sequential(
            nn.Conv2d(3, 256, kernel_size=3, stride=2, padding=1),   # (N,256,32,32)
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 128, kernel_size=3, stride=2, padding=1), # (N,128,16,16)
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, stride=2, padding=1),  # (N,64,8,8)
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, stride=2, padding=1),   # (N,32,4,4)
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(1),
            nn.Linear(32*4*4, num_classes)
        )

    def forward(self, x):
        N, C, H, W = x.shape
        assert C == 3
        x = F.interpolate(x, size=[64,64])
        y = self.layers(x)
        return y


# Basic segmentation model
class AbdomenSegModel(nn.Module):
    def __init__(self, num_classes, in_channels=3,
                 encoder_name="resnet50", encoder_weights="imagenet", 
                 activation="softmax2d"):
        super().__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.unet = Unet(encoder_name=encoder_name,
                             encoder_weights=encoder_weights,
                             in_channels=in_channels,
                             classes=num_classes,
                             activation=activation)

    def forward(self, x, return_tuple=True):
        N, C, H, W = x.shape
        assert H % 32 == 0 and W % 32 == 0
        return self.unet(x, return_tuple=return_tuple)
