In [1]:
import pandas as pd
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.training import train_state
from functools import partial
# from relative_fitness_mechanisms.selective_pressure_prediction import (create_lagged_features, 
#                                                                        process_inputs_all, 
#                                                                        withhold_test_locations_and_split, 
#                                                                        SelectivePressureData
#                                                                        create_training_batches,
#                                                                       train_step, 
#                                                                       loss_fn)

import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import string

In [2]:
## Prepping data

import sys
sys.path.append( '../relative_fitness_mechanisms/')
import plot_utils
from selective_pressure_prediction import (create_lagged_features, 
                                           process_inputs_all, 
                                            withhold_test_locations_and_split, 
                                            SelectivePressureData,
                                            SelectivePressureTSData,
                                            create_training_batches,
                                            train_step)

from ml_utils import TrainerModule

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
selective_pressure_df = pd.read_csv("../data/selective_pressure_growth_cases_full.tsv", sep="\t")
#selective_pressure_df = pd.read_csv("../data/selective_pressure_growth_cases.tsv", sep="\t")

selective_pressure_df["date"] = pd.to_datetime(selective_pressure_df["date"])

In [4]:
selective_pressure_df = selective_pressure_df.dropna()
selective_pressure_df["log_smooth_cases"] = np.log(selective_pressure_df["smooth_cases"])

In [5]:
#excluded = ["Minnesota", "Colorado", "Tennesee", "North Dakota", "Missouri", "Indiana", "Florida", "Connecticut"]
#selective_pressure_df = selective_pressure_df[~selective_pressure_df.location.isin(excluded)]
#included = ["England", "California",   "Michigan", "Nevada", "New York", "Texas", "Washington"]
#selective_pressure_df = selective_pressure_df[selective_pressure_df.location.isin(included)]


In [6]:
input_dfs = {}
TARGETS = ["empirical_growth_rate", "log_smooth_cases"]
keep_features = ["date", "location", "selective_pressure"]
keep_targets = TARGETS

# Create lagged features by group
for loc, group in selective_pressure_df.groupby("location"):
    input_dfs[loc] = create_lagged_features(
        group[keep_features + keep_targets], 
        ["selective_pressure"], 
        28)

In [7]:
inputs_dfs = {loc: input_df.dropna() for loc, input_df in input_dfs.items()}
dates_vec, locations_vec, X, y = process_inputs_all(input_dfs, target=TARGETS)

In [8]:
WITHHELD_LOCATIONS = ["England"]
X_train, y_train, X_test, y_test = withhold_test_locations_and_split(X, y, locations_vec, WITHHELD_LOCATIONS)

In [9]:
from ml_utils import create_data_loaders

train_set = SelectivePressureTSData(
    X_train.values,
    y_train.values,
    locations_vec[X_train.index],
    dates_vec[X_train.index],
    sequence_length=32
)
val_set = SelectivePressureTSData(
    X_test.values,
    y_test.values,
    locations_vec[X_test.index],
    dates_vec[X_train.index],
    sequence_length=1
)

train_loader, val_loader = create_data_loaders(train_set, val_set,
                                                            train=[True, False],
                                                            batch_size=1,
                                                            num_workers=1)



In [10]:
from ml_utils import create_data_loaders

    
train_set = SelectivePressureData(X_train.values, y_train.values)
val_set = SelectivePressureData(X_test.values, y_test.values)

train_loader, val_loader = create_data_loaders(train_set, val_set,
                                                            train=[True, False],
                                                            batch_size=128,
                                                           num_workers=1)

In [11]:
# Defining model
from flax import linen as nn

class PositionalEncoding(nn.Module):
    max_len: int
    d_model: int

    @nn.compact
    def __call__(self, x):
        pe = self.positional_encoding(self.max_len, self.d_model)
        pe = pe[:x.shape[1], :]  # Adjust the length to match the input sequence length
        pe = jnp.expand_dims(pe, axis=0)  # Add batch dimension
        x = x + pe  # Broadcasting positional encoding to the batch size
        return x
    
    def positional_encoding(self, max_len, d_model):
        position = np.arange(max_len)[:, np.newaxis]
        div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
        pe = np.zeros((max_len, d_model))
        pe[:, 0::2] = np.sin(position * div_term)
        pe[:, 1::2] = np.cos(position * div_term)
        return pe
    
class FeedForward(nn.Module):
    d_model: int
    d_ff: int
        
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.d_ff)(x)
        x = nn.relu(x)
        x = nn.Dense(self.d_model)(x)
        return x
    
class BlockLayer(nn.Module):
    d_model: int
    num_heads: int
    d_ff: int

    @nn.compact
    def __call__(self, x):
        attn_output = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads, qkv_features=self.d_model, use_bias=True
        )(x)
        x = nn.LayerNorm()(x + attn_output)

        ff_output = FeedForward(d_model=self.d_model, d_ff=self.d_ff)(x)
        x =  nn.LayerNorm()(x + ff_output)
        return x
    
class SimpleTransformer(nn.Module):
    num_heads: int
    d_model: int  # Size of the attention representations
    d_ff: int  # Size of feedforward
    output_dim: int  # Dimensionality of the regression output

    @nn.compact
    def __call__(self, x, **kwargs):
        x = nn.Dense(self.d_model)(x)
        x = nn.relu(x)
        
        x = jnp.expand_dims(x, axis=-1)
        x = PositionalEncoding(max_len=x.shape[1], d_model=self.d_model)(x)
        
        x = BlockLayer(d_model=self.d_model, num_heads=self.num_heads, d_ff=self.d_ff)(x)
        x = BlockLayer(d_model=self.d_model, num_heads=self.num_heads, d_ff=self.d_ff)(x)
        x = BlockLayer(d_model=self.d_model, num_heads=self.num_heads, d_ff=self.d_ff)(x)

        x = nn.Dense(self.d_ff)(x)
        x = nn.relu(x)
        x = nn.Dense(self.d_ff)(x)
        x = nn.relu(x)
        x = nn.Dense(self.d_ff)(x)
        x = nn.relu(x)        
        x = nn.Dense(self.output_dim)(x)
        x = jnp.squeeze(x, axis=-1)
        x = jnp.mean(x, axis=-1)
        return x
    
# Define loss function
#combined_loss_fn = partial(loss_fn, alpha=1e-4)

