In [1]:
!pip install numpy torch
!pip install GDAL==$(gdal-config --version)
!pip install scikit-learn matplotlib cartopy shapely

Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)
  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch)
  Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch)
  Downloading nvidia_nvjitlink_cu12-1

In [2]:
!ls -l /kaggle/input/btl-ai/DATA_SV/Hima/B04B

total 0
drwxr-xr-x 4 nobody nogroup 0 Apr 27 06:37 2019
drwxr-xr-x 4 nobody nogroup 0 Apr 27 06:37 2020


In [3]:
import os
import numpy as np
from osgeo import gdal
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import logging

# Thiết lập logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Thiết lập thiết bị
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Đặt seed để tái lập kết quả
torch.manual_seed(42)
np.random.seed(42)

Using device: cpu


In [4]:
# Hằng số và đường dẫn
BASE_PATH = "/kaggle/input/btl-ai/DATA_SV"
HIMA_PATH = os.path.join(BASE_PATH, "Hima")
ERA5_PATH = os.path.join(BASE_PATH, "ERA5")
PRECIP_PATH = os.path.join(BASE_PATH, "Precipitation/Radar")
OUTPUT_PATH = "/kaggle/working/output/"
os.makedirs(OUTPUT_PATH, exist_ok=True)

HIMA_BANDS = ['B04B', 'B05B', 'B06B', 'B09B', 'B10B', 'B11B', 'B12B', 'B14B', 'B16B', 'I2B', 'I4B', 'IRB', 'VSB', 'WVB']  # 14 band
ERA5_PARAMS = ['CAPE', 'CIN', 'EWSS', 'IE', 'ISOR', 'KX', 'PEV', 'R250', 'R500', 'R850', 'SLHF', 'SLOR', 'SSHF', 'TCLW', 'TCW', 'TCWV', 'U250', 'U850', 'V250', 'V850']  # 20 tham số
HEIGHT, WIDTH = 90, 250
NUM_EPOCHS = 20
BATCH_SIZE = 2

# Hàm đọc file GeoTIFF
def read_geotiff(file_path):
    try:
        ds = gdal.Open(file_path)
        band = ds.GetRasterBand(1)
        data = band.ReadAsArray().astype(np.float32)
        ds = None
        if data.shape != (HEIGHT, WIDTH):
            logging.warning(f"Invalid shape {data.shape} for file {file_path}, expected ({HEIGHT}, {WIDTH})")
            return None
        return data
    except Exception as e:
        logging.error(f"Error reading {file_path}: {e}")
        return None

In [5]:
# Hàm phân tích thời gian từ tên file
def parse_datetime_from_filename(filename, data_type):
    try:
        if data_type == "Hima":
            time_part = filename.split('_')[1].split('_TB.tif')[0].replace('.Z', '')
            dt = datetime.strptime(time_part, '%Y%m%d%H%M')
        elif data_type in ["ERA5", "Radar"]:
            time_part = filename.split('_')[1].replace('.tif', '')
            dt = datetime.strptime(time_part, '%Y%m%d%H%M%S')
        else:
            return None
        return dt.replace(minute=0, second=0, microsecond=0)
    except Exception as e:
        logging.warning(f"Error parsing datetime from {filename} (type {data_type}): {e}")
        return None

# Hàm thu thập file
def collect_files(base_path, expected_subdirs=None, data_type=None, current_band=None):
    files_dict = {}
    file_count = 0
    for root, _, files in os.walk(base_path):
        for file in files:
            if file.endswith('.tif'):
                file_path = os.path.join(root, file)
                dt = parse_datetime_from_filename(file, data_type)
                if dt is None:
                    continue
                file_count += 1
                if expected_subdirs:
                    if current_band is None:
                        logging.warning(f"current_band not provided, skipping {file_path}")
                        continue
                    if dt not in files_dict:
                        files_dict[dt] = {}
                    files_dict[dt][current_band] = file_path
                else:
                    files_dict[dt] = file_path
    logging.info(f"Found {file_count} files in {base_path}")
    return files_dict

