In [5]:
# @numba.jit() 
import numpy as np
from py_eddy_tracker.dataset.grid import RegularGridDataset
from py_eddy_tracker.dataset.grid import GridDataset
from datetime import datetime
import cv2
import skimage
from skimage import filters
from netCDF4 import Dataset
from scipy import interpolate
from functools import reduce
import traceback


#思来想去还是专门建一个类来保存eddy相关数据
#顺便在里面定义一些eddy参数与方法，比如涡旋半径啥的
#预期实现的效果，鼠标指针悬置于涡旋区域之上时能够显示涡旋相关参数
class Eddy:
    
    def __init__(self, center, contour):
        self.center = center
        self.contour = contour
        
    #查一下能不能把结果cache一下    
    def radius(self):
        pass
    
class EddyDectionResult:
    
    
    def __init__(self, eddies, longitude, latitude, raw_data, results):
        self.eddies_a, self.eddies_c = eddies
        self.longitude = longitude
        self.latitude = latitude
        self.raw_data = raw_data
        self.results_a, self.results_c = results
        
    def display(self, ax):
        data = self.raw_data
        res_a, res_c = self.results_a[-1], self.results_c[-1]
        eddies_a, eddies_c = self.eddies_a, self.eddies_c
        
        
        ax.scatter(res_a[..., 0], res_a[..., 1], color='blue', s=2)
        ax.scatter(res_c[..., 0], res_c[..., 1], color='red', s=2)
        for e in eddies_a:
            c = e.contour
            ax.plot(c[:,0], c[:,1], color='red')
        for e in eddies_c:
            c = e.contour
            ax.plot(c[:,0], c[:,1], color='blue')
        return ax
    
    
def shape_error(cnt):
    area = cv2.contourArea(cnt)
    (_x,_y),radius = cv2.minEnclosingCircle(cnt)
    carea = np.pi*radius**2
    return (carea-area)/carea

def IsValidSeries(s):
    '''
     x - · · - x (a=2时)
    
    ·两个x速度相反
    ·速度单调递增或递减
    
    
    Args:
        s:
            s1 (x - · · - x)
            s2
            s3
            s4
            ...
            sn
     
    '''
    #检查速度是否相反
    #Fix me! 该段似无必要， 单调性满足此段自然满足
    is_opposite = s[:, 0]*s[:, -1] < 0
    
    #检查单调性
    diff = np.diff(s, axis=1)
    is_increasing = np.all(diff >= 0, axis=1)
    is_decreasing = np.all(diff <= 0, axis=1)
    is_mono = is_increasing | is_decreasing
    is_valid = is_mono & is_opposite 
    
    return np.ma.array(is_increasing, mask=~is_valid, fill_value=False)
    
def Election(V):
    V_ = np.apply_along_axis(np.convolve, 1, (V>0)*2-1, v=np.array([-1, 1]), mode='same')
    dot1 = np.stack(np.where(V_==-2), axis=1)
    dot2 = np.stack(np.where(V_==2), axis=1)
    
    dot_left = np.concatenate([dot1, dot2]) - [0,1]
    dot_right = np.concatenate([dot1, dot2]) 
    return np.concatenate([dot_left, dot_right])

def firstYakusoku(V, candidate, a=3):
    '''
    第一约束
    
    a=2
    ____________
    |_|_|_|_|_|_|
    x_|_!_|_x_|_|
    |_|_|_|_|_|_|
    |_|_|_|_|_|_|
    |_|_|_|_|_|_|
    
    '''
    V_ = np.pad(V, [(0, 0),(a, a)], 'constant', constant_values=np.nan)
    length = candidate.shape[0]
    X, Y = candidate[:, 1], candidate[:, 0] #反切
    
    #利用一种扭曲的方法来获得所有候选点±a范围内所有点的索引
    bias = np.arange(a*2+1)
    bias,X = np.meshgrid(bias, X)
    X = (X+bias).flatten() 
    Y = np.repeat(Y, a*2+1)
    s = V_[Y, X].reshape(candidate.shape[0],a*2+1)
#     cord_l = candidate
#     cord_r = candidate + [[0,a*2+1]]  
    

    res = IsValidSeries(s)
    increase = candidate[res.filled()]
    decrease = candidate[(~res).filled()]
    
    
    return increase,decrease
#     index =  np.where( V_[cord_l[...,0], cord_l[...,1]]+V_[cord_r[...,0], cord_r[...,1]] == 0)


