In [None]:
import numpy as np
import pandas as pd

import json
import matplotlib.pyplot as plt

import skimage
import skimage.io as io
import skimage.transform as tr
import skimage.color as color

import skimage.morphology as morph

from matplotlib.patches import Rectangle
import numpy as np
import pandas as pd

from scipy.optimize import differential_evolution,shgo

import xml.etree.ElementTree as ET
import os

In [None]:
def get_mask(bnd_box_array,shape=None):
    if shape == None:  mask = np.zeros([bnd_box_array[-1][1],bnd_box_array[-1][3]])
    else: mask = np.zeros(shape)
    
    for box in bnd_box_array:
        box = box.astype(int)
        mask[box[0]:box[1],box[2]:box[3]] = 1
    return mask

def get_ortho_and_mask(img_date):
    ortho = skimage.img_as_float32(io.imread('Dataset/' + img_date + '.png'))
    ortho_greeness = 2*ortho[:,:,1] - ortho[:,:,0] - ortho[:,:,2]
    grd_truth = io.imread('Dataset/Ground truth/images/'+img_date+'.png',as_gray=True)
    
    return ortho,ortho_greeness,grd_truth

In [None]:
def read_field_structure(json_file):
    with open(json_file, 'r') as j:
            json_dict = json.load(j)
    return json_dict

def get_unit_field_structure(field_struct_dict):
    box_height = field_struct_dict['plot_size_along_range']
    box_width = field_struct_dict['plot_size_across_range'] 
    
    range_gaps = field_struct_dict['range_gaps']
    plot_gap = field_struct_dict['plot_gaps']
    
    num_plots = field_struct_dict['n_plots_per_range']
    n_ranges = field_struct_dict['n_ranges']
    
    rep_col = np.tile(np.array([0,0,0,box_width]),(num_plots[0],1))
    rep_col[:,0] = np.arange(0,(plot_gap[0]+box_height)*num_plots[0],plot_gap[0]+box_height)
    rep_col[:,1] =  rep_col[:,0] + box_height
    
    whole_field = (np.tile(rep_col,(n_ranges,1)))
    
    for i in range(1,len(range_gaps)+1):
        whole_field[i*39:(i+1)*39,2] = range_gaps[i-1]+whole_field[(i-1)*39:i*39,3]
        whole_field[i*39:(i+1)*39,3] = box_width+whole_field[i*39:(i+1)*39,2]
    
    return (whole_field)
#     return [box_height,box_width,num_plots,n_ranges,range_gaps,plot_gap]

def get_scaled_bboxes_for_ortho(unit_field_size_array,corner_points_ortho):
    width_unit = unit_field_size_array[-1][3] 
    height_unit = unit_field_size_array[-1][1] 
    
    width,height = corner_points_ortho[-1] - corner_points_ortho[0]

    h_ratio = height/height_unit
    w_ratio = width/width_unit
    
    unit_field_final = np.array(unit_field_size_array,dtype=np.float32)
    
    unit_field_final[:,:2] = unit_field_size_array[:,:2]*h_ratio
    unit_field_final[:,2:] = unit_field_size_array[:,2:]*w_ratio
    
    return np.round(unit_field_final)

def get_corner_points(corner_csv):
    df = pd.read_csv(corner_csv)
    corners = (df.set_index("Date")[['x1','y1','x2','y2']].T.to_dict('list'))
    for c in corners:
        yield((c,np.array([corners[c][:2],corners[c][2:]],dtype=int)))
        

