In [1]:
from video_tools import load_video, save_video
import numpy as np
import spams

In [9]:
def Omega(M, video_size=(100,100)):
    """
    Valor de la norma Omega de la matriz D. los valores de frame_size se usan
    para calcular los indices g que necesita la norma
    """
    video = M.reshape((frame_size[0], frame_size[1], M.shape[1]))
    
    z = np.zeros(M.shape[1])
    for i in range(1, frame_size[0] - 1): 
        for j in range(1, frame_size[1] - 1): 
            window = video[i-1:i+2, j-1:j+2, :]
            vec_window = window.reshape(9, window.shape[2])
            z+= np.max(vec_window, axis=0)
    
    omega = np.sum(z)
    
    return omega

def get_groups(video_size):
    out = None

    for i in range(1, video_size[0] - 1): 
        for j in range(1, video_size[1] - 1): 
            print(i, j)
            window = np.zeros(video_size)
            window[i-1:i+2, j-1:j+2] += 1

            if out is None:
                out = window.flatten()
            else:
                out = np.column_stack((out, window.flatten()))
    
    return out

def augmented_lagrangian(L, S, Y, mu, lamb, D):
    nuc = np.linalg.norm(L, ord='nuc')
    omega = Omega(S)
    tmp = D - L - S
    frob = np.linalg.norm(tmp, ord='fro')
    return nuc + lamb * omega + np.sum(Y * tmp) + 0.5 * mu * frob ** 2

def J(Y, lamb):
    #if not scipy.sparse.issparse(Y):
    #    Y = scipy.sparse.csc_matrix(Y)
    Y_norm = np.linalg.norm(Y)
    return max(Y_norm, 1/lamb * Y_norm)


def soft_threshold(X, eps):
    out = np.zeros(X.shape)
    out[X > eps] = X[X > eps] - eps
    out[X < -eps] = X[X < -eps] + eps
    return out


def video_segmentation(D, video_size, max_iter=500, mu_0=0.001, lamb=0.01):
    Y_k = D / J(D, lamb)
    S_k = np.zeros(D.shape)
    #S_k = D - Y_k
    
    print('Getting groups')
    groups = get_groups(video_size)
    
    print('Building tree')
    tree = {'eta_g': np.ones(groups.shape[1]), 
            'groups': np.zeros((groups.shape[1], groups.shape[1])), 
            'groups_vars': groups}
    
    err = [] 
    for i in range(1, max_iter):
        mu_k = mu_0 * i  # mu_i crece de forma lineal

        print("Iteracion {}".format(i))
        # Resolver L_{k+1} = argmin_L L(L, S_k, Y_k, mu_k)
        tmp = D - S_k + (1 / mu_k) * Y_k
        #print('Converting to sparse')
        #tmp = scipy.sparse.csc_matrix(tmp)
        
        print("Starting SVD")
        U, S, V = np.linalg.svd(D - S_k + (1 / mu_k) * Y_k, full_matrices=False)
        print("Applying Soft-threshold")
        S_shrink = soft_threshold(S, (1 / mu_k))
        L_k = np.dot(U * S_shrink, V)

        # Resolver S_{k+1} = argmin_S L(L_k, S, Y_k, mu_k)
        G_S = D - L_k + (1 / mu_k) * Y_k
        
        S_k = spams.proximalTree(G_S, tree)
        
        print("Updating Y_k")
        # Paso en valor del dual
        Y_k += mu_k * (D - L_k - S_k)
        print(augmented_lagrangian(L=L_k, S=S_k, lamb=lamb, mu=mu_k, Y=Y_k, D=D))

        err.append(np.linalg.norm(D - L_k - S_k, ord='fro') / np.linalg.norm(D, ord='fro'))
        if err[-1] < 1e-10:
            print('Method converged in iteration {}'.format(i))
            break


    L = L_k.copy()
    S = S_k.copy()
    
    return L, S, err

In [None]:
import time

D, video_size = load_video("test.avi")

start = time.time()
print('Started video segmentation')

L, S, err = video_segmentation(D, video_size, max_iter=100, mu_0=.001, lamb=0.001)

print('Finished video segmentation')
print(time.time() - start)

Started video segmentation
Getting groups
1 1
1 2
1 3
1 4
1 5
1 6
1 7
1 8
1 9
1 10
1 11
1 12
1 13
1 14
1 15
1 16
1 17
1 18
1 19
1 20
1 21
1 22
1 23
1 24
1 25
1 26
1 27
1 28
1 29
1 30
1 31
1 32
1 33
1 34
1 35
1 36
1 37
1 38
1 39
1 40
1 41
1 42
1 43
1 44
1 45
1 46
1 47
1 48
1 49
1 50
1 51
1 52
1 53
1 54
1 55
1 56
1 57
1 58
1 59
1 60
1 61
1 62
1 63
1 64
1 65
1 66
1 67
1 68
1 69
1 70
1 71
1 72
1 73
1 74
1 75
1 76
1 77
1 78
1 79
1 80
1 81
1 82
1 83
1 84
1 85
1 86
1 87
1 88
1 89
1 90
1 91
1 92
1 93
1 94
1 95
1 96
1 97
1 98
2 1
2 2
2 3
2 4
2 5
2 6
2 7
2 8
2 9
2 10
2 11
2 12
2 13
2 14
2 15
2 16
2 17
2 18
2 19
2 20
2 21
2 22
2 23
2 24
2 25
2 26
2 27
2 28
2 29
2 30
2 31
2 32
2 33
2 34
2 35
2 36
2 37
2 38
2 39
2 40
2 41
2 42
2 43
2 44
2 45
2 46
2 47
2 48
2 49
2 50
2 51
2 52
2 53
2 54
2 55
2 56
2 57
2 58
2 59
2 60
2 61
2 62
2 63
2 64
2 65
2 66
2 67
2 68
2 69
2 70
2 71
2 72
2 73
2 74
2 75
2 76
2 77
2 78
2 79
2 80
2 81
2 82
2 83
2 84
2 85
2 86
2 87
2 88
2 89
2 90
2 91
2 92
2 93
2 94
2 95
2 96
2 97
2

array([ 1.,  1.,  1.,  1.])