# Hàm tính trung bình toàn cục
def compute_global_means(files_dict, bands, data_type):
    logging.info(f"Computing global means for {data_type}...")
    global_means = {band: [] for band in bands}
    for dt, paths in files_dict.items():
        for band in bands:
            file_path = paths.get(band)
            if file_path and os.path.exists(file_path):
                data = read_geotiff(file_path)
                if data is not None:
                    valid_data = data[~(np.isinf(data) | np.isnan(data) | (data == -9999))]
                    if valid_data.size > 0:
                        global_means[band].append(np.mean(valid_data))
    
    for band in bands:
        if global_means[band]:
            global_means[band] = np.mean(global_means[band])
        else:
            global_means[band] = 0
            logging.warning(f"No valid data for {data_type} band {band}, using 0 as global mean")
    
    logging.info(f"Completed global means for {data_type}")
    return global_means

In [6]:
# Hàm xử lý dữ liệu thiếu và chuẩn hóa
def preprocess_data(data, data_type, missing_threshold=0.7):
    if data is None:
        logging.warning("Received None data for preprocessing")
        return None, None
    
    invalid_mask = np.isinf(data) | np.isnan(data) | (data == -9999)
    invalid_ratio = np.sum(invalid_mask) / data.size
    if invalid_ratio > missing_threshold:
        logging.warning(f"Data has {invalid_ratio*100:.2f}% invalid values, skipping")
        return None, None
    
    if data_type == "Radar":
        data = np.log1p(np.maximum(data, 0))
    else:
        data_min, data_max = np.nanmin(data), np.nanmax(data)
        if data_max > data_min:
            data = (data - data_min) / (data_max - data_min)
        else:
            data = np.zeros_like(data)
    
    return data, invalid_mask

# Hàm điền giá trị thiếu
def fill_missing_values(sequence, invalid_masks, sequence_dts, global_means, data_type="Hima"):
    sequence = sequence.copy()
    n_frames, height, width, n_channels = sequence.shape
    bands = HIMA_BANDS if data_type == "Hima" else ERA5_PARAMS
    
    for c in range(n_channels):
        band_idx = c if data_type == "Hima" else c - len(HIMA_BANDS)
        if band_idx >= len(bands) or band_idx < 0:
            continue
        band = bands[band_idx]
        
        for t in range(n_frames):
            mask = invalid_masks[t, :, :, c]
            if not np.any(mask):
                continue

            # Step 1: Forward fill (≤ 2 hours)
            for t_prev in range(t-1, -1, -1):
                time_diff = (sequence_dts[t] - sequence_dts[t_prev]).total_seconds() / 3600
                if time_diff <= 2 and not np.any(invalid_masks[t_prev, :, :, c][mask]):
                    sequence[t, :, :, c][mask] = sequence[t_prev, :, :, c][mask]
                    invalid_masks[t, :, :, c][mask] = False
                    logging.info(f"Forward filled channel {c}, frame {t}")
                    break

            # Step 2: Backward fill (≤ 2 hours)
            if np.any(mask):
                for t_next in range(t+1, n_frames):
                    time_diff = (sequence_dts[t_next] - sequence_dts[t]).total_seconds() / 3600
                    if time_diff <= 2 and not np.any(invalid_masks[t_next, :, :, c][mask]):
                        sequence[t, :, :, c][mask] = sequence[t_next, :, :, c][mask]
                        invalid_masks[t, :, :, c][mask] = False
                        logging.info(f"Backward filled channel {c}, frame {t}")
                        break

            # Step 3: Linear Interpolation (≤ 4 hours)
            if np.any(mask):
                t_prev, t_next = None, None
                for t_p in range(t-1, -1, -1):
                    time_diff = (sequence_dts[t] - sequence_dts[t_p]).total_seconds() / 3600
                    if time_diff <= 4 and not np.any(invalid_masks[t_p, :, :, c][mask]):
                        t_prev = t_p
                        break
                for t_n in range(t+1, n_frames):
                    time_diff = (sequence_dts[t_n] - sequence_dts[t]).total_seconds() / 3600
                    if time_diff <= 4 and not np.any(invalid_masks[t_n, :, :, c][mask]):
                        t_next = t_n
                        break
                if t_prev is not None and t_next is not None:
                    time_prev = (sequence_dts[t] - sequence_dts[t_prev]).total_seconds() / 3600
                    time_next = (sequence_dts[t_next] - sequence_dts[t]).total_seconds() / 3600
                    total_time = time_prev + time_next
                    if total_time > 0:
                        weight_prev = time_next / total_time
                        weight_next = time_prev / total_time
                        sequence[t, :, :, c][mask] = (
                            weight_prev * sequence[t_prev, :, :, c][mask] +
                            weight_next * sequence[t_next, :, :, c][mask]
                        )
                        invalid_masks[t, :, :, c][mask] = False
                        logging.info(f"Linearly interpolated channel {c}, frame {t}")

            # Step 4: Fallback - Global mean
            if np.any(mask):
                sequence[t, :, :, c][mask] = global_means[band]
                logging.info(f"Filled channel {c}, frame {t} with global mean {global_means[band]}")

    return sequence