In [None]:
def get_bboxes(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()
    bnd_boxes = []
    for i in root.findall('object'):
        for b in i.findall('bndbox'):
            cur_box = []
            for a in b:
                cur_box.append(int(a.text))
            bnd_boxes.append(cur_box)
    return np.array(bnd_boxes)


def Dice(predicted,grd_truth):
    return 2*np.sum(np.logical_and(predicted,grd_truth))/(np.sum(predicted)+ np.sum(grd_truth))

def objective(shifts,cur_pos):
    i,k = int(shifts[0]),int(shifts[1])
    
    cur_ortho_greeness = cur_pos[4]
    
    gap_size = cur_pos[-4]
    
#     print(cur_pos[0]+i,cur_pos[1]+i,cur_pos[2]+k,cur_pos[3]+k)
    
    
    cur_selection = cur_ortho_greeness[cur_pos[0]+i:cur_pos[1]+i,cur_pos[2]+k:cur_pos[3]+k]    
    cur_selection[cur_selection<0]=0
    
    sum_ax_0 = np.sum(cur_selection,axis=0)
    sum_ax_1 = np.sum(cur_selection,axis=1)
    
    min_ax_0 = np.min(sum_ax_0)/cur_selection.shape[0]
    min_ax_1 = np.min(sum_ax_1)/cur_selection.shape[1]
    
    updated_pos = cur_pos[-2]
    cur_plot_index = cur_pos[-1]  
    
    cur_plot = updated_pos[cur_plot_index]
    nxt_plot = updated_pos[int((cur_plot_index + 1)%(39*6))]   
    prv_plot = updated_pos[int((cur_plot_index - 1)%(39*6))] 
    

    gap_n = np.abs(np.abs(cur_plot[0] - prv_plot[1]) - gap_size)/(2*gap_size) 
    evenness_cost = 1 - (np.exp(-4*gap_n))    
    
    alignment_cost = ((np.abs(cur_plot[2] - prv_plot[2]))/(2*gap_size))
    
#     print(evenness_cost)
    
    weight = np.copy(cur_pos[-3])
    
    if cur_pos[-1]%39==0:
        weight[-1],weight[-2] = 0,0

    cur_cost = np.array([(1-np.mean(cur_selection)),((abs(i)/(gap_size)+abs(k)/(gap_size))/2),(1-min_ax_0),(1-min_ax_1),(evenness_cost),(alignment_cost)])
    
    cost = np.dot(weight,cur_cost)
    
#     print(cost)
    
    return cost
    
    
def estimate_bboxes(init_guess,ortho_img,weights=None):
    if weights is None: weights = np.append(np.array([5]),np.random.dirichlet(np.ones(5)))
    
    inbeteween_gap = init_guess[1,0] - init_guess[0,1]  
    
    init_guess_updates = np.copy(init_guess)
    ortho_sol = []    
    bounds = [(-inbeteween_gap,inbeteween_gap),(-inbeteween_gap,inbeteween_gap)]
    for plot_box in range(init_guess.shape[0]):
        r_1,r_2,c_1,c_2 = init_guess[plot_box][0],init_guess[plot_box][1], init_guess[plot_box][2],init_guess[plot_box][3]
        c_1,r_1,c_2,r_2 = int(c_1),int(r_1),int(c_2),int(r_2)
        sol = differential_evolution(objective,bounds=bounds,args=([[r_1,r_2,c_1,c_2,ortho_img,inbeteween_gap,weights,init_guess_updates,plot_box]]))
        i,k = sol.x
#         selected_params.append(weights)
        ortho_sol.append([r_1+np.rint(i),r_2+np.rint(i),c_1+np.rint(k),c_2+np.rint(k)])
        init_guess_updates[plot_box] = np.array(ortho_sol[-1])
        
#         print(init_guess[plot_box])
#         print(sol.x)
#         return ([r_1+np.rint(i),r_2+np.rint(i),c_1+np.rint(k),c_2+np.rint(k)])
        
#         break
    return np.array(ortho_sol)


In [None]:
def get_mask_ortho_pos(f,img_name):
    df = pd.read_csv('csv/dataset_csv.csv')
    cur = ( df[df['Date']==img_name][['x1','y1','x2','y2']].values).reshape((2,2))
    bboxes_init_guess = get_scaled_bboxes_for_ortho(f,cur) + np.array([cur[0,1],cur[0,1],cur[0,0],cur[0,0]])  
    return bboxes_init_guess

In [None]:
ortho_08_12,ortho_greeness_08_12,grd_truth_08_12 = get_ortho_and_mask('2017-08-12')
init_guess_08_12 = get_mask_ortho_pos(f,'2017-08-12')

In [None]:
ortho_07_07,ortho_greeness_07_07,grd_truth_07_07 = get_ortho_and_mask('2017-07-07')
init_guess_07_07 = get_mask_ortho_pos(f,'2017-07-07')

In [None]:
def weight_obj(cur_val,grd_t1,grd_t2): 
    print(cur_val)
    sol_ortho2 = estimate_bboxes(init_guess_07_07,ortho_greeness_07_07,np.array(cur_val))
    s = get_mask(sol_ortho2,ortho_greeness_07_07.shape)
    
    cur_overlap_1 = Dice(s>0.5,grd_t1>0.5)
    
    print(cur_overlap_1,'2017-07-07')
    
    sol_ortho2 = estimate_bboxes(init_guess_08_12,ortho_greeness_08_12,np.array(cur_val))
    s = get_mask(sol_ortho2,ortho_greeness_08_12.shape)
    
    cur_overlap_2 = Dice(s>0.5,grd_t2>0.5)
    print(cur_overlap_2,'2017-08-12')
    print()
    
    return 1 - ((cur_overlap_1 + cur_overlap_2)/2)


In [None]:
differential_evolution(weight_obj,bounds=[(0,1),(0,1),(0,1),(0,1),(0,1),(0,1)],args=[grd_truth_07_07,grd_truth_08_12])

In [None]:
results = pd.read_csv('weight_search_out.csv',header=None)

In [None]:
results.head()

In [None]:
best_results = results[results[6] > .89]

In [None]:
# ortho_07_26,ortho_greeness_07_26,grd_truth_07_26 = get_ortho_and_mask('2017-08-12')
# init_guess_07_26 = get_mask_ortho_pos(f,'2017-08-12')

In [None]:
for i,r in best_results.iterrows():
    weight_cur = r.values[:-1]
    
    print(weight_cur)
    
    avg_weight = []    
    for d in ds.Date:
        try:
            ortho_07_26,ortho_greeness_07_26,grd_truth_07_26 = get_ortho_and_mask(d)
            init_guess_07_26 = get_mask_ortho_pos(f,d)

            ortho_sol = estimate_bboxes(init_guess_07_26,ortho_greeness_07_26,np.array(weight_cur))

            s = get_mask(ortho_sol,ortho_greeness_07_26.shape)
            cur_dice = (Dice(s>0.5,grd_truth_07_26>0.5),d)
            avg_weight.append(cur_dice[0])
            print(cur_dice)
        except:
            pass
    print('average: ',np.array(avg_weight).mean())
    print()