In [None]:
DEMO_MODE = True

In [None]:
import numpy as np
import torch
import h5py
import glob
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary
import matplotlib.pyplot as plt
import gc
import torchvision.transforms.functional as TF
import re
from torch_dct import dct_2d, idct_2d

# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Data loading functions
def load_data_from_mat(filename, u_v_names=None, mld_name=None):
    with h5py.File(filename, 'r') as f:
        if u_v_names:
            u_velocity = torch.tensor(np.array(f[u_v_names[0]])).float()  # [t, n, m]
            v_velocity = torch.tensor(np.array(f[u_v_names[1]])).float()  # [t, n, m]
        else:
            u_velocity = v_velocity = None

        if mld_name:
            mld = torch.tensor(np.array(f[mld_name])).squeeze().float()  # [t]
        else:
            mld = None

    return u_velocity, v_velocity, mld

def load_multiple_h5py_data(u_v_files_pattern, u_v_names):
    u_velocity_list, v_velocity_list = [], []
    u_v_files = glob.glob(u_v_files_pattern)
    u_v_files = sorted(u_v_files, key=lambda x: [int(num) for num in re.findall(r'\d+', x)])
    
    for file in u_v_files:
        u_velocity, v_velocity = load_data_from_mat(file, u_v_names, None)[:2]
        u_velocity_list.append(u_velocity)
        v_velocity_list.append(v_velocity)

    u_velocity = torch.cat(u_velocity_list, dim=0)  # [t_total, n, m]
    v_velocity = torch.cat(v_velocity_list, dim=0)  # [t_total, n, m]

    return u_velocity, v_velocity

def load_multiple_mld_data(mld_files_pattern, mld_name):
    mld_list = []
    mld_files = glob.glob(mld_files_pattern)
    for file in mld_files:
        _, _, mld = load_data_from_mat(file, None, mld_name)
        if mld is not None:
            mld_list.append(mld)
    if len(mld_list) == 0:
        raise ValueError("No MLD data found.")
    mld = torch.cat(mld_list, dim=0)  # [t_total]
    return mld

def normalize(tensor):
    mean = tensor.mean(dim=(0, 2, 3), keepdim=True)
    std = tensor.std(dim=(0, 2, 3), keepdim=True)
    return (tensor - mean) / (std + 1e-5)

class LRVelocityMLDDataset(Dataset):
    def __init__(self, low_u_velocity, low_v_velocity, mld, inter_u_velocity, inter_v_velocity, normalize_fn=None):
        self.low_u_velocity = torch.nan_to_num(low_u_velocity)  # [t, n, m]
        self.low_v_velocity = torch.nan_to_num(low_v_velocity)  # [t, n, m]
        self.mld = torch.nan_to_num(mld)                # [t]
        self.inter_u_velocity = torch.nan_to_num(inter_u_velocity)  # [t, n, m]
        self.inter_v_velocity = torch.nan_to_num(inter_v_velocity)  # [t, n, m]

        self.inputs = torch.stack([self.low_u_velocity, self.low_v_velocity, self.mld], dim=1)  # [t, 4, n, m]
        self.outputs = torch.stack([self.inter_u_velocity, self.inter_v_velocity], dim=1)  # [t, 2, n, m]

        if normalize_fn:
            self.inputs = normalize_fn(self.inputs)

    def __len__(self):
        return self.inputs.shape[0]
    
    def __getitem__(self, idx):
        input_data = self.inputs[idx, :, :, :]                  # [4, n, m]
        output_data = self.outputs[idx, :, :, :]                  # [2, n, m]
        return input_data, output_data


In [None]:
class TestDataset(Dataset):
    def __init__(self, low_u_velocity, low_v_velocity, mld, inter_u_velocity, inter_v_velocity, normalize_fn=None):
        self.low_u_velocity = torch.nan_to_num(low_u_velocity)  # [t, n, m]
        self.low_v_velocity = torch.nan_to_num(low_v_velocity)  # [t, n, m]
        self.mld = torch.nan_to_num(mld)                # [t]
        self.inter_u_velocity = torch.nan_to_num(inter_u_velocity)  # [t, n, m]
        self.inter_v_velocity = torch.nan_to_num(inter_v_velocity)  # [t, n, m]

        self.inputs = torch.stack([self.low_u_velocity, self.low_v_velocity, self.mld], dim=1)  # [t, 4, n, m]
        self.outputs = torch.stack([self.inter_u_velocity, self.inter_v_velocity], dim=1)  # [t, 2, n, m]
        
        if normalize_fn:
            self.inputs = normalize_fn(self.inputs)
            
    def __len__(self):
        return self.inputs.shape[0]

    def __getitem__(self, idx):
        input_data = self.inputs[idx, :, :, :]
        output_data = self.outputs[idx, :, :]
        return input_data, output_data