In [7]:
# Hàm tạo chuỗi thời gian
def create_time_sequences(hima_files, era5_files, precip_files, common_datetimes, hima_global_means, era5_global_means):
    logging.info("Starting to create time sequences...")
    def generate_sequences():
        for i in range(4, len(common_datetimes)):
            dt = common_datetimes[i]
            valid_sequence = True
            for j in range(1, 5):
                if common_datetimes[i-j] != dt - timedelta(hours=j):
                    valid_sequence = False
                    break
            if not valid_sequence:
                continue

            sequence = []
            invalid_masks = []
            sequence_dts = []
            for j in range(5):
                dt_j = common_datetimes[i-j]
                if dt_j not in hima_files or dt_j not in era5_files:
                    valid_sequence = False
                    break

                # Đọc Himawari
                hima_data = []
                hima_masks = []
                for band in HIMA_BANDS:
                    file_path = hima_files[dt_j].get(band)
                    if not file_path or not os.path.exists(file_path):
                        valid_sequence = False
                        break
                    data = read_geotiff(file_path)
                    if data is None:
                        valid_sequence = False
                        break
                    data, mask = preprocess_data(data, "Hima")
                    if data is None:
                        valid_sequence = False
                        break
                    hima_data.append(data)
                    hima_masks.append(mask)
                if not valid_sequence:
                    break
                hima_data = np.stack(hima_data, axis=-1)
                hima_masks = np.stack(hima_masks, axis=-1)

                # Đọc ERA5
                era5_data = []
                era5_masks = []
                for param in ERA5_PARAMS:
                    file_path = era5_files[dt_j].get(param)
                    if not file_path or not os.path.exists(file_path):
                        valid_sequence = False
                        break
                    data = read_geotiff(file_path)
                    if data is None:
                        valid_sequence = False
                        break
                    data, mask = preprocess_data(data, "ERA5")
                    if data is None:
                        valid_sequence = False
                        break
                    era5_data.append(data)
                    era5_masks.append(mask)
                if not valid_sequence:
                    break
                era5_data = np.stack(era5_data, axis=-1)
                era5_masks = np.stack(era5_masks, axis=-1)

                combined = np.concatenate([hima_data, era5_data], axis=-1)
                combined_masks = np.concatenate([hima_masks, era5_masks], axis=-1)
                sequence.append(combined)
                invalid_masks.append(combined_masks)
                sequence_dts.append(dt_j)
            if not valid_sequence:
                continue

            # Đọc radar
            radar_file = precip_files.get(dt)
            if not radar_file or not os.path.exists(radar_file):
                continue
            radar_data = read_geotiff(radar_file)
            if radar_data is None:
                continue
            radar_data, _ = preprocess_data(radar_data, "Radar")
            if radar_data is None:
                continue

            # Điền giá trị thiếu
            sequence = np.stack(sequence, axis=0)
            invalid_masks = np.stack(invalid_masks, axis=0)
            sequence = fill_missing_values(sequence, invalid_masks, sequence_dts, hima_global_means, "Hima")
            sequence = fill_missing_values(sequence, invalid_masks, sequence_dts, era5_global_means, "ERA5")

            sequence = sequence.transpose(0, 3, 1, 2)
            yield sequence, radar_data

    X, y = [], []
    for seq, target in generate_sequences():
        X.append(seq)
        y.append(target)
    
    logging.info(f"Completed time sequences, generated {len(X)} samples")
    return np.array(X), np.array(y)