In [12]:
# Defining model
from flax import linen as nn

class PositionalEncoding(nn.Module):
    max_len: int
    d_model: int

    @nn.compact
    def __call__(self, x):
        pe = self.positional_encoding(self.max_len, self.d_model)
        pe = pe[:x.shape[1], :]  # Adjust the length to match the input sequence length
        pe = jnp.expand_dims(pe, axis=0)  # Add batch dimension
        x = x + pe  # Broadcasting positional encoding to the batch size
        return x
    
    def positional_encoding(self, max_len, d_model):
        position = np.arange(max_len)[:, np.newaxis]
        div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
        pe = np.zeros((max_len, d_model))
        pe[:, 0::2] = np.sin(position * div_term)
        pe[:, 1::2] = np.cos(position * div_term)
        return pe
    
class FeedForward(nn.Module):
    d_model: int
    d_ff: int
        
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.d_ff)(x)
        x = nn.relu(x)
        x = nn.Dense(self.d_model)(x)
        return x
    
class BlockLayer(nn.Module):
    d_model: int
    num_heads: int
    d_ff: int
    dropout_rate: float

    @nn.compact
    def __call__(self, x, deterministic):
        attn_output = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads, qkv_features=self.d_model, use_bias=True
        )(x)
        x = nn.LayerNorm()(x + attn_output)
        #x = nn.Dropout(rate=self.dropout_rate, deterministic=deterministic)(x)

        ff_output = FeedForward(d_model=self.d_model, d_ff=self.d_ff)(x)
        x =  nn.LayerNorm()(x + ff_output)
        #x = nn.Dropout(rate=self.dropout_rate, deterministic=deterministic)(x)
        return x
    
class SimpleTransformer(nn.Module):
    num_heads: int
    d_model: int  # Size of the attention representations
    d_ff: int  # Size of feedforward
    output_dim: int  # Dimensionality of the regression output
    dropout_rate: float # Probability of dropout

    @nn.compact
    def __call__(self, x, deterministic, **kwargs):
        x = nn.Dense(self.d_model)(x)
        x = nn.relu(x)
        x = nn.Dense(self.d_model)(x)
        x = nn.relu(x)

        x = jnp.expand_dims(x, axis=-1)
        x = PositionalEncoding(max_len=x.shape[1], d_model=self.d_model)(x)
        
        for _ in range(3):
            x = BlockLayer(
                d_model=self.d_model,
                num_heads=self.num_heads,
                d_ff=self.d_ff,
                dropout_rate=self.dropout_rate
            )(x, deterministic)

        x = nn.Dense(self.d_ff)(x)
        x = nn.relu(x)
        x = nn.Dense(self.d_ff)(x)
        x = nn.relu(x)        
        x = nn.Dense(self.output_dim)(x)
        x = jnp.mean(x, axis=1)
        return x
    
# Define loss function
alpha = 1e-4

In [13]:

class TransformerTrainer(TrainerModule):

    def __init__(self,
                 num_heads, 
                 d_model, 
                 d_ff, 
                 output_dim,
                 **kwargs):
        super().__init__(
            model_class=SimpleTransformer, 
            model_hparams={
                'num_heads': num_heads,
                'd_model': d_model,
                'd_ff': d_ff,
                'output_dim': output_dim
            }, **kwargs)

    def create_functions(self):
        def train_step(state, batch):
            """
            Perform a single training step.
            
            Args:
              state: The current training state.
              batch: A batch of training data.
              
            Returns:
              Updated training state and training metrics.
            """
            x, y = batch
            grad_fn = jax.value_and_grad(lambda p: combined_loss_fn({'params': p}, state, x, y))
            loss, grads = grad_fn(state.params)
            state = state.apply_gradients(grads=grads)

            metrics = {
                'loss': loss,
            }
            return state, metrics

        def eval_step(state, batch):
            """
            Perform a single evaluation step.
            
            Args:
              state: The current training state.
              batch: A batch of evaluation data.
              
            Returns:
              Evaluation metrics.
            """
            x, y = batch
            loss = combined_loss_fn({'params': state.params}, state, x, y)
            
            metrics = {
                'loss': loss,
            }

            return metrics

        return train_step, eval_step

In [14]:
from jax import random, jit

def smoothness_loss(predictions, x, boundary_handling=True):
    # Shifted predictions for t+1 and t-1
    predictions_shifted_right = jnp.roll(predictions, -1, axis=0)
    predictions_shifted_left = jnp.roll(predictions, 1, axis=0)

    # Compute the finite differences second derivative
    second_derivative = (
        predictions_shifted_left - 2 * predictions + predictions_shifted_right
    )

    if boundary_handling:
        where_end = (np.arange(predictions.shape[0]) == 0) | (np.arange(predictions.shape[0]) == predictions.shape[0] - 1)
        second_derivative = second_derivative.at[where_end, :].set(0.)

    # Penalize large derivatives (i.e., encourage smoothness)
    return jnp.mean(jnp.square(second_derivative))


