In [31]:
import time
from typing import Any, Dict, List, Tuple, Union

import argparse
import joblib
import tensorflow as tf
from termcolor import cprint
import numpy as np
import pickle as pkl
import torch
import torch.nn as nn
import torch.nn.functional as F
import utils.baseline_utils as baseline_utils
#from lstm_train_test_det import *
from lstm_train_test import *

In [32]:
from glob import glob

In [6]:
!ls ../../argoverse

agents_train.npy	       agents_val_rotation.npy	test_obs
agents_train_rotation.npy      agents_val_transi.npy	train
agents_train_small.npy	       argoverse-api.git	train_original
agents_train_small_transi.npy  features			val
agents_train_transi.npy        preprocessing.ipynb	val_original
agents_val.npy		       test


In [33]:
 with open('../../argoverse/agents_train_rotation.npy', 'rb') as f:
    wholetraj = np.load(f)

In [9]:
wholetraj.shape

(205942, 50, 2)

In [12]:
conformal_train = wholetraj[:102971]
conformal_err = wholetraj[102971:]
conformal_train.shape

(102971, 50, 2)

In [13]:
with open('../../argoverse/conformal_train.npy', 'wb') as f:
    np.save(f, conformal_train)
    
with open('../../argoverse/conformal_err.npy', 'wb') as f:
    np.save(f, conformal_err)

In [34]:
device = 'cuda:0'

In [102]:
encoder = EncoderRNN(input_size=2)
decoder = DecoderRNN(output_size=2)
encoder = nn.DataParallel(encoder)
decoder = nn.DataParallel(decoder)
encoder.to(device)
decoder.to(device)

encoder_optimizer = torch.optim.Adam(encoder.parameters())
decoder_optimizer = torch.optim.Adam(decoder.parameters())

model_utils = ModelUtils()

model_path = '../checkpoints/argo_lstm_conformal/LSTM_rollout30.pth.tar'
epoch, rollout_len, _ = model_utils.load_checkpoint(
            model_path, encoder, decoder, encoder_optimizer,
            decoder_optimizer)

=> loading checkpoint '../checkpoints/argo_lstm_conformal/LSTM_rollout30.pth.tar'
=> loaded checkpoint ../checkpoints/argo_lstm_conformal/LSTM_rollout30.pth.tar (epoch: 193, loss: 4.861854553222656)


In [78]:

class LSTMDataset(Dataset):
    """PyTorch Dataset for LSTM Baselines."""
    def __init__(self, wholetraj, shuffle=True):
        normalized = wholetraj

        self.input_data = normalized[:, :20, :]
        self.output_data = normalized[:, 20:, :]
        self.data_size = self.input_data.shape[0]

    def __len__(self):
        return self.data_size

    def __getitem__(self, idx: int
                    ) -> Tuple[torch.FloatTensor, Any, Dict[str, np.ndarray]]:
        return (
            torch.FloatTensor(self.input_data[idx]),
            torch.FloatTensor(
                self.output_data[idx])
        )

In [28]:
model_utils = ModelUtils()

val_dataset = LSTMDataset(conformal_err)
val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=128,
        drop_last=False,
        shuffle=False,
        collate_fn=model_utils.my_collate_fn,
    )

In [67]:
from tqdm import tqdm

In [87]:
# Set to eval mode
encoder.eval()
decoder.eval()
criterion = nn.MSELoss()

rollout_len = 30
all_mse = {}
all_de = []
for i, (_input, target, helpers) in tqdm(enumerate(val_loader)):

    _input = _input.to(device)
    target = target.to(device)

    # Encoder
    batch_size = _input.shape[0]
    input_length = _input.shape[1]
    output_length = target.shape[1]
    input_shape = _input.shape[2]

    # Initialize encoder hidden state
    encoder_hidden = model_utils.init_hidden(
        batch_size,
        encoder.module.hidden_size if use_cuda else encoder.hidden_size)


    # Encode observed trajectory
    for ei in range(input_length):
        encoder_input = _input[:, ei, :]
        encoder_hidden = encoder(encoder_input, encoder_hidden)

    # Initialize decoder input with last coordinate in encoder
    decoder_input = encoder_input
    
    decoder_hidden = encoder_hidden
    output_shape = list(target.shape)
    decoder_outputs = torch.zeros(output_shape).to(device)
    
    mses = torch.zeros(rollout_len).to(device)
    de = []

    for di in range(rollout_len):
        decoder_output, decoder_hidden = decoder(decoder_input,
                                                 decoder_hidden)
        decoder_outputs[:, di, :] = decoder_output
        mses = (decoder_output[:, :2]-target[:, di, :2]).pow(2).sum(axis=-1).pow(0.5)
        if di in all_mse:
            all_mse[di] = np.concatenate((all_mse[di], mses.detach().cpu().numpy()))
        else:
            all_mse[di] = mses.detach().cpu().numpy()

        # Use own predictions as inputs at next step
        decoder_input = decoder_output
        de.append(torch.sqrt((decoder_output[:, 0] - target[:, di, 0])**2 +
                           (decoder_output[:, 1] - target[:, di, 1])**2).detach().cpu().numpy())
    
    all_de.append(de)
    
    if i>60:
        break