In [8]:
# Định nghĩa lớp ConvLSTMCell
class ConvLSTMCell(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding):
        super(ConvLSTMCell, self).__init__()
        self.out_channels = out_channels
        self.conv = nn.Conv2d(
            in_channels + out_channels, 4 * out_channels, kernel_size,
            padding=padding, bias=True
        )

    def forward(self, x, h_prev, c_prev):
        combined = torch.cat([x, h_prev], dim=1)
        conv_out = self.conv(combined)
        i, f, o, g = torch.chunk(conv_out, 4, dim=1)
        i = torch.sigmoid(i)
        f = torch.sigmoid(f)
        o = torch.sigmoid(o)
        g = torch.tanh(g)
        c_next = f * c_prev + i * g
        h_next = o * torch.tanh(c_next)
        return h_next, c_next

# Định nghĩa lớp ConvLSTM2d
class ConvLSTM2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding):
        super(ConvLSTM2d, self).__init__()
        self.cell = ConvLSTMCell(in_channels, out_channels, kernel_size, padding)

    def forward(self, x):
        if len(x.size()) == 5:
            batch, seq_len, channels, height, width = x.size()
            is_sequence = True
        elif len(x.size()) == 4:
            batch, channels, height, width = x.size()
            seq_len = 1
            x = x.unsqueeze(1)
            is_sequence = False
        else:
            raise ValueError(f"Expected 4 or 5 dimensions, got {len(x.size())}")

        out_channels = self.cell.out_channels
        h = torch.zeros(batch, out_channels, height, width, device=x.device)
        c = torch.zeros(batch, out_channels, height, width, device=x.device)
        outputs = []
        for t in range(seq_len):
            x_t = x[:, t, :, :, :]
            h, c = self.cell(x_t, h, c)
            outputs.append(h)
        output = outputs[-1] if is_sequence else h
        return output, (h, c)

# Định nghĩa mô hình ConvLSTM
class ConvLSTMModel(nn.Module):
    def __init__(self):
        super(ConvLSTMModel, self).__init__()
        self.convlstm1 = ConvLSTM2d(in_channels=34, out_channels=64, kernel_size=(5, 5), padding=(2, 2))
        self.bn1 = nn.BatchNorm2d(64)
        self.convlstm2 = ConvLSTM2d(in_channels=64, out_channels=32, kernel_size=(5, 5), padding=(2, 2))
        self.bn2 = nn.BatchNorm2d(32)
        self.dropout = nn.Dropout(0.2)
        self.conv = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=(3, 3), padding=(1, 1))
        self.relu = nn.ReLU()

    def forward(self, x):
        x, _ = self.convlstm1(x)
        x = self.bn1(x)
        x, _ = self.convlstm2(x)
        x = self.bn2(x)
        x = self.dropout(x)
        x = self.conv(x)
        x = self.relu(x)
        return x.squeeze(1)

In [9]:
# Hàm huấn luyện mô hình
def train_model(model, train_loader, val_loader, epochs=NUM_EPOCHS, patience=5):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    best_loss = float('inf')
    patience_counter = 0
    best_model_state = None

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            output = model(X_batch)
            loss = criterion(output, y_batch)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * X_batch.size(0)
        train_loss /= len(train_loader.dataset)

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                output = model(X_batch)
                loss = criterion(output, y_batch)
                val_loss += loss.item() * X_batch.size(0)
            val_loss /= len(val_loader.dataset)

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

        if val_loss < best_loss:
            best_loss = val_loss
            best_model_state = model.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping")
                break

    model.load_state_dict(best_model_state)
    return model

# Hàm tính chỉ số đánh giá
def evaluate_model(y_true, y_pred, threshold=0.0):
    y_true = y_true.reshape(-1)
    y_pred = y_pred.reshape(-1)
    valid_mask = ~(np.isnan(y_true) | np.isnan(y_pred))
    y_true = y_true[valid_mask]
    y_pred = y_pred[valid_mask]
    
    rmse = np.sqrt(mean_squared_error(y_true, y_pred)) if len(y_true) > 0 else float('inf')
    corr = np.corrcoef(y_true, y_pred)[0, 1] if len(y_true) > 1 and np.std(y_true) > 0 and np.std(y_pred) > 0 else 0

    y_true_bin = (y_true > threshold).astype(int)
    y_pred_bin = (y_pred > threshold).astype(int)
    hits = np.sum((y_true_bin == 1) & (y_pred_bin == 1))
    misses = np.sum((y_true_bin == 1) & (y_pred_bin == 0))
    false_alarms = np.sum((y_true_bin == 0) & (y_pred_bin == 1))
    true_negatives = np.sum((y_true_bin == 0) & (y_pred_bin == 0))
    total = hits + misses + false_alarms + true_negatives

    accuracy = (hits + true_negatives) / total if total > 0 else 0
    csi = hits / (hits + misses + false_alarms) if (hits + misses + false_alarms) > 0 else 0
    far = false_alarms / (hits + false_alarms) if (hits + false_alarms) > 0 else 0
    hss = (2 * (hits * true_negatives - misses * false_alarms)) / \
          ((hits + misses) * (misses + true_negatives) + (hits + false_alarms) * (false_alarms + true_negatives)) \
          if ((hits + misses) * (misses + true_negatives) + (hits + false_alarms) * (false_alarms + true_negatives)) > 0 else 0
    ets = ((hits - ((hits + misses) * (hits + false_alarms) / total)) / \
           (hits + misses + false_alarms - ((hits + misses) * (hits + false_alarms) / total))) \
          if (hits + misses + false_alarms - ((hits + misses) * (hits + false_alarms) / total)) > 0 else 0

    return {'rmse': rmse, 'corr': corr, 'accuracy': accuracy, 'csi': csi, 'far': far, 'hss': hss, 'ets': ets}

