In [1]:
import torch
from torch.autograd import Variable
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torch.nn import functional as F

In [2]:
# Step 1. Load Dataset
# Step 2. Make Dataset Iterable
# Step 3. Create Model Class
# Step 4. Instantiate Model Class
# Step 5. Instantiate Loss Class
# Step 6. Instantiate Optimizer Class
# Step 7. Train Model

In [3]:
train_dataset = dsets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = dsets.MNIST(root='./data', train=False, transform=transforms.ToTensor())

In [4]:
batch_size = 100

#train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [5]:
class LogisticRegression(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        outputs = self.linear(x) # pytorch Crossentropy loss takes scores, not logits as inputs
        #outputs = F.softmax(self.linear(x), dim = -1) # why better without F.softmax
        return outputs 

In [6]:
# was not worth caching this
def get_minibatch(dataset, n, i, input_dim = 784):
    lst1, lst2 = zip(*[dataset[ii] for ii in range(i * n, (i+1) * n)])
    return(torch.stack(lst1).view(-1, input_dim), torch.stack(lst2))

In [7]:
%%time
input_dim = 784
output_dim = 10

# Note we will use a Bayesian model average of (some distribution over) history of sampled class probs in the sampling case
criterion = torch.nn.CrossEntropyLoss(reduction = 'sum') # computes softmax and then the cross entropy
model = LogisticRegression(input_dim, output_dim)
lr_rate = 0.001
optimizer = torch.optim.SGD(model.parameters(), lr=lr_rate)
N = train_dataset.__len__()//batch_size # TO DO: get N from dataset
print(N)
epochs = 4 #n_iters / (len(train_dataset) / batch_size)
i = 0
for epoch in range(int(epochs)):
    epoch_losses = []
    for it in range(N):
        optimizer.zero_grad()
        images, labels = get_minibatch(train_dataset, batch_size, it)    
        outputs = model(images)
        loss = criterion(outputs, labels) # NEXT STEP: replace with an evaluate call!!!
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.detach().item())
        i += 1
    # end of epoch: calculate Accuracy
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = Variable(images.view(-1, 28*28))
        outputs = model(images)
        # TO DO convert these outputs to probs, to get more accurate Accuracy metric
        # and to support historical averaging (e.g. q(n) = p(n) * 2/(n+1) + q(n-1) * (n-1)/(n+1))
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        # for gpu, bring the predicted and labels back to cpu fro python operations to work
        correct+= (predicted == labels).sum()
    accuracy = 100 * float(correct)/total
    print("Iteration: {}. Loss: {:1f}. Accuracy: {}.".format(i, torch.tensor(epoch_losses).mean(), accuracy))

600
Iteration: 600. Loss: 53.662323. Accuracy: 90.16.
Iteration: 1200. Loss: 36.005524. Accuracy: 90.95.
Iteration: 1800. Loss: 33.190990. Accuracy: 91.43.
Iteration: 2400. Loss: 31.717945. Accuracy: 91.56.
Wall time: 24 s


In [21]:
from HINTS_fn import *
# PyTorch test function for HINTS
# proposals are Langevin/HMC and must create a new deep copy of the model
# HINTS uses the ID() of the model object for tracking/caching purposes

class TorchMNIST(UserFn):
        def __init__(self, additive = True, state_is_hashable = True):
            self.batch_size = 100
            self.input_dim = 784
            self.output_dim = 10
            self.train_dataset = dsets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
            self.N = self.train_dataset.__len__()//self.batch_size # num scenarios
            self.lr = 0.001 # TO DO pass this parameter
            self.additive = additive # used by HINTS
            print(self.N)
            self.criterion = torch.nn.CrossEntropyLoss(reduction = 'sum')
            super().__init__(None, state_is_hashable)
        #
        def user_sample_initial_state(self):
            model = LogisticRegression(self.input_dim, self.output_dim)
            return(model)
        #    
        # the return value will be cached, but the gradient computed at level zero will
        # be present in the state only immediately after a cache miss (first call for each scenario, state)
        def user_evaluate(self, state, term_index, level):
            if level == 0:
                f = -self.minibatch_loss(state, term_index)
                f.backward()
            else:
                with torch.no_grad():
                    f = -self.minibatch_loss(state, term_index)
            return(f)
        #
        # this can be called with or without torch.no_grad
        def minibatch_loss(self, model, term_index):
            images, labels = get_minibatch(self.train_dataset, self.batch_size, term_index)  
            outputs = model(images)
            loss = self.criterion(outputs, labels) # NEXT STEP: replace with an evaluate call!!!
            return(loss)

    
