In [1]:
!pip install GDAL



In [2]:
import os
import torch.nn.functional as F 
import numpy as np
from osgeo import gdal
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import glob
from scipy.interpolate import griddata

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [4]:
# Đặt seed để tái lập kết quả
torch.manual_seed(42)
np.random.seed(42)

# Định nghĩa các hằng số và đường dẫn
BASE_PATH = "/kaggle/input/btl-ai/DATA_SV"
HIMA_PATH = os.path.join(BASE_PATH, "Hima")
ERA5_PATH = os.path.join(BASE_PATH, "ERA5")
PRECIP_PATH = os.path.join(BASE_PATH, "Precipitation/Radar")
OUTPUT_PATH = "/kaggle/working/output/"
os.makedirs(OUTPUT_PATH, exist_ok=True)

In [5]:
HIMA_BANDS = ['B04B', 'B05B', 'B06B', 'B09B', 'B10B', 'B11B', 'B12B', 'B14B', 'B16B', 'I2B', 'I4B', 'IRB', 'VSB', 'WVB']  # 14 band
ERA5_PARAMS = ['CAPE', 'CIN', 'EWSS', 'IE', 'ISOR', 'KX', 'PEV', 'R250', 'R500', 'R850', 'SLHF', 'SLOR', 'SSHF', 'TCLW', 'TCW', 'TCWV', 'U250', 'U850', 'V250', 'V850']  # 20 tham số
HEIGHT, WIDTH = 90, 250
IN_CHANNEL = len(HIMA_BANDS) + len(ERA5_PARAMS)

DOWNSCALE_FACTOR = 2
NEW_HEIGHT, NEW_WIDTH = HEIGHT // DOWNSCALE_FACTOR, WIDTH // DOWNSCALE_FACTOR

In [6]:
# Improved function to read GeoTIFF files
def read_geotiff(file_path, data_type="Radar"):
    try:
        ds = gdal.Open(file_path)
        if ds is None:
            print(f"Failed to open {file_path}")
            return None
            
        band = ds.GetRasterBand(1)
        nodata = band.GetNoDataValue()
        data = band.ReadAsArray().astype(np.float32)
        ds = None
        
        if data.shape != (HEIGHT, WIDTH):
            print(f"Invalid shape {data.shape} for file {file_path}, expected ({HEIGHT}, {WIDTH})")
            return None
        
        # Handle missing values
        if nodata is not None:
            mask = data == nodata
            if np.any(mask):
                if data_type == "Radar":
                    # For radar data, missing values should be 0 (no rain)
                    data[mask] = 0
                else:  # Hima and ERA5
                    # For satellite and meteorological data, use mean value imputation
                    valid_data = data[~mask]
                    if len(valid_data) > 0:
                        mean_value = np.mean(valid_data)
                        data[mask] = mean_value
                    else:
                        data[mask] = 0
        
        # Additional outlier handling
        if data_type != "Radar":  # For non-radar data
            # Remove extreme outliers (values beyond 5 standard deviations)
            valid_mask = ~(np.isinf(data) | np.isnan(data) | (data == nodata))
            if np.sum(valid_mask) > 0:
                valid_data = data[valid_mask]
                mean_val = np.mean(valid_data)
                std_val = np.std(valid_data)
                lower_bound = mean_val - 5 * std_val
                upper_bound = mean_val + 5 * std_val
                data = np.clip(data, lower_bound, upper_bound)
        
        return data
    except Exception as e:
        print(f"Error reading {file_path}: {e}")
        return None

In [7]:
# Hàm phân tích thời gian từ tên file
def parse_datetime_from_filename(filename, data_type):
    try:
        if data_type == "Hima":
            parts = filename.split('_')
            if len(parts) < 2:
                return None
            time_part = parts[1].split('_TB.tif')[0]
            time_part = time_part.replace('.Z', '')
            dt = datetime.strptime(time_part, '%Y%m%d%H%M')
        elif data_type == "ERA5":
            parts = filename.split('_')
            if len(parts) < 2:
                return None
            time_part = parts[1].replace('.tif', '')
            dt = datetime.strptime(time_part, '%Y%m%d%H%M%S')
        elif data_type == "Radar":
            time_part = filename.split('_')[1].replace('.tif', '')
            dt = datetime.strptime(time_part, '%Y%m%d%H%M%S')
        else:
            return None
        return dt.replace(minute=0, second=0, microsecond=0)
    except Exception as e:
        global error_count
        if error_count < 5:
            print(f"Error parsing datetime from {filename} (type {data_type}): {e}")
            error_count += 1
        return None

