In [None]:
import numpy as np


#### rho -> PH
#### rho(lev,lat,lon)
def phanom(rhoa,rho0,eta,dz):
    kk,jj,ii = np.shape(rhoa)
    rhoa[rhoa==0]=np.nan
    ph = np.zeros([kk,jj,ii])
    g=9.8
    ph[0,:,:] = g*eta + g*rhoa[0,:,:]*dz[0]/rho0
    for k in range(1,kk):
        ph[k,:,:] = ph[k-1,:,:] + g* (rhoa[k,:,:] + rhoa[k-1,:,:] )/2*dz[k]/rho0
    return ph


#### rho -> PH method2
#### rho(lev,lat,lon)
def phanom_r(rhoanom,rho0,dz):
    kk,jj,ii = np.shape(rhoanom)
    ph = np.zeros([kk,jj,ii])
    g=9.8
    ph[kk-1,:,:] = g*rhoanom[kk-1,:,:]*dz[kk-1]/rho0
    for k in range(1,kk):
        ph[kk-k-1,:,:] = ph[kk-k,:,:] + g*rhoanom[kk-k-1,:,:]*dz[kk-k-1]/rho0
    return ph


#### u/v -> Ux/Uy
## u(lev,lat,lon)
def partialx3(u,dx,dy):
    kk,jj,ii = np.shape(u)   
    ux = np.zeros([kk,jj,ii])
    uy = np.zeros([kk,jj,ii])
    for k in range(kk):
        for i in range(1,ii-1):
            ux[k,:,i] = ( u[k,:,i+1]-u[k,:,i-1] )/2/dx
        ux[k,:,0] = ( u[k,:,1]-u[k,:,-1] )/2/dx
        ux[k,:,-1] = ( u[k,:,0]-u[k,:,-2] )/2/dx
        for j in range(1,jj-1):
            uy[k,j,:] = ( u[k,j+1,:]-u[k,j-1,:] )/2/dy
        uy[k,0,:] = np.nan
        uy[k,-1,:] = np.nan
        
    return ux, uy

#### u/v -> Uz
## u(lev,lat,lon)
def partialz3(u,dz):
    kk,jj,ii = np.shape(u)   
    uz = np.zeros([kk,jj,ii])
    for k in range(1,kk-1):
        uz[k,:,:] = ( u[k+1,:,:]-u[k-1,:,:] )/(dz[k]+dz[k+1]/2+dz[k-1]/2)
    uz[0,:,:] = ( u[1,:,:]-u[0,:,:] )/(dz[0]/2+dz[1]/2)
    uz[-1,:,:] = ( u[-1,:,:]-u[-2,:,:] )/(dz[-1]/2+dz[-2]/2)
        
    return uz


#### psi -> u v in MITgcm
#### psi(lat,lon)
def psi2uv(psi,dx,dy):
    a,b=np.shape(psi)
    ## intialisation
    u = np.zeros([a,b])
    v = np.zeros([a,b])
    ## U 
    for m in range(1,a-1):
        u[m,:] = -( psi[m+1,:] - psi[m-1,:] )/2/dy
    u[0,:]=np.nan
    u[-1,:]=np.nan
    ## V
    for n in range(1,b-1):
        v[:,n] = ( psi[:,n+1] - psi[:,n-1] )/2/dx
    v[:,0]=( psi[:,1] - psi[:,-1] )/2/dx
    v[:,-1]=( psi[:,0] - psi[:,-2] )/2/dx
    
    return u,v


## psi -> vorticity
def psi2vor(psi,dx,dy):
    a,b=np.shape(psi)
    ## intialisation
    vor = np.zeros([a,b])
    term1 = np.zeros([a,b])
    term2 = np.zeros([a,b])
    ## vor
    for i in range(1,b-1):
        for j in range(1,a-1):
            vor[j,i] = (psi[j,i+1]+psi[j,i-1]-2*psi[j,i])/dx/dx + (psi[j+1,i]+psi[j-1,i]-2*psi[j,i])/dy/dy
    vor[0,:]=vor[-1,:]=np.nan
    for j in range(1,a-1):
        vor[j,0] = (psi[j,1]+psi[j,-1]-2*psi[j,0])/dx/dx + (psi[j+1,0]+psi[j-1,0]-2*psi[j,0])/dy/dy
        vor[j,-1] = (psi[j,0]+psi[j,-2]-2*psi[j,-1])/dx/dx + (psi[j+1,-1]+psi[j-1,-1]-2*psi[j,-1])/dy/dy
    
    return vor