61it [00:34,  1.78it/s]


In [84]:
ninetyquant = []
for i in range(30):
    q = np.quantile(all_mse[i], 0.9)
    ninetyquant.append(q)

In [88]:
ninetyquant

[0.9199341535568237,
 0.9753382205963135,
 1.0859079957008362,
 1.2161036133766174,
 1.3877891302108765,
 1.5531651377677917,
 1.752310037612915,
 2.0071715116500854,
 2.247078537940979,
 2.494024634361267,
 2.793949842453003,
 3.1226292848587036,
 3.4556496143341064,
 3.8185853958129883,
 4.143992185592651,
 4.531986236572266,
 4.937242269515991,
 5.381857633590698,
 5.768104076385498,
 6.165144205093384,
 6.541731119155884,
 6.962288856506348,
 7.428401231765747,
 7.888824701309204,
 8.417593002319336,
 8.935048580169678,
 9.438450336456299,
 9.91438627243042,
 10.462123394012451,
 10.978078842163086]

In [92]:
np.array(de).mean()

2.15947

In [75]:
!ls ../../argoverse/agents_val.npy

../../argoverse/agents_val.npy


In [116]:
    
with open('../../argoverse/agents_val_rotation.npy', 'rb') as f:
    test_datase = np.load(f)

In [117]:

test_dataset = LSTMDataset(test_datase)
test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=128,
        drop_last=False,
        shuffle=False,
        collate_fn=model_utils.my_collate_fn,
    )

In [188]:
# Set to eval mode
encoder.eval()
decoder.eval()
criterion = nn.MSELoss()

rollout_len = 30
all_coverage = {}
de_test = []
mrses = []
for i, (_input, target, helpers) in tqdm(enumerate(test_loader)):

    _input = _input.to(device)
    target = target.to(device)

    # Encoder
    batch_size = _input.shape[0]
    input_length = _input.shape[1]
    output_length = target.shape[1]
    input_shape = _input.shape[2]

    # Initialize encoder hidden state
    encoder_hidden = model_utils.init_hidden(
        batch_size,
        encoder.module.hidden_size if use_cuda else encoder.hidden_size)


    # Encode observed trajectory
    for ei in range(input_length):
        
        encoder_input = _input[:, ei, :]
        encoder_hidden = encoder(encoder_input, encoder_hidden)

    # Initialize decoder input with last coordinate in encoder
    decoder_input = encoder_input
    decoder_hidden = encoder_hidden
    output_shape = list(target.shape)
    decoder_outputs = torch.zeros(output_shape).to(device)
    
    mses = torch.zeros(rollout_len).to(device)

    for di in range(rollout_len):
        decoder_output, decoder_hidden = decoder(decoder_input,decoder_hidden)
        decoder_outputs[:, di, :] = decoder_output
        mses = (decoder_output[:, :2]-target[:, di, :2]).pow(2).sum(axis=-1).pow(0.5)
        cover = mses < ninetyquant[di]
        
        second_part = 1/0.1 * np.pi * (mses.cpu().detach().numpy()**2 - ninetyquant[di]**2) 
        condition = np.where(second_part<0, 0, second_part)
        mrs = np.pi*ninetyquant[di] ** 2 + condition 
        mrses.append(mrs)
        if di in all_coverage:
            all_coverage[di] = np.concatenate((all_coverage[di], cover.detach().cpu().numpy()))
        else:
            all_coverage[di] = cover.detach().cpu().numpy()

        # Use own predictions as inputs at next step
        decoder_input = decoder_output
        de_test.append(torch.sqrt((decoder_output[:, 0] - target[:, di, 0])**2 +
                           (decoder_output[:, 1] - target[:, di, 1])**2).detach().cpu().numpy())
    if i > 40:
        break

