In [None]:
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import pandas as pd
import proplot as pplt
from scipy.fftpack import * 

2021.08.08

尝试使用proplot绘制平均图

使用pre_process/merge_cmorph_cn051.ipynb 下合并掉cmorph以及cmorph的数据

绘制每个格点年际时间序列，分为am和jja时段的二十年的年际时间序列的CORR以及RMSD

## 数据读入

In [None]:
dir_in = "/raid52/yycheng/MPAS/REFERENCE/TEMP_DATA_large/pre/ordata/"
filename_obs  = "obsmerge_pre_98-17.nc"
filename_vr     = "vr_pre_98-17.nc"
filename_rcm    = "rcm_pre_98-17.nc"

ds_or = {}
ds_or['obs'] = xr.open_dataset(dir_in + filename_obs)
ds_or['vr']     = xr.open_dataset(dir_in + filename_vr)
ds_or['rcm']    = xr.open_dataset(dir_in + filename_rcm  )
# 提取变量
var = {}
var['obs'] = ds_or['obs']['premerge']#[:,  :, :]
var['vr'] = ds_or['vr']['precip_MPAS']
var['rcm'] = ds_or['rcm']['precip_MPAS']

var['obs'] = var['obs'].reset_coords(names = 'lev', drop = True) # 去除掉obs中多余的lev coords

# change coords
var_list = ['obs', 'vr', 'rcm']
for i in var_list:
    rename_dict = dict(zip(var[i].coords.keys(), var['obs'].coords.keys()))
#     # show converting coords
    for rename_i in rename_dict:
        print(rename_i + " -----converting to----- " + rename_dict[rename_i])

    var[i] = var[i].rename(rename_dict)
    var[i]._coords = var['obs']._coords
    var[i] = var[i].rename(i)


In [None]:
var_selmonth = {}
var_selmonth['am'] = {}
var_selmonth['jja'] = {}

time_idx_am = var['obs'].time.dt.month.isin([4,5])
time_idx_jja = var['obs'].time.dt.month.isin([6,7,8])

for mod_name in ['obs', 'vr', 'rcm']:
    var_selmonth['am'][mod_name]  = var[mod_name].isel(time = time_idx_am)
    var_selmonth['jja'][mod_name] = var[mod_name].isel(time = time_idx_jja)


## 计算年际相关系数

In [None]:
# 获取年际时间序列 var_interannual{} dict
var_interannual = {}
time_for_groupby = {}
time_for_groupby['am'] = var_selmonth['am']['vr'].time.dt.year
time_for_groupby['jja'] = var_selmonth['jja']['vr'].time.dt.year

for iseason in ['am', 'jja']:
    var_interannual[iseason] = {}
    for mod_name in ['obs', 'vr', 'rcm']:
        var_interannual[iseason][mod_name] = var_selmonth[iseason][mod_name].groupby(time_for_groupby[iseason]).mean(dim = 'time')

In [None]:
def scipy_count_corr_2d(a,b):
    '计算两个3D序列在时间维上的相关性，使用scipy逐个格点的计算，获取pvalues,输入数组a,b 按照 time x lat x lon的方式进行排列'
    '如果第一个时次出现了np.nan ,那么就对这个格点赋值为nan'
    import numpy as np
    from scipy import stats

    dim1 = a.shape[1]
    dim2 = a.shape[2]
    pvalues = np.empty(shape = a.shape[1:])
    corrvalues = np.empty(shape = a.shape[1:])
    # np.corrcoef?
    for ilat in range(0, dim1):
        for ilon in range(0, dim2):
            if ( (np.isnan(a[0,ilat,ilon])) | (np.isnan(b[0,ilat,ilon])) ):
                corrvalues[ilat, ilon], pvalues[ilat, ilon] = np.nan, np.nan
                continue    
            corrvalues[ilat, ilon], pvalues[ilat, ilon] \
                = stats.pearsonr(a[:,ilat,ilon], b[:,ilat,ilon])
    return [corrvalues, pvalues]

In [None]:
corr_interannual = {}

