In [None]:
from ftplib import FTP
import numpy as np
from sklearn import datasets
import netCDF4
import time
import matplotlib.pyplot as plt
from pydmd import DMD, CDMD, RDMD

In [None]:
import warnings
warnings.filterwarnings("ignore")

## Download SST data

In [None]:
# # Import and save data locally
# ftp = FTP('ftp.cdc.noaa.gov')
# ftp.login()
# ftp.cwd('/Datasets/noaa.oisst.v2/')

# filenames = ['sst.wkmean.1990-present.nc', 'lsmask.nc']

# for filename in filenames:
#     localfile = open(filename, 'wb')
#     ftp.retrbinary('RETR ' + filename, localfile.write, 1024)
#     localfile.close()

# ftp.quit();

In [None]:
f = netCDF4.Dataset('sst.wkmean.1990-present.nc')

lat,lon = f.variables['lat'], f.variables['lon']
SST = f.variables['sst']
sst = SST[:]

f = netCDF4.Dataset('lsmask.nc')
mask = f.variables['mask']

In [None]:
time_index = 0

masks = np.bool_(np.squeeze(mask))
snapshot = float("nan")*np.ones((180,360))
snapshot[masks] = sst[time_index,masks]

plt.imshow(snapshot, cmap=plt.cm.coolwarm)
plt.xticks([])
plt.yticks([])
plt.title('First snapshot of SST')

In [None]:
def compute_error(true, est):
    """
    Computes and returns relative error.
    """
    return np.linalg.norm(true - est) / np.linalg.norm(true)

## Sketching

In [None]:
from scipy.linalg import hadamard

def generate_countsketch_matrix(m, r, seed=None):
    rng = np.random.default_rng(seed)
    S = np.zeros((m, r))
    for i in range(r):
        row = rng.integers(0, m)
        S[row, i] = rng.choice([-1, 1])
    return S

def generate_shrt_matrix(m, r, seed=None):
    rng = np.random.default_rng(seed)
    p = 2**np.ceil(np.log2(m)).astype(int)
    H = hadamard(p)
    D = np.diag(rng.choice([-1, 1], size=p))
    HD = np.dot(H, D)
    if p < r :
        indices = rng.choice(p, size=r, replace=True)
    else:
        indices = rng.choice(p, size=r, replace=False)
    SHRT = HD[indices, :]
    return SHRT[:, :m].T  # Trim in case of padding

In [None]:
X = sst[:,masks]
X.shape

In [None]:
snapshots_matrix = X.T

## Randomized DMD 
### (1) Varying target rank. 

In [None]:
num_trials = 20
m = snapshots_matrix.shape[1]
r = 50 # # of smapling rows
rank_values = [10, 20, 50, 100, 200]

In [None]:
# Exact DMD
dmd_oversampling_error = np.zeros((len(rank_values), num_trials))
dmd_oversampling_times = np.zeros((len(rank_values), num_trials))

for i, rank in enumerate(rank_values):
    for j in range(num_trials):  
        t0 = time.time()
        dmd = DMD(svd_rank=rank, exact=True)
        dmd.fit(X.T)
        t1 = time.time()

        dmd_oversampling_error[i, j] = (
            compute_error(snapshots_matrix, dmd.reconstructed_data)
        )
        dmd_oversampling_times[i, j] = (t1 - t0)

In [None]:
# RDMD with Gaussian sketch
oversampling_error = np.zeros((len(rank_values), num_trials))
oversampling_times = np.zeros((len(rank_values), num_trials))

for i, rank in enumerate(rank_values):
    for j in range(num_trials):  

        test_matrix = np.random.randn(m, r)
        t0 = time.time()
        rdmd = RDMD(svd_rank=rank, test_matrix=test_matrix).fit(
            snapshots_matrix
        )
        t1 = time.time()

        oversampling_error[i, j] = (
            compute_error(snapshots_matrix, rdmd.reconstructed_data)
        )
        oversampling_times[i, j] = (t1 - t0)

