In [1]:
import torch
from torch.autograd import Variable
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torch.nn import functional as F

In [2]:
# Step 1. Load Dataset
# Step 2. Make Dataset Iterable
# Step 3. Create Model Class
# Step 4. Instantiate Model Class
# Step 5. Instantiate Loss Class
# Step 6. Instantiate Optimizer Class
# Step 7. Train Model

In [3]:
train_dataset = dsets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = dsets.MNIST(root='./data', train=False, transform=transforms.ToTensor())

In [4]:
batch_size = 100

#train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [5]:
class LogisticRegression(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        outputs = self.linear(x) # pytorch Crossentropy loss takes scores, not logits as inputs
        #outputs = F.softmax(self.linear(x), dim = -1) # why better without F.softmax
        return outputs 

In [6]:
# was not worth caching this
def get_minibatch(dataset, n, i, input_dim = 784):
    lst1, lst2 = zip(*[dataset[ii] for ii in range(i * n, (i+1) * n)])
    xs = torch.stack(lst1).view(-1, input_dim)
    ys = torch.tensor(lst2)
    return(xs, ys)

In [7]:
%%time
input_dim = 784
output_dim = 10

# Note we will use a Bayesian model average of (some distribution over) history of sampled class probs in the sampling case
criterion = torch.nn.CrossEntropyLoss(reduction = 'sum') # computes softmax and then the cross entropy
model = LogisticRegression(input_dim, output_dim)
lr_rate = 0.001
optimizer = torch.optim.SGD(model.parameters(), lr=lr_rate)
N = train_dataset.__len__()//batch_size # TO DO: get N from dataset
print(N)
epochs = 4 #n_iters / (len(train_dataset) / batch_size)
i = 0
for epoch in range(int(epochs)):
    epoch_losses = []
    for it in range(N):
        optimizer.zero_grad()
        images, labels = get_minibatch(train_dataset, batch_size, it)    
        outputs = model(images)
        loss = criterion(outputs, labels) # NEXT STEP: replace with an evaluate call!!!
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.detach().item())
        i += 1
    # end of epoch: calculate Accuracy
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = Variable(images.view(-1, 28*28))
        outputs = model(images)
        # TO DO convert these outputs to probs, to get more accurate Accuracy metric
        # and to support historical averaging (e.g. q(n) = p(n) * 2/(n+1) + q(n-1) * (n-1)/(n+1))
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        # for gpu, bring the predicted and labels back to cpu fro python operations to work
        correct+= (predicted == labels).sum()
    accuracy = 100 * float(correct)/total
    print("Iteration: {}. Loss: {:1f}. Accuracy: {}.".format(i, torch.tensor(epoch_losses).mean(), accuracy))

600
Iteration: 600. Loss: 53.603058. Accuracy: 90.16.
Iteration: 1200. Loss: 35.963299. Accuracy: 90.93.
Iteration: 1800. Loss: 33.158787. Accuracy: 91.46.
Iteration: 2400. Loss: 31.691315. Accuracy: 91.63.
CPU times: user 1min 10s, sys: 2min 47s, total: 3min 58s
Wall time: 22.8 s


In [8]:
from HINTS_fn import *
# PyTorch test function for HINTS
# proposals are Langevin/HMC and must create a new deep copy of the model
# HINTS uses the ID() of the model object for tracking/caching purposes

device = 'cuda'

class TorchMNIST(UserFn):
        def __init__(self, additive = True):
            self.batch_size = 100 # 60000 dataset size (so will not see many GPU benefits)
            self.input_dim = 784
            self.output_dim = 10
            self.train_dataset = dsets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
            self.N = self.train_dataset.__len__()//self.batch_size # num scenarios
            self.lr = 0.001 # TO DO pass this parameter
            self.additive = additive # used by HINTS
            print(self.N)
            self.criterion = torch.nn.CrossEntropyLoss(reduction = 'sum')
            super().__init__(None)
        #
        def sample_initial_state(self):
            model = LogisticRegression(self.input_dim, self.output_dim).to(device)
            return(model)
        #    
        @lru_cache(maxsize = 100000)
        def evaluate(self, state, term_index, gradient = False):
            self.counter += 1
            if gradient:
                f = -self.minibatch_loss(state, term_index)
                f.backward()
            else:
                with torch.no_grad():
                    f = -self.minibatch_loss(state, term_index)
            return(f)
        #
        # this can be called with or without torch.no_grad
        def minibatch_loss(self, model, term_index):
            images, labels = get_minibatch(self.train_dataset, self.batch_size, term_index)  
            outputs = model(images.to(device))
            loss = self.criterion(outputs.cpu(), labels) # NEXT STEP: replace with an evaluate call!!!
            return(loss)

    
MH = TorchMNIST()
state0 = MH.sample_initial_state()
%time v = MH.evaluate(state0, 6, True)
%time v = MH.evaluate(state0, 6, True)
print(v)
for f in state0.parameters():
    print(f.shape, f.grad.shape)