for iseason in ['am', 'jja']:
    corr_interannual[iseason] = {}
    for mod_name in ['vr', 'rcm']:
        corr_interannual[iseason][mod_name]  = {}
        corr_interannual[iseason][mod_name]['corr']    = \
            xr.corr(var_interannual[iseason]['obs'], var_interannual[iseason][mod_name], dim= 'year')
        corr_scipy, pvalues_scipy = \
            scipy_count_corr_2d(var_interannual[iseason]['obs'].values, var_interannual[iseason][mod_name].values)
        corr_interannual[iseason][mod_name]['pvalues'] = xr.DataArray(pvalues_scipy, coords = \
            corr_interannual[iseason][mod_name]['corr'].coords, name = 'pvalues')

### 计算年纪均方根误差

In [None]:
rmse_interannual = {}

for iseason in ['am', 'jja']:
    rmse_interannual[iseason] = {}
    for mod_name in ['vr','rcm']:
        nyears = var_interannual[iseason][mod_name].year.shape[0]
        rmse_temp = np.sqrt( ( (var_interannual[iseason][mod_name].values - var_interannual[iseason]['obs'].values)**2).sum(axis = 0) / nyears)
        rmse_interannual[iseason][mod_name] = xr.DataArray(rmse_temp, coords = \
            corr_interannual[iseason][mod_name]['corr'].coords, name = 'rmse')

## 绘图部分

In [None]:
# 调整cmap，去掉gist_ncar 中深蓝色的部分
# https://stackoverflow.com/questions/18926031/how-to-extract-a-subset-of-a-colormap-as-a-new-colormap-in-matplotlib

import matplotlib.pyplot as plt
import matplotlib.colors as colors
import numpy as np
import cmaps

def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
    new_cmap = colors.LinearSegmentedColormap.from_list(
        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, n)))
    return new_cmap

arr = np.linspace(0, 50, 100).reshape((10, 10))
fig, ax = plt.subplots(ncols=2)

# cmap_corr     = plt.get_cmap('bwr')
# cmap_corr     = cmaps.GMT_panoply
cmap_corr     = cmaps.ncl_default
cmap_rmse     = plt.get_cmap('gist_ncar')
new_cmap_corr = truncate_colormap(cmap_corr, 0., 1.)
new_cmap_rmse = truncate_colormap(cmap_rmse, 0.4, 1.)
ax[0].imshow(arr, interpolation='nearest', cmap=cmap_corr)
ax[1].imshow(arr, interpolation='nearest', cmap=new_cmap_rmse)
plt.show()

In [None]:
# 国内政区图的绘制
# Load the border data, CN-border-La.dat is download from
# https://gmt-china.org/data/CN-border-La.dat
import cartopy.crs as ccrs
import cartopy.io.shapereader as shpreader
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.patches as mpatches

cn_border_file = "/m2data2/yycheng/data_stage/CN-border/CN-border_line/CN-border-La.dat"
with open(cn_border_file) as src:
    context = src.read()
    blocks = [cnt for cnt in context.split('>') if len(cnt) > 0]
    borders = [np.fromstring(block, dtype=float, sep=' ') for block in blocks]


### 画图renew
#### 2021.08.08
cn_border_file + shapefile ，但是存在重叠，尝试消除掉China边界，但是其他邻国边界无法处理；
#### 2021.08.09

不使用cn_border_file ，使用shapefile + coast_line（proplot自带） 的办法

shapefile有一些重叠，不绘制行政区

shapefile重新进行绘制，考虑来自 domain_info 中测试的多个shape file中挑选出地资所（改变了prj方式之后就可以正常绘制，具体查看prj后缀文件）进行使用

如果都使用环资所的全球、全国数据，那么是不会出现问题的

#### 2021.08.13
绘制相关性，使用hatch打点比较困难，这里直接考虑使用scatter打点，并且对过于密集的格点手动降低了密度 [::4] 散点大小约为1°