# Hàm vẽ scatter plot
def plot_scatter(y_true, y_pred, output_path):
    plt.figure(figsize=(8, 6))
    plt.scatter(y_true.flatten(), y_pred.flatten(), alpha=0.5)
    plt.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'r--')
    plt.xlabel('Ground Truth (mm/h)')
    plt.ylabel('Predicted (mm/h)')
    plt.title('Scatter Plot: Predicted vs Ground Truth')
    plt.savefig(output_path)
    plt.close()

# Hàm hiển thị bản đồ
def plot_rainfall_map(y_true, y_pred, output_path):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5), subplot_kw={'projection': ccrs.PlateCarree()})
    ax1.set_title('Ground Truth')
    ax2.set_title('Prediction')
    for ax, data in [(ax1, y_true), (ax2, y_pred)]:
        ax.coastlines()
        ax.add_feature(cfeature.BORDERS)
        im = ax.imshow(data, cmap='Blues', origin='upper', transform=ccrs.PlateCarree())
        plt.colorbar(im, ax=ax, label='Rainfall (mm/h)')
    plt.savefig(output_path)
    plt.close()

# Hàm lưu GeoTIFF
def save_geotiff(data, output_path, reference_file):
    ds = gdal.Open(reference_file)
    driver = gdal.GetDriverByName('GTiff')
    out_ds = driver.Create(output_path, WIDTH, HEIGHT, 1, gdal.GDT_Float32)
    out_ds.SetGeoTransform(ds.GetGeoTransform())
    out_ds.SetProjection(ds.GetProjection())
    out_band = out_ds.GetRasterBand(1)
    out_band.WriteArray(data)
    out_band.FlushCache()
    out_ds = None
    ds = None

In [10]:
# Phần chính
# Thu thập file
logging.info("Collecting Himawari files...")
hima_files_raw = {}
for band in HIMA_BANDS:
    band_path = os.path.join(HIMA_PATH, band)
    if not os.path.exists(band_path):
        logging.warning(f"Directory not found: {band_path}")
        continue
    band_files = collect_files(band_path, expected_subdirs=HIMA_BANDS, data_type="Hima", current_band=band)
    for dt, paths in band_files.items():
        if dt not in hima_files_raw:
            hima_files_raw[dt] = {}
        hima_files_raw[dt][band] = paths[band]

# Lọc các thời điểm có đủ tất cả band
hima_files = {}
for dt, bands in hima_files_raw.items():
    if all(band in bands for band in HIMA_BANDS):
        hima_files[dt] = bands
    else:
        logging.warning(f"Datetime {dt} is missing some Himawari bands, skipping")

logging.info("Collecting ERA5 files...")
era5_files_raw = {}
for param in ERA5_PARAMS:
    param_path = os.path.join(ERA5_PATH, param)
    if not os.path.exists(param_path):
        logging.warning(f"Directory not found: {param_path}")
        continue
    param_files = collect_files(param_path, expected_subdirs=ERA5_PARAMS, data_type="ERA5", current_band=param)
    for dt, paths in param_files.items():
        if dt not in era5_files_raw:
            era5_files_raw[dt] = {}
        era5_files_raw[dt][param] = paths[param]

