<a href="https://colab.research.google.com/github/mariomorvan/nam21-astro-ts-physics-dl/blob/main/NAM2021_workshop_astro_ts_physics_dl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Modelling astrophysical time series with physics-based deep learning
### NAM 2021: *Machine Learning Methods for Research in Astrophysics*
- *author*: Mario Morvan
- *contact*: mario.morvan.18@ucl.ac.uk 

In [None]:
# Imports
import numbers
import tqdm
import numpy as np
import matplotlib.pylab as plt
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

# Plotting params 
plt.rcParams["figure.figsize"] = (12, 6)
plt.rcParams["font.size"] = 14

# device-agostic notebook
if torch.cuda.is_available():
  device = 'cuda'
  print('Device name used:', torch.cuda.get_device_name())
else:
  device = 'cpu'

## I) Modelling Astronomical Time Series with RNN

This section aims at presenting and experimenting with simple tool: an LSTM architecture slightly tweaked for imputing missing values in dataset of time series, opening various applications for datasets of astronomical time series. 

Let's first define a dummy dataset made of the (additive) composition of a random walk and a sine with random offset and period.

In [None]:
class DummyDataset(torch.utils.data.Dataset):
  """A simple torch dataset combining random walk and sine processes"""
  def __init__(self, seq_length, size=100, seed=None):
    """Define a DummyDataset object
    
    Args:
      seq_length: int
        lenght of the time series to generate
      size: int
        number of samples for the dataset
      seed: int
        manual seed to define for reproducibility (default None)
    """
    super().__init__()
    self.seq_length = seq_length
    self.size = size
    if seed is not None:
      torch.manual_seed(seed)
    
  def __len__(self):
    return self.size
    
  def __getitem__(self, index):
    sine = torch.sin(torch.linspace(0, 50 * torch.rand(1).item(), self.seq_length) 
                      + torch.rand(1).item()*np.pi)
    random_walk = (torch.rand(self.seq_length) - 0.5).cumsum(0)/3
    gaussian_noise = 0  # torch.randn(self.seq_length) / 10
    out = (sine + random_walk + gaussian_noise).unsqueeze(-1)
    return (out - out[:1].repeat(self.seq_length, 1)) / 2 / (out.max(0, keepdims=True)[0] - out.min(0, keepdims=True)[0])

item = DummyDataset(300, size=100)[0]     
plt.figure()   
plt.plot(item)
plt.xlabel('time index')
plt.ylabel('value')
plt.title('Random sample from our DummyDataset')


In [None]:
# create train a test dummy datasets
seq_length = 200
batch_size = 64

dataset = DummyDataset(seq_length, size=256, seed=0)
dataset_test = DummyDataset(seq_length, size=64, seed=1)

loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
loader_test = DataLoader(dataset_test, batch_size=len(dataset_test), shuffle=False)
batch_test = next(iter(loader_test))

Now let's teach first a basic stacked LSTM to forecast the next step in this dataset

In [None]:
hidden_size = 64
num_layers = 2
model = nn.LSTM(input_size=1, hidden_size=hidden_size, num_layers=num_layers, 
                batch_first=True)
optimiser = torch.optim.Adam(model.parameters(), lr=0.1)
criterion = lambda y, pred: F.mse_loss(y, pred)

def train_forecaster(model, loader, optimiser, criterion, 
                     epochs=1, loader_val=None):
  """Train a basic forecasting pytorch RNN"""
  model.train()
  losses = []

  for epoch in tqdm.tqdm(range(1, 1+epochs)):
    epoch_loss = 0
    for x in loader:
      optimiser.zero_grad()
      with torch.enable_grad():
        pred, _ = model(x)
        loss = criterion(x[:, 1:], pred[:, :-1])   # x_{t+1} ~ f(x_t)
      loss.backward()
      optimiser.step()
      epoch_loss += loss.item()
    losses.append(epoch_loss / len(loader))
  return losses  


In [None]:
losses = train_forecaster(model, loader, optimiser, criterion, epochs=30)
plt.plot(losses)
plt.yscale('log')

