# Malaspina stochastic matrix factorization

This method is based off the paper found here: https://proceedings.neurips.cc/paper/2007/file/d7322ed717dedf1eb4e6e52a37ea7bcd-Paper.pdf

First, import the data cube and draw 1000 random velocity fields:

In [None]:
import netCDF4 as ncf
import numpy as np
import matplotlib.pyplot as plt
import torch
import pickle

fn = './MalaspinaGlacierCube_32607.nc'
data = ncf.Dataset(fn)

# This makes fields with a higher number of non-masked entries more probable
temperature=10.0

# Precomputed non-masked fraction
np.random.seed(0)
valid_fraction = np.array(pickle.load(open('valid_fraction.p','rb')))

draw_p = valid_fraction**temperature
draw_p /= draw_p.sum()

# Draw a set of indices and sort them by date
inds = np.random.choice(range(len(valid_fraction)),size=min(1000,len(valid_fraction)),p=draw_p,replace=False)
inds = inds[np.argsort(data.variables['date_center'][inds])]

# Working with magnitudes
U = data.variables['v']#[:]

This next line actually loads the data into a numpy array

In [None]:
U_inds = U[inds]

Compute time steps

In [None]:
dates = data.variables['date_center'][inds]
dt = dates[1:] - dates[:-1]

Import a glacier outline for use as a mask.  

In [None]:
n_t,n_row,n_col = U_trunc.shape

boundary_points = pickle.load(open('boundary.p','rb'))
xs = data.variables['x'][50:300]
ys = data.variables['y'][:300]

Xs,Ys = np.meshgrid(xs,ys)
X = np.vstack((Xs.ravel(),Ys.ravel())).T

import pyproj
import matplotlib.path as path
tform = pyproj.Transformer.from_crs(crs_from=3338, crs_to=32607, always_xy=True)
fx, fy = tform.transform(boundary_points[:,0], boundary_points[:,1])

p = path.Path(np.vstack((fx,fy)).T)
glacier_mask = p.contains_points(X)

plt.scatter(*X[glacier_mask].T)

glacier_mask = np.hstack([glacier_mask]*n_t).reshape(n_t,n_row,n_col)
mask_f = glacier_mask.reshape(U_trunc.shape[0],-1).T

Compute the mean of off-ice but still "valid" pixels for each image and subtract this from the observed velocity's valid pixels

In [None]:
# Truncate to a slightly smaller area to avoid weirdness in the ocean
U_trunc = U_inds.filled(-1).astype(float)[:,:300,50:300]

off_ice_means = np.zeros(U_trunc.shape[0])
for i in range(U_trunc.shape[0]):
    
    off_ice_means[i] = np.median(U_trunc[i][(U_trunc[i]!=-1) & (glacier_mask[i]==0)])
    U_trunc[i][U_trunc[i]!=-1] -= off_ice_means[i]
    
Uf = (U_trunc).reshape(U_trunc.shape[0],-1).T


Solve a probabilistic matrix factorization problem, with a few types of regularization

In [None]:
# How many non-orthogonal modes to compute
l = 30

# Regularization strength on column-space (spreads out mode coefficients)
lamda_u = 10.0

# Regularization strength on row-space (makes it more likely to use more modes to describe a velocity field)
lamda_v = 1.0

# Penalizes mode gradients
lamda_x = 1000.0

# Penalizes variation in mode strength through time
lamda_t = 10000.0

# avoid divide by zero in time step (which is used in time-gradient regularization)
dt_0 = torch.from_numpy(dt + 1.0)

# Do some reshaping and rescaling
R = torch.from_numpy(Uf)
Rhat = R.ravel()
I_0 = Rhat!=-1.
I_1 = torch.from_numpy(mask_f.ravel())
I = I_0 & I_1
Rbar = Rhat[I]

Rmean = Rbar.mean()
Rstd = Rbar.std()

Rbar = (Rbar - Rmean)/Rstd

# Define low-rank matrix factors 
U_ = torch.randn(R.shape[0],l,requires_grad=True)
V_ = torch.randn(l,R.shape[1],requires_grad=True)

# Initialize
U_.data[:] *= 1e-2
V_.data[:] *= 1e-2

# Define an optimizer for gradient descent
optimizer = torch.optim.Adam([U_,V_],1e-1)

