In [None]:
# %%capture
# !pip install utm
# !pip install openpyxl

In [3]:
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import cm
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.gaussian_process.kernels import RBF
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.gaussian_process import GaussianProcessRegressor
import matplotlib.pyplot as plt
import utm

import warnings
warnings.filterwarnings("ignore")

In [4]:
data = pd.read_excel("../Dataset/Maharashtra_Soil_Nutrients_Data.xlsx")
data.head()

Unnamed: 0,lon,lat,OC,N,P,K
0,73.401111,17.894722,1.08,756.0,9.43,834.37
1,73.401389,17.894722,1.12,781.2,9.21,265.1
2,73.402222,17.894722,0.68,478.8,8.99,318.96
3,73.403056,17.894722,1.76,1234.8,9.65,954.77
4,73.403333,17.894722,1.78,1247.4,8.77,371.77


In [5]:
def scaled_coord(x,y):
    """
    parameters
    ----------
    x : numpy array, float64
        list of longitude cordinates
    y : numpy array, float64
        list of latitude cordinates
        
    return
    ------
    scaled(0-1) x and y
    """
    x = (x-x.min())/(x.max()-x.min())
    y = (y-y.min())/(y.max()-y.min())
    return x,y

In [6]:
%%time

val_col = ['OC','N','P','K']
values = data[val_col]
coordinates = data[['lon','lat']]
#lat,lon to utm projection

x,y,zone,ut = utm.from_latlon(coordinates['lat'].values,coordinates['lon'].values)

lon,lat = y/1000,x/1000 #in km

# lon, lat = scaled_coord(lon,lat)
# normalize values of OC, N, K, P

#standardise lon and lat
# lon = (lon-np.mean(lon))/np.std(lon)
# lat = (lat-np.mean(lat))/np.std(lat)

test_k = MinMaxScaler().fit_transform(values)
values = test_k

Wall time: 15 ms


In [7]:
data['lon'] = lon
data['lat'] = lat
for i,col in enumerate(val_col):
    data[col] = values[:,i]

In [8]:
data.head()

Unnamed: 0,lon,lat,OC,N,P,K
0,1979.264401,330.619321,0.011632,0.079661,0.001163,0.012262
1,1979.264149,330.648754,0.012067,0.082316,0.001136,0.003893
2,1979.263392,330.737053,0.007283,0.05045,0.001109,0.004684
3,1979.262635,330.825353,0.019024,0.130115,0.00119,0.014032
4,1979.262383,330.854785,0.019241,0.131443,0.001082,0.005461


In [9]:
#split dataset into train and test
# split the dataset into train and test dataset
ix = np.random.choice(data.shape[0],int(data.shape[0]*0.2),replace = False)
data_train = data.iloc[[int(i) for i in range(data.shape[0]) if i not in ix]].reset_index(drop = True)
data_test = data.iloc[ix].reset_index(drop = True)

In [10]:
data_train.shape, data_test.shape

((20837, 6), (5209, 6))

In [11]:
data_train.head()

Unnamed: 0,lon,lat,OC,N,P,K
0,1979.264149,330.648754,0.012067,0.082316,0.001136,0.003893
1,1979.263392,330.737053,0.007283,0.05045,0.001109,0.004684
2,1979.262635,330.825353,0.019024,0.130115,0.00119,0.014032
3,1979.261879,330.913651,0.01511,0.10356,0.001109,0.010119
4,1979.261627,330.943085,0.016306,0.111527,0.001136,0.007821


In [12]:
data_test.head()

Unnamed: 0,lon,lat,OC,N,P,K
0,2018.770731,429.996503,0.005435,0.010744,0.000575,0.005802
1,2012.560689,491.570067,0.004783,0.066383,0.001542,0.010198
2,1994.453146,558.712577,0.003805,0.026551,0.001535,0.005758
3,2011.329738,497.238802,0.003479,0.024338,0.001752,0.012507
4,2145.359516,530.740139,0.004674,0.00852,0.000761,0.00395


In [13]:
data_test.shape

(5209, 6)

## Data loading in torch.Dataloader

In [14]:
class NutrientsDataset(Dataset):
    def __init__(self, df, num_context=40, num_extra_target=10):
        self.df = df
        self.num_context = num_context
        self.num_extra_target = num_extra_target

    def get_rows(self, i):
        rows = self.df.iloc[i : i + (self.num_context + self.num_extra_target)].copy()
        x = rows.iloc[:,:2].copy()
        y = rows.iloc[:,2:].copy()
        return x, y


    def __getitem__(self, i):
        x, y = self.get_rows(i)
        return x.values, y.values
        
    def __len__(self):
        return len(self.df) - (self.num_context + self.num_extra_target)

In [15]:
def npsample_batch(x, y, size=None, sort=False):
    
    """Sample from numpy arrays along 2nd dim."""
    inds = np.random.choice(range(x.shape[1]), size=size, replace=False)
    return x[:, inds], y[:, inds]

def collate_fns(max_num_context, max_num_extra_target, sample, sort=True, context_in_target=True):
    def collate_fn(batch, sample=sample):
        # Collate
        x = np.stack([x for x, y in batch], 0)
        y = np.stack([y for x, y in batch], 0)

        # Sample a subset of random size
        num_context = np.random.randint(4, max_num_context)
        num_extra_target = np.random.randint(4, max_num_extra_target)

        x = torch.from_numpy(x).float()
        y = torch.from_numpy(y).float()

        
        x_context = x[:, :max_num_context]
        y_context = y[:, :max_num_context]
    
        x_target_extra = x[:, max_num_context:]
        y_target_extra = y[:, max_num_context:]
        
        if sample:

            x_context, y_context = npsample_batch(
                x_context, y_context, size=num_context
            )

            x_target_extra, y_target_extra = npsample_batch(
                x_target_extra, y_target_extra, size=num_extra_target, sort=sort
            )

        # do we want to compute loss over context+target_extra, or focus in on only target_extra?
        if context_in_target:
            x_target = torch.cat([x_context, x_target_extra], 1)
            y_target = torch.cat([y_context, y_target_extra], 1)
        else:
            x_target = x_target_extra
            y_target = y_target_extra

        
        return x_context, y_context, x_target, y_target

    return collate_fn

 ## NP Model

In [16]:
class baseNPBlock(nn.Module):
    """relu non-linearities for NP block"""
    def __init__(self, inp_size,op_size, norm, bias = False, p = 0):
        """init function for linear2d class
        
        parameters
        ----------
        inp_size : int
                input dimension for the Encoder part (d_in)
        op_size : int
                output dimension for Encoder part(d_out)
        norm : str
                normalization to be applied on linear output
                pass norm == 'batch' to apply batch normalization
                else dropout normalization is applied
        bias : bool
                if True, bias is included for linear layer else discarded
        p : float
                probality to be considered while applying Dropout regularization
                
        """
        super().__init__()
        self.norm = norm
        self.linear = nn.Linear(inp_size,op_size,bias = bias)
        self.relu  = nn.ReLU()
        self.batch_norm = nn.BatchNorm2d(op_size)
        self.dropout = nn.Dropout2d(p)
        
    def forward(self,x):
        x = self.linear(x)
        x = self.batch_norm(x.permute(0,2,1)[:,:,:,None]) if self.norm == 'batch' else self.dropout(x.permute(0,2,1)[:,:,:,None])
        
        x = self.relu(x[:,:,:,0].permute(0,2,1))
        return x

In [17]:
class batch_MLP(nn.Module):
    """ Batch MLP layer for NP-Encoder"""
    def __init__(self, in_size, op_size, num_layers, norm, p = 0):
        """init function for linear2d class
        
        parameters
        ----------
        inp_size : int
                input dimension for the Encoder part (d_in)
        op_size : int
                output dimension for Encoder part(d_out)
        norm : str
                normalization to be applied on linear output
                pass norm == 'batch' to apply batch normalization
                else dropout normalization is applied
                
        return torch.tensor of size (B,num_context_points,d_out)
        """
        super().__init__()
        self.in_size = in_size
        self.op_size = op_size
        self.num_layers = num_layers
        self.norm  = norm
        
        self.first_layer = baseNPBlock(in_size, op_size, self.norm, False,p)
        self.encoder = nn.Sequential(*[baseNPBlock(op_size, op_size, self.norm, False, p) for layer in range(self.num_layers-2)])
        self.last_layer = nn.ReLU()
        
    def forward(self, x):
        x = self.first_layer(x)
        x = self.encoder(x)
        x = self.last_layer(x)
        
        return x

In [18]:
class LinearAttention(nn.Module):
    def __init__(self,in_ch, out_ch):
        super().__init__()
        self.linear = nn.Linear(in_ch, out_ch, bias = False)
        torch.nn.init.normal_(self.linear.weight,std = in_ch**0.5) #initilize weight matrix
        
    def forward(self,x):
        return self.linear(x)
    
    
class AttentionModule(nn.Module):
    def __init__(
        self,
        hidden_dim, 
        attn_type , 
        attn_layers,
        x_dim, 
        rep='mlp',
        n_multiheads = 8,
        norm = 'dropout',
        p = 0):
        
        super().__init__()
        self.rep = rep
#         self.n_multiheads = n_multiheads
        # rep determines whether raw input given to the model would be used as key and query or
        # it's output through MLP. 
        if self.rep =='mlp':
            
            #Both Key and Value needs to have same dimension
            self.batch_mlpk = batch_MLP(x_dim, hidden_dim, attn_layers, norm ,p)
            self.batch_mlpq = batch_MLP(x_dim, hidden_dim, attn_layers, norm, p)
        
        
        if attn_type == 'uniform':
            self.attn_func = self.uniform_attn
        if attn_type=='laplace':
            self.attn_func = self.laplace_attn
        if attn_type == 'dot':
            self.attn_func = self.dot_attn
        elif attn_type == 'multihead':
            self.w_k = nn.ModuleList([LinearAttention(hidden_dim,hidden_dim) for head in range(n_multiheads)])
            self.w_v = nn.ModuleList([LinearAttention(hidden_dim,hidden_dim) for head in range(n_multiheads)])
            self.w_q = nn.ModuleList([LinearAttention(hidden_dim,hidden_dim) for head in range(n_multiheads)])
            
            self.w = LinearAttention(hidden_dim*n_multiheads,hidden_dim)
            self.attn_func = self.multihead_attn
            self.num_heads = n_multiheads
            
            
            
    def forward(self, k, q, v):
        if self.rep =='mlp':
            k = self.batch_mlpk(k) #(B, n, H)
            q = self.batch_mlpq(q) #(B, m, H)
        
        rep = self.attn_func(k,q,v)
        
        return rep
    
    
    def uniform_attn(self, k, q, v):
        num_points = q.shape[1]
        rep = torch.mean(v, axis = 1, keepdim = True)
        rep = rep.repeat(1,num_points,1)
        
        return rep
    
    def laplace_attn(self, k, q, v, scale = 0.5):
        k = k.unsqueeze(1)
        v = v.unsqueeze(2)
        
        w = torch.abs((k-v)*scale)
        w = w.sum(dim = -1)
        weight = torch.softmax(w, dim = -1)
        
        #batch matrix multiplication (einstein summation convention for tensor)
        rep = torch.einsum("bik, bkj -> bij",weight, v)
        
        return rep
    
    
    def dot_product_attn(self, k, q, v):
#         print("k =",k.shape)
#         print("q =",q.shape)
#         print("v =",v.shape)    
        β = q.shape[-1]**0.5
        w_unnorm = torch.einsum('bjk,bik->bij', k, q)/β
#         print("w_unnorm =",w_unnorm.shape)
        
        weight = torch.softmax(w_unnorm, dim = -1)
        rep = torch.einsum("bik, bkj -> bij",weight, v)
#         print("rep =",rep.shape)
        return rep
    
    def multihead_attn(self, k , q, v):
        outs = []
        
        for i in range(self.num_heads):
            k = self.w_k[i](k) #(B, n, H)
#             print("k =",k.shape)
            q = self.w_q[i](q) #(B, m, H)
#             print("q =",q.shape)
            v = self.w_v[i](v) #(B, n, H)
#             print("v =",v.shape)
            out = self.dot_product_attn(k, q, v)
            outs.append(out)
            
        outs = torch.stack(outs, dim = -1) #(B, m, H, n_heads)
#         print("outs dim =", outs.shape)
        outs = outs.view(outs.shape[0], outs.shape[1], -1) #(B, m, n_heads*H)
#         print("outs shape =",outs.shape)
        rep = self.w(outs) #(B, m, H)
        
        return rep
    
    

In [19]:
# AttentionModule?

In [20]:
class DeterministicEncoder(nn.Module):
    def __init__(
                self,
                in_dim,
                x_dim,
                norm = 'dropout',
                hidden_dim = 32,
                encoder_layer = 2,
                self_attn_type ='dot',
                cross_attn_type ='dot',
                p_encoder = 0,
                p_attention = 0,
                attn_layers = 2,
                use_self_attn = False
                ):
        super().__init__()
        
        self.use_self_attn = use_self_attn
        
        self.encoder = batch_MLP(in_dim, hidden_dim, encoder_layer,norm, p_encoder)
        
        if self.use_self_attn:
            self.self_attn = AttentionModule(hidden_dim, self_attn_type, attn_layers,x_dim, rep = 'mlp',norm = norm, p = p_attention)
            
        self.cross_attn = AttentionModule(hidden_dim, cross_attn_type, attn_layers, x_dim)
        
    
    def forward(self, context_x, context_y, target_x):
        #concatenate context_x, context_y along the last dim.
        det_enc_in = torch.cat([context_x, context_y], dim = -1)
        
        det_encoded = self.encoder(det_enc_in) #(B, n, hd)
        
        if self.use_self_attn:
            det_encoded = self.self_attn(det_encoded, det_encoded, det_encoded)
            
        h = self.cross_attn(context_x, target_x, det_encoded)
        
        return h
        
        
    
        
        

