In [28]:
!pip install GDAL



In [29]:
import os
import numpy as np
from osgeo import gdal
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import glob

In [30]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [31]:
# Đặt seed để tái lập kết quả
torch.manual_seed(42)
np.random.seed(42)

# Định nghĩa các hằng số và đường dẫn
BASE_PATH = "/kaggle/input/btl-ai/DATA_SV"
HIMA_PATH = os.path.join(BASE_PATH, "Hima")
ERA5_PATH = os.path.join(BASE_PATH, "ERA5")
PRECIP_PATH = os.path.join(BASE_PATH, "Precipitation/Radar")
OUTPUT_PATH = "/kaggle/working/output/"
os.makedirs(OUTPUT_PATH, exist_ok=True)

In [32]:
HIMA_BANDS = ['B04B', 'B05B', 'B06B', 'B09B', 'B10B', 'B11B', 'B12B', 'B14B', 'B16B', 'I2B', 'I4B', 'IRB', 'VSB', 'WVB']  # 14 band
ERA5_PARAMS = ['CAPE', 'CIN', 'EWSS', 'IE', 'ISOR', 'KX', 'PEV', 'R250', 'R500', 'R850', 'SLHF', 'SLOR', 'SSHF', 'TCLW', 'TCW', 'TCWV', 'U250', 'U850', 'V250', 'V850']  # 20 tham số
HEIGHT, WIDTH = 90, 250

In [33]:
# 2. Hàm xử lý dữ liệu
# Hàm đọc file GeoTIFF
def read_geotiff(file_path):
    try:
        ds = gdal.Open(file_path)
        band = ds.GetRasterBand(1)
        data = band.ReadAsArray()
        ds = None
        if data.shape != (HEIGHT, WIDTH):
            print(f"Invalid shape {data.shape} for file {file_path}, expected ({HEIGHT}, {WIDTH})")
            return None
        return data
    except Exception as e:
        print(f"Error reading {file_path}: {e}")
        return None

In [34]:
# Hàm phân tích thời gian từ tên file
def parse_datetime_from_filename(filename, data_type):
    try:
        if data_type == "Hima":
            parts = filename.split('_')
            if len(parts) < 2:
                return None
            time_part = parts[1].split('_TB.tif')[0]
            time_part = time_part.replace('.Z', '')
            dt = datetime.strptime(time_part, '%Y%m%d%H%M')
        elif data_type == "ERA5":
            parts = filename.split('_')
            if len(parts) < 2:
                return None
            time_part = parts[1].replace('.tif', '')
            dt = datetime.strptime(time_part, '%Y%m%d%H%M%S')
        elif data_type == "Radar":
            time_part = filename.split('_')[1].replace('.tif', '')
            dt = datetime.strptime(time_part, '%Y%m%d%H%M%S')
        else:
            return None
        return dt.replace(minute=0, second=0, microsecond=0)
    except Exception as e:
        global error_count
        if error_count < 5:
            print(f"Error parsing datetime from {filename} (type {data_type}): {e}")
            error_count += 1
        return None

In [35]:
error_count = 0

# Hàm thu thập file
def collect_files(base_path, expected_subdirs=None, data_type=None):
    files_dict = {}
    file_count = 0
    for root, dirs, files in os.walk(base_path):
        for file in files:
            if file.endswith('.tif'):
                file_path = os.path.join(root, file)
                dt = parse_datetime_from_filename(file, data_type)
                if dt is None:
                    continue
                file_count += 1
                if expected_subdirs:
                    subdir = os.path.basename(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(file_path)))))
                    if dt not in files_dict:
                        files_dict[dt] = {}
                    files_dict[dt][subdir] = file_path
                else:
                    files_dict[dt] = file_path
    print(f"Found {file_count} files in {base_path}")
    return files_dict