In [None]:
# RDMD with Count sketch
oversampling_error_count = np.zeros((len(rank_values), num_trials))
oversampling_times_count = np.zeros((len(rank_values), num_trials))

for i, oversampling in enumerate(rank_values):
    for j in range(num_trials):
        test_matrix = generate_countsketch_matrix(m, r)
        t0 = time.time()
        rdmd_count = RDMD(svd_rank=rank, test_matrix=test_matrix).fit(
            snapshots_matrix
        )
        t1 = time.time()

        oversampling_error_count[i, j] = (
            compute_error(snapshots_matrix, rdmd_count.reconstructed_data)
        )
        oversampling_times_count[i, j] += (t1 - t0) 

In [None]:
# RDMD with SHRT sketch
oversampling_error_shrt = np.zeros((len(rank_values), num_trials))
oversampling_times_shrt = np.zeros((len(rank_values), num_trials))

for i, oversampling in enumerate(rank_values):
    for j in range(num_trials):
        test_matrix = generate_shrt_matrix(m, r)
        t0 = time.time()
        rdmd_shrt = RDMD(svd_rank=rank, test_matrix=test_matrix).fit(
            snapshots_matrix
        )
        t1 = time.time()

        oversampling_error_shrt[i, j] = (
            compute_error(snapshots_matrix, rdmd_shrt.reconstructed_data)
        )
        oversampling_times_shrt[i, j] += (t1 - t0) 

In [None]:
plt.figure(figsize=(8, 3))
plt.subplot(1, 2, 1)
plt.errorbar(rank_values, np.mean(dmd_oversampling_error, axis=1), yerr=np.std(dmd_oversampling_error, axis=1), fmt='-o', c="r", label="Exact DMD")
plt.errorbar(rank_values, np.mean(oversampling_error, axis=1), yerr=np.std(oversampling_error, axis=1), fmt='-o', c="g", label="RDMD_gaussian")
plt.errorbar(rank_values, np.mean(oversampling_error_count, axis=1), yerr=np.std(oversampling_error_count, axis=1), fmt='-o', c="orange", label="RDMD_count")
plt.errorbar(rank_values, np.mean(oversampling_error_shrt, axis=1), yerr=np.std(oversampling_error_shrt, axis=1), fmt='-o', c="k", label="RDMD_shrt")

plt.title("Reconstruction Error")
plt.xlabel("Target Rank")
plt.ylabel("Relative Error")
plt.legend()

plt.subplot(1, 2, 2)
plt.errorbar(rank_values, np.mean(dmd_oversampling_times, axis=1), yerr=np.std(dmd_oversampling_times, axis=1), fmt='-o', c="r", label="Exact DMD")
plt.errorbar(rank_values, np.mean(oversampling_times, axis=1), yerr=np.std(oversampling_times, axis=1), fmt='-o', c="g", label="RDMD_gaussian")
plt.errorbar(rank_values, np.mean(oversampling_times_count, axis=1), yerr=np.std(oversampling_times_count, axis=1), fmt='-o', c="orange", label="RDMD_count")
plt.errorbar(rank_values, np.mean(oversampling_times_shrt, axis=1), yerr=np.std(oversampling_times_shrt, axis=1), fmt='-o', c="k", label="RDMD_shrt")

plt.title("Training Time")
plt.xlabel("Target Rank")
plt.ylabel("Runtime (s)")
plt.legend()
plt.tight_layout()
plt.show()

## (2) Varying # of sampling rows. 

In [None]:
rank = 50
sampling_values = [int(rank /2) , rank, int(rank*np.log(rank)), int(rank**2)] # of smapling rows

In [None]:
# Exact DMD
dmd_error = np.zeros((len([rank]), num_trials))
dmd_times = np.zeros((len([rank]), num_trials))

for i, rank in enumerate([rank]):
    for j in range(num_trials):  
        t0 = time.time()
        dmd = DMD(svd_rank=rank, exact=True)
        dmd.fit(X.T)
        t1 = time.time()

        dmd_error[i, j] = (
            compute_error(snapshots_matrix, dmd.reconstructed_data)
        )
        dmd_times[i, j] = (t1 - t0)

