In [1]:
import pandas as pd
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.training import train_state
from functools import partial
# from relative_fitness_mechanisms.selective_pressure_prediction import (create_lagged_features, 
#                                                                        process_inputs_all, 
#                                                                        withhold_test_locations_and_split, 
#                                                                        SelectivePressureData
#                                                                        create_training_batches,
#                                                                       train_step, 
#                                                                       loss_fn)

import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import string

In [2]:
## Prepping data

import sys
sys.path.append( '../relative_fitness_mechanisms/')
import plot_utils
from selective_pressure_prediction import (create_lagged_features, 
                                           process_inputs_all, 
                                            withhold_test_locations_and_split, 
                                            SelectivePressureData,
                                            create_training_batches,
                                            train_step, 
                                            loss_fn)

from ml_utils import TrainerModule

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
#selective_pressure_df = pd.read_csv("../data/selective_pressure_growth_cases_full.tsv", sep="\t")
selective_pressure_df = pd.read_csv("../data/selective_pressure_growth_cases.tsv", sep="\t")

selective_pressure_df["date"] = pd.to_datetime(selective_pressure_df["date"])

In [4]:
selective_pressure_df = selective_pressure_df.dropna()

In [5]:
input_dfs = {}
TARGET = "empirical_growth_rate"
keep_features = ["date", "location", "selective_pressure"]
keep_targets = [TARGET]

# Create lagged features by group
for loc, group in selective_pressure_df.groupby("location"):
    input_dfs[loc] = create_lagged_features(
        group[keep_features + keep_targets], 
        ["selective_pressure"], 
        28)

In [6]:
dates_vec, locations_vec, X, y = process_inputs_all(input_dfs, target=TARGET)

In [7]:
WITHHELD_LOCATIONS = ["England"]
X_train, y_train, X_test, y_test = withhold_test_locations_and_split(X, y, locations_vec, WITHHELD_LOCATIONS)

In [8]:
from ml_utils import create_data_loaders

    
train_set = SelectivePressureData(X_train.values, y_train.values)
val_set = SelectivePressureData(X_test.values, y_test.values)
train_loader, val_loader = create_data_loaders(train_set, val_set,
                                                            train=[True, False],
                                                            batch_size=64,
                                                            num_workers=1)

In [9]:
# Defining model
from flax import linen as nn

class PositionalEncoding(nn.Module):
    max_len: int
    d_model: int

    @nn.compact
    def __call__(self, x):
        pe = self.positional_encoding(self.max_len, self.d_model)
        pe = pe[:x.shape[1], :]  # Adjust the length to match the input sequence length
        pe = jnp.expand_dims(pe, axis=0)  # Add batch dimension
        x = x + pe  # Broadcasting positional encoding to the batch size
        return x
    
    def positional_encoding(self, max_len, d_model):
        position = np.arange(max_len)[:, np.newaxis]
        div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
        pe = np.zeros((max_len, d_model))
        pe[:, 0::2] = np.sin(position * div_term)
        pe[:, 1::2] = np.cos(position * div_term)
        return pe
    
class FeedForward(nn.Module):
    d_model: int
    d_ff: int
        
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.d_ff)(x)
        x = nn.relu(x)
        x = nn.Dense(self.d_model)(x)
        return x
    
class BlockLayer(nn.Module):
    d_model: int
    num_heads: int
    d_ff: int

    @nn.compact
    def __call__(self, x):
        attn_output = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads, qkv_features=self.d_model, use_bias=True
        )(x)
        x = nn.LayerNorm()(x + attn_output)

        ff_output = FeedForward(d_model=self.d_model, d_ff=self.d_ff)(x)
        x =  nn.LayerNorm()(x + ff_output)
        return x
    