In [None]:
# pred = predict_forecaster(model, batch_test)
model.eval()
pred, (h_n, c_n) = model(batch_test)

i = np.random.randint(len(pred))
plt.title('Prediction example on test set')
plt.xlabel('Time')
plt.plot(batch_test[i,1:,0].T)
plt.plot(pred[i,:-1,0].detach().T)
plt.show()

plt.scatter(batch_test[:, 1:,0], pred[:,:-1,0].detach(), s=5)
plt.plot([batch_test.min().item(), batch_test.max().item()], 
         [batch_test.min().item(), batch_test.max().item()], color='red',)
plt.ylabel('Test predictions')
plt.xlabel('Test targets')
plt.show()
# plt.scatter(batch_test[:, 1:,0], batch_test[:,:-1,0].detach(), s=5)
# plt.plot([batch_test.min().item(), batch_test.max().item()], 
#          [batch_test.min().item(), batch_test.max().item()], color='red',)
# plt.show()



### Applications:
- forecasting
- anomaly detection
- encoding latent representation

### Improvements and ideas to explore together:
- loss
- window predictions
- dropout
- visualisation of latent representation with TSNE

In [None]:
class LSTMI(nn.Module):
    def __init__(self, input_size, hidden_size, output_size=None, num_layers=1, dropout=0.):
        """Define an LSTM Imputer network


        Args:
          input_size: dimensionality of the input sequences
          hidden_size: Number of units for the LSTM cells
          output_size: dimensionality of the output sequences.
                       If default (None) will be set as input_size.
          num_layers: number of LSTM layers
          dropout: dropout rate to apply after all-but-last layers
        Returns:
          pytorch module 

        """
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size if output_size is not None else input_size
        self.num_layers = num_layers
        self.dropout = float(dropout)

        self.lstm_cells = nn.ModuleList([nn.LSTMCell(input_size=self.input_size, hidden_size=self.hidden_size)])
        self.lstm_cells.extend([nn.LSTMCell(input_size=self.hidden_size, hidden_size=self.hidden_size)
                                for _ in range(self.num_layers - 1)])
        if dropout != 0:
            self.dropout_layer = nn.Dropout(self.dropout)

        self.fc = nn.Linear(self.hidden_size, self.output_size)

        self.h_t = None
        self.c_t = None
        self.out_t = None

    def init_state(self, batch_size, device=None, dtype=None):
      """Initialise the network's states"""
      self.h_t = [torch.randn(batch_size, self.hidden_size, device=device, dtype=dtype)
                  for _ in range(self.num_layers)]
      self.c_t = [torch.randn(batch_size, self.hidden_size, device=device, dtype=dtype)
                  for _ in range(self.num_layers)]
      self.out_t = torch.randn(batch_size, self.output_size, device=device, dtype=dtype)

    def init_state_like(self, x):
      """Initialise the network's on the model of a given input"""
      self.init_state(len(x), x.device, x.dtype)

    def forward(self, x, z=None, m=None, mask_nans=True):
      """Performs a forward pass

      :param x: Input vector
      :param m: Imputation mask - 1/True for keeping input and 0/False for forcing dynamic imputation
      :return: Output vector with imputed values
      """

      # Checking - casting inputs
      if m is not None:
        if x.shape != m.shape:
          mess = f'x and m must have the same shape ({x.shape} != {m.shape})'
          raise RuntimeError(mess)
      else:
        m = torch.ones_like(x)

      if mask_nans:
        m *= (x == x)

      if len(x.shape) == 2:
        warnings.warn('input has only 2 dims')
        x = x.unsqueeze(-1)
        m = m.unsqueeze(-1)
      elif len(x.shape) != 3:
        print(x.shape)
        raise ValueError('wrong shape')

      batch_size, len_seq, n_dim = x.shape

      if z is not None:
        zdim = z.shape[-1]
      else:
        zdim = 0
      if n_dim + zdim != self.input_size:
        mess = f'input (dim={n_dim}) and covariate (dim={zdim}) dimensions must' \
                + f'add to input_size param ({self.input_size})'
        raise RuntimeError(mess)

      m = m.type(x.dtype)
      inverted_m = (torch.ones_like(m, device=x.device) - m)

      # Preparing input
      if z is not None:
        input_masked = torch.cat((x * m, z), dim=-1)
      else:
        input_masked = x * m
      outputs = []

      for t in range(len_seq):
        pred_t = self.out_t
        input_t = input_masked[:, t] + F.pad(pred_t * inverted_m[:, t], (0, zdim))
        for k, lstm_cell in enumerate(self.lstm_cells):
          if k == 0:
            self.h_t[0], self.c_t[0] = lstm_cell(input_t, (self.h_t[0], self.c_t[0]))
          else:
            self.h_t[k], self.c_t[k] = lstm_cell(self.h_t[k - 1], (self.h_t[k], self.c_t[k]))
          if self.dropout != 0 and k < self.num_layers - 1:
            self.h_t[k] = self.dropout_layer(self.h_t[k])
            self.c_t[k] = self.dropout_layer(self.c_t[k])
        self.out_t = self.fc(self.h_t[-1])
        outputs.append(self.out_t.unsqueeze(1))
      return torch.cat(outputs, 1)