In [36]:
# Hàm xử lý dữ liệu thiếu và chuẩn hóa
def preprocess_data(data, data_type):
    if data is None:
        return None
    # Xử lý giá trị không hợp lệ
    data = np.where(np.isinf(data) | np.isnan(data) | (data == -9999), 0, data)
    if data_type == "Radar":
        # Log transformation cho radar
        data = np.log1p(np.maximum(data, 0))  # log(1 + x)
    else:
        # Min-max scaling cho Himawari và ERA5
        data_min, data_max = np.min(data), np.max(data)
        if data_max > data_min:
            data = (data - data_min) / (data_max - data_min)
        else:
            data = np.zeros_like(data)
    return data

In [37]:
# Hàm tạo chuỗi thời gian t-4 đến t
def create_time_sequences(hima_files, era5_files, precip_files, common_datetimes):
    X, y = [], []
    for i in range(4, len(common_datetimes)):
        dt = common_datetimes[i]
        # Kiểm tra tính liên tục của 5 khung
        valid_sequence = True
        for j in range(1, 5):
            if common_datetimes[i-j] != dt - timedelta(hours=j):
                valid_sequence = False
                break
        if not valid_sequence:
            continue

        # Tạo chuỗi 5 khung
        sequence = []
        for j in range(5):
            dt_j = common_datetimes[i-j]
            # Đọc Himawari
            hima_data = []
            for band in HIMA_BANDS:
                file_path = hima_files.get(dt_j, {}).get(band)
                if not file_path:
                    valid_sequence = False
                    break
                data = read_geotiff(file_path)
                data = preprocess_data(data, "Hima")
                if data is None:
                    valid_sequence = False
                    break
                hima_data.append(data)
            if not valid_sequence:
                break
            hima_data = np.stack(hima_data, axis=-1)  # (90, 250, 14)

            # Đọc ERA5
            era5_data = []
            for param in ERA5_PARAMS:
                file_path = era5_files.get(dt_j, {}).get(param)
                if not file_path:
                    valid_sequence = False
                    break
                data = read_geotiff(file_path)
                data = preprocess_data(data, "ERA5")
                if data is None:
                    valid_sequence = False
                    break
                era5_data.append(data)
            if not valid_sequence:
                break
            era5_data = np.stack(era5_data, axis=-1)  # (90, 250, 20)

            # Kết hợp Himawari và ERA5
            combined = np.concatenate([hima_data, era5_data], axis=-1)  # (90, 250, 34)
            sequence.append(combined)
        if not valid_sequence:
            continue

        # Đọc radar (ground truth)
        radar_file = precip_files.get(dt)
        if not radar_file:
            continue
        radar_data = read_geotiff(radar_file)
        radar_data = preprocess_data(radar_data, "Radar")
        if radar_data is None:
            continue

        sequence = np.stack(sequence, axis=0)  # (5, 90, 250, 34)
        X.append(sequence)
        y.append(radar_data)  # (90, 250)

    X = np.array(X)
    y = np.array(y)
    # Chuyển đổi định dạng để kênh nằm trước chiều không gian
    X = X.transpose(0, 1, 4, 2, 3)  # (samples, 5, 34, 90, 250)
    return X, y

In [38]:
# Định nghĩa lớp ConvLSTMCell tùy chỉnh
class ConvLSTMCell(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding):
        super(ConvLSTMCell, self).__init__()
        self.out_channels = out_channels
        self.conv = nn.Conv2d(
            in_channels + out_channels, 4 * out_channels, kernel_size,
            padding=padding, bias=True
        )

    def forward(self, x, h_prev, c_prev):
        # x: (batch, in_channels, height, width)
        # h_prev, c_prev: (batch, out_channels, height, width)
        combined = torch.cat([x, h_prev], dim=1)  # (batch, in_channels + out_channels, height, width)
        conv_out = self.conv(combined)  # (batch, 4 * out_channels, height, width)
        i, f, o, g = torch.chunk(conv_out, 4, dim=1)  # Mỗi cái: (batch, out_channels, height, width)
        i = torch.sigmoid(i)
        f = torch.sigmoid(f)
        o = torch.sigmoid(o)
        g = torch.tanh(g)
        c_next = f * c_prev + i * g
        h_next = o * torch.tanh(c_next)
        return h_next, c_next