41it [00:23,  1.71it/s]


In [189]:
np.mean(mrses)

211.90213

In [414]:
def mrs():
    second_part = 1/0.1 * np.pi * (mses.pow(2) - ninetyquant[di].pow(2)) 
    condition = np.where(second_part<0, 0, second_part)
    mrs = np.pi*ninetyquant[di].pow(2) + condition


In [97]:
# 309 it

In [172]:
np.array(de_test[:-30]).shape

(9240, 128)

In [165]:
de_array = np.array(de_test)

In [166]:
shapes = [h.shape for h in de_array]

In [167]:
from collections import Counter
Counter(shapes)

Counter({(128,): 9240, (48,): 30})

In [173]:
new_de = np.array(de_test[:-30]).reshape((308,30,128))

In [176]:
fde = np.mean(new_de[:,29,:])
fde

4.6787267

In [145]:
np.mean(np.concatenate(de_test).flatten())


2.159612

In [None]:
np.mean(np.concatenate(de_test).flatten())

In [122]:
all_coverage[0].mean()

0.9006384272395622

In [126]:
all_coverage[14].mean()

0.926682205107418

In [125]:
all_coverage[29].mean()

0.9276702472638833

In [14]:
!nvidia-smi

Fri Jan 14 01:08:37 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.91.03    Driver Version: 460.91.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 208...  Off  | 00000000:1A:00.0 Off |                  N/A |
| 27%   44C    P2    61W / 250W |   6363MiB / 11019MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 208...  Off  | 00000000:1B:00.0 Off |                  N/A |
| 22%   28C    P8     2W / 250W |   1258MiB / 11019MiB |      0%      Default |
|       

# conformal ecco

In [378]:
import os
import sys
import numpy as np
sys.path.append('..')
sys.path.append('.')
from collections import namedtuple
import time
import pickle
import argparse
from argoverse.map_representation.map_api import ArgoverseMap
from datasets.argoverse_loader_old import read_pkl_data
from train_utils import *

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


In [379]:
from models.rho_reg_ECCO_original import ECCONetwork
"""Returns an instance of the network for training and evaluation"""
model = ECCONetwork(radius_scale = 40, 
                        layer_channels = [8, 16, 8, 8, 1], 
                        encoder_hidden_size=18)

In [381]:
load_model_path = '../conformal-ecco-new'

print('loading model from ' + load_model_path)
model = torch.load(load_model_path + '.pth')


loading model from ../conformal-ecco-new


In [415]:
! ls ../../argoverse/conformal

calibrate  train  val


In [441]:
cal_path = '../../argoverse/conformal/calibrate'
cal_dataset = read_pkl_data(val_path, batch_size=8, shuffle=False, repeat=False)


In [444]:
val_path = '../../argoverse/conformal/val'
val_dataset = read_pkl_data(val_path, batch_size=8, shuffle=False, repeat=False)


In [398]:
train_window = 30