def train_lstmi_batch(model, loader, optimiser, criterion, epochs=1):
  """Train a LSTMI module"""
  model.train()
  losses = []

  for epoch in tqdm.tqdm(range(1, 1+epochs)):
    epoch_loss = 0
    for x in loader:
      optimiser.zero_grad()
      model.init_state_like(x)
      with torch.enable_grad():
        pred = model(x)
        loss = criterion(x[:, 1:], pred[:, :-1])   # x_{t+1} ~ f(x_t)
      loss.backward()
      optimiser.step()
      epoch_loss += loss.item()
    losses.append(epoch_loss / len(loader))
  return losses  

In [None]:
model = LSTMI(1, hidden_size=hidden_size, num_layers=num_layers)
optimiser = torch.optim.Adam(model.parameters(), lr=0.1)
criterion = lambda y, pred: F.mse_loss(y, pred)
losses = train_lstmi_batch(model, loader, optimiser, criterion, epochs=200)

In [None]:
# learning curve
plt.plot(losses)
plt.yscale('log')
plt.title('learning curve')
plt.xlabel('Training loss')
plt.ylabel('MSE')
plt.show()

In [None]:
# Check of predictions bias
model.eval()
model.init_state_like(batch_test)
pred = model(batch_test)

i = np.random.randint(len(pred))
plt.plot(batch_test[i,1:,0].T)
plt.plot(pred[i,:-1,0].detach().T)
plt.show()

plt.scatter(batch_test[:, 1:,0], pred[:,:-1,0].detach(), s=5)
plt.plot([batch_test.min().item(), batch_test.max().item()], 
         [batch_test.min().item(), batch_test.max().item()], color='red')
plt.ylabel('Test predictions')
plt.xlabel('Test targets')
plt.show()
plt.show()
# plt.scatter(batch_test[:, 1:,0], batch_test[:,:-1,0].detach(), s=5)
# plt.plot([batch_test.min().item(), batch_test.max().item()], 
#          [batch_test.min().item(), batch_test.max().item()], color='red',)
# plt.show()

### Applications
- Anomaly detections
- Imputing
- Modelling and fiting gaps
- More flexible learning (include in the training loss)


### Suggestions:
- Gaussian Loss 
- Gap imputing metric
- See improvements with growing dataset

## II) Embed differentiable physics model in DL framework

Why hard-coding the physics model in a DL framework? 
- Computational efficiency, with automatic differenciation and GPU acceleration
- Combine with NNs


Physics model requirements:
- implemented in DL framework (here Pytorch)
- vectorised to process batches of samples

Let's define a simple flare model in Pytorch: 

In [None]:
# Defining physics model and dataset class

# def physics_model():