## psi -> OW
def psi2ow(psi,dx,dy):
    a,b=np.shape(psi)
    ## intialisation
    ow = np.zeros([a,b])
    ## ow
    u,v = psi2uv(psi,dx,dy)
    Sn,Ss = uv2strain(u,v,dx,dy)
    vor = uv2vor(u,v,dx,dy)
    
    ow = Sn**2+Ss**2-vor**2
    
    return ow


#### psi -> u v in MITgcm
#### psi(lev,lat,lon)
def psi3dim2uv(psi,dx,dy):
    a,b,c = np.shape(psi)
    ## intialisation
    u = np.zeros([a,b,c])
    v = np.zeros([a,b,c])
    
    for k in range(a):
        u[k,:,:], v[k,:,:]= psi2uv(np.squeeze(psi[k,:,:]),dx,dy)
    
    return u,v



## u，v -> divergence
def uv2div(u,v,dx,dy):
    a,b=np.shape(u)
    ## intialisation
    div = np.zeros([a,b])
    ## vor
    for i in range(1,b-1):
        for j in range(1,a-1):
            div[j,i] = (v[j,i+1]-v[j,i-1])/dy/2 + (u[j+1,i]-u[j-1,i])/dx/2
    div[0,:]=div[-1,:]=np.nan
    for j in range(1,a-1):
        div[j,0]=(v[j,1]-v[j,-1])/dy/2 + (u[j+1,0]-u[j-1,0])/dx/2
        div[j,-1]=(v[j,0]-v[j,-2])/dy/2 + (u[j+1,-1]-u[j-1,-1])/dx/2
    
    return div

#### u，v -> divergence
#### u(lev,lat,lon)
def uv3div(u,v,dx,dy):
    kk,jj,ii=np.shape(u)
    ## intialisation
    div3 = np.zeros([kk,jj,ii])
    for k in range(kk):
        div = uv2div(u[k,:,:],v[k,:,:],dx,dy)
        div3[k,:,:] = div
    
    return div3


## u,v -> vorticity
def uv2vor(u,v,dx,dy):
    a,b=np.shape(u)
    ## intialisation
    vor = np.zeros([a,b])
    ## vor
    for i in range(1,b-1):
        for j in range(1,a-1):
            vor[j,i] = (v[j,i+1]-v[j,i-1])/dx/2 - (u[j+1,i]-u[j-1,i])/dy/2
    vor[0,:]=vor[-1,:]=np.nan
    for j in range(1,a-1):
        vor[j,0]=(v[j,1]-v[j,-1])/dx/2 - (u[j+1,0]-u[j-1,0])/dy/2
        vor[j,-1]=(v[j,0]-v[j,-2])/dx/2 - (u[j+1,-1]-u[j-1,-1])/dy/2
    
    return vor



## u,v -> normal/shear strain: Sn Ss
def uv2strain(u,v,dx,dy):
    a,b=np.shape(u)
    ## intialisation
    Sn = np.zeros([a,b])
    Ss = np.zeros([a,b])
    ## Sn Ss
    for i in range(1,b-1):
        for j in range(1,a-1):
            Sn[j,i]= (u[j,i+1]-u[j,i-1])/dx/2 - (v[j+1,i]-v[j-1,i])/dy/2
            Ss[j,i]= (v[j,i+1]-v[j,i-1])/dx/2 + (u[j+1,i]-u[j-1,i])/dy/2
    Ss[0,:]=Ss[-1,:]=np.nan
    Sn[0,:]=Sn[-1,:]=np.nan
    for j in range(1,a-1):
        Ss[j,0]=(v[j,1]-v[j,-1])/dx/2 - (u[j+1,0]-u[j-1,0])/dy/2
        Ss[j,-1]=(v[j,0]-v[j,-2])/dx/2 - (u[j+1,-1]-u[j-1,-1])/dy/2
        Sn[j,0]=(u[j,1]-u[j,-1])/dx/2 - (v[j+1,0]-v[j-1,0])/dy/2
        Sn[j,-1]=(u[j,0]-u[j,-2])/dx/2 - (v[j+1,-1]-v[j-1,-1])/dy/2   
        
    return Sn,Ss

