# Optimal Denoising Schedule for Denoising Autoencoders on MNIST

In [22]:
import numpy as np
import casadi
from casadi import *
import matplotlib.pyplot as plt
from numpy.random import normal
from tqdm import tqdm
from sklearn.datasets import fetch_openml
from collections import defaultdict
import pickle

In [2]:
#Define auxiliary functions for the ODEs (Theory)

def lt_lt(Delta,sigma,Q,M,k,S,b):
    return (M[k,:].T@M[k,:])*(1-Delta)+((1-Delta)*sigma**2+Delta)*Q
    
def lt_lx(Delta,sigma,Q,M,k,S,b):
    return (M[k,:].T@M[k,:])*sqrt(1-Delta)+(sqrt(1-Delta)*sigma**2)*Q
    
def lx_lx(Delta,sigmas,Q,M,k,S,b):
    return (M[k,:].T@M[k,:])+(sigmas[k]**2)*Q
    
def nut_nut(Delta,sigmas,Q,M,k,S,b):
    stmp = reshape(S[k,:],K,1)
    return (stmp@stmp.T)*(1-Delta)+((1-Delta)*sigmas[k]**2+Delta)*S
    
def lt_nut(Delta,sigmas,Q,M,k,S,b): #shape = r x K
    stmp= reshape(S[k,:],K,1)
    return (M[k,:].T@stmp.T)*(1-Delta)+((1-Delta)*sigmas[k]**2+Delta)*M.T
    
def lt_nux(Delta,sigmas,Q,M,k,S,b):  #shape = r x K
    stmp= reshape(S[k,:],K,1)
    return (M[k,:].T@stmp.T)*sqrt(1-Delta)+(sqrt(1-Delta)*sigmas[k]**2)*M.T
    
def lx_nut(Delta,sigmas,Q,M,k,S,b):  #shape = r x K
    stmp= reshape(S[k,:],K,1)
    return (M[k,:].T@stmp.T)*sqrt(1-Delta)+(sqrt(1-Delta)*sigmas[k]**2)*M.T
    
    
def lt_L(Delta,sigmas,Q,M,k,S,b):
    r,_=Q.size()
    return lt_lx(Delta,sigmas[k],Q,M,k,S,b)-lt_lt(Delta,sigmas[k],Q,M,k,S,b)@(b*MX.eye(r)+Q)
        
def L_L(Delta,sigmas,Q,M,k,S,b):
    r,_=Q.size()
    A=(b*MX.eye(r)+Q)@lt_lx(Delta,sigmas[k],Q,M,k,S,b)
    return lx_lx(Delta,sigmas,Q,M,k,S,b)-A-A.T+(b*MX.eye(r)+Q)@lt_lt(Delta,sigmas[k],Q,M,k,S,b)@(b*MX.eye(r)+Q)

def x_x(Delta,sigmas,Q,M,k,S,b):
    return sigmas[k]**2
def xt_xt(Delta,sigmas,Q,M,k,S,b):
    return (1-Delta)*sigmas[k]**2+Delta
def x_xt(Delta,sigmas,Q,M,k,S,b):
    return  sqrt(1-Delta)*sigmas[k]**2
    
def compute_mse(Delta,sigmas,Q,M,S,b,cs,N,K):
    mse=MX(0)
    for k in range(K):
        mse+=cs[k]*(trace(Q@lt_lt(Delta,sigmas[k],Q,M,k,S,b))\
            -2*trace(lt_lx(Delta,sigmas[k],Q,M,k,S,b))+2*b*trace(lt_lt(Delta,sigmas[k],Q,M,k,S,b))\
        +b**2*N*((1-Delta)*sigmas[k]**2+Delta)-2*b*N*sqrt(1-Delta)*sigmas[k]**2+N*sigmas[k]**2)
    return mse

In [3]:
K=2 #Number of Gaussian clusters (denoted C1 in the paper)
r=2 #Number of network hidden nodes (denoted K in the paper)

In [4]:
Delta = MX.sym('Delta',1)  #Noise level
sigmas = MX.sym('sigmas', K)  #Variance of the Gaussian clusters
Q = MX.sym('Q', r, r)  #Student-student overlap
M = MX.sym('M', K, r)  #Student-centroid overlap (denoted R in the paper)
S = MX.sym('S', K, K)  #Centroid-centroid overlap dDenoted Omega in the paper)
b = MX.sym('b', 1)  #Skip connection
cs = MX.sym('cs', K) #Probability of cluster membership (p_c in the paper)
N = MX.sym('N', 1) #Input dimension
mse_sym = compute_mse(Delta, sigmas, Q, M, S, b, cs, N,K)
mse_func = Function('mse_func', [Delta, sigmas, Q, M, S, b, cs, N], [mse_sym])

