In [None]:
from ftplib import FTP

import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
import netCDF4


In [None]:
import warnings

warnings.filterwarnings("ignore")

import time
import numpy as np
import matplotlib.pyplot as plt

from pydmd import DMD, CDMD, RDMD

In [None]:
# # Import and save data locally
# ftp = FTP('ftp.cdc.noaa.gov')
# ftp.login()
# ftp.cwd('/Datasets/noaa.oisst.v2/')

# filenames = ['sst.wkmean.1990-present.nc', 'lsmask.nc']

# for filename in filenames:
#     localfile = open(filename, 'wb')
#     ftp.retrbinary('RETR ' + filename, localfile.write, 1024)
#     localfile.close()

# ftp.quit();

In [None]:
f = netCDF4.Dataset('sst.wkmean.1990-present.nc')

lat,lon = f.variables['lat'], f.variables['lon']
SST = f.variables['sst']
sst = SST[:]

f = netCDF4.Dataset('lsmask.nc')
mask = f.variables['mask']

In [None]:
masks = np.bool_(np.squeeze(mask))
snapshot = float("nan")*np.ones((180,360))
snapshot[masks] = sst[359,masks]

plt.imshow(snapshot, cmap=plt.cm.coolwarm)
plt.xticks([])
plt.yticks([])
plt.title('First snapshot of SST')
X = sst[:,masks]
#X = np.reshape(X.compressed(), X.shape)

In [None]:
def compute_error(true, est):
    """
    Computes and returns relative error.
    """
    return np.linalg.norm(true - est) / np.linalg.norm(true)

from scipy.linalg import hadamard

def generate_countsketch_matrix(m, r, seed=None):
    rng = np.random.default_rng(seed)
    S = np.zeros((m, r))
    for i in range(r):
        row = rng.integers(0, m)
        S[row, i] = rng.choice([-1, 1])
    return S

def generate_shrt_matrix(m, r, seed=None):
    rng = np.random.default_rng(seed)
    p = 2**np.ceil(np.log2(m)).astype(int)
    H = hadamard(p)
    D = np.diag(rng.choice([-1, 1], size=p))
    HD = np.dot(H, D)
    if p < r :
        indices = rng.choice(p, size=r, replace=True)
    else:
        indices = rng.choice(p, size=r, replace=False)
    SHRT = HD[indices, :]
    return SHRT[:, :m].T  # Trim in case of padding

In [None]:
snapshots_matrix = X.T

In [None]:
num_trials = 1
m = snapshots_matrix.shape[1]
rank = 250 # target rank
oversampling = [0]

In [None]:
# Exact DMD
dmd_oversampling_error = np.zeros((len(oversampling), num_trials))
dmd_oversampling_times = np.zeros((len(oversampling), num_trials))

for i, p in enumerate(oversampling):
    for j in range(num_trials):  
        t0 = time.time()
        dmd = DMD(svd_rank=rank, exact=True)
        dmd.fit(X.T)
        t1 = time.time()

        dmd_oversampling_error[i, j] = (
            compute_error(snapshots_matrix, dmd.reconstructed_data)
        )
        dmd_oversampling_times[i, j] = (t1 - t0)

In [None]:
# RDMD with Gaussian sketch
oversampling_error = np.zeros((len(oversampling), num_trials))
oversampling_times = np.zeros((len(oversampling), num_trials))

for i, p in enumerate(oversampling):
    for j in range(num_trials):  

        test_matrix = np.random.randn(m, rank+p)
        t0 = time.time()
        rdmd = RDMD(svd_rank=rank, test_matrix=test_matrix).fit(
            snapshots_matrix
        )
        t1 = time.time()

        oversampling_error[i, j] = (
            compute_error(snapshots_matrix, rdmd.reconstructed_data)
        )
        oversampling_times[i, j] = (t1 - t0)

In [None]:
# RDMD with Count sketch
oversampling_error_count = np.zeros((len(oversampling), num_trials))
oversampling_times_count = np.zeros((len(oversampling), num_trials))

for i, p in enumerate(oversampling):
    for j in range(num_trials):
        test_matrix = generate_countsketch_matrix(m, rank+p)
        t0 = time.time()
        rdmd_count = RDMD(svd_rank=rank, test_matrix=test_matrix).fit(
            snapshots_matrix
        )
        t1 = time.time()

        oversampling_error_count[i, j] = (
            compute_error(snapshots_matrix, rdmd_count.reconstructed_data)
        )
        oversampling_times_count[i, j] += (t1 - t0) 

