In [9]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch_oopt.optim import LBFGS
import os
import argparse
import time
import datetime
import numpy as np
import obs_configs
import csv

from torchdiffeq import odeint
from torch.utils.data import DataLoader
from amortized_assimilation.data_utils import ChunkedTimeseries, L96, TimeStack, gen_data
from amortized_assimilation.models import MultiObs_ConvEnAF, MultiObs_UEnAF
from amortized_assimilation.operators import filter_obs, mystery_operator

In [10]:
def test(epoch, start_time, base_data, noise, m, model, obs_dict, indices, device, missing = False):
    """Test loop"""
    with torch.no_grad():
        state = base_data[0]
        state = state.unsqueeze(1).repeat(1, m, 1)
        state += torch.randn_like(state) * noise
        state = state.to(device = device)
        noisy_test = base_data + torch.randn_like(base_data) * noise
        noisy_test = noisy_test.to(device = device)
        pred_y_test, _, _, ens = assimilate_unseen_obs_ens(model, noisy_test, state, m,
                                                           obs_dict, indices, device, missing)
        loss = torch.mean(torch.mean((pred_y_test.cpu() - base_data.squeeze())**2, dim = 1)**.5)
        n = ens.shape[0]
        ens_std = [torch.std(ens[i*(n//10):(i+1)*(n//10)], 1).mean().item() for i in range(10)]
        ens_std_s = [format(torch.std(ens[i*(n//10):(i+1)*(n//10)], 1).mean().item(), '.4f') for i in range(10)]
        ens_loss = [format(torch.mean(torch.mean((pred_y_test.cpu()[i*(n//10):(i+1)*(n//10)]
                                           - base_data.squeeze()[i*(n//10):(i+1)*(n//10)])**2, dim = 1)**.5).item(),
                           '.4f')
                    for i in range(10)]
        print('Iter {:04d} | Test Loss {:.6f} | test_std {:.4f} | Time {:.1f}'.format(epoch, loss.item(),
                                                                                      sum(ens_std)/10.,
                                                                                      time.time() - start_time))
        print('Segmentwise loss/std', list(zip(ens_loss, ens_std_s)))
        # print('Iter {:04d} | Total Loss {:.6f} | Time {:.1f}'.format(epoch, loss.item(), time.time() - start_time))
        return loss

def assimilate_unseen_obs_ens(model, data, state, m, obs_dict, indices, device, missing):
    """ Executes online assimilation"""
    preds = []
    states = []
    filtered_preds = []
    filtered_obs = []
    ensembles = []
    memory = torch.zeros(m, 6, 40, device = device)

    for i, obsi in enumerate(data):
        # Masking
        if missing:
            obs = state.detach()[:, torch.randperm(m), :]
            obs[:, :, indices[str(i %  len(obs_dict))]] = (obs_dict[str(i % len(obs_dict))](obsi)).unsqueeze(1).repeat(1, m, 1)
            mask = torch.ones(obs.shape[0], m,  40, device = device) * -.1
            mask[:, :, indices[str(i % len(obs_dict))]] = .1
            obs_type = '0'
        else:
            obs = obsi.unsqueeze(1).repeat(1, m, 1)
            mask = None
            obs_type = '0'
        pred, state, ens, memory = model(obs, state, memory, mask,
                                                obs_type
                                              )
        states.append(state)
        preds.append(pred)
        ensembles.append(ens)

        i += 1
    return (torch.stack(preds, dim = 0).squeeze(), torch.stack(states, dim = 0).squeeze(),
            memory, torch.stack(ensembles, dim = 0).squeeze())

In [11]:

class dummy():
    def add_argument(self, name, type = None, action = None, default = None):
        if action is not None:
            setattr(self, name[2:], True)
        elif default is not None:
            setattr(self, name[2:], default)

args = dummy()
args.add_argument('--dynamics', type=str, default='lorenz96')
args.add_argument('--train_steps', type=int, default=240_000)
args.add_argument('--step_size', type=float, default=.1)
args.add_argument('--batch_steps', type=int, default=40)
args.add_argument('--batch_size', type=int, default=64)
args.add_argument('--m', type=int, default=5)
args.add_argument('--n', type=int, default=40)
args.add_argument('--hidden_size', type=int, default=64)
args.add_argument('--noise', type=float, default=1.0)
args.add_argument('--steps_valid', type=int, default=1000)
args.add_argument('--steps_test', type=int, default=10000)
args.add_argument('--check_disk', action='store_false')
args.add_argument('--obs_conf', type=str, default='every_4th_dim_partial_obs')
# args.add_argument('--obs_conf', type=str, default='full_obs')
args.add_argument('--do', type=float, default = .2)
args.add_argument('--device', type=str, default = 'gpu')

if args.device == 'gpu' and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

if args.obs_conf == 'every_4th_dim_partial_obs':
    missing = True
else:
    missing = False

In [12]:
t = torch.arange(0, args.train_steps*args.step_size, args.step_size)
true_y, true_y_valid, true_y_test = gen_data('lorenz96', t, args.steps_test,
                                             args.steps_valid, check_disk=args.check_disk)
print(true_y.max(dim = 0))

# Set up obs operators - uses full obs for input types since only one network is used
input_types, obs_dict, indices, known_h = obs_configs.lorenz_configs[args.obs_conf]
input_types, _, _, known_h = obs_configs.lorenz_configs['full_obs']

ntypes = len(obs_dict)
# Set up model
model = MultiObs_ConvEnAF(args.n, args.hidden_size, input_types=input_types,
                         m = args.m, missing = missing, do = args.do)
# Get param count
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
print('Param Count', sum([np.prod(p.size()) for p in model_parameters]))

### LOAD MODEL
model.load_state_dict(torch.load('models/2021-01-30_09-09lorenz96_partial_1.0std_64layers/final_convref_lorenz96_partial_0.6111_1.0std_500iters_64filt'))
### LOAD MODEL

model = model.to(device = device)
# print(model)
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay = 0)
# optimizer = LBFGS(model.parameters(), lr=1e-3)
dummy_sched = dummy()
dummy_sched.step = lambda: None

data = ChunkedTimeseries(true_y, args.batch_steps, .5)
loader = DataLoader(data, batch_size=args.batch_size,
                    shuffle=True, num_workers=0, collate_fn=TimeStack())
if args.obs_conf == 'full_obs':
    otype = 'full'
else:
    otype = 'partial'

start_time = time.time()
# Training
itr = 0
train_losses = []
test_losses = []
print('---Test set Results---')
test_loss = test(itr, start_time, true_y_test, args.noise, args.m, model, obs_dict, indices, device, missing)

torch.return_types.max(
values=tensor([[14.2660, 14.6090, 15.5068, 14.9297, 14.8302, 15.3349, 15.2931, 15.2986,
         14.5851, 15.4455, 14.6423, 14.7645, 15.4237, 14.9571, 14.0401, 14.3829,
         14.9659, 14.5102, 15.3957, 14.9481, 15.2749, 14.7238, 14.3481, 14.2430,
         14.8245, 15.2190, 14.8637, 15.5550, 14.4125, 14.7028, 14.8244, 14.4934,
         15.1251, 15.7087, 14.9530, 14.1860, 14.7433, 16.4774, 14.3877, 14.8334]]),
indices=tensor([[ 81606,  39221, 226929, 197613, 106480, 232094, 132012, 111264, 215132,
         136947,  53852, 214263, 127270, 208686,  55333, 182263, 223516, 100037,
         142609, 149733, 142728,  45169, 123948,  92716, 207700, 169419,  93137,
           7943,  28354, 219795, 214747,  20403,  62134, 186728, 190992,  94517,
          31404,  33347, 146376,  88804]]))
Param Count 69710
---Test set Results---
Iter 0000 | Test Loss 0.630299 | test_std 0.2875 | Time 57.5
Segmentwise loss/std [('0.6069', '0.2862'), ('0.6297', '0.2866'), ('0.6413', '0.288