In [443]:
with torch.no_grad():
    count = 0
    prediction_gt = {}
    losses = []
    all_mse = []

    val_iter = iter(cal_dataset)

    for i, sample in enumerate(cal_dataset):
        pred = []
        gt = []

        if count % 1 == 0:
            print('{}'.format(count + 1), end=' ', flush=True)

        count += 1

        data = {}
        convert_keys = (['pos' + str(i) for i in range(31)] + 
                        ['vel' + str(i) for i in range(31)] + 
                        ['pos_2s', 'vel_2s', 'lane', 'lane_norm'])

        for k in convert_keys:
            data[k] = torch.tensor(np.stack(sample[k])[...,:2], dtype=torch.float32, device=device)


        for k in ['track_id' + str(i) for i in range(31)] + ['city', 'agent_id', 'scene_idx']:
            data[k] = np.stack(sample[k])

        for k in ['car_mask', 'lane_mask']:
            data[k] = torch.tensor(np.stack(sample[k]), dtype=torch.float32, device=device).unsqueeze(-1)

        scenes = data['scene_idx'].tolist()
        data['agent_id'] = data['agent_id'][:,np.newaxis]

        data['car_mask'] = data['car_mask'].squeeze(-1)
        accel = torch.zeros(1, 1, 2).to(device)
        data['accel'] = accel

        lane = data['lane']
        lane_normals = data['lane_norm']
        agent_id = data['agent_id']
        city = data['city']

        inputs = ([
            data['pos_2s'], data['vel_2s'], 
            data['pos0'], data['vel0'], 
            data['accel'], None,
            data['lane'], data['lane_norm'], 
            data['car_mask'], data['lane_mask']
        ])

        pr_pos1, pr_vel1, states = model(inputs)
        gt_pos1 = data['pos1']

        l = 0.5 * loss_fn(pr_pos1, gt_pos1, 
                          torch.sum(data['car_mask'], dim = -2) - 1, data['car_mask'].squeeze(-1))

        pr_agent, gt_agent = get_agent_ecco(pr_pos1, data['pos1'],
                                       data['track_id0'], 
                                       data['track_id1'], 
                                       agent_id, device)
        pred.append(pr_agent.unsqueeze(1).detach().cpu())
        gt.append(gt_agent.unsqueeze(1).detach().cpu())
        del pr_agent, gt_agent
        clean_cache(device)

        pos0 = data['pos0']
        vel0 = data['vel0']
        for i in range(29):
            pos_enc = torch.unsqueeze(pos0, 2)
            vel_enc = torch.unsqueeze(vel0, 2)
            inputs = (pos_enc, vel_enc, pr_pos1, pr_vel1, data['accel'], None, 
                      data['lane'], data['lane_norm'], data['car_mask'], data['lane_mask'])
            pos0, vel0 = pr_pos1, pr_vel1
            pr_pos1, pr_vel1, states = model(inputs, states)
            clean_cache(device)

            if i < train_window - 1:
                gt_pos1 = data['pos'+str(i+2)]
                l += 0.5 * loss_fn(pr_pos1, gt_pos1,
                                   torch.sum(data['car_mask'], dim = -2) - 1, data['car_mask'].squeeze(-1))

            pr_agent, gt_agent = get_agent_ecco(pr_pos1, data['pos'+str(i+2)],
                                           data['track_id0'], 
                                           data['track_id'+str(i+2)], 
                                           agent_id, device)

            pred.append(pr_agent.unsqueeze(1).detach().cpu())
            gt.append(gt_agent.unsqueeze(1).detach().cpu())

            clean_cache(device)

        losses.append(l)

        predict_result = (torch.cat(pred, axis=1), torch.cat(gt, axis=1))
        for idx, scene_id in enumerate(scenes):
            prediction_gt[scene_id] = (predict_result[0][idx], predict_result[1][idx])
        
        if count > 800:
            break

    try:
        total_loss = torch.sum(torch.stack(losses),axis=0) / max_iter
    except:
        total_loss = 0

    result = {}
    de = {}
    
    for k, v in prediction_gt.items():
        mse = (v[0][:,:]-v[1][:,:]).pow(2).sum(axis=-1).pow(0.5)
        all_mse.append(mse.detach().numpy())
        de[k] = torch.sqrt((v[0][:,0] - v[1][:,0])**2 + 
                        (v[0][:,1] - v[1][:,1])**2)

    ade = []
    de1s = []
    de2s = []
    de3s = []
    for k, v in de.items():
        ade.append(np.mean(v.numpy()))
        de1s.append(v.numpy()[10])
        de2s.append(v.numpy()[20])
        de3s.append(v.numpy()[-1])

    result['ADE'] = np.mean(ade)
    result['ADE_std'] = np.std(ade)
    result['DE@1s'] = np.mean(de1s)
    result['DE@1s_std'] = np.std(de1s)
    result['DE@2s'] = np.mean(de2s)
    result['DE@2s_std'] = np.std(de2s)
    result['DE@3s'] = np.mean(de3s)
    result['DE@3s_std'] = np.std(de3s)

    print(result)
    print('done')


