In [2]:
import torch
from mirtorch.prox import prox
from mirtorch.linear import *
import numpy as np

In [3]:
#initial implementation
class FISTA():
    def __init__(self, max_iter, step, fval, grad, prox, momentum = 1, restart = False):
        self.max_iter = max_iter
        self.step = step
        self.fval = fval
        self.grad = grad
        self.prox = prox
        self.momentum = momentum
        self.restart = restart
    '''
    Experimenting different implementation
    '''
    def _update(self):
        #looks like one possible implementation from jeff lecture slides
        step_prev = self.step
        self.step = (1 + np.sqrt(1 + 4*step_prev*step_prev))/2
        self.momentum = (step_prev-1)/self.step
        #print(f'Now step={self.step}, momentum={self.momentum}')

    def run_alg(self, x0):
        x_curr = x0
        z_curr = x0 
        for i in range(self.max_iter):
            print(f'Before iter {i}, z = {z_curr}, x = {x_curr}')
            x_prev = x_curr
            z_prev = z_curr
            #compute new z_k and x_k
            #print(x_curr.shape, z_curr.shape)
            print(self.grad(x_prev))
            print(x_prev - self.grad(x_prev))
            z_curr = self.prox(x_prev - self.grad(x_prev))
            #print(f'In iter {i}, z_curr calculated to be {z_curr}')
            #x_curr = z_curr + self.momentum * (z_curr - z_prev)
            x_curr = (1-self.momentum)*z_curr + self.momentum*z_prev
            #print(f'In iter {i}, x_curr calculated to be {x_curr}')
            #update momentum value for next iteration
            self._update()
        return x_curr

In [4]:
class A(LinearMap):
    def __init__(self):
        super().__init__([2], [2])
        self.mat = torch.Tensor([[2,0],[0,1]])
    def _apply(self, x):
        return torch.matmul(self.mat,x)
    def _apply_adjoint(self,x):
        return self._apply(x)
    
class grad(LinearMap):
    def __init__(self):
        super().__init__([2],[2])
        self._A = A()
    def _apply(self, x):
        y = self._A(x)
        #print(y)
        y -= torch.Tensor([2,1])
        #print(y)
        y = self._A(y)
        return 2 * y
    

In [5]:
f = FISTA(10, 0, None, grad(), prox.L1Regularizer(.5))

In [None]:
f.run_alg(torch.Tensor([1,1]))

Before iter 0, z = tensor([1., 1.]), x = tensor([1., 1.])
tensor([0., 0.])
tensor([1., 1.])


 


Before iter 1, z = tensor([0.5000, 0.5000]), x = tensor([1., 1.])
tensor([0., 0.])
tensor([1., 1.])


 


Before iter 2, z = tensor([0.5000, 0.5000]), x = tensor([0.5000, 0.5000])
tensor([-4., -1.])
tensor([4.5000, 1.5000])


 


Before iter 3, z = tensor([4., 1.]), x = tensor([4., 1.])
tensor([24.,  0.])
tensor([-20.,   1.])