In [21]:
class LatentEncoder(nn.Module):
    def __init__(self,
                in_dim,
                hidden_dim = 32,
                latent_dim = 32,
                self_attn_type = 'dot',
                encoder_layer = 3,
                min_std = 0.01,
                norm = 'dropout',
                p_encoder = 0,
                p_attn = 0,
                use_self_attn = False,
                attn_layers = 2,
                ):
        
        super().__init__()
        
        self._use_attn = use_self_attn
        
        self.encoder = batch_MLP(in_dim, hidden_dim, encoder_layer,norm, p_encoder)
        
        if self._use_attn:
            self.self_attn = AttentionModule(hidden_dim, self_attn_type, attn_layers,x_dim, rep = 'identity',norm = norm, p = p_attention)
        
        self.secondlast_layer = nn.Linear(hidden_dim, hidden_dim)
        self.mean = nn.Linear(hidden_dim, latent_dim)
        self.l_sigma = nn.Linear(hidden_dim, latent_dim) 
        self.min_std = min_std
#         self.use_lvar = use_lvar
        self.use_attn = use_self_attn
        
        
        
    def forward(self,x,y):
        encoder_inp = torch.cat([x,y], dim = -1) 
        
        encoded_op = self.encoder(encoder_inp)#(B, n, hd)
#         print("encoder_op shape = ",encoded_op.shape)
        if self.use_attn:
            encoded_op = self.self_attn(encoded_op, encoded_op, encoded_op) #(B, n, hd)
            
        
        mean_val = torch.mean(encoded_op, dim = 1) #mean aggregation (B, hd)
        
        #further MLP layer that maps parameters to gaussian latent
        mean_repr = torch.relu(self.secondlast_layer(mean_val)) #(B, hd)
        
        μ = self.mean(mean_repr) # (B, ld)
#         print("mean = ", μ.shape)
        log_scale = self.l_sigma(mean_repr) #(B, ld)
        
        #to avoid mode collapse
        σ = self.min_std + (1-self.min_std)*torch.sigmoid(log_scale*0.5) #(b, ld)
#         print(σ)
        dist = torch.distributions.Normal(μ, σ)
        
        return dist
        
        
            

In [22]:
class Decoder(nn.Module):
    def __init__(self,
                 x_dim,
                 y_dim,
                 hidden_dim = 32,
                 latent_dim = 32,
                 n_decoder_layer = 3,
                 use_deterministic_path = True,
                 min_std = 0.01,
                 norm = 'dropout',
                 dropout_p = 0,
                ):
        super().__init__()
        
        self.norm = norm
        self.target_transform = nn.Linear(x_dim, hidden_dim)
        
        if use_deterministic_path:
            hidden_dim_2 = 2 * hidden_dim + latent_dim
        else:
            hidden_dim_2 = hidden_dim + latent_dim
            
        self.decoder = batch_MLP(hidden_dim_2, hidden_dim_2, n_decoder_layer, norm, dropout_p)
        
        self.mean = nn.Linear(hidden_dim_2, y_dim)
        self.std = nn.Linear(hidden_dim_2, y_dim)
        self.deterministic_path = use_deterministic_path
        self.min_std = min_std
        
        
    def forward(self, r, z, t_x):
        x = self.target_transform(t_x)
        
        if self.deterministic_path:
            z = torch.cat([r,z], dim = -1)
#             print("z.shape =", z.shape)
        r = torch.cat([z,x], dim = -1)
        
        r = self.decoder(r)
        
        mean = self.mean(r)
        log_sigma = self.std(r)
        
        #clamp sigmad
        sigma = self.min_std + (1 - self.min_std) * F.softplus(log_sigma)
        
        dist = torch.distributions.Normal(mean,sigma)
        
        return dist

In [23]:
class LatentModel(nn.Module):
    def __init__(self,
               x_dim,
               y_dim,
               hidden_dim = 32,
               latent_dim = 32,
               latent_self_attn_type = 'multihead',
                det_self_attn_type = 'multihead',
                det_cross_attn_type = 'multihead',
               n_lat_enc_layer = 2,
               n_det_enc_layer = 2,
               n_decoder_layer = 2,
               use_deterministic_enc = False,
               min_std = 0.01,
               p_drop = 0,
               norm = 'dropout',
               p_attn_drop = 0,
               attn_layers = 2,
               use_self_attn = False,
               context_in_target = True,
                training = False):
        
        super().__init__()
        self.laten_encoder = LatentEncoder(x_dim+y_dim,
                                           hidden_dim=hidden_dim,
                                           latent_dim=latent_dim,
                                           self_attn_type=latent_self_attn_type,
                                           encoder_layer=n_lat_enc_layer,
                                           min_std=min_std,
                                           norm = norm,
                                           p_encoder=p_drop,
                                           p_attn=p_attn_drop,
                                           use_self_attn=use_self_attn,
                                           attn_layers=attn_layers 
                                          )
        self.deterministic_encoder = DeterministicEncoder(x_dim+y_dim,
                                                          x_dim,
                                                          norm = norm,
                                                          hidden_dim=hidden_dim,
                                                          encoder_layer=n_det_enc_layer,
                                                          self_attn_type=det_self_attn_type,
                                                          cross_attn_type=det_cross_attn_type,
                                                          p_encoder=p_drop,
                                                          p_attention=p_attn_drop,
                                                          attn_layers=attn_layers,
                                                          use_self_attn=use_self_attn
                                                         )
        self.decoder = Decoder(x_dim,
                              y_dim,
                              hidden_dim  = hidden_dim,
                              latent_dim=latent_dim,
                              n_decoder_layer=n_decoder_layer,
                              use_deterministic_path=use_deterministic_enc,
                              min_std=min_std,
                              norm=norm,
                              dropout_p=p_drop
                              )
        self.use_deterministic_enc = use_deterministic_enc
        self.context_in_target = context_in_target
        self.training = training
        
        
    def forward(self, c_x, c_y, t_x, t_y = None):
        dist_prior = self.laten_encoder(c_x, c_y)

        if t_y is not None:
            dist_posterior = self.laten_encoder(t_x, t_y)
            z = dist_posterior.loc
        else:
            z = dist_prior.loc
            
        n_target = t_x.shape[1]
        z = z.unsqueeze(1).repeat(1, n_target,1) #(B, n_target, L)
        
        if self.use_deterministic_enc:
            r = self.deterministic_encoder(c_x, c_y, t_x) #(B, n_target=m, H)
#             print(r.shape)
        else:
            r = None
            
        dist = self.decoder(r, z, t_x)
        
        #at test time, target y is not Known so we return None
        if t_y is not None:
            log_p = dist.log_prob(t_y).mean(-1)
            kl_loss = torch.distributions.kl_divergence(dist_posterior, dist_prior).mean(-1)
            kl_loss = kl_loss[:,None].expand(log_p.shape)
            loss = (kl_loss-log_p).mean()
            mse_loss = F.mse_loss(dist.loc, t_y, reduction = 'none')[:,:c_x.size(1)].mean()
        else:
            kl_loss  =None
            log_p = None
            mse_loss = None
            loss = None
            
        y_pred = dist.rsample() if self.training else dist.loc
            
        return y_pred,  dict(loss = loss, loss_p = log_p, loss_kl = kl_loss, loss_mse = mse_loss), dist



In [44]:
Regressor = LatentModel(2,4,
                       p_drop = 0,
                        hidden_dim = 64,
                        latent_dim = 16,
                       n_decoder_layer = 3,
                       norm = 'batch',
                       context_in_target=False
                       )

In [45]:
#train data loader
hparamas = dict(num_context = 15,
               num_extra_target = 16,
               batch_size = 400,
               context_in_target = False)
train_df = NutrientsDataset(data_train,hparamas['num_context'],hparamas['num_extra_target'])

train_loader = DataLoader(train_df,
                          batch_size=hparamas['batch_size'],
                         shuffle = True,
                         collate_fn=collate_fns(
                             hparamas['num_context'],hparamas['num_extra_target'], True,hparamas['context_in_target']))

In [46]:
# #eval data loader
# hparamas = dict(num_context = 15,
#                num_extra_target = 16,
#                batch_size = 40,
#                context_in_target = False)
# eval_df = NutrientsDataset(data_test,hparamas['num_context'],hparamas['num_extra_target'])

# train_loader = DataLoader(eval_df,
#                           batch_size=hparamas['batch_size'],
#                          shuffle = True,
#                          collate_fn=collate_fns(
#                              hparamas['num_context'],hparamas['num_extra_target'], True,hparamas['context_in_target']))

In [47]:
# data_train.iloc[:,:2]

In [61]:
#eval loss
def test(do_eval=True):
    """Run model on test/val data"""
    if do_eval:
        Regressor.eval()
    with torch.no_grad():
        target_x, target_y = data_test.iloc[:,:2], data_test.iloc[:,2:]
        context_x, context_y = data_train.iloc[:,:2], data_train.iloc[:,2:]

        context_x = torch.from_numpy(context_x.values).float()[None, :]
        context_y = torch.from_numpy(context_y.values).float()[None, :]
        target_x = torch.from_numpy(target_x.values).float()[None, :]
        target_y = torch.from_numpy(target_y.values).float()[None, :]
        print(context_x.shape, context_y.shape, target_x.shape, target_y.shape)
        y_pred, losses, extra = Regressor.forward(context_x, context_y, target_x, target_y)
        print(y_pred.shape)
    yr=(target_y-y_pred)[0].detach().cpu().numpy()
#     print(yr)
    return yr, y_pred, losses, extra 

In [62]:
test()[0]

torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])


array([[-0.19728503, -0.3132522 ,  0.845472  , -0.02644168],
       [-0.35095876, -0.42687303,  0.5910084 ,  0.02158483],
       [-0.4037054 , -0.18509154,  0.25507063,  0.09816705],
       ...,
       [-0.1871854 , -0.2815249 ,  0.92175674, -0.00298213],
       [-0.2412051 , -0.41691044,  0.9135716 ,  0.06411079],
       [-0.22108693, -0.25960636,  0.94111425, -0.04580018]],
      dtype=float32)

In [51]:
opt = torch.optim.Adam(Regressor.parameters(), lr=1e-4)

In [63]:
from tqdm.auto import tqdm 

for epoch in range(1000):
    loss = 0 
    mse_loss = 0
    Regressor.train()
    for batch in tqdm(train_loader):
        context_x, context_y, target_x, target_y = batch
#         context_x.shape
        Regressor.zero_grad()
        y_pred, losses, extra = Regressor.forward(context_x, context_y, target_x, target_y)
#         print(y_pred.shape)
        losses['loss'].backward()
        loss += losses['loss'].cpu().detach().numpy()
        mse_loss+=losses['loss_mse'].cpu().detach().numpy()
        opt.step()
    loss /= len(train_loader)
    
    print(epoch)
    print('ELBO train_loss', loss)
    print('mse train_loss', mse_loss/len(train_loader))
    
    val_loss = test()[0]
    val_loss = np.mean(np.abs(val_loss))
    print('val_loss = ', val_loss)
    print("-----------------------------------------------------------------------")

  0%|          | 0/53 [00:00<?, ?it/s]

0
ELBO train_loss 0.40834776804132283
mse train_loss 0.0067994619350669515
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.15735935
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

1
ELBO train_loss 0.24481830478839153
mse train_loss 0.002267636929831977
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.12186013
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

2
ELBO train_loss 0.07959765993840641
mse train_loss 0.001636472598754994
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.19178033
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

3
ELBO train_loss -0.0879209226322413
mse train_loss 0.0013878671100960587
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.0794784
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

4
ELBO train_loss -0.26227339652349363
mse train_loss 0.0011973627293833866
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.297436
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

5
ELBO train_loss -0.433786042456357
mse train_loss 0.0010687425646948505
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.08530353
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

6
ELBO train_loss -0.608558622733602
mse train_loss 0.0009885331361009827
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.30869812
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

7
ELBO train_loss -0.7532960362029526
mse train_loss 0.0013734547394977988
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.13459015
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

8
ELBO train_loss -0.9304907591837757
mse train_loss 0.001073311989081426
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.034837585
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

9
ELBO train_loss -1.131720078441332
mse train_loss 0.0007996864996868062
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.06763096
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

10
ELBO train_loss -1.2974604019578897
mse train_loss 0.0007904292290749134
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.094235405
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

11
ELBO train_loss -1.4506332851805777
mse train_loss 0.0008067153185553287
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.08215833
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

12
ELBO train_loss -1.6132602399250247
mse train_loss 0.0007268361646225149
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.0933292
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

13
ELBO train_loss -1.686658429649641
mse train_loss 0.000830927143809122
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.23638672
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

14
ELBO train_loss -1.6151672084376496
mse train_loss 0.001097363437952052
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.14626084
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

15
ELBO train_loss -1.7923856726232565
mse train_loss 0.0011022007565084353
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.08225398
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

16
ELBO train_loss -1.9609086783427112
mse train_loss 0.0007058670991407883
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.03979873
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