def secondYakusoku(U, candidate, a=3, flag=1):
    '''
    第二约束:
    
    a=2
    ____x_______
    |_|_|_|_|_|_|
    |_|_!_|_|_|_|
    |_|_|_|_|_|_|
    |_|_x_|_|_|_|
    |_|_|_|_|_|_|
    
    额外约束:
        candidate1表示反气旋
        candidate2表示气旋
        
        反气旋需由正变负(decrease)
         气旋需由负变正(increase)
         
         
    Args:
        U: 纬向速度，向东为正
        candidate: 涡旋候选点
        a:探测参数, 以候选点为中心，检查经向±a范围的连续点是否符合第二约束条件
        flag: 涡旋候选点对应的涡旋种类，flag=1为反气旋，flag=0为气旋
    
    '''
    
    assert(flag in [0, 1])
    
    U_ = np.pad(U, [(a, a),(0, 0)], 'constant', constant_values=np.nan)
    
    #正气旋
    length = candidate.shape[0]
    X, Y = candidate[:, 1], candidate[:, 0] #反切
    
    #利用一种扭曲的方法来获得所有候选点±a范围内所有点的索引
    bias = np.arange(2*a+1)
    bias,Y = np.meshgrid(bias, Y)
    Y = (Y+bias).flatten()
    X = np.repeat(X, 2*a+1)
    s = U_[Y, X].reshape(candidate.shape[0],2*a+1)
    res = IsValidSeries(s)
    
    
    #decreasing_points: candidate[(~res).filled()]
    #increasing_points
    return candidate[(~res).filled()] if flag else candidate[res.filled()]  #反气旋,气旋

def thirdYakusoku(U, V, candidate, b=2):
    '''
    Fix me: High cost
    
    '''
    
    for _ in range(2):
        velocity = U**2 + V**2
        if len(candidate):
            collections = np.stack([velocity[i[0]-b:i[0]+b+1, i[1]-b:i[1]+b+1] for i in candidate], axis=0)
            length = len(collections)
            min_velocity_point_candidate = np.unravel_index(np.argmin(collections.reshape(length, -1), axis=1),shape=collections[0].shape)
            y1, x1 = min_velocity_point_candidate 
            min_velocity_point_candidate  = np.stack([y1, x1], axis=1)
            min_velocity_point_candidate += candidate - np.array([b, b])
        else:
            min_velocity_point_candidate = candidate
        candidate = np.unique(min_velocity_point_candidate, axis=0)
    return candidate


def kanjosenTest(candidate, U, V, a, debug=False):
    y, x = candidate
    try:#针对越界问题最无脑的解决方法
        loopU=np.concatenate([U[y-a, x-a:x+a], U[y-a:y+a, x+a], U[y+a, x+a:x-a:-1], U[y+a:y-a:-1, x-a]])
        loopV=np.concatenate([V[y-a, x-a:x+a], V[y-a:y+a, x+a], V[y+a, x+a:x-a:-1], V[y+a:y-a:-1, x-a]])
    except:
        print(f'越界！{candidate}') #debug
        ErrorMessage = traceback.format_exc()
        print(ErrorMessage)
        return False
    
    
    bitU = loopU>0
    bitV = loopV>0
    loopmask = bitU^bitV 
    loop = loopmask*( (bitU<<1)-1) + (~loopmask)*( (bitV<<1)-1)*1j
    test = loop[1:]/loop[:-1]
    
    vel = loopU+loopV*1j
    test2 = vel[1:]/vel[:-1]
    res = np.all(test[test!=1+0j]==0+1j) & np.all(np.imag(test2)>=0)
    
    if debug:
        plt.figure(figsize=(8,4))
        plt.subplot(1,2,1)
        plt.title(f'({y},{x})',fontdict={'color':'green' if res else 'red'})
        plt.quiver(loopU, loopV)
        plt.subplot(1,2,2)
        w = 7
        plt.quiver(U[y-w:y+w, x-w:x+w], V[y-w:y+w, x-w:x+w])
        plt.show()
    
    return res


def lastYakusoku(U, V, candidate, a=3, debug=False):
    '''
    最后约束
    
    '''
    IsValid = np.apply_along_axis(kanjosenTest, arr=candidate, axis=1, a=a-1, U=U, V=V, debug=debug)
    return candidate[IsValid]

def xy2geo(candidates):
    #Fix me:一看就能优化
    #这个函数应该放在类似于utils.py的文件里面
    
    points = np.array([roi_lon[candidates[..., 0],candidates[..., 1]],roi_lat[candidates[..., 0],candidates[..., 1]] ]).T
    return points

