Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
44 changed files
with
3,762 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from easydict import EasyDict as edict | ||
import numpy as np | ||
|
||
__C = edict() | ||
cfg = __C | ||
|
||
|
||
# ============== general training config ===================== | ||
__C.TRAIN = edict() | ||
|
||
# __C.TRAIN.NET = "unet.U_Net" | ||
__C.TRAIN.NET = "unet_df.U_NetDF" | ||
|
||
__C.TRAIN.LR = 0.001 | ||
__C.TRAIN.LR_CLIP = 0.00001 | ||
__C.TRAIN.DECAY_STEP_LIST = [60, 100, 150, 180] | ||
__C.TRAIN.LR_DECAY = 0.5 | ||
|
||
__C.TRAIN.GRAD_NORM_CLIP = 1.0 | ||
|
||
__C.TRAIN.OPTIMIZER = 'adam' | ||
__C.TRAIN.WEIGHT_DECAY = 0 # "L2 regularization coeff [default: 0.0]" | ||
__C.TRAIN.MOMENTUM = 0.9 | ||
|
||
# =============== model config ======================== | ||
__C.MODEL = edict() | ||
|
||
__C.MODEL.SELFEATURE = True | ||
__C.MODEL.SHIFT_N = 1 | ||
__C.MODEL.AUXSEG = True | ||
|
||
# ================= dataset config ========================== | ||
__C.DATASET = edict() | ||
|
||
__C.DATASET.NAME = "acdc" | ||
__C.DATASET.MEAN = 63.19523533061758 | ||
__C.DATASET.STD = 70.74166957523165 | ||
|
||
__C.DATASET.NUM_CLASS = 4 | ||
|
||
__C.DATASET.DF_USED = True | ||
__C.DATASET.DF_NORM = True | ||
__C.DATASET.BOUNDARY = False | ||
|
||
__C.DATASET.TRAIN_LIST = "libs/datasets/jsonLists/acdcList/train.json" | ||
__C.DATASET.TEST_LIST = "libs/datasets/jsonLists/acdcList/test.json" | ||
|
||
__C.DATASET.TEST_PERSON_LIST = "libs/datasets/personList/AcdcTestPersonCarname.json" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .cardia_dataset import CardiaDataset | ||
from .acdc_dataset import AcdcDataset | ||
from .lvsc_dataset import LvscDataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import torch | ||
from torch.utils.data import Dataset | ||
from torchvision import transforms as T | ||
|
||
import os | ||
import json | ||
import numpy as np | ||
from PIL import Image | ||
import h5py | ||
|
||
import sys | ||
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | ||
sys.path.append(os.path.join(BASE_DIR, '../../')) | ||
from utils.direct_field.df_cardia import direct_field | ||
from libs.datasets import augment as standard_aug | ||
from utils.direct_field.utils_df import class2dist | ||
|
||
|
||
class AcdcDataset(Dataset): | ||
def __init__(self, data_list, df_used=False, joint_augment=None, augment=None, target_augment=None, df_norm=True, boundary=False): | ||
self.joint_augment = joint_augment | ||
self.augment = augment | ||
self.target_augment = target_augment | ||
self.data_list = data_list | ||
self.df_used = df_used | ||
self.df_norm = df_norm | ||
self.boundary = boundary | ||
|
||
with open(data_list, 'r') as f: | ||
self.data_infos = json.load(f) | ||
|
||
def __len__(self): | ||
return len(self.data_infos) | ||
|
||
def __getitem__(self,index): | ||
img = h5py.File(self.data_infos[index],'r')['image'] | ||
gt = h5py.File(self.data_infos[index],'r')['label'] | ||
# print(np.unique(gt)) | ||
img = np.array(img)[:,:,None].astype(np.float32) | ||
gt = np.array(gt)[:,:,None].astype(np.float32) | ||
# print(np.unique(gt)) | ||
|
||
if self.joint_augment is not None: | ||
img, gt = self.joint_augment(img, gt) | ||
if self.augment is not None: | ||
img = self.augment(img) | ||
if self.target_augment is not None: | ||
gt = self.target_augment(gt) | ||
|
||
if self.df_used: | ||
gt_df = direct_field(gt.numpy()[0], norm=self.df_norm) | ||
gt_df = torch.from_numpy(gt_df) | ||
else: | ||
gt_df = None | ||
|
||
if self.boundary: | ||
dist_map = torch.from_numpy(class2dist(gt.numpy()[0], C=4)) | ||
else: | ||
dist_map = None | ||
|
||
return img, gt, gt_df, dist_map | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
|
||
import numpy as np | ||
import random | ||
import torch | ||
import numpy as np | ||
from PIL import Image, ImageEnhance | ||
|
||
class Compose(): | ||
def __init__(self, transforms): | ||
self.transforms = transforms | ||
|
||
def __call__(self, img): | ||
for t in self.transforms: | ||
img = t(img) | ||
return img | ||
|
||
class to_Tensor(): | ||
def __call__(self,arr): | ||
if len(np.array(arr).shape) == 2: | ||
arr = np.array(arr)[:,:,None] | ||
arr = torch.from_numpy(np.array(arr).transpose(2,0,1)) | ||
return arr | ||
|
||
def imresize(im, size, interp='bilinear'): | ||
if interp == 'nearest': | ||
resample = Image.NEAREST | ||
elif interp == 'bilinear': | ||
resample = Image.BILINEAR | ||
elif interp == 'bicubic': | ||
resample = Image.BICUBIC | ||
else: | ||
raise Exception('resample method undefined!') | ||
return im.resize(size, resample) | ||
|
||
class To_PIL_Image(): | ||
def __call__(self, img): | ||
return to_pil_image(img) | ||
|
||
class normalize(): | ||
def __init__(self,mean,std): | ||
self.mean = torch.tensor(mean) | ||
self.std = torch.tensor(std) | ||
def __call__(self,img): | ||
self.mean = torch.as_tensor(self.mean,dtype=img.dtype,device=img.device) | ||
self.std = torch.as_tensor(self.std,dtype=img.dtype,device=img.device) | ||
return (img-self.mean)/self.std | ||
|
||
class RandomVerticalFlip(): | ||
def __init__(self, prob): | ||
self.prob = prob | ||
|
||
def __call__(self, img): | ||
if random.random() < self.prob: | ||
if isinstance(img, Image.Image): | ||
return img.transpose(Image.FLIP_TOP_BOTTOM) | ||
if isinstance(img, np.ndarray): | ||
return np.flip(img, axis=0) | ||
return img | ||
|
||
class RandomHorizontallyFlip(): | ||
def __init__(self, prob=0.5): | ||
self.prob = prob | ||
|
||
def __call__(self, img): | ||
if random.random() < self.prob: | ||
if isinstance(img, Image.Image): | ||
return img.transpose(Image.FLIP_LEFT_RIGHT) | ||
if isinstance(img, np.ndarray): | ||
return np.flip(img, axis=1) | ||
return img | ||
|
||
class RandomRotate(): | ||
def __init__(self, degree, prob=0.5): | ||
self.prob = prob | ||
self.degree = degree | ||
|
||
def __call__(self, img, interpolation=Image.BILINEAR): | ||
if random.random() < self.prob: | ||
rotate_detree = random.random() * 2 * self.degree - self.degree | ||
return img.rotate(rotate_detree, interpolation) | ||
return img | ||
|
||
class RandomBrightness(): | ||
def __init__(self, min_factor, max_factor, prob=0.5): | ||
""" :param min_factor: The value between 0.0 and max_factor | ||
that define the minimum adjustment of image brightness. | ||
The value 0.0 gives a black image,The value 1.0 gives the original image, value bigger than 1.0 gives more bright image. | ||
:param max_factor: A value should be bigger than min_factor. | ||
that define the maximum adjustment of image brightness. | ||
The value 0.0 gives a black image, value 1.0 gives the original image, value bigger than 1.0 gives more bright image. | ||
""" | ||
self.prob = prob | ||
self.min_factor = min_factor | ||
self.max_factor = max_factor | ||
|
||
# def __brightness(self, img, factor): | ||
# return img * (1.0 - factor) + img * factor | ||
|
||
# def __call__(self, img): | ||
# if random.random() < self.prob: | ||
# factor = np.random.uniform(self.min_factor, self.max_factor) | ||
# return self.__brightness(img, factor) | ||
|
||
def __call__(self, img): | ||
if random.random() < self.prob: | ||
factor = np.random.uniform(self.min_factor, self.max_factor) | ||
enhancer_brightness = ImageEnhance.Brightness(img) | ||
return enhancer_brightness.enhance(factor) | ||
|
||
return img | ||
|
||
class RandomContrast(): | ||
def __init__(self, min_factor, max_factor, prob=0.5): | ||
""" :param min_factor: The value between 0.0 and max_factor | ||
that define the minimum adjustment of image contrast. | ||
The value 0.0 gives s solid grey image, value 1.0 gives the original image. | ||
:param max_factor: A value should be bigger than min_factor. | ||
that define the maximum adjustment of image contrast. | ||
The value 0.0 gives s solid grey image, value 1.0 gives the original image. | ||
""" | ||
self.prob = prob | ||
self.min_factor = min_factor | ||
self.max_factor = max_factor | ||
|
||
def __call__(self, img): | ||
if random.random() < self.prob: | ||
factor = np.random.uniform(self.min_factor, self.max_factor) | ||
enhance_contrast = ImageEnhance.Contrast(img) | ||
return enhance_contrast.enhance(factor) | ||
return img | ||
|
||
def to_pil_image(pic, mode=None): | ||
"""Convert a tensor or an ndarray to PIL Image. | ||
See :class:`~torchvision.transforms.ToPIlImage` for more details. | ||
Args: | ||
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. | ||
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). | ||
.. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes | ||
Returns: | ||
PIL Image: Image converted to PIL Image. | ||
""" | ||
# if not(_is_numpy_image(pic) or _is_tensor_image(pic)): | ||
# raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) | ||
|
||
npimg = pic | ||
if isinstance(pic, torch.FloatTensor): | ||
pic = pic.mul(255).byte() | ||
if torch.is_tensor(pic): | ||
npimg = np.transpose(pic.numpy(), (1, 2, 0)) | ||
|
||
if not isinstance(npimg, np.ndarray): | ||
raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' + | ||
'not {}'.format(type(npimg))) | ||
|
||
if npimg.shape[2] == 1: | ||
expected_mode = None | ||
npimg = npimg[:, :, 0] | ||
if npimg.dtype == np.uint8: | ||
expected_mode = 'L' | ||
elif npimg.dtype == np.int16: | ||
expected_mode = 'I;16' | ||
elif npimg.dtype == np.int32: | ||
expected_mode = 'I' | ||
elif npimg.dtype == np.float32: | ||
expected_mode = 'F' | ||
if mode is not None and mode != expected_mode: | ||
raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}" | ||
.format(mode, np.dtype, expected_mode)) | ||
mode = expected_mode | ||
|
||
elif npimg.shape[2] == 4: | ||
permitted_4_channel_modes = ['RGBA', 'CMYK'] | ||
if mode is not None and mode not in permitted_4_channel_modes: | ||
raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes)) | ||
|
||
if mode is None and npimg.dtype == np.uint8: | ||
mode = 'RGBA' | ||
else: | ||
permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV'] | ||
if mode is not None and mode not in permitted_3_channel_modes: | ||
raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes)) | ||
if mode is None and npimg.dtype == np.uint8: | ||
mode = 'RGB' | ||
|
||
if mode is None: | ||
raise TypeError('Input type {} is not supported'.format(npimg.dtype)) | ||
|
||
return Image.fromarray(npimg, mode=mode) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | ||
from utils.image_list import to_image_list | ||
|
||
class BatchCollator(object): | ||
""" | ||
From a list of samples from the dataset, | ||
returns the batched images and targets. | ||
This should be passed to the DataLoader | ||
""" | ||
|
||
def __init__(self, size_divisible=0, df_used=False, boundary=False): | ||
self.size_divisible = size_divisible | ||
self.df_used = df_used | ||
self.boundary = boundary | ||
|
||
def __call__(self, batch): | ||
transposed_batch = list(zip(*batch)) | ||
images = to_image_list(transposed_batch[0], self.size_divisible) | ||
targets = to_image_list(transposed_batch[1], self.size_divisible) | ||
|
||
dfs = to_image_list(transposed_batch[2], self.size_divisible) if self.df_used else None | ||
|
||
dist_maps = to_image_list(transposed_batch[3], self.size_divisible) if self.boundary else None | ||
|
||
return images, targets, dfs, dist_maps | ||
|
Oops, something went wrong.