In [2]:

#%%
import torch
import onnxruntime as ort
import os
import numpy as np
import onnx
import time
import itertools

lat_indices = np.linspace(90, -90, 721)
lon_indices = np.linspace(-180, 180, 1441)[:-1]

def latlon_extent(lon_min, lon_max, lat_min, lat_max):    
    lon_min, lon_max = lon_min-180, lon_max-180  
     
    # 위경도 범위를 데이터의 행과 열 인덱스로 변환
    lat_start = np.argmin(np.abs(lat_indices - lat_max)) 
    lat_end = np.argmin(np.abs(lat_indices - lat_min))
    lon_start = np.argmin(np.abs(lon_indices - lon_min))
    lon_end = np.argmin(np.abs(lon_indices - lon_max))
    latlon_ratio = (lon_max-lon_min)/(lat_max-lat_min)
    extent=[lon_min, lon_max, lat_min, lat_max]
    return lat_start, lat_end, lon_start, lon_end, extent, latlon_ratio

lat_start, lat_end, lon_start, lon_end, extent, latlon_ratio = latlon_extent(250,310,5,45)  


year = ['2012']
month = ['06']
day = ['23']
times = ['00']
# ens_num = 100
ens_list = range(100,2050)
perturbation_scale_list =[0.1]
factor_list_list = [['z']] 
# surface_factors.sort()
# upper_factors.sort()
# surface_str = "".join([f"_{factor}" for factor in surface_factors])  # 각 요소 앞에 _ 추가
# upper_str = "".join([f"_{factor}" for factor in upper_factors])  # 각 요소 앞에 _ 추가
pangu_dir = r'/home1/jek/Pangu-Weather'

surface_factor = ['MSLP', 'U10', 'V10', 'T2M']
surface_dict = {'MSLP':0, 'U10':1, 'V10':2, 'T2M':3}
upper_factor = ['z', 'q', 't', 'u', 'v']
upper_dict = {'z':0, 'q':1, 't':2, 'u':3, 'v':4}


# Set the behavior of onnxruntime
options = ort.SessionOptions()
options.enable_cpu_mem_arena= True
options.enable_mem_pattern = False
options.enable_mem_reuse = False

# Increase the number for faster inference and more memory consumption
# options.intra_op_num_threads = 1

# Set the behavior of cuda provider for the first GPU
cuda_provider_options_gpu0 = {'arena_extend_strategy': 'kSameAsRequested', 'device_id': 0}

# Set the behavior of cuda provider for the second GPU
cuda_provider_options_gpu1 = {'arena_extend_strategy': 'kSameAsRequested', 'device_id': 1}

