In [0]:
try:
    from google.colab import drive
    drive.mount('/content/drive')
    #!git clone https://github.com/mjwock/DeepFLaSH_Pytorch.git /content/drive/My\ Drive/DeepFLaSH_Pytorch/FastAI/
    %cd /content/drive/My\ Drive/DeepFLaSH_Pytorch/FastAI2
    #!git pull
except:
    pass

In [0]:
%load_ext autoreload
%autoreload 2
%matplotlib inline 

!pip install elasticdeform

import os
import pandas as pd
import matplotlib.pyplot as plt

from pathlib import Path

from torch import nn
from torchsummary import summary
from sklearn.model_selection import StratifiedKFold, KFold

from datetime import datetime

#imports for Tile
from tqdm import tqdm
from math import ceil

from deepflash import preproc, unetadaption, utility
from deepflash.fastai_extension import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [0]:
MODEL_DIR = '/content/drive/My Drive/DeepFLaSH_Pytorch/FastAI2/data/model'

TEST_IMAGES_DIR = '/content/drive/My Drive/DeepFLaSH_Pytorch/FastAI2/data/images/red/'
LABEL_DIR = '/content/drive/My Drive/DeepFLaSH_Pytorch/FastAI2/data/temp_data/labels/'

IMAGE_TYPE = 'L'    # 'L' for greyscale, 'RGB'for color, 'P' for palette images

TILE_SHAPE = (540,540)    # desired input size
MASK_SHAPE = (356,356)       # 540-184

SEED = 42

Model

In [0]:
model_id = '201903-0954'
model_name = 'final_model.pkl'

##Load Model

In [0]:
learn = load_learner(f'{MODEL_DIR}/{model_id}',model_name)

In [0]:
data = (CustomSegmentationItemList.from_folder(TEST_IMAGES_DIR, convert_mode=IMAGE_TYPE))
labels = (CustomSegmentationLabelList.from_folder(LABEL_DIR, convert_mode='L'))

##**TileGenerator**