In [8]:
error_count = 0

# Hàm thu thập file
def collect_files(base_path, expected_subdirs=None, data_type=None):
    files_dict = {}
    file_count = 0
    for root, dirs, files in os.walk(base_path):
        for file in files:
            if file.endswith('.tif'):
                file_path = os.path.join(root, file)
                dt = parse_datetime_from_filename(file, data_type)
                if dt is None:
                    continue
                file_count += 1
                if expected_subdirs:
                    subdir = os.path.basename(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(file_path)))))
                    if dt not in files_dict:
                        files_dict[dt] = {}
                    files_dict[dt][subdir] = file_path
                else:
                    files_dict[dt] = file_path
    print(f"Found {file_count} files in {base_path}")
    return files_dict

In [9]:
# Improved data preprocessing function with adaptive normalization
def preprocess_data(data, data_type):
    if data is None:
        return None
    
    data = np.where(np.isinf(data) | np.isnan(data) | (data < -9000), 0, data)
    
    if data_type == "Radar":
        data = np.maximum(data, 0)
        # Chỉ áp dụng log transform, không chia theo mức độ mưa
        data_transformed = np.log1p(data)
        # Chuẩn hóa về [0, 1] nhưng giữ độ tương phản
        data_max = np.percentile(data_transformed, 99) if np.max(data_transformed) > 0 else 1.0
        data_min = np.min(data_transformed)
        range_val = data_max - data_min
        if range_val > 0:
            data_transformed = (data_transformed - data_min) / range_val
        else:
            data_transformed = np.zeros_like(data_transformed)
        return data_transformed
    else:
        valid_data = data[~np.isnan(data) & ~np.isinf(data) & (data != -9999)]
        if len(valid_data) == 0:
            return np.zeros_like(data)
        min_val = np.percentile(valid_data, 1)
        max_val = np.percentile(valid_data, 99)
        data = np.clip(data, min_val, max_val)
        range_val = max_val - min_val
        if range_val > 0:
            data = (data - min_val) / range_val
        else:
            data = np.zeros_like(data)
        return data

In [10]:
# FocalLoss nhị phân
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=0.25):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.eps = 1e-7
        
    def forward(self, y_pred, y_true):
        y_pred = torch.clamp(y_pred, self.eps, 1 - self.eps)
        bce = - (y_true * torch.log(y_pred) + (1 - y_true) * torch.log(1 - y_pred))
        pt = torch.exp(-bce)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce
        return focal_loss.mean()

In [11]:
import os
import numpy as np
from datetime import timedelta