In [None]:
# # 打点的方法
# test_or = corr_interannual[iseason][mod_name]['pvalues'][::4,::4]
# test = test_or < 0.1
# scatter_test = np.argwhere(test.values)
# # 找到正确的散点经纬度之后绘制
# plt.scatter(lon[::4][ scatter_test[:,1][::4] ],lat[::4][ scatter_test[:,0][::4] ], s=0.1, color = 'k', marker='o')

In [None]:
# import proplot as plot
from matplotlib import pyplot as plt
import proplot as plot
# ----- get filter vars coords-----

lon = corr_interannual['am']['vr']['corr'].lon.values
lat = corr_interannual['am']['vr']['corr'].lat.values

#----- create plot -----
fig, axs = plot.subplots(ncols=4 ,nrows=2, proj=('cyl'))
m_contour_list = [] # 用于保存contour设置，后续设置colorbar使用

#----- 添加海洋以及行政区划 -----
##---- 直接绘图，从边界文件添加

# for ax_ind in axs:
# for line in borders:
    # axs.plot(line[0::100], line[1::100], lw = 0.5, color='gray',transform=ccrs.Geodetic())
    # axs.plot(line[0::10], line[1::10], lw = 0.4, color='black',transform=ccrs.Geodetic())
##---- 使用shp文件添加
    ## shapefile数据下载的位置：
## http://gaohr.win/site/blogs/2017/2017-04-18-GIS-basic-data-of-China.html
# world_border_shapefile = "/m2data2/yycheng/data_stage/CN-border/World/country.shp"
river_border_shapefile =  "/raid52/yycheng/MPAS/REFERENCE/MODEL_CONSTANT/R1/" + "hyd1_4l.shp"
southsea_shapefile     = "/m2data2/yycheng/data_stage/CN-border/SouthSea/" + "southsea_island.shp"
ninelines_shapefile     = "/m2data2/yycheng/data_stage/CN-border/SouthSea/" + "nine_lines.shp"
## 来源： 沛沛的诸省 + 诸岛
bou24p_shapefile     = "/m2data2/yycheng/data_stage/CN-border/peipeihelp/" + "bou2_4p.shp"
## 来源： https://www.resdc.cn/data.aspx?DATAID=200
province_shapefile     = "/m2data2/yycheng/data_stage/CN-border/CN-sheng/" + "change_proj_CN-sheng-A.shp"

for ax in axs:
    # world     = shpreader.Reader(world_border_shapefile).geometries()
    river     = shpreader.Reader(river_border_shapefile).geometries()
    # bou24p    = shpreader.Reader(bou24p_shapefile).geometries()
    ninelines = shpreader.Reader(ninelines_shapefile).geometries()
    province  = shpreader.Reader(province_shapefile).geometries()
    ax.add_geometries(river, ccrs.PlateCarree(), facecolor='none', edgecolor='b', linewidth=0.4, zorder=1)
    # ax.add_geometries(world, ccrs.PlateCarree(), facecolor='none', edgecolor='k', linewidth=0.4, zorder=1)
    # ax.add_geometries(bou24p, ccrs.PlateCarree(), facecolor='none', edgecolor='k', linewidth=0.6, zorder=1) # 沛沛map
    ax.add_geometries(province, ccrs.PlateCarree(), facecolor='none', edgecolor='k', linewidth=0.6, zorder=1) # 地资所
    ax.add_geometries(ninelines, ccrs.PlateCarree(), facecolor='none', edgecolor='k', linewidth=0.6, zorder=1)

#----- colorbar ticks 统一设置 -----
# cmap = 'gist_ncar'
# cmap = new_cmap
# cmap = cmap_data

corr_ticks = np.linspace(-0.5,1, 16)
rmse_ticks = np.linspace(0, 25, 26)
rmse_ticks = np.concatenate((np.linspace(0,10,21), [12,14,16,18,20,25]), axis=0)
# print("----- tick levels is : " + str(mean_ticks))

