In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec


In [None]:
def grad_function(X, Y, dl_dw, dl_duv, args):
    canonical_grad = dl_dw(X, Y, args)
    parameterization_grad = dl_duv(X, Y, args)
    return canonical_grad[0] * parameterization_grad[0], canonical_grad[1] * parameterization_grad[1] 


def point_canonical_grad(X, Y, args):
    return args['opt_x'] - X, args['opt_y'] - Y


def line_canonical_grad(X, Y, args):
    dist_grad = (X*args['opt_a'] + Y*args['opt_b'] + args['opt_c']) / (args['opt_a']**2 + args['opt_b']**2)
    return - dist_grad * args['opt_a'], - dist_grad * args['opt_b']


def identity_param_grad(X, Y, args):
    return np.ones(np.shape(X)), np.ones(np.shape(Y))


def deep_param_grad(X, Y, args):
    return args['N']**2 * X**(2-(2/args['N'])), args['N']**2 * Y**(2-(2/args['N']))


def log_param_grad(X, Y, args):
    return np.exp(-2*X), np.exp(-2*Y)


def exp_param_grad(X, Y, args):
    return X**2, Y**2
    

def polar_point_grad(X, Y, args):
    R = np.sqrt(X**2 + Y**2)
    theta = np.arctan(Y / (X + 1e-10))
    
    dl_dx, dl_dy = point_canonical_grad(X, Y, args)
    dx_dr = np.cos(theta)
    dx_dtheta = -R*np.sin(theta)
    dy_dr = np.sin(theta)
    dy_dtheta = R*np.cos(theta)
    
    dl_dr = dl_dx*dx_dr + dl_dy*dy_dr
    dl_dtheta = dl_dx*dx_dtheta + dl_dy*dy_dtheta
        
    dx_dt = dx_dr*dl_dr + dx_dtheta*dl_dtheta
    dy_dt = dy_dr*dl_dr + dy_dtheta*dl_dtheta
    
    return dx_dt, dy_dt 

    
def polar_line_grad(X, Y, args):
    R = np.sqrt(X**2 + Y**2)
    theta = np.arctan(Y / (X + 1e-10))
    
    dl_dx, dl_dy = line_canonical_grad(X, Y, args)
    dx_dr = np.cos(theta)
    dx_dtheta = -R*np.sin(theta)
    dy_dr = np.sin(theta)
    dy_dtheta = R*np.cos(theta)
    
    dl_dr = dl_dx*dx_dr + dl_dy*dy_dr
    dl_dtheta = dl_dx*dx_dtheta + dl_dy*dy_dtheta
        
    dx_dt = dx_dr*dl_dr + dx_dtheta*dl_dtheta
    dy_dt = dy_dr*dl_dr + dy_dtheta*dl_dtheta
    
    return dx_dt, dy_dt 


def plot_gradient_field(X, Y, X_grads, Y_grads, grad_norms, title, plot_type, 
                        color, args, additional_plot_params=[], path=None):
    
    for p in additional_plot_params:
        title += f', {p}={args[p]}'
    if path and len(additional_plot_params) > 0:
        path = '_'.join([str(args[p]) for p in additional_plot_params]) + '_' + path
    if color:
        fig = plt.figure(figsize=(12, 9))
        plt.streamplot(X, Y, X_grads, Y_grads, color=grad_norms, density=3, cmap='copper_r')
        plt.colorbar()
    else:
        fig = plt.figure(figsize=(9, 9))
        plt.streamplot(X, Y, X_grads, Y_grads, color='k', density=3)
        
    if plot_type == 'point':
        plt.plot([args['opt_x']], [args['opt_y']], 'go')
    else:
        plt.plot(np.arange(0, 10, 0.2), (-args['opt_c'] -args['opt_a']*np.arange(0, 10, 0.2)) / args['opt_b'], 'g')
    plt.xlabel(r'$w_{1}$', fontsize='xx-large')
    plt.ylabel(r'$w_{2}$', fontsize='xx-large')
    plt.title(title, fontsize='xx-large')
    plt.xlim(0, 10)
    plt.ylim(0, 10)
    plt.tight_layout()
    if path:
        plt.savefig(path)
    plt.show()



In [None]:
#### gradient parameters #
#=====================#
dl_du = polar_line_grad
# dl_dw = line_canonical_grad
dl_dw = identity_param_grad
args = {
        'opt_x':3, 
        'opt_y':6,
#         'opt_a':-1, 
#         'opt_b':-1,
#         'opt_c':10, 
        'opt_a':-1, 
        'opt_b':-2,
        'opt_c':15, 
        'N':0.25,
       }
# additional_plot_params = ['N']
additional_plot_params = []


# plot parameters #
#=================#
plot_type = 'line'
title = f"Polar Parameterization, {plot_type} loss"
color = False
color_str = 'color' if color else 'bw'
path = f'{plot_type}_{dl_du.__name__}_{color_str}.png'


# calculate the gradient field #
#==============================#
w = 10
Y, X = np.mgrid[0:w:100j, 0:w:100j]
X_grads, Y_grads = grad_function(X, Y, dl_dw, dl_du, args)
grad_norms = np.sqrt(X_grads**2 + Y_grads**2)

# plot gradient field #
#=====================#
plot_gradient_field(X, Y, X_grads, Y_grads, grad_norms, title, plot_type, color, args, additional_plot_params, path)

