In [None]:
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm

import seaborn as sns
import matplotlib.pyplot as plt

torch.manual_seed(123456)
np.random.seed(123456)

# CHOOSE EXMAPLE
# switching example 
EXAMPLE = "switching" # "switching", "selction", or "approximate_optimum"
FULL_BATCH = True # True for full batch, False for mini-batch

In [None]:
input_dim = 20
hidden_dim = 512
output_dim = 7

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)  
        self.fc2 = nn.Linear(hidden_dim, output_dim)  
    def forward(self, x):
        x = torch.relu(self.fc1(x))  
        x = self.fc2(x)
        return x

In [None]:
# LOSSES 

def poly_loss(outputs, y_tensor, H, alpha, epsilon=0.0):
    """
    Custom loss function that implements (outputs-y_tensor)T H (outputs-y_tensor) ** n
    Args:
        outputs (torch.tensor): Model outputs
        y_tensor (torch.tensor): Target tensor
        H (torch.tensor): Hessian matrix
        alpha (int): Polynomial degree
    Returns:
        loss (torch.tensor): Custom loss value
    """

    # Calculate the difference between outputs and target tensor
    diff = (outputs - y_tensor + epsilon)# [BATCH, OUTDIM]
    # Calculate the quadratic form using the Hessian matrix
    quad_form = torch.matmul(diff, H) # [BATCH, OUTDIM]    
    quad_form = torch.sum(quad_form * diff, dim=1) # [BATCH]
    # Raise the result to the power of n -> effective power = 2n
    loss = quad_form ** alpha    
    return loss.mean()


if EXAMPLE == "switching":
    # switching example
    # When the loss is large, we should see a preference for the third loss function, via w3.
    # As the loss gets smaller, we should see a preference for the first loss function, via w1.
    # We show this switch in the second plot.
    H = torch.eye(output_dim)
    criterion1 = lambda outputs, y_tensor: poly_loss(outputs, y_tensor, H, alpha=1.0)
    criterion2 = lambda outputs, y_tensor: poly_loss(outputs, y_tensor, H, alpha=1.5)
    criterion3 = lambda outputs, y_tensor: poly_loss(outputs, y_tensor, H, alpha=2.0)

elif EXAMPLE == "selection":
    # selection example
    # Here we expect to see a preference for the first loss function, via w1 
    # due to the eigenvalues of the Hessian matrix.
    alpha = 1.
    H1 = torch.eye(output_dim)
    criterion1 = lambda outputs, y_tensor: poly_loss(outputs, y_tensor, H1, alpha)
    H2 = 0.01 * torch.eye(output_dim)
    H2[0,0] = 1.
    criterion2 = lambda outputs, y_tensor: poly_loss(outputs, y_tensor, H2, alpha)
    H3 = 0.0001 * torch.eye(output_dim)
    H3[0,0] = 1.
    criterion3 = lambda outputs, y_tensor: poly_loss(outputs, y_tensor, H3, alpha)

elif EXAMPLE == "approximate_optimum": 
    # not the same optimum exactly
    # Simialr to the switching example but with only an approximate optimum.
    H = torch.eye(output_dim)
    criterion1 = lambda outputs, y_tensor: poly_loss(outputs, y_tensor, H, alpha=1.0, epsilon=0.0)
    criterion2 = lambda outputs, y_tensor: poly_loss(outputs, y_tensor, H, alpha=1.5, epsilon=0.001)
    criterion3 = lambda outputs, y_tensor: poly_loss(outputs, y_tensor, H, alpha=2.0, epsilon=-0.001)



In [None]:
# GENERATE DATA

if FULL_BATCH: # full batch gradient decent
    n_gradient_steps = 200 
    batch_size = 100 
    X = np.repeat(np.random.uniform(-1, 1, (1, batch_size, input_dim)), repeats=n_gradient_steps, axis=0)
else: # mini-batch gradient decent
    n_gradient_steps = 1000
    batch_size = 1000
    X = np.random.uniform(-1, 1, (n_gradient_steps, batch_size, input_dim))


X_tensors = torch.from_numpy(X).float()
net = Net()
y_tensors = net(X_tensors).detach() / output_dim * 20. # make outputs larger to show the weight switch

