In [1]:
import pinot
import numpy as np
import torch

Using backend: pytorch


In [2]:
ds = pinot.data.moonshot_meta()


def get_separate_dataset(ds):
    # n_tasks = ds[0][1].shape[-1]
    n_tasks = 6
    datasets = [[] for _ in range(n_tasks)]
    
    for g, y in ds:
        for idx in range(n_tasks):
            if np.isnan(y[idx].numpy()) == False:
                datasets[idx].append((g, y[idx][None]))
    
    return datasets


datasets = get_separate_dataset(ds)

In [3]:

def probability_of_improvement(distribution, y_best=0.0):

    return 1.0 - distribution.cdf(y_best)

def expected_improvement(distribution, y_best=0.0):
    return distribution.mean - y_best

def upper_confidence_bound(distribution, y_best=0.0, kappa=0.5):
    from pinot.inference.utils import confidence_interval
    _, high = confidence_interval(distribution, kappa)
    return high

In [68]:
def _independent(distribution):
    return torch.distributions.normal.Normal(
        distribution.mean.flatten(),
        distribution.variance.pow(0.5).flatten())

def _slice_fn_tensor(data, idxs):
    # data.shape = (N, 2)
    # idx is a list
    data = data[idxs]
    assert data.dim() == 2
    assert data.shape[-1] == 2
    return data[:, 0][:, None], data[:, 1][:, None]

def _slice_fn_graph(data, idxs):
    # data is a list
    # idx is a list
    data = [data[idx] for idx in idxs]
    gs, ys = list(zip(*data))
    import dgl
    gs = dgl.batch(gs)
    ys = torch.stack(ys, dim=0)
    return gs, ys


In [145]:
class MultiTaskBayesianOptimizationExperiment(torch.nn.Module):
    """ Multitask BO.
    """
    def __init__(
        self,
        nets,
        datasets,
        acquisition,
        optimizer,
        limit=100,
        n_epochs_training=100,
        workup=_independent,
        slice_fn=_slice_fn_graph,
        net_state_dicts=None):
        
        self.nets = nets
        self.optimizer = optimizer
        self.n_epochs_training = n_epochs_training
        
        self.acquisition = acquisition
        
        self.worup = workup
        self.slice_fn = slice_fn
        
        
        self.limit = limit
        
        self.olds = []
        self.news = []
        
        self.datasets = datasets
        
        self.net_state_dicts = net_state_dicts
        
        
        self.workup = workup 
        
        self.n_tasks = len(self.nets)
        
        for idx in range(self.n_tasks):
            self.olds.append([])
            self.news.append(list(range(len(datasets[idx]))))

        self.y_bests = [0 for x in range(self.n_tasks)]
        
    def reset_net(self):
        # TODO:
        # reset optimizer too
        
        for net in self.nets:
            for module in net.modules():
                if isinstance(module, torch.nn.Linear):
                    module.reset_parameters()
            
        if self.net_state_dicts is not None:
            for idx, net in enumerate(self.nets):
                net.load_state_dict(self.net_state_dicts[idx])
                
    def blind_pick(self):
        import random
        
        for idx in range(self.n_tasks):
            best = random.choice(self.news[idx])
            self.olds[idx].append(self.news[idx].pop(best))
            
    def train(self):
        """ Train the model with new data.
        """
        # reset
        self.reset_net()

        # set to train status
        [net.train() for net in self.nets]

        # grab old data
        # (N, 2) for tensor
        # N list of 2-tuple for lists
        old_datas = [self.slice_fn(self.datasets[idx], self.olds[idx]) for idx in range(self.n_tasks)]

        for _ in range(self.n_epochs_training):
            self.optimizer.zero_grad()
            loss = 0.0
            for idx in range(self.n_tasks):
                g, y = old_datas[0]
                net = self.nets[idx]
                _loss = net.loss(g, y).mean()
                loss += _loss
            loss.backward()
            self.optimizer.step()
            
        for idx in range(self.n_tasks):
            gs, ys = old_datas[idx]
            self.y_bests[idx] = torch.max(ys)
            
            
    def acquire(self):
        for idx, net in enumerate(self.nets):
            if len(self.news[idx]) > 0:
                gs, _ = self.slice_fn(self.datasets[idx], self.news[idx])
                distribution = net.condition(gs)
                distribution = self.workup(distribution)
                score = self.acquisition(distribution, y_best=self.y_bests[idx])

                
                best = torch.argmax(score)
                self.olds[idx].append(self.news[idx].pop(best))
                
                
    def run(self):
        self.blind_pick()
        
        step = 0
        while any(len(self.news[idx]) > 0 for idx in range(self.n_tasks)) and step < self.limit:
            bo.train()
            bo.acquire()
            step += 1
        
        