In [39]:
# Định nghĩa lớp ConvLSTM2d
class ConvLSTM2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding):
        super(ConvLSTM2d, self).__init__()
        self.cell = ConvLSTMCell(in_channels, out_channels, kernel_size, padding)

    def forward(self, x):
        # Kiểm tra số chiều của đầu vào
        if len(x.size()) == 5:
            # x: (batch, seq_len, channels, height, width)
            batch, seq_len, channels, height, width = x.size()
            is_sequence = True
        elif len(x.size()) == 4:
            # x: (batch, channels, height, width)
            batch, channels, height, width = x.size()
            seq_len = 1
            x = x.unsqueeze(1)  # Thêm chiều seq_len: (batch, 1, channels, height, width)
            is_sequence = False
        else:
            raise ValueError(f"Expected 4 or 5 dimensions, got {len(x.size())}")

        out_channels = self.cell.out_channels
        h = torch.zeros(batch, out_channels, height, width, device=x.device)
        c = torch.zeros(batch, out_channels, height, width, device=x.device)
        outputs = []
        for t in range(seq_len):
            x_t = x[:, t, :, :, :]  # (batch, channels, height, width)
            h, c = self.cell(x_t, h, c)
            outputs.append(h)
        output = outputs[-1] if is_sequence else h
        return output, (h, c)  # Trả về output cuối và trạng thái

In [40]:
# Định nghĩa mô hình ConvLSTM
class ConvLSTMModel(nn.Module):
    def __init__(self):
        super(ConvLSTMModel, self).__init__()
        self.convlstm1 = ConvLSTM2d(in_channels=34, out_channels=64, kernel_size=(5, 5), padding=(2, 2))
        self.bn1 = nn.BatchNorm2d(64)
        self.convlstm2 = ConvLSTM2d(in_channels=64, out_channels=32, kernel_size=(5, 5), padding=(2, 2))
        self.bn2 = nn.BatchNorm2d(32)
        self.dropout = nn.Dropout(0.2)
        self.conv = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=(3, 3), padding=(1, 1))
        self.relu = nn.ReLU()

    def forward(self, x):
        # x: (batch, seq_len=5, channels=34, height=90, width=250)
        print(f"Input shape: {x.shape}")
        x, _ = self.convlstm1(x)  # (batch, 64, height, width)
        x = self.bn1(x)
        print(f"After convlstm1: {x.shape}")
        x, _ = self.convlstm2(x)  # (batch, 32, height, width)
        x = self.bn2(x)
        x = self.dropout(x)
        x = self.conv(x)  # (batch, 1, height, width)
        x = self.relu(x)
        return x.squeeze(1)  # (batch, height, width)

In [41]:
# Hàm huấn luyện mô hình
def train_model(model, train_loader, val_loader, epochs=20, patience=5):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    best_loss = float('inf')
    patience_counter = 0
    best_model_state = None

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            output = model(X_batch)
            loss = criterion(output, y_batch)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * X_batch.size(0)
        train_loss /= len(train_loader.dataset)

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                output = model(X_batch)
                loss = criterion(output, y_batch)
                val_loss += loss.item() * X_batch.size(0)
            val_loss /= len(val_loader.dataset)

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

        if val_loss < best_loss:
            best_loss = val_loss
            best_model_state = model.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping")
                break

    model.load_state_dict(best_model_state)
    return model

