In [None]:
import os
import torch
import logging
import numpy as np
import h5py
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
from tqdm import tqdm
import xarray as xr
from model.Triton_model import Triton


# ============================== Initialization Configuration ==============================
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s [%(levelname)s] %(message)s')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



def load_single_model(model_path):
    """Load the best inference model"""
    model = Triton(
        shape_in=(10, 2, 128, 128),
        spatial_hidden_dim=256,
        output_channels=2,
        temporal_hidden_dim=512,
        num_spatial_layers=4,
        num_temporal_layers=8).to(device)
    
    if os.path.exists(model_path):
        checkpoint = torch.load(model_path, map_location=device)
        if any(k.startswith('module.') for k in checkpoint.keys()):
            checkpoint = {k.replace('module.', ''): v for k, v in checkpoint.items()}
        model.load_state_dict(checkpoint)
        logging.info(f"Model loaded successfully.: {os.path.basename(model_path)}")
    else:
        raise FileNotFoundError(f"Model file not found.: {model_path}")
    return model

# ============================== Dataloader ==============================
class OceanDataLoader:
    def __init__(self, nc_path):
        self.ds = xr.open_dataset(nc_path)
        self.time_stamps = self.ds.time.values.astype('datetime64[s]').astype(datetime)
    
    def generate_target_dates(self, start_date, end_date, interval_days=1):
        """Generate a continuous initial date sequence."""
        all_dates = []
        current_date = datetime.strptime(start_date, "%Y-%m-%d")
        end_date = datetime.strptime(end_date, "%Y-%m-%d")
        
        while current_date <= end_date:
            if current_date in self.time_stamps:
                all_dates.append(current_date.strftime("%Y-%m-%d"))
            current_date += timedelta(days=interval_days)
        
        if not all_dates:
            raise ValueError("No valid dates found. Please check the date range and data files.")
        return all_dates
    
    def load_single_case(self, target_date, pred_days):
        """
        Load a single initial condition 
        param target_date: The last date of the input sequence (format: 'YYYY-MM-%d') 
        param pred_days: The number of days to predict 
        return: (initial data, true labels, initial timestamp, label timestamp) 
        """
        try:
            target_dt = datetime.strptime(target_date, "%Y-%m-%d")
            end_idx = np.where(self.time_stamps == target_dt)[0][0]
        except IndexError:
            available_dates = [d.strftime("%Y-%m-%d") for d in self.time_stamps[-10:]]
            raise ValueError(f"Invalid date {target_date}, the last 10 available dates: {available_dates}")

        if end_idx < 9:
            raise ValueError(f"At least 9 days of data are required, the earliest available date: {self.time_stamps[0].strftime('%Y-%m-%d')}")
        if end_idx + pred_days >= len(self.time_stamps):
            raise ValueError(f"Prediction exceeds data range, the data cutoff date is: {self.time_stamps[-1].strftime('%Y-%m-%d')}")

        initial_dates = self.time_stamps[end_idx-9 : end_idx+1]  
        label_dates = self.time_stamps[end_idx+1 : end_idx+1+pred_days]

        def load_var(var_name, start, end):
            data = self.ds[var_name].isel(time=slice(start, end)).values
            return torch.FloatTensor(np.nan_to_num(data, nan=0.0))

        ugos_init = load_var('ugos', end_idx-9, end_idx+1)
        vgos_init = load_var('vgos', end_idx-9, end_idx+1)
        initial = torch.stack([ugos_init, vgos_init], dim=1).unsqueeze(0).to(device)

        ugos_label = load_var('ugos', end_idx+1, end_idx+1+pred_days)
        vgos_label = load_var('vgos', end_idx+1, end_idx+1+pred_days)
        label = torch.stack([ugos_label, vgos_label], dim=1).unsqueeze(0).to(device)

        return initial, label, initial_dates, label_dates

# ============================== Inference engine ==============================
def predict_single(model, initial_input, pred_days):
    model.eval()
    predictions = []
    current_input = initial_input.clone()
    
    with torch.no_grad(), torch.cuda.amp.autocast():
        total_steps = (pred_days + 9) // 10
        for _ in tqdm(range(total_steps), desc=f"Prediction progress.", leave=False):
            output = model(current_input)
            predictions.append(output.cpu())
            current_input = output[:, -10:]
    
    return torch.cat(predictions, dim=1)[:, :pred_days].to(device)