17
ELBO train_loss -2.028430223464966
mse train_loss 0.0006099285586420798
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.10861727
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

18
ELBO train_loss -2.03132977575626
mse train_loss 0.0008410568659870341
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.06149123
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

19
ELBO train_loss -2.1949547304297394
mse train_loss 0.0005433122865759047
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.06145505
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

20
ELBO train_loss -2.2235995823482297
mse train_loss 0.0005981200369171587
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.035563897
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

21
ELBO train_loss -2.2916094964405276
mse train_loss 0.0005517425229728995
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.08023176
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

22
ELBO train_loss -2.30663182375566
mse train_loss 0.0005545936286546079
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.061161123
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

23
ELBO train_loss -2.3499891251887917
mse train_loss 0.0005334508522932049
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.05072058
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

24
ELBO train_loss -2.3791281232294046
mse train_loss 0.0004988302400375207
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.07912498
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

25
ELBO train_loss -2.390800134190973
mse train_loss 0.0005487725308894198
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.041973863
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

26
ELBO train_loss -2.430025378488145
mse train_loss 0.0004925897371305047
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.08228272
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

27
ELBO train_loss -2.38328251973638
mse train_loss 0.0005532702203743371
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.033492282
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

28
ELBO train_loss -2.4689285459945784
mse train_loss 0.0004639147540866309
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.10504768
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

29
ELBO train_loss -2.424550913414865
mse train_loss 0.0005527803825899818
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.044320457
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

30
ELBO train_loss -2.5484791251848327
mse train_loss 0.00045219865734117846
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.035357434
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

31
ELBO train_loss -2.5804037328036324
mse train_loss 0.00045931464245648314
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.03848247
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

32
ELBO train_loss -2.5648806758646696
mse train_loss 0.0004816116152313661
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.046441387
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

33
ELBO train_loss -2.5601606863849566
mse train_loss 0.00045347760842216886
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.06501396
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

34
ELBO train_loss -2.489996419762665
mse train_loss 0.0004751663207502614
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.06731669
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

35
ELBO train_loss -2.4873395056094765
mse train_loss 0.0004947196787227613
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.091746904
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

36
ELBO train_loss -2.4305238869954957
mse train_loss 0.0005578509787809244
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.04201887
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

37
ELBO train_loss -2.5489689853956117
mse train_loss 0.0004704142744951653
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.056978036
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

38
ELBO train_loss -2.615879684124353
mse train_loss 0.0004199807735858485
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.030261366
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

39
ELBO train_loss -2.6499106164248483
mse train_loss 0.0004016833181080039
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.022900581
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

40
ELBO train_loss -2.528158480266355
mse train_loss 0.0005401445853080213
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.06127004
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

41
ELBO train_loss -2.5456336579232848
mse train_loss 0.0005359085724812072
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.040983748
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

42
ELBO train_loss -2.6595204686218836
mse train_loss 0.00042905151033510437
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019157488
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

43
ELBO train_loss -2.6776152984151302
mse train_loss 0.0004062995656156245
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.023120705
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

44
ELBO train_loss -2.6051556366794513
mse train_loss 0.00042995634018707106
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.04497247
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

45
ELBO train_loss -2.621221276949037
mse train_loss 0.00044986148247328357
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.023434158
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

46
ELBO train_loss -2.656302185553425
mse train_loss 0.0004336328974942554
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.038297594
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

47
ELBO train_loss -2.65714352985598
mse train_loss 0.000435002778551348
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.03259304
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

48
ELBO train_loss -2.6733519198759548
mse train_loss 0.0004252419361673331
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021271918
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

49
ELBO train_loss -2.699694311843728
mse train_loss 0.00041466848683139343
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.029299356
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

50
ELBO train_loss -2.7149279162568867
mse train_loss 0.00041957082636522584
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01916606
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

51
ELBO train_loss -2.6869465213901593
mse train_loss 0.000436896316318351
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.042730946
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

52
ELBO train_loss -2.687530022747112
mse train_loss 0.00041985670313291813
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.035137452
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

53
ELBO train_loss -2.7307794004116417
mse train_loss 0.00040168261961607296
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.029165657
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

54
ELBO train_loss -2.680409980270098
mse train_loss 0.0004477554364986542
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.0464146
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

55
ELBO train_loss -2.7090913867050745
mse train_loss 0.00042083312291651964
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.024314495
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

56
ELBO train_loss -2.677936758916333
mse train_loss 0.0004227496458454725
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.03771767
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

57
ELBO train_loss -2.6536983993818177
mse train_loss 0.0004468771419268242
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.02200252
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

58
ELBO train_loss -2.7148541661928283
mse train_loss 0.0004258290839816426
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.03446635
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

59
ELBO train_loss -2.739054468442809
mse train_loss 0.0004042669186777137
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.022941245
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

60
ELBO train_loss -2.746487257615575
mse train_loss 0.000413771241822034
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.024856444
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

61
ELBO train_loss -2.757738171883349
mse train_loss 0.0003933089871491077
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018837165
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

62
ELBO train_loss -2.694187841325436
mse train_loss 0.0004144069699697934
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.03046815
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

63
ELBO train_loss -2.7353887827891223
mse train_loss 0.00041874260960129213
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.037382655
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

64
ELBO train_loss -2.6715808162149393
mse train_loss 0.0004453679045908294
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.03245331
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

65
ELBO train_loss -2.673145244706352
mse train_loss 0.0004481890448588737
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.034287553
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

66
ELBO train_loss -2.7123975950591968
mse train_loss 0.0004237311426001021
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.038316306
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

67
ELBO train_loss -2.7266859535901053
mse train_loss 0.00043082604995260963
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.028997213
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

68
ELBO train_loss -2.717068853243342
mse train_loss 0.00044031320375891156
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.04022708
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

69
ELBO train_loss -2.6871960590470514
mse train_loss 0.0004290330104578181
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.032693528
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

70
ELBO train_loss -2.7596108823452354
mse train_loss 0.0004100055367365163
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.027789975
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

71
ELBO train_loss -2.7535473175768583
mse train_loss 0.0004133255185805402
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.027227316
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

72
ELBO train_loss -2.7557513084051743
mse train_loss 0.0004112909203131666
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.022256583
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

73
ELBO train_loss -2.791561976918634
mse train_loss 0.00039038426579933894
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.022633737
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

74
ELBO train_loss -2.77646171596815
mse train_loss 0.00040031644133499487
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.024538416
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

75
ELBO train_loss -2.7674157799414867
mse train_loss 0.0004014619579428877
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.024063416
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

76
ELBO train_loss -2.7981077500109404
mse train_loss 0.00038419249313795625
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01521877
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

77
ELBO train_loss -2.7856735238489114
mse train_loss 0.00040269532363932086
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.035567947
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

78
ELBO train_loss -2.779385994065483
mse train_loss 0.00038891323383036507
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.029989
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

79
ELBO train_loss -2.76929830605129
mse train_loss 0.0004105641573056895
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.036931884
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

80
ELBO train_loss -2.774078036254307
mse train_loss 0.0004058765598357532
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.02645277
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

81
ELBO train_loss -2.760878120953182
mse train_loss 0.00041133613421203883
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.057841524
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

82
ELBO train_loss -2.7700484158857814
mse train_loss 0.00040183970863922095
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.03850194
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

83
ELBO train_loss -2.7686117550112166
mse train_loss 0.0004162760918195588
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.03277808
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

84
ELBO train_loss -2.808971845878745
mse train_loss 0.00040220505561249084
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01766633
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

85
ELBO train_loss -2.795418858528137
mse train_loss 0.00041262902383130257
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.02739082
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

86
ELBO train_loss -2.7878275142525726
mse train_loss 0.0003988187217026211
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.022258848
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

87
ELBO train_loss -2.8032612800598145
mse train_loss 0.00036440596735667224
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.027568065
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

88
ELBO train_loss -2.7971991840398536
mse train_loss 0.0004009901326839408
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.031396713
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

89
ELBO train_loss -2.8094387864166834
mse train_loss 0.00037288515419997976
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.022040673
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

90
ELBO train_loss -2.816300873486501
mse train_loss 0.00038171893332751013
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.035062887
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

91
ELBO train_loss -2.7893894033611946
mse train_loss 0.00039151550752752163
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.022891128
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

92
ELBO train_loss -2.7941895768327534
mse train_loss 0.0004131938218257724
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.034245074
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

93
ELBO train_loss -2.7984625348504983
mse train_loss 0.0004021816465230483
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.023176039
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

94
ELBO train_loss -2.8229239166907543
mse train_loss 0.00037665628522792657
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.02517226
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

95
ELBO train_loss -2.7783130767210475
mse train_loss 0.00042031230026074107
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.025518272
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

96
ELBO train_loss -2.7997708163171446
mse train_loss 0.00040315467269019274
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.023000775
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

97
ELBO train_loss -2.809939573395927
mse train_loss 0.0004005729512394986
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.023404894
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

98
ELBO train_loss -2.8309177497647844
mse train_loss 0.0003793839821570008
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021173801
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

99
ELBO train_loss -2.7937425689877204
mse train_loss 0.000408082824026867
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.0304697
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

100
ELBO train_loss -2.783164323500867
mse train_loss 0.0004173906608967919
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.05739398
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

101
ELBO train_loss -2.7858076500442794
mse train_loss 0.00040288730733149315
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021215368
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

102
ELBO train_loss -2.836868983394695
mse train_loss 0.00038211171560314535
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.023472086
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

103
ELBO train_loss -2.823196784505304
mse train_loss 0.00039158793268417766
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021746492
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

104
ELBO train_loss -2.828353360014142
mse train_loss 0.0003813079268372846
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.022470651
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

105
ELBO train_loss -2.841936669259701
mse train_loss 0.0003795765676944218
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.045253035
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

106
ELBO train_loss -2.752782162630333
mse train_loss 0.0003922777019774998
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.031153632
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

107
ELBO train_loss -2.833757665922057
mse train_loss 0.0004016335738959002
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021649126
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

108
ELBO train_loss -2.8403475914361342
mse train_loss 0.0003923105876585293
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.029500604
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

109
ELBO train_loss -2.8266270250644325
mse train_loss 0.0003848013551554309
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015265242
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

110
ELBO train_loss -2.7782821835211986
mse train_loss 0.00042548615821496634
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.041723676
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

111
ELBO train_loss -2.728231777560036
mse train_loss 0.00043579978447732566
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.050865274
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

112
ELBO train_loss -2.786755197453049
mse train_loss 0.0004225926677463576
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019598346
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

113
ELBO train_loss -2.839338626501695
mse train_loss 0.0003904984401262207
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.022519322
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

114
ELBO train_loss -2.8310651239359155
mse train_loss 0.00038491567802167374
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.023748519
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

115
ELBO train_loss -2.833147984630657
mse train_loss 0.0004004755352276711
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.020887977
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

116
ELBO train_loss -2.8663319686673723
mse train_loss 0.0003630895876514567
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021494161
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

117
ELBO train_loss -2.8611560452659175
mse train_loss 0.0003911000745094223
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.017221201
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

118
ELBO train_loss -2.8588940872336335
mse train_loss 0.00038970623715596167
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018768128
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

119
ELBO train_loss -2.8644030319069915
mse train_loss 0.00037234731335390606
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019917117
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

120
ELBO train_loss -2.872068364665193
mse train_loss 0.00038055557003112964
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014044623
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

121
ELBO train_loss -2.831835945822158
mse train_loss 0.00038697886558574675
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.046954643
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

122
ELBO train_loss -2.818664294368816
mse train_loss 0.0003804508311111691
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.025822336
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

123
ELBO train_loss -2.8629550888853252
mse train_loss 0.0003791862586952465
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.026355596
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

124
ELBO train_loss -2.847321348370246
mse train_loss 0.00038538365937848486
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015783625
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

125
ELBO train_loss -2.8644602973506137
mse train_loss 0.00037571615270418025
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.022090718
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

126
ELBO train_loss -2.8593154493367896
mse train_loss 0.0003727751222267782
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.025898205
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

127
ELBO train_loss -2.8302857842085496
mse train_loss 0.0003796930533047449
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019619942
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

128
ELBO train_loss -2.800240494170279
mse train_loss 0.0004109444762777783
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015602909
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

129
ELBO train_loss -2.866087958497821
mse train_loss 0.0003631133956830281
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.023285294
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

130
ELBO train_loss -2.8623218941238693
mse train_loss 0.0004039431108328742
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.027457152
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

131
ELBO train_loss -2.85437192556993
mse train_loss 0.0003696000101901415
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.04548142
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

132
ELBO train_loss -2.8303261163099758
mse train_loss 0.0004078565000112713
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.02576965
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

133
ELBO train_loss -2.8253410739718743
mse train_loss 0.00038066602812833944
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.023476088
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

134
ELBO train_loss -2.8468073494029493
mse train_loss 0.0003846533693162338
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.027263926
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

135
ELBO train_loss -2.8753079153456778
mse train_loss 0.0003688578509817884
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.020601723
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

136
ELBO train_loss -2.868587210493268
mse train_loss 0.00038171170981170363
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.032912787
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

137
ELBO train_loss -2.854706213159381
mse train_loss 0.0003799668961876722
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.03477619
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

138
ELBO train_loss -2.8249344893221586
mse train_loss 0.00043506523839171695
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.027320022
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

139
ELBO train_loss -2.825375700896641
mse train_loss 0.0003967495557724692
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.030704245
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

