Skip to content
Branch: master
Find file Copy path
Find file Copy path
621 lines (533 sloc) 26.5 KB
"`Image` provides support to convert, transform and show images"
from ..torch_core import *
from ..basic_data import *
from ..layers import MSELossFlat
from io import BytesIO
import PIL
__all__ = ['PIL', 'Image', 'ImageBBox', 'ImageSegment', 'ImagePoints', 'FlowField', 'RandTransform', 'TfmAffine', 'TfmCoord',
'TfmCrop', 'TfmLighting', 'TfmPixel', 'Transform', 'bb2hw', 'image2np', 'open_image', 'open_mask', 'tis2hw',
'pil2tensor', 'scale_flow', 'show_image', 'CoordFunc', 'TfmList', 'open_mask_rle', 'rle_encode',
'rle_decode', 'ResizeMethod', 'plot_flat', 'plot_multi', 'show_multi', 'show_all']
ResizeMethod = IntEnum('ResizeMethod', 'CROP PAD SQUISH NO')
def pil2tensor(image:Union[NPImage,NPArray],dtype:np.dtype)->TensorImage:
"Convert PIL style `image` array to torch style image tensor."
a = np.asarray(image)
if a.ndim==2 : a = np.expand_dims(a,2)
a = np.transpose(a, (1, 0, 2))
a = np.transpose(a, (2, 1, 0))
return torch.from_numpy(a.astype(dtype, copy=False) )
def image2np(image:Tensor)->np.ndarray:
"Convert from torch style `image` to numpy/matplotlib style."
res = image.cpu().permute(1,2,0).numpy()
return res[...,0] if res.shape[2]==1 else res
def bb2hw(a:Collection[int])->np.ndarray:
"Convert bounding box points from (width,height,center) to (height,width,top,left)."
return np.array([a[1],a[0],a[3]-a[1],a[2]-a[0]])
def tis2hw(size:Union[int,TensorImageSize]) -> Tuple[int,int]:
"Convert `int` or `TensorImageSize` to (height,width) of an image."
if type(size) is str: raise RuntimeError("Expected size to be an int or a tuple, got a string.")
return listify(size, 2) if isinstance(size, int) else listify(size[-2:],2)
def _draw_outline(o:Patch, lw:int):
"Outline bounding box onto image `Patch`."
linewidth=lw, foreground='black'), patheffects.Normal()])
def _draw_rect(ax:plt.Axes, b:Collection[int], color:str='white', text=None, text_size=14):
"Draw bounding box on `ax`."
patch = ax.add_patch(patches.Rectangle(b[:2], *b[-2:], fill=False, edgecolor=color, lw=2))
_draw_outline(patch, 4)
if text is not None:
patch = ax.text(*b[:2], text, verticalalignment='top', color=color, fontsize=text_size, weight='bold')
def _get_default_args(func:Callable):
return {k: v.default
for k, v in inspect.signature(func).parameters.items()
if v.default is not inspect.Parameter.empty}
class FlowField():
"Wrap together some coords `flow` with a `size`."
CoordFunc = Callable[[FlowField, ArgStar, KWArgs], LogitTensorImage]
class Image(ItemBase):
"Support applying transforms to image data in `px`."
def __init__(self, px:Tensor):
self._px = px
self.sample_kwargs = {}
def set_sample(self, **kwargs)->'ImageBase':
"Set parameters that control how we `grid_sample` the image after transforms are applied."
self.sample_kwargs = kwargs
return self
def clone(self):
"Mimic the behavior of torch.clone for `Image` objects."
return self.__class__(self.px.clone())
def shape(self)->Tuple[int,int,int]: return self._px.shape
def size(self)->Tuple[int,int]: return self.shape[-2:]
def device(self)->torch.device: return self._px.device
def __repr__(self): return f'{self.__class__.__name__} {tuple(self.shape)}'
def _repr_png_(self): return self._repr_image_format('png')
def _repr_jpeg_(self): return self._repr_image_format('jpeg')
def _repr_image_format(self, format_str):
with BytesIO() as str_buffer:
plt.imsave(str_buffer, image2np(self.px), format=format_str)
return str_buffer.getvalue()
def apply_tfms(self, tfms:TfmList, do_resolve:bool=True, xtra:Optional[Dict[Callable,dict]]=None,
size:Optional[Union[int,TensorImageSize]]=None, resize_method:ResizeMethod=None,
mult:int=None, padding_mode:str='reflection', mode:str='bilinear', remove_out:bool=True)->TensorImage:
"Apply all `tfms` to the `Image`, if `do_resolve` picks value for random args."
if not (tfms or xtra or size): return self
tfms = listify(tfms)
xtra = ifnone(xtra, {})
default_rsz = ResizeMethod.SQUISH if (size is not None and is_listy(size)) else ResizeMethod.CROP
resize_method = ifnone(resize_method, default_rsz)
if resize_method <= 2 and size is not None: tfms = self._maybe_add_crop_pad(tfms)
tfms = sorted(tfms, key=lambda o: o.tfm.order)
if do_resolve: _resolve_tfms(tfms)
x = self.clone()
x.set_sample(padding_mode=padding_mode, mode=mode, remove_out=remove_out)
if size is not None:
crop_target = _get_crop_target(size, mult=mult)
if resize_method in (ResizeMethod.CROP,ResizeMethod.PAD):
target = _get_resize_target(x, crop_target, do_crop=(resize_method==ResizeMethod.CROP))
elif resize_method==ResizeMethod.SQUISH: x.resize((x.shape[0],) + crop_target)
else: size = x.size
size_tfms = [o for o in tfms if isinstance(o.tfm,TfmCrop)]
for tfm in tfms:
if tfm.tfm in xtra: x = tfm(x, **xtra[tfm.tfm])
elif tfm in size_tfms:
if resize_method in (ResizeMethod.CROP,ResizeMethod.PAD):
x = tfm(x, size=_get_crop_target(size,mult=mult), padding_mode=padding_mode)
else: x = tfm(x)
return x.refresh()
def refresh(self)->None:
"Apply any logit, flow, or affine transfers that have been sent to the `Image`."
if self._logit_px is not None:
self._px = self._logit_px.sigmoid_()
self._logit_px = None
if self._affine_mat is not None or self._flow is not None:
self._px = _grid_sample(self._px, self.flow, **self.sample_kwargs)
self.sample_kwargs = {}
self._flow = None
return self
def save(self, fn:PathOrStr):
"Save the image to `fn`."
x = image2np(*255).astype(np.uint8)
def px(self)->TensorImage:
"Get the tensor pixel buffer."
return self._px
def px(self,v:TensorImage)->None:
"Set the pixel buffer to `v`."
def flow(self)->FlowField:
"Access the flow-field grid after applying queued affine transforms."
if self._flow is None:
self._flow = _affine_grid(self.shape)
if self._affine_mat is not None:
self._flow = _affine_mult(self._flow,self._affine_mat)
self._affine_mat = None
return self._flow
def flow(self,v:FlowField): self._flow=v
def lighting(self, func:LightingFunc, *args:Any, **kwargs:Any):
"Equivalent to `image = sigmoid(func(logit(image)))`."
self.logit_px = func(self.logit_px, *args, **kwargs)
return self
def pixel(self, func:PixelFunc, *args, **kwargs)->'Image':
"Equivalent to `image.px = func(image.px)`."
self.px = func(self.px, *args, **kwargs)
return self
def coord(self, func:CoordFunc, *args, **kwargs)->'Image':
"Equivalent to `image.flow = func(image.flow, image.size)`."
self.flow = func(self.flow, *args, **kwargs)
return self
def affine(self, func:AffineFunc, *args, **kwargs)->'Image':
"Equivalent to `image.affine_mat = image.affine_mat @ func()`."
m = tensor(func(*args, **kwargs)).to(self.device)
self.affine_mat = self.affine_mat @ m
return self
def resize(self, size:Union[int,TensorImageSize])->'Image':
"Resize the image to `size`, size can be a single int."
assert self._flow is None
if isinstance(size, int): size=(self.shape[0], size, size)
if tuple(size)==tuple(self.shape): return self
self.flow = _affine_grid(size)
return self
def affine_mat(self)->AffineMatrix:
"Get the affine matrix that will be applied by `refresh`."
if self._affine_mat is None:
self._affine_mat = torch.eye(3).to(self.device)
return self._affine_mat
def affine_mat(self,v)->None: self._affine_mat=v
def logit_px(self)->LogitTensorImage:
"Get logit(image.px)."
if self._logit_px is None: self._logit_px = logit_(self.px)
return self._logit_px
def logit_px(self,v:LogitTensorImage)->None: self._logit_px=v
def data(self)->TensorImage:
"Return this images pixels as a tensor."
return self.px
def show(self, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True,
cmap:str=None, y:Any=None, **kwargs):
"Show image on `ax` with `title`, using `cmap` if single-channel, overlaid with optional `y`"
cmap = ifnone(cmap, defaults.cmap)
ax = show_image(self, ax=ax, hide_axis=hide_axis, cmap=cmap, figsize=figsize)
if y is not None:, **kwargs)
if title is not None: ax.set_title(title)
class ImageSegment(Image):
"Support applying transforms to segmentation masks data in `px`."
def lighting(self, func:LightingFunc, *args:Any, **kwargs:Any)->'Image': return self
def refresh(self):
self.sample_kwargs['mode'] = 'nearest'
return super().refresh()
def data(self)->TensorImage:
"Return this image pixels as a `LongTensor`."
return self.px.long()
def show(self, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True,
cmap:str='tab20', alpha:float=0.5, **kwargs):
"Show the `ImageSegment` on `ax`."
ax = show_image(self, ax=ax, hide_axis=hide_axis, cmap=cmap, figsize=figsize,
interpolation='nearest', alpha=alpha, vmin=0)
if title: ax.set_title(title)
def reconstruct(self, t:Tensor): return ImageSegment(t)
class ImagePoints(Image):
"Support applying transforms to a `flow` of points."
def __init__(self, flow:FlowField, scale:bool=True, y_first:bool=True):
if scale: flow = scale_flow(flow)
if y_first: flow.flow = flow.flow.flip(1)
self._flow = flow
self._affine_mat = None
self.flow_func = []
self.sample_kwargs = {}
self.transformed = False
self.loss_func = MSELossFlat()
def clone(self):
"Mimic the behavior of torch.clone for `ImagePoints` objects."
return self.__class__(FlowField(self.size, self.flow.flow.clone()), scale=False, y_first=False)
def shape(self)->Tuple[int,int,int]: return (1, *self._flow.size)
def size(self)->Tuple[int,int]: return self._flow.size
def size(self, sz:int): self._flow.size=sz
def device(self)->torch.device: return self._flow.flow.device
def __repr__(self): return f'{self.__class__.__name__} {tuple(self.size)}'
def _repr_image_format(self, format_str): return None
def flow(self)->FlowField:
"Access the flow-field grid after applying queued affine and coord transforms."
if self._affine_mat is not None:
self._flow = _affine_inv_mult(self._flow, self._affine_mat)
self._affine_mat = None
self.transformed = True
if len(self.flow_func) != 0:
for f in self.flow_func[::-1]: self._flow = f(self._flow)
self.transformed = True
self.flow_func = []
return self._flow
def flow(self,v:FlowField): self._flow=v
def coord(self, func:CoordFunc, *args, **kwargs)->'ImagePoints':
"Put `func` with `args` and `kwargs` in `self.flow_func` for later."
if 'invert' in kwargs: kwargs['invert'] = True
else: warn(f"{func.__name__} isn't implemented for {self.__class__}.")
self.flow_func.append(partial(func, *args, **kwargs))
return self
def lighting(self, func:LightingFunc, *args:Any, **kwargs:Any)->'ImagePoints': return self
def pixel(self, func:PixelFunc, *args, **kwargs)->'ImagePoints':
"Equivalent to `self = func_flow(self)`."
self = func(self, *args, **kwargs)
return self
def refresh(self) -> 'ImagePoints':
return self
def resize(self, size:Union[int,TensorImageSize]) -> 'ImagePoints':
"Resize the image to `size`, size can be a single int."
if isinstance(size, int): size=(1, size, size)
self._flow.size = size[1:]
return self
def data(self)->Tensor:
"Return the points associated to this object."
flow = self.flow #This updates flow before we test if some transforms happened
if self.transformed:
if 'remove_out' not in self.sample_kwargs or self.sample_kwargs['remove_out']:
flow = _remove_points_out(flow)
return flow.flow.flip(1)
def show(self, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True, **kwargs):
"Show the `ImagePoints` on `ax`."
if ax is None: _,ax = plt.subplots(figsize=figsize)
pnt = scale_flow(FlowField(self.size,, to_unit=False).flow.flip(1)
params = {'s': 10, 'marker': '.', 'c': 'r', **kwargs}
ax.scatter(pnt[:, 0], pnt[:, 1], **params)
if hide_axis: ax.axis('off')
if title: ax.set_title(title)
class ImageBBox(ImagePoints):
"Support applying transforms to a `flow` of bounding boxes."
def __init__(self, flow:FlowField, scale:bool=True, y_first:bool=True, labels:Collection=None,
classes:dict=None, pad_idx:int=0):
super().__init__(flow, scale, y_first)
self.pad_idx = pad_idx
if labels is not None and len(labels)>0 and not isinstance(labels[0],Category):
labels = array([Category(l,classes[l]) for l in labels])
self.labels = labels
def clone(self) -> 'ImageBBox':
"Mimic the behavior of torch.clone for `Image` objects."
flow = FlowField(self.size, self.flow.flow.clone())
return self.__class__(flow, scale=False, y_first=False, labels=self.labels, pad_idx=self.pad_idx)
def create(cls, h:int, w:int, bboxes:Collection[Collection[int]], labels:Collection=None, classes:dict=None,
pad_idx:int=0, scale:bool=True)->'ImageBBox':
"Create an ImageBBox object from `bboxes`."
if isinstance(bboxes, np.ndarray) and bboxes.dtype == np.object: bboxes = np.array([bb for bb in bboxes])
bboxes = tensor(bboxes).float()
tr_corners =[bboxes[:,0][:,None], bboxes[:,3][:,None]], 1)
bl_corners = bboxes[:,1:3].flip(1)
bboxes =[bboxes[:,:2], tr_corners, bl_corners, bboxes[:,2:]], 1)
flow = FlowField((h,w), bboxes.view(-1,2))
return cls(flow, labels=labels, classes=classes, pad_idx=pad_idx, y_first=True, scale=scale)
def _compute_boxes(self) -> Tuple[LongTensor, LongTensor]:
bboxes = self.flow.flow.flip(1).view(-1, 4, 2).contiguous().clamp(min=-1, max=1)
mins, maxes = bboxes.min(dim=1)[0], bboxes.max(dim=1)[0]
bboxes =[mins, maxes], 1)
mask = (bboxes[:,2]-bboxes[:,0] > 0) * (bboxes[:,3]-bboxes[:,1] > 0)
if len(mask) == 0: return tensor([self.pad_idx] * 4), tensor([self.pad_idx])
res = bboxes[mask]
if self.labels is None: return res,None
return res, self.labels[to_np(mask).astype(bool)]
def data(self)->Union[FloatTensor, Tuple[FloatTensor,LongTensor]]:
bboxes,lbls = self._compute_boxes()
lbls = np.array([ for o in lbls]) if lbls is not None else None
return bboxes if lbls is None else (bboxes, lbls)
def show(self, y:Image=None, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True,
color:str='white', **kwargs):
"Show the `ImageBBox` on `ax`."
if ax is None: _,ax = plt.subplots(figsize=figsize)
bboxes, lbls = self._compute_boxes()
h,w = self.flow.size
bboxes.add_(1).mul_(torch.tensor([h/2, w/2, h/2, w/2])).long()
for i, bbox in enumerate(bboxes):
if lbls is not None: text = str(lbls[i])
else: text=None
_draw_rect(ax, bb2hw(bbox), text=text, color=color)
def open_image(fn:PathOrStr, div:bool=True, convert_mode:str='RGB', cls:type=Image,
"Return `Image` object created from image in file `fn`."
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning) # EXIF warning from TiffPlugin
x =
if after_open: x = after_open(x)
x = pil2tensor(x,np.float32)
if div: x.div_(255)
return cls(x)
def open_mask(fn:PathOrStr, div=False, convert_mode='L', after_open:Callable=None)->ImageSegment:
"Return `ImageSegment` object create from mask in file `fn`. If `div`, divides pixel values by 255."
return open_image(fn, div=div, convert_mode=convert_mode, cls=ImageSegment, after_open=after_open)
def open_mask_rle(mask_rle:str, shape:Tuple[int, int])->ImageSegment:
"Return `ImageSegment` object create from run-length encoded string in `mask_lre` with size in `shape`."
x = FloatTensor(rle_decode(str(mask_rle), shape).astype(np.uint8))
x = x.view(shape[1], shape[0], -1)
return ImageSegment(x.permute(2,0,1))
def rle_encode(img:NPArrayMask)->str:
"Return run-length encoding string from `img`."
pixels = np.concatenate([[0], img.flatten() , [0]])
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
runs[1::2] -= runs[::2]
return ' '.join(str(x) for x in runs)
def rle_decode(mask_rle:str, shape:Tuple[int,int])->NPArrayMask:
"Return an image array from run-length encoded string `mask_rle` with `shape`."
s = mask_rle.split()
starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
starts -= 1
ends = starts + lengths
img = np.zeros(shape[0]*shape[1], dtype=np.uint)
for low, up in zip(starts, ends): img[low:up] = 1
return img.reshape(shape)
def show_image(img:Image, ax:plt.Axes=None, figsize:tuple=(3,3), hide_axis:bool=True, cmap:str='binary',
alpha:float=None, **kwargs)->plt.Axes:
"Display `Image` in notebook."
if ax is None: fig,ax = plt.subplots(figsize=figsize)
ax.imshow(image2np(, cmap=cmap, alpha=alpha, **kwargs)
if hide_axis: ax.axis('off')
return ax
def scale_flow(flow, to_unit=True):
"Scale the coords in `flow` to -1/1 or the image size depending on `to_unit`."
s = tensor([flow.size[0]/2,flow.size[1]/2])[None]
if to_unit: flow.flow = flow.flow/s-1
else: flow.flow = (flow.flow+1)*s
return flow
def _remove_points_out(flow:FlowField):
pad_mask = (flow.flow[:,0] >= -1) * (flow.flow[:,0] <= 1) * (flow.flow[:,1] >= -1) * (flow.flow[:,1] <= 1)
flow.flow = flow.flow[pad_mask]
return flow
class Transform():
"Utility class for adding probability and wrapping support to transform `func`."
def __init__(self, func:Callable, order:Optional[int]=None):
"Create a transform for `func` and assign it an priority `order`, attach to `Image` class."
if order is not None: self.order=order
self.func.__name__ = func.__name__[1:] #To remove the _ that begins every transform function.
functools.update_wrapper(self, self.func)
self.func.__annotations__['return'] = Image
self.params = copy(func.__annotations__)
self.def_args = _get_default_args(func)
setattr(Image, func.__name__,
lambda x, *args, **kwargs: self.calc(x, *args, **kwargs))
def __call__(self, *args:Any, p:float=1., is_random:bool=True, use_on_y:bool=True, **kwargs:Any)->Image:
"Calc now if `args` passed; else create a transform called prob `p` if `random`."
if args: return self.calc(*args, **kwargs)
else: return RandTransform(self, kwargs=kwargs, is_random=is_random, use_on_y=use_on_y, p=p)
def calc(self, x:Image, *args:Any, **kwargs:Any)->Image:
"Apply to image `x`, wrapping it if necessary."
if self._wrap: return getattr(x, self._wrap)(self.func, *args, **kwargs)
else: return self.func(x, *args, **kwargs)
def name(self)->str: return self.__class__.__name__
def __repr__(self)->str: return f'{} ({self.func.__name__})'
class RandTransform():
"Wrap `Transform` to add randomized execution."
resolved:dict = field(default_factory=dict)
do_run:bool = True
is_random:bool = True
use_on_y:bool = True
def __post_init__(self): functools.update_wrapper(self, self.tfm)
def resolve(self)->None:
"Bind any random variables in the transform."
if not self.is_random:
self.resolved = {**self.tfm.def_args, **self.kwargs}
self.resolved = {}
# for each param passed to tfm...
for k,v in self.kwargs.items():
# ...if it's annotated, call that fn...
if k in self.tfm.params:
rand_func = self.tfm.params[k]
self.resolved[k] = rand_func(*listify(v))
# ...otherwise use the value directly
else: self.resolved[k] = v
# use defaults for any args not filled in yet
for k,v in self.tfm.def_args.items():
if k not in self.resolved: self.resolved[k]=v
# anything left over must be callable without params
for k,v in self.tfm.params.items():
if k not in self.resolved and k!='return': self.resolved[k]=v()
self.do_run = rand_bool(self.p)
def order(self)->int: return self.tfm.order
def __call__(self, x:Image, *args, **kwargs)->Image:
"Randomly execute our tfm on `x`."
return self.tfm(x, *args, **{**self.resolved, **kwargs}) if self.do_run else x
def _resolve_tfms(tfms:TfmList):
"Resolve every tfm in `tfms`."
for f in listify(tfms): f.resolve()
def _grid_sample(x:TensorImage, coords:FlowField, mode:str='bilinear', padding_mode:str='reflection', remove_out:bool=True)->TensorImage:
"Resample pixels in `coords` from `x` by `mode`, with `padding_mode` in ('reflection','border','zeros')."
coords = coords.flow.permute(0, 3, 1, 2).contiguous().permute(0, 2, 3, 1) # optimize layout for grid_sample
if mode=='bilinear': # hack to get smoother downwards resampling
mn,mx = coords.min(),coords.max()
# max amount we're affine zooming by (>1 means zooming in)
z = 1/(mx-mn).item()*2
# amount we're resizing by, with 100% extra margin
d = min(x.shape[1]/coords.shape[1], x.shape[2]/coords.shape[2])/2
# If we're resizing up by >200%, and we're zooming less than that, interpolate first
if d>1 and d>z: x = F.interpolate(x[None], scale_factor=1/d, mode='area')[0]
return F.grid_sample(x[None], coords, mode=mode, padding_mode=padding_mode)[0]
def _affine_grid(size:TensorImageSize)->FlowField:
size = ((1,)+size)
N, C, H, W = size
grid = FloatTensor(N, H, W, 2)
linear_points = torch.linspace(-1, 1, W) if W > 1 else tensor([-1])
grid[:, :, :, 0] = torch.ger(torch.ones(H), linear_points).expand_as(grid[:, :, :, 0])
linear_points = torch.linspace(-1, 1, H) if H > 1 else tensor([-1])
grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W)).expand_as(grid[:, :, :, 1])
return FlowField(size[2:], grid)
def _affine_mult(c:FlowField,m:AffineMatrix)->FlowField:
"Multiply `c` by `m` - can adjust for rectangular shaped `c`."
if m is None: return c
size = c.flow.size()
h,w = c.size
m[0,1] *= h/w
m[1,0] *= w/h
c.flow = c.flow.view(-1,2)
c.flow = torch.addmm(m[:2,2], c.flow, m[:2,:2].t()).view(size)
return c
def _affine_inv_mult(c, m):
"Applies the inverse affine transform described in `m` to `c`."
size = c.flow.size()
h,w = c.size
m[0,1] *= h/w
m[1,0] *= w/h
c.flow = c.flow.view(-1,2)
a = torch.inverse(m[:2,:2].t())
c.flow = - m[:2,2], a).view(size)
return c
class TfmAffine(Transform):
"Decorator for affine tfm funcs."
order,_wrap = 5,'affine'
class TfmPixel(Transform):
"Decorator for pixel tfm funcs."
order,_wrap = 10,'pixel'
class TfmCoord(Transform):
"Decorator for coord tfm funcs."
order,_wrap = 4,'coord'
class TfmCrop(TfmPixel):
"Decorator for crop tfm funcs."
class TfmLighting(Transform):
"Decorator for lighting tfm funcs."
order,_wrap = 8,'lighting'
def _round_multiple(x:int, mult:int=None)->int:
"Calc `x` to nearest multiple of `mult`."
return (int(x/mult+0.5)*mult) if mult is not None else x
def _get_crop_target(target_px:Union[int,TensorImageSize], mult:int=None)->Tuple[int,int]:
"Calc crop shape of `target_px` to nearest multiple of `mult`."
target_r,target_c = tis2hw(target_px)
return _round_multiple(target_r,mult),_round_multiple(target_c,mult)
def _get_resize_target(img, crop_target, do_crop=False)->TensorImageSize:
"Calc size of `img` to fit in `crop_target` - adjust based on `do_crop`."
if crop_target is None: return None
ch,r,c = img.shape
target_r,target_c = crop_target
ratio = (min if do_crop else max)(r/target_r, c/target_c)
return ch,int(round(r/ratio)),int(round(c/ratio)) #Sometimes those are numpy numbers and round doesn't return an int.
def plot_flat(r, c, figsize):
"Shortcut for `enumerate(subplots.flatten())`"
return enumerate(plt.subplots(r, c, figsize=figsize)[1].flatten())
def plot_multi(func:Callable[[int,int,plt.Axes],None], r:int=1, c:int=1, figsize:Tuple=(12,6)):
"Call `func` for every combination of `r,c` on a subplot"
axes = plt.subplots(r, c, figsize=figsize)[1]
for i in range(r):
for j in range(c): func(i,j,axes[i,j])
def show_multi(func:Callable[[int,int],Image], r:int=1, c:int=1, figsize:Tuple=(9,9)):
"Call `func(i,j).show(ax)` for every combination of `r,c`"
plot_multi(lambda i,j,ax: func(i,j).show(ax), r, c, figsize=figsize)
def show_all(imgs:Collection[Image], r:int=1, c:Optional[int]=None, figsize=(12,6)):
"Show all `imgs` using `r` rows"
imgs = listify(imgs)
if c is None: c = len(imgs)//r
for i,ax in plot_flat(r,c,figsize): imgs[i].show(ax)
You can’t perform that action at this time.