In [None]:
def learn_HAT_adagrad(case, E_source, E_target, a, b, num_iter=1000, lr=1, dis=False, cost_function='abs',
                      H_source_known=None, A_known=None, T_source_known=None,
                      H_target_known=None, T_target_known=None,
                      random_seed=0, eps=1e-8, penalty_coeff=0.0, alpha=None):


    def cost_l21(H_source, A, T_source, H_target, T_target, E_source, E_target, case, lam=0.1, alpha=None):
        if alpha is None:
            alpha = 1.*np.shape(H_source)[0]/(np.shape(H_source)[0] + np.shape(H_target)[0])
        print(alpha)
        HAT_source = multiply_case(H_source, A, T_source, case)
        HAT_target = multiply_case(H_source, A, T_source, case)
        
        mask_source = ~np.isnan(E_source)
        mask_target = ~np.isnan(E_target)
        
        error_source = (HAT_source - E_source)[mask_source].flatten()
        error_target = (HAT_target - E_target)[mask_target].flatten()
        
        A_shape = A.shape
        A_flat = A.reshape(A_shape[0], A_shape[1]*A_shape[2])
        l1 = 0.
        for j in range(A_shape[0]):
            l1 = l1 + np.sqrt(np.square(A_flat[j,:]).sum())
            #print(j, l1)
        # return np.sqrt((error ** 2).mean()) + lam*np.sum(A[A!=0])
        return alpha* np.sqrt((error_source ** 2).mean()) + (1-alpha)* np.sqrt((error_target ** 2).mean()) + lam * l1


    mg = multigrad(cost_l21, argnums=[0, 1, 2, 3, 4])

    params_source = {}
    params_target = {}
    params_source['M'], params_source['N'], params_source['O'] = E_source.shape
    params_target['M'], params_target['N'], params_target['O'] = E_target.shape
    
    def create_initial_HAT(params):
        H_dim_chars = list(cases[case]['HA'].split(",")[0].strip())
        H_dim = tuple(params[x] for x in H_dim_chars)
        A_dim_chars = list(cases[case]['HA'].split(",")[1].split("-")[0].strip())
        A_dim = tuple(params[x] for x in A_dim_chars)
        T_dim_chars = list(cases[case]['HAT'].split(",")[1].split("-")[0].strip())
        T_dim = tuple(params[x] for x in T_dim_chars)
        H = np.random.rand(*H_dim)
        A = np.random.rand(*A_dim)
        T = np.random.rand(*T_dim)
        return H, A, T

    H_source, A, T_source = create_initial_HAT(params_source)
    H_target, A, T_target = create_initial_HAT(params_target)

    sum_square_gradients_H_source = np.zeros_like(H_source)
    sum_square_gradients_H_target = np.zeros_like(H_target)
    
    sum_square_gradients_A = np.zeros_like(A)
    
    sum_square_gradients_T_source = np.zeros_like(T_source)
    sum_square_gradients_T_target = np.zeros_like(T_target)

    Hs_source = [H_source.copy()]
    Hs_target = [H_target.copy()]
    
    As = [A.copy()]
    
    Ts_source = [T_source.copy()]
    Ts_target = [T_target.copy()]
    
    costs = [cost_l21(H_source, A, T_source, H_target, T_target, E_source, E_target, case, penalty_coeff, alpha=alpha)]
    
    HATs_source = [multiply_case(H_source, A, T_source, 2)]
    HATs_target = [multiply_case(H_target, A, T_target, 2)]

    # GD procedure
    for i in range(num_iter):
        del_h_source, del_a, del_t_source, del_h_target, del_t_target = mg(H_source,
                                                                           A, 
                                                                           T_source,
                                                                           H_target,
                                                                           T_target,
                                                                           E_source,
                                                                           E_target,
                                                                           case, penalty_coeff, alpha)
        sum_square_gradients_H_source += eps + np.square(del_h_source)
        sum_square_gradients_H_target += eps + np.square(del_h_target)
        
        sum_square_gradients_A += eps + np.square(del_a)
        
        sum_square_gradients_T_source += eps + np.square(del_t_source)
        sum_square_gradients_T_target += eps + np.square(del_t_target)

        lr_h_source = np.divide(lr, np.sqrt(sum_square_gradients_H_source))
        lr_h_target = np.divide(lr, np.sqrt(sum_square_gradients_H_target))
        
        lr_a = np.divide(lr, np.sqrt(sum_square_gradients_A))
        
        lr_t_source = np.divide(lr, np.sqrt(sum_square_gradients_T_source))
        lr_t_target = np.divide(lr, np.sqrt(sum_square_gradients_T_target))

        H_source -= lr_h_source * del_h_source
        H_target -= lr_h_target * del_h_target
        
        A -= lr_a * del_a
        
        T_source -= lr_t_source * del_t_source
        T_target -= lr_t_target * del_t_target
        
        # Projection to known values
        if H_source_known is not None:
            H_source = set_known(H_source, H_source_known)
        if H_target_known is not None:
            H_target = set_known(H_target, H_target_known)
            
        if A_known is not None:
            A = set_known(A, A_known)
            
        if T_source_known is not None:
            T_source = set_known(T_source, T_source_known)
        if T_target_known is not None:
            T_target = set_known(T_target, T_target_known)
        
        # Projection to non-negative space
        H_source[H_source < 0] = 1e-8
        H_target[H_target < 0] = 1e-8
        
        A[A < 0] = 1e-8
        
        T_source[T_source < 0] = 1e-8
        T_target[T_target < 0] = 1e-8

        As.append(A.copy())
        Ts_source.append(T_source.copy())
        Ts_target.append(T_target.copy())
        
        Hs_source.append(H_source.copy())
        Hs_target.append(H_target.copy())
        
        costs.append(cost_l21(H_source, A, T_source, H_target, T_target, E_source, E_target, 
                              case, penalty_coeff, alpha=alpha))
        HATs_source.append(multiply_case(H_source, A, T_source, 2))
        HATs_target.append(multiply_case(H_target, A, T_target, 2))

        if i % 500 == 0:
            if dis:
                print(cost_l21(H_source, A, T_source, H_target, T_target, E_source, E_target, 
                              case, penalty_coeff, alpha=alpha))
    return H_source, A, T_source, H_target, T_target, Hs_source, As, Ts_source,
    Hs_target, Ts_target,HATs_source, HATs_target, costs