class TransformerTrainer(TrainerModule):

    def __init__(self,
                 num_heads, 
                 d_model, 
                 d_ff, 
                 output_dim,
                 dropout_rate,
                 **kwargs):
        super().__init__(
            model_class=SimpleTransformer, 
            model_hparams={
                'num_heads': num_heads,
                'd_model': d_model,
                'd_ff': d_ff,
                'dropout_rate': dropout_rate,
                'output_dim': output_dim
            }, **kwargs)
        
    def run_model_init(self, exmp_input, init_rng):
        x = exmp_input[0]
        #x = x.squeeze(0)
        init_rng, dropout_rng = random.split(init_rng)
        return self.model.init({'params': init_rng, 'dropout': dropout_rng}, x=x, deterministic=True)
    
    def print_tabulate(self, exmp_input):
        x = exmp_input[0]
        #x = x.squeeze(0)
        print(self.model.tabulate(rngs={'params': random.PRNGKey(0), 'dropout': random.PRNGKey(1)}, x=x, deterministic=False))
    
    def create_functions(self):
        
        def loss_fn(params, state, x, y, rng, train):
            rng, dropout_rng = random.split(rng)
            predictions = state.apply_fn(params, x, deterministic = not train, rngs={'dropout': dropout_rng},)
            mae_loss = jnp.mean(jnp.abs(predictions - y))
            #mse_loss = jnp.mean(jnp.square(predictions - y))
            loss_smooth = smoothness_loss(predictions, x)
            return mae_loss + alpha * loss_smooth, (mae_loss, loss_smooth, rng)

        @jit
        def train_step(state, batch):
            """
            Perform a single training step.
            
            Args:
              state: The current training state.
              batch: A batch of training data.
              
            Returns:
              Updated training state and training metrics.
            """
            x, y = batch
            #x = x.squeeze(0)
            #y = y.squeeze(0)
            grad_fn = jax.value_and_grad(lambda p: loss_fn({'params': p}, state, x, y, state.rng, train=True), has_aux=True)
            ret, grads = grad_fn(state.params)
            loss, (mae_loss, loss_smooth, rng) = ret
            state = state.apply_gradients(grads=grads, rng=rng)

            metrics = {
                'loss': loss,
                'mae': mae_loss,
                'smooth': loss_smooth
            }
            return state, metrics

        @jit
        def eval_step(state, batch):
            """
            Perform a single evaluation step.
            
            Args:
              state: The current training state.
              batch: A batch of evaluation data.
              
            Returns:
              Evaluation metrics.
            """
            x, y = batch
            #x = x.squeeze(0)
            #y = y.squeeze(0)
            loss, _ = loss_fn({'params': state.params}, state, x, y, state.rng, train=False)
            metrics = {
                'loss': loss,
            }

            return metrics

        return train_step, eval_step

In [15]:
import os

trainer = TransformerTrainer(
    num_heads=16, 
    d_model= 64, 
    d_ff=64, 
    output_dim=2,
    dropout_rate=0.3,
    optimizer="adamw",
    optimizer_hparams={'lr': 4e-4}, 
    exmp_input=next(iter(train_loader))[0:1],
    logger_params={'base_log_dir': os.path.abspath("./saved_models/")})