In [42]:
# Hàm tính chỉ số đánh giá
def evaluate_model(y_true, y_pred, threshold=0.0):
    y_true = y_true.reshape(-1)
    y_pred = y_pred.reshape(-1)
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    corr = np.corrcoef(y_true, y_pred)[0, 1] if np.std(y_true) > 0 and np.std(y_pred) > 0 else 0

    y_true_bin = (y_true > threshold).astype(int)
    y_pred_bin = (y_pred > threshold).astype(int)
    hits = np.sum((y_true_bin == 1) & (y_pred_bin == 1))
    misses = np.sum((y_true_bin == 1) & (y_pred_bin == 0))
    false_alarms = np.sum((y_true_bin == 0) & (y_pred_bin == 1))
    true_negatives = np.sum((y_true_bin == 0) & (y_pred_bin == 0))
    total = hits + misses + false_alarms + true_negatives

    accuracy = (hits + true_negatives) / total if total > 0 else 0
    csi = hits / (hits + misses + false_alarms) if (hits + misses + false_alarms) > 0 else 0
    far = false_alarms / (hits + false_alarms) if (hits + false_alarms) > 0 else 0
    hss = (2 * (hits * true_negatives - misses * false_alarms)) / \
          ((hits + misses) * (misses + true_negatives) + (hits + false_alarms) * (false_alarms + true_negatives)) \
          if ((hits + misses) * (misses + true_negatives) + (hits + false_alarms) * (false_alarms + true_negatives)) > 0 else 0
    ets = ((hits - ((hits + misses) * (hits + false_alarms) / total)) / \
           (hits + misses + false_alarms - ((hits + misses) * (hits + false_alarms) / total))) \
          if (hits + misses + false_alarms - ((hits + misses) * (hits + false_alarms) / total)) > 0 else 0

    return {'rmse': rmse, 'corr': corr, 'accuracy': accuracy, 'csi': csi, 'far': far, 'hss': hss, 'ets': ets}

In [43]:
# Hàm vẽ scatter plot
def plot_scatter(y_true, y_pred, output_path):
    plt.figure(figsize=(8, 6))
    plt.scatter(y_true.flatten(), y_pred.flatten(), alpha=0.5)
    plt.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'r--')
    plt.xlabel('Ground Truth (mm/h)')
    plt.ylabel('Predicted (mm/h)')
    plt.title('Scatter Plot: Predicted vs Ground Truth')
    plt.savefig(output_path)
    plt.close()

In [44]:
# Hàm hiển thị bản đồ
def plot_rainfall_map(y_true, y_pred, output_path):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5), subplot_kw={'projection': ccrs.PlateCarree()})
    ax1.set_title('Ground Truth')
    ax2.set_title('Prediction')
    for ax, data in [(ax1, y_true), (ax2, y_pred)]:
        ax.coastlines()
        ax.add_feature(cfeature.BORDERS)
        im = ax.imshow(data, cmap='Blues', origin='upper', transform=ccrs.PlateCarree())
        plt.colorbar(im, ax=ax, label='Rainfall (mm/h)')
    plt.savefig(output_path)
    plt.close()

In [45]:
# Hàm lưu GeoTIFF
def save_geotiff(data, output_path, reference_file):
    ds = gdal.Open(reference_file)
    driver = gdal.GetDriverByName('GTiff')
    out_ds = driver.Create(output_path, WIDTH, HEIGHT, 1, gdal.GDT_Float32)
    out_ds.SetGeoTransform(ds.GetGeoTransform())
    out_ds.SetProjection(ds.GetProjection())
    out_band = out_ds.GetRasterBand(1)
    out_band.WriteArray(data)
    out_band.FlushCache()
    out_ds = None
    ds = None

In [46]:
# Bắt đầu chương trình
print("Collecting Himawari files...")
hima_files = {}
for band in HIMA_BANDS:
    band_path = os.path.join(HIMA_PATH, band)
    if not os.path.exists(band_path):
        print(f"Directory not found: {band_path}")
        continue
    band_files = collect_files(band_path, expected_subdirs=HIMA_BANDS, data_type="Hima")
    for dt, paths in band_files.items():
        if dt not in hima_files:
            hima_files[dt] = {}
        hima_files[dt][band] = paths[band]