class SimpleTransformer(nn.Module):
    num_heads: int
    d_model: int  # Size of the attention representations
    d_ff: int  # Size of feedforward
    output_dim: int  # Dimensionality of the regression output

    @nn.compact
    def __call__(self, x, **kwargs):
        x = nn.Dense(self.d_model)(x)
        x = nn.relu(x)
        
        x = jnp.expand_dims(x, axis=-1)
        x = PositionalEncoding(max_len=x.shape[1], d_model=self.d_model)(x)
        
        x = BlockLayer(d_model=self.d_model, num_heads=self.num_heads, d_ff=self.d_ff)(x)
        x = BlockLayer(d_model=self.d_model, num_heads=self.num_heads, d_ff=self.d_ff)(x)
        x = BlockLayer(d_model=self.d_model, num_heads=self.num_heads, d_ff=self.d_ff)(x)

        x = nn.Dense(self.d_ff)(x)
        x = nn.relu(x)
        x = nn.Dense(self.d_ff)(x)
        x = nn.relu(x)
        x = nn.Dense(self.d_ff)(x)
        x = nn.relu(x)        
        x = nn.Dense(self.output_dim)(x)
        x = jnp.squeeze(x, axis=-1)
        x = jnp.mean(x, axis=-1)
        return x
    
# Define loss function
combined_loss_fn = partial(loss_fn, alpha=1e-4)

In [10]:
class TransformerTrainer(TrainerModule):

    def __init__(self,
                 num_heads, 
                 d_model, 
                 d_ff, 
                 output_dim,
                 **kwargs):
        super().__init__(
            model_class=SimpleTransformer, 
            model_hparams={
                'num_heads': num_heads,
                'd_model': d_model,
                'd_ff': d_ff,
                'output_dim': output_dim
            }, **kwargs)

    def create_functions(self):

        def train_step(state, batch):
            """
            Perform a single training step.
            
            Args:
              state: The current training state.
              batch: A batch of training data.
              
            Returns:
              Updated training state and training metrics.
            """
            x, y = batch
            grad_fn = jax.value_and_grad(lambda p: combined_loss_fn({'params': p}, state, x, y))
            loss, grads = grad_fn(state.params)
            state = state.apply_gradients(grads=grads)

            metrics = {
                'loss': loss,
            }
            return state, metrics

        def eval_step(state, batch):
            """
            Perform a single evaluation step.
            
            Args:
              state: The current training state.
              batch: A batch of evaluation data.
              
            Returns:
              Evaluation metrics.
            """
            x, y = batch
            loss = combined_loss_fn({'params': state.params}, state, x, y)
            
            metrics = {
                'loss': loss,
            }

            return metrics

        return train_step, eval_step

In [11]:
import os

trainer = TransformerTrainer(
    num_heads=4, 
    d_model=32, 
    d_ff=32, 
    output_dim=1,
    optimizer_hparams={'lr': 4e-4}, 
    exmp_input=next(iter(train_loader))[0:1],
    logger_params={'base_log_dir': os.path.abspath("./saved_models/")})




