In [None]:
import os
import glob
import numpy as np
import netCDF4 as nc
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
import matplotlib.pyplot as plt

# Configuración
data_dir = ""
output_dir = ""
seq_len = 6  # Pasos de entrada
pred_len = 1  # Pasos de salida
batch_size = 1  # Ajustar según GPU (1-2 para 3090)
epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset
class RadarDataset(Dataset):
    def __init__(self, data_dir, seq_len=6, pred_len=1):
        self.seq_len = seq_len
        self.pred_len = pred_len
        self.folders = sorted([f for f in glob.glob(os.path.join(data_dir, "*")) if os.path.isdir(f)])
        self.sequences = []
        for folder in self.folders:
            files = sorted(glob.glob(os.path.join(folder, "*.nc")))
            if len(files) >= seq_len + pred_len:
                for i in range(len(files) - seq_len - pred_len + 1):
                    self.sequences.append(files[i:i + seq_len + pred_len])

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq_files = self.sequences[idx]
        # Cargar entrada (6 pasos)
        x = []
        for f in seq_files[:self.seq_len]:
            ds = nc.Dataset(f)
            dbz = ds.variables['DBZ'][:]  # Shape: (1, 18, 500, 500)
            ds.close()
            # Normalizar [-29, 60] a [0, 1]
            dbz = (dbz + 29) / (60 + 29)
            x.append(dbz[0])  # Quitar dim time
        x = np.stack(x, axis=0)  # Shape: (6, 18, 500, 500)
        # Cargar salida (1 paso)
        y = []
        for f in seq_files[self.seq_len:self.seq_len + self.pred_len]:
            ds = nc.Dataset(f)
            dbz = ds.variables['DBZ'][:]  # Shape: (1, 18, 500, 500)
            ds.close()
            dbz = (dbz + 29) / (60 + 29)
            y.append(dbz[0])
        y = np.stack(y, axis=0)  # Shape: (1, 18, 500, 500)
        return torch.tensor(x, dtype=torch.float32).unsqueeze(-1), torch.tensor(y, dtype=torch.float32).unsqueeze(-1)

# Modelo ConvLSTM
class ConvLSTM(nn.Module):
    def __init__(self, in_channels=1, hidden_channels=32, num_layers=2, kernel_size=3):
        super(ConvLSTM, self).__init__()
        self.convlstm = nn.ModuleList([
            nn.ConvLSTM2D(
                in_channels if i == 0 else hidden_channels,
                hidden_channels,
                kernel_size,
                1,
                batch_first=True
            ) for i in range(num_layers)
        ])
        self.conv = nn.Conv2d(hidden_channels, 1, kernel_size=3, padding=1)

    def forward(self, x):
        # x: (batch, seq_len, channels, height, width)
        for layer in self.convlstm:
            x, _ = layer(x)
        x = x[:, -1]  # Último paso temporal
        x = self.conv(x)
        return x

# Crear directorio de salida
os.makedirs(output_dir, exist_ok=True)

# Cargar datos
dataset = RadarDataset(data_dir, seq_len, pred_len)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Inicializar modelo
model = ConvLSTM(hidden_channels=32, num_layers=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
scaler = GradScaler()

# Entrenamiento
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)  # Shape: (batch, 6, 18, 500, 500, 1), (batch, 1, 18, 500, 500, 1)
        optimizer.zero_grad()
        with autocast():
            y_pred = model(x)  # Shape: (batch, 1, 18, 500, 500)
            loss = criterion(y_pred, y[:, 0])  # Comparar con el primer paso de salida
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.6f}")

# Guardar modelo
torch.save(model.state_dict(), os.path.join(output_dir, "convlstm.pth"))

# Generar predicciones (ejemplo)
model.eval()
with torch.no_grad():
    x, y = dataset[0]
    x = x.unsqueeze(0).to(device)  # Shape: (1, 6, 18, 500, 500, 1)
    y_pred = model(x)  # Shape: (1, 1, 18, 500, 500)
    # Desnormalizar
    y_pred = y_pred.cpu().numpy() * (60 + 29) - 29
    y = y.numpy() * (60 + 29) - 29
    # Guardar como NetCDF
    ds_out = nc.Dataset(os.path.join(output_dir, "pred_0.nc"), 'w', format='NETCDF4')
    ds_out.createDimension('time', 1)
    ds_out.createDimension('z', 18)
    ds_out.createDimension('y', 500)
    ds_out.createDimension('x', 500)
    dbz_var = ds_out.createVariable('DBZ', 'f4', ('time', 'z', 'y', 'x'))
    dbz_var[:] = y_pred
    ds_out.close()