In [None]:
# RDMD with Gaussian sketch
oversampling_error = np.zeros((len(sampling_values), num_trials))
oversampling_times = np.zeros((len(sampling_values), num_trials))

for i, r in enumerate(sampling_values):
    for j in range(num_trials):  
        test_matrix = np.random.randn(m, r)
        t0 = time.time()
        rdmd = RDMD(svd_rank=rank, test_matrix=test_matrix).fit(
            snapshots_matrix
        )
        t1 = time.time()

        oversampling_error[i, j] = (
            compute_error(snapshots_matrix, rdmd.reconstructed_data)
        )
        oversampling_times[i, j] = (t1 - t0)

In [None]:
# RDMD with Count sketch
oversampling_error_count = np.zeros((len(sampling_values), num_trials))
oversampling_times_count = np.zeros((len(sampling_values), num_trials))

for i, r in enumerate(sampling_values):
    for j in range(num_trials):
        test_matrix = generate_countsketch_matrix(m, r)
        t0 = time.time()
        rdmd_count = RDMD(svd_rank=rank, test_matrix=test_matrix).fit(
            snapshots_matrix
        )
        t1 = time.time()

        oversampling_error_count[i, j] = (
            compute_error(snapshots_matrix, rdmd_count.reconstructed_data)
        )
        oversampling_times_count[i, j] += (t1 - t0) 

In [None]:
# RDMD with SHRT sketch
oversampling_error_shrt = np.zeros((len(rank_values), num_trials))
oversampling_times_shrt = np.zeros((len(rank_values), num_trials))

for i, r in enumerate(rank_values):
    for j in range(num_trials):
        test_matrix = generate_shrt_matrix(m, r)
        t0 = time.time()
        rdmd_shrt = RDMD(svd_rank=rank, test_matrix=test_matrix).fit(
            snapshots_matrix
        )
        t1 = time.time()

        oversampling_error_shrt[i, j] = (
            compute_error(snapshots_matrix, rdmd_shrt.reconstructed_data)
        )
        oversampling_times_shrt[i, j] += (t1 - t0) 

In [None]:
plt.figure(figsize=(8, 3))
plt.subplot(1, 2, 1)

plt.errorbar(sampling_values, np.mean(oversampling_error, axis=1), yerr=np.std(oversampling_error, axis=1), fmt='-o', c="g", label="RDMD_gaussian")
plt.errorbar(sampling_values, np.mean(oversampling_error_count, axis=1), yerr=np.std(oversampling_error_count, axis=1), fmt='-o', c="orange", label="RDMD_count")
plt.errorbar(sampling_values, np.mean(oversampling_error_shrt, axis=1), yerr=np.std(oversampling_error_shrt, axis=1), fmt='-o', c="k", label="RDMD_shrt")

# plt.axhline(y=cdmd_error, c="b", label="CDMD")
plt.axhline(y=np.mean(dmd_error, axis = 1), c="r", label="Exact DMD")

plt.title("Reconstruction Error")
plt.xlabel("# of sampling rows")
plt.ylabel("Relative Error")
plt.xscale("log")
plt.legend()

plt.subplot(1, 2, 2)

plt.errorbar(sampling_values, np.mean(oversampling_times, axis=1), yerr=np.std(oversampling_times, axis=1), fmt='-o', c="g", label="RDMD_gaussian")
plt.errorbar(sampling_values, np.mean(oversampling_times_count, axis=1), yerr=np.std(oversampling_times_count, axis=1), fmt='-o', c="orange", label="RDMD_count")
plt.errorbar(sampling_values, np.mean(oversampling_times_shrt, axis=1), yerr=np.std(oversampling_times_shrt, axis=1), fmt='-o', c="k", label="RDMD_shrt")

# plt.axhline(y=cdmd_time, c="b", label="CDMD")
plt.axhline(y=np.mean(dmd_times, axis = 1), c="r", label="Exact DMD")

plt.title("Training Time")
plt.xlabel("# of sampling rows")
plt.ylabel("Runtime")
plt.legend()
plt.tight_layout()
plt.show()