In [5]:
# Load MNIST dataset
mnist = fetch_openml('mnist_784', version=1)
X = mnist.data.to_numpy() / 255.0 /5  # normalize to [0,1]
y = mnist.target.astype(int)


# Group by digit
digit_data = defaultdict(list)
for img, label in zip(X, y):
    digit_data[label].append(img)


# Fit Gaussian model x = mu + sigma * z for each digit
mu_sigma_per_digit = {}
for digit in range(10):
    data = np.stack(digit_data[digit])
    mu = np.mean(data, axis=0)
    sigma_data = np.sqrt(np.mean((data - mu) ** 2))
    mu_sigma_per_digit[digit] = (mu, sigma_data)

In [6]:
#define centroids
mu0=mu_sigma_per_digit[0][0] 
mu1=mu_sigma_per_digit[1][0]
N=len(mu0) #input dimension

In [7]:
P=len(digit_data[0])+len(digit_data[1]) #total number of samples
p0=len(digit_data[0])/P
p1=len(digit_data[1])/P

In [8]:
np.random.seed(195)
T_real=1  #alphaF
mus=np.array([mu0,mu1])
musNormalized=np.array([mu0/np.linalg.norm(mu0),mu1/np.linalg.norm(mu1)])
cs=[p0,p1] #probability of cluster membership
sigmas=[mu_sigma_per_digit[0][1],mu_sigma_per_digit[1][1]] #standard deviations of the clusters
eta=5
T=T_real*eta #scaled total time
b0=0.0 #initial condition of the skip connection
#Initializations
mutmp=np.concatenate((mus,mus),axis=0)
weights_0=0.2*musNormalized+normal(0,1/np.sqrt(N),size=(r,N))
weights_0=np.array([weights_0[p]*np.sqrt(N)/np.linalg.norm(weights_0[p]) for p in range(r)])
M0=mus@weights_0.T/np.sqrt(N)
Q0=weights_0@weights_0.T/N
S=mus@mus.T

In [9]:
O = MX.sym('O',r*(r+K)+1)
M, Q = O[:r*K], O[r*K:-1]
M = reshape(M, K, r)
Q = reshape(Q, r, r)
b = O[-1:]

In [10]:
#ODEs
dM=MX.zeros(K,r)
dQ=MX.zeros(r,r)
db=MX.zeros(1,1)
for k in range(K):
    dMtmp=eta*(lt_nux(Delta,sigmas,Q,M,k,S,b).T-2*b*lt_nut(Delta,sigmas,Q,M,k,S,b).T-M@lt_lt(Delta,sigmas[k],Q,M,k,S,b)\
               +lx_nut(Delta,sigmas,Q,M,k,S,b).T-lt_nut(Delta,sigmas,Q,M,k,S,b).T@Q)

    dQtmp=(lt_L(Delta,sigmas,Q,M,k,S,b)+lt_L(Delta,sigmas,Q,M,k,S,b).T)*(2*eta+eta**2*(x_xt(Delta,sigmas,Q,M,k,S,b)-b*xt_xt(Delta,sigmas,Q,M,k,S,b)))\
            +eta**2*L_L(Delta,sigmas,Q,M,k,S,b)*xt_xt(Delta,sigmas,Q,M,k,S,b)\
            +eta**2*lt_lt(Delta,sigmas[k],Q,M,k,S,b)*(x_x(Delta,sigmas,Q,M,k,S,b)-2*b*x_xt(Delta,sigmas,Q,M,k,S,b)+b**2*xt_xt(Delta,sigmas,Q,M,k,S,b))  
    dM+=cs[k]*dMtmp
    dQ+=cs[k]*dQtmp
    db+=cs[k]*eta*(sqrt(1-Delta)*sigmas[k]**2-b*(1-Delta)*sigmas[k]**2-b*Delta)
    
dQ=0.5*(dQ+dQ.T)

# Scaling time with eta
dQ/=eta
dM/=eta
db/=eta

In [11]:
dM_flat = reshape(dM, K * r, 1)  # Column-major order (column by column)
dQ_flat = reshape(dQ, r * r, 1)  # Column-major order (column by column)

# Concatenate M_flat and Q_flat into a single vector
xdot = vertcat(dM_flat, dQ_flat, db)

