In [None]:
DEMO_MODE = True

In [None]:
import numpy as np
import torch
import h5py
import glob
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary
import matplotlib.pyplot as plt
import gc
import torchvision.transforms.functional as TF
import re

# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Data loading functions
def load_data_from_mat(filename, u_v_names=None, w_name=None, mld_name=None):
    with h5py.File(filename, 'r') as f:
        if u_v_names:
            u_velocity = torch.tensor(np.array(f[u_v_names[0]])).float()  # [t, n, m]
            v_velocity = torch.tensor(np.array(f[u_v_names[1]])).float()  # [t, n, m]
        else:
            u_velocity = v_velocity = None

        if w_name:
            w_velocity = torch.tensor(np.array(f[w_name])).squeeze().float()  # [t]
        else:
            w_velocity = None

        if mld_name:
            mld = torch.tensor(np.array(f[mld_name])).squeeze().float()  # [t]
        else:
            mld = None
            
    return u_velocity, v_velocity, w_velocity, mld
    
def load_multiple_h5py_data(u_v_files_pattern, u_v_names):
    u_velocity_list, v_velocity_list = [], []
    u_v_files = glob.glob(u_v_files_pattern)
    u_v_files = sorted(u_v_files, key=lambda x: [int(num) for num in re.findall(r'\d+', x)])
    
    for file in u_v_files:
        print(file)
        u_velocity, v_velocity = load_data_from_mat(file, u_v_names, None)[:2]
        u_velocity_list.append(u_velocity)
        v_velocity_list.append(v_velocity)

    u_velocity = torch.cat(u_velocity_list, dim=0)  # [t_total, n, m]
    v_velocity = torch.cat(v_velocity_list, dim=0)  # [t_total, n, m]

    return u_velocity, v_velocity

def load_multiple_w_data(w_files_pattern, w_name):
    w_velocity_list = []
    w_files = glob.glob(w_files_pattern)

    w_files = sorted(w_files, key=lambda x: [int(num) for num in re.findall(r'\d+', x)])
    
    for file in w_files:
        print(file)
        _, _, w_velocity, _ = load_data_from_mat(file, None, w_name)
        w_velocity_list.append(w_velocity)

    w_velocity = torch.cat(w_velocity_list, dim=0)  # [t_total, n, m]

    return w_velocity

def load_multiple_mld_data(mld_files_pattern, mld_name):
    mld_list = []
    mld_files = glob.glob(mld_files_pattern)
    mld_files = sorted(mld_files, key=lambda x: [int(num) for num in re.findall(r'\d+', x)])
    for file in mld_files:
        print(file)
        _, _, _, mld = load_data_from_mat(file, None, None, mld_name)
        if mld is not None:
            mld_list.append(mld)
    if len(mld_list) == 0:
        raise ValueError("No MLD data found.")
    mld = torch.cat(mld_list, dim=0)  # [t_total]
    return mld
    
def calculate_stats(tensor):
    tensor_np = tensor.cpu().numpy()
    mean = np.nanmean(tensor_np, axis=(0, 2, 3), keepdims=True)
    std = np.nanstd(tensor_np, axis=(0, 2, 3), keepdims=True)
    return torch.tensor(mean).to(tensor.device), torch.tensor(std).to(tensor.device)

def normalize(tensor, mean, std):
    return (tensor - mean) / (std + 1e-5)

class LRVelocityMLDDataset(Dataset):
    def __init__(self, high_u_velocity, high_v_velocity, w_velocity, mld, mean, std):
        self.high_u_velocity = torch.nan_to_num(high_u_velocity)  # [t, n, m]
        self.high_v_velocity = torch.nan_to_num(high_v_velocity)  # [t, n, m]
        self.w_velocity = torch.nan_to_num(w_velocity)  # [t, n, m]
        self.mld = torch.nan_to_num(mld)                # [t]

        if self.mld.dim() == 1:
            t = self.mld.shape[0]
            n, m = self.high_u_velocity.shape[1], self.high_u_velocity.shape[2]
            self.mld = self.mld[:, None, None].repeat(1, n, m)  # [t,n,m]
        
        self.inputs = torch.stack([self.high_u_velocity, self.high_v_velocity, self.mld], dim=1)  # [t, 3, n, m]
        self.outputs = torch.stack([self.w_velocity], dim=1)  # [t, 1, n, m]
        
        self.inputs = normalize(self.inputs, mean, std)

    def __len__(self):
        return self.inputs.shape[0]
    
    def __getitem__(self, idx):
        input_data = self.inputs[idx, :, :, :]                  # [3, n, m]
        output_data = self.outputs[idx, :, :, :]                  # [1, n, m]
        return input_data, output_data