# Do 500 iterations of gradient descent
for i in range(500):
    optimizer.zero_grad()
    
    # Compute the predicted matrix
    G = U_ @ V_
   
    # Compute the gradients of the right factor (the mode coefficients through time)
    dudt = (V_[:,1:] - V_[:,:-1])/dt_0
    
    # Reshape the left factor (columns hold the modes) into a grid and takes its gradients.
    U_grid = U_.T.reshape(l,n_row,n_col)    
    dudrow = U_grid[:,1:] - U_grid[:,:-1]
    dudcol = U_grid[:,:,1:] - U_grid[:,:,:-1]
    
    # Flatten and mask the predictions
    Gbar = G.ravel()[I]

    # Compute a variety of negative log likelihoods
    
    # Data misfit
    E_misfit = ((Rbar - Gbar)**2).mean() 
    #E_misfit = (torch.abs(Rbar - Gbar)).mean() 
    
    # norm regularization
    E_reg = lamda_u*(U_**2).sum()/len(Rbar) + lamda_v*(V_**2).sum()/len(Rbar)
    
    # Spatial gradient regularization
    E_space = lamda_x/len(Rbar)*((dudrow**2).sum() + (dudcol**2).sum())  
    
    # Time gradient regularization
    E_time = lamda_t/len(Rbar)*(dudt**2).sum()
    
    # Sum to form total cost
    E = E_misfit + E_reg + E_space + E_time
    
    # Backpropagate gradients
    E.backward()
    
    # Update factors
    optimizer.step()
    print(i,E_misfit.item(),E_reg.item(),E_space.item(),E_time.item())



The factors are non-orthogonal, which is a pain to look at.  However, now that the reconstruction doesn't have any missing data, we can just apply low-rank SVD to grab the first 10 or so modes very efficiently (which are now orthonormal).  

In [None]:
u,s,v = torch.svd_lowrank(U_.detach()@V_.detach(),10)

Plot the columns of $u$.  Mode importance is given by $s$.  

In [None]:
#plt.imshow(u[:,16].reshape(U_trunc.shape[1],U_trunc.shape[2]))
#plt.gcf().set_size_inches(12,12)
for i in range(10):
    plt.imshow(u[:,i].detach().reshape(U_trunc.shape[1],U_trunc.shape[2]),extent=(xs.min(),xs.max(),ys.min(),ys.max()),cmap=plt.cm.seismic,vmin=-0.03,vmax=0.03)
    plt.scatter(fx,fy)
    plt.xlim(xs.min(),xs.max())
    plt.ylim(ys.min(),ys.max())
    plt.title(f"Mode {i}, Variance Fraction {s.numpy()[i]/s.numpy()[0]*100:.01f}%")
    plt.gcf().set_size_inches(12,12)
    plt.savefig(f'modes_biascorrected/mode_{i:02d}.png')

Reconstruct the velocity fields from the truncated SVD (this gets rid of scanline modes and other undesirables).  

In [None]:
U_recon = ((u * s) @ v.T).T.reshape(n_t,n_row,n_col)

Plot the reconstructed velocity fields along with the original fields.  

In [None]:
import matplotlib.pyplot as plt

for i in range(n_t):
    fig,axs = plt.subplots(nrows=1,ncols=2)
    axs[0].imshow((torch.maximum(U_recon[i]*Rstd + Rmean,torch.ones(n_row,n_col))),vmin=0,vmax=2000,extent=(xs.min(),xs.max(),ys.min(),ys.max()))
    axs[0].plot(fx,fy,'r-')
    axs[0].set_xlim(xs.min(),xs.max())
    axs[0].set_ylim(ys.min(),ys.max())
    #plt.colorbar()
    axs[1].imshow((U_trunc[i]),vmin=0,vmax=2000,extent=(xs.min(),xs.max(),ys.min(),ys.max()))
    axs[1].plot(fx,fy,'r-')
    axs[1].set_xlim(xs.min(),xs.max())
    axs[1].set_ylim(ys.min(),ys.max())
    fig.set_size_inches(12,8)
    plt.subplots_adjust(wspace=0.0)
    axs[0].set_title(f't={1985+(dates[i]-dates[0])/365.:.02f}')
    fig.savefig(f'out_images_biascorrected/vel_{i:03d}.png')
    plt.close()

Plot a time series of the top 7 modes (with arbitrary offset to see the patterns).

In [None]:
for i in range(10):
    plt.plot(1985+(dates-dates[0])/365,v[:,i]+i*0.2,color=plt.cm.viridis(i/7.))
ax = plt.gca()
ax.set_yticklabels([])
ax.set_xlabel('Years since 1985')
plt.savefig('time_series.png')