In [1]:
%cd /home/q123/Desktop/explo

### local imports 
from src.environment import EnvironmentObjective
from src.optim import step
from src.policy import MLP

### botorch
from botorch.fit import fit_gpytorch_model
from botorch.models import SingleTaskGP
from botorch.acquisition import ExpectedImprovement
from botorch.optim import optimize_acqf
from botorch.models.gpytorch import GPyTorchModel

### gpytorch 
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.kernels import RBFKernel,ScaleKernel,Kernel
from gpytorch.models import ExactGP

### general imports
import numpy as np
import gpytorch
import torch
import gym

### Logging 
import logging
logger = logging.getLogger('__main__')
logger.setLevel(logging.CRITICAL)

/home/q123/Desktop/explo


  from .autonotebook import tqdm as notebook_tqdm


# Imports and kernels


In [2]:
### Toy kernel for warningging

class MyKernel(gpytorch.kernels.RBFKernel):
   
    def forward(self,x1,x2,**params):
        
        logger.warning(f'x1 {x1.shape} / x2 {x2.shape}')
        kernel = super().forward(x1,x2,**params)
        logger.warning(f'pair kernel {kernel.shape}')
        return kernel

In [3]:
class MyGP(ExactGP,GPyTorchModel):
    
    _num_outputs = 1
    
    
    def __init__(self, train_x, train_y,train_s, likelihood,
                 kernel=None,mlp=None):
        
        ExactGP.__init__(self,train_x, train_y, likelihood)
        
        self.mean_module = gpytorch.means.ConstantMean()
        
        if kernel is None:
            self.covar_module = MyKernel()
        else :
            self.covar_module = kernel(mlp,train_s)
            
            
        ### necessary attribute for gpytorch to function
        #self.num_outputs = 1
    
    def update_train_data(self,new_x, new_y,new_s,strict=False):
        
        train_x = torch.cat([self.train_inputs[0], new_x])
        train_y = torch.cat([self.train_targets, new_y])
        ExactGP.set_train_data(self,inputs=train_x,targets=train_y,strict=strict)
        
        
        ### update state kernels with new states
        if isinstance(self.covar_module,StateKernel):
            self.covar_module.update(new_s)
        
    def forward(self, x):
        
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

In [4]:
class StateKernel(gpytorch.kernels.Kernel):
    
    """Abstract class for a kernel that uses state action pairs metric
    """
    
    
    def __init__(self,mlp,train_s):
        
        super().__init__()
        self.update_states(train_s)
        self.mlp = mlp
        ### set rbf_module with appropriate ard dims
        num_states = 200
        self.rbf_module = ScaleKernel(RBFKernel())
        

    def test_policy(self,params_batch,states):
        
        logger.warning('mlp :params_batch.shape{params_batch.shape}')
        actions = self.mlp(params_batch,states)
        logger.warning('mlp :actions.shape{actions.shape}')
        # first_dims = params_batch.shape[:-1]
        # last_dims = actions.shape[-2:]
        # actions = actions.reshape(*first_dims,*last_dims)
        actions = torch.flatten(actions,start_dim=-2)
        logger.warning('reshape :actions.shape{actions.shape}')
        return actions
    
        
        
    def forward(self,x1,x2,**params):
        
        logger.warning(f'x1 {x1.shape} / x2 {x2.shape}')
        
        #Evaluate current parameters
        actions1 = self.test_policy(x1,self.states)
        actions2 = self.test_policy(x2,self.states)
        logger.warning(f'actions1 {actions1.shape} actions2 {actions2.shape} ')
        # Compute pairwise pairwise kernel 
        kernel = self.rbf_module(actions1, actions2, **params)
        logger.warning(f'pair kernel {kernel.shape}')
        
        return kernel 
    
    def update(self,new_s):
        
        raise NotImplementedError
    
        
