In [104]:
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pylab import rcParams
from scipy.ndimage.interpolation import rotate, zoom
import scipy as sp
import cv2
%matplotlib inline

In [295]:
def getImgMass(im):
    
    height, width = im.shape[0], im.shape[1]
    axes = [np.arange(height).astype('float32').repeat(width).reshape(height,width),
            np.arange(width).astype('float32').repeat(height).reshape(width,height).T]
    
    #Calculate centers of mass for color channels
    mass= [im[:,:,col].sum() for col in range(3)]
    m  =  [[(im[:,:,col]*axes[i]).sum()/mass[col] for i in range(2)] for col in range(3)]

    #Calculate second moments for red channel;
    mm =  [[(im[:,:,0]*axes[i]*axes[j]).sum()/mass[0] - m[0][i]*m[0][j]
             for i in range(2)]
             for j in range(2)]
    
    # If there is no green channel then make its centroid == red centroid for rotation purposes
    for i in range(len(m)):
        m[i] = m[i][::-1]
        if any([np.isnan(x) for x in m[i]]) == True:
            m[i] = m[0]
    
    return m,mm

class imgCropRotate():
    def __init__(self,imgPath,level,justRed,thresh=50,levelCropDict={1:600,2:200,3:100},figsize=(20,10)):
        self.levelCropDict = levelCropDict
        self.imgPath = imgPath
        self.level = level
        levelName = "w" + str(level) + "_"
        maskName = imgPath.replace(levelName,"m"+str(level)+"_")
        self.dstPath = imgPath.replace(levelName,"w"+str(level+1)+"_")
        assert level, "Please specify which image level"
        x,mask = Image.open(imgPath), Image.open(maskName)
        w, h = x.size

        self.maskOrig = np.array(mask.resize((w,h)))
        
        x = np.array(x)
        self.mask = self.maskOrig.copy()
        self.mask[self.maskOrig<=threshold] = 0
        self.x = x
        #self.mask = maskThr
        self.figsize = figsize
        self.mask.shape
        
        if justRed == 1: #I.e. just the head
            r = self.mask
            notRed = (r[:,:,0] < 150) | (r[:,:,1] > 100) | (r[:,:,2] > 100)
            r[notRed] = 0
            self.mask = r
            

        self.m, self.mm = getImgMass(self.mask)
        
        
    def show(self):
        plt.figure(figsize=self.figsize)
        toDis = np.hstack((self.x,self.maskOrig,self.mask))
        plt.title("Original and mask")
        plt.imshow(toDis)
        plt.show()

    def showCentroids(self):
        cpy = self.x.copy()
        for i in range(2):
            if i == 1:
                c = (255,255,0)
            else:
                c = (255,0,0)
            cv2.circle(cpy,tuple(map(int,self.m[i])),150 + i*50,c,40)
            plt.imshow(cpy)
        plt.show()

    def crop(self,save,show=0):
        


        try:
            maxY,maxX,c = self.x.shape
  
            mx, my = map(int,self.m[0])
            w2Dict = self.levelCropDict
            w2 = w2Dict[self.level]
            x1 = mx-w2
            y2 = my+w2
            y1 = my-w2
            x2 = mx+w2
            
            if x1 < 0:
                x1 = 0
                x2 = x1 + 2*w2
            elif x2 > maxX:
                x2 = maxX-1
                x1 = x2 - 2*w2
            if y1 < 0:
                y1 = 0
                y2 = y2 + 2*w2
            elif y2 > maxY:
                y2 = maxY - 1
                y1 = y2 - 2*w2
                
            
            
            print("x1 = %d,x2 = %d,y1 = %d,y2 = %d",x1,x2,y1,y2)

            cropped = self.x
            cropped = cropped[y1:y2,x1:x2]
            print(cropped.shape)
            cropped = Image.fromarray(cropped)
            
  
            
            if save == 1:
                cropped.save(self.dstPath)
            
            if show == 1:
                plt.imshow(cropped)
                plt.title(self.imgPath)
                plt.show()
                
        except ValueError:
            plt.title(imgPath + " (COULDN'T CROP)")
            plt.imshow(np.hstack((self.x,self.mask)))
            plt.show()


In [296]:
test = False
toy = True
if test == False:
    mid = "whale_*" # not test
else:
    mid = "*"
    
#mid = "whale_36648"
level1 = glob.glob("../imgs/"+mid+"/w1_*") # 1 First 
level2 = glob.glob("../imgs/"+mid+"/w2_*")  # 2 Second
level3 = glob.glob("../imgs/"+mid+"/w3_*")  # 3 Second    
    
    
if toy == True:
    import pandas as pd
    def getPath(row):
        return "../imgs/"+ row.whaleID + "/" + row.Image.replace("w_","w1_")

    train = pd.read_csv("../trainCV.csv")
    test = pd.read_csv("../testCV.csv")
    toyPaths = pd.concat([train,test]).apply(getPath,1)

    level1 = list(toyPaths)

#print(level1[:100])

In [None]:
imgPaths = level1
nObs = len(imgPaths)
level = 1
justRed = 1
threshold = 0
levelCropDict={1:700,2:200,3:100}

for i in range(nObs)[:]:
    
    imgPath = imgPaths[i]
    print(imgPath)
    im = imgCropRotate(imgPath,level,justRed,levelCropDict=levelCropDict, thresh=threshold)
    show = 1
    if i % 5 == 0:
        print("%d of %d" %(i,nObs))
    show = 1
    im.show()
    im.showCentroids()


    im.crop(save=1,show=show)
   