In [1]:
import sys,os
import torch
import yaml
import logging
from pydantic import ValidationError

sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(os.path.abspath(os.path.dirname(os.getcwd()))))))
from datasets.weather_bench import WeatherDataset
from models.VariableEncoder.datasets.dataset import CustomDataset
from models.VariableEncoder.training.configs import TrainingConfig
from models.VariableEncoder.training.configs import TrainingRunConfig

def get_normal_dataset(config: TrainingConfig):
    tgt_time_len = 1 * config.tgt_time_len
    device = ("cuda" if torch.cuda.is_available() else "cpu" )
    device = torch.device(device)

    weather = WeatherDataset(config.train_offset, device=device)
    # dataset.shape:  torch.Size([7309, 100, 1450])
    source, label, mean_std = weather.load(config.air_variable, config.surface_variable, config.only_input_variable, config.constant_variable)
    dataset = CustomDataset(source, label, tgt_time_len, n_only_input=len(config.only_input_variable)+len(config.constant_variable))
    src_var_list = weather.get_var_code(config.air_variable, config.surface_variable + config.only_input_variable+config.constant_variable)
    tgt_var_list = weather.get_var_code(config.air_variable, config.surface_variable)
    return dataset, mean_std, (src_var_list, tgt_var_list)

config_path = os.path.join(os.path.dirname(os.getcwd()), 'configs/train_config.yaml')

try:
    with open(config_path) as f:
        config_dict = yaml.safe_load(f)
    config: TrainingRunConfig = TrainingRunConfig.parse_obj(config_dict)
except FileNotFoundError:
    logging.error(f"Config file {config_path} does not exist. Exiting.")
except yaml.YAMLError:
    logging.error(f"Config file {config_path} is not valid YAML. Exiting.")
except ValidationError as e:
    logging.error(f"Config file {config_path} is not valid. Exiting.\n{e}")


dataset, mean_std, var_list = get_normal_dataset(config.training)

데이터셋 불러오는 중...
==== LOAD DATASET ====
 <xarray.Dataset>
Dimensions:                       (time: 87673, latitude: 30, longitude: 30,
                                   level: 37)
Coordinates:
  * latitude                      (latitude) float32 32.0 32.25 ... 39.0 39.25
  * level                         (level) int64 1 2 3 5 7 ... 925 950 975 1000
  * longitude                     (longitude) float32 124.0 124.2 ... 131.2
  * time                          (time) datetime64[ns] 2011-12-31 ... 2021-1...
Data variables: (12/15)
    10m_u_component_of_wind       (time, latitude, longitude) float32 dask.array<chunksize=(128, 30, 30), meta=np.ndarray>
    10m_v_component_of_wind       (time, latitude, longitude) float32 dask.array<chunksize=(128, 30, 30), meta=np.ndarray>
    2m_temperature                (time, latitude, longitude) float32 dask.array<chunksize=(128, 30, 30), meta=np.ndarray>
    geopotential                  (time, level, latitude, longitude) float32 dask.array<chunksize=(1

100%|██████████| 13/13 [00:53<00:00,  4.14s/it]


65.73858 sec
==== LOAD DATASET ====
 <xarray.Dataset>
Dimensions:                       (time: 87673, latitude: 20, longitude: 20,
                                   level: 37)
Coordinates:
  * latitude                      (latitude) float64 21.0 22.5 ... 48.0 49.5
  * level                         (level) int64 1 2 3 5 7 ... 925 950 975 1000
  * longitude                     (longitude) float64 115.5 117.0 ... 144.0
  * time                          (time) datetime64[ns] 2011-12-31 ... 2021-1...
Data variables: (12/15)
    10m_u_component_of_wind       (time, latitude, longitude) float32 dask.array<chunksize=(128, 20, 20), meta=np.ndarray>
    10m_v_component_of_wind       (time, latitude, longitude) float32 dask.array<chunksize=(128, 20, 20), meta=np.ndarray>
    2m_temperature                (time, latitude, longitude) float32 dask.array<chunksize=(128, 20, 20), meta=np.ndarray>
    geopotential                  (time, level, latitude, longitude) float32 dask.array<chunksize=(128, 

100%|██████████| 13/13 [00:29<00:00,  2.29s/it]


35.82630 sec


In [2]:
%load_ext autoreload
%autoreload 2

from models.VariableEncoder.training.lightning import TrainModule

model = TrainModule.load_from_checkpoint('/workspace/Haea_dev/models/VariableEncoder/tb_logs/lightning_logs/hzgku0jm/checkpoints/epoch=3-step=17532.ckpt')

/workspace/venv/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:198: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.


In [3]:
from torch.utils.data import DataLoader
import pytorch_lightning as pl

train_ds, test_ds = torch.utils.data.random_split(
            dataset,
            [0.9, 0.1],
        )

data_loader = DataLoader(test_ds, batch_size=16, shuffle=True, drop_last=True, num_workers=8)
trainer = pl.Trainer()

predict = trainer.predict(model, data_loader, return_predictions=True)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/workspace/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_preci

TrainerFn.PREDICTING
init cuda:0


/workspace/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


KeyboardInterrupt: 

: 