In [1]:
import torch
import torch.nn as nn
import pandas as pd
import math
import random
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

t = 0.5
seed = 42
MAX_EPOCHS = int(1e6)

# problem setup: min( ||p*a - t||^2 + \lambda*\sum{(p_i)*(1-p_i)*a_i^2})

def get_regularization(n, p, a, i=0):
    reg = torch.sum(p * (1-p) * torch.pow(a, 2))
    if i == 0:
        lmbda = 1.0
    elif i == 1:
        lmbda = 1.0/(n**math.log(n))
    elif i == 2:
        lmbda = 1.0/n
    elif i == 3:
        lmbda = 1/n**2
    else:
        lmbda = 0
    
    reg = lmbda * reg
    return reg

def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    print("Seeded everything: {}".format(seed))
    
set_seed(seed)

Seeded everything: 42


In [2]:
# decide regularizer

def subset_sum(num_samples=10, regularizer=0):
    a = torch.zeros(num_samples)
    p = nn.Parameter(torch.Tensor(a.size()))
    a.requires_grad = False
    p_list = []
    loss_list = []
    epoch_list = []

    optimizer = optim.Adam(
                [p],
                lr=0.01,
                weight_decay=0)

    # initialize a as uniform [-1, 1]
    nn.init.uniform_(a, a=-1, b=1)
    nn.init.normal_(p, mean=0.5)
    p.data = torch.clamp(p.data, 0.0, 1.0)

    for num_iter in range(MAX_EPOCHS):
        optimizer.zero_grad()
        loss = (t - torch.sum(p*a))**2 + get_regularization(num_samples, p, a, i=regularizer)
        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())
        epoch_list.append(num_iter)
        p_list.append(p.data)
        if num_iter % 100000 == 0 and num_iter != 0:
            print("Iteration={} | Loss={}".format(num_iter, loss))
        p.data = torch.clamp(p.data, 0.0, 1.0)
        if num_iter > 5 and loss.item() == loss_list[-2]:
            # print("Iteration={} | Converged".format(num_iter))
            break
        
    min_error = (1.0/num_samples)**np.log(num_samples)
    # print("Minimum Error = {}".format(min_error))
    results_df = pd.DataFrame({'epoch': epoch_list, 'loss': loss_list})
    return min_error, p, results_df

def get_frac_nonzeros(x, num_samples):
    num_middle = torch.sum(torch.gt(x,
                           torch.ones_like(x)*0) *
                           torch.lt(x,
                           torch.ones_like(x*(1)).int()))
    return 1.0*num_middle/num_samples

def get_dist_to_vertex(x):
    rounded_x = torch.gt(x, torch.ones_like(x)*0.5).int().float()
    return torch.norm(x-rounded_x)

In [3]:
avg_error_ratios = []
for num_samples in [1e2, 1e3,]:
    print("Num Samples = {}".format(num_samples))
    error_ratio_list = []
    for i in range(1):
        for reg in range(5):
            num_samples = int(num_samples)
            min_error, p_final, results_df = subset_sum(num_samples, 0)
            final_error = results_df.tail(1).loss.item()
            frac_nonzeros = get_frac_nonzeros(p_final, num_samples)
            dist_to_vertex = get_dist_to_vertex(p_final)
            print("Regularizer = {} | Error_ratio={} | Final error={} | Minimum error={} | frac_nonzeros={} | dist_to_veretx={}".\
                  format(reg, final_error/min_error, final_error, min_error, frac_nonzeros, dist_to_vertex))
            error_ratio = final_error/min_error
            error_ratio_list.append(error_ratio)
        avg_error_ratio = 1.0*sum(error_ratio_list)/len(error_ratio_list)
        avg_error_ratios.append(avg_error_ratio)

Num Samples = 100.0
Regularizer = 0 | Error_ratio=262636.16966381756 | Final error=0.00016181328101083636 | Minimum error=6.161119438269389e-10 | frac_nonzeros=0.0 | dist_to_veretx=0.0
Regularizer = 1 | Error_ratio=827252.2892519373 | Final error=0.0005096800159662962 | Minimum error=6.161119438269389e-10 | frac_nonzeros=0.0 | dist_to_veretx=0.0
Regularizer = 2 | Error_ratio=4426.375523687913 | Final error=2.727142828007345e-06 | Minimum error=6.161119438269389e-10 | frac_nonzeros=0.0 | dist_to_veretx=0.0
Regularizer = 3 | Error_ratio=11015.96240943746 | Final error=6.7870660132030025e-06 | Minimum error=6.161119438269389e-10 | frac_nonzeros=0.0 | dist_to_veretx=0.0
Regularizer = 4 | Error_ratio=1257533.1925051364 | Final error=0.0007747812196612358 | Minimum error=6.161119438269389e-10 | frac_nonzeros=0.0 | dist_to_veretx=0.0
Num Samples = 1000.0
Regularizer = 0 | Error_ratio=1287913568553.491 | Final error=2.4356836547667626e-09 | Minimum error=1.8911856464889798e-21 | frac_nonzeros=