In [120]:
net = pinot.representation.Sequential(
    pinot.representation.dgl_legacy.gn(),
    [32, 'tanh', 32, 'tanh', 32, 'tanh'])

nets = []
for idx in range(len(datasets)):
    
    model = pinot.inference.gp.gpr.exact_gpr.ExactGPR(
        kernel=pinot.inference.gp.kernels.deep_kernel.DeepKernel(
            base_kernel=pinot.inference.gp.kernels.rbf.RBF(torch.ones(32)),
            representation=net))
    
    nets.append(model)
    
params = []
for net in nets:
    params += list(net.parameters())

bo = MultiTaskBayesianOptimizationExperiment(
    nets=nets,
    datasets=datasets,
    acquisition=probability_of_improvement,
    optimizer=torch.optim.Adam(params, 1e-5))

  torch.tensor(scale))


In [121]:
bo.run()

[367, 265]
[281, 23]
[23, 40]
[355, 345]
[4, 460]
[50, 70]
[367, 265, 375]
[281, 23, 207]
[23, 40, 35]
[355, 345, 354]
[4, 460, 318]
[50, 70, 34]
[367, 265, 375, 99]
[281, 23, 207, 11]
[23, 40, 35, 4]
[355, 345, 354, 282]
[4, 460, 318, 13]
[50, 70, 34, 6]
[367, 265, 375, 99, 18]
[281, 23, 207, 11, 342]
[23, 40, 35, 4, 29]
[355, 345, 354, 282, 20]
[4, 460, 318, 13, 25]
[50, 70, 34, 6, 49]
[367, 265, 375, 99, 18, 111]
[281, 23, 207, 11, 342, 256]
[23, 40, 35, 4, 29, 27]
[355, 345, 354, 282, 20, 111]
[4, 460, 318, 13, 25, 143]
[50, 70, 34, 6, 49, 60]
[367, 265, 375, 99, 18, 111, 393]
[281, 23, 207, 11, 342, 256, 316]
[23, 40, 35, 4, 29, 27, 14]
[355, 345, 354, 282, 20, 111, 372]
[4, 460, 318, 13, 25, 143, 491]
[50, 70, 34, 6, 49, 60, 23]
[367, 265, 375, 99, 18, 111, 393, 212]
[281, 23, 207, 11, 342, 256, 316, 114]
[23, 40, 35, 4, 29, 27, 14, 2]
[355, 345, 354, 282, 20, 111, 372, 213]
[4, 460, 318, 13, 25, 143, 491, 267]
[50, 70, 34, 6, 49, 60, 23, 45]
[367, 265, 375, 99, 18, 111, 393, 212

KeyboardInterrupt: 

In [140]:
regrets = []
for idx in range(bo.n_tasks):
    actual_best = max(bo.datasets[idx][idx_][1].squeeze() for idx_ in range(len(bo.datasets[idx]))).detach().numpy()
    
    regret = []
    
    for step in range(1, len(bo.olds[idx])):
        
        idxs = bo.olds[idx][:step]
        
        _, ys = bo.slice_fn(bo.datasets[idx], idxs)
        
        y_best_now = torch.max(ys.flatten()).detach().numpy()
        
        regret.append(actual_best - y_best_now)
        
    regrets.append(regret)
    

In [144]:
regrets

[[0.8810369,
  0.8810369,
  0.8810369,
  0.8810369,
  0.18761182,
  0.18761182,
  0.18761182,
  0.18761182,
  0.18761182,
  0.18761182,
  0.18761182,
  0.18761182,
  0.18761182,
  0.086015165,
  0.086015165,
  0.086015165,
  0.086015165,
  0.086015165,
  0.086015165,
  0.086015165,
  0.086015165],
 [0.37817907,
  0.37817907,
  0.37817907,
  0.37817907,
  0.37817907,
  0.37817907,
  0.37817907,
  0.37817907,
  0.37817907,
  0.37817907,
  0.08990431,
  0.08990431,
  0.08990431,
  0.08990431,
  0.08990431,
  0.08990431,
  0.08990431,
  0.08990431,
  0.08990431,
  0.08990431,
  0.08990431],
 [0.88493466,
  0.84685165,
  0.7956941,
  0.7956941,
  0.7205272,
  0.6696938,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.92803633,
  0.7156153,
  0.7156153,
  0.7156153,
  0.14995617,
  0.14995617,
  0.14995617,
  0.14995617,
  0.14995617,
  0.14995617,
  0.14995617,
  0.14995617,
  0.14995617,
  0.14995617,
  0.14995617,
  0.14995617,