In [None]:
# RDMD with SHRT sketch
oversampling_error_shrt = np.zeros((len(oversampling), num_trials))
oversampling_times_shrt = np.zeros((len(oversampling), num_trials))

for i, p in enumerate(oversampling):
    for j in range(num_trials):
        test_matrix = generate_shrt_matrix(m, rank+p)
        t0 = time.time()
        rdmd_shrt = RDMD(svd_rank=rank, test_matrix=test_matrix).fit(
            snapshots_matrix
        )
        t1 = time.time()

        oversampling_error_shrt[i, j] = (
            compute_error(snapshots_matrix, rdmd_shrt.reconstructed_data)
        )
        oversampling_times_shrt[i, j] += (t1 - t0) 

In [None]:
for time_point_index in range(0, 52):

    snapshot[masks] = sst[time_point_index,masks]

    snapshot_reconstructed_dmd = float("nan")*np.ones((180, 360)) 
    snapshot_reconstructed_dmd[masks] = dmd.reconstructed_data.real[:, time_point_index]

    snapshot_reconstructed_rdmd = float("nan")*np.ones((180, 360)) 
    snapshot_reconstructed_rdmd[masks] = rdmd.reconstructed_data.real[:, time_point_index]

    snapshot_reconstructed_rdmd_count = float("nan")*np.ones((180, 360)) 
    snapshot_reconstructed_rdmd_count[masks] = rdmd_count.reconstructed_data.real[:, time_point_index]

    snapshot_reconstructed_rdmd_shrt = float("nan")*np.ones((180, 360)) 
    snapshot_reconstructed_rdmd_shrt[masks] = rdmd_shrt.reconstructed_data.real[:, time_point_index]

    plt.figure(figsize=(20, 18))

    plt.subplot(4, 2, 1)
    max_val =  np.max(sst[:,masks])
    min_val =  np.min(sst[:,masks])
    plt.imshow(snapshot_reconstructed_dmd, cmap=plt.cm.coolwarm, vmin=min_val, vmax=max_val)
    plt.title("Exact DMD")
    plt.colorbar()

    plt.subplot(4, 2, 2)
    diff = snapshot - snapshot_reconstructed_dmd
    max_abs_val = 5 #np.max(np.abs(diff))
    plt.subplot(4, 2, 2)
    plt.imshow(diff, cmap=plt.cm.coolwarm, vmin=-max_abs_val, vmax=max_abs_val)
    plt.title("Reconstruction error")
    plt.colorbar()


    plt.subplot(4, 2, 3)
    plt.imshow(snapshot_reconstructed_rdmd, cmap=plt.cm.coolwarm, vmin=min_val, vmax=max_val)
    plt.title("rDMD")
    plt.colorbar()

    plt.subplot(4, 2, 4)
    plt.imshow(snapshot - snapshot_reconstructed_rdmd, cmap=plt.cm.coolwarm, vmin=-max_abs_val, vmax=max_abs_val)
    plt.title("Reconstruction error")
    plt.colorbar()

    plt.subplot(4, 2, 5)
    plt.imshow(snapshot_reconstructed_rdmd_count, cmap=plt.cm.coolwarm, vmin=min_val, vmax=max_val)
    plt.title("rDMD (count)")
    plt.colorbar()

    plt.subplot(4, 2, 6)
    plt.imshow(snapshot - snapshot_reconstructed_rdmd_count, cmap=plt.cm.coolwarm, vmin=-max_abs_val, vmax=max_abs_val)
    plt.title("Reconstruction error")
    plt.colorbar()

    plt.subplot(4, 2, 7)
    plt.imshow(snapshot_reconstructed_rdmd_shrt, cmap=plt.cm.coolwarm, vmin=min_val, vmax=max_val)
    plt.title("rDMD (SHRT)")
    plt.colorbar()

    plt.subplot(4, 2, 8)
    plt.imshow(snapshot - snapshot_reconstructed_rdmd_shrt, cmap=plt.cm.coolwarm, vmin=-max_abs_val, vmax=max_abs_val)
    plt.title("Reconstruction error")
    plt.colorbar()

    plt.tight_layout()
    plt.suptitle(f'Sea Surface Temperature, Week = {time_point_index}', fontsize=20, y=1.02)
    plt.savefig(f'./sst_image/img_week{time_point_index:02d}.jpg', bbox_inches='tight', dpi=200)
    plt.close()

## modes

In [None]:
num_trials = 1
m = snapshots_matrix.shape[1]
rank = 7 # target rank
oversampling = [0]

