In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from tqdm import tqdm

from sampler import Sampler
from dataset import SQGDataset

In [None]:
model_path = "best_model.pth"
image_shape = (2, 64, 64)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

members = 5

eps = lambda t: 0.1 * (1 - t)  # Noise when sampling
invert_eps = lambda t: 0. * (1 - t)  # Noise when inverting

steps = 100
invert_steps = 100

debug = True

sampler = Sampler(device, members, eps, steps, invert_eps, invert_steps, model_path, debug)

### Example on how to invert a physical state

In [None]:
bs = 1
data_std = 2660
dataset = SQGDataset("data/SQG", mean=0, std=data_std)

loader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=False)

In [None]:
# Example on inverting a sample
truth = next(iter(loader)).to(device)
z1 = truth.repeat(members, 1, 1, 1)

z0 = sampler.invert(z1)

### Example on how to sample a physical state

In [None]:
# Example on how to invert prior
bs = 1
data_std_inverted = 1
dataset_inverted = SQGDataset("data/inverted_SQG", mean=0, std=data_std_inverted)

loader_inverted = torch.utils.data.DataLoader(dataset_inverted, batch_size=bs, shuffle=True)

In [None]:
## Noise from your timeseries
#z0 = next(iter(loader_inverted)).to(device)

## Random noise
#z0 = torch.randn((members, *image_shape), device=device)

## All the same noise for testing epsilon
#z0 = torch.randn((1, *image_shape), device=device).repeat(members, 1, 1, 1)

z1, _ = sampler.sample(z0)

### Plotting

In [None]:
plot_truth = truth.cpu()
plot_sample = z1.cpu() 
plot_ens_mean = plot_sample.mean(dim=0)
level = 0

In [None]:
fig, axs = plt.subplots(2,3, figsize=(8,6))
cmap  =  plt.get_cmap('viridis', 10) #'jet' #

vmin = truth[:,level].min()
vmax = truth[:,level].max()

rmse = ((plot_sample - truth)).pow(2).mean(dim=(1,2,3)).sqrt()
rmse_mean = ((plot_ens_mean - truth)).pow(2).mean(dim=(1,2,3)).sqrt()[0]


def set_cbar(im):
    ax = im.axes
    # Create an inset axes for the colorbar above the plot
    cax = inset_axes(ax,
                        width="100%",   # relative to ax width
                        height="5%",   # relative to ax height
                        loc='upper center',
                        bbox_to_anchor=(0, 0.18, 1, 1),  # place above
                        bbox_transform=ax.transAxes,
                        borderpad=0)
    cbar = plt.colorbar(im, cax=cax, orientation='horizontal')
    cbar.ax.xaxis.set_ticks_position('top')
    cbar.ax.xaxis.set_label_position('top')

# Define highlight_cell function
def highlight_cell(obs_mask, ax=None, **kwargs):
    ax = ax or plt.gca()
    for i in range(obs_mask.shape[0]):      # rows
        for j in range(obs_mask.shape[1]):  # columns
            if obs_mask[i, j]:
                rect = plt.Rectangle((j-0.5, i-0.5), 1, 1, fill=False, **kwargs)
                ax.add_patch(rect)

for i, ax in enumerate(axs.flat):
    ax.set_aspect('equal')
    if i == 0:
        im = ax.imshow(plot_truth[0, level], cmap=cmap, vmin=vmin, vmax=vmax)
        ax.set_title('Truth', fontsize=12)
        set_cbar(im)
    elif i ==1:
        ax.imshow((plot_ens_mean[level]), cmap=cmap, vmin=vmin, vmax=vmax)
        ax.set_title(f'Mean,  RMSE {rmse_mean:.2f}', fontsize=12)
    elif i ==2:
        im = ax.imshow((plot_sample.std(dim=0)[level]), cmap=cmap, vmin=0)
        set_cbar(im)
        ax.set_title(f'Std', fontsize=12)
    else:
        ax.imshow((((plot_sample)[i-3, level])).cpu().detach().numpy(), cmap=cmap, vmin=vmin, vmax=vmax)
        ax.set_title(f'Sample #{i-2}, RMSE {rmse[i-3].item():.2f}', fontsize=12)
    
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
plt.tight_layout()


In [None]:
%matplotlib inline  
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

level = 0 # There are two levels in the data, 0 and 1
cmap = plt.get_cmap('viridis',10) # You can try other cmaps also

bs = 1000

dataset = SQGDataset("data/SQG", mean=0, std=2660)
loader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=False)
dataset_inverted = SQGDataset("data/inverted_SQG", mean=0, std=1)
loader_inverted = torch.utils.data.DataLoader(dataset_inverted, batch_size=bs, shuffle=False)

truth = next(iter(loader))
noise = next(iter(loader_inverted))

fig, ax = plt.subplots(1, 2, figsize=(6,3), constrained_layout=True)
ax[0].axis('off')
ax[1].axis('off')

stds = 3
vmin_truth, vmax_truth = -truth.std()*stds, truth.std()*stds
vmin_noise, vmax_noise = -noise.std()*stds, noise.std()*stds

ims = []
im1 = ax[0].imshow(truth[0,level], cmap=cmap, animated=True, vmin=vmin_truth, vmax=vmax_truth)
im2 = ax[1].imshow(noise[0,level], cmap=cmap, animated=True, vmin=vmin_noise, vmax=vmax_noise)
ims.append(im1)
ims.append(im2)

def update(frame):
    ims[0].set_array(truth[frame, level])
    ims[1].set_array(noise[frame, level])
    return ims

ani = animation.FuncAnimation(
    fig, update, frames=truth.shape[0], interval=10, blit=True
)

#ani.save(f"SQG.mp4", writer="ffmpeg", fps=30, dpi=300)

HTML(ani.to_jshtml())