/
solvers_reg.py
101 lines (63 loc) · 2.55 KB
/
solvers_reg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
from torch import cholesky, cholesky_solve
import numpy as np
from sketches import gaussian, srht, less, sparse_rademacher, rrs
from utils import average
from time import time
SKETCHES = {'gaussian': gaussian, 'srht': srht, 'less_sparse': sparse_rademacher, 'less': less, 'rrs': rrs}
def direct_method(data_matrix, target):
if target.ndim == 1:
target = target.reshape((-1,1))
start = time()
x_opt = torch.lstsq(target, data_matrix)[0][:data_matrix.shape[1]]
baseline_time = time() - start
return x_opt, baseline_time
class Solver:
def __init__(self, data_matrix, target, x_opt, reg_param=1e-2):
self.A = data_matrix
if target.ndim == 1:
target = target.reshape((-1,1))
self.b = target
(self.n, self.d), self.c = self.A.shape, self.b.shape[1]
self.x_opt = x_opt
if self.x_opt.ndim == 1:
self.x_opt = self.x_opt.reshape((-1,1))
def compute_error(self, x):
return torch.mean( torch.log( 1 + torch.exp(-self.b * (self.A @ x)) )) + reg_param/2 * (x**2).sum()
class IHS(Solver):
def __init__(self, A, b, x_opt, sketch='gaussian'):
Solver.__init__(self, A, b, x_opt)
self.sketch = sketch
self.sketch_fn = SKETCHES[sketch]
def compute_step_size(self, m, q, x, p, g):
if not self.line_search:
gamma = m / (m-self.d)
return q/(gamma*(gamma-1+q))
else:
return (p*g).sum() / (p*(self.A.T @ (self.A @ p))).sum()
def ihs_iteration(self, x, m, q):
g = self.A.T @ (self.A @ x -self.b)
p = torch.zeros(self.d, self.c).to(self.A.device)
for _ in range(q):
sa = self.sketch_fn(self.A, m, nnz=self.nnz)
U = cholesky(sa.T @ sa)
p += 1./q * cholesky_solve(g, U)
mu = self.compute_step_size(m, q, x, p, g)
x = x - mu * p
return x
@average
def solve(self, m, q=1, line_search=False, n_iterations=10, n_trials=1, nnz=None):
if nnz is None:
self.nnz = self.d
else:
self.nnz = nnz
self.line_search = line_search
x = 1./np.sqrt(self.d) * torch.randn(self.d, self.c)
x = x.to(self.A.device)
errors = [self.compute_error(x)]
iteration = 0
for _ in range(n_iterations):
x = self.ihs_iteration(x, m, q)
errors.append(self.compute_error(x))
cv_rate = (errors[-1]/errors[0])**(1./n_iterations)
return x, torch.Tensor(errors)/errors[0], cv_rate