# Sequence creation with adjusted balancing
def create_time_sequences(hima_files, era5_files, precip_files, common_datetimes):
    X, y = [], []
    rain_count = 0
    no_rain_count = 0
    
    for i in range(4, len(common_datetimes)):
        dt = common_datetimes[i]
        valid_sequence = True
        for j in range(1, 5):
            if common_datetimes[i-j] != dt - timedelta(hours=j):
                valid_sequence = False
                break
        if not valid_sequence:
            continue

        sequence = []
        sequence_valid = True
        
        for j in range(5):
            dt_j = common_datetimes[i-4+j]
            hima_data = []
            for band in HIMA_BANDS:
                file_path = hima_files.get(dt_j, {}).get(band)
                if not file_path:
                    sequence_valid = False
                    break
                data = read_geotiff(file_path, data_type="Hima")
                data = preprocess_data(data, "Hima")
                if data is None:
                    sequence_valid = False
                    break
                hima_data.append(data)
            if not sequence_valid:
                break
            hima_data = np.stack(hima_data, axis=-1)

            era5_data = []
            for param in ERA5_PARAMS:
                file_path = era5_files.get(dt_j, {}).get(param)
                if not file_path:
                    sequence_valid = False
                    break
                data = read_geotiff(file_path, data_type="ERA5")
                data = preprocess_data(data, "ERA5")
                if data is None:
                    sequence_valid = False
                    break
                era5_data.append(data)
            if not sequence_valid:
                break
            era5_data = np.stack(era5_data, axis=-1)

            combined = np.concatenate([hima_data, era5_data], axis=-1)
            sequence.append(combined)
        
        if not sequence_valid:
            continue

        radar_file = precip_files.get(dt)
        if not radar_file:
            continue
        radar_data = read_geotiff(radar_file, data_type="Radar")
        radar_data = preprocess_data(radar_data, "Radar")
        if radar_data is None:
            continue

        has_rain = np.any(radar_data > 0.1)
        if has_rain:
            rain_count += 1
        else:
            no_rain_count += 1
            if no_rain_count > rain_count * 2 and np.random.random() < 0.9:
                continue

        sequence_tensor = torch.tensor(np.array(sequence), dtype=torch.float32)
        radar_tensor = torch.tensor(radar_data, dtype=torch.float32)
        sequence_tensor = sequence_tensor.permute(0, 3, 1, 2)
        radar_tensor = radar_tensor.unsqueeze(0).unsqueeze(0)
        sequence_downsampled = F.avg_pool2d(sequence_tensor, kernel_size=DOWNSCALE_FACTOR, stride=DOWNSCALE_FACTOR)
        radar_downsampled = F.max_pool2d(radar_tensor, kernel_size=DOWNSCALE_FACTOR, stride=DOWNSCALE_FACTOR)
        sequence_downsampled = sequence_downsampled.permute(0, 2, 3, 1).numpy()
        radar_downsampled = radar_downsampled.squeeze().numpy()

        X.append(sequence_downsampled)
        y.append(radar_downsampled)

    print(f"Created {len(X)} sequences ({rain_count} with rain, {no_rain_count} without rain)")
    if len(X) == 0:
        raise ValueError("No valid sequences created. Check your data paths and filters.")
    
    X = np.array(X)
    y = np.array(y)
    X = X.transpose(0, 1, 4, 2, 3)
    return X, y

In [12]:
# Định nghĩa lớp ConvLSTMCell tùy chỉnh
class ConvLSTMCell(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding):
        super(ConvLSTMCell, self).__init__()
        self.out_channels = out_channels
        self.conv = nn.Conv2d(
            in_channels + out_channels, 4 * out_channels, kernel_size,
            padding=padding, bias=True
        )

    def forward(self, x, h_prev, c_prev):
        # x: (batch, in_channels, height, width)
        # h_prev, c_prev: (batch, out_channels, height, width)
        combined = torch.cat([x, h_prev], dim=1)  # (batch, in_channels + out_channels, height, width)
        conv_out = self.conv(combined)  # (batch, 4 * out_channels, height, width)
        i, f, o, g = torch.chunk(conv_out, 4, dim=1)  # Mỗi cái: (batch, out_channels, height, width)
        i = torch.sigmoid(i)
        f = torch.sigmoid(f)
        o = torch.sigmoid(o)
        g = torch.tanh(g)
        c_next = f * c_prev + i * g
        h_next = o * torch.tanh(c_next)
        return h_next, c_next

In [13]:
# ConvLSTM2d
class ConvLSTM2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding):
        super(ConvLSTM2d, self).__init__()
        self.cell = ConvLSTMCell(in_channels, out_channels, kernel_size, padding)

    def forward(self, x):
        batch, seq_len, channels, height, width = x.size()
        h = torch.zeros(batch, self.cell.out_channels, height, width, device=x.device)
        c = torch.zeros(batch, self.cell.out_channels, height, width, device=x.device)
        for t in range(seq_len):
            h, c = self.cell(x[:, t], h, c)
        return h, (h, c)

In [14]:
# # Attention module for focusing on important features
# class SpatialAttention(nn.Module):
#     def __init__(self, in_channels):
#         super(SpatialAttention, self).__init__()
#         self.conv1 = nn.Conv2d(in_channels, 1, kernel_size=1, bias=False)
        
#     def forward(self, x):
#         # Generate attention map
#         attn = self.conv1(x)
#         attn = torch.sigmoid(attn)
        
#         # Apply attention
#         return x * attn