MH = TorchMNIST()
state0 = MH.user_sample_initial_state()
v = MH.user_evaluate(state0, 7, 0)
print(v)
for f in state0.parameters():
    print(f.shape, f.grad.shape)

600
tensor(-237.9083, grad_fn=<NegBackward>)
torch.Size([10, 784]) torch.Size([10, 784])
torch.Size([10]) torch.Size([10])


In [22]:
%%time
epochs = 4
lr_rate = 0.001
model = MH.user_sample_initial_state()
optimizer = torch.optim.SGD(model.parameters(), lr = lr_rate)
i = 0
for epoch in range(int(epochs)):
    epoch_losses = []
    for it in range(MH.N):
        optimizer.zero_grad()
        loss = MH.minibatch_loss(model, it)
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.detach().item())
        i += 1
    # end of epoch: calculate Accuracy
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = Variable(images.view(-1, 28*28))
        outputs = model(images)
        # TO DO convert these outputs to probs, to get more accurate Accuracy metric
        # and to support historical averaging (e.g. q(n) = p(n) * 2/(n+1) + q(n-1) * (n-1)/(n+1))
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        # for gpu, bring the predicted and labels back to cpu fro python operations to work
        correct+= (predicted == labels).sum()
    accuracy = 100 * float(correct)/total
    print("Iteration: {}. Loss: {:1f}. Accuracy: {}.".format(i, torch.tensor(epoch_losses).mean(), accuracy))

Iteration: 600. Loss: 53.543919. Accuracy: 90.27.
Iteration: 1200. Loss: 35.941975. Accuracy: 90.97.
Iteration: 1800. Loss: 33.138775. Accuracy: 91.41.
Iteration: 2400. Loss: 31.673714. Accuracy: 91.63.
Wall time: 24 s


In [23]:
from HINTS import *
class HINTS_HMC(HINTS):
    # TO DO move langevin out?
    def __init__(self, args, fn):
        super().__init__(args, fn)
        self.epsilon = args.epsilon
    #
    # HMC version ... correction depends on knowing gradient at proposed state
    def primitive_move(self, model, index = 0):
        scenarios = self.scenarios(0, index)
        #print("primitive", index, len(scenarios))
        v = self.fn(model, scenarios, 0) # puts a gradient into state as a side effect for HMC
        current = self.fn.user_sample_initial_state() # empty model
        # do this with no grad...
        correction = 0.0
        for f, f_prime in zip(model.parameters(), current.parameters()):
            # f.grad.data is the right shape to store momentum temporarily
            ###TO DO how big is f.grad.data compared with unit noise
            f.grad.data = 1.0 * torch.randn(f.shape) + self.epsilon * 0.5 * f.grad.data # TO DO check gradient convention
            f_prime.data = f.data + 0.5 * self.epsilon * f.grad.data # add momentum
            correction -= 0.5 * (f.grad.data * f.grad.data).sum() # Kinetic energy term of -H
        # compute the value of the new state and its gradient
        v_prime = self.fn(current, scenarios, 0)
        # store the new momentum in the grad entries of the candidate model
        for f, f_prime in zip(model.parameters(), current.parameters()):
            p_prime = f.grad.data + self.epsilon * 0.5 * f_prime.grad.data # TO DO check gradient convention
            correction += 0.5 * (p_prime * p_prime).sum() # kinetic energy term of H_new
            f.grad = None # must not reuse (unless we do more leapfrog steps)
            f_prime.grad = None
        
        # standard MHR / HINTS acceptance
        vdiff = (v_prime - v)/self.Ts[0] # PE change ... these are cached evaluations, no side effects
        #print(v_prime, v, self.Ts[0], correction)
        #
        #correction = 0 # TEMPOARY OVERRDE - SGD
        #
        #
        accept = self.metropolis_accept(vdiff - correction)
        #if not(accept):
        #    print("REJECT!")
        (self.acceptances if accept else self.rejections)[0] += 1
        return((current, vdiff) if accept else (model, 0.0))
    


In [24]:
args.additive

True

In [25]:
import argparse
parser = argparse.ArgumentParser()
args = parser.parse_known_args()[0] # defaults