In [None]:
def pamoo(outputs, loss1, loss2, loss3, w):
    
    jacobians = []
    for loss in [loss1, loss2, loss3]:
        grads = torch.autograd.grad(loss, outputs, create_graph=True) # [BATCH, OUTDIM]
        jacobians.append(grads[0].sum(dim=0)) # [OUTDIM]
    jacobian = torch.stack(jacobians, dim=0) # [NUM_LOSSES, OUTDIM]
    jacobian = jacobian.transpose(1,0) # [OUTDIM, NUM_LOSSES]
    diff = torch.stack([loss1, loss2, loss3],dim=0) # [NUM_LOSSES]
    
    A = jacobian.t() @ jacobian # [NUM_LOSSES, NUM_LOSSES]
    A += 0.0001 * torch.eye(A.shape[0])
    lr = 3e-3

    for i in range(1000): 
        gradient = 2 * diff - 2 * torch.matmul(A, w)  
        w = w + lr * gradient  
        w = torch.clamp(w, min=1e-6) # projecting w to R+
    return w
    
def camoo(outputs, loss1, loss2, loss3, w):
    del w
    # Hutchinson approximation to Hessian
    num_samples = 10
    hessians = []
    for loss in [loss1, loss2, loss3]:
        # init hessisan
        diag_hessian = torch.zeros_like(outputs[0])# [OUTDIM]
        grads = torch.autograd.grad(loss, outputs, create_graph=True) # [BATCH, OUTDIM]
        for _ in range(num_samples):
            # An independent rademacher random variable with the same shape as the parameters of the model
            z = 2 * torch.randint_like(outputs, high=2, memory_format=torch.preserve_format) - 1 # [BATCH, OUTDIM]

            # This gives the Hessian vector product Hv
            h_z = torch.autograd.grad(
                grads, [outputs] , grad_outputs=[z], only_inputs=True, retain_graph=True
            )[0] # [BATCH, OUTDIM]

            # approximate the expected values of z*(H@z), clip negative values
            diag_hessian_update = h_z * z / num_samples # [BATCH, OUTDIM]
            # when calculating hessian with activations the first dimension is the batch size
            # we average over it and multiply by 1/batch_size normalization
            diag_hessian_update = diag_hessian_update.sum(dim=0) # [OUTDIM]
            diag_hessian += diag_hessian_update # [OUTDIM]

        hessians.append(diag_hessian) 
    hessians = torch.stack(hessians, dim=0) # [NUM_LOSSES, OUTDIM]
    
    # primal-dual optimization
    w = torch.tensor([1.0, 1.0, 1.0])
    tau = 0.01
    lr = 1 / (2 * hessians.abs().max() + tau).item()
    output_dim = outputs[0].size(0)
    q = torch.ones(output_dim) / output_dim # [OUTDIM]
    
    
    for i in range(100): # set iterations high to get maximum visualization effect
        # Get extra w
        A = hessians @ q
        logits_A = lr * (A - torch.max(A))
        extra_w = torch.pow(w, exponent=1- lr * tau) * torch.exp(logits_A)
        extra_w = extra_w / torch.sum(extra_w)

        # Get extra q
        B = w @ hessians
        logits_B = lr * (B - torch.min(B))
        extra_q = torch.pow(q, exponent=1- lr*tau) * torch.exp(-logits_B)
        extra_q = extra_q / torch.sum(extra_q)

        # Update weights via gradient ascent
        A = hessians @ extra_q
        logits_A = lr * (A - torch.max(A))
        w = torch.pow(w, exponent=1- lr*tau) * torch.exp(logits_A)
        w = w / torch.sum(w)

        # Update q via gradient descent
        B = extra_w @ hessians
        logits_B = lr * (B - torch.min(B))
        q = torch.pow(q, exponent=1- lr*tau) * torch.exp(-logits_B)
        q = q / torch.sum(q)
    
    return w


In [None]:

def run_experiment(method=pamoo, optimizer= torch.optim.SGD, lr=0.005):
    # Initialize the neural network with random weights
    net_learn = Net()
    # Weight initialization
    w = torch.tensor([1.0, 1.0, 1.0])
    # Define the optimizer
    optimizer = optimizer(net_learn.parameters(), lr=lr)
    # define mse
    mse_loss = nn.MSELoss()
 
    # Initialize dictionaries to store results
    results = {
        'weighted_losses': [],
        'MSE': [],
        'losses': {'loss1': [], 'loss2': [], 'loss3': []},
        'weights': {'w1': [], 'w2': [], 'w3': []}
    }
    
    # Train the network
    for _, (X_tensor, y_tensor) in tqdm(enumerate(zip(X_tensors, y_tensors)), total=len(X_tensors), desc='Training'):
        # Forward pass
        outputs = net_learn(X_tensor)
        loss1 = criterion1(outputs, y_tensor)
        loss2 = criterion2(outputs, y_tensor)
        loss3 = criterion3(outputs, y_tensor)
        mse = mse_loss(outputs, y_tensor)
        # Update weights using the specified method
        if method:
            with torch.no_grad():
                w = method(outputs, loss1, loss2, loss3, w)
        # Compute weighted loss
        weighted_loss = w[0] * loss1 + w[1] * loss2 + w[2] * loss3
        optimizer.zero_grad()
        weighted_loss.backward()
        optimizer.step()
        # Log results
        results['weighted_losses'].append(weighted_loss.item())
        results['MSE'].append(mse.item())
        results['losses']['loss1'].append(loss1.item())
        results['losses']['loss2'].append(loss2.item())
        results['losses']['loss3'].append(loss3.item())
        results['weights']['w1'].append(w[0].item())
        results['weights']['w2'].append(w[1].item())
        results['weights']['w3'].append(w[2].item())
    return results


In [None]:
# Run experiments
sgd_results = run_experiment(method=None, optimizer= torch.optim.SGD,lr=0.0005)
camoo_results = run_experiment(method=camoo,optimizer= torch.optim.SGD,lr=0.0015) # camoo lr is large because the weights get normalized to 1, hence lr = num_losses * lr
pamoo_results = run_experiment(method=pamoo,optimizer= torch.optim.SGD,lr=0.0005)
adam_results = run_experiment(method=None, optimizer= torch.optim.Adam,lr=0.005)
cadam_results = run_experiment(method=camoo, optimizer= torch.optim.Adam,lr=0.015)
padam_results = run_experiment(method=pamoo, optimizer= torch.optim.Adam,lr=0.005)


In [None]:
def plot_results(results_dict, filename):
    # Create a figure with 2 rows and 2 columns
    fig, axs = plt.subplots(2, 2, figsize=(12, 8))
    
    # Plot individual losses
    for method_name, results in results_dict.items():
        axs[0, 0].plot(results['losses']['loss1'], label=method_name)
        axs[0, 1].plot(results['losses']['loss2'], label=method_name)
        axs[1, 0].plot(results['losses']['loss3'], label=method_name)
        axs[1, 1].plot(results['weighted_losses'], label=method_name)
    
    # Set titles
    if EXAMPLE == "switching":
        axs[0, 0].set_title('$f_1(x)=((y-t)^T H (y-t))^1$')
        axs[0, 1].set_title('$f_2(x)=(y-t)^T H (y-t)^{1.5}$')
        axs[1, 0].set_title('$f_3(x)=((y-t)^T H (y-t))^2$')
    elif EXAMPLE == "selection":
        axs[0, 0].set_title('$f_1(x)=(y-t)^T H_{1} (y-t)$')
        axs[0, 1].set_title('$f_2(x)=(y-t)^T H_{0.01} (y-t)$')
        axs[1, 0].set_title('$f_3(x)=(y-t)^T H_{0.0001} (y-t)$')
    elif EXAMPLE == "approximate_optimum":
        axs[0, 0].set_title('$f_1(x)=((y-t)^T H (y-t))$')
        axs[0, 1].set_title('$f_2(x)=((y-t+\epsilon)^T H (y-t+\epsilon))$')
        axs[1, 0].set_title('$f_3(x)=((y-t-\epsilon)^T H (y-t-\epsilon))$')
    axs[1, 1].set_title('$f(x)= \sum w_i f_i(x)$')

    
    # Set y-axis to log scale
    for ax in axs.flat:
        ax.set_yscale('log')
        ax.legend()
        #ax.set_ylim(1e-1, 1e0)
    
    # Layout so plots do not overlap
    fig.tight_layout()
    plt.savefig(filename, bbox_inches='tight')
    plt.show()


