## mount & install

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install tensorboardX blosc2
!pip install ffmpeg
!pip install boto3

Collecting tensorboardX
  Downloading tensorboardx-2.6.4-py3-none-any.whl.metadata (6.2 kB)
Downloading tensorboardx-2.6.4-py3-none-any.whl (87 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tensorboardX
Successfully installed tensorboardX-2.6.4
Collecting ffmpeg
  Downloading ffmpeg-1.4.tar.gz (5.1 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: ffmpeg
  Building wheel for ffmpeg (setup.py) ... [?25l[?25hdone
  Created wheel for ffmpeg: filename=ffmpeg-1.4-py3-none-any.whl size=6083 sha256=03be28d30c1101acb9930c6f6beeda6266e2685be665c99b068d83a41618dc2d
  Stored in directory: /root/.cache/pip/wheels/26/21/0c/c26e09dff860a9071683e279445262346e008a9a1d2142c4ad
Successfully built ffmpeg
Installing collected packages: ffmpeg
Successfully installed ffmpeg-1.4
Collecting boto3
  Downloading boto3-1.40.19-py3-none-any.whl.metadata 

## 各ライブラリのインポート

In [3]:
import os
import datetime
import math
import numpy as np
import matplotlib.pyplot as plt
import keras
from keras import layers
import huggingface_hub
import blosc2
import tensorboardX
import tensorflow as tf
import pandas as pd
import random
from typing import List, Tuple, Optional
import gc
import boto3

## データセットの読み込み

In [4]:
# 使用例とパラメータ設定
def setup_multi_day_training(base_date: str, num_days: int):
    """
    連続する複数日のトレーニングデータを設定

    Parameters:
    -----------
    base_date : str
        開始日付 (例: "20240809")
    num_days : int
        連続する日数

    Returns:
    --------
    List[str]: 日付のリスト
    """
    from datetime import datetime, timedelta

    # 基準日をdatetimeオブジェクトに変換
    base_dt = datetime.strptime(base_date, "%Y%m%d")

    # 連続する日付のリストを生成
    date_list = []
    for i in range(num_days):
        current_date = base_dt + timedelta(days=i)
        date_list.append(current_date.strftime("%Y%m%d"))

    return date_list


# 使用例: 2024年8月9日から5日間のデータを使用
train_dates = setup_multi_day_training("20240809", num_days=5)
print(f"Training dates: {train_dates}")

val_dates = setup_multi_day_training("20240819", num_days=2)
print(f"Validation dates: {val_dates}")

Training dates: ['20240809', '20240810', '20240811', '20240812', '20240813']
Validation dates: ['20240819', '20240820']


In [5]:
lon = np.load('/content/drive/MyDrive/climate_data_platform/xrain/data/metadata/lon_centers.npy')
print(lon.shape)
lat = np.load('/content/drive/MyDrive/climate_data_platform/xrain/data/metadata/lat_centers.npy')
print(lat.shape)

(3200, 3200)
(3200, 3200)


## datasetの処理

In [6]:
lon_min, lon_max = 138, 140
lat_min, lat_max = 35,  37

lat_axis = lat[:, 0]         # shape (3200,)
lon_axis = lon[0, :]         # shape (3200,)

idx_lat = np.where((lat_axis >= lat_min) & (lat_axis <= lat_max))[0]
idx_lon = np.where((lon_axis >= lon_min) & (lon_axis <= lon_max))[0]

LAT_SLICE = slice(idx_lat.min(), idx_lat.max() + 1)   # 行 = 緯度
LON_SLICE = slice(idx_lon.min(), idx_lon.max() + 1)

In [7]:
SEQ_LEN      = 6
DOWNSAMPLE   = 5
SCALE_FACTOR = 300.0      # mm h‑1 の上限
BATCH_SIZE   = 16         # ConvLSTM はメモリ使用量大 (Reduced from 16 to 8)
EPOCHS       = 200

# 縮小後の空間サイズを自動算出
H = math.ceil((LAT_SLICE.stop - LAT_SLICE.start) / DOWNSAMPLE)  # 緯度方向
W = math.ceil((LON_SLICE.stop - LON_SLICE.start) / DOWNSAMPLE)  # 経度方向
print(f"Input size  : (T={SEQ_LEN}, H={H}, W={W})")

estimated_frames_per_day = 1440
total_frames = len(train_dates) * estimated_frames_per_day
total_val_frames = len(val_dates) * estimated_frames_per_day
multi_steps_per_epoch = max(1, (total_frames - SEQ_LEN) // BATCH_SIZE)
multi_val_steps = max(1, (total_val_frames - SEQ_LEN) // BATCH_SIZE)

print(f"Estimated steps per epoch: {multi_steps_per_epoch}")

Input size  : (T=6, H=192, W=128)
Estimated steps per epoch: 449


In [8]:
# 論文に忠実なImportance Sampling実装

def calculate_acceptance_probability(crop: np.ndarray,
                                   training_mode: bool = True,
                                   epsilon: float = 1e-6) -> float:
    """
    降水量データの受容確率を計算する（既存関数を置き換え）
    """
    valid_mask = ~np.isnan(crop)

    if training_mode:
        # Training: g(x) = 1 - e^(-x) for valid, g(x) = 0 for missing
        valid_values = crop[valid_mask]
        g_values = 1.0 - np.exp(-valid_values)
        total_g = np.sum(g_values)  # missing pixelsは自動的に0
    else:
        # Test/Validation: g(x) = x for all (論文に従い、missing valueは0として扱う)
        safe_crop = np.nan_to_num(crop, nan=0.0)
        total_g = np.sum(safe_crop)

    return total_g + epsilon


def hierarchical_sampling(crops_with_probs: List[Tuple[np.ndarray, float]],
                        n_samples: int,
                        remove_threshold: bool = False) -> List[np.ndarray]:
    """
    階層的サンプリング（閾値による排除を撤廃）
    """
    if not crops_with_probs:
        return []

    # 論文では閾値による排除は行わない - 全てのクロップが候補
    all_crops = crops_with_probs

    if not all_crops:
        return []

    # 受容確率を正規化
    probs = np.array([prob for _, prob in all_crops])

    # ゼロ確率を避けるための最小値設定
    min_prob = np.min(probs[probs > 0]) * 0.001 if np.any(probs > 0) else epsilon
    probs = np.maximum(probs, min_prob)

    probs = probs / np.sum(probs)  # 正規化

    # 論文に従った確率的サンプリング（with replacement）
    selected_indices = np.random.choice(
        len(all_crops),
        size=min(n_samples, len(all_crops)),
        p=probs,
        replace=True  # 重要: 重複選択を許可
    )

    selected_crops = [all_crops[i][0] for i in selected_indices]
    return selected_crops


def calculate_importance_sampling_steps(date_list: List[str],
                                      batch_size: int,
                                      training_mode: bool = True) -> int:
    """
    Importance Samplingでのsteps per epochを計算（既存関数を置き換え）
    """

    # 論文の設定に基づく推定
    estimated_frames_per_day = 144  # 10分間隔で1日144フレーム
    total_frames = len(date_list) * estimated_frames_per_day

    if training_mode:
        # Training: 256x256クロップ、stride=32
        crop_size = (256, 256)
        stride = 32
        # 推定空間サイズ（ダウンサンプリング後）
        estimated_spatial_size = (400, 300)
    else:
        # Test/Validation: 512x512クロップ
        crop_size = (512, 512)
        stride = 32
        estimated_spatial_size = (400, 300)

    # 空間方向のクロップ数
    h_crops = max(1, (estimated_spatial_size[0] - crop_size[0]) // stride + 1)
    w_crops = max(1, (estimated_spatial_size[1] - crop_size[1]) // stride + 1)
    spatial_crops_per_frame = h_crops * w_crops

    # 時間方向の有効フレーム数（シーケンス長を考慮）
    seq_len = 6
    valid_temporal_frames = max(1, total_frames - seq_len - 1)

    # 総クロップ数
    total_possible_crops = valid_temporal_frames * spatial_crops_per_frame

    # 論文に従い、importance samplingによる実効的な増加は控えめに
    if training_mode:
        # Training: 重要なサンプルの重複選択により若干増加
        effective_factor = 1.2
    else:
        # Validation: ほぼ実データ通り
        effective_factor = 1.0

    effective_crops = int(total_possible_crops * effective_factor)
    steps_per_epoch = max(1, effective_crops // batch_size)

    print(f"Steps calculation ({'train' if training_mode else 'val'}):")
    print(f"  Days: {len(date_list)}")
    print(f"  Frames per day: {estimated_frames_per_day}")
    print(f"  Spatial crops per frame: {spatial_crops_per_frame}")
    print(f"  Total possible crops: {total_possible_crops}")
    print(f"  Effective factor: {effective_factor}")
    print(f"  Steps per epoch: {steps_per_epoch}")

    return steps_per_epoch


def importance_sampling_generator(date_list: List[str],
                                batch_size: int,
                                lon_slice: slice,
                                lat_slice: slice,
                                seq_len: int = 6,
                                downsample: int = 5,
                                training_mode: bool = True,
                                buffer_size: int = 2):
    """
    Importance Samplingを適用したバッチジェネレータ（既存関数を置き換え）
    """

    # Wasabi Cloud設定（既存コードと同じ）
    from google.colab import userdata
    import boto3
    import blosc2
    import os
    import gc

    WASABI_ENDPOINT = "https://s3.ap-northeast-1.wasabisys.com"
    WASABI_ACCESS_KEY = userdata.get('WASABI_ACCESS_KEY')
    WASABI_SECRET_KEY = userdata.get('WASABI_SECRET_KEY')
    BUCKET_NAME = "xrain-composite-cx"

    s3_client = boto3.client(
        's3',
        endpoint_url=WASABI_ENDPOINT,
        aws_access_key_id=WASABI_ACCESS_KEY,
        aws_secret_access_key=WASABI_SECRET_KEY,
        region_name='ap-northeast-1'
    )

    def load_data_for_date(date: str) -> Optional[np.ndarray]:
        """データロード（既存と同じ）"""
        base_path = "b2nd"
        year = date[0:4]
        month = date[4:6]
        filename = f"{base_path}/{year}/{month}/cx-{date}.b2nd"
        temp_filename = f"/tmp/cx-{date}.b2nd"

        try:
            s3_client.download_file(BUCKET_NAME, filename, temp_filename)
            rains = blosc2.open(temp_filename)
            data = rains[:, lat_slice, lon_slice]
            data = data[:, ::downsample, ::downsample]
            if os.path.exists(temp_filename):
                os.remove(temp_filename)
            return data.astype(np.float32)
        except Exception as e:
            print(f"Failed to load data for {date}: {e}")
            return None

    def generate_crops_with_probabilities(data: np.ndarray) -> List[Tuple[np.ndarray, float]]:
        """クロップ生成と重要度計算"""
        T, H, W = data.shape

        # 論文に従ったクロップサイズ
        if training_mode:
            crop_h, crop_w = 256, 256
        else:
            crop_h, crop_w = 256, 256  # Validationも256x256で統一（メモリ節約）

        stride = 32  # 論文の設定
        crops_with_probs = []

        # 時間方向のクロップ（270分 = 27フレーム@10分間隔）
        temporal_crop_size = 27

        for t_start in range(0, max(1, T - temporal_crop_size), temporal_crop_size // 2):
            t_end = min(t_start + temporal_crop_size, T)

            # 空間方向のクロップ
            for h_start in range(0, max(1, H - crop_h + 1), stride):
                for w_start in range(0, max(1, W - crop_w + 1), stride):
                    h_end = min(h_start + crop_h, H)
                    w_end = min(w_start + crop_w, W)

                    crop = data[t_start:t_end, h_start:h_end, w_start:w_end]

                    # 受容確率を計算
                    prob = calculate_acceptance_probability(crop, training_mode)
                    crops_with_probs.append((crop, prob))

        return crops_with_probs

    # データバッファ
    data_buffer = []
    current_buffer_dates = []

    while True:
        for start_date_idx in range(len(date_list)):
            buffer_dates = date_list[start_date_idx:start_date_idx + buffer_size]

            if current_buffer_dates != buffer_dates:
                data_buffer = []
                gc.collect()

                for date in buffer_dates:
                    data = load_data_for_date(date)
                    if data is not None:
                        data_buffer.append(data)

                current_buffer_dates = buffer_dates.copy()

                if not data_buffer:
                    continue

                combined_data = np.concatenate(data_buffer, axis=0)
                print(f"Loaded data for dates {buffer_dates}, combined shape: {combined_data.shape}")

            # クロップ生成
            crops_with_probs = generate_crops_with_probabilities(combined_data)

            if not crops_with_probs:
                continue

            # 階層的サンプリング（閾値なし）
            selected_crops = hierarchical_sampling(
                crops_with_probs,
                batch_size * 5  # 充分なサンプル数を確保
            )

            # シーケンス生成
            random.shuffle(selected_crops)

            xs, ys = [], []
            for crop in selected_crops:
                if len(xs) >= batch_size:
                    break

                T_crop = crop.shape[0]
                if T_crop <= seq_len:
                    continue

                max_start = T_crop - seq_len - 1
                t_start = np.random.randint(0, max(1, max_start))

                window = crop[t_start:t_start + seq_len + 1]

                # NaN処理を改善
                if np.isnan(window).any():
                    window = np.nan_to_num(window, nan=0.0)

                x_seq = window[:-1][..., None]
                y_target = window[-1][..., None]

                xs.append(x_seq)
                ys.append(y_target)

            if xs:
                yield np.array(xs, dtype=np.float32), np.array(ys, dtype=np.float32)

In [9]:
train_dates = setup_multi_day_training("20240809", num_days=2)
val_dates = setup_multi_day_training("20240819", num_days=2)

train_gen = importance_sampling_generator(
    train_dates, BATCH_SIZE, LON_SLICE, LAT_SLICE, SEQ_LEN, DOWNSAMPLE,
    training_mode=True
)

val_gen = importance_sampling_generator(
    val_dates, BATCH_SIZE, LON_SLICE, LAT_SLICE, SEQ_LEN, DOWNSAMPLE,
    training_mode=False
)

# Steps計算
train_steps = calculate_importance_sampling_steps(train_dates, BATCH_SIZE, training_mode=True)
val_steps = calculate_importance_sampling_steps(val_dates, BATCH_SIZE, training_mode=False)

# 統計分析
# analyze_importance_sampling_statistics(
#     train_dates, LON_SLICE, LAT_SLICE, DOWNSAMPLE
# )

Steps calculation (train):
  Days: 2
  Frames per day: 144
  Spatial crops per frame: 10
  Total possible crops: 2810
  Effective factor: 1.2
  Steps per epoch: 210
Steps calculation (val):
  Days: 2
  Frames per day: 144
  Spatial crops per frame: 1
  Total possible crops: 281
  Effective factor: 1.0
  Steps per epoch: 17


## モデルの定義

In [12]:
# bilinear ワープ & Sobel
def dense_image_warp_bilinear(img, flow):
    """
    img:  (B,H,W,C), flow: (B,H,W,2) with (dy, dx) in pixels (backward sampling)
    returns: warped image (B,H,W,C)
    """
    B = tf.shape(img)[0]
    H = tf.shape(img)[1]
    W = tf.shape(img)[2]
    C = tf.shape(img)[3]

    # base grid
    y = tf.cast(tf.range(H), tf.float32)
    x = tf.cast(tf.range(W), tf.float32)
    Y, X = tf.meshgrid(y, x, indexing="ij")           # (H,W)
    grid = tf.stack([Y, X], axis=-1)                 # (H,W,2)
    grid = tf.broadcast_to(grid[None, ...], [B, H, W, 2])

    # sample coords
    coords = grid + tf.cast(flow, tf.float32)        # (B,H,W,2)
    y_s = coords[..., 0]
    x_s = coords[..., 1]

    # neighbors
    y0 = tf.floor(y_s);  x0 = tf.floor(x_s)
    y1 = y0 + 1.0;       x1 = x0 + 1.0

    y0c = tf.clip_by_value(tf.cast(y0, tf.int32), 0, H-1)
    y1c = tf.clip_by_value(tf.cast(y1, tf.int32), 0, H-1)
    x0c = tf.clip_by_value(tf.cast(x0, tf.int32), 0, W-1)
    x1c = tf.clip_by_value(tf.cast(x1, tf.int32), 0, W-1)

    # gather 4 neighbors
    batch_idx = tf.reshape(tf.range(B, dtype=tf.int32), [B,1,1])
    batch_idx = tf.tile(batch_idx, [1,H,W])

    def gather(b, yy, xx):
        idx = tf.stack([b, yy, xx], axis=-1)     # (B,H,W,3)
        return tf.gather_nd(img, idx)            # (B,H,W,C)

    Ia = gather(batch_idx, y0c, x0c)
    Ib = gather(batch_idx, y0c, x1c)
    Ic = gather(batch_idx, y1c, x0c)
    Id = gather(batch_idx, y1c, x1c)

    # bilinear weights
    wy1 = y1 - y_s
    wx1 = x1 - x_s
    wy0 = 1.0 - wy1
    wx0 = 1.0 - wx1

    wa = (wy1 * wx1)[..., None]
    wb = (wy1 * wx0)[..., None]
    wc = (wy0 * wx1)[..., None]
    wd = (wy0 * wx0)[..., None]

    out = wa*Ia + wb*Ib + wc*Ic + wd*Id
    return out


def sobel_grad_l2(img):
    """
    img: (B,H,W,1)
    returns: (B,H,W,1) with sqrt(gx^2 + gy^2)
    """
    img = tf.cast(img, tf.float32)
    kx = tf.constant([[1,0,-1],[2,0,-2],[1,0,-1]], dtype=tf.float32)
    ky = tf.constant([[1,2,1],[0,0,0],[-1,-2,-1]], dtype=tf.float32)
    kx = tf.reshape(kx, [3,3,1,1])
    ky = tf.reshape(ky, [3,3,1,1])
    gx = tf.nn.conv2d(img, kx, strides=1, padding="SAME")
    gy = tf.nn.conv2d(img, ky, strides=1, padding="SAME")
    return tf.sqrt(tf.maximum(gx*gx + gy*gy, 1e-12))

In [13]:
class EvolutionUNetConvLSTM(keras.Model):
    """
    inputs:  (B, T, H, W, 1)   mm/h
    outputs: (B, H, W, 1)      mm/h  （x_evo: 1ステップ先）
    """
    def __init__(self, seq_len, scale_factor=300.0, max_disp=3.0, lambda_motion=1e-2):
        super().__init__()
        self.seq_len = seq_len
        self.scale = float(scale_factor)
        self.max_disp = float(max_disp)
        self.lambda_motion = float(lambda_motion)

        Act = "relu"
        # ---- Encoder ----
        self.norm_in = layers.Rescaling(1.0 / self.scale)
        self.c1a = layers.ConvLSTM2D(32, 3, padding="same", return_sequences=True, activation=Act)
        self.c1b = layers.ConvLSTM2D(32, 3, padding="same", return_sequences=True, activation=Act)
        self.p1  = layers.MaxPooling3D(pool_size=(1,2,2), padding="same")

        self.c2a = layers.ConvLSTM2D(64, 3, padding="same", return_sequences=True, activation=Act)
        self.c2b = layers.ConvLSTM2D(64, 3, padding="same", return_sequences=True, activation=Act)
        self.p2  = layers.MaxPooling3D(pool_size=(1,2,2), padding="same")

        # ---- Bottleneck ----
        self.c3a = layers.ConvLSTM2D(128, 3, padding="same", return_sequences=True, activation=Act)
        self.c3b = layers.ConvLSTM2D(128, 3, padding="same", return_sequences=True, activation=Act)

        # ---- Decoder ----
        self.u4  = layers.UpSampling3D(size=(1,2,2))
        self.c4a = layers.ConvLSTM2D(64, 3, padding="same", return_sequences=True, activation=Act)
        self.c4b = layers.ConvLSTM2D(64, 3, padding="same", return_sequences=True, activation=Act)

        self.u5  = layers.UpSampling3D(size=(1,2,2))
        self.c5a = layers.ConvLSTM2D(32, 3, padding="same", return_sequences=True, activation=Act)
        self.c5b = layers.ConvLSTM2D(32, 3, padding="same", return_sequences=False, activation=Act)

        # ---- Heads ----
        self.head_flow = layers.Conv2D(2, 3, padding="same", activation="tanh")  # -> [-1,1]
        self.head_res  = layers.Conv2D(1, 1, padding="same", activation=None)
        self.denorm_out = layers.Rescaling(self.scale)

        # metrics
        self.loss_tracker      = keras.metrics.Mean(name="loss")
        self.accum_tracker     = keras.metrics.Mean(name="accum_loss")
        self.motion_tracker    = keras.metrics.Mean(name="motion_reg")
        self.mae_tracker       = keras.metrics.MeanAbsoluteError(name="mae")

    def evolve_step(self, x_prev_norm, flow, resid):
        # flow: [-1,1] → [-max_disp, max_disp]
        disp = tf.cast(flow, x_prev_norm.dtype) * self.max_disp
        x_adv_norm = dense_image_warp_bilinear(x_prev_norm, disp)  # (B,H,W,1)
        x_evo_norm = tf.nn.relu(x_adv_norm + resid)
        return x_adv_norm, x_evo_norm

    def _forward_heads(self, x):
        z = self.norm_in(x)
        c1 = self.c1b(self.c1a(z))
        p1 = self.p1(c1)

        c2 = self.c2b(self.c2a(p1))
        p2 = self.p2(c2)

        c3 = self.c3b(self.c3a(p2))

        u4 = self.u4(c3); u4 = layers.concatenate([u4, c2])
        c4 = self.c4b(self.c4a(u4))

        u5 = self.u5(c4); u5 = layers.concatenate([u5, c1])
        c5 = self.c5a(u5)
        feat = self.c5b(c5)                     # (B,H,W,32)

        flow = self.head_flow(feat)             # [-1,1]
        res  = self.head_res(feat)              # normalized residual
        return flow, res

    def call(self, inputs, training=None):
        flow, res = self._forward_heads(inputs)
        x_prev_norm = self.norm_in(inputs[:, -1, ...])
        _, x_evo_norm = self.evolve_step(x_prev_norm, flow, res)
        return self.denorm_out(x_evo_norm)      # (B,H,W,1) mm/h

    def weighted_l1(self, y_true, y_pred):
        w = tf.minimum(24.0, 1.0 + y_true)
        return tf.reduce_mean(tf.abs(y_true - y_pred) * w)

    def motion_reg(self, flow_px, weight_field_mmph):
        # flow_px: (B,H,W,2) in pixels (after scaling)
        vy = flow_px[..., 0:1]; vx = flow_px[..., 1:2]
        gy = sobel_grad_l2(vy); gx = sobel_grad_l2(vx)    # (B,H,W,1)
        w  = tf.minimum(24.0, 1.0 + weight_field_mmph)
        return tf.reduce_mean(w * (gy + gx))

    def train_step(self, data):
        x, y = data  # x:(B,T,H,W,1) mm/h, y:(B,H,W,1) mm/h
        with tf.GradientTape() as tape:
            flow, res = self._forward_heads(x)
            x_prev_norm = self.norm_in(x[:, -1, ...])
            x_adv_norm, x_evo_norm = self.evolve_step(x_prev_norm, flow, res)
            x_adv = self.denorm_out(x_adv_norm)
            x_evo = self.denorm_out(x_evo_norm)

            accum = self.weighted_l1(y, x_adv) + self.weighted_l1(y, x_evo)
            motion = self.motion_reg(flow * self.max_disp, y)
            loss = accum + self.lambda_motion * motion

        grads = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))

        self.loss_tracker.update_state(loss)
        self.accum_tracker.update_state(accum)
        self.motion_tracker.update_state(motion)
        self.mae_tracker.update_state(y, x_evo)
        return {"loss": self.loss_tracker.result(),
                "accum_loss": self.accum_tracker.result(),
                "motion_reg": self.motion_tracker.result(),
                "mae": self.mae_tracker.result()}

    def test_step(self, data):
        x, y = data
        flow, res = self._forward_heads(x)
        x_prev_norm = self.norm_in(x[:, -1, ...])
        x_adv_norm, x_evo_norm = self.evolve_step(x_prev_norm, flow, res)
        x_adv = self.denorm_out(x_adv_norm)
        x_evo = self.denorm_out(x_evo_norm)

        accum = self.weighted_l1(y, x_adv) + self.weighted_l1(y, x_evo)
        motion = self.motion_reg(flow * self.max_disp, y)
        loss = accum + self.lambda_motion * motion

        self.loss_tracker.update_state(loss)
        self.accum_tracker.update_state(accum)
        self.motion_tracker.update_state(motion)
        self.mae_tracker.update_state(y, x_evo)
        return {"loss": self.loss_tracker.result(),
                "accum_loss": self.accum_tracker.result(),
                "motion_reg": self.motion_tracker.result(),
                "mae": self.mae_tracker.result()}

    @property
    def metrics(self):
        return [self.loss_tracker, self.accum_tracker, self.motion_tracker, self.mae_tracker]

In [14]:
# 2. Steps per epochの再計算
def calculate_importance_sampling_steps(date_list: List[str],
                                      batch_size: int,
                                      crop_size: Tuple[int, int] = (256, 256),
                                      stride: int = 32,
                                      downsample: int = 5,
                                      importance_factor: float = 2.0) -> int:
    """
    Importance Samplingでのsteps per epochを計算

    Parameters:
    -----------
    importance_factor : float
        Importance Samplingによる有効サンプル増加係数
        重要なサンプルが繰り返し選ばれるため、実効的なデータ量が増加する
    """

    # 推定フレーム数（既存の計算）
    estimated_frames_per_day = 1440  # 1日 = 1440分, 10分間隔で144フレーム
    total_frames = len(date_list) * estimated_frames_per_day

    # 空間クロップ数の推定
    # 仮定: 各日のデータサイズを (144, 400, 300) とする
    estimated_h, estimated_w = 400, 300  # ダウンサンプリング後のサイズ
    crop_h, crop_w = crop_size

    # クロップ数の計算
    n_crops_h = max(1, (estimated_h - crop_h) // stride + 1)
    n_crops_w = max(1, (estimated_w - crop_w) // stride + 1)
    crops_per_day = n_crops_h * n_crops_w

    # Importance Samplingによる有効サンプル数
    # 重要なサンプルが重複選択されるため、実効的なサンプル数は増加
    effective_samples = int(len(date_list) * crops_per_day * importance_factor)

    # バッチ数に変換
    steps_per_epoch = max(1, effective_samples // batch_size)

    print(f"Importance Sampling Steps Calculation:")
    print(f"  Days: {len(date_list)}")
    print(f"  Estimated crops per day: {crops_per_day}")
    print(f"  Importance factor: {importance_factor}")
    print(f"  Effective samples: {effective_samples}")
    print(f"  Steps per epoch: {steps_per_epoch}")

    return steps_per_epoch

# steps per epochを計算
importance_steps_per_epoch = calculate_importance_sampling_steps(
    train_dates, BATCH_SIZE, importance_factor=2.5
)

importance_val_steps = calculate_importance_sampling_steps(
    val_dates, BATCH_SIZE, importance_factor=1.5  # 検証では重複が少ない
)

import tensorflow as tf
import numpy as np

# 例: 既存の steps_per_epoch, EPOCHS をそのまま使う
warmup_epochs = 5
lr_start = 3e-4   # ウォームアップ開始
lr_max   = 1e-3   # ピーク
lr_min   = 1e-4   # 終了時（最小）

class WarmupCosine(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, steps_per_epoch, total_epochs,
                 warmup_epochs, lr_start, lr_max, lr_min):
        self.spe = float(steps_per_epoch)
        self.T = float(total_epochs)
        self.warm = float(warmup_epochs)
        self.lr_start = float(lr_start)
        self.lr_max = float(lr_max)
        self.lr_min = float(lr_min)

    def __call__(self, step):
        # step → epoch に変換
        epoch = tf.cast(step, tf.float32) / self.spe
        # 1) Warmup
        lr_warm = self.lr_start + (self.lr_max - self.lr_start) * (epoch / self.warm)
        # 2) Cosine decay (after warmup)
        progress = tf.clip_by_value((epoch - self.warm) / (self.T - self.warm), 0.0, 1.0)
        lr_cos = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1.0 + tf.cos(np.pi * progress))
        return tf.where(epoch < self.warm, lr_warm, lr_cos)

# 学習率スケジュールを作成
lr_sched = WarmupCosine(importance_steps_per_epoch, EPOCHS, warmup_epochs, lr_start, lr_max, lr_min)

# Optimizer を差し替え（clipnorm で安定化）
opt = tf.keras.optimizers.Adam(learning_rate=lr_sched, clipnorm=1.0, epsilon=1e-7)

Importance Sampling Steps Calculation:
  Days: 2
  Estimated crops per day: 10
  Importance factor: 2.5
  Effective samples: 50
  Steps per epoch: 3
Importance Sampling Steps Calculation:
  Days: 2
  Estimated crops per day: 10
  Importance factor: 1.5
  Effective samples: 30
  Steps per epoch: 1


In [15]:
model = EvolutionUNetConvLSTM(
    seq_len=SEQ_LEN,
    scale_factor=SCALE_FACTOR,
    max_disp=3.0,
    lambda_motion=1e-2
)
model.compile(optimizer=opt)
model.build(input_shape=(None, SEQ_LEN, H, W, 1))
_ = model(tf.zeros((1, SEQ_LEN, H, W, 1), dtype=tf.float32), training=False)
model.summary()



### 学習可視化用

In [16]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import tensorflow as tf
import tensorboardX

# --- JMA colormap (既存と同じ) ---
JMA_COLORS = ["#f2f2ff","#a0d2ff","#218cff","#0000ff",
              "#faf500","#ff9900","#ff2100","#c800c8"]
JMA_BOUNDS = [0,1,5,10,20,30,50,80,1000]
JMA_CMAP   = mcolors.ListedColormap(JMA_COLORS)
JMA_NORM   = mcolors.BoundaryNorm(JMA_BOUNDS, JMA_CMAP.N)

def _quiver_downsample(vy, vx, step=8):
    """矢印が多すぎて見づらいのを防ぐため、格子を間引く"""
    return vy[::step, ::step], vx[::step, ::step]

def _imshow_mmhr(ax, img2d, title=""):
    m = ax.imshow(img2d, cmap=JMA_CMAP, norm=JMA_NORM)
    ax.set_title(title); ax.set_xticks([]); ax.set_yticks([])
    return m

In [17]:
def visualize_flow_source(model, x_batch, y_batch=None, quiver_step=8, figsize=(12,6), show=False, print_stats=True):
    """
    model  : EvolutionUNetConvLSTM
    x_batch: (B,T,H,W,1)  mm/h
    y_batch: (B,H,W,1) or None
    """
    model.trainable = False
    x_batch = tf.convert_to_tensor(x_batch, dtype=tf.float32)
    B, T, H, W, _ = x_batch.shape

    # 1) ヘッド出力（flow [-1,1], res(正規化), 最後フレーム正規化）
    flow, res_norm = model._forward_heads(x_batch)      # (B,H,W,2), (B,H,W,1)
    x_prev_norm = model.norm_in(x_batch[:, -1, ...])    # (B,H,W,1)

    # 2) 進化演算子：x_adv_norm, x_evo_norm
    x_adv_norm, x_evo_norm = model.evolve_step(x_prev_norm, flow, res_norm)

    # 3) 可視化は mm/h で
    x_prev = model.denorm_out(x_prev_norm).numpy()[0, ..., 0]
    x_adv  = model.denorm_out(x_adv_norm ).numpy()[0, ..., 0]
    x_evo  = model.denorm_out(x_evo_norm ).numpy()[0, ..., 0]
    source = model.denorm_out(res_norm    ).numpy()[0, ..., 0]  # ← 残差(mm/h)

    vy = (flow.numpy()[0, ..., 0]) * model.max_disp  # [px/step]
    vx = (flow.numpy()[0, ..., 1]) * model.max_disp

    # 4) 図を描く
    fig = plt.figure(figsize=figsize)
    ax1 = fig.add_subplot(2,3,1); _imshow_mmhr(ax1, x_prev, "x_prev (t-1)")
    ax2 = fig.add_subplot(2,3,2); _imshow_mmhr(ax2, x_adv,  "x_adv = warp(x_prev)")
    ax3 = fig.add_subplot(2,3,3); _imshow_mmhr(ax3, x_evo,  "x_evo = ReLU(x_adv + s)")
    ax4 = fig.add_subplot(2,3,4); _imshow_mmhr(ax4, source, "source s (mm/h)")
    ax5 = fig.add_subplot(2,3,5)
    _imshow_mmhr(ax5, x_prev, "flow (quiver on x_prev)")
    # クイバーを間引いて描画
    yy, xx = np.mgrid[0:H, 0:W]
    yq, xq = yy[::quiver_step, ::quiver_step], xx[::quiver_step, ::quiver_step]
    vyq, vxq = vy[::quiver_step, ::quiver_step], vx[::quiver_step, ::quiver_step]
    ax5.quiver(xq, yq, vxq, vyq, angles='xy', scale_units='xy', scale=1.0, width=0.002)

    # 5) もし教師があれば差分も
    if y_batch is not None:
        y = tf.convert_to_tensor(y_batch, dtype=tf.float32).numpy()[0, ..., 0]
        ax6 = fig.add_subplot(2,3,6); _imshow_mmhr(ax6, np.abs(y - x_evo), "|y - x_evo|")
    plt.tight_layout()
    # plt.show()

    if show:
        from IPython.display import display
        display(fig)
    else:
        plt.close(fig)

    # 6) 重要メトリクス（モニタリング用）
    stats = {}
    if y_batch is not None:
        y = tf.convert_to_tensor(y_batch, dtype=tf.float32).numpy()[0, ..., 0]
        w = np.minimum(24.0, 1.0 + y)
        stats["L1_w_y-x_adv"] = float(np.mean(w * np.abs(y - x_adv)))
        stats["L1_w_y-x_evo"] = float(np.mean(w * np.abs(y - x_evo)))
    stats["flow_mean_mag(px)"] = float(np.mean(np.sqrt(vy**2 + vx**2)))
    stats["source_mean(mm/h)"] = float(np.mean(np.abs(source)))
    stats["neg_before_relu(%)"] = float(100.0 * np.mean((x_adv + source) < 0.0))

    if print_stats:
        print("Metrics:", stats)

    return {"x_prev": x_prev, "x_adv": x_adv, "x_evo": x_evo,
            "source": source, "flow_vy": vy, "flow_vx": vx, "stats": stats}

In [18]:
class FlowSourceVizCallback(keras.callbacks.Callback):
    def __init__(self, val_sample, quiver_step=8, tag_prefix="evo"):
        """
        val_sample: (x_val, y_val)  いずれも numpy (B,T,H,W,1) / (B,H,W,1)
        """
        super().__init__()
        self.sample_x, self.sample_y = val_sample
        self.quiver_step = quiver_step
        self.tag_prefix = tag_prefix

        # JMA カラーを RGB 3ch (0-1) へ
        self.jma_colors = np.array([[int(h[i:i+2],16)/255. for i in (1,3,5)]
                                    for h in JMA_COLORS], dtype=np.float32)

    def _to_rgb(self, mmhr_2d):
        mmhr_2d = np.clip(mmhr_2d, 0, 1000)
        idx = np.digitize(mmhr_2d, JMA_BOUNDS) - 1
        idx = np.clip(idx, 0, len(self.jma_colors)-1)
        rgb = self.jma_colors[idx]                  # (H,W,3)
        return rgb.transpose(2,0,1)[None,...]       # (1,3,H,W)

    def on_train_begin(self, logs=None):
        self.writer = tensorboardX.SummaryWriter()

    def on_epoch_end(self, epoch, logs=None):
        x = self.sample_x[:1]   # 1サンプルだけ
        y = self.sample_y[:1] if self.sample_y is not None else None

        # 予測と中間生成物
        out = visualize_flow_source(self.model, x, y,
                                    quiver_step=self.quiver_step,
                                    figsize=(12,6),
                                    show=False,
                                    print_stats=False)
        x_prev, x_adv, x_evo, source = out["x_prev"], out["x_adv"], out["x_evo"], out["source"]
        vy, vx = out["flow_vy"], out["flow_vx"]

        # 画像（RGB）をTBへ
        for name, arr in [("x_prev", x_prev), ("x_adv", x_adv), ("x_evo", x_evo), ("source", source)]:
            rgb = self._to_rgb(arr.astype(np.float32))
            self.writer.add_images(f"{self.tag_prefix}/{name}", rgb, epoch)

        # flow magnitude をグレースケールで
        mag = np.sqrt(vy**2 + vx**2)
        mag_norm = (mag / (np.max(mag) + 1e-6)).astype(np.float32)[None,None,...]  # (1,1,H,W)
        self.writer.add_images(f"{self.tag_prefix}/flow_mag_norm", mag_norm, epoch)

        # スカラーメトリクス
        for k, v in out["stats"].items():
            self.writer.add_scalar(f"{self.tag_prefix}/{k}", v, epoch)

## 学習

In [19]:
class TensorBoardXCallback(keras.callbacks.Callback):
    def __init__(self, val_sample):
        super().__init__()
        # --- JMA カラーマップ ---
        cmap_hex = ["#f2f2ff","#a0d2ff","#218cff","#0000ff",
                    "#faf500","#ff9900","#ff2100","#c800c8"]
        self.jma_colors = np.array([[int(h[i:i+2],16)/255. for i in (1,3,5)]
                                    for h in cmap_hex], dtype=np.float32)
        self.color_border = [0,1,5,10,20,30,50,80,1000]

        self.sample_x = val_sample  # 1 バッチだけ保持しておく

    def on_train_begin(self, logs=None):
        self.writer = tensorboardX.SummaryWriter()

    def on_epoch_end(self, epoch, logs=None):
        # ------ 損失を記録 ------
        self.writer.add_scalar("loss/train", logs["loss"],     epoch)
        self.writer.add_scalar("loss/val",   logs["val_loss"], epoch)

        # ------ 推論画像を記録 ------
        pred = self.model.predict(self.sample_x, verbose=0)[0]     # (H,W,1)
        pred = np.clip(pred[...,0], 0, 1e9)                       # (H,W)

        idx  = np.digitize(pred, self.color_border) - 1            # (H,W)
        rgb  = self.jma_colors[idx]                                # (H,W,3)
        rgb  = rgb.transpose(2,0,1)[None,...]                      # (1,3,H,W)

        self.writer.add_images("pred", rgb, epoch)

In [20]:
import itertools
from typing import Generator, Tuple, Optional

def seq_batch_generator(rains: np.ndarray,
                       batch_size: int,
                       lon_slice: slice,
                       lat_slice: slice,
                       seq_len: int = 6,
                       downsample: int = 5,
                       shuffle: bool = True):
    """
    単一のnumpy配列からシーケンスバッチを生成するジェネレータ

    Parameters:
    -----------
    rains : np.ndarray
        形状 (T, H, W) の降水量データ
    batch_size : int
        バッチサイズ
    lon_slice, lat_slice : slice
        経度・緯度方向のスライス
    seq_len : int
        シーケンス長
    downsample : int
        ダウンサンプリング係数
    shuffle : bool
        インデックスをシャッフルするかどうか

    Yields:
    -------
    Tuple[np.ndarray, np.ndarray]
        (入力データ, ターゲットデータ)のタプル
        入力: (batch_size, seq_len, H, W, 1)
        ターゲット: (batch_size, H, W, 1)
    """

    # データの前処理
    data = rains[:, lat_slice, lon_slice]
    data = data[:, ::downsample, ::downsample]

    T = data.shape[0]

    while True:  # 無限ループでエポックを繰り返す
        # 有効なインデックスを生成
        indices = list(range(0, T - seq_len - 1))

        if shuffle:
            np.random.shuffle(indices)

        # バッチサイズずつ処理
        for i in range(0, len(indices), batch_size):
            batch_indices = indices[i:i + batch_size]

            xs, ys = [], []

            for t in batch_indices:
                # シーケンス窓を取得
                window = data[t:t + seq_len + 1]

                # 欠測値チェック
                if np.isnan(window).any():
                    continue

                # 入力とターゲットに分割
                x_seq = window[:-1][..., None]  # (seq_len, H, W, 1)
                y_target = window[-1][..., None]  # (H, W, 1)

                xs.append(x_seq)
                ys.append(y_target)

            if xs:
                yield np.array(xs, dtype=np.float32), np.array(ys, dtype=np.float32)

def get_importance_validation_sample_fixed(date_list: List[str],
                                         batch_size: int,
                                         lon_slice: slice,
                                         lat_slice: slice,
                                         seq_len: int = 6,
                                         downsample: int = 5,
                                         max_attempts: int = 10) -> Tuple[np.ndarray, np.ndarray]:
    """
    修正版: Importance Samplingを使った検証サンプルを安全に取得
    ジェネレータを使い回さず、専用の取得関数として実装
    """

    from google.colab import userdata
    import boto3
    import blosc2
    import os

    # Wasabi Cloud設定
    WASABI_ENDPOINT = "https://s3.ap-northeast-1.wasabisys.com"
    WASABI_ACCESS_KEY = userdata.get('WASABI_ACCESS_KEY')
    WASABI_SECRET_KEY = userdata.get('WASABI_SECRET_KEY')
    BUCKET_NAME = "xrain-composite-cx"

    s3_client = boto3.client(
        's3',
        endpoint_url=WASABI_ENDPOINT,
        aws_access_key_id=WASABI_ACCESS_KEY,
        aws_secret_access_key=WASABI_SECRET_KEY,
        region_name='ap-northeast-1'
    )

    def load_single_date(date: str) -> Optional[np.ndarray]:
        """単一の日付のデータを安全にロード"""
        try:
            base_path = "b2nd"
            year, month = date[0:4], date[4:6]
            filename = f"{base_path}/{year}/{month}/cx-{date}.b2nd"
            temp_filename = f"/tmp/cx-{date}-sample.b2nd"

            s3_client.download_file(BUCKET_NAME, filename, temp_filename)
            rains = blosc2.open(temp_filename)
            data = rains[:, lat_slice, lon_slice]
            data = data[:, ::downsample, ::downsample]

            if os.path.exists(temp_filename):
                os.remove(temp_filename)

            return data.astype(np.float32)
        except Exception as e:
            print(f"Failed to load validation sample for {date}: {e}")
            return None

    # 最初の日付からデータをロード
    sample_date = date_list[0]
    data = load_single_date(sample_date)

    print(f"Loaded validation sample from {sample_date}, shape: {data.shape}")

    # クロップ生成と重要度計算
    crop_size = (256, 256)
    T, H, W = data.shape
    crop_h, crop_w = crop_size

    best_crop = None
    best_score = -1

    # 複数の位置でクロップを試行
    for attempt in range(max_attempts):
        # ランダムな位置を選択
        h_start = np.random.randint(0, max(1, H - crop_h))
        w_start = np.random.randint(0, max(1, W - crop_w))
        h_end = min(h_start + crop_h, H)
        w_end = min(w_start + crop_w, W)

        crop = data[:, h_start:h_end, w_start:w_end]

        if crop.shape[0] <= seq_len:
            continue

        # 受容確率を計算
        score = calculate_acceptance_probability(crop, training_mode=False)

        if score > best_score:
            best_score = score
            best_crop = crop

    if best_crop is None:
        print("Warning: No valid crop found, using first available")
        best_crop = data[:, :crop_h, :crop_w]

    # シーケンスを生成
    T_crop = best_crop.shape[0]


    # 最も降水量が多い時点を中心にシーケンスを選択
    max_rain_per_frame = [np.max(best_crop[t]) for t in range(T_crop)]
    best_end_idx = np.argmax(max_rain_per_frame)

    # シーケンスの開始点を調整
    start_idx = max(0, min(best_end_idx - seq_len, T_crop - seq_len - 1))

    window = best_crop[start_idx:start_idx + seq_len + 1]

    # 入力とターゲットに分割
    x_seq = window[:-1][..., None]  # (seq_len, H, W, 1)
    y_target = window[-1][..., None]  # (H, W, 1)

    # バッチ次元を追加
    x_batch = x_seq[np.newaxis]  # (1, seq_len, H, W, 1)
    y_batch = y_target[np.newaxis]  # (1, H, W, 1)

    print(f"Generated validation sample with max rain: {np.max(y_target):.2f} mm/h")

    return x_batch.astype(np.float32), y_batch.astype(np.float32)


def create_finite_generator(date_list: List[str],
                          batch_size: int,
                          lon_slice: slice,
                          lat_slice: slice,
                          seq_len: int = 6,
                          downsample: int = 5,
                          max_batches_per_epoch: int = 100) -> Generator:
    """
    有限回数で停止するジェネレータ（デバッグ用）
    """
    batch_count = 0

    for x_batch, y_batch in importance_sampling_generator(
        date_list, batch_size, lon_slice, lat_slice,
        seq_len, downsample, buffer_size=2
    ):
        yield x_batch, y_batch
        batch_count += 1

        if batch_count >= max_batches_per_epoch:
            print(f"Generator stopped after {batch_count} batches (safety limit)")
            break


val_sample_x, val_sample_y = get_importance_validation_sample_fixed(
    val_dates, BATCH_SIZE, LON_SLICE, LAT_SLICE, SEQ_LEN, DOWNSAMPLE
)

# 4. コールバックの更新
fs_cb_importance = FlowSourceVizCallback(
    (val_sample_x, val_sample_y),
    quiver_step=8,
    tag_prefix="evo_importance"
)

print("=== Importance Sampling Configuration ===")
print(f"Original steps per epoch: {multi_steps_per_epoch}")
print(f"New steps per epoch: {importance_steps_per_epoch}")
print(f"Original val steps: {multi_val_steps}")
print(f"New val steps: {importance_val_steps}")

Loaded validation sample from 20240819, shape: (1440, 192, 128)
Generated validation sample with max rain: nan mm/h
=== Importance Sampling Configuration ===
Original steps per epoch: 449
New steps per epoch: 3
Original val steps: 179
New val steps: 1


In [None]:
# 5. 最終的なmodel.fit呼び出し
model.fit(
    train_gen,  # Importance Samplingジェネレータ
    steps_per_epoch=multi_steps_per_epoch,  # 計算されたsteps
    epochs=EPOCHS,
    validation_data=val_gen,  # Importance Samplingジェネレータ
    validation_steps=val_steps,  # 計算されたsteps
    callbacks=[
        keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True),
        keras.callbacks.TerminateOnNaN(),
        TensorBoardXCallback(val_sample_x[:1]),  # 検証サンプル
        fs_cb_importance,  # コールバック
    ],
)

Loaded data for dates ['20240809', '20240810'], combined shape: (2880, 192, 128)
Loaded data for dates ['20240810'], combined shape: (1440, 192, 128)
Epoch 1/200
[1m  2/449[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m3:46[0m 507ms/step - accum_loss: 1.9274 - loss: 1.9274 - mae: 0.0615 - motion_reg: 4.6471e-04 Loaded data for dates ['20240809', '20240810'], combined shape: (2880, 192, 128)
[1m  3/449[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m2:30:42[0m 20s/step - accum_loss: 2.0517 - loss: 2.0517 - mae: 0.0654 - motion_reg: 6.0858e-04Loaded data for dates ['20240810'], combined shape: (1440, 192, 128)
[1m  4/449[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m2:20:33[0m 19s/step - accum_loss: 2.0046 - loss: 2.0046 - mae: 0.0641 - motion_reg: 7.1675e-04Loaded data for dates ['20240809', '20240810'], combined shape: (2880, 192, 128)
[1m  5/449[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m2:38:17[0m 21s/step - accum_loss: 1.9885 - loss: 1.9885 - mae: 0.0638 - motion_reg: 8.3583e-04Loaded data for dates ['202

## 学習結果の表示

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

plt.ioff()  # jupyter notebookで自動的にグラフが表示されないようにする

jma_colors = [
    "#f2f2ff",
    "#a0d2ff",
    "#218cff",
    "#0000ff",
    "#faf500",
    "#ff9900",
    "#ff2100",
    "#c800c8",
]
jma_bounds = [0, 1, 5, 10, 20, 30, 50, 80, 1000]
jma_cmap = mcolors.ListedColormap(jma_colors)
jma_norm = mcolors.BoundaryNorm(jma_bounds, jma_cmap.N)
jma_cmap.set_bad("#ffffff")


In [None]:
import matplotlib.figure


# 2つの雨量データを並べて表示する関数
def plot2(fig: matplotlib.figure.Figure, data0, data1):
    ax0 = fig.add_subplot(1, 2, 1)
    ax1 = fig.add_subplot(1, 2, 2)

    # 配色用の ScalarMappable を一度だけ生成
    sm = plt.cm.ScalarMappable(cmap=jma_cmap, norm=jma_norm)
    sm.set_array([])

    cax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    fig.colorbar(sm, cax=cax, orientation="vertical",
                 boundaries=jma_bounds, spacing="uniform",
                 ticks=[1,5,10,20,30,50,80]).set_label("mm/h")

    ax0.imshow(data0, cmap=jma_cmap, norm=jma_norm)
    ax1.imshow(data1, cmap=jma_cmap, norm=jma_norm)
    ax0.set_title("input")
    ax1.set_title("output")

In [None]:
def make_input_seq(rains: np.ndarray, t: int, seq_len: int = SEQ_LEN):
    """
    t-seq_len … t-1 の窓を (1,T,H,W,1) で返す
    行 = 緯度(LAT_SLICE) / 列 = 経度(LON_SLICE)
    """
    assert t >= seq_len, "t は SEQ_LEN 以上にしてください"

    seq = rains[t-seq_len : t,
                LAT_SLICE,   # ← 行方向 = 緯度
                LON_SLICE]   # ← 列方向 = 経度
    seq = seq[:, ::DOWNSAMPLE, ::DOWNSAMPLE]   # 10 倍間引き
    seq = seq[..., None]                       # (T,H,W,1)
    return seq[np.newaxis].astype("float32")

In [None]:
x_in = make_input_seq(val_rains, SEQ_LEN)
print("x_in shape :", x_in.shape)   # → (1, 6, 192, 128, 1)

y_pred = model(x_in, training=False).numpy()[0, ..., 0]   # (192,128)
print("y_pred shape:", y_pred.shape)

In [None]:
i = SEQ_LEN + 0                              # 表示したいフレーム番号
x_in   = make_input_seq(val_rains, i)        # (1,T,H,W,1)
y_pred = model(x_in, training=False).numpy()[0, ..., 0]   # (H,W)

y_true = val_rains[i,
                   LAT_SLICE,  # 行 = 緯度
                   LON_SLICE]  # 列 = 経度
y_true = y_true[::DOWNSAMPLE, ::DOWNSAMPLE]

fig = plt.figure(figsize=(8,4))
plot2(fig, y_true, y_pred)   # True → Pred を並べて表示
plt.tight_layout()
plt.show()

In [None]:
from tqdm import tqdm

init_i      = SEQ_LEN
predict_num = val_rains.shape[0] - init_i   # val_rains 全体から SEQ_LEN 枚を引いた数
# predict_num = 2

# ── 2) 結果を格納する配列を確保 ───────────────────────────────
H = math.ceil((LAT_SLICE.stop - LAT_SLICE.start) / DOWNSAMPLE)
W = math.ceil((LON_SLICE.stop - LON_SLICE.start) / DOWNSAMPLE)
predict_rain = np.zeros((predict_num, H, W), dtype=np.float32)

# ── 3) val_rains に対してループ推論 ───────────────────────────
for i in tqdm(range(predict_num), desc="Predicting on val_rains"):
    t = init_i + i
    x_in   = make_input_seq(val_rains, t)                           # (1,SEQ_LEN,H,W,1)
    y_pred = model(x_in, training=False).numpy()[0, ..., 0]         # → (H,W)
    predict_rain[i] = y_pred

In [None]:
import matplotlib.animation as animation

fig = plt.figure(figsize=(10, 5))

def update_plot(i):
    fig.clear()
    # 真値を切り出し＆ダウンサンプリング
    true_frame = val_rains[i + init_i,
                           LAT_SLICE, LON_SLICE][::DOWNSAMPLE, ::DOWNSAMPLE]
    pred_frame = predict_rain[i]  # (H,W)
    plot2(fig, true_frame, pred_frame)
    return []

ani = animation.FuncAnimation(
    fig,
    update_plot,
    frames=range(predict_num),
    interval=200,   # ミリ秒
    blit=False
)

# 保存（MP4）
writer = animation.FFMpegWriter(fps=10, codec='libx264')
ani.save("predict_val.mp4", writer=writer, dpi=200)

plt.close(fig)

In [None]:
from IPython.display import HTML, Video
from base64 import b64encode

mp4 = open('predict_val.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()

HTML(f"""
<video width="800" controls>
  <source src="{data_url}" type="video/mp4">
</video>
""")

## 自己回帰

In [None]:
import numpy as np
from collections import deque

# ---- 開始フレーム（0-based）----
START_T = 1070  # ← ここを変更（1-basedなら 1069 にする）

T_total = val_rains.shape[0]
assert START_T >= SEQ_LEN, "START_T は SEQ_LEN 以上にしてください"

# ---- 予測枚数：START_T から最後まで ----
predict_num = T_total - START_T
# predict_num = 2

# ---- 履歴を START_T の直前 SEQ_LEN 枚の真値で初期化 ----
history = deque(maxlen=SEQ_LEN)
for t in range(START_T - SEQ_LEN, START_T):
    frame = val_rains[t, LAT_SLICE, LON_SLICE][::DOWNSAMPLE, ::DOWNSAMPLE]
    history.append(frame.astype("float32"))

# ---- 推論ループ ----
predict_rain = np.empty((predict_num, H, W), dtype="float32")

for i in range(predict_num):
    # 履歴 → (1,SEQ_LEN,H,W,1)
    x = np.stack(history, axis=0)[np.newaxis, ..., np.newaxis]

    # 1枚先（時刻 t = START_T + i）を予測
    y_hat = model(x, training=False).numpy()[0, ..., 0]  # (H,W)

    # 保存 & 履歴更新（自己回帰）
    predict_rain[i] = y_hat
    history.append(y_hat)

print("START_T =", START_T)
print("predict_rain.shape =", predict_rain.shape)  # → (T_total - START_T, H, W)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5), constrained_layout=True)

# カラーバー（1回だけ作成）
sm = plt.cm.ScalarMappable(cmap=jma_cmap, norm=jma_norm)
sm.set_array([])
cbar = fig.colorbar(sm, ax=ax.ravel().tolist(), orientation="vertical", fraction=0.046, pad=0.02)
cbar.set_label("mm/h")

def update_plot(i):
    ax[0].cla()
    ax[1].cla()

    # 左：真値（時刻 START_T+i）
    true_frame = val_rains[START_T + i, LAT_SLICE, LON_SLICE][::DOWNSAMPLE, ::DOWNSAMPLE]
    im0 = ax[0].imshow(true_frame, vmin=0, vmax=SCALE_FACTOR, cmap="Blues")
    ax[0].set_title(f"Ground Truth (t={START_T + i})")

    # 右：自己回帰予測（i 番目）
    pred_frame = predict_rain[i]
    im1 = ax[1].imshow(pred_frame, vmin=0, vmax=SCALE_FACTOR, cmap="Blues")
    ax[1].set_title(f"Prediction (t={START_T + i})")

    return im0, im1

import matplotlib.animation as animation

ani = animation.FuncAnimation(
    fig, update_plot, frames=range(predict_num), interval=200
)

writer = animation.FFMpegWriter(fps=10, codec="libx264")
ani.save("predict_val_self.mp4", writer=writer, dpi=200)

plt.close(fig)

In [None]:
mp4 = open('predict_val_self.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()

HTML(f"""
<video width="800" controls>
  <source src="{data_url}" type="video/mp4">
</video>
""")

In [None]:
from google.colab import runtime

runtime.unassign()