## Time-instant models: GHGs(t) -> temperature(t):

- Prediction task:
  - $f_{global}: co2_{global-mean}(t) \in \mathbb{R} \rightarrow tas_{global-mean}(t) \in \mathbb{R}$
  - $f_{local}: tas_{global-mean}(t) \in \mathbb{R} \rightarrow tas_{0:lat, 0:lon}(t) \in \mathbb{R}^{(lat, lon)}$
- Models:
  - Fully-connected Neural Net (FCNN):
    - $f_{global}$: [batch_size,1,1,1] —> (dense layers) -> [batch_size,1,1,1]
    - $f_{local}$: [batch_size,1,1,1] -> (dense layers) -> [batch_size,1,1,lat*lon] —> Reshape(batch_size,1,lat,lon)
  - UNet:
    - [batch_size, 1, 1, 96 * 144] —> Reshape(batch_size, 1, 96, 144) -> UNet -> (batch_size, 1, 96, 144) -> tanh?()


## Autoregressive Markovian models: GHGs(t), temperature(t-1) -> temperature(t):

- Prediction task:

$[\text{co2}_\text{global-mean}(t), \text{tas}_{0:Lat, 0:Lon}(t)] \in \mathbb{R}^{3}\times \mathbb{R}^{(Lat,Lon)} \rightarrow \text{tas}_{0:Lat, 0:Lon}(t+1) \in \mathbb{R}^{(Lat, Lon)}$

- Limitations
  - Captures atmospheric heating due to co2. Assumes that the state (tas and co2) captures the full carbon cycle. But it doesn't capture longer term effects like ocean heating, ice loss, etc.
  - Spikes in CO2 would not be accurately modeled.

In [None]:
#@title code: Create train and test dataset

from emcli.dataset.autoregressiveDataset import AutoregressiveDataset

len_snippet = 1 # Number of time steps that target is ahead of input.
# E.g., [t=0]->[t=3] for len_snippet = 3. This is also the number of
# autoregressive model forecasting steps until the loss is applied.
ar_dataset_train = AutoregressiveDataset(X_train, Y_train,
    len_snippet=len_snippet, split='train')
ar_dataset_test = AutoregressiveDataset(X_test, Y_test,
    len_snippet=len_snippet, split='test')

input_sample, output_sample = ar_dataset_train.__getitem__(idx=0)
print('Number of training data samples: ',len(ar_dataset_train))
print('Number of test data samples: ',len(ar_dataset_test))
print('Input sample: ', input_sample.shape, input_sample.dtype)
print('Output sample: ', output_sample.shape, output_sample.dtype)

Number of training data samples:  1078
Number of test data samples:  250
Input sample:  torch.Size([1, 2, 96, 144]) torch.float32
Output sample:  torch.Size([1, 96, 144]) torch.float32


In [None]:
# Create data loaders
# We use the test as val set due to our limited data size. This seems acceptable
# as we will perform limited amount of hyperparameter tuning.
from torch.utils.data import DataLoader

batch_size = 8
loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
ar_train_loader = DataLoader(ar_dataset_train, shuffle=True, **loader_args)
ar_val_loader = DataLoader(ar_dataset_test, shuffle=False, drop_last=True, **loader_args)

In [None]:
#@title code: PushforwardUNet
from emcli.models.unet.unet_model import PushforwardUNet

In [None]:
#@title code: evaluate.py -> evaluate()

import torch
import torch.nn.functional as F
from tqdm.notebook import tqdm

@torch.inference_mode()
def evaluate(model, dataloader, criterion, device, amp):
  model.eval()
  n_val = len(dataloader.dataset)
  num_val_batches = len(dataloader)
  total_loss = 0

  # iterate over the validation set
  # with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
  with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
    with tqdm(total=n_val, desc='validation.', unit='img', leave=False) as pbar2:
      for i, (inputs, targets) in enumerate(dataloader):
        batch_size = inputs.shape[0]
        inputs = inputs.to(device=device)# , memory_format=torch.channels_last)
        targets = targets.to(device=device)# , memory_format=torch.channels_last)

        pred = model(inputs)

        total_loss += criterion(pred, targets)

        pbar2.update(batch_size)
        pbar2.set_postfix(**{'val MSE/img': total_loss.cpu().numpy() / float(i+1)})

  model.train()
  return total_loss / max(num_val_batches, 1)

# val_score = evaluate(model, ar_val_loader, nn.MSELoss(), device, cfg["amp"])

In [None]:
#@title code: train.py -> train_model()

from torch import optim
from tqdm.notebook import tqdm
from pathlib import Path