140
ELBO train_loss -2.834379070210007
mse train_loss 0.0003966606010129359
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.026402088
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

141
ELBO train_loss -2.8423396956245854
mse train_loss 0.00040014651505153553
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.025299843
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

142
ELBO train_loss -2.866404762807882
mse train_loss 0.000392795819898997
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019400515
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

143
ELBO train_loss -2.8692083988549575
mse train_loss 0.0003908436524655389
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.024869444
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

144
ELBO train_loss -2.8791690682465174
mse train_loss 0.0003709720715634384
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015912987
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

145
ELBO train_loss -2.8610323330141463
mse train_loss 0.00039086656624073477
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.044660263
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

146
ELBO train_loss -2.813949076634533
mse train_loss 0.00038792707936460467
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021527898
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

147
ELBO train_loss -2.8453541504886917
mse train_loss 0.00038857949828456666
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.029121662
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

148
ELBO train_loss -2.8266844749450684
mse train_loss 0.0003909584786083971
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021506302
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

149
ELBO train_loss -2.8668200002526336
mse train_loss 0.0003809605912911653
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.031942423
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

150
ELBO train_loss -2.8083037987070263
mse train_loss 0.0004270665994449958
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.038877442
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

151
ELBO train_loss -2.810326524500577
mse train_loss 0.00039653049233887906
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.020682983
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

152
ELBO train_loss -2.8780521311849916
mse train_loss 0.0003879702793333223
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018776275
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

153
ELBO train_loss -2.882470427819018
mse train_loss 0.00038785658741615853
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018038506
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

154
ELBO train_loss -2.882865199502909
mse train_loss 0.00037209105014555016
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.0182118
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

155
ELBO train_loss -2.8743696167783916
mse train_loss 0.0003912321059582404
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.024787799
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

156
ELBO train_loss -2.865720416015049
mse train_loss 0.00039357266977499677
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019224154
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

157
ELBO train_loss -2.8771630413127394
mse train_loss 0.00038356923790349855
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.017459322
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

158
ELBO train_loss -2.8367187983866007
mse train_loss 0.00038685984739304503
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.028658686
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

159
ELBO train_loss -2.856478232257771
mse train_loss 0.00038679293439205175
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.024945142
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

160
ELBO train_loss -2.8721246584406437
mse train_loss 0.0003884245667887746
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021576945
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

161
ELBO train_loss -2.8773518643289244
mse train_loss 0.000373224415080534
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018797044
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

162
ELBO train_loss -2.8906463587059164
mse train_loss 0.0003660686387140128
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018553153
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

163
ELBO train_loss -2.8922863546407447
mse train_loss 0.00037673077071612735
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011548341
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

164
ELBO train_loss -2.9038295475941784
mse train_loss 0.00035571995194570846
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012743288
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

165
ELBO train_loss -2.898029388121839
mse train_loss 0.00038517355333512896
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021735154
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

166
ELBO train_loss -2.8943143385761188
mse train_loss 0.0003681928921548017
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.026069101
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

167
ELBO train_loss -2.8809936001615704
mse train_loss 0.00037460973544511944
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.022071948
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

168
ELBO train_loss -2.843641016280876
mse train_loss 0.00040279257903095194
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.034608122
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

169
ELBO train_loss -2.831554453327971
mse train_loss 0.00040914611929351077
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.024178142
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

170
ELBO train_loss -2.8976258106951445
mse train_loss 0.0003662515760570729
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014768819
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

171
ELBO train_loss -2.750285450017677
mse train_loss 0.0004241053662416613
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.04325812
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

172
ELBO train_loss -2.7535977183647877
mse train_loss 0.0004358716212423905
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.033679396
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

173
ELBO train_loss -2.852871004140602
mse train_loss 0.000397838423696329
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01976156
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

174
ELBO train_loss -2.874987341322989
mse train_loss 0.0003895174905785166
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.022980476
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

175
ELBO train_loss -2.870921476831976
mse train_loss 0.0003729398842277181
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.04316494
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

176
ELBO train_loss -2.824267014017645
mse train_loss 0.0004075158273203755
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.020249136
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

177
ELBO train_loss -2.8890671100256577
mse train_loss 0.000377225859819929
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.016551444
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

178
ELBO train_loss -2.894300267381488
mse train_loss 0.00038325494179449414
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.0247359
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

179
ELBO train_loss -2.8601305361063973
mse train_loss 0.00038253521202966784
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.031422973
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

180
ELBO train_loss -2.8440344041248538
mse train_loss 0.00038826665297405407
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019292945
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

181
ELBO train_loss -2.878740027265729
mse train_loss 0.0003908992553336265
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021293286
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

182
ELBO train_loss -2.9105441345358796
mse train_loss 0.00036576327693182975
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019475816
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

183
ELBO train_loss -2.918705998726611
mse train_loss 0.0003631172247586842
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010069778
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

184
ELBO train_loss -2.9087651180771164
mse train_loss 0.00035283522722135595
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014048449
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

185
ELBO train_loss -2.9063973066941746
mse train_loss 0.0003636134892728461
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015408349
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

186
ELBO train_loss -2.916422983385482
mse train_loss 0.00036324652868437244
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013998062
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

187
ELBO train_loss -2.8914375012775637
mse train_loss 0.0003773053621495459
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.020863684
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

188
ELBO train_loss -2.9063106212975844
mse train_loss 0.0003734604309462362
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015245865
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

189
ELBO train_loss -2.9051669183767066
mse train_loss 0.00037441434907586364
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021070791
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

190
ELBO train_loss -2.8896248475560604
mse train_loss 0.0003624906132485331
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.024386728
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

191
ELBO train_loss -2.888084150710196
mse train_loss 0.0003853769193706542
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.024229651
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

192
ELBO train_loss -2.8979804290915436
mse train_loss 0.00039260874377159435
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01213028
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

193
ELBO train_loss -2.911479057006116
mse train_loss 0.00035248756576724844
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01781776
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

194
ELBO train_loss -2.903986359542271
mse train_loss 0.0003590937379861848
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015056159
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

195
ELBO train_loss -2.9095431948607824
mse train_loss 0.00035819719943112503
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011822013
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

196
ELBO train_loss -2.918958528986517
mse train_loss 0.00037825506873525466
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013705866
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

197
ELBO train_loss -2.8125236664178237
mse train_loss 0.000389493708081438
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018060638
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

198
ELBO train_loss -2.9107526203371443
mse train_loss 0.00036581783672961636
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014610746
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

199
ELBO train_loss -2.8781909312842027
mse train_loss 0.00037881258226061274
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.032853648
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

200
ELBO train_loss -2.8999024787039125
mse train_loss 0.00038102317441096703
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.02082017
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

201
ELBO train_loss -2.925482821914385
mse train_loss 0.0003713654839644595
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01634286
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

202
ELBO train_loss -2.9133817519781724
mse train_loss 0.00037600758940174755
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013394219
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

203
ELBO train_loss -2.9199407010708214
mse train_loss 0.0003577603605457248
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01565521
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

204
ELBO train_loss -2.8788457794009514
mse train_loss 0.0003956260336980329
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.02125668
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

205
ELBO train_loss -2.9198737639301227
mse train_loss 0.0003663313658556567
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021647876
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

206
ELBO train_loss -2.7567367598695576
mse train_loss 0.00043019825369690736
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.017229244
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

207
ELBO train_loss -2.9098933507811346
mse train_loss 0.00040560065387563673
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013387659
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

208
ELBO train_loss -2.916485075680715
mse train_loss 0.0003789072575702174
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013284927
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

209
ELBO train_loss -2.9381947967241393
mse train_loss 0.00036031652883327796
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014667197
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

210
ELBO train_loss -2.911092922372638
mse train_loss 0.0003706155364612504
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.028299022
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

211
ELBO train_loss -2.8704999190456464
mse train_loss 0.00042139720072846787
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014844022
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

212
ELBO train_loss -2.925741375617261
mse train_loss 0.0003609996539417584
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018425237
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

213
ELBO train_loss -2.9199470259108633
mse train_loss 0.0003805386678443097
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019325288
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

214
ELBO train_loss -2.6993904833523734
mse train_loss 0.0004396672159617872
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.030421948
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

215
ELBO train_loss -2.636239564643716
mse train_loss 0.000552856062433489
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.05855865
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

216
ELBO train_loss -2.6755507172278636
mse train_loss 0.00047074231612364285
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.03899775
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

217
ELBO train_loss -2.838323696604315
mse train_loss 0.00040163987130585637
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.017861107
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

218
ELBO train_loss -2.9030959628662973
mse train_loss 0.0003811844027675864
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.024722155
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

219
ELBO train_loss -2.8812919877610117
mse train_loss 0.00042029309996957274
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015081955
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

220
ELBO train_loss -2.9214048790481857
mse train_loss 0.00036128352831100237
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011758095
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

221
ELBO train_loss -2.9156625315828144
mse train_loss 0.0003854849278039457
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.017703826
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

222
ELBO train_loss -2.912112699364716
mse train_loss 0.000370042072114851
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.0119657675
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

223
ELBO train_loss -2.907405340446616
mse train_loss 0.00035957714492226687
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018455256
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

224
ELBO train_loss -2.9045214697999775
mse train_loss 0.00037879154101488584
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011365886
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

225
ELBO train_loss -2.9220426577442096
mse train_loss 0.00037630438257694104
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.016980486
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

226
ELBO train_loss -2.9083893838918433
mse train_loss 0.00036755999147162756
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.02169832
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

227
ELBO train_loss -2.8907845807525345
mse train_loss 0.00039550681891719334
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.016261615
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

228
ELBO train_loss -2.9179828481854133
mse train_loss 0.0003621666104629425
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.0116627
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

229
ELBO train_loss -2.8608079108427154
mse train_loss 0.00037389758475704713
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.02108458
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

230
ELBO train_loss -2.904929759367457
mse train_loss 0.00037627975716703695
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.020212365
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

231
ELBO train_loss -2.931363416167925
mse train_loss 0.0003771333423122447
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01678097
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

232
ELBO train_loss -2.92823526544391
mse train_loss 0.0003756064076508167
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011474746
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

233
ELBO train_loss -2.9223374150833994
mse train_loss 0.00036685869074646243
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.02188466
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

234
ELBO train_loss -2.8985819794097036
mse train_loss 0.00036712678232810124
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.022085259
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

235
ELBO train_loss -2.9063847784726127
mse train_loss 0.00038601749632639356
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018508235
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

236
ELBO train_loss -2.936106672826803
mse train_loss 0.00037562015454311204
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010454091
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

237
ELBO train_loss -2.811065466898792
mse train_loss 0.0004150876929140035
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.03454203
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

238
ELBO train_loss -2.8664573633445882
mse train_loss 0.00039887921868182085
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.017410588
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

239
ELBO train_loss -2.786259408267039
mse train_loss 0.00038063966362728333
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01728567
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

240
ELBO train_loss -2.9287285444871434
mse train_loss 0.000366287479201279
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014175976
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

241
ELBO train_loss -2.9241251945495605
mse train_loss 0.00038103054179853917
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019001136
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

242
ELBO train_loss -2.8744475515383594
mse train_loss 0.0003974990370652621
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.03132294
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

243
ELBO train_loss -2.856891834510947
mse train_loss 0.00039196579595193054
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019964492
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

244
ELBO train_loss -2.930144247019066
mse train_loss 0.00036351028818830427
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012256696
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

245
ELBO train_loss -2.937780456722907
mse train_loss 0.0003661108291092909
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01378582
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

246
ELBO train_loss -2.918149025935047
mse train_loss 0.0003720911272981573
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013845463
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

247
ELBO train_loss -2.918219010784941
mse train_loss 0.0003646800845294734
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01631845
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

248
ELBO train_loss -2.9150788649073185
mse train_loss 0.00037713176905801346
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.02337139
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

249
ELBO train_loss -2.8939396705267564
mse train_loss 0.00037398032475052014
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.032848056
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

250
ELBO train_loss -2.9005395084057213
mse train_loss 0.0003952703703379364
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015518238
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

251
ELBO train_loss -2.880370592171291
mse train_loss 0.00037257739958871717
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.02493231
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

252
ELBO train_loss -2.917711028512919
mse train_loss 0.00037252584819348074
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01375847
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

253
ELBO train_loss -2.9299657524756664
mse train_loss 0.0003714528978683652
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.009890891
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

254
ELBO train_loss -2.949913794139646
mse train_loss 0.0003563958840150351
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012457425
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

255
ELBO train_loss -2.9499708121677615
mse train_loss 0.0003549307763229458
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01471181
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

256
ELBO train_loss -2.757726403902162
mse train_loss 0.0003737848411854532
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.025637068
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

257
ELBO train_loss -2.906325101852417
mse train_loss 0.00039313196831707893
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019313658
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

258
ELBO train_loss -2.916476213707114
mse train_loss 0.0003777305471473637
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.027048381
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

259
ELBO train_loss -2.9142597171495543
mse train_loss 0.000364401580640642
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019345203
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

260
ELBO train_loss -2.908228206184675
mse train_loss 0.000377068721924311
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.034731623
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

261
ELBO train_loss -2.9088325320549733
mse train_loss 0.0003770793711806138
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018641558
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

262
ELBO train_loss -2.9325535432347714
mse train_loss 0.0003637388106413572
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01599474
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

263
ELBO train_loss -2.9450173872821734
mse train_loss 0.00035536675280144545
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012255272
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