# Lọc các thời điểm có đủ tất cả param
era5_files = {}
for dt, params in era5_files_raw.items():
    if all(param in params for param in ERA5_PARAMS):
        era5_files[dt] = params
    else:
        logging.warning(f"Datetime {dt} is missing some ERA5 params, skipping")

logging.info("Collecting Precipitation files...")
precip_files = collect_files(PRECIP_PATH, data_type="Radar")

# Tìm thời gian chung
common_datetimes = sorted(list(set(hima_files.keys()) & set(era5_files.keys()) & set(precip_files.keys())))
logging.info(f"Found {len(common_datetimes)} common datetimes")

# Tính global means
hima_global_means = compute_global_means(hima_files, HIMA_BANDS, "Hima")
era5_global_means = compute_global_means(era5_files, ERA5_PARAMS, "ERA5")

# Tạo chuỗi thời gian
X, y = create_time_sequences(hima_files, era5_files, precip_files, common_datetimes, hima_global_means, era5_global_means)

# Chia dữ liệu
logging.info("Splitting data into train/val/test...")
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)
logging.info(f"Train: {len(X_train)} samples, Val: {len(X_val)} samples, Test: {len(X_test)} samples")

# Chuyển sang tensor
X_train = torch.tensor(X_train, dtype=torch.float32).to(device)
y_train = torch.tensor(y_train, dtype=torch.float32).to(device)
X_val = torch.tensor(X_val, dtype=torch.float32).to(device)
y_val = torch.tensor(y_val, dtype=torch.float32).to(device)
X_test = torch.tensor(X_test, dtype=torch.float32).to(device)
y_test = torch.tensor(y_test, dtype=torch.float32).to(device)

# Tạo DataLoader
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
test_dataset = TensorDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

# Khởi tạo mô hình
model = ConvLSTMModel().to(device)
logging.info("Initialized ConvLSTM model")

# Xóa cache GPU
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    logging.info("Cleared GPU cache")

# Huấn luyện mô hình
logging.info("Starting training...")
model = train_model(model, train_loader, val_loader)

# Lưu mô hình
torch.save(model.state_dict(), os.path.join(OUTPUT_PATH, "convlstm_model.pth"))
logging.info("Model saved to convlstm_model.pth")

# Đánh giá trên tập kiểm thử
logging.info("Evaluating on test set...")
model.eval()
y_pred = []
with torch.no_grad():
    for X_batch, _ in test_loader:
        X_batch = X_batch.to(device)
        output = model(X_batch)
        y_pred.append(output.cpu().numpy())
y_pred = np.concatenate(y_pred, axis=0)

# Đánh giá
metrics = evaluate_model(y_test.numpy(), y_pred)
print("Evaluation Metrics:", metrics)

# Vẽ scatter plot
plot_scatter(y_test.numpy(), y_pred, os.path.join(OUTPUT_PATH, 'scatter_plot.png'))

# Vẽ bản đồ cho mẫu đầu tiên
plot_rainfall_map(y_test[0].numpy(), y_pred[0], os.path.join(OUTPUT_PATH, 'rainfall_map.png'))

# Lưu bản đồ dự đoán dưới dạng GeoTIFF
save_geotiff(y_pred[0], os.path.join(OUTPUT_PATH, 'predicted_rainfall.tif'),
             precip_files[common_datetimes[-1]])

Epoch 1/20, Train Loss: 0.1116, Val Loss: 0.0673
Epoch 2/20, Train Loss: 0.0777, Val Loss: 0.0663
Epoch 3/20, Train Loss: 0.0653, Val Loss: 0.0658
Epoch 4/20, Train Loss: 0.0692, Val Loss: 0.0617
Epoch 5/20, Train Loss: 0.0592, Val Loss: 0.0698
Epoch 6/20, Train Loss: 0.0582, Val Loss: 0.0635
Epoch 7/20, Train Loss: 0.0551, Val Loss: 0.0673
Epoch 8/20, Train Loss: 0.0540, Val Loss: 0.0621
Epoch 9/20, Train Loss: 0.0508, Val Loss: 0.0676
Early stopping
Evaluation Metrics: {'rmse': 0.19621901, 'corr': 0.6616477218956265, 'accuracy': 0.8937179138321996, 'csi': 0.2568134105424724, 'far': 0.6998443291326909, 'hss': 0.3585643622798681, 'ets': 0.2184455814410702}


