In [1]:
import os
import numpy as np
from datetime import datetime, timedelta
from osgeo import gdal
import pandas as pd
import csv
import gc
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import IterableDataset, DataLoader
from sklearn.metrics import mean_squared_error

# Định nghĩa các hằng số và đường dẫn
BASE_PATH = "/kaggle/input/btl-ai/DATA_SV"
ERA5_PATH = os.path.join(BASE_PATH, "ERA5")
AWS_PATH = os.path.join(BASE_PATH, "Precipitation/AWS")
OUTPUT_PATH = "/kaggle/working"
os.makedirs(OUTPUT_PATH, exist_ok=True)
SELECTED_ERA5_PARAMS = ['TCW', 'U850', 'EWSS', 'V850', 'TCLW', 'U250', 'R850', 'R500', 'CAPE', 'KX', 'V250', 'R250']
HEIGHT, WIDTH = 90, 250

# Giảm số thời điểm và tham số mô hình
NUM_TIMEPOINTS = 5
TIMESTEPS = 3
FEATURES = len(SELECTED_ERA5_PARAMS)
HIDDEN_SIZE = 32
NUM_LAYERS = 1
EPOCHS = 5
BATCH_SIZE = 16

# Biến global để đếm lỗi
error_count = 0

# Hàm đọc file GeoTIFF và trích xuất thông tin địa lý
def read_geotiff(file_path):
    try:
        ds = gdal.Open(file_path)
        band = ds.GetRasterBand(1)
        data = band.ReadAsArray()
        geotransform = ds.GetGeoTransform()
        if geotransform is None:
            print(f"No geotransform available for {file_path}, using default grid")
            return data, None, None, None, None
        lon_min = geotransform[0]
        lat_max = geotransform[3]
        lon_res = geotransform[1]
        lat_res = geotransform[5]
        ds = None
        if data.shape != (HEIGHT, WIDTH):
            print(f"Invalid shape {data.shape} for file {file_path}, expected ({HEIGHT}, {WIDTH})")
            return None, None, None, None, None
        return data, lon_min, lat_max, lon_res, lat_res
    except Exception as e:
        print(f"Error reading {file_path}: {e}")
        return None, None, None, None, None

# Hàm phân tích thời gian từ tên file
def parse_datetime_from_filename(filename, data_type):
    try:
        if data_type == "ERA5":
            parts = filename.split('_')
            if len(parts) < 2:
                return None
            time_part = parts[1].replace('.tif', '')
            dt = datetime.strptime(time_part, '%Y%m%d%H%M%S')
        elif data_type == "AWS":
            time_part = filename.split('_')[1].replace('.tif', '')
            dt = datetime.strptime(time_part, '%Y%m%d%H%M%S')
        else:
            return None
        return dt.replace(minute=0, second=0, microsecond=0)
    except Exception as e:
        global error_count
        if error_count < 5:
            print(f"Error parsing datetime from {filename} (type {data_type}): {e}")
            error_count += 1
        return None

# Hàm thu thập file
def collect_files(base_path, expected_subdirs=None, data_type=None):
    files_dict = {}
    file_count = 0
    for root, dirs, files in os.walk(base_path):
        for file in files:
            if file.endswith('.tif'):
                file_path = os.path.join(root, file)
                dt = parse_datetime_from_filename(file, data_type)
                if dt is None:
                    continue
                file_count += 1
                if expected_subdirs:
                    subdir = os.path.basename(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(file_path)))))
                    if subdir not in expected_subdirs:
                        continue
                    if dt not in files_dict:
                        files_dict[dt] = {}
                    files_dict[dt][subdir] = file_path
                else:
                    files_dict[dt] = file_path
    print(f"Found {file_count} files in {base_path}")
    return files_dict

# Hàm xử lý giá trị không hợp lệ
def preprocess_data(data):
    if data is None:
        return None
    data = np.where(np.isinf(data) | np.isnan(data) | (data == -9999), np.nan, data)
    return data