# Initialize onnxruntime session for Pangu-Weather Models on different GPUs
ort_session_6 = ort.InferenceSession(rf'{pangu_dir}/pangu_weather_6.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options_gpu0)])
ort_session_24 = ort.InferenceSession(rf'{pangu_dir}/pangu_weather_24.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options_gpu0)])


start = time.time()

for factor_list in factor_list_list:
    for perturbation_scale in perturbation_scale_list:
        for y, m, d, tm in itertools.product(year, month, day, times):
            time_str = f'{y}/{m}/{d}/{tm}UTC'

            input_data_dir = rf'{pangu_dir}/input_data/{time_str}'
            output_data_dir = rf'/data03/Pangu_TC_ENS/output_data/{time_str}'

            input_upper = np.load(os.path.join(input_data_dir, 'upper.npy')).astype(np.float32)
            input_surface = np.load(os.path.join(input_data_dir, 'surface.npy')).astype(np.float32)


            
            std_dev_upper = np.std(input_upper, axis=(2, 3), dtype=np.float32)*perturbation_scale
            std_dev_surface = np.std(input_surface, axis=(1, 2), dtype=np.float32)*perturbation_scale


            factor_str = "".join([f"_{f}" for f in factor_list])

            for ens in ens_list:
                output_data_dir = rf'/data03/Pangu_TC_ENS/output_data/{time_str}/{perturbation_scale}ENS{factor_str}/{ens}'
                # output_data_dir = rf'/data03/Pangu_TC_ENS/output_data/{time_str}/{ens}'
                
                if not os.path.exists(os.path.join(output_data_dir, f'upper')):
                    os.makedirs(os.path.join(output_data_dir, f'upper'))
                if not os.path.exists(os.path.join(output_data_dir, f'surface')):
                    os.makedirs(os.path.join(output_data_dir, f'surface'))
                
                
                perturbed_upper = input_upper.copy()
                perturbed_surface = input_surface.copy()
                seed_val = hash((ens, tuple(factor_list), perturbation_scale)) % (2**32)
                rng = np.random.default_rng(seed_val)  # 새로운 난수 생성기
                # Perturbation 생성 및 적용     
                if ens == 0:
                    pass
                    
                # else:
                #     for factor in factor_list:
                #         if factor in upper_dict:
                #             idx = upper_dict[factor]
                #             for j in range(13):
                #                 perturbation = np.random.normal(0, std_dev_upper[idx, j], input_upper[idx, j].shape)
                #                 perturbed_upper[idx, j] = input_upper[idx, j] + perturbation.astype(np.float32)
                        
                #         elif factor in surface_dict:
                #             idx = surface_dict[factor]
                #             perturbation = np.random.normal(0, std_dev_surface[idx], input_surface[idx].shape)
                #             perturbed_surface[idx] = input_upper[idx] + perturbation.astype(np.float32)

                else:
                    for factor in factor_list:
                        if factor in upper_dict:
                            idx = upper_dict[factor]
                            for j in range(13):
                                perturbation = rng.normal(0, std_dev_upper[idx, j], input_upper[idx, j].shape)
                                perturbed_upper[idx, j] = input_upper[idx, j] + perturbation.astype(np.float32)
                        elif factor in surface_dict:
                            idx = surface_dict[factor]
                            perturbation = rng.normal(0, std_dev_surface[idx], input_surface[idx].shape)
                            perturbed_surface[idx] = input_upper[idx] + perturbation.astype(np.float32)

                    
                    # for factor_index, factor_name in enumerate(surface_factor):  # 지표면 변수 반복
                    #     if factor_name in surface_factor:
                    #         perturbation = np.random.normal(0, std_dev_surface[factor_index], input_surface[factor_index].shape)
                    #         perturbed_surface[factor_index] = input_surface[factor_index] + perturbation.astype(np.float32)
                    #     else:
                    #         perturbed_surface[factor_index] = input_surface[factor_index]
                    

                    
                np.save(os.path.join(output_data_dir, f'upper/0h'), perturbed_upper[:,:,lat_start: lat_end+1, lon_start:lon_end+1])
                np.save(os.path.join(output_data_dir, f'surface/0h'), perturbed_surface[:,lat_start: lat_end+1, lon_start:lon_end+1])

                perturbed_24, perturbed_surface_24 = perturbed_upper, perturbed_surface

                for i in range(28):
                    start_i = time.time()
                    predict_interval = 6*(i+1)
                    if (i+1) % 4 == 0:
                        output, output_surface = ort_session_24.run(None, {'input':perturbed_24, 'input_surface':perturbed_surface_24})
                        perturbed_24, perturbed_surface_24 = output, output_surface
                        np.save(os.path.join(output_data_dir, f'upper/{predict_interval}h'), output[:,:,lat_start: lat_end+1, lon_start:lon_end+1])
                        np.save(os.path.join(output_data_dir, f'surface/{predict_interval}h'), output_surface[:,lat_start: lat_end+1, lon_start:lon_end+1])
                        


                    # 6시간 간격도 저장하고 싶으면 주석 해제
                    else:
                        output, output_surface = ort_session_6.run(None, {'input':perturbed_upper, 'input_surface':perturbed_surface})
                        np.save(os.path.join(output_data_dir, f'upper/{predict_interval}h'), output[:,:,lat_start: lat_end+1, lon_start:lon_end+1])
                        np.save(os.path.join(output_data_dir, f'surface/{predict_interval}h'), output_surface[:,lat_start: lat_end+1, lon_start:lon_end+1])

                    perturbed_upper, perturbed_surface = output, output_surface
                    end_i = time.time()
                    print(f'{factor_list} {perturbation_scale}_{ens}ENS {i+1}번째 반복 +{predict_interval}h {end_i-start_i}s')
                

                end = time.time()
                print(f"{factor_list} {perturbation_scale}_{ens}ENS: {end-start}s")

['z'] 0.1_100ENS 1번째 반복 +6h 2.434659242630005s
['z'] 0.1_100ENS 2번째 반복 +12h 2.2632029056549072s
['z'] 0.1_100ENS 3번째 반복 +18h 2.2601563930511475s
['z'] 0.1_100ENS 4번째 반복 +24h 2.430621385574341s
['z'] 0.1_100ENS 5번째 반복 +30h 2.2594118118286133s
['z'] 0.1_100ENS 6번째 반복 +36h 2.2603647708892822s
['z'] 0.1_100ENS 7번째 반복 +42h 2.2588560581207275s
['z'] 0.1_100ENS 8번째 반복 +48h 2.26133394241333s
['z'] 0.1_100ENS 9번째 반복 +54h 2.2607803344726562s
['z'] 0.1_100ENS 10번째 반복 +60h 2.2645153999328613s
['z'] 0.1_100ENS 11번째 반복 +66h 2.2623019218444824s
['z'] 0.1_100ENS 12번째 반복 +72h 2.261709451675415s
['z'] 0.1_100ENS 13번째 반복 +78h 2.261134147644043s
['z'] 0.1_100ENS 14번째 반복 +84h 2.262500047683716s
['z'] 0.1_100ENS 15번째 반복 +90h 2.2624857425689697s
['z'] 0.1_100ENS 16번째 반복 +96h 2.263079881668091s
['z'] 0.1_100ENS 17번째 반복 +102h 2.2621490955352783s
['z'] 0.1_100ENS 18번째 반복 +108h 2.2649497985839844s
['z'] 0.1_100ENS 19번째 반복 +114h 2.262261390686035s
['z'] 0.1_100ENS 20번째 반복 +120h 2.2760519981384277s
['z'] 0.1_100EN