264
ELBO train_loss -2.8194729157213896
mse train_loss 0.0003983490152574443
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012144757
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

265
ELBO train_loss -2.9161180550197385
mse train_loss 0.00037404071931718445
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015745431
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

266
ELBO train_loss -2.9161763731038794
mse train_loss 0.0003708117366395012
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.016068688
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

267
ELBO train_loss -2.8505448327874237
mse train_loss 0.0004055730129245071
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.03279401
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

268
ELBO train_loss -2.895455837249756
mse train_loss 0.0003894228276243117
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.016889742
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

269
ELBO train_loss -2.9173616958114335
mse train_loss 0.00037110436769237495
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.017863888
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

270
ELBO train_loss -2.9258674720548234
mse train_loss 0.0003729104278565226
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.016703244
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

271
ELBO train_loss -2.9307606535137825
mse train_loss 0.0003742146512212337
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019873776
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

272
ELBO train_loss -2.919361114501953
mse train_loss 0.00037435551568259536
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.016084205
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

273
ELBO train_loss -2.9222412244328915
mse train_loss 0.0003667864470769001
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.026272032
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

274
ELBO train_loss -2.912887658712999
mse train_loss 0.00037407013807305185
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014560412
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

275
ELBO train_loss -2.9366434520145632
mse train_loss 0.00036729379328635504
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.017523516
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

276
ELBO train_loss -2.9187477399718085
mse train_loss 0.00038346962729811596
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.025399704
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

277
ELBO train_loss -2.913163814904555
mse train_loss 0.0003618915699299355
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.017769048
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

278
ELBO train_loss -2.9167945429963886
mse train_loss 0.0003655519413889192
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018899791
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

279
ELBO train_loss -2.93100838391286
mse train_loss 0.0003782671251343036
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.016621526
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

280
ELBO train_loss -2.9290314665380515
mse train_loss 0.00038819053712221884
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014618573
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

281
ELBO train_loss -2.939847473828298
mse train_loss 0.0003811176624676248
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012690781
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

282
ELBO train_loss -2.935683880212172
mse train_loss 0.0003785408818039973
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01587168
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

283
ELBO train_loss -2.9265491692525036
mse train_loss 0.0003793975827064506
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.016535686
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

284
ELBO train_loss -2.9191195964813232
mse train_loss 0.0004170990882898277
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013834353
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

285
ELBO train_loss -2.9516933639094516
mse train_loss 0.00036883541604697284
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010627637
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

286
ELBO train_loss -2.943262909943203
mse train_loss 0.0003587135197792448
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013893442
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

287
ELBO train_loss -2.9122818101127192
mse train_loss 0.0004220330870253438
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.02274248
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

288
ELBO train_loss -2.9048552940476617
mse train_loss 0.0003843869124912604
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.017708823
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

289
ELBO train_loss -2.9464390322847187
mse train_loss 0.0003402222242790608
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.023068676
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

290
ELBO train_loss -2.922982584755376
mse train_loss 0.00035988408380258335
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01659735
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

291
ELBO train_loss -2.9380373190034113
mse train_loss 0.0003698762750059788
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018723346
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

292
ELBO train_loss -2.9388376092011073
mse train_loss 0.00036452971492182323
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01877037
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

293
ELBO train_loss -2.9306471437778114
mse train_loss 0.0003583956584510095
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.016015686
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

294
ELBO train_loss -2.918680438455546
mse train_loss 0.0003655575353647846
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015081582
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

295
ELBO train_loss -2.9379231299994126
mse train_loss 0.00037559957813497913
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01682667
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

296
ELBO train_loss -2.9423497757821715
mse train_loss 0.0003576431157647298
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012409329
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

297
ELBO train_loss -2.9237518715408615
mse train_loss 0.0003776365431478106
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.037596896
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

298
ELBO train_loss -2.8882553397484547
mse train_loss 0.00038862857832529425
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019772949
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

299
ELBO train_loss -2.9457975513530226
mse train_loss 0.00036978557538944033
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01366618
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

300
ELBO train_loss -2.9472402761567316
mse train_loss 0.0003634090906004685
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011376514
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

301
ELBO train_loss -2.9341172722150697
mse train_loss 0.000374026957323845
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019133404
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

302
ELBO train_loss -2.9009617589554697
mse train_loss 0.0003561035003447202
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.02816685
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

303
ELBO train_loss -2.9288332192402966
mse train_loss 0.0003727148904768258
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019107312
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

304
ELBO train_loss -2.944390593834643
mse train_loss 0.00037276624212643345
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.009539293
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

305
ELBO train_loss -2.9462188864653966
mse train_loss 0.00035675532565636667
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.0131752975
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

306
ELBO train_loss -2.9508794613604277
mse train_loss 0.0003896602649101109
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010800472
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

307
ELBO train_loss -2.960016277601134
mse train_loss 0.0003546762257102736
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010794149
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

308
ELBO train_loss -2.9619281022053845
mse train_loss 0.0003601443884621884
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014046232
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

309
ELBO train_loss -2.9495673224611103
mse train_loss 0.00035639787941209024
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019539192
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

310
ELBO train_loss -2.944677343908346
mse train_loss 0.0003575501131016831
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.017225228
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

311
ELBO train_loss -2.9392026415411032
mse train_loss 0.00035944682538272145
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015153006
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

312
ELBO train_loss -2.9526061921749474
mse train_loss 0.00035932132129157663
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013576163
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

313
ELBO train_loss -2.958575239721334
mse train_loss 0.0003659865266494102
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.023977257
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

314
ELBO train_loss -2.9343313945914216
mse train_loss 0.0003675448714825764
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018379379
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

315
ELBO train_loss -2.93447959198142
mse train_loss 0.0003667869228970716
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.0094841495
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

316
ELBO train_loss -2.9615486747813673
mse train_loss 0.00035549205961804136
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012625192
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

317
ELBO train_loss -2.964412437295014
mse train_loss 0.00034504045374287326
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010860322
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

318
ELBO train_loss -2.9560153934190856
mse train_loss 0.00035335973315907395
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012053597
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

319
ELBO train_loss -2.961318124015376
mse train_loss 0.0003645082719883231
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.022261214
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

320
ELBO train_loss -2.930555613535755
mse train_loss 0.0003773326027857245
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018674972
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

321
ELBO train_loss -2.937492797959526
mse train_loss 0.0003621142467688154
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012474905
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

322
ELBO train_loss -2.958627035033028
mse train_loss 0.0003495046016123181
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015165456
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

323
ELBO train_loss -2.9437862859582
mse train_loss 0.0003716122342014404
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021837408
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

324
ELBO train_loss -2.8940657827089415
mse train_loss 0.0003907059924457363
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.028594803
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

325
ELBO train_loss -2.918076897567173
mse train_loss 0.0003769035670216599
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015142612
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

326
ELBO train_loss -2.899573052548013
mse train_loss 0.0003787923688266672
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.023008099
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

327
ELBO train_loss -2.8934946965496495
mse train_loss 0.0003856908648608709
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021711841
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

328
ELBO train_loss -2.9282662373668744
mse train_loss 0.0003831518719517538
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012735217
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

329
ELBO train_loss -2.9661455199403584
mse train_loss 0.00036396586911682533
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010027954
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

330
ELBO train_loss -2.9506333863960124
mse train_loss 0.00039500478662787673
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010207844
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

331
ELBO train_loss -2.966363542484787
mse train_loss 0.0003567168212886724
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014091648
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

332
ELBO train_loss -2.954635260240087
mse train_loss 0.00037209233071409026
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.017136315
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

333
ELBO train_loss -2.9511989467548876
mse train_loss 0.0003728307536495674
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014705048
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

334
ELBO train_loss -2.952071243861936
mse train_loss 0.00038765202620163347
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013037526
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

335
ELBO train_loss -2.951939618812417
mse train_loss 0.0003771348084863451
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01238574
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

336
ELBO train_loss -2.965652956152862
mse train_loss 0.00036740929101382926
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010653278
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

337
ELBO train_loss -2.9526008255076857
mse train_loss 0.0003764356595226827
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011952405
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

338
ELBO train_loss -2.9499913611502016
mse train_loss 0.00035908432344616093
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015846185
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

339
ELBO train_loss -2.9543598462950507
mse train_loss 0.0003687350718054112
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013576304
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

340
ELBO train_loss -2.904710769934474
mse train_loss 0.0003760905306509538
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012922173
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

341
ELBO train_loss -2.918247842563773
mse train_loss 0.00037683756859481054
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021814361
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

342
ELBO train_loss -2.892284501273677
mse train_loss 0.00040811183285802816
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.026041986
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

343
ELBO train_loss -2.9362817485377475
mse train_loss 0.00036308825683762443
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018406734
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

344
ELBO train_loss -2.9507881515430956
mse train_loss 0.00036566851513442707
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013821386
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

345
ELBO train_loss -2.943425664361918
mse train_loss 0.00035684070668517135
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01642147
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

346
ELBO train_loss -2.9349345980950123
mse train_loss 0.00036776374189736636
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018784655
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

347
ELBO train_loss -2.9641857552078537
mse train_loss 0.00034210617068654173
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012351028
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

348
ELBO train_loss -2.954128526291757
mse train_loss 0.0003644106220475943
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019802997
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

349
ELBO train_loss -2.9149695747303515
mse train_loss 0.00037183051238073704
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.025144387
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

350
ELBO train_loss -2.9316589000090114
mse train_loss 0.00034424152696408545
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014085304
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

351
ELBO train_loss -2.937568821997013
mse train_loss 0.000376164911804668
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011390866
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

352
ELBO train_loss -2.9361474626469164
mse train_loss 0.0003650833645018415
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.017659018
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

353
ELBO train_loss -2.8923262222757877
mse train_loss 0.0003795423992483367
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.03830117
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

354
ELBO train_loss -2.910213011615681
mse train_loss 0.00038325565483834033
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011575796
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

355
ELBO train_loss -2.930310415771772
mse train_loss 0.0003980412697068082
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014133709
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

356
ELBO train_loss -2.9657748240344928
mse train_loss 0.0003521666243751445
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.017567547
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

357
ELBO train_loss -2.9608327577698907
mse train_loss 0.000371251726426395
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010774615
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

358
ELBO train_loss -2.9503623089700377
mse train_loss 0.00036930790863478697
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.024370918
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

359
ELBO train_loss -2.9184684123633042
mse train_loss 0.000375701813850956
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019585723
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

360
ELBO train_loss -2.9447483656541356
mse train_loss 0.0003595090658461042
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01457707
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

361
ELBO train_loss -2.95654803402019
mse train_loss 0.0003690073145267923
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019128455
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

362
ELBO train_loss -2.9641358852386475
mse train_loss 0.00036495923611581466
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011814543
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

363
ELBO train_loss -2.9685254501846603
mse train_loss 0.00037363633010448574
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011789844
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

364
ELBO train_loss -2.853595846104172
mse train_loss 0.0004210834526500823
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.02278104
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

365
ELBO train_loss -2.931242839345392
mse train_loss 0.0003919830136132901
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018667938
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

366
ELBO train_loss -2.964206691058177
mse train_loss 0.0003529255388783432
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014462035
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

367
ELBO train_loss -2.9771388836626738
mse train_loss 0.00036154688667628987
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010373394
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

368
ELBO train_loss -2.966900834497416
mse train_loss 0.00034574753514772174
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014943248
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

369
ELBO train_loss -2.9680053333066545
mse train_loss 0.00036490240347360805
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012721231
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

370
ELBO train_loss -2.9493382067050575
mse train_loss 0.00037066010376087053
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.02316937
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

371
ELBO train_loss -2.9383311721513854
mse train_loss 0.00037392625916872246
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010070399
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

372
ELBO train_loss -2.9786298679855636
mse train_loss 0.00037154922989979033
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012185328
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

373
ELBO train_loss -2.965714346687749
mse train_loss 0.0003447330450828029
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013826711
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

374
ELBO train_loss -2.9461267669245883
mse train_loss 0.0003772264674310309
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012753726
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

375
ELBO train_loss -2.970452385128669
mse train_loss 0.000354936624545481
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.020766037
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

376
ELBO train_loss -2.917051951840239
mse train_loss 0.0003886451016131894
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019948544
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

377
ELBO train_loss -2.939574322610531
mse train_loss 0.00042924030523589055
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014439112
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

378
ELBO train_loss -2.9522071716920384
mse train_loss 0.00039107009674035854
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.0152451545
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

379
ELBO train_loss -2.958849353610345
mse train_loss 0.00035545979884567334
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01719663
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

380
ELBO train_loss -2.978078045935001
mse train_loss 0.00034660151264551185
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010983937
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

381
ELBO train_loss -2.978419650275752
mse train_loss 0.0003733027793405542
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011565252
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

382
ELBO train_loss -2.9724651867488645
mse train_loss 0.00037304661099421175
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011820117
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

383
ELBO train_loss -2.9707929143365823
mse train_loss 0.0003786266549108199
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.009579961
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

384
ELBO train_loss -2.9721840957425676
mse train_loss 0.000358149215451635
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018676821
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

385
ELBO train_loss -2.9579295932122
mse train_loss 0.00035446755563892985
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.016242232
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

386
ELBO train_loss -2.9694406311467008
mse train_loss 0.0003504062080777795
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014849953
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

387
ELBO train_loss -2.9347344591932476
mse train_loss 0.0004069870911683571
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018833743
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