1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 {'ADE': 2.0078743, 'ADE_std': 1.9154725, 'DE@1s': 1.1720881, 'DE@1s_std': 1.2563577, 'DE@2s': 2.7079055, 'DE@2s_std': 2.710586, 'DE@3s': 4.561687, 'DE@3s_std': 4.323260

In [396]:
def get_agent_ecco(pr, gt, pr_id, gt_id, agent_id, device='cpu'):
        
    pr_agent = pr[pr_id == agent_id,:]
    gt_agent = gt[gt_id == agent_id,:]
    
    return pr_agent, gt_agent

In [403]:
all_mse = [l.detach().numpy() for l in all_mse]

In [445]:
all_mse = np.array(all_mse)
all_mse.shape

(1879, 30)

In [446]:
ninetyquant = []
for i in range(30):
    q = np.quantile(all_mse[:,i], 0.9)
    ninetyquant.append(q)

In [447]:
ninetyquant

[0.5263745784759523,
 0.7044272184371948,
 0.8153408527374268,
 0.9904256820678712,
 1.1850415706634523,
 1.3769285678863525,
 1.5819029569625862,
 1.6921488285064696,
 1.9486913681030276,
 2.225640773773194,
 2.4322388648986824,
 2.6510236263275155,
 2.945130205154419,
 3.2434028625488285,
 3.5610769271850584,
 3.8270337104797365,
 4.163390064239502,
 4.436709499359132,
 4.734423065185547,
 5.203984355926514,
 5.651400947570801,
 6.060585594177247,
 6.366055202484132,
 6.839110851287845,
 7.364124011993409,
 7.811906909942628,
 8.197876548767091,
 8.637243843078615,
 9.115517807006837,
 9.614289665222168]

In [453]:
with torch.no_grad():
    count = 0
    prediction_gt = {}
    losses = []
    all_mse = []
    mrses = []
    all_coverage = []
    val_iter = iter(val_dataset)

    for i, sample in enumerate(val_dataset):
        pred = []
        gt = []

        if count % 1 == 0:
            print('{}'.format(count + 1), end=' ', flush=True)

        count += 1

        data = {}
        convert_keys = (['pos' + str(i) for i in range(31)] + 
                        ['vel' + str(i) for i in range(31)] + 
                        ['pos_2s', 'vel_2s', 'lane', 'lane_norm'])

        for k in convert_keys:
            data[k] = torch.tensor(np.stack(sample[k])[...,:2], dtype=torch.float32, device=device)


        for k in ['track_id' + str(i) for i in range(31)] + ['city', 'agent_id', 'scene_idx']:
            data[k] = np.stack(sample[k])

        for k in ['car_mask', 'lane_mask']:
            data[k] = torch.tensor(np.stack(sample[k]), dtype=torch.float32, device=device).unsqueeze(-1)

        scenes = data['scene_idx'].tolist()
        data['agent_id'] = data['agent_id'][:,np.newaxis]

        data['car_mask'] = data['car_mask'].squeeze(-1)
        accel = torch.zeros(1, 1, 2).to(device)
        data['accel'] = accel

        lane = data['lane']
        lane_normals = data['lane_norm']
        agent_id = data['agent_id']
        city = data['city']

        inputs = ([
            data['pos_2s'], data['vel_2s'], 
            data['pos0'], data['vel0'], 
            data['accel'], None,
            data['lane'], data['lane_norm'], 
            data['car_mask'], data['lane_mask']
        ])

        pr_pos1, pr_vel1, states = model(inputs)
        gt_pos1 = data['pos1']

        l = 0.5 * loss_fn(pr_pos1, gt_pos1, 
                          torch.sum(data['car_mask'], dim = -2) - 1, data['car_mask'].squeeze(-1))

        pr_agent, gt_agent = get_agent_ecco(pr_pos1, data['pos1'],
                                       data['track_id0'], 
                                       data['track_id1'], 
                                       agent_id, device)
        pred.append(pr_agent.unsqueeze(1).detach().cpu())
        gt.append(gt_agent.unsqueeze(1).detach().cpu())
        del pr_agent, gt_agent
        clean_cache(device)

        pos0 = data['pos0']
        vel0 = data['vel0']
        for i in range(29):
            pos_enc = torch.unsqueeze(pos0, 2)
            vel_enc = torch.unsqueeze(vel0, 2)
            inputs = (pos_enc, vel_enc, pr_pos1, pr_vel1, data['accel'], None, 
                      data['lane'], data['lane_norm'], data['car_mask'], data['lane_mask'])
            pos0, vel0 = pr_pos1, pr_vel1
            pr_pos1, pr_vel1, states = model(inputs, states)
            clean_cache(device)

            if i < train_window - 1:
                gt_pos1 = data['pos'+str(i+2)]
                l += 0.5 * loss_fn(pr_pos1, gt_pos1,
                                   torch.sum(data['car_mask'], dim = -2) - 1, data['car_mask'].squeeze(-1))

            pr_agent, gt_agent = get_agent_ecco(pr_pos1, data['pos'+str(i+2)],
                                           data['track_id0'], 
                                           data['track_id'+str(i+2)], 
                                           agent_id, device)

            pred.append(pr_agent.unsqueeze(1).detach().cpu())
            gt.append(gt_agent.unsqueeze(1).detach().cpu())

            clean_cache(device)

        losses.append(l)

        predict_result = (torch.cat(pred, axis=1), torch.cat(gt, axis=1))
        for idx, scene_id in enumerate(scenes):
            prediction_gt[scene_id] = (predict_result[0][idx], predict_result[1][idx])
        
    try:
        total_loss = torch.sum(torch.stack(losses),axis=0) / max_iter
    except:
        total_loss = 0

    result = {}
    de = {}
    
    for k, v in prediction_gt.items():
        mse = (v[0][:,:]-v[1][:,:]).pow(2).sum(axis=-1).pow(0.5).cpu().detach().numpy()
        all_mse.append(mse)
        
        cover = mse < ninetyquant
        
        second_part = 1/0.1 * np.pi * (np.power(mse,2) - np.power(ninetyquant,2)) 
        condition = np.where(second_part<0, 0, second_part)
        mrs = np.pi*np.power(ninetyquant,2) + condition 
        mrses.append(mrs)
        all_coverage.append(cover)

        
        de[k] = torch.sqrt((v[0][:,0] - v[1][:,0])**2 + 
                        (v[0][:,1] - v[1][:,1])**2)

    ade = []
    de1s = []
    de2s = []
    de3s = []
    for k, v in de.items():
        ade.append(np.mean(v.numpy()))
        de1s.append(v.numpy()[10])
        de2s.append(v.numpy()[20])
        de3s.append(v.numpy()[-1])

    result['ADE'] = np.mean(ade)
    result['ADE_std'] = np.std(ade)
    result['DE@1s'] = np.mean(de1s)
    result['DE@1s_std'] = np.std(de1s)
    result['DE@2s'] = np.mean(de2s)
    result['DE@2s_std'] = np.std(de2s)
    result['DE@3s'] = np.mean(de3s)
    result['DE@3s_std'] = np.std(de3s)

    print(result)
    print('done')


1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 {'ADE': 2.0084646, 'ADE_std': 1.9147258, 'DE@1s': 1.1727345, 'DE@1s_std': 1.2566749, 'DE@2s': 2.7086236, 'DE@2s_std': 2.7090895, 'DE@3s': 4.5598664, 'DE@3s_std': 4.3220024}

In [462]:
np.array(all_coverage).shape

(1871, 30)

In [466]:
np.array(all_coverage)[:,1].mean()

0.9005879208979155

In [458]:
np.array(all_coverage)[:,9].mean()

0.900053447354356

In [459]:
np.array(all_coverage)[:,19].mean()

0.900053447354356

In [456]:
np.array(all_coverage)[:,29].mean()

0.900053447354356

In [467]:
np.array(mrses).mean()

220.9237621733659

In [469]:
for parameter in model.parameters():
    print(parameter.shape)

torch.Size([19, 8, 3, 8, 2])
torch.Size([19, 8])
torch.Size([1, 8, 3, 8, 2])
torch.Size([1, 8])
torch.Size([19, 8, 8])
torch.Size([24, 16, 3, 8, 8])
torch.Size([24, 16, 8])
torch.Size([16, 8, 3, 8, 8])
torch.Size([16, 8, 8])
torch.Size([8, 8, 3, 8, 8])
torch.Size([8, 8, 8])
torch.Size([8, 1, 3, 2, 8])
torch.Size([8, 1])
torch.Size([24, 16, 8])
torch.Size([16, 8, 8])
torch.Size([8, 8, 8])
torch.Size([8, 1, 8])


In [471]:
pytorch_total_params

129320

In [472]:
pytorch_total_params = sum(p.numel() for p in model.parameters())
pytorch_total_params

129320