In [None]:
if not DEMO_MODE:
    high_u_v_names_train = ['BBE_surface_u_rho_points', 'BBE_surface_v_rho_points']
    w_name_train = 'BBE_surface_w'
    mld_name_train = 'mld'
    
    train_high_u_velocity, train_high_v_velocity = load_multiple_h5py_data(
        '/home/user_bk/roms_ccs/ver4_rotate/data_for_cnn_final/ROMS_surface_velocity_rho_points_2021_*.mat',
        high_u_v_names_train
    )
    print(f"train_high_u_velocity shape: {train_high_u_velocity.shape}")  # [t_total]
    print(f"train_high_v_velocity shape: {train_high_v_velocity.shape}")  # [t_total]
    
    w_files_pattern_train = '/home/user_bk/roms_ccs/ver4_rotate/data_for_cnn_final/ROMS_surface_vertical_velocity_2021_*.mat'
    train_w_velocity = load_multiple_w_data(w_files_pattern_train, w_name_train)
    print(f"train_w shape: {train_w_velocity.shape}")  # [t_total]
    
    mld_files_pattern_train = '/home/user_bk/roms_ccs/ver4_rotate/data_for_cnn_final/roms_temp_2021_*.mat'
    train_MLD = load_multiple_mld_data(mld_files_pattern_train, mld_name_train)
    print(f"train_mld shape: {train_MLD.shape}")  # [t_total]
    
    assert train_high_u_velocity.shape[0] == train_high_v_velocity.shape[0] == train_w_velocity.shape[0], \
        "All input tensors must have the same number of samples (t dimension)."
    
    from sklearn.model_selection import train_test_split
    import numpy as np
    
    num_hours_per_month = [744, 672, 744, 720, 744, 720, 744, 744, 720, 744, 720, 744]
    cumulative_hours = np.cumsum([0] + num_hours_per_month)
    
    train_u, val_u, test_u = [], [], []
    train_v, val_v, test_v = [], [], []
    train_w, val_w, test_w = [], [], []
    train_mld, val_mld, test_mld = [], [], []
    
    for i in range(12):
        start, end = cumulative_hours[i], cumulative_hours[i+1]
    
        u_month = train_high_u_velocity[start:end]
        v_month = train_high_v_velocity[start:end]
        w_month = train_w_velocity[start:end]
        mld_month = train_MLD[start:end]
    
        all_data = list(zip(u_month, v_month, mld_month, w_month))
    
        indices = np.arange(len(all_data))
        
        train_idx, test_idx = train_test_split(
            indices, test_size=0.3, random_state=42
        )
        
        print("Train indices:", train_idx)
        print("Test indices:", test_idx)
    
    num_hours_per_month = [744, 672, 744, 720, 744, 720, 744, 744, 720, 744, 720, 744]
    cumulative_hours = np.cumsum([0] + num_hours_per_month)
    
    train_u, val_u, test_u = [], [], []
    train_v, val_v, test_v = [], [], []
    train_w, val_w, test_w = [], [], []
    train_mld, val_mld, test_mld = [], [], []
    
    for i in range(12):
        start, end = cumulative_hours[i], cumulative_hours[i+1]
    
        u_month = train_high_u_velocity[start:end]
        v_month = train_high_v_velocity[start:end]
        w_month = train_w_velocity[start:end]
        mld_month = train_MLD[start:end]
    
        all_data = list(zip(u_month, v_month, mld_month, w_month))
    
        train_set, val_test_set = train_test_split(all_data, test_size=0.3, random_state=42)
    
        val_set, test_set = train_test_split(val_test_set, test_size=0.5, random_state=42)
    
        u_train, v_train, mld_train, w_train = zip(*train_set)
        u_val, v_val, mld_val, w_val = zip(*val_set)
        u_test, v_test, mld_test, w_test = zip(*test_set)
    
        train_u.append(np.stack(u_train))
        val_u.append(np.stack(u_val))
        test_u.append(np.stack(u_test))
    
        train_v.append(np.stack(v_train))
        val_v.append(np.stack(v_val))
        test_v.append(np.stack(v_test))
    
        train_mld.append(np.stack(mld_train))
        val_mld.append(np.stack(mld_val))
        test_mld.append(np.stack(mld_test))
    
        train_w.append(np.stack(w_train))
        val_w.append(np.stack(w_val))
        test_w.append(np.stack(w_test))
    
    inputs_high_u_train = np.concatenate(train_u)
    inputs_high_u_val = np.concatenate(val_u)
    inputs_high_u_test = np.concatenate(test_u)
    
    inputs_high_v_train = np.concatenate(train_v)
    inputs_high_v_val = np.concatenate(val_v)
    inputs_high_v_test = np.concatenate(test_v)
    
    outputs_w_train = np.concatenate(train_w)
    outputs_w_val = np.concatenate(val_w)
    outputs_w_test = np.concatenate(test_w)
    
    inputs_mld_train = np.concatenate(train_mld)
    inputs_mld_val = np.concatenate(val_mld)
    inputs_mld_test = np.concatenate(test_mld)
    
    print(f"Train size: {len(inputs_high_u_train)} ({len(inputs_high_u_train) / 8760 * 100:.1f}%)")
    print(f"Validation size: {len(inputs_high_u_val)} ({len(inputs_high_u_val) / 8760 * 100:.1f}%)")
    print(f"Test size: {len(inputs_high_u_test)} ({len(inputs_high_u_test) / 8760 * 100:.1f}%)")
    
    inputs_high_u_train = torch.from_numpy(inputs_high_u_train).float()
    inputs_high_u_val = torch.from_numpy(inputs_high_u_val).float()
    inputs_high_u_test = torch.from_numpy(inputs_high_u_test).float()
    
    inputs_high_v_train = torch.from_numpy(inputs_high_v_train).float()
    inputs_high_v_val = torch.from_numpy(inputs_high_v_val).float()
    inputs_high_v_test = torch.from_numpy(inputs_high_v_test).float()
    
    outputs_w_train = torch.from_numpy(outputs_w_train).float()
    outputs_w_val = torch.from_numpy(outputs_w_val).float()
    outputs_w_test = torch.from_numpy(outputs_w_test).float()
    
    inputs_mld_train = torch.from_numpy(inputs_mld_train).float()
    inputs_mld_val = torch.from_numpy(inputs_mld_val).float()
    inputs_mld_test = torch.from_numpy(inputs_mld_test).float()
    
    if inputs_mld_train.dim() == 1:
        n, m = inputs_high_u_train.shape[1], inputs_high_u_train.shape[2]
        inputs_mld_train_map = inputs_mld_train[:, None, None].repeat(1, n, m)
    else:
        inputs_mld_train_map = inputs_mld_train
        
    train_mean, train_std = calculate_stats(torch.stack([inputs_high_u_train, inputs_high_v_train, inputs_mld_train], dim=1))
    
    train_dataset = LRVelocityMLDDataset(
        high_u_velocity=inputs_high_u_train,
        high_v_velocity=inputs_high_v_train,
        mld=inputs_mld_train,
        w_velocity=outputs_w_train * 1000.0,
        mean=train_mean,
        std=train_std
    )
    
    val_dataset = LRVelocityMLDDataset(
        high_u_velocity=inputs_high_u_val,
        high_v_velocity=inputs_high_v_val,
        mld=inputs_mld_val,
        w_velocity=outputs_w_val * 1000.0,
        mean=train_mean,
        std=train_std
    )
    
    test_dataset = LRVelocityMLDDataset(
        high_u_velocity=inputs_high_u_test,
        high_v_velocity=inputs_high_v_test,
        mld=inputs_mld_test,
        w_velocity=outputs_w_test * 1000.0,
        mean=train_mean,
        std=train_std
    )
    
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