def detect(data, longitude, latitude, roi_U, roi_V, a=3, b=2, levels=None, debug=False):
    '''
        接口函数
    '''
    
    return _detect_a_and_c(data, longitude, latitude, roi_U, roi_V, a, b, levels, debug)

def _detect_a_and_c(data, longitude, latitude, roi_U, roi_V, a, b, levels, debug):
    '''
    分别进行顺/逆时针的检测,并将结果打包到EddyDectionResult
    '''
    
    #非常蠢的写法
    #但是很有效
    #high cost
    eddies_a, results_a = _detect(-data, roi_U, roi_V, a, b, levels, debug)
    eddies_c, results_c = _detect(data, -roi_U, -roi_V, a, b, levels, debug)
    return EddyDectionResult((eddies_a, eddies_c), longitude, latitude, data, (results_a, results_c))
        
def _detect(data, roi_U, roi_V, a, b, levels, debug):
    testU, testV = roi_U, roi_V
    candidate = Election(testV)
    a = a
    b = b
    candidate2, candidate_ = firstYakusoku(testV, candidate, a=a)
    candidate3 = secondYakusoku(testU, candidate2, a=a, flag=1)
    candidate4 = thirdYakusoku(testU, testV, candidate3, b=b)
    candidate5 = lastYakusoku(testU, testV, candidate4, a=a, debug=debug)
    
    centers = xy2geo(candidate5)
    contours = create_contours(data, centers)
    results = [xy2geo(c) for c in [candidate2, candidate3, candidate4, candidate5]]
    
    eddies = [Eddy(center=center, contour=contour)  for center, contour in zip(centers, contours)]
    return eddies, results

def create_contours(data, centers, levels=None):
    if not levels:  
        levels = np.linspace(np.nanmin(data), np.nanmax(data), 2000)
    #Fix me: High cost
    contours = plt.contour(roi_lon, roi_lat, data, cmap='rainbow', levels=levels).collections
    contours = [c.get_paths() for c in contours]
    contours = reduce(lambda x,y:x+y, contours)
    res = [ [c for c in contours if c.codes[-1] == 79 and c.contains_point(p)] for p in centers]
    for idx,contours in enumerate(res):
        for c in contours:
            cnt=c.vertices.astype(np.float32)
            if shape_error(cnt)<0.55:
                res[idx]=cnt
                break
    
    
    return res

In [None]:
def grid(data, lon, lat, roi_lon0, roi_lon1, roi_lat0, roi_lat1, step_lon, step_lat=None):
    minvalue = data.min()
    maxvalue = data.max()
    
    roi_x = ((roi_lon0 < lon ) & (lon < roi_lon1)).nonzero()
    roi_y = ((roi_lat0 < lat ) & (lat < roi_lat1)).nonzero()
    roi_x, roi_y = np.meshgrid(roi_x, roi_y)
    roi_lon, roi_lat = lon[roi_x], lat[roi_y]
    data = data[roi_x, roi_y]


    if step_lat == None:
        step_lat = step_lon
    
    x = roi_lon
    y = roi_lat
    #</input>

    #生成插值器
    x = x.flatten()
    y = y.flatten()
    # data = data.flatten()
    ori_data = data.copy()
    interpolator = interpolate.interp2d(roi_lon[0],roi_lat[...,0],data,kind='cubic')#trial
    print('done')

    #插值,生成插值后

    #<output, name=X1, Y1, data_roi>
    X1 = np.arange(roi_lon0, roi_lon1, step_lon)
    Y1 = np.arange(roi_lat0, roi_lat1,step_lat)
    # x, y = X1, Y1
    # X1, Y1 = np.meshgrid(X1, Y1)


    data = interpolator(X1, Y1)

    mask=(np.isnan(data) | (data<minvalue) | (data>maxvalue) )
    data = np.ma.array(data, mask=mask)
    return data

def write(data, roi_lon0, roi_lon1, roi_lat0, roi_lat1, step_lon, step_lat):
    with Dataset('test2.nc', 'w', format="NETCDF4") as dataset:
        dataset.createDimension('longitude', data.shape[1])
        dataset.createDimension('latitude', data.shape[0])

        lon = dataset.createVariable('longitude', 'f4', ('longitude',))
        lat = dataset.createVariable('latitude', 'f4', ('latitude',))
        adt_ = dataset.createVariable('adt', 'f4', ('latitude', 'longitude'))

        lon[:] = np.arange(roi_lon0, roi_lon1, step_lon)
        lat[:] = np.arange(roi_lat0, roi_lat1,step_lat)
        adt_.units='m'
        adt_[:] = data

    print('closed')