In [1]:
from fastai.vision import *

In [None]:
class SegmentationProcessor(PreProcessor):
    "`PreProcessor` that stores the classes for segmentation."
    def __init__(self, ds:ItemList): self.classes = ds.classes
    def process(self, ds:ItemList):  ds.classes,ds.c = self.classes,len(self.classes)

class SegLabelList(ImageList):
    "`ItemList` for segmentation masks."
    _processor=SegmentationProcessor
    def __init__(self, items:Iterator, classes:Collection=None, ai = None, aj = None, **kwargs):
        super().__init__(items, **kwargs)
        self.copy_new.append('classes')
        self.classes,self.loss_func = classes,CrossEntropyFlat(axis=1)
        self.ai = ai
        self.aj = aj
        
    def open(self, fn):        
        m = imread(fn)
        m = convert_to_float(m)
        m = np.expand_dims(m, 0)
        return ImageSegment(tensor(m))        
    def analyze_pred(self, pred, thresh:float=0.5): return pred.argmax(dim=0)[None]
    def reconstruct(self, t:Tensor): return ImageSegment(t)

class SegItemList(ImageList):
    "`ItemList` suitable for segmentation tasks."
    _label_cls,_square_show_res = SegLabelList,False
    def open(self, fn):        
        x = imread(fn)
        x = convert_to_float(x)
        t = imread(get_t_fn(fn))
        t = convert_to_float(t)
        im = torch.stack((tensor(x),tensor(t)), dim = 0)
        return Image(tensor(im))


In [1]:
from skimage.transform import rescale, resize, downscale_local_mean

In [None]:
class SegItemListSmall(ImageList):
    "`ItemList` suitable for segmentation tasks."
    _label_cls,_square_show_res = SegLabelList,False
    def open(self, fn):        
        x = imread(fn)
        x = convert_to_float(x)
        x = downscale_local_mean(x,(2,2))
        t = imread(get_t_fn(fn))
        t = convert_to_float(t)
        t = downscale_local_mean(t,(2,2))
        im = torch.stack((tensor(x),tensor(t)), dim = 0)
        return Image(tensor(im))