[3m                           SimpleTransformer Summary                            [0m
┏━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath         [0m[1m [0m┃[1m [0m[1mmodule       [0m[1m [0m┃[1m [0m[1minputs       [0m[1m [0m┃[1m [0m[1moutputs     [0m[1m [0m┃[1m [0m[1mparams       [0m[1m [0m┃
┡━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│               │ SimpleTransf… │ -             │ [2mfloat32[0m[64]  │               │
│               │               │ [2mfloat32[0m[64,2… │              │               │
│               │               │ - train: True │              │               │
├───────────────┼───────────────┼───────────────┼──────────────┼───────────────┤
│ Dense_0       │ Dense         │ [2mfloat32[0m[64,2… │ [2mfloat32[0m[64,… │ bias:         │
│               │               │               │              │ [2mfloat32[0m[32]   │
│               │    

In [12]:
metrics = trainer.train_model(train_loader,
                              val_loader,
                              test_loader=None,
                              num_epochs=50);
metrics

Epochs:   0%|                                                                                                   | 0/50 [00:00<?, ?it/s]
Training:   0%|                                                                                                 | 0/91 [00:00<?, ?it/s][A
Training:   1%|▉                                                                                        | 1/91 [00:05<08:26,  5.62s/it][A
Training:   7%|█████▊                                                                                   | 6/91 [00:05<01:00,  1.41it/s][A
Training:  12%|██████████▋                                                                             | 11/91 [00:05<00:26,  3.06it/s][A
Training:  18%|███████████████▍                                                                        | 16/91 [00:05<00:14,  5.24it/s][A
Training:  23%|████████████████████▎                                                                   | 21/91 [00:06<00:08,  8.02it/s][A
Training:  29%|███████████████

Training:  87%|████████████████████████████████████████████████████████████████████████████▍           | 79/91 [00:01<00:00, 49.81it/s][A
Training:  92%|█████████████████████████████████████████████████████████████████████████████████▏      | 84/91 [00:01<00:00, 49.82it/s][A
Training:  98%|██████████████████████████████████████████████████████████████████████████████████████  | 89/91 [00:01<00:00, 49.73it/s][A
Epochs:  12%|██████████▉                                                                                | 6/50 [00:35<03:44,  5.10s/it]
Training:   0%|                                                                                                 | 0/91 [00:00<?, ?it/s][A
Training:   5%|████▉                                                                                    | 5/91 [00:00<00:01, 47.05it/s][A
Training:  11%|█████████▋                                                                              | 10/91 [00:00<00:01, 48.10it/s][A
Training:  18%|███████████████

Training:  52%|█████████████████████████████████████████████▍                                          | 47/91 [00:00<00:00, 48.93it/s][A
Training:  57%|██████████████████████████████████████████████████▎                                     | 52/91 [00:01<00:00, 49.17it/s][A
Training:  63%|███████████████████████████████████████████████████████                                 | 57/91 [00:01<00:00, 49.38it/s][A
Training:  69%|████████████████████████████████████████████████████████████▉                           | 63/91 [00:01<00:00, 49.79it/s][A
Training:  76%|██████████████████████████████████████████████████████████████████▋                     | 69/91 [00:01<00:00, 49.94it/s][A
Training:  82%|████████████████████████████████████████████████████████████████████████▌               | 75/91 [00:01<00:00, 50.05it/s][A
Training:  89%|██████████████████████████████████████████████████████████████████████████████▎         | 81/91 [00:01<00:00, 49.96it/s][A
Training:  95%|████████████

Training:  29%|█████████████████████████▏                                                              | 26/91 [00:00<00:01, 49.33it/s][A
Training:  35%|██████████████████████████████▉                                                         | 32/91 [00:00<00:01, 49.77it/s][A
Training:  42%|████████████████████████████████████▋                                                   | 38/91 [00:00<00:01, 49.97it/s][A
Training:  48%|██████████████████████████████████████████▌                                             | 44/91 [00:00<00:00, 50.20it/s][A
Training:  55%|████████████████████████████████████████████████▎                                       | 50/91 [00:01<00:00, 50.36it/s][A
Training:  62%|██████████████████████████████████████████████████████▏                                 | 56/91 [00:01<00:00, 50.32it/s][A
Training:  68%|███████████████████████████████████████████████████████████▉                            | 62/91 [00:01<00:00, 50.28it/s][A
Training:  75%|████████████

Training:  31%|███████████████████████████                                                             | 28/91 [00:00<00:01, 49.94it/s][A
Training:  36%|███████████████████████████████▉                                                        | 33/91 [00:00<00:01, 49.62it/s][A
Training:  42%|████████████████████████████████████▋                                                   | 38/91 [00:00<00:01, 49.51it/s][A
Training:  47%|█████████████████████████████████████████▌                                              | 43/91 [00:00<00:00, 49.64it/s][A
Training:  53%|██████████████████████████████████████████████▍                                         | 48/91 [00:00<00:00, 49.73it/s][A
Training:  58%|███████████████████████████████████████████████████▎                                    | 53/91 [00:01<00:00, 48.56it/s][A
Training:  64%|████████████████████████████████████████████████████████                                | 58/91 [00:01<00:00, 46.88it/s][A
Training:  69%|████████████

Training:  35%|██████████████████████████████▉                                                         | 32/91 [00:00<00:01, 50.02it/s][A
Training:  41%|███████████████████████████████████▊                                                    | 37/91 [00:00<00:01, 49.85it/s][A
Training:  47%|█████████████████████████████████████████▌                                              | 43/91 [00:00<00:00, 50.22it/s][A
Training:  54%|███████████████████████████████████████████████▍                                        | 49/91 [00:00<00:00, 50.20it/s][A
Training:  60%|█████████████████████████████████████████████████████▏                                  | 55/91 [00:01<00:00, 50.35it/s][A
Training:  67%|██████████████████████████████████████████████████████████▉                             | 61/91 [00:01<00:00, 50.27it/s][A
Training:  74%|████████████████████████████████████████████████████████████████▊                       | 67/91 [00:01<00:00, 50.29it/s][A
Training:  80%|████████████

Training:  38%|█████████████████████████████████▊                                                      | 35/91 [00:00<00:01, 50.43it/s][A
Training:  45%|███████████████████████████████████████▋                                                | 41/91 [00:00<00:00, 50.49it/s][A
Training:  52%|█████████████████████████████████████████████▍                                          | 47/91 [00:00<00:00, 50.59it/s][A
Training:  58%|███████████████████████████████████████████████████▎                                    | 53/91 [00:01<00:00, 50.44it/s][A
Training:  65%|█████████████████████████████████████████████████████████                               | 59/91 [00:01<00:00, 50.47it/s][A
Training:  71%|██████████████████████████████████████████████████████████████▊                         | 65/91 [00:01<00:00, 50.45it/s][A
Training:  78%|████████████████████████████████████████████████████████████████████▋                   | 71/91 [00:01<00:00, 50.45it/s][A
Training:  85%|████████████

Epochs:  82%|█████████████████████████████████████████████████████████████████████████▊                | 41/50 [03:20<00:46,  5.12s/it]
Training:   0%|                                                                                                 | 0/91 [00:00<?, ?it/s][A
Training:   5%|████▉                                                                                    | 5/91 [00:00<00:01, 47.54it/s][A
Training:  11%|█████████▋                                                                              | 10/91 [00:00<00:01, 48.84it/s][A
Training:  18%|███████████████▍                                                                        | 16/91 [00:00<00:01, 49.50it/s][A
Training:  23%|████████████████████▎                                                                   | 21/91 [00:00<00:01, 49.46it/s][A
Training:  30%|██████████████████████████                                                              | 27/91 [00:00<00:01, 49.78it/s][A
Training:  35%|███████████████

Training:  49%|███████████████████████████████████████████▌                                            | 45/91 [00:00<00:00, 46.16it/s][A
Training:  55%|████████████████████████████████████████████████▎                                       | 50/91 [00:01<00:00, 46.25it/s][A
Training:  60%|█████████████████████████████████████████████████████▏                                  | 55/91 [00:01<00:00, 46.39it/s][A
Training:  66%|██████████████████████████████████████████████████████████                              | 60/91 [00:01<00:00, 46.10it/s][A
Training:  71%|██████████████████████████████████████████████████████████████▊                         | 65/91 [00:01<00:00, 46.18it/s][A
Training:  77%|███████████████████████████████████████████████████████████████████▋                    | 70/91 [00:01<00:00, 45.89it/s][A
Training:  82%|████████████████████████████████████████████████████████████████████████▌               | 75/91 [00:01<00:00, 45.75it/s][A
Training:  88%|████████████

{'val/loss': 0.030865227803587914,
 'train/loss': 0.03358151763677597,
 'epoch_time': 1.8158459663391113}

In [13]:
model_bd = trainer.bind_model()

In [14]:
# Subset locations
locations_subset = ["California",   "Michigan", "Nevada", "New York", "Texas", "Washington"]
target_locations = ["England"]

In [1]:
fig = plt.figure(figsize=(20, 10))
spec = fig.add_gridspec(ncols=3, nrows=len(locations_subset) // 3 + 1)

TRAIN_PERIOD_COLOR, TEST_PERIOD_COLOR = "#d1c7c9", "#949fa5"
PREDICTION_COLOR = "#006633"
DATA_COLOR = "grey"

axes = []

data = {loc: df for loc, df in selective_pressure_df.groupby("location")}

for l, loc in enumerate(locations_subset):
    ax = fig.add_subplot(spec[l], sharex=None if l==0 else ax)
    loc_idx = locations_vec.values == loc
    loc_idx_train = loc_idx[X_train.index]
    loc_idx_test = loc_idx[X_test.index]

    # Plot data point
    ax.plot(dates_vec[X_train.index][loc_idx_train], y_train[loc_idx_train], color=DATA_COLOR, label="Data")
    ax.plot(dates_vec[X_test.index][loc_idx_test], y_test[loc_idx_test], color=DATA_COLOR)
    
    # Make predictions
    pred_train = model_bd(X_train[loc_idx_train].values)
    pred_test = model_bd(X_test[loc_idx_test].values)
    
    ax.plot(dates_vec[X_train.index][loc_idx_train], pred_train, color=PREDICTION_COLOR, label="Predicted")
    ax.plot(dates_vec[X_test.index][loc_idx_test], pred_test, color=PREDICTION_COLOR)
    
    # Highlight test and training sets
    min_train_time, max_train_time = dates_vec[X_train.index][loc_idx_train].min(), dates_vec[X_train.index][loc_idx_train].max()
    if not pd.isnull(max_train_time):
        ax.axvline(max_train_time, color="k", linestyle="--")
        ax.axvspan(min_train_time, max_train_time, color=TRAIN_PERIOD_COLOR, alpha=0.3)
    
    min_test_time, max_test_time = dates_vec[X_test.index][loc_idx_test].min(), dates_vec[X_test.index][loc_idx_test].max()
    if not pd.isnull(max_test_time):
        ax.axvspan(min_test_time, max_test_time, color=TEST_PERIOD_COLOR, alpha=0.3)
    
    ax.set_ylabel("Empirical Growth Rate")
    ax.set_title(loc)
    
    # Format dates
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%b\n%Y'))
    if l == 0:
        ax.legend()
    axes.append(ax)
    
# Add in target_locations
last_l = l
for l, loc in enumerate(target_locations):
    
    # Add selective pressure
    _l = 0
    ax = fig.add_subplot(spec[-1, 0], sharex=None if _l==0 else ax)
    ax.plot(data[loc]["date"], data[loc]["selective_pressure"], color="lightblue")
    ax.set_ylabel("Selective Pressure")
    ax.set_title(loc)
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%b\n%Y'))
    axes.append(ax)
    
    # Add prevalence
    ax = fig.add_subplot(spec[-1, 1], sharex = ax)    
    ax.plot(data[loc]["date"], data[loc]["smooth_cases"], color="grey")
    ax.set_title(loc)
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%b\n%Y'))
    ax.set_ylabel("ONS Prevalence")
    axes.append(ax)
   
    # Add predictions
    ax = fig.add_subplot(spec[-1, -1], sharex=ax)
    loc_idx = locations_vec.values == loc
    loc_idx_train = loc_idx[X_train.index]
    loc_idx_test = loc_idx[X_test.index]
    
    # Plot data points
    ax.plot(dates_vec[X_train.index][loc_idx_train], y_train[loc_idx_train], color=DATA_COLOR, label="Data")
    ax.plot(dates_vec[X_test.index][loc_idx_test], y_test[loc_idx_test], color=DATA_COLOR)
    
    # Make predictions
    pred_train = model_bd(X_train[loc_idx_train].values)
    pred_test = model_bd(X_test[loc_idx_test].values)

    ax.plot(dates_vec[X_train.index][loc_idx_train], pred_train, color=PREDICTION_COLOR, label="Predicted")
    ax.plot(dates_vec[X_test.index][loc_idx_test], pred_test, color=PREDICTION_COLOR)
    
    # Highlight test and training sets
    min_train_time, max_train_time = dates_vec[X_train.index][loc_idx_train].min(), dates_vec[X_train.index][loc_idx_train].max()
    if not pd.isnull(max_train_time):
        ax.axvline(max_train_time, color="k", linestyle="--")
        ax.axvspan(min_train_time, max_train_time, color=TRAIN_PERIOD_COLOR, alpha=0.3)
    
    min_test_time, max_test_time = dates_vec[X_test.index][loc_idx_test].min(), dates_vec[X_test.index][loc_idx_test].max()
    if not pd.isnull(max_test_time):
        ax.axvspan(min_test_time, max_test_time, color=TEST_PERIOD_COLOR, alpha=0.3)
    
    ax.set_ylabel("Empirical Growth Rate")
    ax.set_title(loc)
    
    # Format dates
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%b\n%Y'))
    axes.append(ax)

ax_labels = string.ascii_uppercase 
for ax, ax_label in zip(axes, ax_labels):
    ax.text(-0.1, 1.05, ax_label + ".", transform=ax.transAxes, size=36, weight='bold')
fig.tight_layout()

NameError: name 'plt' is not defined

In [None]:
fig.savefig("../manuscript/figures/selective_pressure_prediction.png")