else:
    print("DEMO MODE: using dummy tensors (no external data files).")

    t_train, t_val, t_test = 16, 8, 8
    n, m = 280, 340

    # ---- dummy inputs ----
    inputs_high_u_train = torch.randn(t_train, n, m)
    inputs_high_u_val   = torch.randn(t_val, n, m)
    inputs_high_u_test  = torch.randn(t_test, n, m)

    inputs_high_v_train = torch.randn(t_train, n, m)
    inputs_high_v_val   = torch.randn(t_val, n, m)
    inputs_high_v_test  = torch.randn(t_test, n, m)

    # w target (note: your original code multiplies by 1000.0; keep that if you want)
    outputs_w_train = torch.randn(t_train, n, m) * 1000.0
    outputs_w_val   = torch.randn(t_val, n, m) * 1000.0
    outputs_w_test  = torch.randn(t_test, n, m) * 1000.0

    # mld is originally [t]; keep it 1D here and broadcast inside Dataset (see section 2 below)
    inputs_mld_train = torch.randn(t_train)
    inputs_mld_val   = torch.randn(t_val)
    inputs_mld_test  = torch.randn(t_test)

    # ---- IMPORTANT: compute stats using broadcasted mld map ----
    mld_train_map = inputs_mld_train[:, None, None].repeat(1, n, m)  # [t,n,m]
    train_mean, train_std = calculate_stats(
        torch.stack([inputs_high_u_train, inputs_high_v_train, mld_train_map], dim=1)
    )

    train_dataset = LRVelocityMLDDataset(
        high_u_velocity=inputs_high_u_train,
        high_v_velocity=inputs_high_v_train,
        mld=inputs_mld_train,          # 1D ok if Dataset broadcasts
        w_velocity=outputs_w_train,
        mean=train_mean,
        std=train_std
    )

    val_dataset = LRVelocityMLDDataset(
        high_u_velocity=inputs_high_u_val,
        high_v_velocity=inputs_high_v_val,
        mld=inputs_mld_val,
        w_velocity=outputs_w_val,
        mean=train_mean,
        std=train_std
    )

    test_dataset = LRVelocityMLDDataset(
        high_u_velocity=inputs_high_u_test,
        high_v_velocity=inputs_high_v_test,
        mld=inputs_mld_test,
        w_velocity=outputs_w_test,
        mean=train_mean,
        std=train_std
    )

    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    val_loader   = DataLoader(val_dataset, batch_size=4, shuffle=False)
    test_loader  = DataLoader(test_dataset, batch_size=4, shuffle=False)