def train_model(
        model,
        train_loader,
        val_loader,
        device,
        epochs: int = 5,
        batch_size: int = 1,
        learning_rate: float = 1e-5,
        val_percent: float = 0.1,
        save_checkpoint: bool = True,
        img_scale: float = 0.5,
        amp: bool = False,
        weight_decay: float = 1e-8,
        momentum: float = 0.999,
        gradient_clipping: float = 1.0,
        no_wandb: bool = False,
        parallel: bool = False,
        dir_checkpoint: str = '',
        cfg: dict = None,
  ):
  """
  Train model
  Source: https://github.com/milesial/Pytorch-UNet/blob/2f62e6b1c8e98022a6418d31a76f6abd800e5ae7/train.py#L81
  """
  # (Initialize logging)
  n_train = len(train_loader.dataset)
  n_val = len(val_loader.dataset)
  if not no_wandb:
    experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
    experiment.config.update(dict(epochs=epochs, batch_size=batch_size,
        learning_rate=learning_rate, val_percent=val_percent,
        save_checkpoint=save_checkpoint, img_scale=img_scale, amp=amp))

  logging.info(f'''Starting training:
      Epochs:          {epochs}
      Batch size:      {batch_size}
      Learning rate:   {learning_rate}
      Training size:   {n_train}
      Validation size: {n_val}
      Checkpoints:     {save_checkpoint}
      Device:          {device.type}
      Images scaling:  {img_scale}
      Mixed Precision: {amp}
  ''')

  # Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
  optimizer = optim.RMSprop(model.parameters(),
                lr=learning_rate, weight_decay=weight_decay,
                momentum=momentum, foreach=True)
  scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, 'min', patience=5)  # goal: minimize MSE
  grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
  criterion = nn.MSELoss()
  global_step = 0

  print('batch_size')
  # Begin training
  for epoch in range(1, epochs + 1):
    model.train()
    epoch_loss = 0
    with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
      for i, (inputs, targets) in enumerate(train_loader):
        inputs = inputs.to(device=device)# , memory_format=torch.channels_last)
        targets = targets.to(device=device)# , memory_format=torch.channels_last)

        assert inputs.shape[2] == model.n_channels, \
          f'Network has been defined with {model.n_channels} input channels, ' \
          f'but loaded images have {inputs.shape[2]} channels. Please check that ' \
          'the images are loaded correctly.'

        with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
          pred = model(inputs)
          # todo: check if loss calculates batch correctly
          loss = criterion(pred, targets)

        # todo: double check if gradients only go back to last network
        optimizer.zero_grad(set_to_none=True)
        grad_scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
        grad_scaler.step(optimizer)
        grad_scaler.update()

        pbar.update(inputs.shape[0])
        global_step += 1
        epoch_loss += loss.item()
        if not no_wandb:
          experiment.log({
              'train loss': loss.item(),
              'step': global_step,
              'epoch': epoch
          })
        pbar.set_postfix(**{'avg MSE/img': epoch_loss / float(i+1)})

        # Evaluation round
        division_step = (n_train // (1 * batch_size))
        if division_step > 0:
          if global_step % division_step == 0:
            histograms = {}
            for tag, value in model.named_parameters():
              if not no_wandb:
                tag = tag.replace('/', '.')
                if not (torch.isinf(value) | torch.isnan(value)).any():
                  histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
                if not (torch.isinf(value.grad) | torch.isnan(value.grad)).any():
                  histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())

            val_score = evaluate(model, val_loader, criterion, device, amp)
            scheduler.step(val_score)

            logging.info('Validation Dice score: {}'.format(val_score))
            try:
              experiment.log({
                'learning rate': optimizer.param_groups[0]['lr'],
                'validation Dice': val_score,
                'inputs': wandb.Image(inputs[0].cpu()),
                'predictions': {
                  'true': wandb.Image(targets[0].float().cpu()),
                  'pred': wandb.Image(pred.argmax(dim=1)[0].float().cpu()),
                },
                'step': global_step,
                'epoch': epoch,
                **histograms
              })
            except:
              pass

    if save_checkpoint:
      Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
      state_dict = model.state_dict()
      torch.save(state_dict, str(Path(dir_checkpoint) / 'checkpoint_epoch{}.pth'.format(epoch)))
      logging.info(f'Checkpoint {epoch} saved!')

In [None]:
train_model(
  model=model,
  train_loader=ar_train_loader,
  val_loader=ar_val_loader,
  epochs=cfg["epochs"],
  batch_size=cfg["batch_size"],
  learning_rate=cfg["learning_rate"],
  device=device,
  img_scale=cfg["scale"],
  val_percent=cfg["validation"] / 100.,
  amp=cfg["amp"],
  no_wandb=cfg["no_wandb"],
  dir_checkpoint=cfg["dir_checkpoint"],
  cfg=cfg,
)

NameError: name 'model' is not defined

cuda