In [12]:
dT_control = 0.01
N_control = int(T//dT_control) # number of control intervals
L=MX(0)
alphas=np.linspace(0,T_real,N_control+1)

In [13]:
dae = {'x':O, 'p':Delta, 'ode':xdot, 'quad':L}
F = integrator('F', 'cvodes', dae, 0, T/N_control)

In [14]:
# Evaluate at a test point
Q0flat=Q0.flatten(order='F')
M0flat=M0.flatten(order='F')
O0=np.concatenate((M0flat, Q0flat,[b0])).tolist()
Fk = F(x0=O0,p=0.1)
print(Fk['xf'])
print(Fk['qf'])


[0.343627, 0.175487, 0.203153, 0.265527, 0.996657, 0.000802741, 0.000802741, 0.996625, 1.70825e-05]
0


In [15]:
# Start with an empty NLP
w=[]
w0 = []
lbw = []
ubw = []
J = 0
g=[]
lbg = []
ubg = []

In [16]:
# "Lift" initial conditions
Ok = MX.sym('O0', r*K+r*r+1)
w += [Ok]
lbw += O0
ubw += O0
w0 += O0
Delta0=0.3
Otmp=O0

In [17]:
# Formulate the NLP
for k in range(N_control):
    # New NLP variable for the control
    Uk = MX.sym('U_' + str(k))
    w += [Uk]
    lbw += [0.01]
    ubw += [0.99]
    w0 += [Delta0]

    # Integrate till the end of the interval
    Fk = F(x0=Ok, p=Uk)
    Ok_end = Fk['xf']

    Otmp = F(x0=Otmp,p=Delta0)['xf']
    
    # New NLP variable for state at end of interval
    Ok = MX.sym('O_' + str(k+1), r*(r+K)+1)
    w   += [Ok]

    
    
    lbw +=  [0 if (K+r<=i and i<(K+r)*r) else -inf for i in range(r*(r+K)+1)]
    ubw += [inf for i in range(r*(r+K)+1)]
    w0 += Otmp.toarray().squeeze().tolist()

     # Add equality constraint
    g   += [Ok_end-Ok]
    lbg += [0 for i in range(r*(r+K)+1)]
    ubg += [0 for i in range(r*(r+K)+1)]



Mf, Qf = Ok[:r*K], Ok[r*K:-1]
Mf = reshape(Mf, K, r)
Qf = reshape(Qf, r, r)
bf=Ok[-1]

In [18]:
opt_list={"Deltaf":[], "M_opt":[], "Q_opt":[], "b_opt":[], "Delta_opt":[], "J_opt":[]}

In [19]:
const_list={"Deltaf":[], "M":[], "Q":[], "b":[], "J_const":[]}

In [20]:
Deltas=[0.1,0.2,0.3,0.4]

In [21]:
for Deltaf in Deltas:
    opt_list["Deltaf"].append(Deltaf)
    J=compute_mse(Deltaf,sigmas,Qf,Mf,S,bf,cs,N,K)
    # Create an NLP solver
    prob = {'f': J, 'x': vertcat(*w), 'g': vertcat(*g)}
    solver = nlpsol('solver', 'ipopt', prob)
    # Solve the NLP
    print("\n\nSolving for $\\Delta_F=${}...\n\n".format(Deltaf))
    sol = solver(x0=w0, lbx=lbw, ubx=ubw, lbg=lbg, ubg=ubg)
    w_opt = sol['x'].full().flatten()

    #Optimal parameters
    M_opt=[w_opt[i:i+K*r]for i in range(0, len(w_opt), K*r+r*r+2)]
    Q_opt=[w_opt[i:i+r*r]for i in range(K*r, len(w_opt), K*r+r*r+2)]
    b_opt=[w_opt[i]for i in range(K*(K+r), len(w_opt), K*r+r*r+2)]
    Delta_opt=[w_opt[i]for i in range(K*(K+r)+1, len(w_opt), K*r+r*r+2)]

    Jsopt=[]
    for i in range(len(b_opt)):
        Jsopt.append(mse_func(Deltaf,sigmas,np.reshape(Q_opt[i],(r,r),order='F'),np.reshape(M_opt[i],(K,r),order='F'),S,b_opt[i],cs,N).full()[0][0])

    opt_list["M_opt"].append(M_opt)
    opt_list["Q_opt"].append(Q_opt)
    opt_list["b_opt"].append(b_opt)
    opt_list["Delta_opt"].append(Delta_opt)
    opt_list["J_opt"].append(Jsopt)

    #save to file
    with open(f'data/schedule_Delta={Deltaf}_Treal={T_real}_eta={eta}.pickle', 'wb') as handle:
        pickle.dump(Delta_opt, handle, protocol=pickle.HIGHEST_PROTOCOL)




Solving for $\Delta_F=$0.1...



******************************************************************************
This program contains Ipopt, a library for large-scale nonlinear optimization.
 Ipopt is released as open source code under the Eclipse Public License (EPL).
         For more information visit https://github.com/coin-or/Ipopt
******************************************************************************

This is Ipopt version 3.14.11, running with linear solver MUMPS 5.4.1.

Number of nonzeros in equality constraint Jacobian...:    45336
Number of nonzeros in inequality constraint Jacobian.:        0
Number of nonzeros in Lagrangian Hessian.............:    27427

Total number of variables............................:     4990
                     variables with only lower bounds:     1996
                variables with lower and upper bounds:      499
                     variables with only upper bounds:        0
Total number of equality constraints.................:    