In [None]:
import warnings
warnings.simplefilter('ignore')
import os
import gc
import sys
import glob
from multiprocessing import Pool
import netCDF4 as nc
import numpy as np
from tqdm import tqdm
import torch

In [None]:
# 数据组装函数
def assemble(day):
    '''
    day: '20200320'
    '''
    year = day[:4]
    # z(geopotential)
    file_path = f'pressure_level/geopotential/{year}/{day}.nc'
    file_obj = nc.Dataset(file_path)
    z = file_obj.variables['z']
    z = z[:]
    # t(temperature)
    file_path = f'pressure_level/temperature/{year}/{day}.nc'
    file_obj = nc.Dataset(file_path)
    t = file_obj.variables['t']
    t = t[:]
    # u(u_component_of_wind)
    file_path = f'pressure_level/u_component_of_wind/{year}/{day}.nc'
    file_obj = nc.Dataset(file_path)
    u = file_obj.variables['u']
    u = u[:]
    # v(v_component_of_wind)
    file_path = f'pressure_level/v_component_of_wind/{year}/{day}.nc'
    file_obj = nc.Dataset(file_path)
    v = file_obj.variables['v']
    v = v[:]
    # q(specific_humidity)
    file_path = f'pressure_level/specific_humidity/{year}/{day}.nc'
    file_obj = nc.Dataset(file_path)
    q = file_obj.variables['q']
    q = q[:]
    # ciwc(specific_cloud_ice_water_content)
    file_path = f'pressure_level/specific_cloud_ice_water_content/{year}/{day}.nc'
    file_obj = nc.Dataset(file_path)
    ciwc = file_obj.variables['ciwc']
    ciwc = ciwc[:]
    # clwc(specific_cloud_liquid_water_content)
    file_path = f'pressure_level/specific_cloud_liquid_water_content/{year}/{day}.nc'
    file_obj = nc.Dataset(file_path)
    clwc = file_obj.variables['clwc']
    clwc = clwc[:]
    # crwc(specific_rain_water_content)
    file_path = f'pressure_level/specific_rain_water_content/{year}/{day}.nc'
    file_obj = nc.Dataset(file_path)
    crwc = file_obj.variables['crwc']
    crwc = crwc[:]
    # cswc(specific_snow_water_content)
    file_path = f'pressure_level/specific_snow_water_content/{year}/{day}.nc'
    file_obj = nc.Dataset(file_path)
    cswc = file_obj.variables['cswc']
    cswc = cswc[:]
    # assemble
    arr = np.concatenate((z.data[:, :, 35:81, 70:141], 
                          t.data[:, :, 35:81, 70:141], 
                          u.data[:, :, 35:81, 70:141], 
                          v.data[:, :, 35:81, 70:141], 
                          q.data[:, :, 35:81, 70:141],
                          ciwc.data[:, :, 35:81, 70:141], 
                          clwc.data[:, :, 35:81, 70:141], 
                          crwc.data[:, :, 35:81, 70:141], 
                          cswc.data[:, :, 35:81, 70:141]), 
                         axis=1)
    # check sanity
    assert arr.shape == (4, 117, 46, 71)
    return arr


os.makedirs('prepared_data', exist_ok=True)
all_days = sorted([i.split('/')[-1].replace('.nc', '') for i in glob.glob('pressure_level/geopotential/*/*.nc')])
with Pool(8) as pool:
    arr = list(
        tqdm(
            pool.imap(assemble, all_days),
            total=len(all_days),
            desc="Generating data",
        )
    )
arr = np.concatenate(arr)
print('processed data shape:', arr.shape)
np.save('prepared_data/data', arr)

In [None]:
# x 归一化 (nunique==1 的特征不用了)
x_indics = [i for i in list(range(117)) if i not in [78,79,80,91,92,93]]
x_arr = arr[:, x_indics, :, :]
x_mean = x_arr.mean(axis=(0, 2, 3), keepdims=True)
x_std = x_arr.std(axis=(0, 2, 3), keepdims=True)
np.save('prepared_data/norm_mean_x', x_mean)
np.save('prepared_data/norm_std_x', x_std)

# y 归一化
y_indics = [16,20,22,23,25,55,59,61,62,64,68,72,74,75,77,81,85,87,88,90,94,98,100,101,103,107,111,113,114,116]
y_arr = arr[:, y_indics, :, :]
y_mean = y_arr.mean(axis=(0, 2, 3), keepdims=True)
y_std = y_arr.std(axis=(0, 2, 3), keepdims=True)
np.save('prepared_data/norm_mean_y', y_mean)
np.save('prepared_data/norm_std_y', y_std)