In [46]:
!pip install GDAL



In [47]:
import os
import numpy as np
from osgeo import gdal
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Dataset, WeightedRandomSampler
from scipy.interpolate import griddata
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import glob

In [48]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [49]:
# Đặt seed để tái lập kết quả
torch.manual_seed(42)
np.random.seed(42)

# Định nghĩa các hằng số và đường dẫn
BASE_PATH = "/kaggle/input/btl-ai/DATA_SV"
HIMA_PATH = os.path.join(BASE_PATH, "Hima")
ERA5_PATH = os.path.join(BASE_PATH, "ERA5")
PRECIP_PATH = os.path.join(BASE_PATH, "Precipitation/Radar")
OUTPUT_PATH = "/kaggle/working/output/"
os.makedirs(OUTPUT_PATH, exist_ok=True)

In [50]:
HIMA_BANDS = ['B04B', 'B05B', 'B06B', 'B09B', 'B10B', 'B11B', 'B12B', 'B14B', 'B16B', 'I2B', 'I4B', 'IRB', 'VSB', 'WVB']  # 14 band
ERA5_PARAMS = ['CAPE', 'CIN', 'EWSS', 'IE', 'ISOR', 'KX', 'PEV', 'R250', 'R500', 'R850', 'SLHF', 'SLOR', 'SSHF', 'TCLW', 'TCW', 'TCWV', 'U250', 'U850', 'V250', 'V850']  # 20 tham số
SELECTED_FEATURES = ['B04B', 'B10B', 'B11B', 'B16B', 'IRB', 'CAPE', 'R850', 'TCWV', 'U850', 'I2B', 'TCLW', 'TCW' ]  # Giả định các đặc trưng có tương quan cao
HEIGHT, WIDTH = 90, 250
HEIGHT, WIDTH = 90, 250
in_channel = len(SELECTED_FEATURES)

In [51]:
# Hàm đọc file GeoTIFF
def read_geotiff(file_path):
    try:
        ds = gdal.Open(file_path)
        band = ds.GetRasterBand(1)
        data = band.ReadAsArray()
        ds = None
        if data.shape != (HEIGHT, WIDTH):
            print(f"Invalid shape {data.shape} for file {file_path}, expected ({HEIGHT}, {WIDTH})")
            return None
        return data
    except Exception as e:
        print(f"Error reading {file_path}: {e}")
        return None

In [52]:
# Hàm phân tích thời gian từ tên file
def parse_datetime_from_filename(filename, data_type):
    try:
        if data_type == "Hima":
            parts = filename.split('_')
            if len(parts) < 2:
                return None
            time_part = parts[1].split('_TB.tif')[0]
            time_part = time_part.replace('.Z', '')
            dt = datetime.strptime(time_part, '%Y%m%d%H%M')
        elif data_type == "ERA5":
            parts = filename.split('_')
            if len(parts) < 2:
                return None
            time_part = parts[1].replace('.tif', '')
            dt = datetime.strptime(time_part, '%Y%m%d%H%M%S')
        elif data_type == "Radar":
            time_part = filename.split('_')[1].replace('.tif', '')
            dt = datetime.strptime(time_part, '%Y%m%d%H%M%S')
        else:
            return None
        return dt.replace(minute=0, second=0, microsecond=0)
    except Exception as e:
        global error_count
        if error_count < 5:
            print(f"Error parsing datetime from {filename} (type {data_type}): {e}")
            error_count += 1
        return None

In [53]:
error_count = 0

# Hàm thu thập file
def collect_files(base_path, expected_subdirs=None, data_type=None):
    files_dict = {}
    file_count = 0
    for root, dirs, files in os.walk(base_path):
        for file in files:
            if file.endswith('.tif'):
                file_path = os.path.join(root, file)
                dt = parse_datetime_from_filename(file, data_type)
                if dt is None:
                    continue
                file_count += 1
                if expected_subdirs:
                    subdir = os.path.basename(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(file_path)))))
                    if dt not in files_dict:
                        files_dict[dt] = {}
                    files_dict[dt][subdir] = file_path
                else:
                    files_dict[dt] = file_path
    print(f"Found {file_count} files in {base_path}")
    return files_dict