In [None]:
if not DEMO_MODE:
    low_u_v_names_train = ['low_BBE_U_surf', 'low_BBE_V_surf']
    inter_u_v_names_train = ['inv_dct_band_surf_u', 'inv_dct_band_surf_v']
    mld_name_train = 'low_mld'
    
    train_low_u_velocity, train_low_v_velocity = load_multiple_h5py_data(
        '/home/user_bk/roms_ccs/ver4_rotate/data_for_cnn_final/downsampling_roms_veloicty_rho_points_month_*.mat',
        low_u_v_names_train
    )
    print(f"train_low_u_velocity shape: {train_low_u_velocity.shape}")  # [t_total]
    print(f"train_low_v_velocity shape: {train_low_v_velocity.shape}")  # [t_total]
    
    train_inter_u_velocity, train_inter_v_velocity = load_multiple_h5py_data(
        '/home/user_bk/roms_ccs/ver4_rotate/data_for_cnn_final/intermediate_surface_velocity_rho_points_month_*.mat',
        inter_u_v_names_train
    )
    print(f"train_inter_u_velocity shape: {train_inter_u_velocity.shape}")  # [t_total]
    print(f"train_inter_v_velocity shape: {train_inter_v_velocity.shape}")  # [t_total]
    
    mld_files_pattern_train = '/home/user_bk/roms_ccs/ver4_rotate/data_for_cnn_final/downsampling_roms_spatial_mld_month_*.mat'
    train_mld = load_multiple_mld_data(mld_files_pattern_train, mld_name_train)
    print(f"train_mld shape: {train_mld.shape}")  # [t_total]
    
    assert train_low_u_velocity.shape[0] == train_low_v_velocity.shape[0] == train_mld.shape[0], \
        "All input tensors must have the same number of samples (t dimension)."
    
    assert train_inter_u_velocity.shape[0] == train_inter_v_velocity.shape[0], \
        "All output tensors must have the same number of samples (t dimension)."
    
    num_hours_per_month = [744, 672, 744, 720, 744, 720, 744, 744, 720, 744, 720, 744]
    cumulative_hours = np.cumsum([0] + num_hours_per_month)
    
    input_train_low_u, input_val_low_u = [], []
    input_train_low_v, input_val_low_v = [], []
    input_train_mld, input_val_mld = [], []
    output_train_inter_u, output_val_inter_u = [], []
    output_train_inter_v, output_val_inter_v = [], []
    
    for i in range(12):
        start, end = cumulative_hours[i], cumulative_hours[i+1]
    
        input_low_u_month = train_low_u_velocity[start:end]
        input_low_v_month = train_low_v_velocity[start:end]
        input_mld_month = train_mld[start:end]
        output_inter_u_month = train_inter_u_velocity[start:end]
        output_inter_v_month = train_inter_v_velocity[start:end]
    
        full_combined = list(zip(
            input_low_u_month, input_low_v_month, input_mld_month,
            output_inter_u_month, output_inter_v_month
        ))
    
        train_combined, val_combined = train_test_split(full_combined, test_size=0.2, random_state=42)
    
        (low_u_train, low_v_train, mld_train, inter_u_train, inter_v_train) = zip(*train_combined)
        (low_u_val, low_v_val, mld_val, inter_u_val, inter_v_val) = zip(*val_combined)
        
        input_train_low_u.append(np.stack(low_u_train))
        input_val_low_u.append(np.stack(low_u_val))
        input_train_low_v.append(np.stack(low_v_train))
        input_val_low_v.append(np.stack(low_v_val))
        input_train_mld.append(np.stack(mld_train))
        input_val_mld.append(np.stack(mld_val))
    
        output_train_inter_u.append(np.stack(inter_u_train))
        output_val_inter_u.append(np.stack(inter_u_val))
        output_train_inter_v.append(np.stack(inter_v_train))
        output_val_inter_v.append(np.stack(inter_v_val))
        
    inputs_low_u_train = np.concatenate(input_train_low_u)
    inputs_low_u_val = np.concatenate(input_val_low_u)
    inputs_low_v_train = np.concatenate(input_train_low_v)
    inputs_low_v_val = np.concatenate(input_val_low_v)
    inputs_mld_train = np.concatenate(input_train_mld)
    inputs_mld_val = np.concatenate(input_val_mld)
    
    outputs_inter_u_train = np.concatenate(output_train_inter_u)
    outputs_inter_u_val = np.concatenate(output_val_inter_u)
    outputs_inter_v_train = np.concatenate(output_train_inter_v)
    outputs_inter_v_val = np.concatenate(output_val_inter_v)
    
    print(f"Train size: {len(inputs_low_u_train)} ({len(inputs_low_u_train) / 8760 * 100:.1f}%)")
    print(f"Validation size: {len(outputs_inter_u_val)} ({len(outputs_inter_u_val) / 8760 * 100:.1f}%)")
    
    inputs_low_u_train = torch.from_numpy(inputs_low_u_train).float()
    inputs_low_u_val = torch.from_numpy(inputs_low_u_val).float()
    inputs_low_v_train = torch.from_numpy(inputs_low_v_train).float()
    inputs_low_v_val = torch.from_numpy(inputs_low_v_val).float()
    inputs_mld_train = torch.from_numpy(inputs_mld_train).float()
    inputs_mld_val = torch.from_numpy(inputs_mld_val).float()
    
    outputs_inter_u_train = torch.from_numpy(outputs_inter_u_train).float()
    outputs_inter_u_val = torch.from_numpy(outputs_inter_u_val).float()
    outputs_inter_v_train = torch.from_numpy(outputs_inter_v_train).float()
    outputs_inter_v_val = torch.from_numpy(outputs_inter_v_val).float()
    
    train_dataset = LRVelocityMLDDataset(
        low_u_velocity=inputs_low_u_train,
        low_v_velocity=inputs_low_v_train,
        mld=inputs_mld_train,
        inter_u_velocity=outputs_inter_u_train,
        inter_v_velocity=outputs_inter_v_train,
        normalize_fn=normalize
    )
    
    val_dataset = LRVelocityMLDDataset(
        low_u_velocity=inputs_low_u_val,
        low_v_velocity=inputs_low_v_val,
        mld=inputs_mld_val,
        inter_u_velocity=outputs_inter_u_val,
        inter_v_velocity=outputs_inter_v_val,
        normalize_fn=normalize
    )
    
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
    
    low_u_v_names_test = ['hycom_u_surf_rotate', 'hycom_v_surf_rotate']
    inter_u_v_names_test = ['inv_dct_band_surf_u', 'inv_dct_band_surf_v']
    mld_name_test = 'mld_romsgrid'
    
    # Load the Coriolis parameter
    hycom_file_path = '/home/user_bk/roms_ccs/ver4_rotate/data_for_cnn/surf_vertical_vorticity_HYCOM_month_1.mat'
    
    with h5py.File(hycom_file_path, 'r') as mat_file:
        hycom_f = mat_file['hycom_f'][:]
    
    def load_test_data():
        test_u_velocity, test_v_velocity = load_multiple_h5py_data(
            '/home/user_bk/roms_ccs/ver4_rotate/data_for_cnn/HYCOM_GLBy008_spatial_surf_velocity_rotate_ROMSgrid_2021_*.mat',
            low_u_v_names_test
        )
    
        mld_files_pattern_test = '/home/user_bk/roms_ccs/ver4_rotate/data_for_cnn/HYCOM_GLBy008_temp_ROMSgrid_2021_*.mat'
        test_mld = load_multiple_mld_data(mld_files_pattern_test, mld_name_test)
        print(f"test_mld shape: {test_mld.shape}")  # [t_total]
    
        test_inter_u_velocity, test_inter_v_velocity = load_multiple_h5py_data(
            '/home/user_bk/roms_ccs/ver4_rotate/data_for_cnn_final/intermediate_surface_velocity_rho_points_month_*.mat',
            inter_u_v_names_test
        )
        
        print(f"test_u_velocity shape: {test_u_velocity.shape}")  # [t_total, n, m]
        print(f"test_v_velocity shape: {test_v_velocity.shape}")  # [t_total, n, m]
        print(f"test_mld shape: {test_mld.shape}")                # [t_total]
        print(f"test_inter_u_velocity shape: {test_inter_u_velocity.shape}")                # [t_total]
        print(f"test_inter_v_velocity shape: {test_inter_v_velocity.shape}")                # [t_total]
    
        return test_u_velocity, test_v_velocity, test_mld, test_inter_u_velocity, test_inter_v_velocity
    
    test_u_velocity, test_v_velocity, test_mld, test_inter_u_velocity, test_inter_v_velocity = load_test_data()
    
    test_u_velocity = test_u_velocity[:2912, :, :]
    test_v_velocity = test_v_velocity[:2912, :, :]
    test_mld = test_mld[:2912, :, :]
    test_inter_u_velocity = test_inter_u_velocity[0:8736:3, :, :]
    test_inter_v_velocity = test_inter_v_velocity[0:8736:3, :, :]
    
    test_dataset = TestDataset(
        low_u_velocity=test_u_velocity,
        low_v_velocity=test_v_velocity,
        mld=test_mld,
        inter_u_velocity = test_inter_u_velocity,
        inter_v_velocity = test_inter_v_velocity,
        normalize_fn=normalize
    )
    
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

