In [1]:
import tensorflow as tf
import cv2 
import numpy as np
import numpy.random as rng
import matplotlib.pyplot as plt
import matplotlib.cm as cm
%matplotlib inline
import glob
import os
os.chdir("../")



In [101]:
class dataGenerator():
    def __init__(self,cvSplit=0.8,batchSize=5,inputDim=(128,128),outputDim=(32,32)):
        rng.seed(1006)
        self.trainPaths, self.testPaths = [glob.glob(s+"/*[0-9].tif") for s in ["train","test"]]
        self.batchSize = batchSize
        self.inputDim = inputDim
        self.outputDim = outputDim
        print("%d train paths and %d test paths" % (len(self.trainPaths),len(self.testPaths)))
        
        # Split train into CV and non CV (cross validation)
        rng.shuffle(self.trainPaths)
        cvSplitPoint = int(cvSplit*len(self.trainPaths))
        self.trainPathsCV, self.testPathsCV = self.trainPaths[:cvSplitPoint], self.trainPaths[cvSplitPoint:]
        assert len(set(self.trainPathsCV).intersection(set(self.testPathsCV))) == 0 
        
        print("Train set split into %d train CV paths and %d test CV paths" % (len(self.trainPathsCV),len(self.testPathsCV)))     

    def showImg(self,img):
        plt.imshow(img,cmap=cm.gray)
        plt.show()

    def loadImg(self,path,train,augment=0,method=cv2.INTER_CUBIC):
        img = cv2.imread(path,0)
        maskPath = path.replace(".tif","_mask.tif")
        if os.path.exists(path = maskPath):
            mask = cv2.imread(maskPath,0)
        
        if augment == 1 and train == 1:
            rows,cols = img.shape
            
            M = cv2.getRotationMatrix2D((cols/2,rows/2),np.random.uniform(-5,5),1)
            tX, tY = np.random.randint(0,10,2)
            M[0,2] = tX
            M[1,2] = tY
            img,mask = [im[5:rows-5, 5:cols-5] for im in [img,mask]]
            img,mask = [cv2.warpAffine(im,M,(cols,rows),borderMode = 1) for im in [img,mask]]
            
            img = cv2.resize(img,self.inputDim, interpolation = method)
            mask = cv2.resize(mask,self.outputDim, interpolation = method)
            return img,mask
        elif augment == 0 and train == 1:
            img = cv2.resize(img,self.inputDim, interpolation = method)
            mask = cv2.resize(mask,self.outputDim, interpolation = method)
            return img,mask
        elif train == 0:
            img = cv2.resize(img,self.inputDim, interpolation = method)
            return img, _
        
    def gen(self,train):
        if train==1:
            paths = self.trainPathsCV
            nObs = len(paths)
            augment = 1
            print("Training paths length =  %d" % nObs)
        elif train == 0:
            paths = self.testPathsCV
            nObs = len(paths)
            augment = 1
            print("Testing paths length =  %d" % nObs)
        self.idx = 0
        while True:
            batchX = np.empty((self.batchSize,self.inputDim[0],self.inputDim[1],1))
            batchY = np.empty((self.batchSize,self.outputDim[0],self.outputDim[1],1))
            idx = 0
            for i in range(self.idx,min(self.batchSize+self.idx,nObs)):
                x,y = self.loadImg(paths[i],train=train,augment=augment)  
                x.resize(self.inputDim[0],self.inputDim[1],1), y.resize(self.outputDim[0],self.outputDim[1],1)
                batchX[idx],batchY[idx] = x, y
                idx += 1
            self.idx += self.batchSize
            if self.idx >= nObs:
                self.idx = 0
            yield batchX,batchY

     
            
        
        

In [77]:
dataGen = dataGenerator(cvSplit = 0.01,batchSize=30)

5635 train paths and 5508 test paths
Train set split into 56 train CV paths and 5579 test CV paths


In [88]:
x = dataGen.gen(train=1)

In [100]:
x.next()

(array([[[[ 220.],
          [ 220.],
          [ 219.],
          ..., 
          [ 183.],
          [ 189.],
          [ 191.]],
 
         [[ 220.],
          [ 220.],
          [ 219.],
          ..., 
          [ 184.],
          [ 190.],
          [ 190.]],
 
         [[ 217.],
          [ 220.],
          [ 220.],
          ..., 
          [ 185.],
          [ 190.],
          [ 189.]],
 
         ..., 
         [[  34.],
          [  41.],
          [  43.],
          ..., 
          [  37.],
          [  39.],
          [  42.]],
 
         [[  36.],
          [  39.],
          [  40.],
          ..., 
          [  32.],
          [  32.],
          [  32.]],
 
         [[  59.],
          [  55.],
          [  52.],
          ..., 
          [  31.],
          [  31.],
          [  31.]]],
 
 
        [[[ 197.],
          [ 197.],
          [ 182.],
          ..., 
          [ 201.],
          [ 200.],
          [ 200.]],
 
         [[ 199.],
          [ 197.],
          [ 1