In [15]:
# Training function
def train_model(model, train_loader, val_loader, epochs=100, patience=15):
    criterion = FocalLoss(gamma=2.0, alpha=0.25, beta=0.95)
    optimizer = optim.Adam(model.parameters(), lr=0.0005)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
    
    best_loss = float('inf')
    patience_counter = 0
    best_model_state = None
    train_losses = []
    val_losses = []
    
    print("Starting training...")
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            output = model(X_batch)
            loss = criterion(output, y_batch)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            train_loss += loss.item() * X_batch.size(0)
        train_loss /= len(train_loader.dataset)
        train_losses.append(train_loss)

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                output = model(X_batch)
                loss = criterion(output, y_batch)
                val_loss += loss.item() * X_batch.size(0)
            val_loss /= len(val_loader.dataset)
            val_losses.append(val_loss)

        scheduler.step(val_loss)
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

        if val_loss < best_loss:
            best_loss = val_loss
            best_model_state = model.state_dict()
            patience_counter = 0
            torch.save(model.state_dict(), os.path.join(OUTPUT_PATH, 'best_model.pth'))
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.savefig(os.path.join(OUTPUT_PATH, 'training_curve.png'))
    plt.close()
    
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    return model

In [16]:
# Mô hình đơn giản hóa
class ConvLSTMModel(nn.Module):
    def __init__(self, in_channels):
        super(ConvLSTMModel, self).__init__()
        self.convlstm1 = ConvLSTM2d(in_channels, 64, kernel_size=(5, 5), padding=(2, 2))
        self.convlstm2 = ConvLSTM2d(64, 32, kernel_size=(3, 3), padding=(1, 1))
        self.conv_out = nn.Conv2d(32, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x, _ = self.convlstm1(x)
        x, _ = self.convlstm2(x)
        x = self.conv_out(x)
        return self.sigmoid(x)

In [17]:
# Adjusted evaluation function with multiple thresholds
def evaluate_model(y_true, y_pred, thresholds=[0.1, 0.2, 0.3]):
    y_true_flat = y_true.reshape(-1)
    y_pred_flat = y_pred.reshape(-1)
    rmse = np.sqrt(mean_squared_error(y_true_flat, y_pred_flat))
    if np.std(y_true_flat) > 0 and np.std(y_pred_flat) > 0:
        corr = np.corrcoef(y_true_flat, y_pred_flat)[0, 1]
    else:
        corr = 0
    
    results = {
        'rmse': rmse,
        'corr': corr
    }
    
    for threshold in thresholds:
        y_true_bin = (y_true_flat > 0.1).astype(int)
        y_pred_bin = (y_pred_flat > threshold).astype(int)
        hits = np.sum((y_true_bin == 1) & (y_pred_bin == 1))
        misses = np.sum((y_true_bin == 1) & (y_pred_bin == 0))
        false_alarms = np.sum((y_true_bin == 0) & (y_pred_bin == 1))
        true_negatives = np.sum((y_true_bin == 0) & (y_pred_bin == 0))
        total = hits + misses + false_alarms + true_negatives
        accuracy = (hits + true_negatives) / total if total > 0 else 0
        pod = hits / (hits + misses) if (hits + misses) > 0 else 0
        far = false_alarms / (hits + false_alarms) if (hits + false_alarms) > 0 else 0
        csi = hits / (hits + misses + false_alarms) if (hits + misses + false_alarms) > 0 else 0
        denominator = ((hits + misses) * (misses + true_negatives) + 
                       (hits + false_alarms) * (false_alarms + true_negatives))
        hss = (2 * (hits * true_negatives - misses * false_alarms)) / denominator if denominator > 0 else 0
        random_hits = (hits + misses) * (hits + false_alarms) / total if total > 0 else 0
        ets_denom = (hits + misses + false_alarms - random_hits)
        ets = (hits - random_hits) / ets_denom if ets_denom > 0 else 0
        
        threshold_str = str(threshold).replace('.', '_')
        results[f'accuracy_{threshold_str}'] = accuracy
        results[f'pod_{threshold_str}'] = pod
        results[f'far_{threshold_str}'] = far
        results[f'csi_{threshold_str}'] = csi
        results[f'hss_{threshold_str}'] = hss
        results[f'ets_{threshold_str}'] = ets
    
    return results

In [18]:
# Plotting functions
def plot_scatter(y_true, y_pred, output_path):
    plt.figure(figsize=(10, 8))
    from scipy.stats import gaussian_kde
    xy = np.vstack([y_true.flatten(), y_pred.flatten()])
    z = gaussian_kde(xy)(xy)
    idx = z.argsort()
    x, y, z = y_true.flatten()[idx], y_pred.flatten()[idx], z[idx]
    plt.scatter(x, y, c=z, s=10, cmap='viridis', alpha=0.6)
    plt.colorbar(label='Density')
    plt.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'r--', label='Perfect Prediction')
    plt.xlabel('Ground Truth (Normalized Rainfall)')
    plt.ylabel('Predicted (Normalized Rainfall)')
    plt.title('Scatter Plot: Predicted vs Ground Truth')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig(output_path)
    plt.close()