# Hàm tạo lưới Lat, Lon
def create_lat_lon_grid(lon_min, lat_max, lon_res, lat_res, height, width):
    if lon_min is None or lat_max is None:
        lats = np.linspace(0, height-1, height)
        lons = np.linspace(0, width-1, width)
    else:
        lats = lat_max + np.arange(height) * lat_res
        lons = lon_min + np.arange(width) * lon_res
    lon_grid, lat_grid = np.meshgrid(lons, lats)
    return lat_grid.flatten(), lon_grid.flatten()

# Hàm tạo bảng CSV
def create_data_table():
    print("Creating CSV file...")
    aws_files = collect_files(AWS_PATH, data_type="AWS")
    era5_files = collect_files(ERA5_PATH, expected_subdirs=SELECTED_ERA5_PARAMS, data_type="ERA5")

    aws_datetimes = set(aws_files.keys())
    era5_datetimes = set(era5_files.keys())
    common_datetimes = sorted(aws_datetimes.intersection(era5_datetimes))
    print(f"Common datetimes: {common_datetimes}")

    common_datetimes = common_datetimes[:NUM_TIMEPOINTS]
    print(f"Processing first {NUM_TIMEPOINTS} common datetimes: {common_datetimes}")

    output_file = os.path.join(OUTPUT_PATH, 'aws_era5_selected_data.csv')
    header = ['Lat', 'Lon', 'Time', 'AWS'] + [f'ERA5_{param}' for param in SELECTED_ERA5_PARAMS]
    with open(output_file, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=header)
        writer.writeheader()

        for dt in common_datetimes:
            aws_file = aws_files.get(dt)
            if not aws_file:
                continue
            aws_data, lon_min, lat_max, lon_res, lat_res = read_geotiff(aws_file)
            aws_data = preprocess_data(aws_data)
            if aws_data is None:
                continue

            lat_flat, lon_flat = create_lat_lon_grid(lon_min, lat_max, lon_res, lat_res, HEIGHT, WIDTH)

            era5_data_dict = {}
            valid_era5 = True
            for param in SELECTED_ERA5_PARAMS:
                era5_file = era5_files.get(dt, {}).get(param)
                if not era5_file:
                    valid_era5 = False
                    break
                data, _, _, _, _ = read_geotiff(era5_file)
                data = preprocess_data(data)
                if data is None:
                    valid_era5 = False
                    break
                era5_data_dict[param] = data.flatten()
            if not valid_era5:
                continue

            for idx in range(HEIGHT * WIDTH):
                row = {
                    'Lat': lat_flat[idx],
                    'Lon': lon_flat[idx],
                    'Time': dt.strftime('%Y-%m-%d %H:%M:%S'),
                    'AWS': aws_data.flatten()[idx]
                }
                for param in SELECTED_ERA5_PARAMS:
                    row[f'ERA5_{param}'] = era5_data_dict[param][idx]
                writer.writerow(row)

            del aws_data, lat_flat, lon_flat, era5_data_dict
            gc.collect()

    if os.path.exists(output_file):
        file_size = os.path.getsize(output_file)
        print(f"File size: {file_size} bytes")
    else:
        print("File was not created!")
    print("Files in /kaggle/working:")
    print(os.listdir("/kaggle/working"))

# Tạo file CSV
create_data_table()

# Kiểm tra file
DATA_PATH = "/kaggle/working/aws_era5_selected_data.csv"
if not os.path.exists(DATA_PATH):
    print(f"Error: File {DATA_PATH} does not exist!")
    raise FileNotFoundError(f"File {DATA_PATH} not found!")

