Skip to content

Commit

Permalink
main source code
Browse files Browse the repository at this point in the history
  • Loading branch information
c-feng committed Jul 14, 2020
1 parent 397bc85 commit 292d84c
Show file tree
Hide file tree
Showing 44 changed files with 3,762 additions and 0 deletions.
48 changes: 48 additions & 0 deletions libs/configs/config_acdc.py
@@ -0,0 +1,48 @@
from easydict import EasyDict as edict
import numpy as np

__C = edict()
cfg = __C


# ============== general training config =====================
__C.TRAIN = edict()

# __C.TRAIN.NET = "unet.U_Net"
__C.TRAIN.NET = "unet_df.U_NetDF"

__C.TRAIN.LR = 0.001
__C.TRAIN.LR_CLIP = 0.00001
__C.TRAIN.DECAY_STEP_LIST = [60, 100, 150, 180]
__C.TRAIN.LR_DECAY = 0.5

__C.TRAIN.GRAD_NORM_CLIP = 1.0

__C.TRAIN.OPTIMIZER = 'adam'
__C.TRAIN.WEIGHT_DECAY = 0 # "L2 regularization coeff [default: 0.0]"
__C.TRAIN.MOMENTUM = 0.9

# =============== model config ========================
__C.MODEL = edict()

__C.MODEL.SELFEATURE = True
__C.MODEL.SHIFT_N = 1
__C.MODEL.AUXSEG = True

# ================= dataset config ==========================
__C.DATASET = edict()

__C.DATASET.NAME = "acdc"
__C.DATASET.MEAN = 63.19523533061758
__C.DATASET.STD = 70.74166957523165

__C.DATASET.NUM_CLASS = 4

__C.DATASET.DF_USED = True
__C.DATASET.DF_NORM = True
__C.DATASET.BOUNDARY = False

__C.DATASET.TRAIN_LIST = "libs/datasets/jsonLists/acdcList/train.json"
__C.DATASET.TEST_LIST = "libs/datasets/jsonLists/acdcList/test.json"

__C.DATASET.TEST_PERSON_LIST = "libs/datasets/personList/AcdcTestPersonCarname.json"
3 changes: 3 additions & 0 deletions libs/datasets/__init__.py
@@ -0,0 +1,3 @@
from .cardia_dataset import CardiaDataset
from .acdc_dataset import AcdcDataset
from .lvsc_dataset import LvscDataset
63 changes: 63 additions & 0 deletions libs/datasets/acdc_dataset.py
@@ -0,0 +1,63 @@
import torch
from torch.utils.data import Dataset
from torchvision import transforms as T

import os
import json
import numpy as np
from PIL import Image
import h5py

import sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(BASE_DIR, '../../'))
from utils.direct_field.df_cardia import direct_field
from libs.datasets import augment as standard_aug
from utils.direct_field.utils_df import class2dist


class AcdcDataset(Dataset):
def __init__(self, data_list, df_used=False, joint_augment=None, augment=None, target_augment=None, df_norm=True, boundary=False):
self.joint_augment = joint_augment
self.augment = augment
self.target_augment = target_augment
self.data_list = data_list
self.df_used = df_used
self.df_norm = df_norm
self.boundary = boundary

with open(data_list, 'r') as f:
self.data_infos = json.load(f)

def __len__(self):
return len(self.data_infos)

def __getitem__(self,index):
img = h5py.File(self.data_infos[index],'r')['image']
gt = h5py.File(self.data_infos[index],'r')['label']
# print(np.unique(gt))
img = np.array(img)[:,:,None].astype(np.float32)
gt = np.array(gt)[:,:,None].astype(np.float32)
# print(np.unique(gt))

if self.joint_augment is not None:
img, gt = self.joint_augment(img, gt)
if self.augment is not None:
img = self.augment(img)
if self.target_augment is not None:
gt = self.target_augment(gt)

if self.df_used:
gt_df = direct_field(gt.numpy()[0], norm=self.df_norm)
gt_df = torch.from_numpy(gt_df)
else:
gt_df = None

if self.boundary:
dist_map = torch.from_numpy(class2dist(gt.numpy()[0], C=4))
else:
dist_map = None

return img, gt, gt_df, dist_map


193 changes: 193 additions & 0 deletions libs/datasets/augment.py
@@ -0,0 +1,193 @@

import numpy as np
import random
import torch
import numpy as np
from PIL import Image, ImageEnhance

class Compose():
def __init__(self, transforms):
self.transforms = transforms

def __call__(self, img):
for t in self.transforms:
img = t(img)
return img

class to_Tensor():
def __call__(self,arr):
if len(np.array(arr).shape) == 2:
arr = np.array(arr)[:,:,None]
arr = torch.from_numpy(np.array(arr).transpose(2,0,1))
return arr

def imresize(im, size, interp='bilinear'):
if interp == 'nearest':
resample = Image.NEAREST
elif interp == 'bilinear':
resample = Image.BILINEAR
elif interp == 'bicubic':
resample = Image.BICUBIC
else:
raise Exception('resample method undefined!')
return im.resize(size, resample)

class To_PIL_Image():
def __call__(self, img):
return to_pil_image(img)

class normalize():
def __init__(self,mean,std):
self.mean = torch.tensor(mean)
self.std = torch.tensor(std)
def __call__(self,img):
self.mean = torch.as_tensor(self.mean,dtype=img.dtype,device=img.device)
self.std = torch.as_tensor(self.std,dtype=img.dtype,device=img.device)
return (img-self.mean)/self.std