# ============================== Batch processing. ==============================
def process_batch(model, data_loader, target_dates, pred_days, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    success_count = 0
    
    for date_str in tqdm(target_dates, desc="Process initial conditions."):
        try:
            initial, label, init_dates, label_dates = data_loader.load_single_case(date_str, pred_days)
            initial = initial[..., ::2, ::2] 
            label = label[..., ::2, ::2]
            
            prediction = predict_single(model, initial, pred_days)
            
            save_path = os.path.join(save_dir, f"forecast_{date_str.replace('-','')}.h5")
            save_results(
                initial.cpu(), 
                label.cpu(), 
                prediction.cpu(),
                init_dates, 
                label_dates,
                save_path
            )
            success_count += 1
            
            del initial, label, prediction
            torch.cuda.empty_cache()
            
        except Exception as e:
            logging.error(f"Process {date_str} failed: {str(e)}")
            continue
    
    logging.info(f"Processing complete, successfully processed {success_count}/{len(target_dates)} initial conditions")

# ============================== Results saved ==============================
def save_results(initial, label, prediction, init_dates, label_dates, save_path):
    with h5py.File(save_path, 'w') as f:
        f.create_dataset('initial', data=initial.numpy())
        f.create_dataset('label', data=label.numpy())
        f.create_dataset('prediction', data=prediction.numpy())
        
        def save_dates(dataset_name, dates):
            str_dates = [d.strftime("%Y-%m-%d") for d in dates]
            dt = h5py.string_dtype(encoding='utf-8')
            f.create_dataset(dataset_name, data=np.array(str_dates, dtype=dt))
        
        save_dates('initial_dates', init_dates)
        save_dates('label_dates', label_dates)
        
        f.attrs['input_end_date'] = init_dates[-1].strftime("%Y-%m-%d")
        f.attrs['pred_start_date'] = label_dates[0].strftime("%Y-%m-%d")
        f.attrs['pred_end_date'] = label_dates[-1].strftime("%Y-%m-%d")

def visualize_enhanced(h5_path, step=0, save_fig=True):
    with h5py.File(h5_path, 'r') as f:
        initial = f['initial'][0]
        label = f['label'][0]
        prediction = f['prediction'][0]
        init_dates = [d.decode() for d in f['initial_dates'][:]]
        label_dates = [d.decode() for d in f['label_dates'][:]]
    
    input_end_date = init_dates[-1]
    pred_date = label_dates[min(step, len(label_dates)-1)]
    
    def get_speed(data, step):
        return np.sqrt(data[step,0]**2 + data[step,1]**2)
    
    fig, axes = plt.subplots(1, 3, figsize=(24, 6))
    fig.suptitle(f"Comparison of Ocean Surface Current Speed\nInput End Date: {input_end_date} → Prediction Date: {pred_date}", 
                y=1.05, fontsize=14, fontweight='bold')
    
    plot_kwargs = {
        'cmap': 'jet',
        'extent': [123.1, 154.9, 10.06, 41.94],  
        'origin': 'lower',
        'vmin': 0,
        'vmax': max(np.nanmax(label), np.nanmax(prediction))
    }
    
    speed_initial = get_speed(initial, -1)
    im0 = axes[0].imshow(speed_initial, **plot_kwargs)
    axes[0].set_title(f"Initial Field Last Day\n{init_dates[-1]}", fontsize=12)
    axes[0].set_xlabel('Longitude', fontsize=10)
    axes[0].set_ylabel('Latitude', fontsize=10)

    
    speed_label = get_speed(label, step)
    im1 = axes[1].imshow(speed_label, **plot_kwargs)
    axes[1].set_title(f"True Values\n{pred_date}", fontsize=12)
    axes[1].set_xlabel('Longitude', fontsize=10)

    speed_pred = get_speed(prediction, step)
    im2 = axes[2].imshow(speed_pred, **plot_kwargs)
    axes[2].set_title(f"Predicted Values\n{pred_date}", fontsize=12)
    axes[2].set_xlabel('Longitude', fontsize=10)

    cbar = fig.colorbar(im1, ax=axes, orientation='vertical', shrink=0.8, pad=0.03)
    cbar.set_label('Current Speed (m/s)', fontsize=10)

    plt.tight_layout()

    if save_fig:
        fig_name = f"forecast_{input_end_date}_day{step+1}.png"
        plt.savefig(fig_name, dpi=300, bbox_inches='tight')
        plt.close()
    else:
        plt.show()


# ============================== Main ==============================
if __name__ == "__main__":
    backbone = 'Kuro_Triton_exp1_128_20250322'
    config = {
        'model_path': f'/jizhicfs/easyluwu/ocean_project/NPJ_baselines/Exp_2_Kuroshio/checkpoints/{backbone}_best_model.pth',
        'data_path': '/jizhicfs/easyluwu/ocean_project/kuro/KURO.nc',
        'date_range': {  
            'start': '2021-01-01',
            'end': '2021-12-31',
            'interval': 5  
        },
        'pred_days': 120,
        'save_dir':f'./{backbone}_forecast_results'
    }

    try:
        model = load_single_model(config['model_path'])
        data_loader = OceanDataLoader(config['data_path'])
        
        target_dates = data_loader.generate_target_dates(
            start_date=config['date_range']['start'],
            end_date=config['date_range']['end'],
            interval_days=config['date_range']['interval']
        )
        logging.info(f"Generated {len(target_dates)} initial dates, example: {target_dates[:5]}...")
        
        process_batch(
            model, 
            data_loader,
            target_dates,
            config['pred_days'],
            config['save_dir']
        )
        
        sample_dates = [target_dates[0], target_dates[-1]]
        for date in sample_dates:
            h5_file = os.path.join(config['save_dir'], f"forecast_{date.replace('-','')}.h5")
            for step in [0, 60, 119]: 
                visualize_enhanced(h5_file, step=step)
        
    except Exception as e:
        logging.error(f"Main process error: {str(e)}")
        raise