388
ELBO train_loss -2.957725263991446
mse train_loss 0.00036713226318341803
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010811282
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

389
ELBO train_loss -2.9695722516977563
mse train_loss 0.000355658517328952
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018759053
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

390
ELBO train_loss -2.964800902132718
mse train_loss 0.00036342140481552496
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.017332323
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

391
ELBO train_loss -2.972355302774681
mse train_loss 0.00036082639307319624
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.023576032
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

392
ELBO train_loss -2.961041212081909
mse train_loss 0.000383677015153533
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.028058048
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

393
ELBO train_loss -2.960207318359951
mse train_loss 0.00037864054704770306
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.016458679
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

394
ELBO train_loss -2.897656057240828
mse train_loss 0.0004206763902780125
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015237663
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

395
ELBO train_loss -2.956619217710675
mse train_loss 0.0003645175006482223
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.0174633
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

396
ELBO train_loss -2.9564068587321155
mse train_loss 0.0003700658106816195
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014701759
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

397
ELBO train_loss -2.9730813863142482
mse train_loss 0.00037110833350112134
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010556787
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

398
ELBO train_loss -2.9901078737007
mse train_loss 0.00034924527701047547
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011745463
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

399
ELBO train_loss -2.9864885627098805
mse train_loss 0.00036192064053809517
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019589808
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

400
ELBO train_loss -2.945207559837485
mse train_loss 0.00036701561667833404
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019900566
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

401
ELBO train_loss -2.9142080989648713
mse train_loss 0.00036912298495529814
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.0304306
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

402
ELBO train_loss -2.914049413968932
mse train_loss 0.0003884477235534984
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.02570505
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

403
ELBO train_loss -2.9618095766823247
mse train_loss 0.00036342485197204743
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015575232
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

404
ELBO train_loss -2.9671528969170913
mse train_loss 0.00034113751256504573
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.016830208
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

405
ELBO train_loss -2.9527375788058876
mse train_loss 0.00036942074020428336
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014477641
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

406
ELBO train_loss -2.9748265788240253
mse train_loss 0.00036239692307613296
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014654772
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

407
ELBO train_loss -2.981643834204044
mse train_loss 0.0003747805399229346
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012025962
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

408
ELBO train_loss -2.988701397517942
mse train_loss 0.0003598726649421601
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012622027
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

409
ELBO train_loss -2.9770326434441334
mse train_loss 0.0003564497631675792
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021946296
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

410
ELBO train_loss -2.949685184460766
mse train_loss 0.00047452496101130855
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.017009899
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

411
ELBO train_loss -2.9426065233518495
mse train_loss 0.00039981614868595916
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.022363449
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

412
ELBO train_loss -2.9612483978271484
mse train_loss 0.0003828237287344341
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01691817
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

413
ELBO train_loss -2.9824393515316947
mse train_loss 0.0003561826289921366
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015238111
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

414
ELBO train_loss -2.9551421628808074
mse train_loss 0.0003678212112545053
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015870411
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

415
ELBO train_loss -2.9742967272704504
mse train_loss 0.00036756085305483484
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011666598
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

416
ELBO train_loss -2.969380464193956
mse train_loss 0.00035640569365328565
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.029018532
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

417
ELBO train_loss -2.9653939301112913
mse train_loss 0.0003710330954317952
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012506554
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

418
ELBO train_loss -2.9891209467402042
mse train_loss 0.0003624758425054474
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01302385
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

419
ELBO train_loss -2.957971266980441
mse train_loss 0.00036546985499609155
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.024411235
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

420
ELBO train_loss -2.9367687522240407
mse train_loss 0.0004244502243928541
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.020359829
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

421
ELBO train_loss -2.9586267471313477
mse train_loss 0.0003715385724065544
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.016282815
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

422
ELBO train_loss -2.9586374489766247
mse train_loss 0.00036368286154751296
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014722385
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

423
ELBO train_loss -2.9434335614150426
mse train_loss 0.00037559147271819693
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.02320005
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

424
ELBO train_loss -2.9634718580066033
mse train_loss 0.0003689997412158914
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011423927
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

425
ELBO train_loss -2.9819765990635134
mse train_loss 0.0003625754311456749
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011468711
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

426
ELBO train_loss -2.959725559882398
mse train_loss 0.0004443954044891486
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012139044
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

427
ELBO train_loss -2.9671381419559695
mse train_loss 0.00039332623418886214
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010311372
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

428
ELBO train_loss -2.9823758107311322
mse train_loss 0.00035255876334630093
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010142599
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

429
ELBO train_loss -2.987106012848188
mse train_loss 0.0003539123171545073
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013864804
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

430
ELBO train_loss -2.965222790556134
mse train_loss 0.0003619120524370305
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.017827023
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

431
ELBO train_loss -2.975477025193988
mse train_loss 0.0003551796453736968
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.016235348
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

432
ELBO train_loss -2.917716253478572
mse train_loss 0.0004261588410887305
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021032268
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

433
ELBO train_loss -2.9340816331359574
mse train_loss 0.00046359231631833847
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.022387914
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

434
ELBO train_loss -2.955230029124134
mse train_loss 0.00037335404354138827
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013971854
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

435
ELBO train_loss -2.985091398347099
mse train_loss 0.00037548501063182176
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.009318286
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

436
ELBO train_loss -2.964909983131121
mse train_loss 0.0003640755166998133
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01253686
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

437
ELBO train_loss -2.9890105769319355
mse train_loss 0.00036950887154506626
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011520371
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

438
ELBO train_loss -3.007365568628851
mse train_loss 0.0003533663912094198
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011061987
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

439
ELBO train_loss -2.9899752364968353
mse train_loss 0.0003660821030795012
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018032897
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

440
ELBO train_loss -2.9872064320546277
mse train_loss 0.0003495875346245912
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015582264
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

441
ELBO train_loss -2.9796459629850567
mse train_loss 0.0004165112480240048
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014687313
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

442
ELBO train_loss -2.9803716596567407
mse train_loss 0.0003727551089503082
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01504781
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

443
ELBO train_loss -2.9905974864959717
mse train_loss 0.0003485266250483516
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012804199
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

444
ELBO train_loss -2.985573647157201
mse train_loss 0.0003736502507960705
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012315167
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

445
ELBO train_loss -2.9997014864435734
mse train_loss 0.0003357195429380913
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012831294
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

446
ELBO train_loss -2.991765764524352
mse train_loss 0.0003581695458500312
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01427164
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

447
ELBO train_loss -2.997405583003782
mse train_loss 0.0003486521570077869
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012486503
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

448
ELBO train_loss -2.9908531206958697
mse train_loss 0.0003777758802072022
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015061377
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

449
ELBO train_loss -2.992982598970521
mse train_loss 0.0003635631263297807
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012238546
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

450
ELBO train_loss -2.9874682966268287
mse train_loss 0.0003674985401064403
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015844837
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

451
ELBO train_loss -2.9792039529332577
mse train_loss 0.0003678066285880599
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01266623
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

452
ELBO train_loss -2.989597536482901
mse train_loss 0.0003687829846727117
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018264836
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

453
ELBO train_loss -2.990742170585776
mse train_loss 0.0003637714216206662
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015368823
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

454
ELBO train_loss -2.9929688201760345
mse train_loss 0.0003567018177149431
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01096373
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

455
ELBO train_loss -2.984829358334811
mse train_loss 0.0004307974826210653
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011658567
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

456
ELBO train_loss -3.0041734497502164
mse train_loss 0.000356404539933516
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011410241
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

457
ELBO train_loss -2.9910787726348302
mse train_loss 0.00036603717527677356
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.017063694
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

458
ELBO train_loss -2.9818866567791633
mse train_loss 0.00035515166808113035
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01845074
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

459
ELBO train_loss -2.9868062982019388
mse train_loss 0.00035300024017419245
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021548752
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

460
ELBO train_loss -2.996818227588006
mse train_loss 0.00037239799643910647
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01189068
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

461
ELBO train_loss -2.9552131104019455
mse train_loss 0.00042870650858671034
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010089973
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

462
ELBO train_loss -2.986301318654474
mse train_loss 0.0003696862560969268
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018018583
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

463
ELBO train_loss -2.8139544972833597
mse train_loss 0.0003811599470385052
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.030825162
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

464
ELBO train_loss -2.895920164180252
mse train_loss 0.00039955719845272814
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021895662
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

465
ELBO train_loss -2.9473216038829877
mse train_loss 0.00042611182494888537
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013263225
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

466
ELBO train_loss -2.9796244018482714
mse train_loss 0.0004384460186866938
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014336251
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

467
ELBO train_loss -2.995923393177536
mse train_loss 0.0003489277798919675
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010210706
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

468
ELBO train_loss -3.0082861882335736
mse train_loss 0.00036061073780797843
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010869487
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

469
ELBO train_loss -3.010650378353191
mse train_loss 0.00035491401828245596
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012135844
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

470
ELBO train_loss -2.9751142803228126
mse train_loss 0.0003609719756498652
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012931635
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

471
ELBO train_loss -2.9761626315566727
mse train_loss 0.0003526705852050756
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010397351
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

472
ELBO train_loss -2.997911619690229
mse train_loss 0.00036990243111262625
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01240684
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

473
ELBO train_loss -3.0044726695654527
mse train_loss 0.00036284664425041246
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012946015
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

474
ELBO train_loss -2.979237419254375
mse train_loss 0.0003667636812919364
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01876908
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

475
ELBO train_loss -2.960313122227507
mse train_loss 0.0003774207440108272
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012814902
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

476
ELBO train_loss -2.9168880075778603
mse train_loss 0.00036911520929040634
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.03337346
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

477
ELBO train_loss -2.9118484955913617
mse train_loss 0.00040678813381401715
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.022122564
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

478
ELBO train_loss -2.9561525020959243
mse train_loss 0.0003710345001034614
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019590514
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

479
ELBO train_loss -2.959011433259496
mse train_loss 0.00037649411456615507
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013143368
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

480
ELBO train_loss -2.943461537361145
mse train_loss 0.00038332859397962476
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.017345652
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

481
ELBO train_loss -2.940511663005037
mse train_loss 0.00037937580233325587
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013944931
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

482
ELBO train_loss -2.984730711523092
mse train_loss 0.0003723623005910513
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012768371
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

483
ELBO train_loss -2.968183458976026
mse train_loss 0.000383004015981536
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011765065
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

484
ELBO train_loss -2.9678228936105406
mse train_loss 0.00038664510368176225
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011820395
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

485
ELBO train_loss -2.9738858690801657
mse train_loss 0.00036914371841428975
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.0147032505
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

486
ELBO train_loss -2.9913501694517315
mse train_loss 0.0003606863987065395
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.009626982
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

487
ELBO train_loss -2.981534926396496
mse train_loss 0.00037011804189840506
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012508623
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

488
ELBO train_loss -2.990398447468596
mse train_loss 0.000354477863062507
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015372741
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

489
ELBO train_loss -2.9962790192298168
mse train_loss 0.0003536466464075207
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01203144
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

490
ELBO train_loss -3.00979098733866
mse train_loss 0.00035093621848335596
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.008257949
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

491
ELBO train_loss -3.0052395091866546
mse train_loss 0.00036128227464977725
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01074045
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

492
ELBO train_loss -3.001719740201842
mse train_loss 0.0003630765001256599
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010754022
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

493
ELBO train_loss -3.0128556152559676
mse train_loss 0.00035252449948907353
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011855407
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

494
ELBO train_loss -2.9627193251870714
mse train_loss 0.00035474641289219333
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.016184298
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

495
ELBO train_loss -2.990475447672718
mse train_loss 0.0003709159618493978
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018510945
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

496
ELBO train_loss -3.005594509952473
mse train_loss 0.0003598360797798894
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.009891844
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

497
ELBO train_loss -3.005831241607666
mse train_loss 0.0003565006818290677
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010970746
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

498
ELBO train_loss -3.003398980734483
mse train_loss 0.0003764344124510041
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01130478
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

499
ELBO train_loss -3.009338545349409
mse train_loss 0.0003385727542544768
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010860354
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

500
ELBO train_loss -2.9201576552301085
mse train_loss 0.00044762928072881515
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021446383
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

501
ELBO train_loss -2.959456963359185
mse train_loss 0.000373575766131563
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.025922999
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

502
ELBO train_loss -2.97331683590727
mse train_loss 0.0003705723183712799
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012601084
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

503
ELBO train_loss -3.0131612903666944
mse train_loss 0.00034577752691957185
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01084631
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

504
ELBO train_loss -3.0062769448982096
mse train_loss 0.000361793284372532
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015605138
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

505
ELBO train_loss -2.986364836962718
mse train_loss 0.00037120914834613014
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01227501
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

506
ELBO train_loss -2.929221713318015
mse train_loss 0.0003654955126310013
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013916384
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

507
ELBO train_loss -3.008693294705085
mse train_loss 0.0003570613925200272
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013871472
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

508
ELBO train_loss -2.9987595846068182
mse train_loss 0.00037364220880367355
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.016782384
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

509
ELBO train_loss -3.0054235998189673
mse train_loss 0.0003553684836556643
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014781209
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

510
ELBO train_loss -3.005397544716889
mse train_loss 0.0003688245347837197
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011410528
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

511
ELBO train_loss -3.010487331534332
mse train_loss 0.0003590091169517452
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012448257
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

512
ELBO train_loss -2.9830100221453972
mse train_loss 0.0003796466722324455
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015015772
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