def compute_flare(time, a_0, tau_g, tau_e, t_0=50):
  "ref: https://academic.oup.com/mnras/article/445/3/2268/2907951"
  out = torch.empty_like(time)
  out[time <= t_0] = a_0 * torch.exp(-(time[time<=t_0] - t_0)**2 / (2*tau_g**2))
  out[time > t_0] = a_0 * torch.exp(-(time[time>t_0] - t_0)**2 / (tau_e**2))
  return out

def compute_flare_batch(time, a_0, tau_g, tau_e, t_0=50):
  """ 
  """
  batch_size = len(a_0)
  time = time[None,:].repeat(batch_size, 1)
  a_0 = a_0.reshape(batch_size,1)
  tau_g = tau_g.reshape(batch_size,1)
  tau_e = tau_e.reshape(batch_size,1)
  # t_0 = t_0.reshape(batch_size,1)

  # temp_g = (time - t_0) / 2*tau_g
  # temp_e =  / tau_e

  return a_0 * torch.exp(-((time - t_0) * ((time < t_0) / 2*tau_g / tau_e))**2)


class PhysicsDataset():
  """

  """
  def __init__(self, physics_model, bounds, target_params, 
               seq_length=200, size=100, noise_level=1e-4, seed=0):
    super().__init__()
    self.physics_model = physics_model
    self.bounds = bounds
    self.target_params = target_params
    self.seq_length = seq_length
    self.size = size
    self.noise_level = noise_level
    if seed is not None:
        torch.manual_seed(seed)
    self.device = device
    
    self.params = dict()
    self._sample_priors(self.bounds)
    self.time = torch.linspace(0, 100,self.seq_length, device=device)
    self.noise = torch.randn(self.size, self.seq_length, device=device) * self.noise_level

  def _sample_priors(self, bounds):
    for par in bounds:
      self.params[par] = torch.rand(self.size, device=self.device) * (bounds[par][1] - bounds[par][0]) + bounds[par][0]

  def _get_item_params(self, index):
    return {par: self.params[par][index] for par in self.params}
      
  def __len__(self):
      return self.size
  
  def __getitem__(self, index):
    x =  (self.physics_model(self.time, **self._get_item_params(index)) 
          + self.noise[index])
    target = torch.tensor([self._get_item_params(index)[par] for par in self.target_params])
    # additional_params = torch.tensor([self._get_item_params(index)[par] for par in self.bounds if par not in self.target_params])
    return x.to(self.device), target.to(self.device) #, additional_params


In [None]:
# plot one example
bounds = {'a_0': [1, 5], 'tau_g':[4, 10], 'tau_e':[1, 5], 't_0':[50,50]}
target_params = ['a_0', 'tau_g', 'tau_e']
dataset = PhysicsDataset(compute_flare, bounds, target_params, noise_level=0.05)

x, target = dataset[0]

plt.plot(x.detach().cpu())

In [None]:
# Reproducibility - does not seem necessary or compatible with some cuda functions
torch.use_deterministic_algorithms(False)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)
# g = torch.Generator()
# g.manual_seed(0)

In [None]:
# Instantiating datasets and loaders

##### SIMPLIFY CELL

# General params
target_params = ['a_0', 'tau_g', 'tau_e']
seq_length = 100
size_train = 512
batch_size = 128
size_val = 512
size_test = 1024

device = 'cuda'

# Train dataset
bounds_train = {'a_0': [1, 2], 'tau_g':[4, 6], 'tau_e':[2, 5]}
bounds_val_2 = {'a_0': [0, 3], 'tau_g':[3, 8], 'tau_e':[1, 6]}
noise_train = 0.01

dataset_train = PhysicsDataset(compute_flare, bounds_train, target_params, 
                               seq_length=seq_length, size=size_train, 
                               noise_level=noise_train, seed=0)

loader_train = DataLoader(dataset_train, batch_size=batch_size, 
                          shuffle=True, 
                          # num_workers=1,                            #!: investigate why this fails
                          # worker_init_fn=seed_worker, generator=g   #!: does not seem necessary for reproducibility
                          )