# Định nghĩa IterableDataset để xử lý dữ liệu theo batch
class TimeSeriesDataset(IterableDataset):
    def __init__(self, csv_file, timesteps):
        self.csv_file = csv_file
        self.timesteps = timesteps
        self.df = pd.read_csv(csv_file)
        self.df['Time'] = pd.to_datetime(self.df['Time'])
        self.df = self.df.dropna()
        self.df = self.df.sort_values(['Lat', 'Lon', 'Time'])
        self.groups = self.df.groupby(['Lat', 'Lon'])

    def __iter__(self):
        for (lat, lon), group in self.groups:
            group = group.sort_values('Time')
            times = group['Time'].values
            features = group[[f'ERA5_{param}' for param in SELECTED_ERA5_PARAMS]].values
            target = group['AWS'].values

            for i in range(self.timesteps, len(group)):
                valid_sequence = True
                for j in range(1, self.timesteps):
                    # Chuyển timedelta thành numpy.timedelta64 để tương thích với datetime64[ns]
                    delta = np.timedelta64(j, 'h')
                    if times[i-j] != times[i] - delta:
                        valid_sequence = False
                        break
                if not valid_sequence:
                    continue

                sequence = features[i-self.timesteps:i]
                target_value = target[i]
                yield torch.FloatTensor(sequence), torch.FloatTensor([target_value])

# Tạo DataLoader
dataset = TimeSeriesDataset(DATA_PATH, TIMESTEPS)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

# Định nghĩa mô hình LSTM
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out

# Khởi tạo mô hình
model = LSTMModel(input_size=FEATURES, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, output_size=1)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Huấn luyện mô hình
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Training on device: {device}")

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0
    batch_count = 0
    for batch_X, batch_y in dataloader:
        batch_X, batch_y = batch_X.to(device), batch_y.to(device)
        optimizer.zero_grad()
        output = model(batch_X)
        loss = criterion(output.squeeze(), batch_y.squeeze())
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * batch_X.size(0)
        batch_count += batch_X.size(0)
    train_loss /= batch_count
    print(f"Epoch {epoch+1}/{EPOCHS}, Train Loss: {train_loss:.6f}")

# Lưu mô hình
output_model_path = "/kaggle/working/lstm_aws_era5.pth"
torch.save(model.state_dict(), output_model_path)
print(f"Model saved to {output_model_path}")

# Hiển thị các file trong /kaggle/working
print("Files in /kaggle/working:")
print(os.listdir("/kaggle/working"))

# Tạo liên kết download
from IPython.display import FileLink
print("Download link for model:")
display(FileLink('lstm_aws_era5.pth'))

Creating CSV file...
Found 2807 files in /kaggle/input/btl-ai/DATA_SV/Precipitation/AWS
Found 58560 files in /kaggle/input/btl-ai/DATA_SV/ERA5
Common datetimes: [datetime.datetime(2019, 4, 1, 0, 0), datetime.datetime(2019, 4, 1, 1, 0), datetime.datetime(2019, 4, 1, 2, 0), datetime.datetime(2019, 4, 1, 3, 0), datetime.datetime(2019, 4, 1, 4, 0), datetime.datetime(2019, 4, 1, 5, 0), datetime.datetime(2019, 4, 1, 6, 0), datetime.datetime(2019, 4, 1, 7, 0), datetime.datetime(2019, 4, 1, 8, 0), datetime.datetime(2019, 4, 1, 9, 0), datetime.datetime(2019, 4, 1, 10, 0), datetime.datetime(2019, 4, 1, 11, 0), datetime.datetime(2019, 4, 1, 12, 0), datetime.datetime(2019, 4, 1, 13, 0), datetime.datetime(2019, 4, 1, 14, 0), datetime.datetime(2019, 4, 1, 15, 0), datetime.datetime(2019, 4, 1, 16, 0), datetime.datetime(2019, 4, 1, 17, 0), datetime.datetime(2019, 4, 1, 18, 0), datetime.datetime(2019, 4, 1, 19, 0), datetime.datetime(2019, 4, 1, 20, 0), datetime.datetime(2019, 4, 1, 21, 0), datetime.dat