513
ELBO train_loss -3.01002018406706
mse train_loss 0.00038105736454764274
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011161655
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

514
ELBO train_loss -3.0135941955278502
mse train_loss 0.0003428231274878237
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013427736
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

515
ELBO train_loss -3.0067191078977764
mse train_loss 0.00038775095231787144
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014185799
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

516
ELBO train_loss -2.9886230140362144
mse train_loss 0.000370308252453435
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.016104497
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

517
ELBO train_loss -2.9764732324852132
mse train_loss 0.00039275719637033653
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013883052
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

518
ELBO train_loss -3.018434632499263
mse train_loss 0.0003423792601646504
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012744805
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

519
ELBO train_loss -3.0033152643239722
mse train_loss 0.0003690703644061391
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013998458
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

520
ELBO train_loss -3.0169293970431923
mse train_loss 0.0003635547289137184
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01103261
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

521
ELBO train_loss -3.0104401111602783
mse train_loss 0.00035732052149470755
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015260638
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

522
ELBO train_loss -2.969531585585396
mse train_loss 0.00037948182868586747
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.03063122
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

523
ELBO train_loss -2.969559759463904
mse train_loss 0.0003643069662843427
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014226574
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

524
ELBO train_loss -3.0091970911565817
mse train_loss 0.0003578767393504814
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011253978
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

525
ELBO train_loss -3.0073047628942526
mse train_loss 0.0003646158131129407
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.016874216
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

526
ELBO train_loss -2.996961724083379
mse train_loss 0.0003661097912557142
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013812962
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

527
ELBO train_loss -2.9989469501207457
mse train_loss 0.00036605097735652105
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.016047928
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

528
ELBO train_loss -2.984137179716578
mse train_loss 0.0003735588021677444
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015191004
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

529
ELBO train_loss -2.975395054187415
mse train_loss 0.0003806568348872528
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013400959
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

530
ELBO train_loss -2.9939624161090492
mse train_loss 0.00045696131084421824
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014514794
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

531
ELBO train_loss -3.012997181910389
mse train_loss 0.0003578688533106854
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011508293
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

532
ELBO train_loss -3.010739821308064
mse train_loss 0.00036828610472236143
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010059639
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

533
ELBO train_loss -3.0197639285393483
mse train_loss 0.0003526860958993224
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.00992197
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

534
ELBO train_loss -3.007462294596546
mse train_loss 0.0003795390696053938
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011767263
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

535
ELBO train_loss -3.0187687064116857
mse train_loss 0.00037387995566087685
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013071958
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

536
ELBO train_loss -3.026412419553073
mse train_loss 0.0003646202945537022
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.009053444
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

537
ELBO train_loss -3.0136801296809934
mse train_loss 0.00036267451568559373
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013343222
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

538
ELBO train_loss -3.0207229605260886
mse train_loss 0.0003469208030816484
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012265603
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

539
ELBO train_loss -3.0169822404969415
mse train_loss 0.0003632245871418525
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014438968
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

540
ELBO train_loss -3.0037460866964087
mse train_loss 0.0003753743287645948
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01148202
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

541
ELBO train_loss -3.012625019505339
mse train_loss 0.00036530420944678934
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010390154
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

542
ELBO train_loss -3.0199988698059657
mse train_loss 0.0003661666002853791
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.00982803
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

543
ELBO train_loss -3.012319218437627
mse train_loss 0.0003725494590877854
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021319497
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

544
ELBO train_loss -2.9802523221609727
mse train_loss 0.0003727036591439139
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01507257
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

545
ELBO train_loss -2.991712934565994
mse train_loss 0.0003582834252947063
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013600694
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

546
ELBO train_loss -3.018823205300097
mse train_loss 0.00037511377480047985
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012908155
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

547
ELBO train_loss -3.024139111896731
mse train_loss 0.0003496310750532122
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010669006
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

548
ELBO train_loss -3.0215936291892573
mse train_loss 0.00037785845491288336
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014104011
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

549
ELBO train_loss -3.0302267974277712
mse train_loss 0.00034892639649904524
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012218856
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

550
ELBO train_loss -3.0267220263211234
mse train_loss 0.0003651222134479058
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.031815056
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

551
ELBO train_loss -3.0211420778958304
mse train_loss 0.0003822223855268193
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014430227
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

552
ELBO train_loss -3.0089977417352065
mse train_loss 0.0003857419728345396
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.016584268
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

553
ELBO train_loss -3.015391574715668
mse train_loss 0.000366227857258944
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012458476
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

554
ELBO train_loss -3.030875232984435
mse train_loss 0.0003561356334446244
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015945286
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

555
ELBO train_loss -3.0258791221762604
mse train_loss 0.0003538441384114536
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010774056
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

556
ELBO train_loss -3.021581663275665
mse train_loss 0.0003541102800364518
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.028568905
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

557
ELBO train_loss -3.0068560636268473
mse train_loss 0.0004276417633452681
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.0153001305
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

558
ELBO train_loss -3.034841631943325
mse train_loss 0.0003537833896566639
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012260866
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

559
ELBO train_loss -3.0350641844407567
mse train_loss 0.0003675265736847167
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.027262157
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

560
ELBO train_loss -3.0266065147687806
mse train_loss 0.0003653604299475808
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.026396688
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

561
ELBO train_loss -3.0428344708568646
mse train_loss 0.0003679528374462125
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015570975
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

562
ELBO train_loss -3.0070438429994404
mse train_loss 0.0003762412264411566
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.023540987
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

563
ELBO train_loss -2.9704085475993605
mse train_loss 0.000397548833993738
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015563863
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

564
ELBO train_loss -3.045656941971689
mse train_loss 0.00036348759681030334
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012418874
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

565
ELBO train_loss -3.044751635137594
mse train_loss 0.0003627567057264969
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010629733
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

566
ELBO train_loss -3.053323889678379
mse train_loss 0.00035206496716643914
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.0120893475
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

567
ELBO train_loss -3.051156718775911
mse train_loss 0.00036981272624304765
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.0107938815
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

568
ELBO train_loss -3.0415842128249833
mse train_loss 0.0003702626039190569
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012498744
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

569
ELBO train_loss -3.0200669900426327
mse train_loss 0.00038112325479610346
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014958619
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

570
ELBO train_loss -3.044704518228207
mse train_loss 0.0003698164622413592
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012198879
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

571
ELBO train_loss -3.0430750082123956
mse train_loss 0.00036844898074017484
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013623525
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

572
ELBO train_loss -3.041359707994281
mse train_loss 0.0003585972271877218
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010018591
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

573
ELBO train_loss -3.0641035898676456
mse train_loss 0.0003622290447688187
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.017665818
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

574
ELBO train_loss -3.063297406682428
mse train_loss 0.0003470202192954086
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012585844
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

575
ELBO train_loss -3.0500215044561423
mse train_loss 0.00034793216942695306
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012004771
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

576
ELBO train_loss -3.0463323863047473
mse train_loss 0.00035641027970309807
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010589026
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

577
ELBO train_loss -3.0679318410045697
mse train_loss 0.0003447078993732883
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.0155790625
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

578
ELBO train_loss -3.0632315806622774
mse train_loss 0.0003564837968025251
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014968122
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

579
ELBO train_loss -3.0655516228585875
mse train_loss 0.0003721972129566876
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01440255
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

580
ELBO train_loss -3.0772550285987132
mse train_loss 0.0003562426572898403
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013440088
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

581
ELBO train_loss -3.067584712550325
mse train_loss 0.00034222334984663593
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010265485
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

582
ELBO train_loss -3.0725075973654694
mse train_loss 0.00037826195357582655
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013019454
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

583
ELBO train_loss -3.0819240875963896
mse train_loss 0.00035604299183381407
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010090497
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

584
ELBO train_loss -3.0809133322733753
mse train_loss 0.00037400335988319776
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.008573927
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

585
ELBO train_loss -3.0743978878237166
mse train_loss 0.0003454459741610577
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014184293
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

586
ELBO train_loss -3.0816505180214935
mse train_loss 0.00034620328136403467
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012520655
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

587
ELBO train_loss -3.0893638134002686
mse train_loss 0.00036977999486723247
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.009974335
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

588
ELBO train_loss -3.0850565298548283
mse train_loss 0.0003579321362952403
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012252722
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

589
ELBO train_loss -3.087121504657673
mse train_loss 0.0003567585292742324
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.027775103
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

590
ELBO train_loss -3.0472281507725985
mse train_loss 0.00038663562544278873
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011944087
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

591
ELBO train_loss -3.0667266665764576
mse train_loss 0.00037705930985536707
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014542385
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

592
ELBO train_loss -3.0761700171344684
mse train_loss 0.00034923575744955397
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014824448
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

593
ELBO train_loss -3.0842541973545865
mse train_loss 0.00036657011089187536
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012988422
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

594
ELBO train_loss -3.0968939853164383
mse train_loss 0.0003586670058177573
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.023033971
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

595
ELBO train_loss -3.1058572463269503
mse train_loss 0.0003535009531813832
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.0141682075
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

596
ELBO train_loss -3.0491612332539177
mse train_loss 0.0003570002217596959
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012178422
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

597
ELBO train_loss -3.0925121532296234
mse train_loss 0.00036860070189127244
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011230748
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

598
ELBO train_loss -3.1057131245451153
mse train_loss 0.0003559043435193159
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01533327
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

599
ELBO train_loss -3.1130415718510465
mse train_loss 0.0003725988922183806
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010913912
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

600
ELBO train_loss -3.044097014193265
mse train_loss 0.0004031366099194922
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.0137194665
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

601
ELBO train_loss -3.071977439916359
mse train_loss 0.0003646521882334162
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.02044652
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

602
ELBO train_loss -3.108085807764305
mse train_loss 0.0003651664045942455
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.020983035
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

603
ELBO train_loss -3.1112709585225806
mse train_loss 0.00035356547554981244
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.009208245
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

604
ELBO train_loss -3.123910571044346
mse train_loss 0.00034562079771312027
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.0145005165
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

605
ELBO train_loss -3.1136767504350193
mse train_loss 0.0003557305931023603
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018221581
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

606
ELBO train_loss -3.1084509435689673
mse train_loss 0.0003594932385790601
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012457988
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

607
ELBO train_loss -3.1042747992389605
mse train_loss 0.0003655730540214001
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012098927
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

608
ELBO train_loss -3.1114418146745213
mse train_loss 0.00035960912316782787
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014013546
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

609
ELBO train_loss -3.1215127009265826
mse train_loss 0.00034677960749557416
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01403374
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

610
ELBO train_loss -3.090312142417116
mse train_loss 0.0003525533250484721
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.029791884
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

611
ELBO train_loss -3.0990195499276214
mse train_loss 0.00036742706274622524
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015164272
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

612
ELBO train_loss -3.076111966708921
mse train_loss 0.00043961555446621103
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.03821163
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

613
ELBO train_loss -2.8849311999554903
mse train_loss 0.00037514038311475714
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019979626
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

614
ELBO train_loss -2.9466287414982633
mse train_loss 0.00037096199916702044
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01461724
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

615
ELBO train_loss -2.9923710778074444
mse train_loss 0.00034918367202648506
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019739915
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

616
ELBO train_loss -3.015027055200541
mse train_loss 0.00037171043409473914
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.017303912
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

617
ELBO train_loss -3.04310619606162
mse train_loss 0.0003572136777637632
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015013591
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

618
ELBO train_loss -3.068444346481899
mse train_loss 0.00035026869247887143
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015906928
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

619
ELBO train_loss -3.073092150238325
mse train_loss 0.0003794854976904561
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.00942278
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

620
ELBO train_loss -3.095283071949797
mse train_loss 0.00035230691005467033
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012914964
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

621
ELBO train_loss -3.1058993114615387
mse train_loss 0.00035370449576184226
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.0154341785
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

622
ELBO train_loss -3.111399736044542
mse train_loss 0.00036517627573493026
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.009761806
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

623
ELBO train_loss -3.112572872413779
mse train_loss 0.0003448767766833112
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012376208
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

624
ELBO train_loss -3.128709019355054
mse train_loss 0.0003568756071200489
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.024029654
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

625
ELBO train_loss -3.140510275678815
mse train_loss 0.00036308000370517446
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010761074
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

626
ELBO train_loss -3.1356253893870227
mse train_loss 0.00034877858212533985
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010751532
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

627
ELBO train_loss -3.1276328068859174
mse train_loss 0.00036917413548645474
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010902088
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

628
ELBO train_loss -3.124426490855667
mse train_loss 0.0003694928611435418
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014008356
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

629
ELBO train_loss -3.124304474524732
mse train_loss 0.00036776400081163166
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010757049
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

630
ELBO train_loss -3.1414206162938534
mse train_loss 0.0003660852946165316
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012657199
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

631
ELBO train_loss -3.1176610982642985
mse train_loss 0.00038482571368870095
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019410111
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

632
ELBO train_loss -3.1258049775969305
mse train_loss 0.00035803075627212957
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.020907251
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

633
ELBO train_loss -3.140511193365421
mse train_loss 0.0003399586838775508
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.009623252
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

634
ELBO train_loss -3.1464484907546133
mse train_loss 0.000354919960406029
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011684744
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

635
ELBO train_loss -3.1482765989483528
mse train_loss 0.0003606001884931312
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011243365
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

636
ELBO train_loss -3.1303342468333692
mse train_loss 0.0003790032477689169
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019427255
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