else:
    print("DEMO MODE: using dummy tensors (no external data files).")

    t_train = 16
    t_val = 8
    t_test = 8
    n, m = 280, 340

    inputs_low_u_train = torch.randn(t_train, n, m)
    inputs_low_v_train = torch.randn(t_train, n, m)
    inputs_mld_train_1d   = torch.randn(t_train)
    outputs_inter_u_train = torch.randn(t_train, n, m)
    outputs_inter_v_train = torch.randn(t_train, n, m)

    inputs_low_u_val = torch.randn(t_val, n, m)
    inputs_low_v_val = torch.randn(t_val, n, m)
    inputs_mld_val_1d   = torch.randn(t_val)
    outputs_inter_u_val = torch.randn(t_val, n, m)
    outputs_inter_v_val = torch.randn(t_val, n, m)

    inputs_mld_train = inputs_mld_train_1d[:, None, None].repeat(1, n, m)  # [t,n,m]
    inputs_mld_val   = inputs_mld_val_1d[:, None, None].repeat(1, n, m)    # [t,n,m]

    train_dataset = LRVelocityMLDDataset(
        low_u_velocity=inputs_low_u_train,
        low_v_velocity=inputs_low_v_train,
        mld=inputs_mld_train,
        inter_u_velocity=outputs_inter_u_train,
        inter_v_velocity=outputs_inter_v_train,
        normalize_fn=normalize
    )

    val_dataset = LRVelocityMLDDataset(
        low_u_velocity=inputs_low_u_val,
        low_v_velocity=inputs_low_v_val,
        mld=inputs_mld_val,
        inter_u_velocity=outputs_inter_u_val,
        inter_v_velocity=outputs_inter_v_val,
        normalize_fn=normalize
    )

    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    val_loader   = DataLoader(val_dataset, batch_size=4, shuffle=False)

    # ---- dummy test tensors ----
    test_u_velocity = torch.randn(t_test, n, m)
    test_v_velocity = torch.randn(t_test, n, m)
    test_mld_1d        = torch.randn(t_test)   # [t]
    test_inter_u_velocity = torch.randn(t_test, n, m)
    test_inter_v_velocity = torch.randn(t_test, n, m)

    test_mld = test_mld_1d[:, None, None].repeat(1, n, m)  # [t, n, m]

    test_dataset = TestDataset(
        low_u_velocity=test_u_velocity,
        low_v_velocity=test_v_velocity,
        mld=test_mld,
        inter_u_velocity=test_inter_u_velocity,
        inter_v_velocity=test_inter_v_velocity,
        normalize_fn=normalize
    )
    test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