600
CPU times: user 156 ms, sys: 11.9 ms, total: 168 ms
Wall time: 168 ms
CPU times: user 1e+03 ns, sys: 3 µs, total: 4 µs
Wall time: 4.77 µs
tensor(-227.8024, grad_fn=<NegBackward>)
torch.Size([10, 784]) torch.Size([10, 784])
torch.Size([10]) torch.Size([10])


In [9]:
%%time
# MINIBATCH SGD EXAMPLE (no HINTS yet)
epochs = 4
lr_rate = 0.001
model = MH.sample_initial_state()
optimizer = torch.optim.SGD(model.parameters(), lr = lr_rate)
i = 0
for epoch in range(int(epochs)):
    epoch_losses = []
    for it in range(MH.N):
        optimizer.zero_grad()
        loss = MH.minibatch_loss(model, it).cpu()
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.detach().item())
        i += 1
    # end of epoch: calculate Accuracy
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = Variable(images.view(-1, 28*28))
        outputs = model(images.to(device)).cpu()
        # TO DO convert these outputs to probs, to get more accurate Accuracy metric
        # and to support historical averaging (e.g. q(n) = p(n) * 2/(n+1) + q(n-1) * (n-1)/(n+1))
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        # for gpu, bring the predicted and labels back to cpu fro python operations to work
        correct+= (predicted == labels).sum()
    accuracy = 100 * float(correct)/total
    print("Iteration: {}. Loss: {:1f}. Accuracy: {}.".format(i, torch.tensor(epoch_losses).mean(), accuracy))

Iteration: 600. Loss: 53.359703. Accuracy: 90.14.
Iteration: 1200. Loss: 35.966660. Accuracy: 91.01.
Iteration: 1800. Loss: 33.157066. Accuracy: 91.3.
Iteration: 2400. Loss: 31.687439. Accuracy: 91.56.
CPU times: user 23.6 s, sys: 13.7 s, total: 37.3 s
Wall time: 20.1 s


In [10]:
# GPU slightly reduces minibatch wall clock time (100)

In [11]:
from HINTS import *

class HINTS_HMC(HINTS):
    # TO DO move langevin out?
    def __init__(self, args, fn, noise_sd = 0.01):
        super().__init__(args, fn)
        self.epsilon = args.epsilon
        self.noise_sd = noise_sd
    #
    # HMC version ... correction depends on knowing gradient at proposed state
    def primitive_move(self, model, index = 0, always_accept = False):
        scenarios = self.scenarios(0, index)
        #print("primitive", index, len(scenarios))
        v = self.fn(model, scenarios, True) # puts a gradient into state as a side effect for HMC
        current = self.fn.sample_initial_state() # empty model
        # do this with no grad...
        correction = 0.0
        for f, f_prime in zip(model.parameters(), current.parameters()):
            # f.grad.data is the right shape to store momentum temporarily
            ###TO DO how big is f.grad.data compared with unit noise
            f.grad.data = self.noise_sd * torch.randn(f.shape).to(device) + self.epsilon * 0.5 * f.grad.data # TO DO check gradient convention
            f_prime.data = f.data + 0.5 * self.epsilon * f.grad.data # add momentum
            correction -= 0.5 * (f.grad.data * f.grad.data).sum() # Kinetic energy term of -H
        # compute the value of the new state and its gradient
        v_prime = self.fn(current, scenarios, True) # need gdt again
        # store the new momentum in the grad entries of the candidate model
        for f, f_prime in zip(model.parameters(), current.parameters()):
            p_prime = f.grad.data + self.epsilon * 0.5 * f_prime.grad.data # TO DO check gradient convention
            correction += 0.5 * (p_prime * p_prime).sum() # kinetic energy term of H_new
            f.grad = None # must not reuse (unless we do more leapfrog steps)
            f_prime.grad = None
        
        # standard MHR / HINTS acceptance
        vdiff = (v_prime - v)/self.Ts[0] # PE change ... these are cached evaluations, no side effects
        #print(v_prime, v, self.Ts[0], correction)
        #
        #correction = 0 # TEMPOARY OVERRDE - SGD
        #
        #
        accept = True if always_accept else self.metropolis_accept(vdiff - correction)
        (self.acceptances if accept else self.rejections)[0] += 1
        return((current, vdiff) if accept else (model, 0.0))
    


In [16]:
import argparse
parser = argparse.ArgumentParser()
args = parser.parse_known_args()[0] # defaults

# checked regular MCMC is accurate (1 level, lbf = 6)
if False:
    args.levels = 6
    log_branch_factor = 1
    N_0 = 1
    #args.design = np.array([N_0] + [2 ** log_branch_factor for l in range(args.levels)])
    args.design = np.array([1,2,3,2,5,2,5])
    NUM_SCENARIOS = args.design.prod()
    aa = False
    #NUM_SCENARIOS = N_0 * 2 ** (args.levels * log_branch_factor) # TO DO get from HINTS