In [None]:
def plot_losses(train_losses, val_losses):
    epochs = range(1, len(train_losses)+1)
    plt.figure(figsize=(10,5))
    plt.plot(epochs, train_losses, 'bo-', label='Training loss')
    plt.plot(epochs, val_losses, 'ro-', label='Validation loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            m.bias.data.zero_()

In [None]:
def test_model(model, test_loader):
    model.eval()
    predictions = []
    true_labels = []
    with torch.no_grad():  # No need to compute gradients during testing
        for inputs, outputs in test_loader:
            inputs = inputs.to(device)
            outputs = outputs.to(device)
            
            with torch.cuda.amp.autocast():
                p_outputs = model(inputs)
                predictions.append(p_outputs.cpu().numpy())  # Store predictions
                true_labels.append(outputs.cpu().numpy())  # Store true labels

    # Concatenate all predictions and true labels
    predictions = np.concatenate(predictions, axis=0)
    true_labels = np.concatenate(true_labels, axis=0)

    return predictions, true_labels

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=200, accumulate_steps=4, device='cuda', scheduler=None):
    train_losses = []
    val_losses = []
    scaler = torch.cuda.amp.GradScaler()

    model.apply(weights_init)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        optimizer.zero_grad()

        for i, (inputs, labels) in enumerate(train_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            with torch.cuda.amp.autocast():
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss = loss / accumulate_steps  # Gradients Accumulation

            scaler.scale(loss).backward()

            if (i + 1) % accumulate_steps == 0 or (i + 1) == len(train_loader):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            running_loss += loss.item() * inputs.size(0)

            torch.cuda.empty_cache()
            gc.collect()

        avg_train_loss = running_loss / len(train_loader.dataset)
        train_losses.append(avg_train_loss)

        val_loss = 0.0
        model.eval()
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)

                with torch.cuda.amp.autocast():
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                val_loss += loss.item() * inputs.size(0)

        avg_val_loss = val_loss / len(val_loader.dataset)
        val_losses.append(avg_val_loss)
        
        if scheduler:
            scheduler.step()

        for param_group in optimizer.param_groups:
            print(f"Learning Rate: {param_group['lr']:.6f}")

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

        torch.cuda.empty_cache()
        gc.collect()

    print("Finished Training")
    plot_losses(train_losses, val_losses)

In [None]:
# (Conv → BatchNorm → ReLU) × 2
def double_conv(in_channels, out_channels, dropout_rate=0.0):
    layers = [
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    ]
    if dropout_rate > 0:
        layers.insert(2, nn.Dropout2d(dropout_rate))
    return nn.Sequential(*layers)