In [None]:
# Exact DMD
dmd_oversampling_error = np.zeros((len(oversampling), num_trials))
dmd_oversampling_times = np.zeros((len(oversampling), num_trials))

for i, p in enumerate(oversampling):
    for j in range(num_trials):  
        t0 = time.time()
        dmd = DMD(svd_rank=rank, exact=True)
        dmd.fit(X.T)
        t1 = time.time()

        dmd_oversampling_error[i, j] = (
            compute_error(snapshots_matrix, dmd.reconstructed_data)
        )
        dmd_oversampling_times[i, j] = (t1 - t0)

In [None]:
# RDMD with Gaussian sketch
oversampling_error = np.zeros((len(oversampling), num_trials))
oversampling_times = np.zeros((len(oversampling), num_trials))

for i, p in enumerate(oversampling):
    for j in range(num_trials):  

        test_matrix = np.random.randn(m, rank+p)
        t0 = time.time()
        rdmd = RDMD(svd_rank=rank, test_matrix=test_matrix).fit(
            snapshots_matrix
        )
        t1 = time.time()

        oversampling_error[i, j] = (
            compute_error(snapshots_matrix, rdmd.reconstructed_data)
        )
        oversampling_times[i, j] = (t1 - t0)

In [None]:
# RDMD with Count sketch
oversampling_error_count = np.zeros((len(oversampling), num_trials))
oversampling_times_count = np.zeros((len(oversampling), num_trials))

for i, p in enumerate(oversampling):
    for j in range(num_trials):
        test_matrix = generate_countsketch_matrix(m, rank+p)
        t0 = time.time()
        rdmd_count = RDMD(svd_rank=rank, test_matrix=test_matrix).fit(
            snapshots_matrix
        )
        t1 = time.time()

        oversampling_error_count[i, j] = (
            compute_error(snapshots_matrix, rdmd_count.reconstructed_data)
        )
        oversampling_times_count[i, j] += (t1 - t0) 

In [None]:
# RDMD with SHRT sketch
oversampling_error_shrt = np.zeros((len(oversampling), num_trials))
oversampling_times_shrt = np.zeros((len(oversampling), num_trials))

for i, p in enumerate(oversampling):
    for j in range(num_trials):
        test_matrix = generate_shrt_matrix(m, rank+p)
        t0 = time.time()
        rdmd_shrt = RDMD(svd_rank=rank, test_matrix=test_matrix).fit(
            snapshots_matrix
        )
        t1 = time.time()

        oversampling_error_shrt[i, j] = (
            compute_error(snapshots_matrix, rdmd_shrt.reconstructed_data)
        )
        oversampling_times_shrt[i, j] += (t1 - t0) 

In [None]:
mode_dmd = np.nan * np.ones((180, 360))
mode_rdmd = np.nan * np.ones((180, 360))
mode_rdmd_count = np.nan * np.ones((180, 360))
mode_rdmd_shrt = np.nan * np.ones((180, 360))

mode_numbers = [0, 4, 6]

fig, axs = plt.subplots(nrows=4, ncols=3, figsize=(15, 12)) 

for index, mode_number in enumerate(mode_numbers):

    mode_dmd[masks] = np.abs(dmd.modes[:, mode_number])
    ax = axs[0, index] 
    cax = ax.imshow(mode_dmd, cmap=plt.cm.coolwarm)
    ax.set_title(f'DMD Mode Number: {mode_number + 1}')
    ax.axis('off') 

    mode_rdmd[masks] = np.abs(rdmd.modes[:, mode_number])
    ax = axs[1, index] 
    cax = ax.imshow(mode_rdmd, cmap=plt.cm.coolwarm)
    ax.set_title(f'RDMD Mode Number: {mode_number + 1}')
    ax.axis('off')

    mode_rdmd_count[masks] = np.abs(rdmd_count.modes[:, mode_number])
    ax = axs[2, index] 
    cax = ax.imshow(mode_rdmd_count, cmap=plt.cm.coolwarm)
    ax.set_title(f'RDMD (count) Mode Number: {mode_number + 1}')
    ax.axis('off')

    mode_rdmd_shrt[masks] = np.abs(rdmd_shrt.modes[:, mode_number])
    ax = axs[3, index] 
    cax = ax.imshow(mode_rdmd_shrt, cmap=plt.cm.coolwarm)
    ax.set_title(f'RDMD (shrt) Mode Number: {mode_number + 1}')
    ax.axis('off')

plt.tight_layout()
plt.show()