In [1]:
import os

import torch
import torch.nn as nn
import pyro
from tqdm import trange


from epidemic import Epidemic
from neural.modules import Mlp, LazyFn
from neural.aggregators import LSTMImplicitDAD
from neural.baselines import DesignBaseline, BatchDesignBaseline
from neural.critics import CriticDotProd

from estimators.bb_mi import InfoNCE
from oed.design import OED

## SIR model

We are going to study the SIR model---an SDE-based model from epidemiology (for details see e.g. [Wikipedia article](https://en.wikipedia.org/wiki/Compartmental_models_in_epidemiology)). The model is governed by two parameters -- infection rate and recovery rate, which we wish to learn. What we control is time $ \tau \in (0, 100)$, at which we measure the number of infected people in the population.

Before we begin, you have to generate some training data (if you haven't done so already). Note this may take some time to run, but you only need to do it once 

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

if os.path.exists("data/sir_sde_data.pt"):
    print("Loading data.")
    simdata = torch.load("data/sir_sde_data.pt", map_location=device)
else: 
    from epidemic_simulate_data import solve_sir_sdes
    simdata = solve_sir_sdes(
        num_samples=100000,
        device=device,
        grid=10000,
        save=True, 
        filename="sir_sde_data.pt"
    )

Loading data.


Below we define some constants required for the model. 

In [3]:
T = 2 # we wish to perform T=2 experiemnts
design_dim = (1, 1) # one observation per design; design is 1d
observation_dim = 1 #observation is 1-dimensional too -- number of infected people
latent_dim = 2 # the dimension of the parameter is 2

### Constant designs

Let's start very simple: suppose we choose two constants, say $t_1=25$ and $t_2=75$. The cells below set up a constnant design network and show a few realisations of running such a policy (i.e. regardless of the underlying paramter, our design strategy is to always query at the same times).

In [4]:
class SimpleDesignNet2Constants(DesignBaseline):
    """ 
    Design Network which returns a pre-defined constant 
    
    design1, design2: two constants
    
    The transformed design (corresponding to time) equals Sigmoid(design)*100, 
    which is a number between 0 and 100.
    """
    def __init__(self, design1, design2, design_dim=design_dim):
        super().__init__(design_dim=design_dim)
        self.design = [torch.zeros(design_dim) + design1, torch.zeros(design_dim) + design2]

    def forward(self, *design_obs_pairs):
        return self.design[len(design_obs_pairs)]

## Initialize the design net with two constants, which after transformation 
## (pass though sigmoid, multiply by 100) correspond to ~25 and 75
design_net_const = SimpleDesignNet2Constants(-1.1, 1.1)

In [6]:
sir_model2 = Epidemic(design_net=design_net_const, T=2, simdata=simdata)
for i in range(3):
    _ = sir_model2.eval(verbose=True)
    print("\n")

Example run
*True Theta: tensor([0.6063, 0.0813])*
xi1: 24.978992462158203  y1: 170.554931640625
xi2: 75.02100372314453  y2: 0.8645527362823486


Example run
*True Theta: tensor([0.2852, 0.1718])*
xi1: 24.978992462158203  y1: 25.199087142944336
xi2: 75.02100372314453  y2: 11.668272972106934


Example run
*True Theta: tensor([0.6901, 0.3115])*
xi1: 24.978992462158203  y1: 45.527217864990234
xi2: 75.02100372314453  y2: 0.0




### Optimized designs

Are these the best constants we could choose? Probably not! 

We can optimize the two constants to obtain designs that are optimal according to the Expected Information gain (EIG) objective.

In [7]:
pyro.clear_param_store()

# fix seed as the initial designs are sampled from uniform(-5, 5)
torch.manual_seed(20211101)

# <BatchDesignBaseline> (from neural.baselines) is a very simple extension 
# of the <SimpleDesignNet2Constants> class above
design_net_optimized = BatchDesignBaseline(
    T=2, 
    design_dim=design_dim, 
    design_init=torch.distributions.Uniform(torch.tensor(-5.0, device=device), torch.tensor(5.0, device=device))
)
sir_model_const_optimized = Epidemic(design_net=design_net_optimized, T=2, simdata=simdata)

print("--- Initial designs ---")
for i in range(3):
    _ = sir_model_const_optimized.eval(verbose=True)
    print("\n")

--- Initial designs ---
Example run
*True Theta: tensor([0.7385, 0.0712])*
xi1: 38.99824905395508  y1: 54.18731689453125
xi2: 91.85362243652344  y2: 0.639757513999939


Example run
*True Theta: tensor([1.4754, 0.0524])*
xi1: 38.99824905395508  y1: 86.43742370605469
xi2: 91.85362243652344  y2: 4.93392276763916


Example run
*True Theta: tensor([0.2476, 0.0391])*
xi1: 38.99824905395508  y1: 306.08258056640625
xi2: 91.85362243652344  y2: 47.065521240234375




#### Critic network

In [8]:
### We need to define a critic network -- we will train a tiny one
encoding_dim = 16
hidden_dim = 64

critic_history_encoder = LSTMImplicitDAD(
    # encoder network (MLP): encodes individual design-outcome pairs, 
    # whcih are then stacked and passed through an LSTM aggregator to 
    # get a vector of size <encoding_dim>
    encoder_network=Mlp(input_dim=[*design_dim, observation_dim], hidden_dim=hidden_dim, output_dim=encoding_dim),
    # emission network (MLP): takes the final representation and passes
    # though final ("head") layers
    emission_network=Mlp(input_dim=encoding_dim, hidden_dim=hidden_dim, output_dim=encoding_dim), 
    empty_value=torch.zeros(design_dim).to(device)
).to(device)
critic_latent_encoder = Mlp(input_dim=latent_dim, hidden_dim=[64, 128], output_dim=encoding_dim).to(device)

# Critic takes experimental histories and parameters as inputs and returns a number
# Optimal critic achieves tight bounds
critic_net = CriticDotProd(
    history_encoder_network=critic_history_encoder, 
    latent_encoder_network=critic_latent_encoder
).to(device)

In [9]:
# Let's optimize for a few steps
pyro.clear_param_store()
# First set up loss:
mi_loss = InfoNCE(
    model=sir_model_const_optimized.model, 
    critic=critic_net, 
    batch_size=256, 
    num_negative_samples=255
)


# and an otpimizer
optimizer = pyro.optim.Adam({"lr": 0.001})
oed = OED(optim=optimizer, loss=mi_loss)

num_steps=5000
num_steps_range = trange(1, num_steps + 1, desc="Loss: 0.000 ")
for i in num_steps_range:
    sir_model_const_optimized.train()
    loss = oed.step()  
    num_steps_range.set_description("Loss: {:.3f} ".format(loss))
    
print(f"--- Final designs ---")
_ = sir_model_const_optimized.eval(verbose=True)

Loss: 3.718 : 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [04:44<00:00, 17.57it/s]

--- Final designs ---
Example run
*True Theta: tensor([0.2109, 0.1383])*
xi1: 33.92506790161133  y1: 34.3079719543457
xi2: 90.68833923339844  y2: 9.713905334472656





We will need to train for longer to ensure the designs (and critic params) have converged. Although the two designs might be the optimal constants, this design strategy is still not adaptive, i.e. we are not using past information to make future design decisions.

### Adaptive Designs

We wish our designs to be a function of the history. Specifically, given we have $T=2$ designs here, we want the second design to be informed by the design-outcome pair from the first experiment. Here's how we can do that

In [10]:
pyro.clear_param_store()
torch.manual_seed(20211101)
### Set up a design network
# We do this in (essentially) the same way as with the critic net (one big difference is that
# the design net takes intermediate histories of variable length as inputs, while the critic
# only takes the full history $h_T$ (this is dealt with in the networks in neural.aggregator),
# so the code looks pretty much the same
design_net_adaptive = LSTMImplicitDAD(
    encoder_network=Mlp(input_dim=[*design_dim, observation_dim], hidden_dim=hidden_dim, output_dim=encoding_dim), 
    # note that the "head" layer here outputs a design, i.e. something of size <design_dim>
    emission_network=Mlp(input_dim=encoding_dim, hidden_dim=hidden_dim, output_dim=design_dim), 
    empty_value=torch.zeros(design_dim)
)

sir_model_adaptive=Epidemic(
    design_net=design_net_adaptive,
    T=2, 
    simdata=simdata,
    # note! we need to make sure the designs are increasing!
    # to do that, simply select designs_transform="ts" (for time series) 
    design_transform="ts"
)
print("--- Initial designs ---")
for i in range(3):
    _ = sir_model_adaptive.eval(verbose=True)
    print("\n")

--- Initial designs ---
Example run
*True Theta: tensor([0.8327, 0.0386])*
xi1: 45.891273498535156  y1: 97.0364990234375
xi2: 70.79920196533203  y2: 41.37067413330078


Example run
*True Theta: tensor([0.4836, 0.0941])*
xi1: 45.891273498535156  y1: 20.98923110961914
xi2: 70.77635192871094  y2: 1.7925260066986084


Example run
*True Theta: tensor([0.6650, 0.1082])*
xi1: 45.891273498535156  y1: 12.45532512664795
xi2: 70.7790756225586  y2: 0.0




In [11]:
### define a new critic network (same as before)
encoding_dim = 16
hidden_dim = 64

critic_history_encoder2 = LSTMImplicitDAD(
    encoder_network=Mlp(input_dim=[*design_dim, observation_dim], hidden_dim=hidden_dim, output_dim=encoding_dim),
    emission_network=Mlp(input_dim=encoding_dim, hidden_dim=hidden_dim, output_dim=encoding_dim), 
    empty_value=torch.zeros(design_dim).to(device)
).to(device)
critic_latent_encoder2 = Mlp(input_dim=latent_dim, hidden_dim=[64, 128], output_dim=encoding_dim).to(device)

critic_net2 = CriticDotProd(
    history_encoder_network=critic_history_encoder2, 
    latent_encoder_network=critic_latent_encoder2
).to(device)

Notice that the second design, $\xi_2$, is now slightly different. Let's now optimize the design network (and the critic) for a few steps.

In [12]:
# Let's optimize for a few steps
pyro.clear_param_store()
# First set up a new loss with the new model (though we could have used the old loss 
# and the old model, by just changing the design network):
mi_loss_adaptive = InfoNCE(
    model=sir_model_adaptive.model, 
    critic=critic_net2, 
    batch_size=256, 
    num_negative_samples=255
)


# and an otpimizer
optimizer = pyro.optim.Adam({"lr": 0.001})
oed = OED(optim=optimizer, loss=mi_loss_adaptive)

num_steps=5000
num_steps_range = trange(1, num_steps + 1, desc="Loss: 0.000 ")
for i in num_steps_range:
    sir_model_adaptive.train()
    loss = oed.step()  
    num_steps_range.set_description("Loss: {:.3f} ".format(loss))
    
print(f"--- Final designs ---")
for i in range(3):
    _ = sir_model_adaptive.eval(verbose=True)
    print("\n")

Loss: 3.091 : 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [05:07<00:00, 16.26it/s]


--- Final designs ---
Example run
*True Theta: tensor([0.7548, 0.1081])*
xi1: 10.930394172668457  y1: 274.59637451171875
xi2: 33.50607681274414  y2: 36.46894454956055


Example run
*True Theta: tensor([0.4764, 0.1522])*
xi1: 10.930394172668457  y1: 103.99957275390625
xi2: 33.510799407958984  y2: 49.70225524902344


Example run
*True Theta: tensor([0.3431, 0.0438])*
xi1: 10.930394172668457  y1: 50.67266082763672
xi2: 34.003257751464844  y2: 251.98411560058594