# Val dataset 1 (same parameter space as training)
dataset_val_1 = PhysicsDataset(compute_flare, bounds_train, target_params, 
                                seq_length=seq_length, size=size_val, 
                                noise_level=noise_train, seed=1)
loader_val_1 = DataLoader(dataset_val_1, batch_size=len(dataset_val_1))
batch_val_1 = next(iter(loader_val_1))

# Test dataset 1 (same parameter space as training)
dataset_test_1 = PhysicsDataset(compute_flare, bounds_train, target_params, 
                                seq_length=seq_length, size=size_test, 
                                noise_level=noise_train, seed=4)
loader_test_1 = DataLoader(dataset_test_1, batch_size=len(dataset_test_1))
x_test_1, target_test_1 = next(iter(loader_test_1))


# Val dataset 2 (different parameter space)
dataset_val_2 = PhysicsDataset(compute_flare, bounds_val_2, target_params, 
                                seq_length=seq_length, size=size_val, 
                                noise_level=noise_train, seed=2)
loader_val_2 = DataLoader(dataset_val_2, batch_size=len(dataset_val_2))
batch_val_2 = next(iter(loader_val_2))


# Test dataset 2 (different parameter space)
dataset_test_2 = PhysicsDataset(compute_flare, bounds_val_2, target_params, 
                                seq_length=seq_length, size=size_test, 
                                noise_level=noise_train, seed=5)
loader_test_2 = DataLoader(dataset_test_2, batch_size=len(dataset_test_2))
x_test_2, target_test_2 = next(iter(loader_test_2))

In [None]:
# compute_flare_batch(dataset_test_1.time, **dataset_test_1.params)
# [compute_flare(dataset_train.time, **{par: dataset_train.params[par][i:i+1] for par in dataset_train.params}) for i in range(512)]


In [None]:
# plotting some signals

In [None]:
# A simple network to solve the inverse problem data -> params

class Network(nn.Module):
  """Define a simple MLP in pytorch"""
  def __init__(self, seq_length, out_dim):
    super().__init__()
    self.fc1 = nn.Linear(seq_length, 256)
    self.fc2 = nn.Linear(256, 64)
    self.fc3 = nn.Linear(64, 32)
    self.fc4 = nn.Linear(32, out_dim)
      
  def forward(self, x):
    out = F.relu(self.fc1(x))
    out = F.relu(self.fc2(out))
    out = F.relu(self.fc3(out))
    out = self.fc4(out)
    return out


def train_network(model, loader, optimiser, criterion,
                  epochs=1, batch_val={}, metric_val=None, eval_epochs=[]):
  """Train a pytorch module"""
  losses = {'train':[]}
  losses.update({name: [] for name in batch_val})

  for epoch in tqdm.tqdm(range(1, 1+epochs)):
    model.train()
    epoch_loss = 0
    for x, target in loader:
      optimiser.zero_grad()
      x = x.to(device)
      target = target.to(device)
      with torch.enable_grad():
        pred = model(x)
        loss = criterion(x, target, pred)
      loss.backward()
      optimiser.step()
      epoch_loss += loss.item()
    losses['train'].append(epoch_loss / len(loader))
    if metric_val is not None and batch_val is not None and epoch in eval_epochs:
      model.eval()
      for name, batch in batch_val.items():
        x_val, target_val = batch[0].to(device), batch[1].to(device)
        pred_val = model(x_val)
        losses[name].append(metric_val(x_val, target_val, pred_val))
  return losses  

In [None]:
# defining two losses

def naive_loss(x, target, pred):
  """Wrapper around pytorch mse_loss to add inputs to signature

  Args: 
    x: torch.Tensor 
      input time series of shape (batch_size, T) or (T,). (unused argument)
    target: torch.Tensor
      output targets of shape (batch_size, dim) or (dim,)
    pred: torch.Tensor
      predicted targets of shape (batch_size, dim) or (dim,)
  Return: torch.Tensor
    mean squared error value between target and predictions
  """
  return F.mse_loss(target, pred)

