In [None]:
%load_ext autoreload
%autoreload 2

# Smooth EMOS

Experiment with a module that smooths the linear model weights after the training epoch.
The hope is that this will improve the validation score, because we postulate that the biases are slow moving across time for a given station.


In [None]:
import collections
import hydra
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pathlib
import torch
import torch.nn as nn

from pandas.api.types import CategoricalDtype


from smc01.postprocessing.train import make_dataloader
from smc01.postprocessing.util import load_checkpoint_from_run

In [None]:
DATA_DIR = pathlib.Path(os.getenv('DATA_DIR'))

In [None]:
with hydra.initialize_config_module('smc01.postprocessing.conf'):
    cfg = hydra.compose('train')

In [None]:
dataset = hydra.utils.instantiate(cfg.experiment.dataset)
loader = make_dataloader(cfg, dataset)

In [None]:
model = load_checkpoint_from_run(DATA_DIR / 'runs/postprocessing/outputs/2022-01-24/16-26-56')

In [None]:
model

In [None]:
model.weights.shape

## First try with dataloader

In [None]:
model.weights.mean(dim=0).mean(dim=0).mean(dim=0)

In [None]:
model.biases.max()

In [None]:
model.biases.min()

In [None]:
model.weights.max()

In [None]:
model.weights.min()

In [None]:
model.weights.shape

In [None]:
plt.plot(model.weights[1000, 0::2, 9, 0].detach().numpy())
plt.plot(model.weights[1000, 1::2, 9, 0].detach().numpy())

In [None]:
plt.plot(model.weights[1000, 0::2, 12, 0].detach().numpy())
plt.plot(model.weights[1000, 1::2, 12, 0].detach().numpy())

In [None]:
transformed_dataset.transform.station_dtype

In [None]:
model.weights[437].shape

In [None]:
model.biases.shape

In [None]:
plt.plot(model.biases[1].mean(dim=0)[1:-1:8].detach().numpy())

In [None]:
import torch.nn.functional as F

In [None]:
model.weights.shape

In [None]:
def smooth_across_one_dim(tensor, dim, filter_size=3, dilation=1):
    """Perform a smoothing operation across one dimension of a tensor. Useful
    to smooth EMOS weights across time, for instance.
    
    Args
        tensor: The tensor to smooth.
        dim: The number of the dimension across which to perform the smoothing.
        filter_size: The size of the smoothing filter.
        
    Returns
        A tensor that had a smoothing filter applied."""
    
    
    # Move target dimension last.
    tensor = tensor.transpose(dim, -1)
    
    # Flatten all intermediary dims.
    original_shape = tensor.shape
    tensor = tensor.flatten(1, len(tensor.shape) - 2)
    
    # Add circular padding
    left_padding = dilation * (filter_size // 2)
    right_padding = dilation * (filter_size // 2) - dilation + dilation * (filter_size % 2)
    padded_tensor = F.pad(tensor, [left_padding, right_padding], mode='circular')
    
    # Create filter.
    # We want the filter to be the average of all the filtered values.
    n_dims = padded_tensor.shape[1]
    fltr = torch.ones(n_dims, 1, filter_size, requires_grad=False, device=tensor.device) / filter_size
    
    # Perform convolution.
    # We use groups=n_dims so that one dimension at a time is filtered.
    with torch.no_grad():
        filtered = F.conv1d(padded_tensor, fltr, groups=n_dims, dilation=dilation)
    
    print(padded_tensor.shape)
    
    filtered = filtered.reshape(*original_shape)
    # Bring filtered dim back where it was.
    filtered = filtered.transpose(dim, -1)
    print(filtered.shape)
        
    return filtered

In [None]:
model.weights.shape

In [None]:
smoothed = smooth_across_one_dim(model.weights, 1, 3, dilation=2)

In [None]:
smoothed.shape

In [None]:
fig, ax = plt.subplots()
ax.plot(model.weights[437, :, 0, 0].detach().numpy())
ax.plot(smoothed[437, :, 0, 0].detach().numpy())

In [None]:
fig, ax = plt.subplots()
ax.plot(model.weights[437, 0::2, 8, 0].detach().numpy())
ax.plot(smoothed[437, 0::2, 8, 0].detach().numpy())

In [None]:
fig, ax = plt.subplots()
ax.plot(model.weights[437, 1::2, 8, 0].detach().numpy())
ax.plot(smoothed[437, 1::2, 8, 0].detach().numpy())

In [None]:
fig, ax = plt.subplots()
ax.plot(smoothed[437, 0::2, 8, 0].detach().numpy())
ax.plot(smoothed[437, 1::2, 8, 0].detach().numpy())

In [None]:
fig, ax = plt.subplots()
ax.plot(model.weights[100, 30, :, 0].detach().numpy())
ax.plot(smoothed[100, 30, :, 0].detach().numpy())

In [None]:
fig, ax = plt.subplots()
ax.plot(model.weights[100, :, 0, 1].detach().numpy())
ax.plot(smoothed[100, :, 0, 1].detach().numpy())