elif True: # test Langevin MCMC first (or set noise_sd to small value for SGD)
    args.levels = 1 
    NUM_SCENARIOS = 600 # 600 for naive mcmc, 1 for SGD
    args.design = np.array([1,NUM_SCENARIOS]) # ensure whole dataset is covered
    aa = True # use always_accept flag to get minibatch Langevin (so we actually only have  a 1 level architecture)
else: # full Langevin/gradient descent
    args.levels = 0 
    NUM_SCENARIOS = 600 # 600 for naive mcmc, 1 for SGD
    args.design = np.array([NUM_SCENARIOS]) # ensure whole dataset is covered
    aa = True # always accept only applies to non-primitive moves SO THIS WILL NOT WORK!!
    
    
print(NUM_SCENARIOS)

# design now has levels + 1 entriess

# additive log probability is more natural from a Bayesian perspective but both are valid

args.additive = True # effectively selects a different temperature structure when False (= average or expectation)

args.T = 1.0 #top level
args.dT = 0.0 if args.additive else 0.5 # temperature increment by level (mainly for optimisation or averaging structure)
args.epsilon = 0.02 # for HMC
print(args.__dict__)



600
{'levels': 1, 'design': array([  1, 600]), 'additive': True, 'T': 1.0, 'dT': 0.0, 'epsilon': 0.02}


In [17]:

g = TorchMNIST(args.additive)
hmc = HINTS_HMC(args, g, noise_sd = 1e-3) # noise sd crucial for acceptance rate (check maths for sd not equal to 1
state  = g.sample_initial_state()
print(state)



600
RESET
1
[  1 600]
600
[1. 1.]
LogisticRegression(
  (linear): Linear(in_features=784, out_features=10, bias=True)
)


In [None]:
for t in range(1000):
    hmc.shuffle()
    print(t)
    g.evaluate.cache_clear() # risk of revisiting same state and scenario after reject, and gradient not being available
    state, correction = hmc.hints(state, args.levels, always_accept = aa) # e.g. dbg = (t==0)
    # diagnostic histogram
    # show progress
    if True:
        with torch.no_grad():
            loss = MH.minibatch_loss(state, t % NUM_SCENARIOS)
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = Variable(images.view(-1, 28*28))
            outputs = state(images.to(device))
            # TO DO convert these outputs to probs, to get more accurate Accuracy metric
            # and to support historical averaging (e.g. q(n) = p(n) * 2/(n+1) + q(n-1) * (n-1)/(n+1))
            _, predicted = torch.max(outputs.cpu().data, 1)
            total += labels.size(0)
            # for gpu, bring the predicted and labels back to cpu fro python operations to work
            correct+= (predicted == labels).sum()
        accuracy = 100 * float(correct)/total
        print("Iteration: {}. Loss: {}. Accuracy: {}."\
              .format(t+1, loss, accuracy), hmc.acceptances, hmc.rejections)

#TO DO skip accept/reject at top level

0
Iteration: 1. Loss: 82.80477142333984. Accuracy: 83.52. [411   1] [189   0]
1
Iteration: 2. Loss: 68.9365463256836. Accuracy: 86.54. [894   2] [306   0]
2
Iteration: 3. Loss: 53.33583450317383. Accuracy: 87.78. [1415    3] [385   0]
3
Iteration: 4. Loss: 31.505752563476562. Accuracy: 88.39. [1937    4] [463   0]
4
Iteration: 5. Loss: 44.37473678588867. Accuracy: 88.76. [2487    5] [513   0]
5
Iteration: 6. Loss: 50.96900177001953. Accuracy: 89.28. [3025    6] [575   0]
6
Iteration: 7. Loss: 57.9368896484375. Accuracy: 89.53. [3568    7] [632   0]
7
Iteration: 8. Loss: 34.14878845214844. Accuracy: 89.72. [4125    8] [675   0]
8
Iteration: 9. Loss: 54.5328369140625. Accuracy: 90.01. [4665    9] [735   0]
9
Iteration: 10. Loss: 43.673362731933594. Accuracy: 90.14. [5215   10] [785   0]
10
Iteration: 11. Loss: 59.189971923828125. Accuracy: 90.34. [5771   11] [829   0]
11
Iteration: 12. Loss: 43.510948181152344. Accuracy: 90.42. [6322   12] [878   0]
12
Iteration: 13. Loss: 41.03447723388

In [54]:
print(g.total_counter, g.counter, "miss% = " + str(int((100.0 * g.counter)/ g.total_counter))) # check cache ratio

9540 6752 miss% = 70


In [None]:
# looks to be better than MCMC [but for this dataset, no compelling case for MCMC methods]
# TO DO higher branch factor (4+) more efficient
# HOW MUCH NOISE AT PRIMITIVE LEVEL? - zero case for SGD
# aim for bigger moves at primitive level (or HMC chain?)
# GPU - DONE
# Bayesian accuracy measure thru decaying average (triangle distrib)