def hybrid_loss(dataset, beta=1):
  """Define a hybrid regression & reconstruction loss function

  Args:
    dataset: PhysicsDataset
      torch dataset with arguments target_params, time and physics_model
    beta: 
      weight parameter between regression and reconstruction terms: 
      loss = regression + beta * reconstruction
  Return: function
    loss function associated with provided dataset and beta parameter
  """
  def loss_function(x, target, pred):
    """Compute the hybrid loss

    Args:


    Return:
      
    """
    pred_dict = {dataset.target_params[i]: pred[:,i] for i in range(len(dataset.target_params))}  # Iterate over feature dimension to produce dict of outputs
    reconstructed_time_series = compute_flare_batch(dataset.time.to(device), **pred_dict)
    physics_reconstruction_term = F.mse_loss(x, reconstructed_time_series)
    return naive_loss(x, target, pred) + beta * physics_reconstruction_term
  return loss_function

# Instance 2 scenarios
scenarios = ['naive', 'hybrid']
network = {scenario: Network(seq_length, len(target_params)).to(device) for scenario in scenarios}
optimiser = {scenario: torch.optim.Adam(network[scenario].parameters(), lr=0.002) for scenario in scenarios}
loss = {'naive': naive_loss,
        'hybrid': hybrid_loss(dataset_train,beta=0.01)}

In [None]:
# Running experiment with two losses
loss_history = dict()
epochs = 1500
eval_epochs = list(range(1, 1+epochs, 10))
for scenario in scenarios:
  loss_history[scenario] = train_network(network[scenario], loader_train, optimiser[scenario], loss[scenario], epochs=epochs, 
                                         batch_val={'val_1': batch_val_1, 'val_2': batch_val_2}, metric_val=naive_loss, eval_epochs=eval_epochs)

In [None]:
# Plot learning curves
plt.figure(figsize=(10,6))

linestyle = {'naive': 'solid', 'hybrid': 'dotted'}
alpha = {'naive': 0.7, 'hybrid': 1}
for scenario in scenarios:
  plt.plot(eval_epochs, loss_history[scenario]['val_1'], label=f'Val 1 ({scenario})', 
           c='blue', linestyle=linestyle[scenario], alpha=alpha[scenario])
  plt.plot(eval_epochs, loss_history[scenario]['val_2'], label=f'Val 2 ({scenario})', 
           c='green', linestyle=linestyle[scenario], alpha=alpha[scenario])
  # plt.plot(eval_epochs, loss_history[scenario]['val_3'], label=f'Val 3 ({scenario})', 
  #          c='red', linestyle=linestyle[scenario], alpha=alpha[scenario])
plt.legend()
plt.yscale('log')
plt.xlabel('Epochs')
plt.ylabel('MSE')

In [None]:
# evaluation on test sets

scores_1 = dict()
scores_2 = dict()

for scenario in scenarios:
  scores_1[scenario] = naive_loss(x_test_1, network[scenario](x_test_1), target_test_1).item()
  scores_2[scenario] = naive_loss(x_test_2, network[scenario](x_test_2), target_test_2).item()
print(f'scores_1: {scores_1}\n', f'scores_2: {scores_2}')

Pros:
- accuracy
- stability
- generalisability

Cons:
- need the physics model implemented in DL framework
- physics model suceptible of adding complexity to the network's! 
- training may require further tuning to accomodate for different loss terms

Improvements:
- hyperoptimisation for the loss weight $\beta$ 
- hyperoptim for LRs in both cases

To go further:
- use your own physics model! 
- design transfer learning and meta-learning experiments to assess generalisability


## III) Bonus: Combine RNN detrending wit differentiable physics model


In [None]:
Using the two different tools to perform denoising and fitting in the same time.
 - RNN to model and impute multiple time series with masked signal
 - physics model on the residuals
 - all trained end-to-end

Inference echniques:
 - VAE
 - SVI 

Let's code an example below.