In [1]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import h5py


import torch
import torch.distributions as dist 

In [2]:
def forward_function(bulk_velocity, offset):    
    # since we know the source is at x=0, we can calculate the distance to the receiver directly by using pythagoras theorem
    return torch.sqrt(0.5**2 + offset**2) / bulk_velocity

In [3]:
n_prior = int(1e5)

bulk_velocity_prior = dist.Independent(dist.Normal(torch.tensor([4.5,]), torch.tensor([0.5,])), 1)
bulk_velocity_prior_samples = bulk_velocity_prior.sample([n_prior])

In [4]:
def data_likelihood(samples, **kwargs):

    torch.manual_seed(0)
    data_likelihood = dist.Independent(dist.Normal(samples, torch.tensor(0.1)), 1)
    
    return data_likelihood

In [5]:
n_design_points = 10
design_names = [str(i) for i in range(n_design_points)]
offsets      = torch.linspace(1, 9, n_design_points)

filename_base_case = 'data/forward_model_data.h5'
dataset = 'data'

design_dicts = {}

for i, name in enumerate(design_names):
    design_dicts[name] = {'index': i, 'offset': offsets[i], 'file': filename_base_case, 'dataset': dataset, 'cost': 1.0}

In [6]:
with h5py.File(filename_base_case, 'w') as f:
    data = f.create_dataset("data", (n_prior, n_design_points, 1))
    
    for design in design_dicts:
                
        data[:, design_dicts[design]['index']] = forward_function(bulk_velocity_prior_samples, design_dicts[design]['offset'])
        
with h5py.File(filename_base_case, 'r') as f:
    data = torch.from_numpy(f['data'][:])

In [7]:
from geobed import BED_discrete

BED_class = BED_discrete(
    design_dicts, data_likelihood,
    prior_samples=bulk_velocity_prior_samples, prior_dist=bulk_velocity_prior)

In [8]:
import zuko

class GMM_guide(torch.nn.Module):
    def __init__(self, data_samples, **kwargs):
        data_mean = torch.mean(data_samples, dim=0)
        data_std = torch.std(data_samples, dim=0)
        features = data_samples.shape[-1]
    
        super().__init__()
    
        self.base = zuko.flows.GMM(features=features, **kwargs)
    
        self.transforms = [zuko.flows.Unconditional(dist.AffineTransform, -data_mean/data_std, 1/data_std, buffer=True),]# buffer=True excludes the parameters from the optimization
    
    def forward(self):
                
        transform = zuko.transforms.ComposedTransform(*(t(None) for t in self.transforms))

        base = self.base(None)

        return zuko.flows.NormalizingFlow(transform, base)
        
    def log_prob(self, x):
                
        out = self.forward().log_prob(x)        
        return out
    
    def sample(self, n_samples):
        
        shape = torch.Size([n_samples])
        return self.forward().sample(shape) 

In [13]:
N = int(1e3)
M = int(1e3)
n_batch = 100
n_epochs = 10

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau

out = BED_class.calculate_eig_list(
    design_list=[['1', '9'], ['1', '8'], ['1', '9'], ['1', '8'], ['1', '9'], ['1', '8'], ['1', '9'], ['1', '8'], ['1', '9'], ['1', '8']],
    method='variational_marginal',
    method_kwargs= {'guide': GMM_guide,
     'N': N,
     'M': M,
     'guide_kwargs': {'components':20},
     'n_batch': n_batch,
     'n_epochs': n_epochs,
     'optimizer': torch.optim.Adam,
     'optimizer_kwargs': {'lr': 1e-3},
     'scheduler': scheduler,
     'scheduler_kwargs': {'patience': 2, 'threshold': 1e-1},
     'return_guide': True,
     'return_train_loss': True,
     'return_test_loss': True,
     },
    progress_bar=True,
    num_workers=2,
    parallel_method='joblib',)

Calculating eig:   0%|          | 0/10 [00:00<?, ?it/s]

In [14]:
print(out)

[tensor([1.4377, 1.3507, 1.4377, 1.3507, 1.4377, 1.3507, 1.4377, 1.3507, 1.4377,
        1.3507]), ({'N': 1000, 'M': 1000, 'n_epochs': 10, 'n_batch': 100, 'optimizer_kwargs': {'lr': 0.001}, 'scheduler_kwargs': {'patience': 2, 'threshold': 0.1}, 'guide': GMM_guide(
  (base): GMM(
    (phi): Parameters(
        (0): tensor of shape (20,)
        (1): tensor of shape (20, 2)
        (2): tensor of shape (20, 2)
        (3): tensor of shape (20, 1)
    )
  )
), 'train_loss': [-0.09762141853570938, -0.16061685979366302, -0.10318402200937271, -0.37735259532928467, -0.1420172154903412, -0.38839206099510193, -0.23535801470279694, -0.2044912725687027, -0.27565211057662964, -0.22276811301708221, -0.2686990797519684, -0.25171998143196106, -0.302926242351532, -0.24390682578086853, -0.11313101649284363, -0.25050726532936096, -0.2005438655614853, -0.2336709201335907, -0.18671968579292297, -0.30655986070632935, -0.1593945175409317, -0.26421427726745605, -0.0627373605966568, -0.18082325160503387, -0.3