def plot_rainfall_map(y_true, y_pred, output_path):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5), subplot_kw={'projection': ccrs.PlateCarree()})
    ax1.set_title('Ground Truth')
    ax2.set_title('Prediction')
    for ax, data in [(ax1, y_true), (ax2, y_pred)]:
        ax.coastlines()
        ax.add_feature(cfeature.BORDERS)
        im = ax.imshow(data, cmap='Blues', origin='upper', transform=ccrs.PlateCarree(), vmin=0, vmax=np.max([y_true.max(), y_pred.max()]))
        plt.colorbar(im, ax=ax, label='Rainfall (Normalized)')
    plt.savefig(output_path)
    plt.close()

def save_geotiff(data, output_path, reference_file):
    ds = gdal.Open(reference_file)
    driver = gdal.GetDriverByName('GTiff')
    out_ds = driver.Create(output_path, NEW_WIDTH, NEW_HEIGHT, 1, gdal.GDT_Float32)
    geo_transform = list(ds.GetGeoTransform())
    geo_transform[1] *= DOWNSCALE_FACTOR
    geo_transform[5] *= DOWNSCALE_FACTOR
    out_ds.SetGeoTransform(tuple(geo_transform))
    out_ds.SetProjection(ds.GetProjection())
    out_band = out_ds.GetRasterBand(1)
    out_band.WriteArray(data)
    out_band.FlushCache()
    out_ds = None
    ds = None

In [19]:
# Bắt đầu chương trình
print("Collecting Himawari files...")
hima_files = {}
for band in HIMA_BANDS:
    band_path = os.path.join(HIMA_PATH, band)
    if not os.path.exists(band_path):
        print(f"Directory not found: {band_path}")
        continue
    band_files = collect_files(band_path, expected_subdirs=HIMA_BANDS, data_type="Hima")
    for dt, paths in band_files.items():
        if dt not in hima_files:
            hima_files[dt] = {}
        hima_files[dt][band] = paths[band]

Collecting Himawari files...
Found 1438 files in /kaggle/input/btl-ai/DATA_SV/Hima/B04B
Found 1361 files in /kaggle/input/btl-ai/DATA_SV/Hima/B05B
Found 1158 files in /kaggle/input/btl-ai/DATA_SV/Hima/B06B
Found 2777 files in /kaggle/input/btl-ai/DATA_SV/Hima/B09B
Found 2777 files in /kaggle/input/btl-ai/DATA_SV/Hima/B10B
Found 2777 files in /kaggle/input/btl-ai/DATA_SV/Hima/B11B
Found 2777 files in /kaggle/input/btl-ai/DATA_SV/Hima/B12B
Found 2776 files in /kaggle/input/btl-ai/DATA_SV/Hima/B14B
Found 2776 files in /kaggle/input/btl-ai/DATA_SV/Hima/B16B
Found 2776 files in /kaggle/input/btl-ai/DATA_SV/Hima/I2B
Found 2673 files in /kaggle/input/btl-ai/DATA_SV/Hima/I4B
Found 2776 files in /kaggle/input/btl-ai/DATA_SV/Hima/IRB
Found 1448 files in /kaggle/input/btl-ai/DATA_SV/Hima/VSB
Found 2774 files in /kaggle/input/btl-ai/DATA_SV/Hima/WVB