Collecting Himawari files...
Found 1438 files in /kaggle/input/btl-ai/DATA_SV/Hima/B04B
Found 1361 files in /kaggle/input/btl-ai/DATA_SV/Hima/B05B
Found 1158 files in /kaggle/input/btl-ai/DATA_SV/Hima/B06B
Found 2777 files in /kaggle/input/btl-ai/DATA_SV/Hima/B09B
Found 2777 files in /kaggle/input/btl-ai/DATA_SV/Hima/B10B
Found 2777 files in /kaggle/input/btl-ai/DATA_SV/Hima/B11B
Found 2777 files in /kaggle/input/btl-ai/DATA_SV/Hima/B12B
Found 2776 files in /kaggle/input/btl-ai/DATA_SV/Hima/B14B
Found 2776 files in /kaggle/input/btl-ai/DATA_SV/Hima/B16B
Found 2776 files in /kaggle/input/btl-ai/DATA_SV/Hima/I2B
Found 2673 files in /kaggle/input/btl-ai/DATA_SV/Hima/I4B
Found 2776 files in /kaggle/input/btl-ai/DATA_SV/Hima/IRB
Found 1448 files in /kaggle/input/btl-ai/DATA_SV/Hima/VSB
Found 2774 files in /kaggle/input/btl-ai/DATA_SV/Hima/WVB


In [47]:
print("Collecting ERA5 files...")
era5_files = {}
for param in ERA5_PARAMS:
    param_path = os.path.join(ERA5_PATH, param)
    if not os.path.exists(param_path):
        print(f"Directory not found: {param_path}")
        continue
    param_files = collect_files(param_path, expected_subdirs=ERA5_PARAMS, data_type="ERA5")
    for dt, paths in param_files.items():
        if dt not in era5_files:
            era5_files[dt] = {}
        era5_files[dt][param] = paths[param]

Collecting ERA5 files...
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/CAPE
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/CIN
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/EWSS
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/IE
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/ISOR
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/KX
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/PEV
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/R250
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/R500
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/R850
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/SLHF
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/SLOR
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/SSHF
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/TCLW
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/TCW
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/TCWV
Found 2928 files in /kaggle/input/btl-

In [48]:
print("Collecting Precipitation files...")
precip_files = collect_files(PRECIP_PATH, data_type="Radar")

Collecting Precipitation files...
Found 2487 files in /kaggle/input/btl-ai/DATA_SV/Precipitation/Radar


In [49]:
# Đồng bộ thời gian
common_datetimes = set(hima_files.keys()) & set(era5_files.keys()) & set(precip_files.keys())
common_datetimes = sorted(list(common_datetimes))
print(f"Số thời điểm đồng bộ: {len(common_datetimes)}")

Số thời điểm đồng bộ: 2337


In [50]:
# Tạo chuỗi thời gian
print("Creating time sequences...")
X, y = create_time_sequences(hima_files, era5_files, precip_files, common_datetimes)
print(f"X shape: {X.shape}, y shape: {y.shape}")

Creating time sequences...
X shape: (326, 5, 34, 90, 250), y shape: (326, 90, 250)


In [51]:
# Chia dữ liệu
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)
print(f"Train: {X_train.shape}, Val: {X_val.shape}, Test: {X_test.shape}")

Train: (228, 5, 34, 90, 250), Val: (49, 5, 34, 90, 250), Test: (49, 5, 34, 90, 250)


In [52]:
# Chuyển sang tensor
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32)
X_val = torch.tensor(X_val, dtype=torch.float32)
y_val = torch.tensor(y_val, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32)

# Tạo DataLoader
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
test_dataset = TensorDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2)
test_loader = DataLoader(test_dataset, batch_size=2)

In [53]:
# Khởi tạo và huấn luyện mô hình
model = ConvLSTMModel().to(device)
model = train_model(model, train_loader, val_loader)