637
ELBO train_loss -3.120215020089779
mse train_loss 0.0003560568998954347
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014264525
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

638
ELBO train_loss -3.1522854634051054
mse train_loss 0.000352654319733188
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011539043
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

639
ELBO train_loss -3.1575722604427696
mse train_loss 0.0003625527090164569
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012819426
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

640
ELBO train_loss -3.141640541688451
mse train_loss 0.0003500482628126365
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012807744
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

641
ELBO train_loss -3.1520650116902478
mse train_loss 0.00034128789425353116
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011983101
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

642
ELBO train_loss -3.1432759491902478
mse train_loss 0.00040280277443223824
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.024417011
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

643
ELBO train_loss -3.077586695833026
mse train_loss 0.0003913460120356659
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.039184712
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

644
ELBO train_loss -3.1496138932569973
mse train_loss 0.0003620556732507879
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010962108
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

645
ELBO train_loss -3.166074460407473
mse train_loss 0.00034244759595856284
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011570456
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

646
ELBO train_loss -2.9788584259321107
mse train_loss 0.00042818871492581476
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.042123202
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

647
ELBO train_loss -3.0852077457140075
mse train_loss 0.00035927954233110933
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.02978024
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

648
ELBO train_loss -3.1494943825703747
mse train_loss 0.00035873539048413976
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.02365448
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

649
ELBO train_loss -3.179914613939681
mse train_loss 0.00033718289739896
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.042401224
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

650
ELBO train_loss -3.1667210470955327
mse train_loss 0.00037916019539137915
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01038504
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

651
ELBO train_loss -3.1760255165819853
mse train_loss 0.00036459473562470794
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.009494026
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

652
ELBO train_loss -3.172026418290048
mse train_loss 0.000341494880421175
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.04980095
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

653
ELBO train_loss -3.160310794722359
mse train_loss 0.0003507009028942535
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021760792
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

654
ELBO train_loss -3.1850310991395197
mse train_loss 0.000355055442443425
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010815708
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

655
ELBO train_loss -3.1903823771566713
mse train_loss 0.0003420261535486031
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.034287848
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

656
ELBO train_loss -3.1770839601192833
mse train_loss 0.00035512868511569597
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013766807
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

657
ELBO train_loss -3.1737688217522964
mse train_loss 0.00035541292657759394
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.101771966
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

658
ELBO train_loss -3.1178284591099
mse train_loss 0.00035066966171692703
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.023129579
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

659
ELBO train_loss -3.1504086975781425
mse train_loss 0.00041686349086032934
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018221006
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

660
ELBO train_loss -3.173454959437532
mse train_loss 0.0003516917476070307
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019541165
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

661
ELBO train_loss -3.1804589505465524
mse train_loss 0.000348301270245223
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014971787
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

662
ELBO train_loss -3.175193723642601
mse train_loss 0.00037290779066509304
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014862021
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

663
ELBO train_loss -3.139890472843962
mse train_loss 0.0003988072460073591
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.02623039
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

664
ELBO train_loss -3.0232776200996256
mse train_loss 0.00039374487055026275
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.028472953
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

665
ELBO train_loss -3.1519020256006494
mse train_loss 0.00037026739644132695
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.017241102
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

666
ELBO train_loss -3.171225408338151
mse train_loss 0.0003679275248015194
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.024376586
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

667
ELBO train_loss -3.18026742395365
mse train_loss 0.0003619805238657754
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014129191
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

668
ELBO train_loss -3.1912151507611544
mse train_loss 0.00035212699346956006
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010567215
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

669
ELBO train_loss -3.190836582543715
mse train_loss 0.00036018968033736117
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.0122805815
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

670
ELBO train_loss -3.191329281285124
mse train_loss 0.0003535250274031774
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.02141952
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

671
ELBO train_loss -3.185905654475374
mse train_loss 0.00035899694674573665
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021743977
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

672
ELBO train_loss -3.191414959025833
mse train_loss 0.0005007981559380292
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012147647
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

673
ELBO train_loss -3.200642653231351
mse train_loss 0.0003904470446835852
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.04024268
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

674
ELBO train_loss -3.1938662034160687
mse train_loss 0.00036256785399087194
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015312696
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

675
ELBO train_loss -3.210690444370486
mse train_loss 0.00037032158694994406
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01967043
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

676
ELBO train_loss -3.1946271815389955
mse train_loss 0.0003853695923440545
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01334219
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

677
ELBO train_loss -3.0243926880494603
mse train_loss 0.00038561000151533353
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.085661866
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

678
ELBO train_loss -3.120729198995626
mse train_loss 0.0003849617051804122
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.030327873
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

679
ELBO train_loss -3.1965764018724547
mse train_loss 0.00034976346151384895
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013091515
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

680
ELBO train_loss -3.2128915966681713
mse train_loss 0.00036513852353651464
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012847767
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

681
ELBO train_loss -3.193607939864105
mse train_loss 0.00040027346458716565
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012220801
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

682
ELBO train_loss -3.2058821444241508
mse train_loss 0.0003765564807800506
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012819303
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

683
ELBO train_loss -3.2200132585921377
mse train_loss 0.0003458483093571937
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013646087
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

684
ELBO train_loss -3.2251498654203594
mse train_loss 0.0003458980077170363
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.023922948
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

685
ELBO train_loss -3.224615699840042
mse train_loss 0.0003493143366935294
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013332954
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

686
ELBO train_loss -3.2283305896902985
mse train_loss 0.00036359454322834763
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.016974794
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

687
ELBO train_loss -3.226599580836746
mse train_loss 0.00034937062267793937
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.015407589
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

688
ELBO train_loss -3.2287645025073357
mse train_loss 0.0003520347523109829
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.0123443855
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

689
ELBO train_loss -3.235228277602286
mse train_loss 0.00037929863599787977
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.025494438
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

690
ELBO train_loss -3.240288648965224
mse train_loss 0.0003458892437574489
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.013800731
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

691
ELBO train_loss -3.2376787437582917
mse train_loss 0.00036780724622076377
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011346192
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

692
ELBO train_loss -3.2368734827581442
mse train_loss 0.00035112160027361
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.010370887
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

693
ELBO train_loss -3.240170919670249
mse train_loss 0.00035369034388700046
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012061499
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

694
ELBO train_loss -3.2408111140413105
mse train_loss 0.00040073708888090586
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.04925789
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

695
ELBO train_loss -3.2128715605105995
mse train_loss 0.0003625203172768519
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.011920876
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

696
ELBO train_loss -3.231299895160603
mse train_loss 0.0003651919893322969
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.012051086
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

697
ELBO train_loss -3.2416453271541954
mse train_loss 0.0003749323272440618
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.019404398
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

698
ELBO train_loss -3.2421668610482848
mse train_loss 0.00033612983610990614
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.023281118
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

699
ELBO train_loss -3.2295216794283883
mse train_loss 0.00037925126896468255
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01751405
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

700
ELBO train_loss -3.215396723657284
mse train_loss 0.0003982369388774354
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018904917
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

701
ELBO train_loss -3.2346770043643014
mse train_loss 0.00035108759711531675
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.028848104
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

702
ELBO train_loss -3.236990456311208
mse train_loss 0.0003568366159496056
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014739772
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

703
ELBO train_loss -3.2490914317796813
mse train_loss 0.0003655116498812963
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014889522
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

704
ELBO train_loss -3.242183316428706
mse train_loss 0.0003782826407707103
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.018669982
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

705
ELBO train_loss -3.242716127971433
mse train_loss 0.0003730615050167302
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.027037501
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

706
ELBO train_loss -3.215073383079385
mse train_loss 0.0003626884190796189
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.03430771
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

707
ELBO train_loss -3.252434923963727
mse train_loss 0.0003663710255905072
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.01383243
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

708
ELBO train_loss -3.24580313574593
mse train_loss 0.00038051643169414746
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.017978817
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

709
ELBO train_loss -3.2418219098504983
mse train_loss 0.0003754524796880465
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.017922299
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

710
ELBO train_loss -3.2060360458661927
mse train_loss 0.00037200654786023014
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.014887975
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

711
ELBO train_loss -3.2496387283757047
mse train_loss 0.00035575572577006414
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.042479232
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

712
ELBO train_loss -3.24133237352911
mse train_loss 0.0003655838956100919
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.03384453
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

713
ELBO train_loss -3.250419265819046
mse train_loss 0.0003703219018753548
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.016448788
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

714
ELBO train_loss -3.24164154844464
mse train_loss 0.0003682042602208039
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.027619187
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

715
ELBO train_loss -3.253242218269492
mse train_loss 0.0003721281287001165
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.016647995
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

716
ELBO train_loss -3.2642099182560758
mse train_loss 0.00034525178351993055
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.036524754
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

717
ELBO train_loss -3.263999507112323
mse train_loss 0.00037628159674877335
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])
torch.Size([1, 5209, 4])
val_loss =  0.021006078
-----------------------------------------------------------------------


  0%|          | 0/53 [00:00<?, ?it/s]

718
ELBO train_loss -3.2425110497564638
mse train_loss 0.0003764521092273084
torch.Size([1, 20837, 2]) torch.Size([1, 20837, 4]) torch.Size([1, 5209, 2]) torch.Size([1, 5209, 4])


KeyboardInterrupt: 

# Meuse data testing

In [None]:
%%capture
!pip install geopandas

In [None]:
import os,requests
import geopandas as gpd

In [None]:
allow_columns = ['lead', 'elev', 'dist', 'x', 'y', 'ffreq', 'soil']
label ='lead'

In [None]:
meuse = gpd.read_file('meuse')
meuse.crs = {'init':'epsg:28992'}
meuse['x'] = meuse['geometry'].apply(lambda x: x.x)
meuse['y'] = meuse['geometry'].apply(lambda x: x.y)
meuse.sample(2)

In [None]:
for col in ['ffreq', 'soil']:
    meuse[col] = pd.to_numeric(meuse[col]) * 1.0
meuse[allow_columns].info()

In [None]:
np.random.seed(0)
test_indexes = np.random.choice(a=meuse.index, size=int(np.round(len(meuse.index.values)/4)))
train_indexes = [index for index in meuse.index if index not in test_indexes]
meuse_test = meuse.loc[test_indexes,:].copy()
meuse_train = meuse.loc[train_indexes,:].copy()
print('Number of observations in training: {}, in test: {}'.format(len(meuse_train), len(meuse_test)))

In [None]:
df_train = meuse_train[allow_columns].copy()
norm_mean = meuse[allow_columns].mean()
norm_std = meuse[allow_columns].std()

hparams = dict(
    num_context=15,
    num_extra_target=16,
    batch_size=40,
    context_in_target=False,
)


In [None]:
class DataSet(torch.utils.data.Dataset):
    def __init__(self, df, num_context=40, num_extra_target=10, label_names=['lead']):
        self.df = df
        self.num_context = num_context
        self.num_extra_target = num_extra_target
        self.label_names = label_names

    def get_rows(self, i):
        rows = self.df.iloc[i : i + (self.num_context + self.num_extra_target)].copy()

        # make sure tstp, which is our x axis, is the first value

        # This will be the last row, and will change it upon sample to let the model know some points are in the future

        x = rows.drop(columns=self.label_names).copy()
        y = rows[self.label_names].copy()
#         print(x.shape)
#         print(y.shape)
        return x, y


    def __getitem__(self, i):
        x, y = self.get_rows(i)
        return x.values, y.values
        
    def __len__(self):
        return len(self.df) - (self.num_context + self.num_extra_target)


In [None]:
for i, batch in enumerate(loader_train):
    a,b,c,d = batch
    print(d.shape)

In [None]:
df_train = meuse_train[allow_columns].copy()
df_train -= norm_mean
df_train /= norm_std

data_train = DataSet(
    df_train, hparams["num_context"], hparams["num_extra_target"]
)
loader_train = torch.utils.data.DataLoader(
    data_train,
    batch_size=hparams["batch_size"],
    shuffle=True,
    collate_fn=collate_fns(
        hparams["num_context"], hparams["num_extra_target"], sample=True, context_in_target=hparams["context_in_target"]
    ),
)

In [None]:
df_test = meuse_test[allow_columns].copy()
df_test -= norm_mean
df_test /= norm_std

data_test = DataSet(
    df_test, hparams["num_context"], hparams["num_extra_target"]
)
loader_test = torch.utils.data.DataLoader(
    data_test,
    batch_size=hparams["batch_size"],
    shuffle=False,
    collate_fn=collate_fns(
        hparams["num_context"], hparams["num_extra_target"], sample=False, context_in_target=hparams["context_in_target"]
    ),
)

In [None]:
Regressor = LatentModel(len(allow_columns)-1,1,
                       p_drop = 0.5,
                        hidden_dim = 64,
                        latent_dim = 16,
                       n_decoder_layer = 3,
                       norm = 'batch',
                       context_in_target=False
                       )

In [None]:
from tqdm.auto import tqdm 

for epoch in range(1000):
    loss = 0 
    Regressor.train()
    for batch in tqdm(loader_train):
        context_x, context_y, target_x, target_y = batch
        print(context_x.shape)
        Regressor.zero_grad()
        y_pred, losses, extra = Regressor.forward(context_x, context_y, target_x, target_y)

        losses['loss'].backward()
        loss += losses['loss'].cpu().detach().numpy()
        opt.step()
    loss /= len(loader_train)
    
    print(epoch)
    print('train_loss', loss)