In [20]:
print("Collecting ERA5 files...")
era5_files = {}
for param in ERA5_PARAMS:
    param_path = os.path.join(ERA5_PATH, param)
    if not os.path.exists(param_path):
        print(f"Directory not found: {param_path}")
        continue
    param_files = collect_files(param_path, expected_subdirs=ERA5_PARAMS, data_type="ERA5")
    for dt, paths in param_files.items():
        if dt not in era5_files:
            era5_files[dt] = {}
        era5_files[dt][param] = paths[param]

Collecting ERA5 files...
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/CAPE
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/CIN
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/EWSS
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/IE
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/ISOR
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/KX
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/PEV
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/R250
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/R500
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/R850
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/SLHF
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/SLOR
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/SSHF
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/TCLW
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/TCW
Found 2928 files in /kaggle/input/btl-ai/DATA_SV/ERA5/TCWV
Found 2928 files in /kaggle/input/btl-

In [21]:
print("Collecting Precipitation files...")
precip_files = collect_files(PRECIP_PATH, data_type="Radar")

Collecting Precipitation files...
Found 2487 files in /kaggle/input/btl-ai/DATA_SV/Precipitation/Radar


In [22]:
# Đồng bộ thời gian
common_datetimes = set(hima_files.keys()) & set(era5_files.keys()) & set(precip_files.keys())
common_datetimes = sorted(list(common_datetimes))
print(f"Số thời điểm đồng bộ: {len(common_datetimes)}")

Số thời điểm đồng bộ: 2337


In [23]:
from scipy.stats import pearsonr
from skimage.transform import resize

correlations = {}
# Chỉ chọn những thời điểm có đủ cả Hima, ERA5, Radar
common_datetimes = [dt for dt in common_datetimes if dt in precip_files]

def resize_data(data, new_height, new_width):
    return resize(data, (new_height, new_width), order=1, mode='reflect', anti_aliasing=True)

# Đọc và resize toàn bộ radar data trước
radar_data_list = []
for dt in common_datetimes:
    radar = read_geotiff(precip_files[dt])
    if radar is not None:
        radar_resized = resize_data(radar, NEW_HEIGHT, NEW_WIDTH)
        radar_data_list.append(radar_resized.flatten())
    else:
        # Nếu thiếu radar tại thời điểm nào thì bỏ luôn thời điểm đó
        common_datetimes.remove(dt)

radar_data_array = np.array(radar_data_list)  # shape = (num_samples, pixels)

print(f"Số mẫu sau khi đồng bộ và resize: {len(common_datetimes)}")

for band in HIMA_BANDS + ERA5_PARAMS:
    feature_data_list = []
    valid_datetimes = []
    for dt in common_datetimes:
        file_path = hima_files.get(dt, {}).get(band) or era5_files.get(dt, {}).get(band)
        if file_path:
            data = read_geotiff(file_path)
            if data is not None:
                resized_data = resize_data(data, NEW_HEIGHT, NEW_WIDTH)
                feature_data_list.append(resized_data.flatten())
                valid_datetimes.append(dt)
    
    # Chỉ giữ những radar sample ứng với những feature sample còn lại
    if len(feature_data_list) >= 5:  # Chỉ tính nếu có ít nhất 5 sample (để Pearson meaningful)
        feature_array = np.array(feature_data_list)
        # Lọc radar data ứng với valid_datetimes
        radar_matched = [radar_data_list[common_datetimes.index(dt)] for dt in valid_datetimes]
        radar_array = np.array(radar_matched)

        # Tính tương quan (flatten toàn bộ pixel)
        corr, _ = pearsonr(feature_array.flatten(), radar_array.flatten())
        correlations[band] = abs(corr)

# Chọn top 10 features có tương quan cao nhất
selected_features = sorted(correlations, key=correlations.get, reverse=True)[:10]

print("Top 10 features có tương quan cao nhất với lượng mưa:")
for feat in selected_features:
    print(f"{feat}: {correlations[feat]:.4f}")


Số mẫu sau khi đồng bộ và resize: 2337
Top 10 features có tương quan cao nhất với lượng mưa:
B11B: 0.1550
IRB: 0.1535
I2B: 0.1512
B14B: 0.1502
I4B: 0.1451
B16B: 0.1446
B12B: 0.1425
B10B: 0.1380
TCW: 0.1326
TCLW: 0.1290