Input shape: torch.Size([2, 5, 34, 90, 250])
After convlstm1: torch.Size([2, 64, 90, 250])
Input shape: torch.Size([2, 5, 34, 90, 250])
After convlstm1: torch.Size([2, 64, 90, 250])
Input shape: torch.Size([2, 5, 34, 90, 250])
After convlstm1: torch.Size([2, 64, 90, 250])
Input shape: torch.Size([2, 5, 34, 90, 250])
After convlstm1: torch.Size([2, 64, 90, 250])
Input shape: torch.Size([2, 5, 34, 90, 250])
After convlstm1: torch.Size([2, 64, 90, 250])
Input shape: torch.Size([2, 5, 34, 90, 250])
After convlstm1: torch.Size([2, 64, 90, 250])
Input shape: torch.Size([2, 5, 34, 90, 250])
After convlstm1: torch.Size([2, 64, 90, 250])
Input shape: torch.Size([2, 5, 34, 90, 250])
After convlstm1: torch.Size([2, 64, 90, 250])
Input shape: torch.Size([2, 5, 34, 90, 250])
After convlstm1: torch.Size([2, 64, 90, 250])
Input shape: torch.Size([2, 5, 34, 90, 250])
After convlstm1: torch.Size([2, 64, 90, 250])
Input shape: torch.Size([2, 5, 34, 90, 250])
After convlstm1: torch.Size([2, 64, 90, 250])

In [54]:
# Đánh giá trên tập kiểm thử
model.eval()
y_pred = []
with torch.no_grad():
    for X_batch, _ in test_loader:
        X_batch = X_batch.to(device)
        output = model(X_batch)
        y_pred.append(output.cpu().numpy())
y_pred = np.concatenate(y_pred, axis=0)

# Đánh giá
metrics = evaluate_model(y_test.numpy(), y_pred)
print("Evaluation Metrics:", metrics)

Input shape: torch.Size([2, 5, 34, 90, 250])
After convlstm1: torch.Size([2, 64, 90, 250])
Input shape: torch.Size([2, 5, 34, 90, 250])
After convlstm1: torch.Size([2, 64, 90, 250])
Input shape: torch.Size([2, 5, 34, 90, 250])
After convlstm1: torch.Size([2, 64, 90, 250])
Input shape: torch.Size([2, 5, 34, 90, 250])
After convlstm1: torch.Size([2, 64, 90, 250])
Input shape: torch.Size([2, 5, 34, 90, 250])
After convlstm1: torch.Size([2, 64, 90, 250])
Input shape: torch.Size([2, 5, 34, 90, 250])
After convlstm1: torch.Size([2, 64, 90, 250])
Input shape: torch.Size([2, 5, 34, 90, 250])
After convlstm1: torch.Size([2, 64, 90, 250])
Input shape: torch.Size([2, 5, 34, 90, 250])
After convlstm1: torch.Size([2, 64, 90, 250])
Input shape: torch.Size([2, 5, 34, 90, 250])
After convlstm1: torch.Size([2, 64, 90, 250])
Input shape: torch.Size([2, 5, 34, 90, 250])
After convlstm1: torch.Size([2, 64, 90, 250])
Input shape: torch.Size([2, 5, 34, 90, 250])
After convlstm1: torch.Size([2, 64, 90, 250])

In [55]:
# Vẽ scatter plot
plot_scatter(y_test.numpy(), y_pred, os.path.join(OUTPUT_PATH, 'scatter_plot.png'))

# Vẽ bản đồ cho mẫu đầu tiên
plot_rainfall_map(y_test[0].numpy(), y_pred[0], os.path.join(OUTPUT_PATH, 'rainfall_map.png'))

# Lưu bản đồ dự đoán dưới dạng GeoTIFF
save_geotiff(y_pred[0], os.path.join(OUTPUT_PATH, 'predicted_rainfall.tif'),
             precip_files[common_datetimes[-1]])