In [None]:
def create_dct_bandpass_mask(H, W, dx=1.0, dy=1.0, lambda_min=5.0, lambda_max=20.0):
    """
    H, W: height, width of the spatial grid (e.g., 256)
    dx, dy: grid resolution in km (default 1.0 km)
    lambda_min, lambda_max: passband in km
    """
    mask = np.zeros((H, W), dtype=np.float32)

    Lx = H * dx  # total domain size in x
    Ly = W * dy  # total domain size in y

    for i in range(H):
        for j in range(W):
            # Avoid division by zero (i = 0 or j = 0 are low frequency)
            lambda_x = 2 * Lx / (2 * i + 1)
            lambda_y = 2 * Ly / (2 * j + 1)

            # Equivalent isotropic wavelength
            lambda_eq = 1.0 / np.sqrt((1/lambda_x**2) + (1/lambda_y**2))

            # Bandpass condition
            if lambda_min <= lambda_eq <= lambda_max:
                mask[i, j] = 1.0

    # Convert to torch tensor and normalize shape
    mask_tensor = torch.tensor(mask)  # shape: [H, W]
    return mask_tensor

def dual_loss(pred, target, freq_mask, alpha=0.95, beta=0.05):
    pixel_loss = F.smooth_l1_loss(pred, target)

    pred_band = idct_2d(dct_2d(pred.float()) * freq_mask.unsqueeze(0).unsqueeze(0))
    target_band = idct_2d(dct_2d(target.float()) * freq_mask.unsqueeze(0).unsqueeze(0))
    spec_loss = F.mse_loss(pred_band, target_band)

    return alpha * pixel_loss + beta * spec_loss