In [54]:
def preprocess_data(data, data_type):
    if data is None:
        return None
    # Xử lý giá trị không hợp lệ bằng nội suy
    if np.any(np.isinf(data) | np.isnan(data) | (data == -9999)):
        mask = np.isinf(data) | np.isnan(data) | (data == -9999)
        x, y = np.indices(data.shape)
        valid_points = np.column_stack((x[~mask], y[~mask]))
        valid_values = data[~mask]
        invalid_points = np.column_stack((x[mask], y[mask]))
        if len(valid_values) > 0:
            interpolated_values = griddata(valid_points, valid_values, invalid_points, method='nearest')
            data[mask] = interpolated_values
        else:
            data[mask] = 0
    if data_type == "Radar":
        data = np.maximum(data, 0)  # Đảm bảo không có giá trị âm
    else:
        # Min-max scaling cho Himawari và ERA5
        data_min, data_max = np.min(data), np.max(data)
        if data_max > data_min:
            data = (data - data_min) / (data_max - data_min)
        else:
            data = np.zeros_like(data)
    return data

In [55]:
# Hàm tạo chuỗi thời gian t-4 đến t
def create_time_sequences(hima_files, era5_files, precip_files, common_datetimes):
    X, y = [], []
    for i in range(4, len(common_datetimes)):
        dt = common_datetimes[i]
        # Kiểm tra tính liên tục của 5 khung
        valid_sequence = True
        for j in range(1, 5):
            if common_datetimes[i-j] != dt - timedelta(hours=j):
                valid_sequence = False
                break
        if not valid_sequence:
            continue

        # Tạo chuỗi 5 khung
        sequence = []
        for j in range(5):
            dt_j = common_datetimes[i-j]
            # Đọc Himawari (chỉ các band được chọn)
            hima_data = []
            for band in [b for b in HIMA_BANDS if b in SELECTED_FEATURES]:
                file_path = hima_files.get(dt_j, {}).get(band)
                if not file_path:
                    valid_sequence = False
                    break
                data = read_geotiff(file_path)
                data = preprocess_data(data, "Hima")
                if data is None:
                    valid_sequence = False
                    break
                hima_data.append(data)
            if not valid_sequence:
                break

            # Đọc ERA5 (chỉ các tham số được chọn)
            era5_data = []
            for param in [p for p in ERA5_PARAMS if p in SELECTED_FEATURES]:
                file_path = era5_files.get(dt_j, {}).get(param)
                if not file_path:
                    valid_sequence = False
                    break
                data = read_geotiff(file_path)
                data = preprocess_data(data, "ERA5")
                if data is None:
                    valid_sequence = False
                    break
                era5_data.append(data)
            if not valid_sequence:
                break

            # Kết hợp Himawari và ERA5
            combined = np.concatenate([np.stack(hima_data, axis=-1), np.stack(era5_data, axis=-1)], axis=-1)  # (90, 250, 10)
            sequence.append(combined)
        if not valid_sequence:
            continue

        # Đọc radar (ground truth)
        radar_file = precip_files.get(dt)
        if not radar_file:
            continue
        radar_data = read_geotiff(radar_file)
        radar_data = preprocess_data(radar_data, "Radar")
        if radar_data is None:
            continue

        sequence = np.stack(sequence, axis=0)  # (5, 90, 250, 10)
        X.append(sequence)
        y.append(radar_data)  # (90, 250)

    X = np.array(X)
    y = np.array(y)
    # Chuyển đổi định dạng để kênh nằm trước chiều không gian
    X = X.transpose(0, 1, 4, 2, 3)  # (samples, 5, 10, 90, 250)
    return X, y

In [56]:
# Định nghĩa lớp ConvLSTMCell tùy chỉnh
class ConvLSTMCell(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding):
        super(ConvLSTMCell, self).__init__()
        self.out_channels = out_channels
        self.conv = nn.Conv2d(
            in_channels + out_channels, 4 * out_channels, kernel_size,
            padding=padding, bias=True
        )

    def forward(self, x, h_prev, c_prev):
        combined = torch.cat([x, h_prev], dim=1)
        conv_out = self.conv(combined)
        i, f, o, g = torch.chunk(conv_out, 4, dim=1)
        i = torch.sigmoid(i)
        f = torch.sigmoid(f)
        o = torch.sigmoid(o)
        g = torch.tanh(g)
        c_next = f * c_prev + i * g
        h_next = o * torch.tanh(c_next)
        return h_next, c_next

