In [80]:
import torch
from mirtorch.prox import prox
from mirtorch.linear import *
import numpy as np

In [81]:
#initial implementation
class FISTA():
    def __init__(self, max_iter, step, fval, grad, prox, momentum = 1, restart = False):
        self.max_iter = max_iter
        self.step = step
        self.fval = fval
        self.grad = grad
        self.prox = prox
        self.momentum = momentum
        self.restart = restart
    '''
    Experimenting different implementation
    '''
    def _update(self, iter):
        #looks like one possible implementation from jeff lecture slides
        step_prev = self.step
        self.step = (1 + np.sqrt(1 + 4*step_prev*step_prev))/2
        self.momentum = (1-step_prev)/(1+step_prev)

    def run_alg(self, x0):
        x_curr = x0
        z_curr = x0 
        for i in range(self.max_iter):
            x_prev = x_curr
            z_prev = z_curr
            #compute new z_k and x_k
            #print(x_curr.shape, z_curr.shape)
            z_curr = self.prox(x_prev - self.grad(x_prev)) 
            x_curr = z_curr + self.momentum * (z_curr - z_prev)
            print(f'After iter {i}, z = {z_curr}, x = {x_curr}')
            #update momentum value for next iteration
            self._update(i)
        return x_curr

In [82]:
class A(LinearMap):
    def __init__(self):
        super().__init__([2], [2])
        self.mat = torch.Tensor([[2,0],[0,1]])
    def _apply(self, x):
        return torch.matmul(self.mat,x)
    def _apply_adjoint(self,x):
        return self._apply(x)
class grad(LinearMap):
    def __init__(self):
        super().__init__([2],[2])
        self._A = A()
    def _apply(self, x):
        y = self._A(x)
        #print(y)
        y -= torch.Tensor([2,1])
        #print(y)
        y = self._A(y)
        return 2 * y
    

In [83]:
f = FISTA(10, 0, None, grad(), prox.L1Regularizer(.5))

In [84]:
f.run_alg(torch.Tensor([1,1]))

After iter 0, z = tensor([0.5000, 0.5000]), x = tensor([0., 0.])
After iter 1, z = tensor([7.5000, 1.5000]), x = tensor([14.5000,  2.5000])
After iter 2, z = tensor([-93.,   0.]), x = tensor([-93.,   0.])
After iter 3, z = tensor([658.5000,   1.5000]), x = tensor([481.0949,   1.1459])
After iter 4, z = tensor([-3.3592e+03,  3.5410e-01]), x = tensor([-1.8576e+03,  7.8236e-01])
After iter 5, z = tensor([1.3011e+04, 7.1764e-01]), x = tensor([5.3720e+03, 5.4800e-01])
After iter 6, z = tensor([-3.7596e+04,  9.5200e-01]), x = tensor([-1.0555e+04,  8.2677e-01])
After iter 7, z = tensor([7.3893e+04, 6.7323e-01]), x = tensor([8.5446e+03, 8.3663e-01])
After iter 8, z = tensor([-5.9804e+04,  6.6337e-01]), x = tensor([2.4053e+04, 6.6955e-01])
After iter 9, z = tensor([-1.6837e+05,  8.3045e-01]), x = tensor([-9.6644e+04,  7.2007e-01])


tensor([-9.6644e+04,  7.2007e-01])