Skip to content
Permalink
1303c014dc
Switch branches/tags
Go to file
 
 
Cannot retrieve contributors at this time
242 lines (200 sloc) 8 KB
import torch
import torch.multiprocessing as mp
import numpy as np
import random
# from .admm_lasso import ADMMLassoSolver
# from .admm_ridge import ADMMRidgeSolver
class SolveIndividual:
def solve(self, A, b, nu, rho, Z):
t1 = A.dot(A.T)
A = A.reshape(-1, 1)
tX = (A * b + rho * Z - nu) / (t1 + rho)
return tX
class CombineSolution:
def combine(self, nuBar, xBar, Z, rho):
t = nuBar.reshape(-1, 1)
t = t + rho * (xBar.reshape(-1, 1) - Z)
return t.T
# End of Helper class for solving in Parallel
class ADMMRidgeSolver():
"""
min f(x) + g(z)
s.t. Ax + Bz = c
"""
def __init__(self, params, parallel=False):
if len(params) < 4:
raise ValueError("Incorect number of parameters passed: 4 required, but {} are passed".format(len(params)))
self.A = params[0].T
self.b = params[1]
self.lamb = params[2]
self.rho = params[3]
self.z = np.zeros((self.A.shape[0], 1))
self.parallel = parallel
if self.parallel:
self.x = np.zeros((self.A.shape[1], self.A.shape[0]))
self.nu = np.zeros((self.A.shape[1], self.A.shape[0]))
self.x_bar = np.mean(self.x, 0).reshape(-1 , 1)
self.nu_bar = np.mean(self.nu, 0).reshape(-1 , 1)
self.num_cores = mp.cpu_count()
else:
self.x = np.zeros((self.A.shape[0], 1))
self.nu = np.zeros((self.A.shape[0], 1))
def step(self):
if self.parallel:
self.update_parallel()
else:
self.update_serial()
return self.getLoss()
def getParams(self):
if self.parallel:
return self.x_bar
return self.x
def get_diff(self):
if self.parallel:
print(self.x_bar - self.z)
else:
print(self.x - self.z)
def update_one_x(self, i):
self.x[i] = (self.A[:, i] * self.b[i] + self.rho * self.z.reshape(-1) - self.nu[i]) / (self.A[:, i].dot(self.A[:, i].T) + self.rho)
def update_one_nu(self, i):
self.nu[i] = self.nu[i] + (self.rho * (self.x[i] - self.z.reshape(-1)))
def update_parallel(self):
p = mp.Pool(self.num_cores)
p.map(self.update_one_x, range(self.A.shape[1]))
# Parallel(n_jobs=self.num_cores)(delayed()(i) for i in range(self.A.shape[1]))
self.x_bar = np.mean(self.x, 0).reshape(-1 , 1)
# self.z = (self.x_bar + self.nu_bar / self.rho) - (self.lamb / self.rho) * np.sign(self.z)
self.z = (self.rho * self.x_bar + self.nu_bar) / (self.rho + 2 * self.lamb)
p.map(self.update_one_nu, range(self.A.shape[1]))
# Parallel(n_jobs=self.num_cores)(delayed(self.update_one_nu)(i) for i in range(self.A.shape[1]))
self.nu_bar = np.mean(self.nu, 0).reshape(-1 , 1)
def update_serial(self):
self.x = np.linalg.inv(self.A.dot(self.A.T) + self.rho).dot(self.A.dot(self.b) + self.rho * self.z - self.nu)
# self.z = (self.x + self.nu)*(self.rho/(2*self.lamb + self.rho))
self.z = (self.rho * self.x + self.nu) / (self.rho + 2 * self.lamb)
self.nu = self.nu + self.rho * (self.x - self.z)
def getLoss(self):
if self.parallel:
return 0.5 * np.linalg.norm(self.A.T.dot(self.x_bar) - self.b) ** 2 + self.lamb * (np.linalg.norm(self.x_bar)**2)
return 0.5 * np.linalg.norm(self.A.T.dot(self.x) - self.b) ** 2 + self.lamb * (np.linalg.norm(self.x)**2)
# ADMM Lasso Solver
class ADMMLassoSolver:
"""
Implements ADMM algorithm for Lasso Objective Function
"""
def __init__(self, params, parallel = False):
# Start of params
if len(params) < 4:
raise ValueError("Incorect number of parameters passed: 4 required, but {} are passed".format(len(params)))
self.A = params[0]
self.b = params[1]
self.alpha = params[2]
self.rho = params[3]
# End of params
self.D = self.A.shape[1]
self.N = self.A.shape[0]
if parallel:
self.XBar = np.zeros((self.N, self.D))
self.nuBar = np.zeros((self.N, self.D))
self.nu = np.zeros((self.D, 1))
self.X = np.random.randn(self.D, 1)
self.Z = np.zeros((self.D, 1))
self.parallel = parallel
self.numberOfThreads = mp.cpu_count()
# step: Step function
def step(self):
"""
Performs a single optimization step.
"""
if self.parallel:
return self.step_parallel()
# Solve for X_t+1
self.X = np.linalg.inv(self.A.T.dot(self.A) + self.rho).dot(self.A.T.dot(self.b) + self.rho * self.Z - self.nu)
# Solve for Z_t+1
self.Z = self.X + self.nu / self.rho - (self.alpha / self.rho) * np.sign(self.Z)
# Combine
self.nu = self.nu + self.rho * (self.X - self.Z)
return self.getLoss()
# solveIndividual: Solve Lasso in Parallel
def solveIndividual(self, i):
solve = SolveIndividual()
return solve.solve(self.A[i], np.asscalar(self.b[i]), self.nuBar[i].reshape(-1, 1), self.rho, self.Z)
# combineSolution: Combine Solution Parallely
def combineSolution(self, i):
combine = CombineSolution()
return combine.combine(self.nuBar[i].reshape(-1, 1), self.XBar[i].reshape(-1, 1), self.Z, self.rho)
# step method - Parallel Version
def step_parallel(self):
# Solve for X_t+1
process = []
for i in range(0, self.N-1, 4):
p = mp.Process(target = self.solveIndividual, args= (i,))
p.start()
process.append(p)
for p in process:
p.join()
self.X = np.average(self.XBar, axis = 0)
self.nu = np.average(self.nuBar, axis = 0)
self.X = self.X.reshape(-1, 1)
self.nu = self.nu.reshape(-1, 1)
# Solve for Z_t+1
self.Z = self.X + self.nu / self.rho - (self.alpha / self.rho) * np.sign(self.Z)
# Combine
process = []
for i in range(0, self.N-1, 4):
p = mp.Process(target = self.combineSolution, args= (i,))
p.start()
process.append(p)
for p in process:
p.join()
return self.getLoss()
# step method - Iterative Version
def step_iterative(self):
# Solve for X_t+1
for i in range(0, self.N-1):
t = self.solveIndividual(i)
self.XBar[i] = t.T
self.X = np.average(self.XBar, axis = 0)
self.nu = np.average(self.nuBar, axis = 0)
self.X = self.X.reshape(-1, 1)
self.nu = self.nu.reshape(-1, 1)
# Solve for Z_t+1
self.Z = self.X + self.nu / self.rho - (self.alpha / self.rho) * np.sign(self.Z)
# Combine
for i in range(0, self.N-1):
t = self.nuBar[i].reshape(-1, 1)
t = t + self.rho * (self.XBar[i].reshape(-1, 1) - self.Z)
self.nuBar[i] = t.T
return self.getLoss()
# getLoss - returns Lasso Objective Loss
def getLoss(self):
return 0.5 * np.linalg.norm(self.A.dot(self.X) - self.b)**2 + self.alpha * np.linalg.norm(self.X, 1)
# returns weight
def getParams(self):
return self.X
class ADMM:
"""
Wrapper ADMM class implements following Solver
- Lasso
"""
def __init__(self, params, objFunc = "Lasso", parallel = False):
if objFunc == "Lasso":
self.solver = ADMMLassoSolver(params, parallel)
elif objFunc == "Ridge":
self.solver = ADMMRidgeSolver(params, parallel)
else:
raise ValueError("Un-Supported Solver")
# step: calls Step on current solver
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
return self.solver.step()
# getLoss - returns Loss for current solver
def getLoss(self):
return self.solver.getLoss()
# returns weight
def getParams(self):
return self.solver.getParams()