In [None]:

import os
import os.path as osp
import torch.utils.data
import numpy as np
import tqdm
from tqdm.contrib import tenumerate
from utils import load_data, Model
import pickle as pkl


# Train BFGS

In [None]:
eval_every = 0.1
check_step_ratio = 0.01
stop_step_norm = 1e-5
def train_bfgs(batch_size, lr, loss_lambda, num_epochs, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    # data
    train_samples, test_samples = load_data('data1_sub.mat')
    train_data = np.stack([sample['x'] for sample in train_samples])
    test_data = np.stack([sample['x'] for sample in test_samples])
    generator = torch.Generator('cpu')
    generator.manual_seed(42)
    train_samples =[ {'x': i, 'y': sample['y']} for i, sample in enumerate(train_samples)] # using index as the kernel would be precomputed
    test_samples =[ {'x': i, 'y': sample['y']} for i, sample in enumerate(test_samples)] # using index as the kernel would be precomputed
    train_loader = torch.utils.data.DataLoader(
        train_samples, batch_size=batch_size, shuffle=True, generator=generator
    )

    test_loader = torch.utils.data.DataLoader(
        test_samples, batch_size=1, shuffle=True, generator=generator,
    )


    # model
    model = Model(1, train_data=train_data, test_data=test_data, loss_lambda=loss_lambda)
    accs = []
    epoch_losses = []

    def eval(num_eval_iters=None):
        num_correct, num_samples = 0, 0
        num_eval_iters = len(test_loader) if num_eval_iters is None else num_eval_iters
        for i, batch in enumerate(test_loader):
            if i > num_eval_iters:
                break
            x = batch['x'].numpy()
            y = batch['y'].numpy()
            pred = model.predict(x)
            pred = pred > 0
            gt = y > 0
            num_correct += np.sum(gt == pred)
            num_samples += y.shape[0]
        acc = num_correct / num_samples
        return acc

    total_steps = len(train_loader) * num_epochs

    num_train_iters = len(train_loader)
    for epoch in range(num_epochs):
        batch_num_samples = []
        batch_losses = []
        grad_norms = []
        for i, batch in enumerate(train_loader):
            x = batch['x'].numpy()
            y = batch['y'].numpy()
            returns = model.get_loss(x, y)
            # cur_lr = lr * max((total_steps - epoch * num_train_iters - i) / total_steps, 0.001)
            cur_lr = lr
            grad_norm = model.bfgs(x, y, returns, cur_lr)
            loss = returns[0]
            grad_norms.append(grad_norm)
            batch_losses.append(loss.item())
            batch_num_samples.append(x.shape[0])
            # if i != 0 and i % int(num_train_iters * eval_every) == 0:
            #     print(f'evaluating on epoch {epoch}, step {i} ')
            #     acc = eval(100)
            #     last_n = int(num_train_iters * eval_every)
            #     interval_loss = np.sum(batch_losses[-last_n: ]) / np.sum(batch_num_samples[-last_n: ])
            #     print(f'on epoch {epoch}, step {i}: evaluating interval train loss: {interval_loss} '
            #           f'eval acc: {100 * acc: .02f}%')
            # if i != 0 and i % int(num_train_iters * check_step_ratio) == 0:
            #     last_n = int(num_train_iters * check_step_ratio) 
            #     avg_step_norm = np.mean(grad_norms[-int(num_train_iters * check_step_ratio):]) 
            #     print(f'on epoch {epoch}, step {i}: interval avg_step_norm {avg_step_norm}')
            #     if avg_step_norm < stop_step_norm:
            #         print(f'on epoch {epoch}, step {i}: avg_step_norm ({avg_step_norm}) < stop_step_norm ({stop_step_norm}), breaking...')
            #         break

        epoch_losses.append(np.sum(batch_losses))
        acc = eval()
        accs.append(acc)
        print(f'on epoch {epoch}: acc: {100 * acc: .02f}%')
        print(f'on epoch {epoch}: epoch total loss: {epoch_losses[-1]: .02f}')
        # avg_step_norm = np.mean(grad_norms[-int(num_train_iters * check_step_ratio):]) 
        # if avg_step_norm < stop_step_norm:
        #     print(f'on epoch {epoch}, step {i}: avg_step_norm ({avg_step_norm}) < stop_step_norm ({stop_step_norm}), breaking...')
        #     epoch_losses.append(np.sum(batch_losses) / np.sum(batch_num_samples))
        #     break
        
    # with open(osp.join(save_dir, f'weight_epoch{epoch}.pkl'), 'wb') as f:
    #     pkl.dump(model.weights, f)

    return epoch_losses, accs
            

# Main Hyper Parameters

In [None]:
experiment_result_dict = {}
loss_lambda=1e-3
num_epochs=200
# num_epochs=100
lr=10

# BFGS

In [None]:

# might be some bug in bfgs
epoch_losses, accs = train_bfgs(
    batch_size = 4000,
    save_dir = 'bfgs',
    lr = lr, # seems like need to down scale the learning rate for batch size 100???
    loss_lambda = 0,
    num_epochs = num_epochs,
)
# converges when bsz = 15
# doesn't converges  when bsz >= 20
experiment_result_dict['batch_size=100'] = {
    'accs': accs,
    'epoch_losses': epoch_losses
}


In [None]:
# epoch_losses, accs = train_bfgs(
#     batch_size = 1,
#     save_dir = 'bfgs',
#     lr = lr, # seems like need to down scale the learning rate for batch size 100???
#     loss_lambda = 0,
#     num_epochs = num_epochs,
# )
# # converges when bsz = 15
# # doesn't converges  when bsz >= 20
# experiment_result_dict['batch_size=100'] = {
#     'accs': accs,
#     'epoch_losses': epoch_losses
# }


# Plotting

In [None]:
experiment_result_dict

In [None]:
# accs
from matplotlib import pyplot as plt

for experiment_name, experiment_result in experiment_result_dict.items():
    plt.plot(experiment_result['accs'], label=experiment_name)

plt.title('sgd accuracy curve')
plt.show()


# epoch losses

for experiment_name, experiment_result in experiment_result_dict.items():
    plt.plot(np.array(experiment_result['epoch_losses']), label=experiment_name)

plt.title('epoch losses')
plt.legend()
plt.show()