## u,v -> OW
def uv2ow(u,v,dx,dy):
    a,b=np.shape(u)
    ## intialisation
    ow = np.zeros([a,b])
    ## ow
    Sn,Ss = uv2strain(u,v,dx,dy)
    vor = uv2vor(u,v,dx,dy)
    
    ow = Sn**2+Ss**2-vor**2
    
    return ow


#### psi2waf
#### psi(lev,lat,lon)
#### psi_m -> mean state; psi_p -> prime state
def psi3dim2waf(psi_m,psi_p,dx,dy,rho):
    if len(np.shape(psi_m))==3:
        aa,bb,cc = np.shape(psi_m)
        up,vp = psi3dim2uv( psi_p, dx, dy )
        um,vm = psi3dim2uv( psi_m, dx, dy )
    else:
        bb,cc = np.shape(psi_m)
        aa = 1
        up=vp=np.zeros([aa,bb,cc])
        um=vm=np.zeros([aa,bb,cc])
        up[0,:,:],vp[0,:,:] = psi2uv( psi_p, dx, dy )
        um[0,:,:],vm[0,:,:] = psi2uv( psi_m, dx, dy )
        
        psi_p3 = np.zeros([aa,bb,cc])
        psi_p3[0,:,:] = psi_p
        psi_p = psi_p3
    speed = np.sqrt(up*up + vp*vp)
    #### waf initialisation
    
    wx = np.zeros([aa,bb,cc])
    wy = np.zeros([aa,bb,cc])
    
    wx_term1= np.zeros([aa,bb,cc])
    wx_term2= np.zeros([aa,bb,cc]) 
    
    wy_term1= np.zeros([aa,bb,cc])
    wy_term2= np.zeros([aa,bb,cc])  
    
    #### cal wx
    for i in range(1,cc-1):
        wx_term1[:,:,i] = um[:,:,i]*( vp[:,:,i]*vp[:,:,i] - psi_p[:,:,i]*( vp[:,:,i+1] - vp[:,:,i-1] )/2/dx) 
    for j in range(1,bb-1):
        wx_term2[:,j,:] = vm[:,j,:]*( vp[:,j,:]*up[:,j,:] + psi_p[:,j,:]*( vp[:,j+1,:] - vp[:,j-1,:] )/2/dy )
    
    if len(np.shape(psi_m))==3:
        wx = (wx_term1 - wx_term2)*rho/speed/2
    else:
        wx = (wx_term1 - wx_term2)*rho[0,:,:]/speed/2
    
    #### cal wy
    for i in range(1,cc-1):
        wy_term1[:,:,i] = um[:,:,i]*( -up[:,:,i]*vp[:,:,i] + psi_p[:,:,i]*( up[:,:,i+1] - up[:,:,i-1] )/2/dx) 
    for j in range(1,bb-1):
        wy_term2[:,j,:] = vm[:,j,:]*( up[:,j,:]*up[:,j,:] + psi_p[:,j,:]*( up[:,j+1,:] - up[:,j-1,:] )/2/dy )
    
    if len(np.shape(psi_m))==3:
        wy = (wy_term1 + wy_term2)*rho/speed/2  
    else:
        wy = (wy_term1 + wy_term2)*rho[0,:,:]/speed/2  
    
    return wx,wy

#### 多维转一维
def trim_axs(axs, N):
    """
    Reduce *axs* to *N* Axes. All further Axes are removed from the figure.
    """
    axs = axs.flat
    for ax in axs[N:]:
        ax.remove()
    return axs[:N]