class RandomVerticalFlip():
def __init__(self, prob):
self.prob = prob

def __call__(self, img):
if random.random() < self.prob:
if isinstance(img, Image.Image):
return img.transpose(Image.FLIP_TOP_BOTTOM)
if isinstance(img, np.ndarray):
return np.flip(img, axis=0)
return img

class RandomHorizontallyFlip():
def __init__(self, prob=0.5):
self.prob = prob

def __call__(self, img):
if random.random() < self.prob:
if isinstance(img, Image.Image):
return img.transpose(Image.FLIP_LEFT_RIGHT)
if isinstance(img, np.ndarray):
return np.flip(img, axis=1)
return img

class RandomRotate():
def __init__(self, degree, prob=0.5):
self.prob = prob
self.degree = degree

def __call__(self, img, interpolation=Image.BILINEAR):
if random.random() < self.prob:
rotate_detree = random.random() * 2 * self.degree - self.degree
return img.rotate(rotate_detree, interpolation)
return img

class RandomBrightness():
def __init__(self, min_factor, max_factor, prob=0.5):
""" :param min_factor: The value between 0.0 and max_factor
that define the minimum adjustment of image brightness.
The value 0.0 gives a black image,The value 1.0 gives the original image, value bigger than 1.0 gives more bright image.
:param max_factor: A value should be bigger than min_factor.
that define the maximum adjustment of image brightness.
The value 0.0 gives a black image, value 1.0 gives the original image, value bigger than 1.0 gives more bright image.
"""
self.prob = prob
self.min_factor = min_factor
self.max_factor = max_factor

# def __brightness(self, img, factor):
# return img * (1.0 - factor) + img * factor

# def __call__(self, img):
# if random.random() < self.prob:
# factor = np.random.uniform(self.min_factor, self.max_factor)
# return self.__brightness(img, factor)

def __call__(self, img):
if random.random() < self.prob:
factor = np.random.uniform(self.min_factor, self.max_factor)
enhancer_brightness = ImageEnhance.Brightness(img)
return enhancer_brightness.enhance(factor)

return img

class RandomContrast():
def __init__(self, min_factor, max_factor, prob=0.5):
""" :param min_factor: The value between 0.0 and max_factor
that define the minimum adjustment of image contrast.
The value 0.0 gives s solid grey image, value 1.0 gives the original image.
:param max_factor: A value should be bigger than min_factor.
that define the maximum adjustment of image contrast.
The value 0.0 gives s solid grey image, value 1.0 gives the original image.
"""
self.prob = prob
self.min_factor = min_factor
self.max_factor = max_factor

def __call__(self, img):
if random.random() < self.prob:
factor = np.random.uniform(self.min_factor, self.max_factor)
enhance_contrast = ImageEnhance.Contrast(img)
return enhance_contrast.enhance(factor)
return img

def to_pil_image(pic, mode=None):
"""Convert a tensor or an ndarray to PIL Image.
See :class:`~torchvision.transforms.ToPIlImage` for more details.
Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
.. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes
Returns:
PIL Image: Image converted to PIL Image.
"""
# if not(_is_numpy_image(pic) or _is_tensor_image(pic)):
# raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))

npimg = pic
if isinstance(pic, torch.FloatTensor):
pic = pic.mul(255).byte()
if torch.is_tensor(pic):
npimg = np.transpose(pic.numpy(), (1, 2, 0))

if not isinstance(npimg, np.ndarray):
raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' +
'not {}'.format(type(npimg)))

if npimg.shape[2] == 1:
expected_mode = None
npimg = npimg[:, :, 0]
if npimg.dtype == np.uint8:
expected_mode = 'L'
elif npimg.dtype == np.int16:
expected_mode = 'I;16'
elif npimg.dtype == np.int32:
expected_mode = 'I'
elif npimg.dtype == np.float32:
expected_mode = 'F'
if mode is not None and mode != expected_mode:
raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}"
.format(mode, np.dtype, expected_mode))
mode = expected_mode

elif npimg.shape[2] == 4:
permitted_4_channel_modes = ['RGBA', 'CMYK']
if mode is not None and mode not in permitted_4_channel_modes:
raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes))

if mode is None and npimg.dtype == np.uint8:
mode = 'RGBA'
else:
permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV']
if mode is not None and mode not in permitted_3_channel_modes:
raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes))
if mode is None and npimg.dtype == np.uint8:
mode = 'RGB'

if mode is None:
raise TypeError('Input type {} is not supported'.format(npimg.dtype))

return Image.fromarray(npimg, mode=mode)
26 changes: 26 additions & 0 deletions libs/datasets/collate_batch.py
@@ -0,0 +1,26 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from utils.image_list import to_image_list

class BatchCollator(object):
"""
From a list of samples from the dataset,
returns the batched images and targets.
This should be passed to the DataLoader
"""

def __init__(self, size_divisible=0, df_used=False, boundary=False):
self.size_divisible = size_divisible
self.df_used = df_used
self.boundary = boundary

def __call__(self, batch):
transposed_batch = list(zip(*batch))
images = to_image_list(transposed_batch[0], self.size_divisible)
targets = to_image_list(transposed_batch[1], self.size_divisible)

dfs = to_image_list(transposed_batch[2], self.size_divisible) if self.df_used else None

dist_maps = to_image_list(transposed_batch[3], self.size_divisible) if self.boundary else None

return images, targets, dfs, dist_maps

0 comments on commit 292d84c

Please sign in to comment.