[3m                           SimpleTransformer Summary                            [0m
┏━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath         [0m[1m [0m┃[1m [0m[1mmodule       [0m[1m [0m┃[1m [0m[1minputs       [0m[1m [0m┃[1m [0m[1moutputs     [0m[1m [0m┃[1m [0m[1mparams       [0m[1m [0m┃
┡━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│               │ SimpleTransf… │ deterministi… │ [2mfloat32[0m[128… │               │
│               │               │ False         │              │               │
│               │               │ x:            │              │               │
│               │               │ [2mfloat32[0m[128,… │              │               │
├───────────────┼───────────────┼───────────────┼──────────────┼───────────────┤
│ Dense_0       │ Dense         │ [2mfloat32[0m[128,… │ [2mfloat32[0m[128… │ bias:         │
│               │            

In [None]:
metrics = trainer.train_model(train_loader,
                              val_loader,
                              test_loader=None,
                              num_epochs=50);
metrics

Epochs:   0%|                                                                                                   | 0/50 [00:00<?, ?it/s]
Training:   0%|                                                                                                | 0/111 [00:00<?, ?it/s][A
Training:   1%|▊                                                                                       | 1/111 [00:04<08:03,  4.40s/it][A
Training:   2%|█▌                                                                                      | 2/111 [00:04<03:26,  1.90s/it][A
Training:   3%|██▍                                                                                     | 3/111 [00:04<01:57,  1.09s/it][A
Training:   4%|███▏                                                                                    | 4/111 [00:04<01:16,  1.40it/s][A
Training:   5%|███▉                                                                                    | 5/111 [00:04<00:53,  1.98it/s][A
Training:   5%|████▊          

Training:   0%|                                                                                                | 0/111 [00:00<?, ?it/s][A
Training:   1%|▊                                                                                       | 1/111 [00:00<00:16,  6.83it/s][A
Training:   2%|█▌                                                                                      | 2/111 [00:00<00:15,  6.87it/s][A
Training:   3%|██▍                                                                                     | 3/111 [00:00<00:15,  7.00it/s][A
Training:   4%|███▏                                                                                    | 4/111 [00:00<00:15,  7.07it/s][A
Training:   5%|███▉                                                                                    | 5/111 [00:00<00:14,  7.12it/s][A
Training:   5%|████▊                                                                                   | 6/111 [00:00<00:14,  7.12it/s][A
Training:   6%|█████▌      

Training:   1%|▊                                                                                       | 1/111 [00:00<00:18,  5.95it/s][A
Training:   2%|█▌                                                                                      | 2/111 [00:00<00:16,  6.46it/s][A
Training:   3%|██▍                                                                                     | 3/111 [00:00<00:16,  6.72it/s][A
Training:   4%|███▏                                                                                    | 4/111 [00:00<00:15,  6.83it/s][A
Training:   5%|███▉                                                                                    | 5/111 [00:00<00:15,  6.92it/s][A
Training:   5%|████▊                                                                                   | 6/111 [00:00<00:15,  6.96it/s][A
Training:   6%|█████▌                                                                                  | 7/111 [00:01<00:14,  7.02it/s][A
Training:   7%|██████▎     

Training:   2%|█▌                                                                                      | 2/111 [00:00<00:16,  6.79it/s][A
Training:   3%|██▍                                                                                     | 3/111 [00:00<00:15,  6.77it/s][A
Training:   4%|███▏                                                                                    | 4/111 [00:00<00:15,  6.83it/s][A
Training:   5%|███▉                                                                                    | 5/111 [00:00<00:15,  6.89it/s][A
Training:   5%|████▊                                                                                   | 6/111 [00:00<00:15,  6.95it/s][A
Training:   6%|█████▌                                                                                  | 7/111 [00:01<00:14,  6.98it/s][A
Training:   7%|██████▎                                                                                 | 8/111 [00:01<00:14,  6.94it/s][A
Training:   8%|███████▏    

Training:   3%|██▍                                                                                     | 3/111 [00:00<00:15,  6.76it/s][A
Training:   4%|███▏                                                                                    | 4/111 [00:00<00:15,  6.84it/s][A
Training:   5%|███▉                                                                                    | 5/111 [00:00<00:15,  6.87it/s][A
Training:   5%|████▊                                                                                   | 6/111 [00:00<00:15,  6.90it/s][A
Training:   6%|█████▌                                                                                  | 7/111 [00:01<00:15,  6.90it/s][A
Training:   7%|██████▎                                                                                 | 8/111 [00:01<00:14,  6.96it/s][A
Training:   8%|███████▏                                                                                | 9/111 [00:01<00:14,  6.98it/s][A
Training:   9%|███████▊    

Training:   4%|███▏                                                                                    | 4/111 [00:00<00:15,  6.99it/s][A
Training:   5%|███▉                                                                                    | 5/111 [00:00<00:15,  6.99it/s][A
Training:   5%|████▊                                                                                   | 6/111 [00:00<00:15,  6.98it/s][A
Training:   6%|█████▌                                                                                  | 7/111 [00:01<00:14,  7.03it/s][A
Training:   7%|██████▎                                                                                 | 8/111 [00:01<00:14,  7.01it/s][A
Training:   8%|███████▏                                                                                | 9/111 [00:01<00:14,  6.98it/s][A
Training:   9%|███████▊                                                                               | 10/111 [00:01<00:14,  6.96it/s][A
Training:  10%|████████▌   

Training:   5%|███▉                                                                                    | 5/111 [00:00<00:15,  6.98it/s][A
Training:   5%|████▊                                                                                   | 6/111 [00:00<00:15,  6.98it/s][A
Training:   6%|█████▌                                                                                  | 7/111 [00:01<00:14,  7.02it/s][A
Training:   7%|██████▎                                                                                 | 8/111 [00:01<00:14,  7.03it/s][A
Training:   8%|███████▏                                                                                | 9/111 [00:01<00:14,  7.03it/s][A
Training:   9%|███████▊                                                                               | 10/111 [00:01<00:14,  7.00it/s][A
Training:  10%|████████▌                                                                              | 11/111 [00:01<00:14,  6.99it/s][A
Training:  11%|█████████▍  

Training:   5%|████▊                                                                                   | 6/111 [00:00<00:15,  6.98it/s][A
Training:   6%|█████▌                                                                                  | 7/111 [00:01<00:14,  6.97it/s][A
Training:   7%|██████▎                                                                                 | 8/111 [00:01<00:14,  7.01it/s][A
Training:   8%|███████▏                                                                                | 9/111 [00:01<00:14,  7.04it/s][A
Training:   9%|███████▊                                                                               | 10/111 [00:01<00:14,  7.03it/s][A
Training:  10%|████████▌                                                                              | 11/111 [00:01<00:14,  7.00it/s][A
Training:  11%|█████████▍                                                                             | 12/111 [00:01<00:14,  6.99it/s][A
Training:  12%|██████████▏ 

Training:   6%|█████▌                                                                                  | 7/111 [00:01<00:15,  6.86it/s][A
Training:   7%|██████▎                                                                                 | 8/111 [00:01<00:14,  6.90it/s][A
Training:   8%|███████▏                                                                                | 9/111 [00:01<00:14,  6.94it/s][A
Training:   9%|███████▊                                                                               | 10/111 [00:01<00:14,  6.93it/s][A
Training:  10%|████████▌                                                                              | 11/111 [00:01<00:14,  6.92it/s][A
Training:  11%|█████████▍                                                                             | 12/111 [00:01<00:14,  6.91it/s][A
Training:  12%|██████████▏                                                                            | 13/111 [00:01<00:14,  6.96it/s][A
Training:  13%|██████████▉ 

Training:   7%|██████▎                                                                                 | 8/111 [00:01<00:14,  7.02it/s][A
Training:   8%|███████▏                                                                                | 9/111 [00:01<00:14,  7.02it/s][A
Training:   9%|███████▊                                                                               | 10/111 [00:01<00:14,  7.00it/s][A
Training:  10%|████████▌                                                                              | 11/111 [00:01<00:14,  7.00it/s][A
Training:  11%|█████████▍                                                                             | 12/111 [00:01<00:14,  6.93it/s][A
Training:  12%|██████████▏                                                                            | 13/111 [00:01<00:14,  6.89it/s][A
Training:  13%|██████████▉                                                                            | 14/111 [00:02<00:14,  6.82it/s][A
Training:  14%|███████████▊

Training:   8%|███████▏                                                                                | 9/111 [00:01<00:14,  7.04it/s][A
Training:   9%|███████▊                                                                               | 10/111 [00:01<00:14,  7.01it/s][A
Training:  10%|████████▌                                                                              | 11/111 [00:01<00:14,  7.04it/s][A
Training:  11%|█████████▍                                                                             | 12/111 [00:01<00:14,  7.07it/s][A
Training:  12%|██████████▏                                                                            | 13/111 [00:01<00:13,  7.08it/s][A
Training:  13%|██████████▉                                                                            | 14/111 [00:01<00:13,  7.09it/s][A
Training:  14%|███████████▊                                                                           | 15/111 [00:02<00:13,  7.08it/s][A
Training:  14%|████████████

Training:   9%|███████▊                                                                               | 10/111 [00:01<00:14,  6.85it/s][A
Training:  10%|████████▌                                                                              | 11/111 [00:01<00:14,  6.74it/s][A
Training:  11%|█████████▍                                                                             | 12/111 [00:01<00:14,  6.73it/s][A
Training:  12%|██████████▏                                                                            | 13/111 [00:01<00:14,  6.75it/s][A
Training:  13%|██████████▉                                                                            | 14/111 [00:02<00:14,  6.76it/s][A
Training:  14%|███████████▊                                                                           | 15/111 [00:02<00:14,  6.79it/s][A
Training:  14%|████████████▌                                                                          | 16/111 [00:02<00:13,  6.81it/s][A
Training:  15%|████████████

Training:  10%|████████▌                                                                              | 11/111 [00:01<00:14,  6.93it/s][A
Training:  11%|█████████▍                                                                             | 12/111 [00:01<00:14,  6.96it/s][A
Training:  12%|██████████▏                                                                            | 13/111 [00:01<00:14,  6.96it/s][A
Training:  13%|██████████▉                                                                            | 14/111 [00:02<00:13,  6.97it/s][A
Training:  14%|███████████▊                                                                           | 15/111 [00:02<00:13,  6.99it/s][A
Training:  14%|████████████▌                                                                          | 16/111 [00:02<00:13,  6.99it/s][A
Training:  15%|█████████████▎                                                                         | 17/111 [00:02<00:13,  6.97it/s][A
Training:  16%|████████████

Training:  11%|█████████▍                                                                             | 12/111 [00:01<00:14,  6.88it/s][A
Training:  12%|██████████▏                                                                            | 13/111 [00:01<00:14,  6.92it/s][A
Training:  13%|██████████▉                                                                            | 14/111 [00:02<00:14,  6.91it/s][A
Training:  14%|███████████▊                                                                           | 15/111 [00:02<00:13,  6.88it/s][A
Training:  14%|████████████▌                                                                          | 16/111 [00:02<00:13,  6.92it/s][A
Training:  15%|█████████████▎                                                                         | 17/111 [00:02<00:13,  6.98it/s][A
Training:  16%|██████████████                                                                         | 18/111 [00:02<00:13,  6.99it/s][A
Training:  17%|████████████

Training:  12%|██████████▏                                                                            | 13/111 [00:01<00:14,  6.84it/s][A
Training:  13%|██████████▉                                                                            | 14/111 [00:02<00:14,  6.84it/s][A
Training:  14%|███████████▊                                                                           | 15/111 [00:02<00:13,  6.88it/s][A
Training:  14%|████████████▌                                                                          | 16/111 [00:02<00:13,  6.91it/s][A
Training:  15%|█████████████▎                                                                         | 17/111 [00:02<00:13,  6.93it/s][A
Training:  16%|██████████████                                                                         | 18/111 [00:02<00:13,  6.97it/s][A
Training:  17%|██████████████▉                                                                        | 19/111 [00:02<00:13,  6.99it/s][A
Training:  18%|████████████

Training:  13%|██████████▉                                                                            | 14/111 [00:02<00:13,  7.06it/s][A
Training:  14%|███████████▊                                                                           | 15/111 [00:02<00:13,  7.07it/s][A
Training:  14%|████████████▌                                                                          | 16/111 [00:02<00:13,  7.08it/s][A
Training:  15%|█████████████▎                                                                         | 17/111 [00:02<00:13,  7.06it/s][A
Training:  16%|██████████████                                                                         | 18/111 [00:02<00:13,  7.04it/s][A
Training:  17%|██████████████▉                                                                        | 19/111 [00:02<00:13,  7.00it/s][A
Training:  18%|███████████████▋                                                                       | 20/111 [00:02<00:12,  7.00it/s][A
Training:  19%|████████████

Training:  14%|███████████▊                                                                           | 15/111 [00:02<00:13,  7.08it/s][A
Training:  14%|████████████▌                                                                          | 16/111 [00:02<00:13,  7.08it/s][A
Training:  15%|█████████████▎                                                                         | 17/111 [00:02<00:13,  7.05it/s][A
Training:  16%|██████████████                                                                         | 18/111 [00:02<00:13,  7.04it/s][A
Training:  17%|██████████████▉                                                                        | 19/111 [00:02<00:13,  7.03it/s][A
Training:  18%|███████████████▋                                                                       | 20/111 [00:02<00:12,  7.04it/s][A
Training:  19%|████████████████▍                                                                      | 21/111 [00:02<00:12,  7.06it/s][A
Training:  20%|████████████

Training:  14%|████████████▌                                                                          | 16/111 [00:02<00:13,  6.97it/s][A
Training:  15%|█████████████▎                                                                         | 17/111 [00:02<00:13,  6.99it/s][A
Training:  16%|██████████████                                                                         | 18/111 [00:02<00:13,  6.98it/s][A
Training:  17%|██████████████▉                                                                        | 19/111 [00:02<00:13,  7.00it/s][A
Training:  18%|███████████████▋                                                                       | 20/111 [00:02<00:13,  6.94it/s][A
Training:  19%|████████████████▍                                                                      | 21/111 [00:03<00:12,  6.95it/s][A
Training:  20%|█████████████████▏                                                                     | 22/111 [00:03<00:12,  6.97it/s][A
Training:  21%|████████████

Training:  15%|█████████████▎                                                                         | 17/111 [00:02<00:13,  7.00it/s][A
Training:  16%|██████████████                                                                         | 18/111 [00:02<00:13,  7.06it/s][A
Training:  17%|██████████████▉                                                                        | 19/111 [00:02<00:13,  7.05it/s][A
Training:  18%|███████████████▋                                                                       | 20/111 [00:02<00:12,  7.08it/s][A
Training:  19%|████████████████▍                                                                      | 21/111 [00:03<00:12,  7.05it/s][A
Training:  20%|█████████████████▏                                                                     | 22/111 [00:03<00:12,  7.06it/s][A
Training:  21%|██████████████████                                                                     | 23/111 [00:03<00:12,  7.05it/s][A
Training:  22%|████████████

Training:  16%|██████████████                                                                         | 18/111 [00:02<00:13,  6.82it/s][A
Training:  17%|██████████████▉                                                                        | 19/111 [00:02<00:13,  6.87it/s][A
Training:  18%|███████████████▋                                                                       | 20/111 [00:02<00:13,  6.92it/s][A
Training:  19%|████████████████▍                                                                      | 21/111 [00:03<00:12,  6.94it/s][A
Training:  20%|█████████████████▏                                                                     | 22/111 [00:03<00:12,  6.99it/s][A
Training:  21%|██████████████████                                                                     | 23/111 [00:03<00:12,  7.00it/s][A
Training:  22%|██████████████████▊                                                                    | 24/111 [00:03<00:12,  7.05it/s][A
Training:  23%|████████████

Training:  17%|██████████████▉                                                                        | 19/111 [00:02<00:14,  6.43it/s][A
Training:  18%|███████████████▋                                                                       | 20/111 [00:02<00:14,  6.48it/s][A
Training:  19%|████████████████▍                                                                      | 21/111 [00:03<00:13,  6.56it/s][A
Training:  20%|█████████████████▏                                                                     | 22/111 [00:03<00:13,  6.67it/s][A
Training:  21%|██████████████████                                                                     | 23/111 [00:03<00:13,  6.72it/s][A
Training:  22%|██████████████████▊                                                                    | 24/111 [00:03<00:12,  6.80it/s][A
Training:  23%|███████████████████▌                                                                   | 25/111 [00:03<00:12,  6.86it/s][A
Training:  23%|████████████

Training:  18%|███████████████▋                                                                       | 20/111 [00:02<00:13,  6.96it/s][A
Training:  19%|████████████████▍                                                                      | 21/111 [00:03<00:12,  6.97it/s][A
Training:  20%|█████████████████▏                                                                     | 22/111 [00:03<00:12,  6.98it/s][A
Training:  21%|██████████████████                                                                     | 23/111 [00:03<00:12,  7.01it/s][A
Training:  22%|██████████████████▊                                                                    | 24/111 [00:03<00:12,  7.01it/s][A
Training:  23%|███████████████████▌                                                                   | 25/111 [00:03<00:12,  6.99it/s][A
Training:  23%|████████████████████▍                                                                  | 26/111 [00:03<00:12,  6.98it/s][A
Training:  24%|████████████

Training:  19%|████████████████▍                                                                      | 21/111 [00:02<00:12,  7.09it/s][A
Training:  20%|█████████████████▏                                                                     | 22/111 [00:03<00:12,  7.12it/s][A
Training:  21%|██████████████████                                                                     | 23/111 [00:03<00:12,  7.13it/s][A
Training:  22%|██████████████████▊                                                                    | 24/111 [00:03<00:12,  7.14it/s][A
Training:  23%|███████████████████▌                                                                   | 25/111 [00:03<00:12,  7.14it/s][A
Training:  23%|████████████████████▍                                                                  | 26/111 [00:03<00:11,  7.12it/s][A
Training:  24%|█████████████████████▏                                                                 | 27/111 [00:03<00:11,  7.09it/s][A
Training:  25%|████████████

Training:  20%|█████████████████▏                                                                     | 22/111 [00:03<00:12,  6.95it/s][A
Training:  21%|██████████████████                                                                     | 23/111 [00:03<00:12,  6.96it/s][A
Training:  22%|██████████████████▊                                                                    | 24/111 [00:03<00:12,  6.96it/s][A
Training:  23%|███████████████████▌                                                                   | 25/111 [00:03<00:12,  6.99it/s][A
Training:  23%|████████████████████▍                                                                  | 26/111 [00:03<00:12,  6.99it/s][A
Training:  24%|█████████████████████▏                                                                 | 27/111 [00:03<00:12,  7.00it/s][A
Training:  25%|█████████████████████▉                                                                 | 28/111 [00:04<00:11,  7.00it/s][A
Training:  26%|████████████

Training:  21%|██████████████████                                                                     | 23/111 [00:03<00:12,  7.00it/s][A
Training:  22%|██████████████████▊                                                                    | 24/111 [00:03<00:12,  6.99it/s][A
Training:  23%|███████████████████▌                                                                   | 25/111 [00:03<00:12,  7.00it/s][A
Training:  23%|████████████████████▍                                                                  | 26/111 [00:03<00:12,  7.00it/s][A
Training:  24%|█████████████████████▏                                                                 | 27/111 [00:03<00:11,  7.02it/s][A
Training:  25%|█████████████████████▉                                                                 | 28/111 [00:04<00:11,  7.00it/s][A
Training:  26%|██████████████████████▋                                                                | 29/111 [00:04<00:11,  6.99it/s][A
Training:  27%|████████████

Training:  22%|██████████████████▊                                                                    | 24/111 [00:03<00:12,  6.97it/s][A
Training:  23%|███████████████████▌                                                                   | 25/111 [00:03<00:12,  6.98it/s][A
Training:  23%|████████████████████▍                                                                  | 26/111 [00:03<00:12,  7.01it/s][A
Training:  24%|█████████████████████▏                                                                 | 27/111 [00:03<00:12,  6.98it/s][A
Training:  25%|█████████████████████▉                                                                 | 28/111 [00:04<00:11,  7.00it/s][A
Training:  26%|██████████████████████▋                                                                | 29/111 [00:04<00:11,  7.01it/s][A
Training:  27%|███████████████████████▌                                                               | 30/111 [00:04<00:11,  7.02it/s][A
Training:  28%|████████████

Training:  23%|███████████████████▌                                                                   | 25/111 [00:03<00:12,  6.93it/s][A
Training:  23%|████████████████████▍                                                                  | 26/111 [00:03<00:12,  6.91it/s][A
Training:  24%|█████████████████████▏                                                                 | 27/111 [00:03<00:12,  6.94it/s][A
Training:  25%|█████████████████████▉                                                                 | 28/111 [00:04<00:11,  6.93it/s][A
Training:  26%|██████████████████████▋                                                                | 29/111 [00:04<00:11,  6.94it/s][A
Training:  27%|███████████████████████▌                                                               | 30/111 [00:04<00:11,  6.92it/s][A
Training:  28%|████████████████████████▎                                                              | 31/111 [00:04<00:11,  6.97it/s][A
Training:  29%|████████████

Training:  23%|████████████████████▍                                                                  | 26/111 [00:03<00:12,  7.03it/s][A
Training:  24%|█████████████████████▏                                                                 | 27/111 [00:03<00:11,  7.02it/s][A
Training:  25%|█████████████████████▉                                                                 | 28/111 [00:04<00:11,  7.04it/s][A
Training:  26%|██████████████████████▋                                                                | 29/111 [00:04<00:11,  6.99it/s][A
Training:  27%|███████████████████████▌                                                               | 30/111 [00:04<00:11,  6.98it/s][A
Training:  28%|████████████████████████▎                                                              | 31/111 [00:04<00:11,  6.97it/s][A
Training:  29%|█████████████████████████                                                              | 32/111 [00:04<00:11,  7.00it/s][A
Training:  30%|████████████

Training:  24%|█████████████████████▏                                                                 | 27/111 [00:03<00:11,  7.05it/s][A
Training:  25%|█████████████████████▉                                                                 | 28/111 [00:04<00:11,  7.05it/s][A
Training:  26%|██████████████████████▋                                                                | 29/111 [00:04<00:11,  7.07it/s][A
Training:  27%|███████████████████████▌                                                               | 30/111 [00:04<00:11,  7.06it/s][A
Training:  28%|████████████████████████▎                                                              | 31/111 [00:04<00:11,  7.06it/s][A
Training:  29%|█████████████████████████                                                              | 32/111 [00:04<00:11,  7.00it/s][A
Training:  30%|█████████████████████████▊                                                             | 33/111 [00:04<00:11,  7.01it/s][A
Training:  31%|████████████

Training:  25%|█████████████████████▉                                                                 | 28/111 [00:03<00:11,  7.04it/s][A
Training:  26%|██████████████████████▋                                                                | 29/111 [00:04<00:11,  7.01it/s][A
Training:  27%|███████████████████████▌                                                               | 30/111 [00:04<00:11,  7.00it/s][A
Training:  28%|████████████████████████▎                                                              | 31/111 [00:04<00:11,  7.01it/s][A
Training:  29%|█████████████████████████                                                              | 32/111 [00:04<00:11,  7.01it/s][A
Training:  30%|█████████████████████████▊                                                             | 33/111 [00:04<00:11,  7.03it/s][A
Training:  31%|██████████████████████████▋                                                            | 34/111 [00:04<00:10,  7.01it/s][A
Training:  32%|████████████

Training:  26%|██████████████████████▋                                                                | 29/111 [00:04<00:11,  6.84it/s][A
Training:  27%|███████████████████████▌                                                               | 30/111 [00:04<00:11,  6.85it/s][A
Training:  28%|████████████████████████▎                                                              | 31/111 [00:04<00:11,  6.86it/s][A
Training:  29%|█████████████████████████                                                              | 32/111 [00:04<00:11,  6.87it/s][A
Training:  30%|█████████████████████████▊                                                             | 33/111 [00:04<00:11,  6.81it/s][A
Training:  31%|██████████████████████████▋                                                            | 34/111 [00:04<00:11,  6.74it/s][A
Training:  32%|███████████████████████████▍                                                           | 35/111 [00:05<00:11,  6.71it/s][A
Training:  32%|████████████

Training:  27%|███████████████████████▌                                                               | 30/111 [00:04<00:11,  7.19it/s][A
Training:  28%|████████████████████████▎                                                              | 31/111 [00:04<00:11,  7.23it/s][A
Training:  29%|█████████████████████████                                                              | 32/111 [00:04<00:11,  7.17it/s][A
Training:  30%|█████████████████████████▊                                                             | 33/111 [00:04<00:10,  7.15it/s][A
Training:  31%|██████████████████████████▋                                                            | 34/111 [00:04<00:10,  7.15it/s][A
Training:  32%|███████████████████████████▍                                                           | 35/111 [00:05<00:10,  7.18it/s][A
Training:  32%|████████████████████████████▏                                                          | 36/111 [00:05<00:10,  7.20it/s][A
Training:  33%|████████████

Training:  28%|████████████████████████▎                                                              | 31/111 [00:04<00:11,  7.01it/s][A
Training:  29%|█████████████████████████                                                              | 32/111 [00:04<00:11,  7.00it/s][A
Training:  30%|█████████████████████████▊                                                             | 33/111 [00:04<00:11,  7.01it/s][A
Training:  31%|██████████████████████████▋                                                            | 34/111 [00:04<00:10,  7.04it/s][A
Training:  32%|███████████████████████████▍                                                           | 35/111 [00:04<00:10,  7.04it/s][A
Training:  32%|████████████████████████████▏                                                          | 36/111 [00:05<00:10,  7.03it/s][A
Training:  33%|█████████████████████████████                                                          | 37/111 [00:05<00:10,  7.03it/s][A
Training:  34%|████████████

Training:  29%|█████████████████████████                                                              | 32/111 [00:04<00:11,  6.93it/s][A
Training:  30%|█████████████████████████▊                                                             | 33/111 [00:04<00:11,  6.94it/s][A
Training:  31%|██████████████████████████▋                                                            | 34/111 [00:04<00:11,  6.88it/s][A
Training:  32%|███████████████████████████▍                                                           | 35/111 [00:05<00:11,  6.90it/s][A
Training:  32%|████████████████████████████▏                                                          | 36/111 [00:05<00:10,  6.92it/s][A
Training:  33%|█████████████████████████████                                                          | 37/111 [00:05<00:10,  6.93it/s][A
Training:  34%|█████████████████████████████▊                                                         | 38/111 [00:05<00:10,  6.90it/s][A
Training:  35%|████████████

Training:  30%|█████████████████████████▊                                                             | 33/111 [00:04<00:11,  6.80it/s][A
Training:  31%|██████████████████████████▋                                                            | 34/111 [00:04<00:11,  6.87it/s][A
Training:  32%|███████████████████████████▍                                                           | 35/111 [00:05<00:11,  6.77it/s][A
Training:  32%|████████████████████████████▏                                                          | 36/111 [00:05<00:11,  6.74it/s][A
Training:  33%|█████████████████████████████                                                          | 37/111 [00:05<00:10,  6.84it/s][A
Training:  34%|█████████████████████████████▊                                                         | 38/111 [00:05<00:10,  6.88it/s][A
Training:  35%|██████████████████████████████▌                                                        | 39/111 [00:05<00:10,  6.96it/s][A
Training:  36%|████████████

Training:  31%|██████████████████████████▋                                                            | 34/111 [00:04<00:11,  6.98it/s][A
Training:  32%|███████████████████████████▍                                                           | 35/111 [00:04<00:10,  6.98it/s][A
Training:  32%|████████████████████████████▏                                                          | 36/111 [00:05<00:10,  7.01it/s][A
Training:  33%|█████████████████████████████                                                          | 37/111 [00:05<00:10,  7.00it/s][A
Training:  34%|█████████████████████████████▊                                                         | 38/111 [00:05<00:10,  6.99it/s][A
Training:  35%|██████████████████████████████▌                                                        | 39/111 [00:05<00:10,  7.00it/s][A
Training:  36%|███████████████████████████████▎                                                       | 40/111 [00:05<00:10,  7.02it/s][A
Training:  37%|████████████

Training:  32%|███████████████████████████▍                                                           | 35/111 [00:04<00:10,  7.05it/s][A
Training:  32%|████████████████████████████▏                                                          | 36/111 [00:05<00:10,  7.03it/s][A
Training:  33%|█████████████████████████████                                                          | 37/111 [00:05<00:10,  7.04it/s][A
Training:  34%|█████████████████████████████▊                                                         | 38/111 [00:05<00:10,  7.00it/s][A
Training:  35%|██████████████████████████████▌                                                        | 39/111 [00:05<00:10,  7.02it/s][A
Training:  36%|███████████████████████████████▎                                                       | 40/111 [00:05<00:10,  7.05it/s][A
Training:  37%|████████████████████████████████▏                                                      | 41/111 [00:05<00:09,  7.05it/s][A
Training:  38%|████████████

In [None]:
model_bd = trainer.bind_model()

In [None]:
def wrap_y(y):
    return pd.DataFrame({target: y[:, t] for t, target in enumerate(TARGETS)})    

def predict(X):
    pred = model_bd(X, deterministic=True)
    return wrap_y(pred)

# TODO: Remake plots so that we can look at each cariable 

In [None]:
predict(X_train.values)

In [None]:
# Subset locations
locations_subset = ["California", "Michigan", "Nevada", "New York", "Texas", "Washington"]
target_locations = ["England"]

In [None]:
fig = plt.figure(figsize=(20, 10))
spec = fig.add_gridspec(ncols=3, nrows=len(locations_subset) // 3 + 1)

TRAIN_PERIOD_COLOR, TEST_PERIOD_COLOR = "#d1c7c9", "#949fa5"
PREDICTION_COLOR = "#006633"
DATA_COLOR = "grey"

plot_target = 0
target_labels = ["Empirical Growth Rate", "Smoothed Cases"]
axes = []

data = {loc: df for loc, df in selective_pressure_df.groupby("location")}

for l, loc in enumerate(locations_subset):
    ax = fig.add_subplot(spec[l], sharex=None if l==0 else ax)
    loc_idx = locations_vec.values == loc
    loc_idx_train = loc_idx[X_train.index]
    loc_idx_test = loc_idx[X_test.index]

    # Plot data point
    ax.plot(dates_vec[X_train.index][loc_idx_train], y_train[loc_idx_train], color=DATA_COLOR, label="Data")
    ax.plot(dates_vec[X_test.index][loc_idx_test], y_test[loc_idx_test], color=DATA_COLOR)
    
    # Make predictions
    pred_train = predict(X_train[loc_idx_train].values)[plot_target]
    pred_test = predict(X_test[loc_idx_test].values)[plot_target]
    
    ax.plot(dates_vec[X_train.index][loc_idx_train], pred_train, color=PREDICTION_COLOR, label="Predicted")
    ax.plot(dates_vec[X_test.index][loc_idx_test], pred_test, color=PREDICTION_COLOR)
    
    # Highlight test and training sets
    min_train_time, max_train_time = dates_vec[X_train.index][loc_idx_train].min(), dates_vec[X_train.index][loc_idx_train].max()
    if not pd.isnull(max_train_time):
        ax.axvline(max_train_time, color="k", linestyle="--")
        ax.axvspan(min_train_time, max_train_time, color=TRAIN_PERIOD_COLOR, alpha=0.3)
    
    min_test_time, max_test_time = dates_vec[X_test.index][loc_idx_test].min(), dates_vec[X_test.index][loc_idx_test].max()
    if not pd.isnull(max_test_time):
        ax.axvspan(min_test_time, max_test_time, color=TEST_PERIOD_COLOR, alpha=0.3)
    
    ax.set_ylabel(target_labels[plot_target])
    ax.set_title(loc)
    
    # Format dates
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%b\n%Y'))
    if l == 0:
        ax.legend()
    axes.append(ax)
    
# Add in target_locations
last_l = l
for l, loc in enumerate(target_locations):
    # Add selective pressure
    _l = 0
    ax = fig.add_subplot(spec[-1, 0], sharex=None if _l==0 else ax)
    ax.plot(data[loc]["date"], data[loc]["selective_pressure"], color="lightblue")
    ax.set_ylabel("Selective Pressure")
    ax.set_title(loc)
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%b\n%Y'))
    axes.append(ax)
    
    # Add prevalence
    ax = fig.add_subplot(spec[-1, 1], sharex = ax)    
    ax.plot(data[loc]["date"], data[loc]["smooth_cases"], color="grey")
    ax.set_title(loc)
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%b\n%Y'))
    ax.set_ylabel("ONS Prevalence")
    axes.append(ax)
   
    # Add predictions
    ax = fig.add_subplot(spec[-1, -1], sharex=ax)
    loc_idx = locations_vec.values == loc
    loc_idx_train = loc_idx[X_train.index]
    loc_idx_test = loc_idx[X_test.index]
    
    # Plot data points
    ax.plot(dates_vec[X_train.index][loc_idx_train], y_train[loc_idx_train], color=DATA_COLOR, label="Data")
    ax.plot(dates_vec[X_test.index][loc_idx_test], y_test[loc_idx_test], color=DATA_COLOR)
    
    # Make predictions
    pred_train = predict(X_train[loc_idx_train].values)[plot_target]
    pred_test = predict(X_test[loc_idx_test].values)[plot_target]

    ax.plot(dates_vec[X_train.index][loc_idx_train], pred_train, color=PREDICTION_COLOR, label="Predicted")
    ax.plot(dates_vec[X_test.index][loc_idx_test], pred_test, color=PREDICTION_COLOR)
    
    # Highlight test and training sets
    min_train_time, max_train_time = dates_vec[X_train.index][loc_idx_train].min(), dates_vec[X_train.index][loc_idx_train].max()
    if not pd.isnull(max_train_time):
        ax.axvline(max_train_time, color="k", linestyle="--")
        ax.axvspan(min_train_time, max_train_time, color=TRAIN_PERIOD_COLOR, alpha=0.3)
    
    min_test_time, max_test_time = dates_vec[X_test.index][loc_idx_test].min(), dates_vec[X_test.index][loc_idx_test].max()
    if not pd.isnull(max_test_time):
        ax.axvspan(min_test_time, max_test_time, color=TEST_PERIOD_COLOR, alpha=0.3)
    
    ax.set_ylabel(target_labels[plot_target])
    ax.set_title(loc)
    
    # Format dates
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%b\n%Y'))
    axes.append(ax)

ax_labels = string.ascii_uppercase 
for ax, ax_label in zip(axes, ax_labels):
    ax.text(-0.1, 1.05, ax_label + ".", transform=ax.transAxes, size=36, weight='bold')
fig.tight_layout()