# plot contourf and titile axs
axs[0,:].format(ltitle = 'AM')
axs[1,:].format(ltitle = 'JJA')
axs[:,0].format(title='VR')
axs[:,1].format(title='RCM')
axs[:,2].format(title='VR')
axs[:,3].format(title='RCM')
axs[0,1].format(rtitle='CORR')
axs[0,3].format(rtitle='RMSE')

for season_ind, season_name in enumerate(['am','jja']):
    for mod_ind, mod_name in enumerate(['vr','rcm']):
        # corr
        m_corr  = axs[season_ind, mod_ind].contourf(lon, lat, corr_interannual[season_name][mod_name]['corr'].values,\
        levels=corr_ticks,cmap=new_cmap_corr, norm = "midpoint")
        # 绘制散点
        # 需要将原本密集的格点散点减少 (::4) 然后找到减少后的经纬度lon[::4] 进行散点的绘制
        scatter_test = np.argwhere((corr_interannual[season_name][mod_name]['pvalues'][::4,::4]<0.05).values)
        axs[season_ind, mod_ind].scatter(lon[::4][scatter_test[:,1]],lat[::4][scatter_test[:,0]], s=0.15, color = 'k', marker='o')

        # rmse
        m_rmse  = axs[season_ind, mod_ind+2].contourf(lon, lat, rmse_interannual[season_name][mod_name].values,\
        levels=rmse_ticks,cmap=new_cmap_rmse)

# m_corr = axs[mod_ind+4].contourf(lon, lat, corr_interannual['jja'][mod_name]['corr'].values,\
# levels=mean_ticks,cmap=new_cmap_corr)
# scatter_test = np.argwhere((corr_interannual['jja'][mod_name]['pvalues'][::4,::4]<0.05).values)
# axs[mod_ind+4].scatter(lon[::4][ scatter_test[:,1] ], lat[::4][ scatter_test[:,0] ], s=0.15, color = 'k', marker='o')


# m_rmse  = axs[mod_ind+3].contourf(lon, lat, rmse_interannual['jja'][mod_name].values,\
# levels=rmse_ticks,cmap=new_cmap_corr)

#----- add color bar-----
# fig.colorbar(m_overlay, loc='b', cmap=cmap, width=0.1)

fig.colorbar(m_corr, loc='b', width=0.1,cols = (1,2),
ticklabelsize=5,ticks=corr_ticks, title='corrleation')

fig.colorbar(m_rmse, loc='b', width=0.1,cols = (3,4),
ticklabelsize=5,ticks=rmse_ticks, title='RMSE')


# axs[0].colorbar(m_mean, loc='b', width=0.1,
# ticklabelsize=5,ticks=mean_ticks)

# axs[1].colorbar(m_idctn, loc='b', width=0.1,
# ticklabelsize=5,ticks=idctn_ticks)

# axs[2].colorbar(m_idctn_diff, loc='b',width=0.1,
# ticklabelsize=5,ticks=diff_ticks)


# ----- format setting -----
axs.format(
abc=True,
abcloc = 'ul',
#----- 地图底图设置 -----
# reso = 'x-hi',
reso = 'med',
# coast = False,
coast = True,
coastlinewidth = 0.4,
borders = False,
lakes = False,
land  = False,
ocean = False,
# cartopyautoextent = True, 
# borderslinewidth=.5,
labels = True,
longrid  = True,
latgrid  = True,
#-----GEO axis-----
lonlim=(70, 140), latlim=(5, 60),
gridlabelsize = 5,
gridminor = True,
lonlocator = np.arange(70,142,10),
latlocator = np.arange(5,60+2,10),
lonminorlocator = np.arange(70,140+2,2),
latminorlocator = np.arange(5,60+2,2),
#-----line label-----
# linewidth = 0.5,
# suptitle="3000km-2000km bandpass 500hPa height(1998-06 timemean)",
suptitle="interannual TCC & RMSE",
)

#----- save figure -----
fig.patch.set_facecolor('white')
plt.savefig('./output_pic/pre_corr_98-17_pplt_0913.png', dpi=600, facecolor= "white")
# plt.savefig('./output_pic/hgt_idctn.png', dpi=300, facecolor= "white")