In [0]:
class TileGenerator:
  '''
  TileGenerator to split input images into multiple tiles to be used by learner. 
  Images can be padded to get a prediction of complete image
  '''
  def __init__(
      self,
      data,
      learner,
      tile_shape=(540,540),
      mask_shape=(356,356),
      same_size = True,
      padding_mode = None   # None, 'zeros', 'border' or 'reflection'
  ):
    
    self.data = [image for image in tqdm(data,desc='Loading Data: ')]

    self.learner = learner
    self.tile_shape = tile_shape
    self.mask_shape = mask_shape
    self.same_size = same_size
    
    if padding_mode:
      padding = tuple(np.subtract(tile_shape,mask_shape)//2)
      self.pad_images(padding, padding_mode)
    
    self.data_shapes = [self.data[0].shape]*len(data) if same_size else [item.shape for item in self.data]

    self.tiles = None
    self.tile_dimensions = None
    self.predictions = None
    self.split_tiles()

  def __len__(self):
      return len(self.data)

  def __getitem__(self, index):
      if self.tiles == None:
        return self.data[index]
      else:
        return self.tiles[index]

  def tile_splitter(self,input_shape:tuple):
    """
    Gets the pixelwise regions of the tiles for cropping the image. Tiles need 
    to overlap, if your model doesn't use padding to compensate the cropping of 
    Convolutions.
    :param input_shape: Input shape of the Image in form of (C,H,W)

    :return tiles: pixelwise tile areas (xs,xf,ys,yf) with S(xs,ys) being left 
    upper corner and F(xf,yf) being the lower right corner of our rectangular tile
    :return tile_dimensions: decimal numbers of how many tiles are needed for given
    input image shape in form of (xtiles,ytiles)
    """
    tx, ty = self.tile_shape
    mx, my = self.mask_shape
    px, py = np.subtract((tx,ty),(mx,my)) # x and y padding

    _, dx, dy = input_shape

    # how many tiles are needed (decimal precision)
    xtiles = (dx-px)/mx 
    ytiles = (dy-py)/my 

    # add starting points for full tiles with spacing mx and my
    x_start = [0+mx*ix for ix in range(int(xtiles))]
    y_start = [0+my*iy for iy in range(int(ytiles))]

    # add a last x, y starting point for none-integer tiles
    if not xtiles%1==0:
      x_start.append(dx-tx)

    if not ytiles%1==0:
      y_start.append(dy-ty)

    # build tiles with width tx and height ty
    tiles = []
    for y in y_start:
      for x in x_start:
        tiles.append((x,x+tx,y,y+ty))

    tile_dimensions = (xtiles,ytiles)
    
    return tiles, tile_dimensions

  def split_tiles(self):
    '''
    Splits the images into tiles and saves them to 'self.tiles' as well as the float 
    precision of the needed amount of tiles into 'self.tile_dimensions' in form of 
    a list of tuples (xtiles,ytiles)
    '''
    ds = self.data_shapes

    # if all input images have the same shape
    if self.same_size:
      
      # call tile_splitter to split all images depending on their shape
      tile_regions, tile_dimension =  self.tile_splitter(ds[0])

      tiles = []

      for img in tqdm(self.data,desc='Building Tiles'):
        img_tiles = []
        for region in tile_regions:
          img_tiles.append(self.crop_to_tile(img,region))
        tiles.append(img_tiles)

      self.tiles = tiles
      self.tile_dimensions = [tile_dimension]*len(self.data)
    
    #if input images have different shapes
    else:
      
      tiles = []
      tile_dimensions =[]

      for i, img in enumerate(self.data,desc='Building Tiles'):
        
        # call tile_splitter to split image depending on image shape
        tile_regions, tile_dimension =  self.tile_splitter(ds[i])
        img_tiles = []

        for region in tile_regions:
          img_tiles.append(self.crop_to_tile(img, region))

        tiles.append(img_tiles)
        tile_dimensions.append(tile_dimension)

      self.tiles = tiles
      self.tile_dimensions = tile_dimensions

  def crop_to_tile(self, img:Image, region):
    '''
    Crops input image to rectangular region from corner (xs,ys) to corner (xf,yf)
      :param img: The Image to be cropped
      :param region: tuple defining pixels for cropped region in form (xs,xf,ys,yf),
      s denoting the starting pixels and f denoting the final pixels

      :return: Cropped Image
    '''
    xs,xf,ys,yf = region
    return Image(img.data[:,ys:yf,xs:xf])
  
  def pad_images(self,padding,padding_mode): 
    '''
    Pads all images in self.data 
      :param padding: magnitude of padding
    '''
    assert padding[0] == padding[1], 'For padding: tile_shape and mask_shape need to be squares.'
    self.data = [image.pad(padding[0],padding_mode) for image in tqdm(self.data,desc='Padding images')]

  def display_tiles(self,batch:Iterator,shape:tuple=(3,3),figsize:tuple=(9,9)):
    '''
    Displays given tiles in 'batch' in subplots arranged by 'shape' within figure
    with figsize of 'figsize'
    :param batch: list of Images
    :param shape: rows and colums of subplots (tuple)
    :param figsize: figsize of plt figure
    '''
    f, axarr = plt.subplots(shape[0], shape[1], figsize=figsize)
    f.tight_layout(pad = 0)

    for ax,tile in zip(axarr.flatten(), batch):
      plt.subplots_adjust(wspace=0, hspace=0)
      ax.set_xticklabels([])
      ax.set_yticklabels([])
      tile.show(ax)

  def stitch_helper(self,dimensions,mask_shape,reshape = True):
    '''
    creates a lookup matrix on how the tiles need to be cropped for stitching.
      :args dimensions: 'tile_dimension' on how many tiles were needed for this 
                        image with float precision (N,M)
      :args mask_shape: the shape of the prediction output (W,H)
      :args reshap: reshape to an indexable array with length NxM

      :return: lookup matrix (np.array (NxM,R) or (N,M,R)) with R being the region 
               argument for crop_to_tile()
    '''
    xd,yd = dimensions
    xs,ys = mask_shape
    rows  = ceil(xd)*ceil(yd)
    lookup_matrix = np.ones((ceil(xd),ceil(yd),4))*(0,xs,0,ys)
    
    for ix in range(ceil(xd)):
      xd -= 1
      if xd<0:
        lookup_matrix[ix,:,2] = -xd*xs
      else:
        lookup_matrix[ix,:,2] = 0
    
    for iy in range(ceil(yd)):
      yd -= 1
      if yd<0:
        lookup_matrix[:,iy,0] = -yd*ys
      else:
        lookup_matrix[:,iy,0] = 0
    if reshape:
      return np.reshape(lookup_matrix,(rows,4)).astype(int)
    else:
      return lookup_matrix.astype(int)

  def show_tiles(self, index, crop_to = 'mask', base_size=3):
    '''
    Displays tiles on a set of subplots.
      :param index: index of instance to be displayed (int)
      :param crop_to: crop is either None, 'mask' or 'original' {None,'mask','original'}
      :base_size: base size that each tile
    '''
    batch = copy(self.tiles[index])
    dimensions = self.tile_dimensions[index]
    x, y = np.ceil(dimensions).astype(int)
    shape = (x,y)
    figsize = (x*base_size,y*base_size)

    if crop_to == 'mask':
      for tile in batch:
        tile = tile.crop(self.mask_shape)

    elif crop_to == 'original':  
      lookup_matrix = self.stitch_helper(dimensions,self.mask_shape)  

      for region, tile in zip(lookup_matrix,batch):
        tile = tile.crop(self.mask_shape) 
        tile = self.crop_to_tile(tile,region)

    self.display_tiles(batch,shape,figsize)

  def show_predictions(self):
    print('Not implemented')

  def predict_instance(self, x:Image):
    '''
    Calls FastAIs Learner.predict() function and transforms it into a prediction 
    mask.
    :args x: tile to be predicted (Image)

    :return: prediction (Image)
    :return: probabilities (Tensor)
    '''
    raw = self.learner.predict(x)
    probabilities, values = raw[1].max(0)

    return Image(values.unsqueeze(0)), probabilities.unsqueeze(0)

  def predict_all(self):
    '''
    Predicts all tiles from all images by calling predict_instance() and saves 
    them together with the pixelwise probabilities to self.predictions.
    '''
    solutions = []
    probabilities = []
    
    # iterate all images
    for batch in tqdm(self.tiles,desc='Predicting Tiles'):
      batch_solutions = []
      batch_probabilities = []

      # iterate all tiles in image and predict
      for tile in batch:

        tile_prediction, tile_probabilities = self.predict_instance(tile)
        batch_solutions.append(tile_prediction)
        batch_probabilities.append(tile_probabilities)
      solutions.append(batch_solutions)
      probabilities.append(batch_probabilities)

    self.predictions = [solutions,probabilities]
  
  def stitch_image(self,tile_list,lookup_matrix,is_img=True):
    '''
    Stitch image by concatenating and cropping tiles, depending on lookup_matrix
      :args tile_list: list of tiles making out image
      :args lookup_matrix: list

      :return: concatenated Image 
    '''

    image = None
    
    for brow,lmrow in zip(tile_list, lookup_matrix):

      row = None
      for tile, region in zip(brow,lmrow):

        # crops tile region if region of interest is smaller than given input
        if not np.array_equal(region,[0,self.mask_shape[0],0,self.mask_shape[1]]):
          if is_img:
            tile = self.crop_to_tile(tile, region)
          else:
            xs,xf,ys,yf = region
            tile = tile[:,ys:yf,xs:xf]
        
        #concat columns

        if row is None:
          row = tile.data if is_img else tile
        else:
          row = torch.cat((row,tile.data if is_img else tile),2)

      #concat rows
      if image is None:
        image = row
      else:
        image = torch.cat((image,row),1)
    
    return Image(image) if is_img else image

  def stitch_results(self):
    '''
    Stitches the results in self.predictions
    '''
    
    xs,ys = self.mask_shape
    results = []
    for tiles, probabilities, dimensions in zip(self.predictions[0],self.predictions[1],self.tile_dimensions):
      xd,yd = np.ceil(dimensions)
      lookup_matrix = self.stitch_helper(dimensions,(xs,ys),reshape=False) 
      _reshape = lookup_matrix.shape[0:2]

      tiles_array = np.reshape(tiles,_reshape) 
      prob_array = [probabilities[i:i+xd] for i in range(0, len(data_list), 3)]
      
      image = self.stitch_image(tiles_array,lookup_matrix)
      prob_map = self.stitch_image(prob_array,lookup_matrix,is_img=False)

      results.append(image,prob_map)

    return results

In [0]:
thisTile = TileGenerator(data,learn, TILE_SHAPE,MASK_SHAPE,same_size=True,padding_mode='reflection')

Loading Data: 100%|██████████| 36/36 [00:00<00:00, 65.15it/s]
Padding images: 100%|██████████| 36/36 [00:00<00:00, 204.86it/s]
Building Tiles: 100%|██████████| 36/36 [00:00<00:00, 5896.17it/s]


In [0]:
thisTile.predict_all()
predictions = thisTile.predictions
#thisTile.predictions = predictions

Predicting Tiles: 100%|██████████| 36/36 [01:40<00:00,  2.80s/it]


In [0]:
thisTile.stitch_results()

NameError: ignored

In [0]:
      rows = None
      for brow,lmrow in zip(batch_array, lookup_matrix):

        row = None
        for tile, region in zip(brow,lmrow):

          # crops tile region if region of interest is smaller than given input
          if not np.array_equal(region,[0,xs,0,ys]):
            print(f'Region: {region}, ({type(region)})')
            tile = self.crop_to_tile(tile, region)
          
          #concat columns
          if row is None:
            row = tile.data
          else:
            row = torch.cat((row,tile.data),2)

        #concat rows
        if rows is None:
          rows = row
        else:
          rows = torch.cat((rows,row),1)

In [0]:
thisTile.show_tiles(1,crop_to='original')

In [0]:
a[1]

In [0]:
ax.add_patch(copy(bbox))
                        #(x-offset,y-offset)      ,width      ,height
bbox = patches.Rectangle((padding-1,padding-1),wdt-2*padding,hgt-2*padding,edgecolor='r',linewidth=1,facecolor='none')

pil2tensor(x,np.float32)