In [57]:
# Định nghĩa lớp ConvLSTM2d
class ConvLSTM2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding):
        super(ConvLSTM2d, self).__init__()
        self.cell = ConvLSTMCell(in_channels, out_channels, kernel_size, padding)

    def forward(self, x):
        if len(x.size()) == 5:
            batch, seq_len, channels, height, width = x.size()
            is_sequence = True
        elif len(x.size()) == 4:
            batch, channels, height, width = x.size()
            seq_len = 1
            x = x.unsqueeze(1)
            is_sequence = False
        else:
            raise ValueError(f"Expected 4 or 5 dimensions, got {len(x.size())}")

        out_channels = self.cell.out_channels
        h = torch.zeros(batch, out_channels, height, width, device=x.device)
        c = torch.zeros(batch, out_channels, height, width, device=x.device)
        outputs = []
        for t in range(seq_len):
            x_t = x[:, t, :, :, :]
            h, c = self.cell(x_t, h, c)
            outputs.append(h)
        output = outputs[-1] if is_sequence else h
        return output, (h, c)

In [58]:
# Hàm huấn luyện (giữ weighted loss để xử lý dữ liệu không cân bằng)
def train_model(model, train_loader, val_loader, epochs=30, patience=7):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    criterion = nn.MSELoss(reduction='none')
    optimizer = optim.Adam(model.parameters(), lr=0.0001)  # Giảm lr để ổn định
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)
    best_loss = float('inf')
    patience_counter = 0
    best_model_state = None

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            output = model(X_batch)
            weights = torch.where(y_batch > 0, torch.tensor(10.0, device=device), torch.tensor(1.0, device=device))
            loss = (criterion(output, y_batch) * weights).mean()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * X_batch.size(0)
        train_loss /= len(train_loader.dataset)

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                output = model(X_batch)
                weights = torch.where(y_batch > 0, torch.tensor(10.0, device=device), torch.tensor(1.0, device=device))
                loss = (criterion(output, y_batch) * weights).mean()
                val_loss += loss.item() * X_batch.size(0)
            val_loss /= len(val_loader.dataset)

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, LR: {optimizer.param_groups[0]['lr']}")
        scheduler.step(val_loss)

        if val_loss < best_loss:
            best_loss = val_loss
            best_model_state = model.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping")
                break

    model.load_state_dict(best_model_state)
    return model

In [59]:
# import torch.optim as optim

# # Định nghĩa lớp SpatialAttention
# class SpatialAttention(nn.Module):
#     def __init__(self, in_channels):
#         super(SpatialAttention, self).__init__()
#         self.conv1 = nn.Conv2d(2, in_channels // 8, kernel_size=1)  # Đầu vào là 2 kênh từ avg và max
#         self.conv2 = nn.Conv2d(in_channels // 8, in_channels, kernel_size=1)
#         self.sigmoid = nn.Sigmoid()

#     def forward(self, x):
#         avg_out = torch.mean(x, dim=1, keepdim=True)  # Trung bình theo kênh
#         max_out, _ = torch.max(x, dim=1, keepdim=True)  # Tối đa theo kênh
#         out = torch.cat([avg_out, max_out], dim=1)  # Ghép thành 2 kênh
#         out = self.conv1(out)
#         out = self.conv2(out)
#         return self.sigmoid(out)