class GridKernel(StateKernel):
    
    
    def get_grid(self,low,high,samples_per_dim):
        
        
        state_dims = low.shape[0]
        points = [torch.linspace(low[i],high[i],samples_per_dim) 
                    for i in range(state_dims)]
        grid = torch.meshgrid(*points)
        grid = torch.stack(grid)
        grid = torch.flatten(grid,start_dim=1).T ## [n_states,state_dim]
        
        logger.warning(f' grid shape {grid.shape}')
        
        return grid
    
    def update_states(self,new_s):
        
        self.high,_= torch.max(new_s,dim=0)
        self.low,_= torch.min(new_s,dim=0)
        self.states = self.get_grid(self.low,self.high,
                                    samples_per_dim=5)
        
        #print(f'observation box : \n low {self.low} \n high :{self.high} \n grid shape {self.states.shape}')
    
    def update(self,new_s):
        
        
        tmp_buff = torch.cat([self.states, new_s])
        high,_= torch.max(tmp_buff,dim=0)
        low,_= torch.min(tmp_buff,dim=0)
        
        logger.warning(f'BUffer shape {tmp_buff.shape}')
        
        ### update only if be
        if any(high>self.high) or any(low<self.low):
            self.update_states(tmp_buff)

# Experiment 

In [5]:
### pendulum
mlp = MLP([3,1],add_bias=True) ## pendulum
env = gym.make("Pendulum-v1")

### Swimmer 
#mlp = MLP([8,2],add_bias=True) ##swimmer
#env = gym.make("Swimmer-v3")

### Inverted pendulum

# mlp = MLP([4,1],add_bias=True) ##swimmer
# env = gym.make("InvertedPendulum-v2")


# Initialize environment

objective_env = EnvironmentObjective(
  env=env,
  mlp=mlp,
  manipulate_state=None,
  manipulate_reward=None,
)

### initialize train_x, train_y
train_x = torch.rand(100,mlp.len_params) ## [n_trials,n_params]
train_data = [objective_env.run(p) for p in train_x]
train_y = torch.Tensor([d[0] for d in train_data]).reshape(-1)  ## [n_trials,1]
train_s = torch.stack( [d[1] for d in train_data])  ## [n_trials,max_len,state_dim]
train_s = torch.flatten(train_s,start_dim=0,end_dim=1) ## [n_trials*max_len,state_dim]

# initialize likelihood and model
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = MyGP(train_x, train_y,train_s,likelihood,
                    kernel=GridKernel,mlp=mlp)

In [6]:
### now we loop :
max_iter = 5

for i in range(max_iter):

  step(model,objective_env)

  if i % 1 == 0 and i>=1:
  

    max = model.train_targets.max()
    batch_mean = model.train_targets[i-1:i].mean()
    batch_max = model.train_targets[i-1:i].max()
    curr = model.train_targets[-1]
    print(f'current {curr} / max {max} /batch_mean {batch_mean} /batch_max {batch_max} ')

    #print(f'model.train_inputs.shape{model.train_inputs[0].shape}')

##############################
likelihood.noise_covar.raw_noise tensor([0.])
mean_module.constant tensor([0.])
covar_module.rbf_module.raw_outputscale tensor(0.)
covar_module.rbf_module.base_kernel.raw_lengthscale tensor([[0.]])
##############################
##############################
likelihood.noise_covar.raw_noise tensor([6169.4604])
mean_module.constant tensor([-1564.2587])
covar_module.rbf_module.raw_outputscale tensor(3972.7710)
covar_module.rbf_module.base_kernel.raw_lengthscale tensor([[-39.9550]])
##############################
current -1671.0499267578125 / max -1040.72265625 /batch_mean -1548.0745849609375 /batch_max -1548.0745849609375 
##############################
likelihood.noise_covar.raw_noise tensor([11078.7637])
mean_module.constant tensor([-1562.0969])
covar_module.rbf_module.raw_outputscale tensor(8882.0742)
covar_module.rbf_module.base_kernel.raw_lengthscale tensor([[-39.9550]])
##############################
current -1087.8255615234375 / max -1040.72265625 /

# Manually fitting GP (maximizing likelihood)

In [7]:
# training_iter = 100 

# # Find optimal model hyperparameters
# model.train()
# likelihood.train()

# # Use the adam optimizer
# optimizer = torch.optim.SGD(model.parameters(), lr=0.25)  # Includes GaussianLikelihood parameters

# # "Loss" for GPs - the marginal log likelihood
# mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

# for i in range(3):
#     # Zero gradients from previous iteration
#     optimizer.zero_grad()
#     # Output from model
#     output = model(train_x)
#     # Calc loss and backprop gradients
#     loss = -mll(output, train_y)
#     logger.warning(f'Loss {loss.shape}')
#     loss.backward()
#     print('Iter %d/%d - Loss: %.3f noise: %.3f' % 
#         (
#         i + 1, training_iter, loss.item(),
#         model.likelihood.noise.item())
#         )
#     optimizer.step()