class UNet(nn.Module):
    def __init__(self, 
                 in_channels=3,
                 out_channels=1,
                 features=[64, 128, 256, 512],
                 dropout_rate=0.0):
        super(UNet, self).__init__()

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.downs = nn.ModuleList()
        current_in_channels = in_channels
        for feature in features:
            self.downs.append(double_conv(current_in_channels, feature, dropout_rate))
            current_in_channels = feature

        self.bottleneck = double_conv(features[-1], features[-1]*2, dropout_rate)

        self.ups = nn.ModuleList()
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(in_channels=feature*2, 
                                   out_channels=feature,
                                   kernel_size=2, stride=2)
            )
            self.ups.append(
                double_conv(feature*2, feature, dropout_rate)
            )

        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down_block in self.downs:
            x = down_block(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)

        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)

            skip_connection = skip_connections[idx // 2]

            if x.shape != skip_connection.shape:
                x = F.interpolate(x, size=skip_connection.shape[2:], 
                                  mode='bilinear', align_corners=False)

            x = torch.cat((skip_connection, x), dim=1)

            x = self.ups[idx + 1](x)

        return self.final_conv(x)

In [None]:
model = UNet().to(device)
model.apply(weights_init)
summary(model, input_size=(3, 280, 340))

criterion = nn.MSELoss()
learning_rate = 0.00022608954273243165
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

scaler = torch.cuda.amp.GradScaler()

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

In [None]:
def check_nan(tensor, tensor_name="Tensor"):
    if torch.isnan(tensor).any():
        print(f"NaN values detected in {tensor_name}.")
    else:
        print(f"No NaN values detected in {tensor_name}.")

check_nan(train_dataset.inputs, "train_dataset.inputs")

check_nan(val_dataset.inputs, "val_dataset.inputs")

check_nan(test_dataset.inputs, "test_dataset.inputs")

In [None]:
# Train the model
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=100, accumulate_steps=4, device=device)

# Test the model
predictions, true_labels = test_model(model, test_loader)

# Print or save the test results as needed
print("Test predictions and true labels have been computed.")

In [None]:
hycom_grid_fn_path = '/home/user_bk/roms_ccs/ver4_rotate/data_for_cnn/HYCOM_grid_info.mat'
with h5py.File(hycom_grid_fn_path, 'r') as mat_file:
    hycom_lon = mat_file['lon'][:]
    hycom_lat = mat_file['lat'][:]

roms_grid_fn_path = '/home/user_bk/roms_ccs/ver4_rotate/data_for_cnn/roms_grid_info.mat'
with h5py.File(roms_grid_fn_path, 'r') as mat_file:
    roms_lon = mat_file['lon_rho'][:]
    roms_lat = mat_file['lat_rho'][:]

def visualize_with_lat_lon(lat, lon, true_velocity, predicted_velocity, sample_idx=0):
    """
    lat: latitude [height]
    lon: longitude [width]
    true_vorticity: Ground truth vorticity [height, width]
    predicted_vorticity: Predicted vorticity [height, width]
    sample_idx: the index of the sample to visualize (not used here)
    """
    
    # Selecting the sample to visualize (you can loop over multiple samples if needed)
    true_velocity = true_velocity[sample_idx].squeeze()  # Ground truth velocity [height, width]
    predicted_velocity = predicted_velocity[sample_idx].squeeze()  # Predicted velocity [height, width]

    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    # Ground truth vorticity plot
    im1 = axes[0].pcolormesh(lon, lat, true_velocity, cmap='seismic', vmin=-0.0002, vmax=0.0002)
    axes[0].set_title('Ground Truth Velocity')
    axes[0].set_xlabel('Longitude')
    axes[0].set_ylabel('Latitude')
    fig.colorbar(im1, ax=axes[0])

    # Predicted vorticity plot
    im2 = axes[1].pcolormesh(lon, lat, predicted_velocity, cmap='seismic', vmin=-0.0002, vmax=0.0002)
    axes[1].set_title('Predicted Velocity')
    axes[1].set_xlabel('Longitude')
    axes[1].set_ylabel('Latitude')
    fig.colorbar(im2, ax=axes[1])

    plt.tight_layout()
    plt.show()


# figure u-velocity
visualize_with_lat_lon(roms_lat, roms_lon, true_labels[:,0,:,:] / 1000, predictions[:,0,:,:] / 1000, sample_idx=500)


In [None]:
from scipy.io import savemat

with h5py.File('UNet_test_250416_uvmld_to_w_dropout_optimized_output.mat', 'w') as f:
    f.create_dataset('true_labels', data=true_labels)
    f.create_dataset('predictions', data=predictions)


In [None]:
with h5py.File('UNet_test_250416_uvmld_to_w_dropout_optimized_roms_grid_info.mat', 'w') as f2:
    f2.create_dataset('roms_lon', data=roms_lon)
    f2.create_dataset('roms_lat', data=roms_lat)

In [None]:
torch.save(model, 'UNet_test_250416_uvmld_to_w_optimized_dropout.pth')
torch.save(model.state_dict(), 'UNet_test_250416_uvmld_to_w_optimized_dropout_state_dict.pt')