In [87]:
%reload_ext autoreload
%autoreload 2

# Block to allow debug in Pycharm
import sys
sys.path.append('D:\\Projects\\fastai2')
sys.path.reverse()

In [88]:
from fastai.basics import *
from fastai.vision.all import *
from fastai.vision.core import *
from fastai.vision.data import *
from fastai.data.all import *

from pathlib import Path
import gdal

import pdb

import numpy as np
import matplotlib.pyplot as plt

ModuleNotFoundError: No module named 'gdal'

In [None]:
path = Path('./Data')

images_path = path/'images'
labels_path = path/'labels'

print(f'Checking number of files - images: {len([f for f in images_path.iterdir()])})\
      masks:{len([f for f in labels_path.iterdir()])}')
        
# Checking file shapes
idx = 22
img_path = [f for f in images_path.iterdir()][idx]
msk_path = [f for f in labels_path.iterdir()][idx]
      
img = np.load(str(img_path))
msk = np.load(str(msk_path))
      
print(f'Checking shapes - image: {img.shape} mask: {msk.shape}')

In [None]:
# Plotting a sample
_, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(img.transpose((1, 2, 0))[..., [3, 2, 1]]*3.0)
ax[1].imshow(msk)

In [None]:
def open_npy(fn, chnls=None, cls=torch.Tensor):
    im = torch.from_numpy(np.load(str(fn))).type(torch.float32)
    if chnls is not None:
        im = im[chnls]
    return cls(im)

class MSTensorImage(TensorImage):
    
    def __init__(self, x, chnls_first=False):
        self.chnls_first = chnls_first
        
    @classmethod
    def create(cls, data:(Path,str,ndarray), chnls=None, chnls_first=True):
        
        if isinstance(data, Path) or isinstance(data, str):
            if str(data).endswith('npy'): 
                im = open_npy(fn=data, chnls=chnls, cls=torch.Tensor)
        
        elif isinstance(data, ndarray):
            im = torch.from_numpy(data)
        else:
            im = data
            
        return cls(im, chnls_first=chnls_first)
    
    
    def show(self, chnls=[3, 2, 1], bright=1., ctx=None):
            
        if img.ndim > 2:
            visu_img = self[..., chnls] if not self.chnls_first else self.permute([1, 2, 0])[..., chnls]
        else:
            visu_img = self
            
        visu_img = visu_img.squeeze()
        
        visu_img *= bright
        visu_img = np.where(visu_img > 1, 1, visu_img)
        visu_img = np.where(visu_img < 0, 0, visu_img)
        
        plt.imshow(visu_img) if ctx is None else ctx.imshow(visu_img)
        
        return ctx
    
    def __repr__(self):
        
        return (f'MSTensorImage: {self.shape}')

In [None]:
img = MSTensorImage.create(img_path)
print(img)

_, ax = plt.subplots(1, 3, figsize=(12, 4))
img.show(bright=3., ctx=ax[0])
img.show(chnls=[2, 7, 10], ctx=ax[1])
img.show(chnls=[11], ctx=ax[2])

In [None]:
mask = TensorMask(open_npy(msk_path))
print(mask.shape)

_, ax = plt.subplots(1, 2, figsize=(10, 5))
img.show(bright=3., ctx=ax[0])
mask.show(ctx=ax[1])

In [None]:
def get_lbl_fn(img_fn: Path):
    lbl_path = img_fn.parent.parent/'labels'
    lbl_name = img_fn.name
    return (lbl_path/lbl_name)

db = Datablock(blocks=(TransformBlock(type_tfms=partial(MSTensorImage.create, chnls_first=True)),
                       TransformBlock(type_tfms=[get_lbl_fn, partial(open_npy, cls=TensorMask)],
                                      item_tfms=AddMaskCodes(codes=['clear', 'water', 'shadow'])),
                      ),
               get_items=partial(get_files, extensions='.npy'),
               splitter=RandomSplitter(valid_pct=0.1)
              )

db.summary(source=imgs_path)
               