Comparing SAG to Momentum
----

In [2]:
import torch
from torch.autograd import Variable
import numpy as np

In [24]:
def f(w, x):
    return w[0] + w[1] * x + w[2] * x * x

def F(w, x):
    return sum(f(w, xi) for xi in x)

w = Variable(torch.Tensor([-1, 0, 1]), requires_grad=True)

In [27]:
x = Variable(torch.Tensor(np.arange(3)))

print f(w, x[0])
print F(w, x)

Variable containing:
-1
[torch.FloatTensor of size 1]

Variable containing:
 2
[torch.FloatTensor of size 1]



In [109]:
# estimate gradient
y = F(w, x)
w.grad.data.zero_()
y.backward()
print 'Exact gradient', w.grad

i = np.random.randint(len(x))
y = f(w, x[i])
w.grad.data.zero_()
y.backward()
print 'Gradient at x[{}]'.format(i), w.grad


Exact gradient Variable containing:
 3
 3
 5
[torch.FloatTensor of size 3]

Gradient at x[2] Variable containing:
 1
 2
 4
[torch.FloatTensor of size 3]



In [290]:


def generate(n, T):
    global indices
    global binary_indices
    indices = np.random.randint(n, size=(n_samples, T))
    # assume each index has been sampled once
    for t in xrange(n):
        indices[:, t] = t
    
    binary_indices = np.zeros((n_samples, T, n))
        for sample_idx in xrange(n_samples):
        binary_indices[sample_idx, np.arange(T), indices[sample_idx]] = 1.

    binary_indices.argmax(axis=0)

    propagate = np.zeros((n, T, T))
    for u in xrange(T):
        sub = binary_indices[u:, :]
        #print 'sub', sub
        sub_argmax = sub.argmax(axis=0)
        #print 'argmax', sub_argmax
        for i in xrange(n):
            if sub_argmax[i] < T:
                propagate[i, u, sub_argmax[i]] = 1

    #print 'propagate indices', propagate
    return indices, binary_indices, propagate

for sample in xrange(100):
    indices, binary_indices, propagate = generate(n, T)
    propagate.sum(axis=0)

IndentationError: unexpected indent (<ipython-input-290-c4f9455aba82>, line 14)

In [311]:
n = 30
T = 600
n_samples = 200

In [312]:
def get_backpointer(n_samples, T, n):

    indices = np.random.randint(n, size=(n_samples, T))
    # assume each index has been sampled once
    for t in xrange(n):
        indices[:, t] = t

    binary_indices = np.zeros((n_samples, T, n))
    for sample_idx in xrange(n_samples):
        binary_indices[sample_idx, np.arange(T), indices[sample_idx]] = 1.


    backpointer = 666 * np.ones((n_samples, T, n)).astype(int)

    for t in xrange(T):
        up_to_t = binary_indices[:, :t+1]
        #print 'submatrix', up_to_t
        backpointer[:, t, :] = up_to_t[:, ::-1].argmax(axis=1)
        #print 'backpointer', backpointer[:, t, :]  # there is always a 1 (see generate)
        
    return backpointer

backpointer = get_backpointer(n_samples, T, n)

In [313]:
i = 0  # here we sum only over one n
all_dst_abs = []
all_dst_sum = []
for u in xrange(T):
    momentum_acc = []
    sag_acc = []
    for t in xrange(u, T):
        momentum_acc.append(np.ones(n_samples) * beta**(t-u))
        sag_acc.append(backpointer[:, t, i] == (t-u))

    momentum_acc = np.asarray(momentum_acc).T
    sag_acc = np.asarray(sag_acc).astype(np.float).T

    #print 'mommentum', momentum_acc
    #print 'sag', sag_acc

    dst_abs = np.sum(np.abs(momentum_acc - sag_acc), axis=1)
    dst_sum = np.abs(np.sum(momentum_acc-sag_acc, axis=1))

    all_dst_abs.append(dst_abs)
    all_dst_sum.append(dst_sum)
    
all_dst_abs = np.asarray(all_dst_abs)
all_dst_sum = np.asarray(all_dst_sum)
print 'Absolute distance {} +/- {}'.format(all_dst_abs.mean(), all_dst_abs.std())
print 'Sum distance {} +/- {}'.format(all_dst_sum.mean(), all_dst_sum.std())


Absolute distance 3.80413568317 +/- 6.7370711398
Sum distance 3.79881031162 +/- 6.73828417109


In [204]:
u_ = np.arange(T)
t_ = np.arange(T)
u_mesh, t_mesh = np.meshgrid(u_,t_)

t_mesh - u_mesh

beta = (1-1./n)
momentum_moment = beta ** (t_mesh - u_mesh) * (t_mesh>=u_mesh)
momentum_moment = momentum_moment[np.newaxis, :, :] * np.ones((n, 1, 1))

In [205]:
momentum_moment

array([[[ 1.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.66666667,  1.        ,  0.        ,  0.        ,  0.        ],
        [ 0.44444444,  0.66666667,  1.        ,  0.        ,  0.        ],
        [ 0.2962963 ,  0.44444444,  0.66666667,  1.        ,  0.        ],
        [ 0.19753086,  0.2962963 ,  0.44444444,  0.66666667,  1.        ]],

       [[ 1.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.66666667,  1.        ,  0.        ,  0.        ,  0.        ],
        [ 0.44444444,  0.66666667,  1.        ,  0.        ,  0.        ],
        [ 0.2962963 ,  0.44444444,  0.66666667,  1.        ,  0.        ],
        [ 0.19753086,  0.2962963 ,  0.44444444,  0.66666667,  1.        ]],

       [[ 1.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.66666667,  1.        ,  0.        ,  0.        ,  0.        ],
        [ 0.44444444,  0.66666667,  1.        ,  0.        ,  0.        ],
        [ 0.2962963 ,

In [176]:
t_mesh

array([[ True, False, False, ..., False, False, False],
       [False,  True, False, ..., False, False, False],
       [False, False,  True, ..., False, False, False],
       ..., 
       [False, False, False, ...,  True, False, False],
       [False, False, False, ..., False,  True, False],
       [False, False, False, ..., False, False,  True]], dtype=bool)