results_dict = {
    'EW-SGD': sgd_results,
    'CAMOO-SGD':camoo_results,
    'PAMOO-SGD': pamoo_results,
    'EW-ADAM': adam_results,
    'CAMOO-ADAM': cadam_results,
    'PAMOO-ADAM': padam_results,

    
}

plot_results(results_dict, 'all_plots')

In [None]:

def plot_weights(results, filename):
    # Create a figure and axis
    fig, ax1 = plt.subplots()
    # Plot average loss on left y-axis
    ax1.plot(results['MSE'], color='blue')
    ax1.set_xlabel('Gradient Steps')
    ax1.set_ylabel('Mean Squared Error', color='blue')
    ax1.tick_params(axis='y', labelcolor='blue')
    ax1.grid(False)  # Turn off grid for left subplot
    # ax1.set_yscale('log')
    # Create a new y-axis on the right
    ax2 = ax1.twinx()
    colors = plt.cm.Reds(np.linspace(0.5, 1, 3))  # Generate 3 different red tones
    if EXAMPLE == "switching":
        ax2.plot(results['weights']['w1'], color=colors[0], label='$w_1$  scales $f_1(x)=((y-t)^T H (y-t))$')
        ax2.plot(results['weights']['w2'], color=colors[1], label='$w_2$  scales $f_2(x)=(y-t)^T H (y-t)^{1.5}$')
        ax2.plot(results['weights']['w3'], color=colors[2], label='$w_3$  scales $f_3(x)=((y-t)^T H (y-t))^2$')
    elif EXAMPLE == "selection":
        ax2.plot(results['weights']['w1'], color=colors[0], label='$w_1$  scales $f_1(x)=(y-t)^T H_{1} (y-t)$')
        ax2.plot(results['weights']['w2'], color=colors[1], label='$w_2$  scales $f_2(x)=(y-t)^T H_{0.01} (y-t)$')
        ax2.plot(results['weights']['w3'], color=colors[2], label='$w_3$  scales $f_3(x)=(y-t)^T H_{0.0001} (y-t)$')
    elif EXAMPLE == "approximate_optimum":
        ax2.plot(results['weights']['w1'], color=colors[0], label='$w_1$  scales $f_1(x)=((y-t)^T H (y-t))$')
        ax2.plot(results['weights']['w2'], color=colors[1], label='$w_2$  scales $f_2(x)=((y-t+\epsilon)^T H (y-t+\epsilon))$')
        ax2.plot(results['weights']['w3'], color=colors[2], label='$w_3$  scales $f_3(x)=((y-t-\epsilon)^T H (y-t-\epsilon))$')
    ax2.set_ylabel('Weight for respective Loss Function', color='red')
    ax2.tick_params(axis='y', labelcolor='red')
    ax2.grid(False)  # Turn off grid for right subplot
    # Add legend
    ax2.legend(loc='upper right')
    plt.savefig(filename, bbox_inches='tight')
    plt.show()


# plot_weights(results_dict['PAMOO-SGD'],'weights_PAMOO-SGD')
# plot_weights(results_dict['PAMOO-ADAM'],'weights_PAMOO-ADAM')
plot_weights(results_dict['CAMOO-SGD'], 'weights_CAMOO-SGD')
# plot_weights(results_dict['CAMOO-ADAM'], 'weights_CAMOO-ADAM')


In [None]:
def plot_results(results_dict,filename):
    # Create a figure with 2 rows and 2 columns
    fig, axs = plt.subplots(1, 1, figsize=(6, 4))
    
    # Plot individual losses
    for method_name, results in results_dict.items():
        axs.plot(results['MSE'], label=method_name)

    axs.set_ylabel('Mean Squared Error')
    axs.set_xlabel('Gradient Steps')
    # set limits if needed
    # axs.set_ylim(1e-3, 2e0)
    # axs.set_xlim(0, 200)
    
    # Set y-axis to log scale
    axs.set_yscale('log')
    axs.legend()

    # Layout so plots do not overlap
    fig.tight_layout()
    plt.savefig(filename, bbox_inches='tight')
    plt.show()


results_dict = {
    'EW-SGD': sgd_results,
    'CAMOO-SGD':camoo_results,
    'PAMOO-SGD': pamoo_results,
    'EW-ADAM': adam_results,
    'CAMOO-ADAM': cadam_results,
    'PAMOO-ADAM': padam_results,
    
}

plot_results(results_dict, 'MSE_plot')