# checked regular MCMC is accurate (1 level, lbf = 6)
if True:
    args.levels = 6
    log_branch_factor = 1
    N_0 = 1
    #args.design = np.array([N_0] + [2 ** log_branch_factor for l in range(args.levels)])
    args.design = np.array([1,2,3,2,5,2,5])
    NUM_SCENARIOS = args.design.prod()
    aa = False
    #NUM_SCENARIOS = N_0 * 2 ** (args.levels * log_branch_factor) # TO DO get from HINTS
else: # test Langevin MCMC first
    args.levels = 1 
    NUM_SCENARIOS = 600 # 600 for naive mcmc, 1 for SGD
    args.design = np.array([1,NUM_SCENARIOS]) # ensure whole dataset is covered
    aa = True
    # use always_accept flag to get MCMC
    
print(NUM_SCENARIOS)

# design now has levels + 1 entries

# additive log probability is more natural from a Bayesian perspective but both are valid

args.additive = True # effectively selects a different temperature structure when False (= average or expectation)

args.T = 1.0 #top level
args.dT = 0.0 if args.additive else 0.5 # temperature increment by level (mainly for optimisation or averaging structure)
args.epsilon = 0.02 # for HMC
print(args.__dict__)



600
{'levels': 6, 'design': array([1, 2, 3, 2, 5, 2, 5]), 'additive': True, 'T': 1.0, 'dT': 0.0, 'epsilon': 0.02}


In [26]:

g = TorchMNIST(args.additive, state_is_hashable = True)
hmc = HINTS_HMC(args, g)
state  = g.sample_initial_state()
print(state)



600
RESET
6
[  1   2   6  12  60 120 600]
600
[1. 1. 1. 1. 1. 1. 1.]
LogisticRegression(
  (linear): Linear(in_features=784, out_features=10, bias=True)
)


In [None]:
for t in range(1000):
    hmc.shuffle()
    print(t)
    state, correction = hmc.hints(state, args.levels, always_accept = aa) # e.g. dbg = (t==0)
    g.cached_evaluate.cache_clear() # risk of revisiting same state and scenario after reject
    # diagnostic histogram
    # show progress
    if True:
        with torch.no_grad():
            loss = MH.minibatch_loss(state, t % NUM_SCENARIOS)
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = Variable(images.view(-1, 28*28))
            outputs = state(images)
            # TO DO convert these outputs to probs, to get more accurate Accuracy metric
            # and to support historical averaging (e.g. q(n) = p(n) * 2/(n+1) + q(n-1) * (n-1)/(n+1))
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            # for gpu, bring the predicted and labels back to cpu fro python operations to work
            correct+= (predicted == labels).sum()
        accuracy = 100 * float(correct)/total
        print("Iteration: {}. Loss: {}. Accuracy: {}."\
              .format(t+1, loss, accuracy), hmc.acceptances, hmc.rejections)

#TO DO skip accept/reject at top level

0
Iteration: 1. Loss: 111.57228088378906. Accuracy: 63.04. [475 237  83  46  10   5   1] [125  63  17   4   0   0   0]
1
Iteration: 2. Loss: 89.43009185791016. Accuracy: 69.35. [1024  457  148   79   19   10    2] [176 143  52  21   1   0   0]
2
Iteration: 3. Loss: 97.63916015625. Accuracy: 71.83. [1585  662  208  114   28   15    3] [215 238  92  36   2   0   0]
3
Iteration: 4. Loss: 57.772796630859375. Accuracy: 74.41. [2135  879  262  147   36   20    4] [265 321 138  53   4   0   0]
4
Iteration: 5. Loss: 69.00692749023438. Accuracy: 75.93. [2683 1092  312  183   41   25    5] [317 408 188  67   9   0   0]
5
Iteration: 6. Loss: 90.3050537109375. Accuracy: 76.84. [3259 1316  359  219   47   29    6] [341 484 241  81  13   1   0]
6
Iteration: 7. Loss: 84.50695037841797. Accuracy: 77.76. [3821 1531  410  255   53   34    7] [379 569 290  95  17   1   0]
7
Iteration: 8. Loss: 60.63371658325195. Accuracy: 78.35. [4365 1757  461  287   61   39    8] [435 643 339 113  19   1   0]
8


In [None]:
# looks to be better than MCMC
# TO DO higher branch factor (4+) more efficient
# aim for bigger moves at primitive level (or HMC chain?)
# GPU
# Bayesian accuracy measure thru decaying average (triangle distrib)