We need as input for OED

- Mandatory:
  - design space
  - model prior distribution or samples
     - if only samples, only the relative information gain can be computed
  - Observation forward mapping (deterministic, any probabilistic influence is absorbed into nuisance parameters)
  - Class with forward and backward operator and flag to indicate that it is backwards differentiable
    - observation model model parameter jacobian
      - either takes the model it is linearised about as input or none
      - can be used to differentiate between linear and linearisable model
    - observation model design parameter gradients

  - Observation noise distribution
    - can be explicit or implicit
- Optional:
  - nuisance parameter distribution
    - can be explicit or implicit (only needs to be able to be sampled from)
    - optional conditional on model parameters
    - optional only samples if independent of model parameters
    - can also be used to simulate the target inverse mapping
  - target forward mapping
    - can be deterministic or probabilistic (explicit or implict)
    - in either case we only need to be able to sample from it

 

In [1]:
import torch
import torch.distributions as dist

In [5]:
from abc import ABC, abstractmethod
from torch import Tensor
from torch.distributions import Distribution

from typing import Union

class Observation_Model():

    def __init__(
        self,
        forward_function,
        obs_noise_dist = None,
        m_prior_dist: Distribution=None,
        m_prior_samples: Tensor=None,
        nuisance_dist: callable=None,
        nuisance_parameter_samples: Tensor=None,
        target_forward_function: callable=None,
        ):
        
        if type(forward_function) is callable:
            class Dummy_Foward_Class():
                def __init__(self, forward_function):
                    self.forward_function = forward_function
                def forward(self, design, model_parameters, nuisance_parameters):
                    return self.forward_function(design, model_parameters, nuisance_parameters)
        else:
            self.forward_function = forward_function
            
        #  Check if either prior samples xor prior distribution is provided
        assert (m_prior_dist is None) ^ (m_prior_samples is None), "Either model parameter prior distribution or samples must be provided"
        self.m_prior_dist = m_prior_dist
        self.m_prior_samples = m_prior_samples
        
        #  Check if either prior samples xor prior distribution is provided
        assert (nuisance_dist is None) ^ (nuisance_parameter_samples is None), "Either nuisance parameter prior distribution or samples must be provided"
        self.nuisance_dist = nuisance_dist
        self.nuisance_parameter_samples = nuisance_parameter_samples
        
        # Check if observation noise distribution is provided
        if obs_noise_dist is None:
            class Delta(Distribution):
                """ Inspired by https://pytorch.org/rl/_modules/torchrl/modules/distributions/continuous.html#Delta """
                def __init__(self, loc):
                    super().__init__(validate_args=False)
                    self.loc = loc
                def sample(self, size=torch.Size()):
                    if size is None:
                        size = torch.Size([])
                    return self.loc.expand(*size, *self.loc.shape)
            self.obs_noise_dist = Delta
        else:
            self.obs_noise_dist = obs_noise_dist

        # Check if observation noise distribution is explicitly or implicitly defined
        if getattr(self.obs_noise_dist, 'log_prob', None) is None:
            self.implict_obs_noise_dist = True
        else:
            self.implict_obs_noise_dist = False
            
        self.target_forward_function = target_forward_function
        if hasattr(self.target_forward_function, 'sample'):
            self.implict_target_forward_function = True
            self.deterministic_target_forward_function = False
        elif hasattr(self.target_forward_function, 'log_prob'):
            self.implict_target_forward_function = False
            self.deterministic_target_forward_function = False
        else:
            self.deterministic_target_forward_function = True
            
    
    def get_forward_function_samples(
        self,
        num_samples: int,
        design: dict,
        model_parameters: Tensor,
        nuisance_parameters: Tensor=None,
        ) -> Tensor:
        
        if nuisance_parameters is None:
            return self.forward_function.forward(design, model_parameters[:num_samples])
        else:
            return self.forward_function.forward(design, model_parameters[:num_samples], nuisance_parameters[:num_samples])

    def get_foward_model_distribution(
        self,
        num_samples: int,
        design: dict,
        model_parameters: Tensor,
        nuisance_parameters: Tensor=None,
        ) -> Distribution:
        return self.obs_noise_dist(self.get_forward_function_samples(num_samples, design, model_parameters, nuisance_parameters), design)
    
    def get_forward_model_samples(
        self,
        num_samples: int,
        design: dict,
        model_parameters: Tensor,
        nuisance_parameters: Tensor=None,
        ) -> Tensor:
        return self.get_foward_model_distribution(num_samples, design, model_parameters, nuisance_parameters).sample()
    
    # @abstractmethod
    # def forward_function(
    #     self,
    #     design : dict,
    #     model_parameters: Tensor,
    #     nuisance_parameters: Tensor=None) -> Tensor:
        
    #     raise NotImplementedError
    
    

        
        

test_m_prior_dist = dist.Normal(0,1)
test_m_prior_samples = test_m_prior_dist.sample((100,))    
    
test_nuisance_dist = dist.Normal(0,1)
test_nuisance_parameter_samples = test_nuisance_dist.sample((100,))

test_forward_function = lambda x, y: x + y

test_observation_model = Observation_Model(
    forward_function=test_forward_function,
    m_prior_dist=test_m_prior_dist,
    # m_prior_samples=test_m_prior_samples,
    nuisance_dist=test_nuisance_dist,
    # nuisance_parameter_samples=test_nuisance_parameter_samples,
    )

print(test_observation_model.obs_noise_dist(torch.linspace(0,10,100)[:, None] ).sample().shape)
print(test_observation_model.implict_obs_noise_dist)



torch.Size([100, 1])
False
False