freq_mask = create_dct_bandpass_mask(H=280, W=340,
                                      dx=1.0, dy=1.0,
                                      lambda_min=5.0, lambda_max=20.0)
freq_mask = freq_mask.to(device)

In [None]:
def plot_losses(train_losses, val_losses):
    epochs = range(1, len(train_losses)+1)
    plt.figure(figsize=(10,5))
    plt.plot(epochs, train_losses, 'bo-', label='Training loss')
    plt.plot(epochs, val_losses, 'ro-', label='Validation loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            m.bias.data.zero_()

In [None]:
def test_model(model, test_loader):
    model.eval()
    predictions = []
    true_labels = []
    with torch.no_grad():
        for inputs, outputs in test_loader:
            inputs = inputs.to(device)
            outputs = outputs.to(device)
            
            with torch.cuda.amp.autocast():
                p_outputs = model(inputs)
                predictions.append(p_outputs.cpu().numpy())  # Store predictions
                true_labels.append(outputs.cpu().numpy())  # Store true labels

    # Concatenate all predictions and true labels
    predictions = np.concatenate(predictions, axis=0)
    true_labels = np.concatenate(true_labels, axis=0)

    return predictions, true_labels

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=100, accumulate_steps=4, device='cuda', scheduler=None):
    train_losses = []
    val_losses = []
    scaler = torch.cuda.amp.GradScaler()

    model.apply(weights_init)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        optimizer.zero_grad()

        for i, (inputs, labels) in enumerate(train_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            with torch.cuda.amp.autocast():
                outputs = model(inputs)
                loss = dual_loss(outputs, labels, freq_mask, alpha=0.95, beta=0.05)
                loss = loss / accumulate_steps

            scaler.scale(loss).backward()

            if (i + 1) % accumulate_steps == 0 or (i + 1) == len(train_loader):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            running_loss += loss.item() * inputs.size(0)

            torch.cuda.empty_cache()
            gc.collect()

        avg_train_loss = running_loss / len(train_loader.dataset)
        train_losses.append(avg_train_loss)

        val_loss = 0.0
        model.eval()
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)

                with torch.cuda.amp.autocast():
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                val_loss += loss.item() * inputs.size(0)

        avg_val_loss = val_loss / len(val_loader.dataset)
        val_losses.append(avg_val_loss)
        
        if scheduler:
            scheduler.step()

        for param_group in optimizer.param_groups:
            print(f"Learning Rate: {param_group['lr']:.6f}")

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

        torch.cuda.empty_cache()
        gc.collect()

    print("Finished Training")
    plot_losses(train_losses, val_losses)

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.relu = nn.LeakyReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return self.relu(out + residual)