In [60]:
# Định nghĩa mô hình ConvLSTM không có SpatialAttention
class ConvLSTMModel(nn.Module):
    def __init__(self, in_channels=in_channel, hidden_channels=64, kernel_size=5):
        super(ConvLSTMModel, self).__init__()
        self.convlstm1 = ConvLSTM2d(in_channels, hidden_channels, kernel_size, padding=(kernel_size//2))
        self.bn1 = nn.BatchNorm2d(hidden_channels)
        self.convlstm2 = ConvLSTM2d(hidden_channels, 32, kernel_size, padding=(kernel_size//2))
        self.bn2 = nn.BatchNorm2d(32)
        self.dropout = nn.Dropout(0.2)
        self.conv = nn.Conv2d(32, 1, kernel_size=3, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x, _ = self.convlstm1(x)
        x = self.bn1(x)
        x, _ = self.convlstm2(x)
        x = self.bn2(x)
        x = self.dropout(x)
        x = self.conv(x)
        x = self.relu(x)
        return x.squeeze(1)

In [61]:
def evaluate_model(y_true, y_pred, threshold=1.0):
    rmse = np.sqrt(mean_squared_error(y_true.flatten(), y_pred.flatten()))
    corr = np.corrcoef(y_true.flatten(), y_pred.flatten())[0, 1] if np.std(y_true) > 0 and np.std(y_pred) > 0 else 0

    y_true_bin = (y_true > threshold).astype(int)
    y_pred_bin = (y_pred > threshold).astype(int)
    hits = np.sum((y_true_bin == 1) & (y_pred_bin == 1))
    misses = np.sum((y_true_bin == 1) & (y_pred_bin == 0))
    false_alarms = np.sum((y_true_bin == 0) & (y_pred_bin == 1))
    true_negatives = np.sum((y_true_bin == 0) & (y_pred_bin == 0))
    total = hits + misses + false_alarms + true_negatives

    accuracy = (hits + true_negatives) / total if total > 0 else 0
    csi = hits / (hits + misses + false_alarms) if (hits + misses + false_alarms) > 0 else 0
    far = false_alarms / (hits + false_alarms) if (hits + false_alarms) > 0 else 0
    hss = (2 * (hits * true_negatives - misses * false_alarms)) / \
          ((hits + misses) * (misses + true_negatives) + (hits + false_alarms) * (false_alarms + true_negatives)) \
          if ((hits + misses) * (misses + true_negatives) + (hits + false_alarms) * (false_alarms + true_negatives)) > 0 else 0
    ets = ((hits - ((hits + misses) * (hits + false_alarms) / total)) / \
           (hits + misses + false_alarms - ((hits + misses) * (hits + false_alarms) / total))) \
          if (hits + misses + false_alarms - ((hits + misses) * (hits + false_alarms) / total)) > 0 else 0

    return {'rmse': rmse, 'corr': corr, 'accuracy': accuracy, 'csi': csi, 'far': far, 'hss': hss, 'ets': ets}


In [62]:
# Hàm vẽ scatter plot
def plot_scatter(y_true, y_pred, output_path):
    plt.figure(figsize=(8, 6))
    plt.scatter(y_true.flatten(), y_pred.flatten(), alpha=0.5)
    plt.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'r--')
    plt.xlabel('Ground Truth (mm/h)')
    plt.ylabel('Predicted (mm/h)')
    plt.title('Scatter Plot: Predicted vs Ground Truth')
    plt.savefig(output_path)
    plt.close()

In [63]:
# Hàm hiển thị bản đồ
def plot_rainfall_map(y_true, y_pred, output_path):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5), subplot_kw={'projection': ccrs.PlateCarree()})
    ax1.set_title('Ground Truth')
    ax2.set_title('Prediction')
    for ax, data in [(ax1, y_true), (ax2, y_pred)]:
        ax.coastlines()
        ax.add_feature(cfeature.BORDERS)
        im = ax.imshow(data, cmap='Blues', origin='upper', transform=ccrs.PlateCarree())
        plt.colorbar(im, ax=ax, label='Rainfall (mm/h)')
    plt.savefig(output_path)
    plt.close()

In [64]:
# Hàm lưu GeoTIFF
def save_geotiff(data, output_path, reference_file):
    ds = gdal.Open(reference_file)
    driver = gdal.GetDriverByName('GTiff')
    out_ds = driver.Create(output_path, WIDTH, HEIGHT, 1, gdal.GDT_Float32)
    out_ds.SetGeoTransform(ds.GetGeoTransform())
    out_ds.SetProjection(ds.GetProjection())
    out_band = out_ds.GetRasterBand(1)
    out_band.WriteArray(data)
    out_band.FlushCache()
    out_ds = None
    ds = None

In [65]:
# Bắt đầu chương trình
print("Collecting Himawari files...")
hima_files = {}
for band in HIMA_BANDS:
    band_path = os.path.join(HIMA_PATH, band)
    if not os.path.exists(band_path):
        print(f"Directory not found: {band_path}")
        continue
    band_files = collect_files(band_path, expected_subdirs=HIMA_BANDS, data_type="Hima")
    for dt, paths in band_files.items():
        if dt not in hima_files:
            hima_files[dt] = {}
        hima_files[dt][band] = paths[band]

Collecting Himawari files...
Found 1438 files in /kaggle/input/btl-ai/DATA_SV/Hima/B04B
Found 1361 files in /kaggle/input/btl-ai/DATA_SV/Hima/B05B
Found 1158 files in /kaggle/input/btl-ai/DATA_SV/Hima/B06B
Found 2777 files in /kaggle/input/btl-ai/DATA_SV/Hima/B09B
Found 2777 files in /kaggle/input/btl-ai/DATA_SV/Hima/B10B
Found 2777 files in /kaggle/input/btl-ai/DATA_SV/Hima/B11B
Found 2777 files in /kaggle/input/btl-ai/DATA_SV/Hima/B12B
Found 2776 files in /kaggle/input/btl-ai/DATA_SV/Hima/B14B
Found 2776 files in /kaggle/input/btl-ai/DATA_SV/Hima/B16B
Found 2776 files in /kaggle/input/btl-ai/DATA_SV/Hima/I2B
Found 2673 files in /kaggle/input/btl-ai/DATA_SV/Hima/I4B
Found 2776 files in /kaggle/input/btl-ai/DATA_SV/Hima/IRB
Found 1448 files in /kaggle/input/btl-ai/DATA_SV/Hima/VSB
Found 2774 files in /kaggle/input/btl-ai/DATA_SV/Hima/WVB


In [66]:
print("Collecting ERA5 files...")
era5_files = {}
for param in ERA5_PARAMS:
    param_path = os.path.join(ERA5_PATH, param)
    if not os.path.exists(param_path):
        print(f"Directory not found: {param_path}")
        continue
    param_files = collect_files(param_path, expected_subdirs=ERA5_PARAMS, data_type="ERA5")
    for dt, paths in param_files.items():
        if dt not in era5_files:
            era5_files[dt] = {}
        era5_files[dt][param] = paths[param]

Collecting ERA5 files...
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/CAPE
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/CIN
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/EWSS
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/IE
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/ISOR
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/KX
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/PEV
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/R250
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/R500
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/R850
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/SLHF
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/SLOR
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/SSHF
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/TCLW
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/TCW
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/TCWV
Found 2928 files in /kaggle/input/btl-

In [67]:
print("Collecting Precipitation files...")
precip_files = collect_files(PRECIP_PATH, data_type="Radar")

Collecting Precipitation files...
Found 2487 files in /kaggle/input/btl-ai/DATA_SV/Precipitation/Radar


In [68]:
# Đồng bộ thời gian
common_datetimes = set(hima_files.keys()) & set(era5_files.keys()) & set(precip_files.keys())
common_datetimes = sorted(list(common_datetimes))
print(f"Số thời điểm đồng bộ: {len(common_datetimes)}")

Số thời điểm đồng bộ: 2337


In [69]:
from scipy.interpolate import griddata
# Tạo chuỗi thời gian
print("Creating time sequences...")
X, y = create_time_sequences(hima_files, era5_files, precip_files, common_datetimes)
print(f"X shape: {X.shape}, y shape: {y.shape}")

Creating time sequences...
X shape: (720, 5, 12, 90, 250), y shape: (720, 90, 250)


### Data augumentation

In [70]:
# Chia dữ liệu thành train, val, test
X_train_full, X_temp, y_train_full, y_temp = train_test_split(X, y, test_size=0.3, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)
print(f"Train: {X_train_full.shape}, Val: {X_val.shape}, Test: {X_test.shape}")

Train: (504, 5, 12, 90, 250), Val: (108, 5, 12, 90, 250), Test: (108, 5, 12, 90, 250)


In [71]:
# Chuyển sang tensor
X_train_full = torch.tensor(X_train_full, dtype=torch.float32)
y_train_full = torch.tensor(y_train_full, dtype=torch.float32)
X_val = torch.tensor(X_val, dtype=torch.float32)
y_val = torch.tensor(y_val, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32)

# Tính tỷ lệ pixel mưa trong mỗi mẫu
rain_indices = []
no_rain_indices = []
threshold_ratio = 0.05  # Ngưỡng: ít nhất 5% pixel có mưa
total_pixels = 90 * 250  # Số pixel trong mỗi mẫu

for i in range(len(X_train_full)):
    rain_pixels = torch.sum(y_train_full[i] > 0).item()
    rain_ratio = rain_pixels / total_pixels
    if rain_ratio >= threshold_ratio:
        rain_indices.append(i)
    else:
        no_rain_indices.append(i)

print(f"Number of samples with rain in train set: {len(rain_indices)}")
print(f"Number of samples without rain in train set: {len(no_rain_indices)}")


# Tính trọng số cho WeightedRandomSampler
weights = torch.zeros(len(X_train_full))
oversample_factor = 5  # Giảm từ 10 xuống 5 để tiết kiệm RAM
for i in rain_indices:
    weights[i] = oversample_factor * (len(no_rain_indices) / len(rain_indices))  # Tăng trọng số cho mẫu có mưa
for i in no_rain_indices:
    weights[i] = 1.0  # Trọng số mặc định cho mẫu không mưa

# Tạo sampler
train_sampler = WeightedRandomSampler(weights=weights, num_samples=len(X_train_full), replacement=True)

# Tạo TensorDataset
train_dataset = TensorDataset(X_train_full, y_train_full)
val_dataset = TensorDataset(X_val, y_val)
test_dataset = TensorDataset(X_test, y_test)

# Tạo DataLoader với sampler cho tập train
train_loader = DataLoader(train_dataset, batch_size=8, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=8)
test_loader = DataLoader(test_dataset, batch_size=8)

Number of samples with rain in train set: 188
Number of samples without rain in train set: 316


In [72]:
# Khởi tạo và huấn luyện mô hình
model = ConvLSTMModel().to(device)
model = train_model(model, train_loader, val_loader)

Epoch 1/30, Train Loss: 17.7588, Val Loss: 43.1117, LR: 0.0001
Epoch 2/30, Train Loss: 17.7262, Val Loss: 43.4238, LR: 0.0001
Epoch 3/30, Train Loss: 22.5581, Val Loss: 43.4488, LR: 0.0001
Epoch 4/30, Train Loss: 23.5547, Val Loss: 42.0627, LR: 0.0001
Epoch 5/30, Train Loss: 22.8258, Val Loss: 41.8399, LR: 0.0001
Epoch 6/30, Train Loss: 17.5843, Val Loss: 41.8889, LR: 0.0001
Epoch 7/30, Train Loss: 18.2838, Val Loss: 41.8687, LR: 0.0001
Epoch 8/30, Train Loss: 20.2648, Val Loss: 43.8251, LR: 0.0001
Epoch 9/30, Train Loss: 17.4918, Val Loss: 41.6055, LR: 0.0001
Epoch 10/30, Train Loss: 15.5867, Val Loss: 42.4865, LR: 0.0001
Epoch 11/30, Train Loss: 15.1016, Val Loss: 41.3256, LR: 0.0001
Epoch 12/30, Train Loss: 15.6329, Val Loss: 41.8009, LR: 0.0001
Epoch 13/30, Train Loss: 15.0371, Val Loss: 41.9465, LR: 0.0001
Epoch 14/30, Train Loss: 13.7249, Val Loss: 41.1826, LR: 0.0001
Epoch 15/30, Train Loss: 15.5693, Val Loss: 41.9359, LR: 0.0001
Epoch 16/30, Train Loss: 18.1580, Val Loss: 40.85

In [73]:
# Đánh giá trên tập kiểm thử
model.eval()
y_pred = []
with torch.no_grad():
    for X_batch, _ in test_loader:
        X_batch = X_batch.to(device)
        output = model(X_batch)
        y_pred.append(output.cpu().numpy())
y_pred = np.concatenate(y_pred, axis=0)

# Đánh giá
metrics = evaluate_model(y_test.numpy(), y_pred, threshold=0.5)
print("Evaluation Metrics:", metrics)

Evaluation Metrics: {'rmse': 0.9670093, 'corr': 0.4309980857215112, 'accuracy': 0.8712012345679012, 'csi': 0.15011092676118296, 'far': 0.8403013704961982, 'hss': 0.220457084379325, 'ets': 0.12388410666816273}


In [74]:
# Vẽ scatter plot
plot_scatter(y_test.numpy(), y_pred, os.path.join(OUTPUT_PATH, 'scatter_plot.png'))

# Vẽ bản đồ cho mẫu đầu tiên
plot_rainfall_map(y_test[0].numpy(), y_pred[0], os.path.join(OUTPUT_PATH, 'rainfall_map.png'))

# Lưu bản đồ dự đoán dưới dạng GeoTIFF
save_geotiff(y_pred[0], os.path.join(OUTPUT_PATH, 'predicted_rainfall.tif'),
             precip_files[common_datetimes[-1]])