In [1]:
import argparse
import os, sys 
import time
import tabulate

sys.path.append('/home/mohit/understandingbdl/')
import torch
import torch.nn.functional as F
import torchvision
import numpy as np

from swag import data, models, utils, losses
from swag.posteriors import SWAG

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
#parser = argparse.ArgumentParser(description='SGD/SWA training')
savedir = ''
with open('swagModelPaths.txt', 'r') as f:
    swag_ckpts = f.readlines()[0][1:-1].split(' ')
dataset = 'CIFAR10'
data_path = '/home/mohit/cifar-normal/cifar-10-batches-py/'
use_test = True
batch_size = 256
num_workers = 4
model = 'PreResNet20'
label_arr = None
max_num_models = 20
swag_samples = 5
inference = 'low_rank_gaussian'
subspace = 'covariance'
no_cov_mat = False

device = None
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
torch.backends.cudnn.benchmark = True
model_cfg = getattr(models, model)
columns = ['swag', 'sample', 'te_loss', 'te_acc', 'ens_loss', 'ens_acc']

In [3]:
swag_ckpts

['/media/data_dump/Mohit/bayesianML/models/iter16/swag-300.pt',
 '/media/data_dump/Mohit/bayesianML/models/iter10/swag-300.pt',
 '/media/data_dump/Mohit/bayesianML/models/iter3/swag-300.pt',
 '/media/data_dump/Mohit/bayesianML/models/iter17/swag-300.pt',
 '/media/data_dump/Mohit/bayesianML/models/iter14/swag-300.pt',
 '/media/data_dump/Mohit/bayesianML/models/iter23/swag-300.pt',
 '/media/data_dump/Mohit/bayesianML/models/iter19/swag-300.pt',
 '/media/data_dump/Mohit/bayesianML/models/iter28/swag-300.pt',
 '/media/data_dump/Mohit/bayesianML/models/iter31/swag-300.pt',
 '/media/data_dump/Mohit/bayesianML/models/iter7/swag-300.pt',
 '/media/data_dump/Mohit/bayesianML/models/iter25/swag-300.pt',
 '/media/data_dump/Mohit/bayesianML/models/iter2/swag-300.pt',
 '/media/data_dump/Mohit/bayesianML/models/iter20/swag-300.pt',
 '/media/data_dump/Mohit/bayesianML/models/iter15/swag-300.pt',
 '/media/data_dump/Mohit/bayesianML/models/iter32/swag-300.pt',
 '/media/data_dump/Mohit/bayesianML/models/

In [4]:
loaders, num_classes = data.loaders(
    dataset,
    data_path,
    batch_size,
    num_workers,
    model_cfg.transform_train,
    model_cfg.transform_test,
    use_validation=not use_test,
    split_classes=None
    )

Files already downloaded and verified
You are going to run models on the test set. Are you sure?
Files already downloaded and verified


In [5]:
print('Using Noisy Test set')
corrupted_testset = np.load("/media/data_dump/Mohit/bayesianML/cifar-corrupted/motion_blur_5.npz")
loaders['test'].dataset.data = corrupted_testset["data"]
loaders['test'].dataset.targets = corrupted_testset["labels"]

Using Noisy Test set


In [6]:
if label_arr:
    print("Using labels from {}".format(label_arr))
    label_arr = np.load(label_arr)
    print("Corruption:", (loaders['train'].dataset.targets != label_arr).mean())
    loaders['train'].dataset.targets = label_arr

In [7]:
print('Preparing model')
model = model_cfg.base(*model_cfg.args, num_classes=num_classes,
                       **model_cfg.kwargs)
model.to(device)
print("Model has {} parameters".format(sum([p.numel() for p in model.parameters()])))


swag_model = SWAG(model_cfg.base,
                subspace, {'max_rank': max_num_models},
                *model_cfg.args, num_classes=num_classes,
                **model_cfg.kwargs)
swag_model.to(device)

Preparing model
Model has 272282 parameters


In [None]:
from ax import load
def getBoAcc(Pi, expPath):
    experiment = load_experiment(expPath)
    print("Computing for base case")
    # Base case = uniform weights for each SWAG and its models
    swagCount = 0
    for ckpt in swag_ckpts:
        swagCount += 1
    n_ensembled = 0.
    multiswag_probs = None
    for ckpt_i, ckpt in enumerate(swag_ckpts):
        print("Checkpoint {}".format(ckpt))
        checkpoint = torch.load(ckpt)
        swag_model.subspace.rank = torch.tensor(0)
        swag_model.load_state_dict(checkpoint['state_dict'])

        for sample in range(swag_samples):
            swag_model.sample(.5)
            utils.bn_update(loaders['train'], swag_model)
            res = utils.predict(loaders['test'], swag_model)
            probs = res['predictions']
            targets = res['targets']
            nll = utils.nll(probs, targets)
            acc = utils.accuracy(probs, targets)

            if multiswag_probs is None:
                multiswag_probs = probs.copy()
            else:
                #TODO: rewrite in a numerically stable way
                multiswag_probs +=  (probs - multiswag_probs)/ (n_ensembled + 1)
            n_ensembled += 1

            ens_nll = utils.nll(multiswag_probs, targets)
            ens_acc = utils.accuracy(multiswag_probs, targets)
            values = [ckpt_i, sample, nll, acc, ens_nll, ens_acc]
            table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.4f')
            print(table)
    initialPi = [1/swagCount]*swagCount
    return (initialPi, ens_nll, ens_acc)

In [18]:
def boSwag(Pi):
    n_ensembled = 0.
    multiswag_probs = None
    for ckpt_i, ckpt in enumerate(swag_ckpts):
        #print("Checkpoint {}".format(ckpt))
        checkpoint = torch.load(ckpt)
        swag_model.subspace.rank = torch.tensor(0)
        swag_model.load_state_dict(checkpoint['state_dict'])
        #swagWeight = Pi[ckpt]
        swagWeight = Pi[ckpt]/sum([Pi[i] for i in Pi])
        indivWeight = swagWeight/swag_samples

        for sample in range(swag_samples):
            swag_model.sample(.5)
            utils.bn_update(loaders['train'], swag_model)
            res = utils.predict(loaders['test'], swag_model)
            probs = res['predictions']
            targets = res['targets']
            nll = utils.nll(probs, targets)
            acc = utils.accuracy(probs, targets)

            if multiswag_probs is None:
                multiswag_probs = indivWeight*probs.copy()
            else:
                #TODO: rewrite in a numerically stable way
                #multiswag_probs +=  (probs - multiswag_probs)/ (n_ensembled + 1)
                multiswag_probs += indivWeight*probs.copy()
            n_ensembled += 1

            ens_nll = utils.nll(multiswag_probs, targets)
            ens_acc = utils.accuracy(multiswag_probs, targets)
            values = [ckpt_i, sample, nll, acc, ens_nll, ens_acc]
            #table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.4f')
            #print(table)
    if useMetric == 'nll':
        return ens_nll
    else:
        return -ens_acc

In [14]:
Pi = {}
for ckpt_i, ckpt in enumerate(swag_ckpts):
    Pi[ckpt] = 0

In [15]:
# https://ax.dev/tutorials/gpei_hartmann_developer.html
from ax import (
    ComparisonOp,
    ParameterType, 
    RangeParameter,
    SearchSpace, 
    SimpleExperiment, 
    OutcomeConstraint,
    SumConstraint
)
from ax.modelbridge.registry import Models

parameters = []
for i in Pi:
    parameters.append(RangeParameter(name=i, parameter_type=ParameterType.FLOAT, lower=0.0, upper=1.0))
#probCons = SumConstraint(parameters=[i for i in parameters], is_upper_bound=True, bound=1.0)

#swagSearchSpace = SearchSpace(parameters=parameters, parameter_constraints=[probCons])
swagSearchSpace = SearchSpace(parameters=parameters)

exp = SimpleExperiment(
    name="BO_SWAG",
    search_space=swagSearchSpace,
    evaluation_function=boSwag,
    objective_name="testSwag",
    minimize=True,
)

print(f"Running Sobol initialization trials...")
sobol = Models.SOBOL(exp.search_space)
for i in range(5):
    exp.new_trial(generator_run=sobol.gen(1))
    
for i in range(20):
    print(f"Running GP+EI optimization trial {i+1}/20...")
    # Reinitialize GP+EI model at each step with updated data.
    gpei = Models.BOTORCH(experiment=exp, data=exp.eval())
    batch = exp.new_trial(generator_run=gpei.gen(1))

print(exp.eval_trial(exp.trials[1]))
print("Done!")

Running Sobol initialization trials...
Running GP+EI optimization trial 1/20...
Running GP+EI optimization trial 2/20...
Running GP+EI optimization trial 3/20...
Running GP+EI optimization trial 4/20...
Running GP+EI optimization trial 5/20...
Running GP+EI optimization trial 6/20...
Running GP+EI optimization trial 7/20...
Running GP+EI optimization trial 8/20...
Running GP+EI optimization trial 9/20...
Running GP+EI optimization trial 10/20...
Running GP+EI optimization trial 11/20...
Running GP+EI optimization trial 12/20...
Running GP+EI optimization trial 13/20...
Running GP+EI optimization trial 14/20...
Running GP+EI optimization trial 15/20...
Running GP+EI optimization trial 16/20...
Running GP+EI optimization trial 17/20...
Running GP+EI optimization trial 18/20...
Running GP+EI optimization trial 19/20...
Running GP+EI optimization trial 20/20...
<ax.core.data.Data object at 0x7f3419f890d0>
Done!


In [16]:
from ax.plot.trace import optimization_trace_single_method
from ax.utils.notebook.plotting import render, init_notebook_plotting

objective_means = np.array([[trial.objective_mean for trial in exp.trials.values()]])
best_objective_plot = optimization_trace_single_method(
        y=np.minimum.accumulate(objective_means, axis=1),
        optimum=-0.7813,  # Known minimum objective for Hartmann6 function.
)
render(best_objective_plot)

In [17]:
from ax import save
save(exp, '/home/mohit/boLogs/cifar10corruptedAcc.json')

In [19]:
objective_means

array([[-0.7811 , -0.77922, -0.77842, -0.7822 , -0.78084, -0.77642,
        -0.77922, -0.77856, -0.78052, -0.78124, -0.7806 , -0.77862,
        -0.77956, -0.78134, -0.78236, -0.77854, -0.7808 , -0.78002,
        -0.78082, -0.7766 , -0.78264, -0.77988, -0.781  , -0.78076,
        -0.78066]])

In [65]:
def getCurrAcc(Pi):
    # Constraints may sum to greater than 1
    Pi = Pi/sum(Pi)
    swagCount = 0
    for ckpt in swag_ckpts:
        swagCount += 1
    n_ensembled = 0.
    multiswag_probs = None
    for ckpt_i, ckpt in enumerate(swag_ckpts):
        #print("Checkpoint {}".format(ckpt))
        checkpoint = torch.load(ckpt)
        swag_model.subspace.rank = torch.tensor(0)
        swag_model.load_state_dict(checkpoint['state_dict'])
        swagWeight = Pi[ckpt_i]
        indivWeight = swagWeight/swag_samples

        for sample in range(swag_samples):
            swag_model.sample(.5)
            utils.bn_update(loaders['train'], swag_model)
            res = utils.predict(loaders['test'], swag_model)
            probs = res['predictions']
            targets = res['targets']
            nll = utils.nll(probs, targets)
            acc = utils.accuracy(probs, targets)

            if multiswag_probs is None:
                #multiswag_probs = probs.copy()
                multiswag_probs = indivWeight*probs.copy()
            else:
                #TODO: rewrite in a numerically stable way
                multiswag_probs += indivWeight*probs.copy()
                #multiswag_probs +=  (probs - multiswag_probs)/ (n_ensembled + 1)
            n_ensembled += 1

            ens_nll = utils.nll(multiswag_probs, targets)
            ens_acc = utils.accuracy(multiswag_probs, targets)
            values = [ckpt_i, sample, nll, acc, ens_nll, ens_acc]
            table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.4f')
            print(table)
    initialPi = [1/swagCount]*swagCount
    return ens_nll

In [66]:
# from torch.multiprocessing import Pool, Process, set_start_method
# try:
#      set_start_method('spawn', force=True)
# except RuntimeError:
#     pass

In [67]:
# """from scipy.optimize import minimize
# cons = ({'type': 'eq',
#          'fun' : lambda x: np.array([sum(x) - 1])})

# res = minimize(getCurrAcc, initialGuess, constraints=cons, ver)"""
# import black_box as bb

# best_params = bb.search_min(f = getCurrAcc,domain = [[0.0, 1.0]]*len(swag_ckpts), budget = 80, batch = 10, resfile = 'output.csv')

In [76]:
# import numpy.random
# #initialGuess = numpy.random.sample(len(swag_ckpts))
# temp = np.random.dirichlet(np.ones(len(swag_ckpts)),size=1)
# #initialGuess = initialGuess/sum(initialGuess)

In [77]:
from ax import load
exp = load('/home/mohit/boLogs/cifar10normal.json')

In [83]:
objective_means = np.array([[trial.objective_mean for trial in exp.trials.values()]])

In [95]:
import inspect
for i in exp.trials.values():
    method_list = [func for func in dir(i) if callable(getattr(i, func))]
    #print(inspect.getmembers(i, predicate=inspect.ismethod))
    print(method_list)
    break

['__class__', '__delattr__', '__dir__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__init__', '__init_subclass__', '__le__', '__lt__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_check_existing_and_name_arm', '_get_candidate_metadata', '_get_candidate_metadata_from_all_generator_runs', '_mark_failed_if_past_TTL', '_set_generation_step_index', 'add_arm', 'add_generator_run', 'assign_runner', 'complete', 'fetch_data', 'get_metric_mean', 'mark_abandoned', 'mark_as', 'mark_completed', 'mark_failed', 'mark_running', 'mark_staged', 'run', 'update_run_metadata']