# UNet with Adaptive Pooling
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        def CBR2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True):
            layers = [
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                          kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
                nn.BatchNorm2d(num_features=out_channels),
                nn.LeakyReLU()
            ]
            return nn.Sequential(*layers)

        # Contracting path
        self.enc1_1 = CBR2d(in_channels=3, out_channels=32)
        self.enc1_2 = CBR2d(in_channels=32, out_channels=32)
        self.pool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0)

        self.enc2_1 = CBR2d(in_channels=32, out_channels=64)
        self.enc2_2 = CBR2d(in_channels=64, out_channels=64)
        self.pool2 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0)

        self.enc3_1 = CBR2d(in_channels=64, out_channels=128)
        self.enc3_2 = CBR2d(in_channels=128, out_channels=128)
        self.pool3 = nn.AdaptiveMaxPool2d((10, 21))

        self.enc4_1 = CBR2d(in_channels=128, out_channels=256)
        self.enc4_2 = CBR2d(in_channels=256, out_channels=256)
        self.pool4 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0)

        self.enc5_1 = CBR2d(in_channels=256, out_channels=512)
        self.enc5_2 = CBR2d(in_channels=512, out_channels=512)
        self.pool5 = nn.AdaptiveMaxPool2d((5, 11))

        self.enc6_1 = nn.Sequential(
            nn.Conv2d(512, 1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(),
            ResidualBlock(1024),
            ResidualBlock(1024)
        )

        # Expansive path
        self.dec6_1 = CBR2d(in_channels=1024, out_channels=512)
        self.unpool5 = nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=(1, 1), stride=(4, 2), padding=0)

        self.dec5_2 = CBR2d(in_channels=2 * 512, out_channels=512)
        self.dec5_1 = CBR2d(in_channels=512, out_channels=256)
        self.unpool4 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=(3, 2), stride=(2, 2), padding=0)

        self.dec4_2 = CBR2d(in_channels=2 * 256, out_channels=256)
        self.dec4_1 = CBR2d(in_channels=256, out_channels=128)
        self.unpool3 = nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=(2, 3), stride=(2, 2), padding=0)

        self.dec3_2 = CBR2d(in_channels=2 * 128, out_channels=128)
        self.dec3_1 = CBR2d(in_channels=128, out_channels=64)
        self.unpool2 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=(2, 2), stride=(2, 2), padding=0)

        self.dec2_2 = CBR2d(in_channels=2 * 64, out_channels=64)
        self.dec2_1 = CBR2d(in_channels=64, out_channels=32)
        self.unpool1 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=(2, 2), stride=(2, 2), padding=0)

        self.dec1_2 = CBR2d(in_channels=2 * 32, out_channels=32)
        self.dec1_1 = CBR2d(in_channels=32, out_channels=32)

        self.fc = nn.Conv2d(in_channels=32, out_channels=2, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        # Contracting path
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        pool1 = self.pool1(enc1_2)
    
        enc2_1 = self.enc2_1(pool1)
        enc2_2 = self.enc2_2(enc2_1)
        pool2 = self.pool2(enc2_2)
    
        enc3_1 = self.enc3_1(pool2)
        enc3_2 = self.enc3_2(enc3_1)
        pool3 = self.pool3(enc3_2)
    
        enc4_1 = self.enc4_1(pool3)
        enc4_2 = self.enc4_2(enc4_1)
        pool4 = self.pool4(enc4_2)
    
        enc5_1 = self.enc5_1(pool4)
        enc5_2 = self.enc5_2(enc5_1)
        pool5 = self.pool5(enc5_2)
    
        enc6_1 = self.enc6_1(pool5)
    
        # Expansive path
        dec6_1 = self.dec6_1(enc6_1)
        unpool5 = self.unpool5(dec6_1)

        diffY = unpool5.size()[2] - enc5_2.size()[2]
        diffX = unpool5.size()[3] - enc5_2.size()[3]
    
        if diffY != 0 or diffX != 0:
            enc5_2_padded = F.pad(enc5_2, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        else:
            enc5_2_padded = enc5_2
    
        cat5 = torch.cat((unpool5, enc5_2_padded), dim=1)
        dec5_2 = self.dec5_2(cat5)
        dec5_1 = self.dec5_1(dec5_2)
    
        unpool4 = self.unpool4(dec5_1)

        diffY = unpool4.size()[2] - enc4_2.size()[2]
        diffX = unpool4.size()[3] - enc4_2.size()[3]
    
        if diffY != 0 or diffX != 0:
            enc4_2_padded = F.pad(enc4_2, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        else:
            enc4_2_padded = enc4_2
    
        cat4 = torch.cat((unpool4, enc4_2_padded), dim=1)
        dec4_2 = self.dec4_2(cat4)
        dec4_1 = self.dec4_1(dec4_2)
    
        unpool3 = self.unpool3(dec4_1)

        diffY = unpool3.size()[2] - enc3_2.size()[2]
        diffX = unpool3.size()[3] - enc3_2.size()[3]
    
        if diffY != 0 or diffX != 0:
            enc3_2_padded = F.pad(enc3_2, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        else:
            enc3_2_padded = enc3_2
    
        cat3 = torch.cat((unpool3, enc3_2_padded), dim=1)
        dec3_2 = self.dec3_2(cat3)
        dec3_1 = self.dec3_1(dec3_2)
    
        unpool2 = self.unpool2(dec3_1)

        diffY = unpool2.size()[2] - enc2_2.size()[2]
        diffX = unpool2.size()[3] - enc2_2.size()[3]
    
        if diffY != 0 or diffX != 0:
            enc2_2_padded = F.pad(enc2_2, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        else:
            enc2_2_padded = enc2_2
    
        cat2 = torch.cat((unpool2, enc2_2_padded), dim=1)
        dec2_2 = self.dec2_2(cat2)
        dec2_1 = self.dec2_1(dec2_2)
    
        unpool1 = self.unpool1(dec2_1)

        diffY = unpool1.size()[2] - enc1_2.size()[2]
        diffX = unpool1.size()[3] - enc1_2.size()[3]
    
        if diffY != 0 or diffX != 0:
            enc1_2_padded = F.pad(enc1_2, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        else:
            enc1_2_padded = enc1_2
    
        cat1 = torch.cat((unpool1, enc1_2_padded), dim=1)
        dec1_2 = self.dec1_2(cat1)
        dec1_1 = self.dec1_1(dec1_2)

        x = self.fc(dec1_1)
    
        return x


In [None]:
model = UNet().to(device)
model.apply(weights_init)
summary(model, input_size=(3, 35, 85))

criterion = lambda pred, target: dual_loss(pred, target, freq_mask, alpha=0.95, beta=0.05)
learning_rate = 0.0002842422846893456
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

scaler = torch.cuda.amp.GradScaler()

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)


In [None]:
def check_nan(tensor, tensor_name="Tensor"):
    if torch.isnan(tensor).any():
        print(f"NaN values detected in {tensor_name}.")
    else:
        print(f"No NaN values detected in {tensor_name}.")

check_nan(train_dataset.inputs, "train_dataset.inputs")

check_nan(val_dataset.inputs, "val_dataset.inputs")

check_nan(test_dataset.inputs, "test_dataset.inputs")

In [None]:
# Train the model
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=100, accumulate_steps=4, device=device)

# Test the model
predictions, true_labels = test_model(model, test_loader)

# Print or save the test results as needed
print("Test predictions and true labels have been computed.")

In [None]:
hycom_grid_fn_path = '/home/user_bk/roms_ccs/ver4_rotate/data_for_cnn/HYCOM_grid_info.mat'
with h5py.File(hycom_grid_fn_path, 'r') as mat_file:
    hycom_lon = mat_file['lon'][:]
    hycom_lat = mat_file['lat'][:]

roms_grid_fn_path = '/home/user_bk/roms_ccs/ver4_rotate/data_for_cnn/roms_grid_info.mat'
with h5py.File(roms_grid_fn_path, 'r') as mat_file:
    roms_lon = mat_file['lon_rho'][:]
    roms_lat = mat_file['lat_rho'][:]

def visualize_with_lat_lon(lat, lon, true_velocity, predicted_velocity, sample_idx=0):
    """
    lat: latitude [height]
    lon: longitude [width]
    true_vorticity: Ground truth vorticity [height, width]
    predicted_vorticity: Predicted vorticity [height, width]
    sample_idx: the index of the sample to visualize (not used here)
    """
    
    # Selecting the sample to visualize (you can loop over multiple samples if needed)
    true_velocity = true_velocity[sample_idx].squeeze()  # Ground truth velocity [height, width]
    predicted_velocity = predicted_velocity[sample_idx].squeeze()  # Predicted velocity [height, width]

    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    # Ground truth vorticity plot
    im1 = axes[0].pcolormesh(lon, lat, true_velocity, cmap='seismic', vmin=-0.5, vmax=0.5)
    axes[0].set_title('Ground Truth Velocity')
    axes[0].set_xlabel('Longitude')
    axes[0].set_ylabel('Latitude')
    fig.colorbar(im1, ax=axes[0])

    # Predicted vorticity plot
    im2 = axes[1].pcolormesh(lon, lat, predicted_velocity, cmap='seismic', vmin=-0.5, vmax=0.5)
    axes[1].set_title('Predicted Velocity')
    axes[1].set_xlabel('Longitude')
    axes[1].set_ylabel('Latitude')
    fig.colorbar(im2, ax=axes[1])

    plt.tight_layout()
    plt.show()

# figure u-velocity
visualize_with_lat_lon(roms_lat, roms_lon, true_labels[:,0,:,:], predictions[:,0,:,:], sample_idx=100)
visualize_with_lat_lon(roms_lat, roms_lon, true_labels[:,1,:,:], predictions[:,1,:,:], sample_idx=100)


In [None]:
import matplotlib.pyplot as plt
import torch

def plot_bandpassed_comparison(pred, target, freq_mask, vmin=-0.5, vmax=0.5, cmap='RdBu_r'):
    """
    pred, target: [1, 1, H, W] or [C, H, W]
    freq_mask: [H, W]
    """
    if pred.ndim == 4:
        pred = pred[0, 0]
    elif pred.ndim == 3:
        pred = pred[0]

    if target.ndim == 4:
        target = target[0, 0]
    elif target.ndim == 3:
        target = target[0]

    pred = pred.float()
    target = target.float()
    freq_mask = freq_mask.to(pred.device)

    pred_dct = dct_2d(pred)
    target_dct = dct_2d(target)

    band_mask = freq_mask.unsqueeze(0).unsqueeze(0)  # [1,1,H,W]

    pred_filtered = idct_2d(pred_dct * band_mask)
    target_filtered = idct_2d(target_dct * band_mask)

    pred_np = pred_filtered.detach().cpu().squeeze().numpy()
    target_np = target_filtered.detach().cpu().squeeze().numpy()

    fig, axs = plt.subplots(1, 2, figsize=(12, 5))
    im0 = axs[0].imshow(target_np, vmin=vmin, vmax=vmax, cmap=cmap)
    axs[0].set_title("Ground Truth (bandpassed)")
    plt.colorbar(im0, ax=axs[0])

    im1 = axs[1].imshow(pred_np, vmin=vmin, vmax=vmax, cmap=cmap)
    axs[1].set_title("Prediction (bandpassed)")
    plt.colorbar(im1, ax=axs[1])

    for ax in axs:
        ax.set_xticks([])
        ax.set_yticks([])

    plt.tight_layout()
    plt.show()

In [None]:
tt = 500
pred=torch.tensor(predictions[tt,0,:,:])
target=torch.tensor(true_labels[tt,0,:,:])

plot_bandpassed_comparison(pred, target, freq_mask)

In [None]:
from scipy.io import savemat

with h5py.File('UNet_test_250411_except_MLD_temp_sorted_real_5km_20km_optimized_output.mat', 'w') as f:
    f.create_dataset('true_labels', data=true_labels)
    f.create_dataset('predictions', data=predictions)


In [None]:
with h5py.File('UNet_test_250411_except_MLD_temp_sorted_real_5km_20km_optimized_roms_grid_info.mat', 'w') as f2:
    f2.create_dataset('roms_lon', data=roms_lon)
    f2.create_dataset('roms_lat', data=roms_lat)

In [None]:
torch.save(model, 'UNet_test_250411_except_MLD_temp_sorted_real_5km_20km_optimized.pth')
torch.save(model.state_dict(), 'UNet_test_250411_except_MLD_temp_sorted_real_5km_20km_optimized_state_dict.pt')