Dataset includes several time series representing different ATMs and their daily cash money demand at different locations.

In [1]:
# NOTE FIX THESE IMPORTS

from data_loader import convert_tsf_to_dataframe
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import statsmodels.api as sm
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import linear, softmax
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch import Tensor
from typing import Optional, Tuple, Union, Callable, Any, List
from torch.nn.init import xavier_uniform_, constant_, xavier_normal_
import copy
from torch.nn import Module, LayerNorm, ModuleList
from torch.nn.parameter import Parameter
import datetime
from torch.utils.data import Dataset, DataLoader
import sys

In [2]:
torch.set_printoptions(threshold=1000)

In [3]:
TSF_FILE = 'nn5_daily_dataset_with_missing_values.tsf'

In [4]:
data = convert_tsf_to_dataframe(TSF_FILE)

In [5]:
series_values = data[0]['series_value']

In [None]:
data

In [None]:
data[0]

In [None]:
data[0]['series_value'][0]

In [None]:
date_range_0 = pd.date_range(start='1996-03-18', periods=791, freq='D')

In [None]:
date_range_0

In [None]:
data_0 = pd.DataFrame({'series_value': data[0]['series_value'][0], 'date': date_range_0})
data_0 = data_0.set_index('date')

In [None]:
data_0

In [None]:
plt.plot(data_0['series_value'][:21])

In [None]:
np.array(data_0['series_value'])

In [None]:
# proportion of nulls and zeroes
total = 791 * 111
null_count = 0
zero_count = 0
for i in range(len(data[0]['series_value'])):
    series_values = np.array(pd.to_numeric(data[0]['series_value'][i], errors='coerce'))
    nulls_in_series = np.sum(np.isnan(series_values))
    null_count += nulls_in_series
    zeroes_in_series = len(series_values) - np.count_nonzero(series_values)
    zero_count += zeroes_in_series
(null_count + zero_count) / total
    

Questions and notes:

- Global methods seem better than traditional univariate methods for this dataset because we can leverage related but seperate datasets in a unified model. Traditional methods such as ARIMA do not have this capability and if we would need to fit a separate model for each dataset. (Look into this).  


- How to deal with missing values? I assume if we're using a deep learning method, we could simply treat the missing values as breakpoints for training samples. I.e. let's say you have \[$x_1$, $x_2$, $x_3$, $x_4$, $x_5$, $x_6$, $x_7$, $x_8$\] where this is a sequence in order of time and $x_5$ is missing. Then, assuming the model uses a lag of 2, we can treat {$X$: \[$x_1$, $x_2$\], $y$: $x_3$}, {$X$: \[$x_2$, $x_3$\], $y$: $x_4$} and {$X$: \[$x_6$, $x_7$\], $y$: $x_8$} as our training examples, treating $x_5$ as a breakpoint. Is there a better way? ANSWER: for now no need to interpolate missing data, let's just try training on what we have.


- How can we identify the types of seasonality in the data? Can we deal with multiple seasonality? Traditional seasonal decomposition methods (even those which deal with multiple seasonality i.e. MSTL) will probably not work too well due to the multiple series. How do we integrate seasonality into the model? One idea is to use sin/cos encodings depending on the seasonal position. I.e. for seasonality that occurs one day a week, use $a\sin (x\cdot \frac{\pi}{3.5})$ and the cos version, where $a$ is tuneable. But how many of these terms should we include? Seems like a relatively wieldy hyperparameter. Also, if we need to see how to modify the architecture to include these terms. Another idea is to include the previous season's data, i.e. alongside the $k$ lags from the past we use $x_{t - p}$ as input, where $p$ is the seasonal period. Note that we would also need some kind of positional encoding for this value. Perhaps try dealing with seasonality after an original attempt without it, though.


- TODO: I want to try applying a transformer model to this data, research and look into transformers. From there we can try different things. So first just try applying a vanilla transformer. 
    - BUT WAIT... is the transformer really the best choice here? The problem I see with the transformer is that it is more suited for sequence-to-sequence models. In other words, you get one complete sequence as input and have to output another related sequence. Perhaps there is some modification to this that allows for completion of one sequence but anyway, if I'm correct, the self-attention used in the transformer will not be applicable to numerical time series such as this one. When it comes to sequence-to-sequence models that the transformer is meant for, the words can influence each other, i.e. a word later in the sequence can provide semantic information about a word earlier in the sequence, and self-attention deals with that scenario, where the information flow is not necessarily left-to-right. But for a time series like this, the information flows unidirectionally, elements earlier in the sequence do not depend on elements later in the sequence. Using a self-attention encoding where all elements of the sequence attend on each other would lead to spurious modelling. Perhaps an RNN model would be more suited to this kind of time series data. I will look into deep learning models for time series and find something that makes sense. 



- Better way to decode? Right now training uses "teacher forcing", but inference does not, if a very wrong prediction is produced then it negatively affects the subsequent predictions that get decoded, and without a concept of sequence probability we can't use beam search. Perhaps predict everything at once, i.e. no autoregressive decoding (which is needed for NLP because you don't know when the sequence will end, but we know the desired forecast horizon for time series). New "decoder" takes timestep position, performs some positional embedding, computes attention on encoder output, depending on query generated from positional embedding, performs transformations on the values and outputs the desired number of forecasts, one for each positional embedding. Something to try. Note this could possibly also enable forecasting of variable length for a single trained model depending on how I implement the positional encodings. If I use a lookup table for the positional encodings this is impossible, but if I use matrix mulitplications with the raw timestep number it is possible.


- Seasonailty: first create a seasonal encoder. Essentially given a forecast horizon, i.e. 7, and a hyperparameter for window size, i.e. 5, we create a "seasonal context" of len (horizon + window_sz) - 1. In other words, we take all our desired dates to forecast, seasonally shifted, get those values, as well as the (window_sz - 1) / 2 values on each side. So we have our seasonal context, comprised of horizon_len windows of length window_sz. Example: if our context is \[1, 2, 3, 4, 5\] and our window_sz is 3 (meaning we want to forecast 3 into the future), we have 3 windows \[1, 2, 3\], \[2, 3, 4\], and \[3, 4, 5\]. One corresponding to each position we want to decode. The intuition is that we hope each window contains seasonal data relevant to each date we want to forecast. We encode this seasonal context with its own learned position embeddings and weights; encoder is our modified encoder with causal self attention. Once the self attention finishes, we get a matrix of size (context_len x d_season). Normally we would then use the decoder-encoder attention as in the normal transformer but the problem here is that if we want to use this with the modified decoder as above, we can't do this. The decoder inputs for the modified decoder are merely timestep encodings so I'm not sure if that will give enough information to attend to the seasonal context, but I should probably try it nonetheless. My idea is to first perform attention from the decoder timestep encodings to the encodings of the previous $k$ time series elements immediately preceding the dates we want to forecast. Then we attend from the output of this to the seasonal encodings, applying as mask (see notes) before the softmax so that each decoder position only attends to the corresponding window for each season. We perform this attention on each seasonal component at the same time by concatentating the keys along the position axis. From here we get the regular dimension of values which we transform to get all the decoder timesteps at once. 


- Think about feeature engineering


- How to deal with nulls in target? I.e. how do we remove these from loss? We might still have to discard data... Like if at timestep 3 in the decoder the ground truth isn't available, i.e. null, how can we compute the loss and keep it differentiable? 
    - I suppose we could impute the null value in the target to be the same as the previous value or something. Not ideal but will be differentiable at least
    - Pytorch has a reduce flag in loss which lets you do perform loss elementwise. Idea is that we remove the nulls after doing this, and average the rest. But is this differentiable? Well apparently it's somehow differentiable but some of the gradients become null
    
    
- Perhaps I should apply some sort of clustering method. It seems like each ATM is at a different location, so the location itself is a latent variable that is not explicitly included. Different locations and location types may look different. Clustering could deal with this, use a different model for each cluster. Hopefully these inherent location differences get picked up by the clustering method. Look into time series clustering. 


- Here's the plan: try a variety of different models, and forecast for 1-day, 2-days, ..., 7-days ahead

- Models to consider:
    - ARIMA (baseline)
    - LSTM
    - Transformer
    - Temporal fusion transformer


- Transformer ideas:
    - Deal with nulls by setting them to 0 and applying an (input) mask which we add before the softmax which makes the null components of the self-attention negative infinity, becoming 0 after the softmax, thus negating their attention weights and ensuring they do not factor in to the final result. Experimentally verified that this does not affect the backprop computation (null component has 0 gradient in query/key matrices). 
        - Note that we also have to mask out the row of the queries for the null value, only way I can think of to do this is to first let the softmax do its thing on the row of 0's, then remove the row (using the null timestep dim input), and add in a row of 0's at the deleted position. Wouldn't this break the differentiable graph though? Perhaps there's no need to do this after all. Let's verify. Yes, we don't need to do the initial removal from the softmax tensor, we can simply add a mask in the decoding component which masks out the encoding corresponding to null in its attention
    - Mask out future timesteps during encoding. I.e. make the encoder self attention autoregressive by applying an encoding mask. Deals with the issue mentioned above. So the encoding at position 2 can only be affected by positions 1 and 2, etc.
    

In [None]:
# Trying vanilla transformer from pytorch with adaptations

In [6]:

def _in_projection_packed(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    w: Tensor,
    b: Optional[Tensor] = None,
) -> List[Tensor]:
    r"""
    Performs the in-projection step of the attention operation, using packed weights.
    Output is a triple containing projection tensors for query, key and value.
    Args:
        q, k, v: query, key and value tensors to be projected. For self-attention,
            these are typically the same tensor; for encoder-decoder attention,
            k and v are typically the same tensor. (We take advantage of these
            identities for performance if they are present.) Regardless, q, k and v
            must share a common embedding dimension; otherwise their shapes may vary.
        w: projection weights for q, k and v, packed into a single tensor. Weights
            are packed along dimension 0, in q, k, v order.
        b: optional projection biases for q, k and v, packed into a single tensor
            in q, k, v order.
    Shape:
        Inputs:
        - q: :math:`(..., E)` where E is the embedding dimension
        - k: :math:`(..., E)` where E is the embedding dimension
        - v: :math:`(..., E)` where E is the embedding dimension
        - w: :math:`(E * 3, E)` where E is the embedding dimension
        - b: :math:`E * 3` where E is the embedding dimension
        Output:
        - in output list :math:`[q', k', v']`, each output tensor will have the
            same shape as the corresponding input tensor.
    """
    E = q.size(-1)
    if k is v:
        if q is k:
            # self-attention
            return linear(q, w, b).chunk(3, dim=-1)
        else:
            # encoder-decoder attention
            w_q, w_kv = w.split([E, E * 2])
            if b is None:
                b_q = b_kv = None
            else:
                b_q, b_kv = b.split([E, E * 2])
            return (linear(q, w_q, b_q),) + linear(k, w_kv, b_kv).chunk(2, dim=-1)
    else:
        w_q, w_k, w_v = w.chunk(3)
        if b is None:
            b_q = b_k = b_v = None
        else:
            b_q, b_k, b_v = b.chunk(3)
        return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)

    
def _in_projection(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    w_q: Tensor,
    w_k: Tensor,
    w_v: Tensor,
    b_q: Optional[Tensor] = None,
    b_k: Optional[Tensor] = None,
    b_v: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor]:
    r"""
    Performs the in-projection step of the attention operation. This is simply
    a triple of linear projections, with shape constraints on the weights which
    ensure embedding dimension uniformity in the projected outputs.
    Output is a triple containing projection tensors for query, key and value.
    Args:
        q, k, v: query, key and value tensors to be projected.
        w_q, w_k, w_v: weights for q, k and v, respectively.
        b_q, b_k, b_v: optional biases for q, k and v, respectively.
    Shape:
        Inputs:
        - q: :math:`(Qdims..., Eq)` where Eq is the query embedding dimension and Qdims are any
            number of leading dimensions.
        - k: :math:`(Kdims..., Ek)` where Ek is the key embedding dimension and Kdims are any
            number of leading dimensions.
        - v: :math:`(Vdims..., Ev)` where Ev is the value embedding dimension and Vdims are any
            number of leading dimensions.
        - w_q: :math:`(Eq, Eq)`
        - w_k: :math:`(Eq, Ek)`
        - w_v: :math:`(Eq, Ev)`
        - b_q: :math:`(Eq)`
        - b_k: :math:`(Eq)`
        - b_v: :math:`(Eq)`
        Output: in output triple :math:`(q', k', v')`,
         - q': :math:`[Qdims..., Eq]`
         - k': :math:`[Kdims..., Eq]`
         - v': :math:`[Vdims..., Eq]`
    """
    Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1)
    assert w_q.shape == (Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
    assert w_k.shape == (Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
    assert w_v.shape == (Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
    assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
    assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
    assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
    return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)



def _scaled_dot_product_attention(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    attn_mask: Optional[Tensor] = None,
    dropout_p: float = 0.0,
) -> Tuple[Tensor, Tensor]:
    r"""
    Computes scaled dot product attention on query, key and value tensors, using
    an optional attention mask if passed, and applying dropout if a probability
    greater than 0.0 is specified.
    Returns a tensor pair containing attended values and attention weights.
    Args:
        q, k, v: query, key and value tensors. See Shape section for shape details.
        attn_mask: optional tensor containing mask values to be added to calculated
            attention. May be 2D or 3D; see Shape section for details.
        dropout_p: dropout probability. If greater than 0.0, dropout is applied.
    Shape:
        - q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length,
            and E is embedding dimension.
        - key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
            and E is embedding dimension.
        - value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
            and E is embedding dimension.
        - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of
            shape :math:`(Nt, Ns)`.
        - Output: attention values have shape :math:`(B, Nt, E)`; attention weights
            have shape :math:`(B, Nt, Ns)`
    """
    B, Nt, E = q.shape
    q = q / math.sqrt(E)

    # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
    if attn_mask is not None:
        attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1))
    else:
        attn = torch.bmm(q, k.transpose(-2, -1))
    
#     print('attn before soft')
#     print(attn)
    attn = softmax(attn, dim=-1)
#     print('attn after soft')
#     print(attn)

    if dropout_p > 0.0:
        attn = F.dropout(attn, p=dropout_p)
    # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
    output = torch.bmm(attn, v)
#     print('attn output')
#     print(output)
    return output, attn



def multi_head_attention_forward(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    embed_dim_to_check: int,
    num_heads: int,
    in_proj_weight: Optional[Tensor],
    in_proj_bias: Optional[Tensor],
    bias_k: Optional[Tensor],
    bias_v: Optional[Tensor],
    add_zero_attn: bool,
    dropout_p: float,
    out_proj_weight: Tensor,
    out_proj_bias: Optional[Tensor],
    training: bool = True,
    key_padding_mask: Optional[Tensor] = None,
    need_weights: bool = True,
    attn_mask: Optional[Tensor] = None,
    use_separate_proj_weight: bool = False,
    q_proj_weight: Optional[Tensor] = None,
    k_proj_weight: Optional[Tensor] = None,
    v_proj_weight: Optional[Tensor] = None,
    static_k: Optional[Tensor] = None,
    static_v: Optional[Tensor] = None,
    average_attn_weights: bool = True,
    is_batched: bool = True,
) -> Tuple[Tensor, Optional[Tensor]]:
    r"""
    Args:
        query, key, value: map a query and a set of key-value pairs to an output.
            See "Attention Is All You Need" for more details.
        embed_dim_to_check: total dimension of the model.
        num_heads: parallel attention heads.
        in_proj_weight, in_proj_bias: input projection weight and bias.
        bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
        add_zero_attn: add a new batch of zeros to the key and
                       value sequences at dim=1.
        dropout_p: probability of an element to be zeroed.
        out_proj_weight, out_proj_bias: the output projection weight and bias.
        training: apply dropout if is ``True``.
        key_padding_mask: if provided, specified padding elements in the key will
            be ignored by the attention. This is an binary mask. When the value is True,
            the corresponding value on the attention layer will be filled with -inf.
        need_weights: output attn_output_weights.
        attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
            the batches while a 3D mask allows to specify a different mask for the entries of each batch.
        use_separate_proj_weight: the function accept the proj. weights for query, key,
            and value in different forms. If false, in_proj_weight will be used, which is
            a combination of q_proj_weight, k_proj_weight, v_proj_weight.
        q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
        static_k, static_v: static key and value used for attention operators.
        average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
            Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
            when ``need_weights=True.``. Default: True
    Shape:
        Inputs:
        - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
          the embedding dimension.
        - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
        - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
        - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
          If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
          will be unchanged. If a BoolTensor is provided, the positions with the
          value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
        - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
          3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
          S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
          positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
          while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
          are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
          is provided, it will be added to the attention weight.
        - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
          N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
        - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
          N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
        Outputs:
        - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
          E is the embedding dimension.
        - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
          attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
          :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
          :math:`S` is the source sequence length. If ``average_weights=False``, returns attention weights per
          head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
    """
    
#     print('query')
#     print(query)
#     print('key')
#     print(key)
#     print('value')
#     print(value)
    
    
    # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
    # is batched, run the computation and before returning squeeze the
    # batch dimension so that the output doesn't carry this temporary batch dimension.
    if not is_batched:
        # unsqueeze if the input is unbatched
        query = query.unsqueeze(1)
        key = key.unsqueeze(1)
        value = value.unsqueeze(1)
        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask.unsqueeze(0)

    # set up shape vars
    tgt_len, bsz, embed_dim = query.shape
    src_len, _, _ = key.shape
    assert embed_dim == embed_dim_to_check, \
        f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
    
    head_dim = embed_dim // num_heads
    assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
    if use_separate_proj_weight:
        # allow MHA to have different embedding dimensions when separate projection weights are used
        assert key.shape[:2] == value.shape[:2], \
            f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
    else:
        assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"

    #
    # compute in-projection
    #
 
    if not use_separate_proj_weight:
        assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
        q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
    else:
        assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
        assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
        assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
        if in_proj_bias is None:
            b_q = b_k = b_v = None
        else:
            b_q, b_k, b_v = in_proj_bias.chunk(3)
        q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)

#     print('query_emb')
#     print(q)
#     print('key_emb')
#     print(k)
#     print('value_emb')
#     print(v)
    # prep attention mask
    if attn_mask is not None:
        if attn_mask.dtype == torch.uint8:
            warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
            attn_mask = attn_mask.to(torch.bool)
        else:
            assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \
                f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}"
        # ensure attn_mask's dim is 3
        if attn_mask.dim() == 2:
            correct_2d_size = (tgt_len, src_len)
            if attn_mask.shape != correct_2d_size:
                raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
            attn_mask = attn_mask.unsqueeze(0)
        elif attn_mask.dim() == 3:
            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
            if attn_mask.shape != correct_3d_size:
                raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
        else:
            raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")

    # prep key padding mask
    if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
        warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
        key_padding_mask = key_padding_mask.to(torch.bool)

    # add bias along batch dimension (currently second)
    if bias_k is not None and bias_v is not None:
        assert static_k is None, "bias cannot be added to static key."
        assert static_v is None, "bias cannot be added to static value."
        k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
        v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
        if attn_mask is not None:
            attn_mask = pad(attn_mask, (0, 1))
        if key_padding_mask is not None:
            key_padding_mask = pad(key_padding_mask, (0, 1))
    else:
        assert bias_k is None
        assert bias_v is None

    #
    # reshape q, k, v for multihead attention and make em batch first
    #
    q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)  
    if static_k is None:
        k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    else:
        # TODO finish disentangling control flow so we don't do in-projections when statics are passed
        assert static_k.size(0) == bsz * num_heads, \
            f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
        assert static_k.size(2) == head_dim, \
            f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
        k = static_k
    if static_v is None:
        v = v.contiguous().view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    else:
        # TODO finish disentangling control flow so we don't do in-projections when statics are passed
        assert static_v.size(0) == bsz * num_heads, \
            f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
        assert static_v.size(2) == head_dim, \
            f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
        v = static_v

    # update source sequence length after adjustments
    src_len = k.size(1)

    # merge key padding and attention masks
    if key_padding_mask is not None:
        assert key_padding_mask.shape == (bsz, src_len), \
            f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
        key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).   \
            expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
        if attn_mask is None:
            attn_mask = key_padding_mask
        elif attn_mask.dtype == torch.bool:
            attn_mask = attn_mask.logical_or(key_padding_mask)
        else:
            attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))

    # convert mask to float
    if attn_mask is not None and attn_mask.dtype == torch.bool:
        new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
        new_attn_mask.masked_fill_(attn_mask, float("-inf"))
        attn_mask = new_attn_mask

    # adjust dropout probability
    if not training:
        dropout_p = 0.0

    #
    # (deep breath) calculate attention and out projection
    #
    attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)
    attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
    attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
    attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))

    if need_weights:
        # optionally average attention weights over heads
        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
        if average_attn_weights:
            attn_output_weights = attn_output_weights.sum(dim=1) / num_heads

        if not is_batched:
            # squeeze the output if input was unbatched
            attn_output = attn_output.squeeze(1)
            attn_output_weights = attn_output_weights.squeeze(0)
        return attn_output, attn_output_weights
    else:
        if not is_batched:
            # squeeze the output if input was unbatched
            attn_output = attn_output.squeeze(1)
        return attn_output, None

In [7]:
class MultiheadAttention(nn.Module):
    r"""Allows the model to jointly attend to information
    from different representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    Multi-Head Attention is defined as:

    .. math::
        \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O

    where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.

    Args:
        embed_dim: Total dimension of the model.
        num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
            across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
        dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
        bias: If specified, adds bias to input / output projection layers. Default: ``True``.
        add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
        add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
            Default: ``False``.
        kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
        vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).

    Examples::

        >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
        >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
    """

    def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
                 kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(MultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim

        self.num_heads = num_heads
        self.dropout = dropout
        self.batch_first = batch_first
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        if self._qkv_same_embed_dim is False:
            self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
            self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
            self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
            self.register_parameter('in_proj_weight', None)
        else:
            self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
            self.register_parameter('q_proj_weight', None)
            self.register_parameter('k_proj_weight', None)
            self.register_parameter('v_proj_weight', None)

        if bias:
            self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
        else:
            self.register_parameter('in_proj_bias', None)
        self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)

        if add_bias_kv:
            self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
            self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
        else:
            self.bias_k = self.bias_v = None

        self.add_zero_attn = add_zero_attn

        self._reset_parameters()

    def _reset_parameters(self):
        if self._qkv_same_embed_dim:
            xavier_uniform_(self.in_proj_weight)
        else:
            xavier_uniform_(self.q_proj_weight)
            xavier_uniform_(self.k_proj_weight)
            xavier_uniform_(self.v_proj_weight)

        if self.in_proj_bias is not None:
            constant_(self.in_proj_bias, 0.)
            constant_(self.out_proj.bias, 0.)
        if self.bias_k is not None:
            xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            xavier_normal_(self.bias_v)

    def __setstate__(self, state):
        # Support loading old MultiheadAttention checkpoints generated by v1.1.0
        if '_qkv_same_embed_dim' not in state:
            state['_qkv_same_embed_dim'] = True

        super(MultiheadAttention, self).__setstate__(state)

    def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
                need_weights: bool = True, attn_mask: Optional[Tensor] = None,
                average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
        r"""
    Args:
        query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
            or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
            :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
            Queries are compared against key-value pairs to produce the output.
            See "Attention Is All You Need" for more details.
        key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
            or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
            :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
            See "Attention Is All You Need" for more details.
        value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
            ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
            sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
            See "Attention Is All You Need" for more details.
        key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
            to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
            Binary and byte masks are supported.
            For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
            the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding ``key``
            value will be ignored.
        need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
            Default: ``True``.
        attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
            :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
            :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
            broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
            Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
            corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
            corresponding position is not allowed to attend. For a float mask, the mask values will be added to
            the attention weight.
        average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
            heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
            effect when ``need_weights=True.``. Default: True (i.e. average weights across heads)

    Outputs:
        - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
          :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
          where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
          embedding dimension ``embed_dim``.
        - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
          returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
          :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
          :math:`S` is the source sequence length. If ``average_weights=False``, returns attention weights per
          head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.

        .. note::
            `batch_first` argument is ignored for unbatched inputs.
        """
        
        is_batched = query.dim() == 3
        if self.batch_first and is_batched:
            query, key, value = [x.transpose(1, 0) for x in (query, key, value)]

        if not self._qkv_same_embed_dim:
            attn_output, attn_output_weights = multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask, use_separate_proj_weight=True,
                q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
                v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights)
        else:
            attn_output, attn_output_weights = multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask, average_attn_weights=average_attn_weights)
        if self.batch_first and is_batched:
            return attn_output.transpose(1, 0), attn_output_weights
        else:
            return attn_output, attn_output_weights
        


In [8]:
def _get_clones(module, N):
    return ModuleList([copy.deepcopy(module) for i in range(N)])



class TransformerModel(Module):
    r"""A transformer model. User is able to modify the attributes as needed. The architecture
    is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
    Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
    Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
    Processing Systems, pages 6000-6010.
    Args:
        d_model: the number of expected features in the encoder/decoder inputs (default=512).
        nhead: the number of heads in the multiheadattention models (default=8).
        num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
        num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of encoder/decoder intermediate layer, can be a string
            ("relu" or "gelu") or a unary callable. Default: relu
        custom_encoder: custom encoder (default=None).
        custom_decoder: custom decoder (default=None).
        layer_norm_eps: the eps value in layer normalization components (default=1e-5).
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
        norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before
            other attention and feedforward operations, otherwise after. Default: ``False`` (after).
    Examples::
        >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
        >>> src = torch.rand((10, 32, 512))
        >>> tgt = torch.rand((20, 32, 512))
        >>> out = transformer_model(src, tgt)
    Note: A full example to apply nn.Transformer module for the word language model is available in
    https://github.com/pytorch/examples/tree/master/word_language_model
    """

    def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
                 num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
                 custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
                 use_norm=True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(TransformerModel, self).__init__()

        if custom_encoder is not None:
            self.encoder = custom_encoder
        else:
            encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
                                                    activation, layer_norm_eps, batch_first, norm_first, 
                                                    use_norm,
                                                    **factory_kwargs)
            encoder_norm = None
            if use_norm:
                encoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
            self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        if custom_decoder is not None:
            self.decoder = custom_decoder
        else:
            decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
                                                    activation, layer_norm_eps, batch_first, norm_first,
                                                    use_norm,
                                                    **factory_kwargs)
            decoder_norm = None
            if use_norm:
                decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
            self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)

        self._reset_parameters()

        self.d_model = d_model
        self.nhead = nhead

        self.batch_first = batch_first

    def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Take in and process masked source/target sequences.
        Args:
            src: the sequence to the encoder (required).
            tgt: the sequence to the decoder (required).
            src_mask: the additive mask for the src sequence (optional).
            tgt_mask: the additive mask for the tgt sequence (optional).
            memory_mask: the additive mask for the encoder output (optional).
            src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
            tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
            memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).
        Shape:
            - src: :math:`(S, E)` for unbatched input, :math:`(S, N, E)` if `batch_first=False` or
              `(N, S, E)` if `batch_first=True`.
            - tgt: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
              `(N, T, E)` if `batch_first=True`.
            - src_mask: :math:`(S, S)` or :math:`(N\cdot\text{num\_heads}, S, S)`.
            - tgt_mask: :math:`(T, T)` or :math:`(N\cdot\text{num\_heads}, T, T)`.
            - memory_mask: :math:`(T, S)`.
            - src_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
            - tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`.
            - memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
            Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked
            positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
            while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
            are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
            is provided, it will be added to the attention weight.
            [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
            the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero
            positions will be unchanged. If a BoolTensor is provided, the positions with the
            value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
            - output: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
              `(N, T, E)` if `batch_first=True`.
            Note: Due to the multi-head attention architecture in the transformer model,
            the output sequence length of a transformer is same as the input sequence
            (i.e. target) length of the decoder.
            where S is the source sequence length, T is the target sequence length, N is the
            batch size, E is the feature number
        Examples:
            >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
        """

        is_batched = src.dim() == 3
        if not self.batch_first and src.size(1) != tgt.size(1) and is_batched:
            raise RuntimeError("the batch number of src and tgt must be equal")
        elif self.batch_first and src.size(0) != tgt.size(0) and is_batched:
            raise RuntimeError("the batch number of src and tgt must be equal")

        if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model:
            raise RuntimeError("the feature number of src and tgt must be equal to d_model")
        
#         print('ENCODER_STEP')
        memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
#         print('DECODER_STEP')
        output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
                              tgt_key_padding_mask=tgt_key_padding_mask,
                              memory_key_padding_mask=memory_key_padding_mask)
        return output

    @staticmethod
    def generate_square_subsequent_mask(sz: int) -> Tensor:
        r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
            Unmasked positions are filled with float(0.0).
        """
        return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)

    def _reset_parameters(self):
        r"""Initiate parameters in the transformer model."""

        for p in self.parameters():
            if p.dim() > 1:
                xavier_uniform_(p)


class TransformerEncoder(Module):
    r"""TransformerEncoder is a stack of N encoder layers. Users can build the
    BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
    Args:
        encoder_layer: an instance of the TransformerEncoderLayer() class (required).
        num_layers: the number of sub-encoder-layers in the encoder (required).
        norm: the layer normalization component (optional).
        enable_nested_tensor: if True, input will automatically convert to nested tensor
            (and convert back on output). This will improve the overall performance of
            TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
        >>> src = torch.rand(10, 32, 512)
        >>> out = transformer_encoder(src)
    """
    __constants__ = ['norm']

    def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=True):
        super(TransformerEncoder, self).__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm
        self.enable_nested_tensor = enable_nested_tensor

    def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the input through the encoder layers in turn.
        Args:
            src: the sequence to the encoder (required).
            mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).
        Shape:
            see the docs in Transformer class.
        """
        for mod in self.layers:
            output = mod(src, src_mask=mask, src_key_padding_mask=src_key_padding_mask)

        if self.norm is not None:
            output = self.norm(output)

        return output


class TransformerDecoder(Module):
    r"""TransformerDecoder is a stack of N decoder layers
    Args:
        decoder_layer: an instance of the TransformerDecoderLayer() class (required).
        num_layers: the number of sub-decoder-layers in the decoder (required).
        norm: the layer normalization component (optional).
    Examples::
        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
        >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
        >>> memory = torch.rand(10, 32, 512)
        >>> tgt = torch.rand(20, 32, 512)
        >>> out = transformer_decoder(tgt, memory)
    """
    __constants__ = ['norm']

    def __init__(self, decoder_layer, num_layers, norm=None):
        super(TransformerDecoder, self).__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the inputs (and mask) through the decoder layer in turn.
        Args:
            tgt: the sequence to the decoder (required).
            memory: the sequence from the last layer of the encoder (required).
            tgt_mask: the mask for the tgt sequence (optional).
            memory_mask: the mask for the memory sequence (optional).
            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
            memory_key_padding_mask: the mask for the memory keys per batch (optional).
        Shape:
            see the docs in Transformer class.
        """
        output = tgt
#         print('x before decoder')
#         print(output)
        for mod in self.layers:
            output = mod(output, memory, tgt_mask=tgt_mask,
                         memory_mask=memory_mask,
                         tgt_key_padding_mask=tgt_key_padding_mask,
                         memory_key_padding_mask=memory_key_padding_mask)

        if self.norm is not None:
            output = self.norm(output)

        return output

class TransformerEncoderLayer(Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.
    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of the intermediate layer, can be a string
            ("relu" or "gelu") or a unary callable. Default: relu
        layer_norm_eps: the eps value in layer normalization components (default=1e-5).
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
        norm_first: if ``True``, layer norm is done prior to attention and feedforward
            operations, respectivaly. Otherwise it's done after. Default: ``False`` (after).
    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    Alternatively, when ``batch_first`` is ``True``:
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
        >>> src = torch.rand(32, 10, 512)
        >>> out = encoder_layer(src)
    Fast path:
        forward() will use a special optimized implementation if all of the following
        conditions are met:
        - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor
          argument ``requires_grad``
        - training is disabled (using ``.eval()``)
        - batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``)
        - norm_first is ``False`` (this restriction may be loosened in the future)
        - activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu``
        - at most one of ``src_mask`` and ``src_key_padding_mask`` is passed
        - if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask``
          nor ``src_key_padding_mask`` is passed
        - the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case
          unless the caller has manually modified one without modifying the other)
        If the optimized implementation is in use, a
        `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be
        passed for ``src`` to represent padding more efficiently than using a padding
        mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be
        returned, and an additional speedup proportional to the fraction of the input that
        is padding can be expected.
    """
    __constants__ = ['batch_first', 'norm_first']

    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
                 use_norm=True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
                                            **factory_kwargs)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)

        self.norm_first = norm_first
        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        
        self.activation = activation
        
        self.use_norm = use_norm

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(TransformerEncoderLayer, self).__setstate__(state)

    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the input through the encoder layer.
        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).
        Shape:
            see the docs in Transformer class.
        """

        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf

    
        x = src
        if self.norm_first:
            x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
            x = x + self._ff_block(self.norm2(x))
        else:
            x = x + self._sa_block(x, src_mask, src_key_padding_mask)
            if self.use_norm:
                x = self.norm1(x)
#             print('x after self attn, res, norm')
#             print(x)
            x = x + self._ff_block(x)
            if self.use_norm:
                x = self.norm2(x)
            

        return x

    # self-attention block
    def _sa_block(self, x: Tensor,
                  attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
        x = self.self_attn(x, x, x,
                           attn_mask=attn_mask,
                           key_padding_mask=key_padding_mask,
                           need_weights=False)[0]
        return self.dropout1(x)

    # feed forward block
    def _ff_block(self, x: Tensor) -> Tensor:
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout2(x)


class TransformerDecoderLayer(Module):
    r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
    This standard decoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.
    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of the intermediate layer, can be a string
            ("relu" or "gelu") or a unary callable. Default: relu
        layer_norm_eps: the eps value in layer normalization components (default=1e-5).
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
        norm_first: if ``True``, layer norm is done prior to self attention, multihead
            attention and feedforward operations, respectivaly. Otherwise it's done after.
            Default: ``False`` (after).
    Examples::
        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
        >>> memory = torch.rand(10, 32, 512)
        >>> tgt = torch.rand(20, 32, 512)
        >>> out = decoder_layer(tgt, memory)
    Alternatively, when ``batch_first`` is ``True``:
        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
        >>> memory = torch.rand(32, 10, 512)
        >>> tgt = torch.rand(32, 20, 512)
        >>> out = decoder_layer(tgt, memory)
    """
    __constants__ = ['batch_first', 'norm_first']

    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
                 use_norm=True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(TransformerDecoderLayer, self).__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
                                            **factory_kwargs)
        self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
                                                 **factory_kwargs)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)

        self.norm_first = norm_first
        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = activation
        
        self.use_norm = use_norm

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(TransformerDecoderLayer, self).__setstate__(state)

    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the inputs (and mask) through the decoder layer.
        Args:
            tgt: the sequence to the decoder layer (required).
            memory: the sequence from the last layer of the encoder (required).
            tgt_mask: the mask for the tgt sequence (optional).
            memory_mask: the mask for the memory sequence (optional).
            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
            memory_key_padding_mask: the mask for the memory keys per batch (optional).
        Shape:
            see the docs in Transformer class.
        """
        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf

        x = tgt
        if self.norm_first:
            
            x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask) 
            
            
            x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask)  
            
            x = x + self._ff_block(self.norm3(x))
        else:

            x = x + self._sa_block(x, tgt_mask, tgt_key_padding_mask)
            if self.use_norm:
                x = self.norm1(x)
                
#             print('x after self attn, res, norm')
#             print(x)
             
            x = x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask)
            if self.use_norm:
                x = self.norm2(x)
                
            x = x + self._ff_block(x)
            if self.use_norm:
                x = self.norm3(x)
            
#         print('declayer out')
#         print(x)
        return x

    # self-attention block
    def _sa_block(self, x: Tensor,
                  attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
        x = self.self_attn(x, x, x,
                               attn_mask=attn_mask,
                               key_padding_mask=key_padding_mask,
                               need_weights=False)[0]
        return self.dropout1(x)

    # multihead attention block
    def _mha_block(self, x: Tensor, mem: Tensor,
                   attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
        x = self.multihead_attn(x, mem, mem,
                                attn_mask=attn_mask,
                                key_padding_mask=key_padding_mask,
                                need_weights=False)[0]
        return self.dropout2(x)

    # feed forward block
    def _ff_block(self, x: Tensor) -> Tensor:
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout3(x)

In [9]:
# from https://pytorch.org/tutorials/beginner/transformer_tutorial.html

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [10]:
class Transformer(Module):
    
    
    def __init__(self, d_model=64, nhead=4, num_encoder_layers=3, num_decoder_layers=3,
                dim_feedforward=512, dropout=0.1, activation=F.gelu, use_norm=False):
        
        super(Transformer, self).__init__()
        self.transformer_model = TransformerModel(d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers,
                                                  num_decoder_layers=num_decoder_layers, dim_feedforward=dim_feedforward,
                                                  dropout=dropout, activation=activation, use_norm=use_norm)
        
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        self.input_projection1 = nn.Linear(1, d_model)
        self.input_projection2 = nn.Linear(d_model, d_model)
        self.output_projection = nn.Linear(d_model, 1)
        
    def forward(self, encoder_x, decoder_x, src_mask=None, tgt_mask=None, memory_mask=None):
        """ encoder_x should be vector of size input_seq_len x bs x 1   
            decoder_x should be vector of size output_seq_len x bs x 1
        """

        
        encoder_x = self.pos_encoder(self.input_projection2(torch.sigmoid(self.input_projection1(encoder_x))))
        decoder_x = self.pos_encoder(self.input_projection2(torch.sigmoid(self.input_projection1(decoder_x))))
        

        x = self.transformer_model(encoder_x, decoder_x, 
                                   src_mask=src_mask, tgt_mask=tgt_mask, memory_mask=memory_mask)
        x = self.output_projection(x)
        return x

    
        
        
        

In [None]:
# process data for vanilla transformer. I.e. create all possible sequences. Include null values

In [None]:
series_values = data[0]['series_value']
series_values

In [None]:
date_range = pd.date_range(start='1996-03-18', periods=791, freq='D')  # can do this because all series have same length
date_range

In [None]:
# to make it easy, set a default encoder and decoder seq len
encoder_seq_len = 14  # i.e. 2 weeks of data to use in encoder
decoder_seq_len = 7

In [None]:
# think about rolling cross validation but for now just have one test set at the end
train_set_dates = date_range[date_range < '1998-01-01'] 
test_set_dates = date_range[date_range >= datetime.datetime.strptime('1998-01-01', '%Y-%m-%d') - datetime.timedelta(encoder_seq_len)] 
# i.e. the first dates we want to predict are the week starting from 1998-01-01
# have to make sure these values we want to predict never pop up in the training set, including the decoder inputs
# in other words the latest decoder output should be one day before 1998-01-01 to avoid any leakage


In [11]:
def split_data_into_sequences(data_series, encoder_seq_len, decoder_seq_len, start_date=None, end_date=None, 
                              absolute_start_date='1996-03-18'):
    
    data_series = list(pd.Series(data_series).replace('NaN', np.nan))
    date_range = pd.date_range(start=absolute_start_date, periods=len(data_series), freq='D')
    start_date_idx = date_range.get_loc(start_date) if start_date is not None else 0
    end_date_idx = date_range.get_loc(end_date) if end_date is not None else len(date_range) - 1   

    data_subset = torch.Tensor(data_series[start_date_idx : end_date_idx + 1])
    
    
    sequences_ = []
    for start in range(0, len(data_subset) - decoder_seq_len - encoder_seq_len):
        enc_seq = data_subset[start : start + encoder_seq_len]
        dec_seq = data_subset[start + encoder_seq_len - 1 : start + encoder_seq_len + decoder_seq_len - 1]
        ground_truth = data_subset[start + encoder_seq_len : start + encoder_seq_len + decoder_seq_len]
        if torch.sum(~torch.isnan(ground_truth)).item() > 0:  # sequences with all null ground truth exist. Exclude these model ruiners
            sequences_.append((enc_seq, dec_seq, ground_truth))
            
        
    
#     sequences_ = [(data_subset[start : start + encoder_seq_len],
#                    data_subset[start + encoder_seq_len - 1 : start + encoder_seq_len + decoder_seq_len - 1],
#                    data_subset[start + encoder_seq_len : start + encoder_seq_len + decoder_seq_len])
#                  for start in range(0, len(data_subset) - decoder_seq_len - encoder_seq_len)]
    
    return sequences_


def split_dataset_into_sequences(dataset, encoder_seq_len, decoder_seq_len, start_date=None, end_date=None, 
                                 absolute_start_date='1996-03-18'):
    all_sequences = []
    for time_series in dataset:
        sequences_ = split_data_into_sequences(time_series, encoder_seq_len, decoder_seq_len, start_date=start_date, 
                                               end_date=end_date, absolute_start_date=absolute_start_date)
        all_sequences.extend(sequences_)
    return all_sequences


In [12]:
encoder_seq_len = 14
decoder_seq_len = 7

In [90]:
torch.sum(~torch.isnan(torch.Tensor([np.nan, np.nan, np.nan]))).item() > 0

False

In [None]:
testing_start_date = datetime.datetime.strptime('1998-01-01', '%Y-%m-%d') 
val_start_date = testing_start_date - datetime.timedelta(days=30*4)
training_end_date = val_start_date - datetime.timedelta(days=1)
val_end_date = testing_start_date - datetime.timedelta(days=1)

In [None]:
train_sequences = split_dataset_into_sequences(series_values, encoder_seq_len, decoder_seq_len, end_date=training_end_date)

In [None]:
len(train_sequences)

In [None]:
val_sequences = split_dataset_into_sequences(series_values, encoder_seq_len, decoder_seq_len, 
                                             start_date=val_start_date, end_date=val_end_date)

In [None]:
len(val_sequences)

In [None]:
test_sequences = split_dataset_into_sequences(series_values, encoder_seq_len, decoder_seq_len, start_date=testing_start_date)

In [None]:
len(test_sequences)

In [None]:
t = torch.Tensor([1, 2, 3])
i = torch.Tensor([0, np.nan, 1]).isnan().nonzero().squeeze()
torch.cat([t[:i], t[i+1:]])

In [None]:
torch.Tensor([1]).reshape((1))

In [13]:
def create_null_masks_general(in_seq, out_seq, is_self_attn=True, seq_len_dim=0):
    in_seq_len = in_seq.shape[seq_len_dim]
    out_seq_len = out_seq.shape[seq_len_dim]
    isnan_res = in_seq.isnan().nonzero()
    if len(isnan_res) == 0:
        return torch.zeros((out_seq_len, in_seq_len))
    else:
        null_indices = isnan_res.squeeze()
        if isnan_res.shape[0] == 1:
            null_indices = null_indices.reshape((1,))
        mask = torch.zeros((out_seq_len, in_seq_len)).index_fill_(1, null_indices, float('-inf'))
        if is_self_attn:
            for index in null_indices:
                mask[index, index] = 0  # need this to prevent bug when null is in first position (of decoder self attn)
        return mask
    

In [14]:
def create_null_masks(enc_seq, dec_seq, seq_len_dim=0):
    enc_sa_null_mask = create_null_masks_general(enc_seq, enc_seq, seq_len_dim=seq_len_dim)
    dec_sa_null_mask = create_null_masks_general(dec_seq, dec_seq, seq_len_dim=seq_len_dim)
    dec_enc_null_mask = create_null_masks_general(enc_seq, dec_seq, is_self_attn=False, seq_len_dim=seq_len_dim)
    
    return enc_sa_null_mask, dec_sa_null_mask, dec_enc_null_mask

In [15]:
def create_decoder_sa_causality_mask(n_t):
    return TransformerModel.generate_square_subsequent_mask(n_t)

In [16]:
def generate_transformer_masks(encoder_seq, decoder_seq, seq_length_dim=0):
    encoder_sa_null_mask, decoder_sa_null_mask, decoder_encoder_null_mask = create_null_masks(encoder_seq, decoder_seq, 
                                                                                              seq_length_dim)
    decoder_sa_causality_mask = create_decoder_sa_causality_mask(decoder_seq.shape[seq_length_dim])
    decoder_sa_mask = decoder_sa_null_mask + decoder_sa_causality_mask
    
    return encoder_sa_null_mask, decoder_sa_mask, decoder_encoder_null_mask

In [17]:
class SequenceDatasetForTransformer(Dataset):
    def __init__(self, sequences):
        
        """ sequences should be list of tuples of len 3. 
        First item of a tuple: context (encoder input)
        Second item of a tuple: decoder input
        Third item of tuple: ground truth
        """
        self.sequences = sequences
        
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        encoder_input, decoder_input, ground_truth = self.sequences[idx]
        
        encoder_sa_null_mask, decoder_sa_mask, decoder_encoder_null_mask = generate_transformer_masks(encoder_input, 
                                                                                                      decoder_input, 
                                                                                                      seq_length_dim=0)
        ground_truth_null_nans =  ground_truth.isnan().nonzero()
        ground_truth_null_indices = torch.ones(ground_truth.shape)
        null_indices_for_sequence = ground_truth_null_nans[:, 0]
        ground_truth_null_indices[null_indices_for_sequence] = 0
        
        encoder_input = torch.nan_to_num(encoder_input)
        decoder_input = torch.nan_to_num(decoder_input)
        ground_truth = torch.nan_to_num(ground_truth)
        
        return (encoder_input, decoder_input, ground_truth, ground_truth_null_indices,
                encoder_sa_null_mask, decoder_sa_mask, decoder_encoder_null_mask)
    

In [18]:
class SequenceDatasetForTransformerTesting(Dataset):
    def __init__(self, sequences):
        
        """ sequences should be list of tuples of len 3. 
        First item of a tuple: context (encoder input)
        Second item of a tuple: decoder input
        Third item of tuple: ground truth
        """
        self.sequences = sequences
        
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        encoder_input, _, ground_truth = self.sequences[idx]
        
        ground_truth_null_nans =  ground_truth.isnan().nonzero()
        ground_truth_null_indices = torch.ones(ground_truth.shape)
        null_indices_for_sequence = ground_truth_null_nans[:, 0]
        ground_truth_null_indices[null_indices_for_sequence] = 0
        
        
        
        return (encoder_input, ground_truth, ground_truth_null_indices)

In [16]:
# def process_batch_nulls(encoder_input, decoder_input, ground_truth):  
#     # need this because null indices for ground truth will be different sizes
#     encoder_input = torch.nan_to_num(encoder_input)
#     decoder_input = torch.nan_to_num(decoder_input)
#     ground_truth_null_nans =  ground_truth.isnan().nonzero().T
#     ground_truth_null_indices = torch.ones(ground_truth.shape)  # need this to remove null elements in gt from loss
#     batch_size = ground_truth.shape[0]
    
#     for b in range(batch_size):
#         null_indices_for_sequence = ground_truth_null_nans[1][ground_truth_null_nans[0] == b]
#         ground_truth_null_indices[b, null_indices_for_sequence] = 0
        
#     ground_truth = torch.nan_to_num(ground_truth)
    
#     return encoder_input, decoder_input, ground_truth, ground_truth_null_indices

In [19]:
def format_input(encoder_input, decoder_input, encoder_sa_null_mask, decoder_sa_mask, decoder_encoder_null_mask, n_heads):
    encoder_sa_null_mask = encoder_sa_null_mask.repeat(n_heads, 1, 1)
    decoder_sa_mask = decoder_sa_mask.repeat(n_heads, 1, 1)
    decoder_encoder_null_mask = decoder_encoder_null_mask.repeat(n_heads, 1, 1)
    
    # need e, d to be of shape Length x bs x 1
    encoder_input = encoder_input.T.unsqueeze(-1)
    decoder_input = decoder_input.T.unsqueeze(-1)
    
    return encoder_input, decoder_input, encoder_sa_null_mask, decoder_sa_mask, decoder_encoder_null_mask

In [20]:
def create_train_val_test_sets(data, training_end_date, val_start_date, val_end_date, testing_start_date,
                              encoder_seq_len, decoder_seq_len):
    train_sequences = split_dataset_into_sequences(data, encoder_seq_len, decoder_seq_len, 
                                                   end_date=training_end_date)
    val_sequences = split_dataset_into_sequences(data, encoder_seq_len, decoder_seq_len, 
                                                 start_date=val_start_date, end_date=val_end_date)
    test_sequences = split_dataset_into_sequences(data, encoder_seq_len, decoder_seq_len, 
                                                  start_date=testing_start_date)
    return train_sequences, val_sequences, test_sequences

In [21]:
testing_start_date = datetime.datetime.strptime('1998-01-01', '%Y-%m-%d') 
val_start_date = testing_start_date - datetime.timedelta(days=30*4)
training_end_date = val_start_date - datetime.timedelta(days=1)
val_end_date = testing_start_date - datetime.timedelta(days=1)

In [22]:
nhead = 4

In [23]:
train_sequences, val_sequences, test_sequences = create_train_val_test_sets(series_values, training_end_date, val_start_date, 
                                                                            val_end_date, testing_start_date,
                                                                           14, 7)
train_dataset = SequenceDatasetForTransformer(train_sequences)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

In [24]:
vanilla_transformer_with_nullmask = Transformer(d_model=64, nhead=nhead, num_encoder_layers=3, num_decoder_layers=3,
                dim_feedforward=256, dropout=0.1, activation=F.gelu)

In [25]:
optimizer = torch.optim.Adam(vanilla_transformer_with_nullmask.parameters())

In [30]:
save_path = '/home/mbaroody/Programming/Time series/NN5/v3-{}.pt'

In [39]:

epoch = 0
# prev_params = None
while True:
    running_loss = 0
    print("Epoch {}".format(epoch))
    for i, data in enumerate(train_dataloader):
        if ((i + 1) % 250 == 0):
            print("Reached batch {}, running_loss was {}".format(i+1, running_loss / 250))
            running_loss = 0
            
        encoder_input, decoder_input, ground_truth, ground_truth_null_indices,\
                encoder_sa_null_mask, decoder_sa_mask, decoder_encoder_null_mask = data
        
        
        encoder_input, decoder_input, encoder_sa_null_mask,\
                decoder_sa_mask, decoder_encoder_null_mask = format_input(encoder_input, decoder_input, encoder_sa_null_mask, 
                                                                          decoder_sa_mask, decoder_encoder_null_mask, nhead)
        optimizer.zero_grad()
        outputs = vanilla_transformer_with_nullmask(encoder_input, decoder_input, encoder_sa_null_mask, decoder_sa_mask, 
                                                    decoder_encoder_null_mask)
        # outputs is shape dec_len x bs x 1
        outputs = outputs.squeeze().T

            
        # delete items that are null in target
        losses_nonull_per_sequence = torch.sum(((outputs - ground_truth) * ground_truth_null_indices) ** 2, 1)
        batch_divide = ground_truth_null_indices.shape[1] - torch.sum(1 - ground_truth_null_indices, 1)
        losses_nonull = torch.mean(losses_nonull_per_sequence / batch_divide)
        
        losses_nonull.backward()
        optimizer.step()
        
#         flag = False
#         for param in vanilla_transformer_with_nullmask.parameters():
#             if torch.sum(torch.isnan(param.data)) > 0:
#                 flag = True
#                 break
#         if flag:
#             print('outputs')
#             print(outputs)
#             print('ground truth')
#             print(ground_truth)
#             print('ground truth null indices')
#             print(ground_truth_null_indices)
#             print('loss a')
#             print(losses_nonull_per_sequence)
#             print('batch divide')
#             print(batch_divide)
#             print('loss')
#             print(losses_nonull.item())
            
#             break
                
#         prev_params = vanilla_transformer_with_nullmask.parameters()
        
        running_loss += losses_nonull.item()

        
    print('Saving')
    torch.save(vanilla_transformer_with_nullmask.state_dict(), save_path.format(epoch))
    epoch += 1
        
    

Epoch 0
Reached batch 250, running_loss was 73.03599977111817
Reached batch 500, running_loss was 50.69700951004028
Reached batch 750, running_loss was 41.87771994781494
Reached batch 1000, running_loss was 36.60844202804565
Reached batch 1250, running_loss was 37.05646192169189
Reached batch 1500, running_loss was 37.133101928710936
Reached batch 1750, running_loss was 35.058937873840335
Reached batch 2000, running_loss was 34.744815811157224
Reached batch 2250, running_loss was 35.490919330596924
Reached batch 2500, running_loss was 33.82575410842895
Reached batch 2750, running_loss was 36.09074656677246
Reached batch 3000, running_loss was 33.20457484817505
Reached batch 3250, running_loss was 33.82931559181213
Reached batch 3500, running_loss was 33.41293079376221
Reached batch 3750, running_loss was 33.63183804512024
Reached batch 4000, running_loss was 34.480533950805665
Reached batch 4250, running_loss was 34.516366176605224
Reached batch 4500, running_loss was 32.01423648834228

Reached batch 2000, running_loss was 30.500821786880493
Reached batch 2250, running_loss was 31.269571327209473
Reached batch 2500, running_loss was 29.8469110660553
Reached batch 2750, running_loss was 28.672202409744262
Reached batch 3000, running_loss was 28.809651374816895
Reached batch 3250, running_loss was 28.629059427261353
Reached batch 3500, running_loss was 27.80919753074646
Reached batch 3750, running_loss was 28.58871251296997
Reached batch 4000, running_loss was 29.249101085662844
Reached batch 4250, running_loss was 30.506778217315674
Reached batch 4500, running_loss was 28.67996919250488
Reached batch 4750, running_loss was 28.060999586105346
Reached batch 5000, running_loss was 29.450410301208496
Reached batch 5250, running_loss was 28.772467542648315
Reached batch 5500, running_loss was 31.28061215209961
Reached batch 5750, running_loss was 29.787938190460206
Reached batch 6000, running_loss was 28.994168886184692
Reached batch 6250, running_loss was 31.16738592338562

KeyboardInterrupt: 

In [26]:
val_dataset = SequenceDatasetForTransformerTesting(val_sequences)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=True)

In [27]:
def create_null_mask_for_batch(in_seq, out_seq, is_self_attn=True):
    in_seq_len = in_seq.shape[1]
    out_seq_len = out_seq.shape[1]
    batch_size = in_seq.shape[0]
    null_indices = in_seq.isnan().nonzero()
    if len(null_indices) == 0:
        return torch.zeros((batch_size, out_seq_len, in_seq_len))
    else:
        mask = torch.zeros((batch_size, out_seq_len, in_seq_len))
        for i in range(batch_size):
            sequence_nulls =  null_indices[null_indices[:, 0] == i]
            sequence_nulls = sequence_nulls[:, 1]
            mask[i].index_fill_(1, sequence_nulls, float('-inf'))
            if is_self_attn:
                for index in sequence_nulls:
                    mask[i, index, index] = 0  # need this to prevent bug when null is in first position (of decoder self attn)
        return mask

In [28]:
def create_decoder_causality_mask_batch(dec_seq_len, bs):
    return TransformerModel.generate_square_subsequent_mask(dec_seq_len).unsqueeze(0).repeat(bs, 1, 1)

In [29]:
def create_batch_masks(enc_seq_batch, dec_seq_batch):
    # should be shape bs x seq_len
    enc_mask = create_null_mask_for_batch(enc_seq_batch, enc_seq_batch, is_self_attn=True)
    dec_sa_null_mask = create_null_mask_for_batch(dec_seq_batch, dec_seq_batch, is_self_attn=True)
    dec_sa_causality_mask = create_decoder_causality_mask_batch(dec_seq_batch.shape[1], dec_seq_batch.shape[0])
    dec_sa_mask = dec_sa_null_mask + dec_sa_causality_mask
    dec_enc_mask = create_null_mask_for_batch(enc_seq_batch, dec_seq_batch, is_self_attn=False)
    
    return enc_mask, dec_sa_mask, dec_enc_mask
    
    
    

In [174]:
# e, g, gi = next(iter(val_dataloader))
# print(e)

# d = e[:, -1].unsqueeze(-1)

# create_batch_masks(e, d)

In [51]:
# vanilla_transformer_with_nullmask_ = Transformer(d_model=64, nhead=nhead, num_encoder_layers=2, num_decoder_layers=2,
#                 dim_feedforward=256, dropout=0.1, activation=F.gelu)
vanilla_transformer_with_nullmask_ = Transformer(d_model=64, nhead=nhead, num_encoder_layers=3, num_decoder_layers=3,
                dim_feedforward=256, dropout=0.1, activation=F.gelu)
vanilla_transformer_with_nullmask_.load_state_dict(torch.load('/home/mbaroody/Programming/Time series/NN5/v3-1.pt'))

<All keys matched successfully>

In [30]:
def evaluate_model(model, eval_dataloader):
    model.eval()
    with torch.no_grad():
        total_loss = 0
        batch_count = 0
        for i, val_data in enumerate(eval_dataloader):  # todo refactor code to make val/test stepthroughs easier, remove redundant


            encoder_input, ground_truth, ground_truth_null_indices = val_data

            decoder_input = encoder_input[:, -1].unsqueeze(-1)
            enc_mask = create_null_mask_for_batch(encoder_input, encoder_input, is_self_attn=True)

            encoder_input = torch.nan_to_num(encoder_input)

            ground_truth = torch.nan_to_num(ground_truth)

            batch_errors = torch.zeros((ground_truth.shape[0], ground_truth.shape[1]))
            batch_predictions = torch.zeros((ground_truth.shape[0], ground_truth.shape[1]))
    #         print('enc input')
    #         print(encoder_input)
    #         print('ground truth')
    #         print(ground_truth)
            for h in range(ground_truth.shape[1]):
                dec_sa_null_mask = create_null_mask_for_batch(decoder_input, decoder_input, is_self_attn=True)
                dec_sa_causality_mask = create_decoder_causality_mask_batch(decoder_input.shape[1], decoder_input.shape[0])
                dec_sa_mask = dec_sa_null_mask + dec_sa_causality_mask
                dec_enc_mask = create_null_mask_for_batch(encoder_input, decoder_input, is_self_attn=False)
                decoder_input_ = torch.nan_to_num(decoder_input)

                encoder_input_, decoder_input_, enc_mask_, \
                    dec_sa_mask, dec_enc_mask = format_input(encoder_input, decoder_input_, 
                                                             enc_mask, dec_sa_mask, 
                                                             dec_enc_mask, nhead)
    #             print('dec in')
    #             print(decoder_input_)
    #             print(dec_sa_mask)
    #             print(dec_enc_mask)
                outputs = model(encoder_input_, decoder_input_, enc_mask_, dec_sa_mask, 
                                                            dec_enc_mask)
    #             print('outputs')
    #             print(outputs)
                if h > 0:
                    timestep_outputs = outputs.squeeze().T[:, h]
                else:
                    timestep_outputs = outputs.reshape(outputs.shape[0], outputs.shape[1]).T[:, h]
    #             print('latest')
    #             print(timestep_outputs)
                batch_predictions[:, h] = timestep_outputs
                timestep_batch_loss = (ground_truth[:, h] - timestep_outputs) ** 2
                batch_errors[:, h] = timestep_batch_loss

                decoder_input = torch.cat((decoder_input, timestep_outputs.unsqueeze(-1)), 1)

            losses_nonull_per_sequence = torch.sum(batch_errors, 1)
            batch_divide = ground_truth_null_indices.shape[1] - torch.sum(1 - ground_truth_null_indices, 1)
            error_nonull = torch.mean(losses_nonull_per_sequence / batch_divide)


            print('batch_predictions')
            print(batch_predictions)
            print('ground truth')
            print(ground_truth)
    #         print('batch error')
    #         print(error_nonull)
            batch_count += 1
            total_loss += error_nonull
        avg_loss = total_loss / batch_count
        return avg_loss

            
      
        

In [31]:
evaluate_model(vanilla_transformer_with_nullmask_, val_dataloader)

NameError: name 'vanilla_transformer_with_nullmask_' is not defined

In [32]:
test_dataset = SequenceDatasetForTransformerTesting(test_sequences)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True)

In [58]:
evaluate_model(vanilla_transformer_with_nullmask_, test_dataloader)

batch_predictions
tensor([[14.4533, 15.0215, 14.0074, 15.0193, 17.5832, 24.3482, 24.4432],
        [10.8278, 13.0628, 17.2765, 21.4622, 18.0658, 10.1093, 11.2110],
        [13.0939, 11.1450, 11.0837, 11.5120, 16.7141, 25.4750, 25.8301],
        [23.4957,  8.0764, 11.1902, 10.6207, 10.5891, 15.6228, 27.1202],
        [15.3418, 22.7115, 27.4751, 20.8706, 12.3762, 13.7968, 14.0079],
        [17.5687, 27.6394, 41.0251, 33.7790, 15.2059, 18.5604, 17.0168],
        [20.7200, 12.2637, 12.7205, 12.1016, 13.0882, 19.0379, 24.9836],
        [27.1492, 24.5650, 23.8544, 26.6158, 34.9255, 51.7365, 49.1516]])
ground truth
tensor([[13.3088, 16.9253, 13.9663, 15.7549, 21.3966, 21.0153, 27.1831],
        [ 9.3109, 14.2031, 18.6481, 31.6938, 39.2031, 17.5697, 15.6891],
        [15.6102, 12.5460, 11.0731,  9.3372, 17.8985, 25.9469,  7.2856],
        [27.6170,  8.6928, 11.9279,  9.9816, 12.2567, 17.4382, 31.0363],
        [12.6907, 27.6565, 32.4961, 24.3951, 18.6875, 14.7422, 14.7685],
        [19.5949, 2

batch_predictions
tensor([[14.8978, 16.3136, 20.2336, 24.1461, 24.6050, 11.1329, 15.5892],
        [ 7.8543,  9.1375, 10.3501, 13.4615, 13.1107,  7.5598,  7.7805],
        [11.7566, 13.4077, 15.4122, 19.1295, 18.8638, 18.7799, 11.5193],
        [16.8552,  8.7077, 10.9598, 11.2986, 11.5924, 14.2490, 21.5965],
        [13.1606, 16.7871, 21.8607, 23.4526, 21.9684, 14.3408, 14.0190],
        [13.1713, 15.1047,  7.1575,  6.6644,  7.0184,  7.7302, 10.0560],
        [10.7955, 12.9889, 11.2398,  9.9905, 14.4765, 15.8455,  9.7966],
        [11.0346, 12.9461, 16.5782, 24.0102, 22.3471,  7.0541, 11.6226]])
ground truth
tensor([[18.5034, 19.5029, 18.1220, 21.9095, 30.9311, 12.1646, 13.7033],
        [ 8.8243, 11.8753, 10.6391, 16.1625, 16.2415,  6.8254,  7.1673],
        [ 8.8294, 13.2653, 13.2370, 20.3231, 22.9592, 19.3594, 12.7409],
        [26.4729, 10.5997, 10.0342, 11.4940, 10.2446, 15.3603, 17.6749],
        [12.3225, 18.3193, 23.7770, 42.8853, 39.0058, 17.6749, 19.2925],
        [12.9274,  

batch_predictions
tensor([[14.5069, 16.4294, 19.0423, 19.9567, 18.7685,  7.6664, 11.6821],
        [13.4764, 12.5158, 15.0263, 20.9575, 29.1098, 23.9915, 14.1465],
        [13.0518, 14.1040, 14.9801, 22.0641, 29.5771, 25.4113, 14.2302],
        [16.3377, 23.2227, 17.4442,  9.1561, 10.2580, 10.7409, 11.2954],
        [19.1689, 20.6188, 19.1729, 20.5825, 26.8619, 43.0532, 42.1938],
        [18.5519, 25.1770, 36.4902, 34.9946, 18.8222, 20.4139, 19.4498],
        [12.2551, 15.7981, 13.4075, 13.6532, 20.7192, 29.7764, 30.2052],
        [10.0240, 10.3090, 10.8668, 14.3791, 23.6896, 23.4694,  8.0174]])
ground truth
tensor([[10.4734,  9.9065, 13.4070, 19.8696,  3.1746,  4.7052, 19.7137],
        [16.0289,  9.3254, 14.1582, 22.2647, 30.3713, 24.2489, 31.9019],
        [10.7710, 12.5992, 15.5329, 22.3498, 35.7143, 27.5085, 12.8260],
        [16.6241, 35.5300, 35.5159, 15.8730, 15.5612, 11.4938, 13.1378],
        [34.1553, 30.1446, 26.6015, 24.4473,  8.2625, 39.5408, 46.3435],
        [19.5862, 2

batch_predictions
tensor([[38.5417, 29.2367, 15.7918, 16.4912, 16.5205, 18.1200, 26.4376],
        [13.3887, 12.5969, 11.1133, 11.1142, 14.0296, 20.2683, 23.1089],
        [27.8409, 18.8936, 18.5881, 18.4337, 21.6959, 32.6462, 42.8075],
        [14.0320, 14.2593, 15.1548, 19.8968, 25.9192, 18.8390, 14.2052],
        [ 7.7913,  7.8652,  8.1154,  8.2398,  9.5041, 14.4125, 18.1945],
        [24.6042,  8.7783, 12.8458, 11.8793, 12.2271, 18.2001, 28.1769],
        [19.6268, 22.9512,  9.1145, 12.0825, 13.1817, 14.0113, 15.6501],
        [19.1932, 15.9918,  8.1897,  9.0991,  9.9171, 10.4993, 15.5893]])
ground truth
tensor([[49.6173, 35.8560,  0.0000, 12.0607, 22.6616, 19.7279, 26.1763],
        [12.6417, 12.8260, 11.6355, 12.8685, 21.0176, 28.9683, 20.6916],
        [32.4688, 23.2285, 22.8741, 17.8146, 16.9501, 25.2693, 41.2273],
        [18.1973, 12.6559, 20.7766, 22.6616, 31.0516, 35.4592, 14.3707],
        [ 0.0000,  5.1552,  7.7985,  8.0484,  5.8916, 17.6749, 18.0957],
        [28.7698,  

batch_predictions
tensor([[11.6739, 11.1113, 11.3129, 14.1589, 22.3772, 20.5489, 11.3539],
        [10.3037, 10.3026, 10.3585, 12.8391, 19.9579, 19.9792, 10.6826],
        [18.0012, 19.2982,  9.7116,  9.5569,  9.4997,  9.4684, 11.6683],
        [11.7805, 14.1706, 12.4507, 12.8110, 16.3701, 25.7554, 27.8708],
        [ 8.2556, 13.3799, 12.4078, 12.4587, 15.4349, 23.1830, 27.5927],
        [ 9.6926,  6.5360,  7.3758,  9.0925,  9.9726, 13.2618, 19.0171],
        [14.4319, 13.6153, 14.8659, 19.0367, 21.3607, 21.7911, 13.2109],
        [10.8454, 12.0657, 13.0203, 14.1627, 17.1758, 22.2248, 15.4951]])
ground truth
tensor([[12.9252, 12.9110, 11.3095, 12.6417, 25.2268,  0.0000,  0.0000],
        [14.4003,  8.9690,  7.5881, 13.9137, 21.4361, 21.8306,  8.4429],
        [21.7649, 22.1725, 13.0326, 12.1120,  9.9421, 11.0863, 12.1646],
        [ 8.7160, 13.2795,  9.4813, 15.4762, 16.7234, 27.4660, 30.2438],
        [ 5.0879, 12.7126, 13.6054, 12.8968, 20.7341, 24.8299, 29.6769],
        [19.5292,  

batch_predictions
tensor([[12.4565,  7.6213,  6.8279,  7.5932,  8.1715,  9.4991, 13.0530],
        [ 9.8319, 11.5284, 13.6814, 17.4215,  6.7329,  6.9228,  9.7889],
        [14.5461, 16.7957, 22.8866, 22.8192, 12.5427, 12.5748, 13.8327],
        [28.6611, 37.9018, 27.6937, 17.4512, 18.3811, 16.7938, 17.1228],
        [25.8691, 24.4908, 13.0008, 13.9208, 13.7841, 14.2709, 20.1991],
        [30.4693, 29.2410, 20.6192, 15.9074, 15.6673, 14.8318, 18.3520],
        [ 8.9713,  8.9536, 10.1618, 13.3386,  8.6761,  7.4670,  9.1508],
        [ 6.6519,  7.5553,  9.3733,  9.9271, 12.6152, 18.7787, 21.5137]])
ground truth
tensor([[15.0447, 13.0852,  6.8780, 11.0205,  8.8901, 11.0468, 12.0594],
        [ 7.6670,  9.7712, 17.1094,  2.5118,  8.2194, 11.5203,  9.0479],
        [17.3330, 22.2120, 21.3309, 26.4335, 13.1510,  3.7612, 10.0605],
        [34.3537, 45.5074, 37.4291, 15.3203, 15.5896, 18.8776, 23.0017],
        [29.9603,  8.7727, 10.9410, 17.0068, 16.3690, 21.0743, 26.0346],
        [45.1105, 2

batch_predictions
tensor([[21.3133, 12.1979, 12.4044, 11.9747, 12.9373, 17.1369, 25.1677],
        [19.0613, 25.5765, 36.6821, 29.1503, 12.9723, 16.1511, 16.6046],
        [22.3573, 25.4685, 36.1304, 46.5513, 39.0014, 23.1041, 24.1018],
        [32.8186, 31.1167, 11.8777, 15.6243, 15.5605, 16.3381, 21.3615],
        [ 7.5311,  8.5302, 10.1622, 14.4116, 18.0395,  6.4507,  7.5076],
        [34.2243, 14.9118, 17.4236, 16.8897, 17.5677, 26.3448, 41.2123],
        [23.0829, 17.2424, 15.9761, 14.9183, 15.8809, 28.1679, 34.7831],
        [16.4857, 16.7195, 17.0650, 19.5034, 26.7854, 31.7416, 27.8541]])
ground truth
tensor([[24.9575, 12.8827, 12.2874, 11.7489,  6.9019, 21.7545, 26.1338],
        [17.3044, 22.1514, 36.2812, 43.1689, 20.6066, 16.4116, 16.8934],
        [24.5748, 25.4819,  7.4405, 90.2778, 50.4819, 33.2200, 28.5573],
        [38.8605, 39.4558,  7.8231, 17.6020, 18.4382, 19.0618, 23.0867],
        [ 9.9684, 10.6523, 10.8627, 18.4903, 21.1599,  3.9584,  8.5613],
        [25.6094, 1

batch_predictions
tensor([[14.5978, 18.5171, 23.3632, 17.4507, 12.9032, 13.5041, 13.4095],
        [24.2584, 12.1456, 13.6761, 12.7028, 14.6593, 25.3495, 34.0038],
        [25.4107, 10.2486, 13.3568, 13.5731, 14.6574, 20.3282, 25.1898],
        [16.6184, 24.5685, 20.8497, 11.0390, 11.4754, 12.0070, 13.0997],
        [14.8676, 17.4501, 22.6091, 30.8440, 25.0279, 13.0841, 16.2991],
        [16.2134, 23.8599, 28.7092,  8.4645, 13.7204, 14.3813, 14.5109],
        [46.9465, 35.5626, 22.7932, 22.2192, 21.7688, 24.6896, 35.4314],
        [15.0222, 19.6310, 19.2921,  8.5180,  9.5827, 10.7243, 11.2648]])
ground truth
tensor([[14.6370, 15.4787, 29.1426, 28.4850, 29.1031, 15.2683, 16.5834],
        [ 8.2908,  6.5618, 11.5646,  6.4059, 14.7959, 32.5964, 38.9314],
        [24.4898, 12.2449, 14.6825, 14.1723, 12.1032, 18.4807, 29.9745],
        [18.5941, 28.0045, 27.3810, 14.5833, 13.9456, 12.4717, 14.1015],
        [15.5187, 17.5737,  4.3226, 37.1457, 34.5947, 14.6117, 19.7279],
        [18.7270, 2

batch_predictions
tensor([[18.6807, 24.8439, 25.6746, 14.7631, 13.3192, 13.6868, 13.9592],
        [14.3133, 14.7700, 13.0069, 13.2477, 15.8624, 25.4834, 28.7275],
        [15.2157, 21.6507, 29.2934, 25.8149, 12.5003, 15.3298, 14.9244],
        [12.6812, 13.9489, 10.3497,  7.6983,  8.3203,  8.8404,  8.6382],
        [18.9415, 22.8139, 22.2590,  9.4598, 12.7799, 13.7634, 14.8501],
        [ 9.3425,  9.8551, 12.2672, 17.8479, 16.3878,  6.8338,  8.3159],
        [19.8627, 10.2991, 12.3897, 12.3610, 12.2923, 19.5588, 28.2752],
        [13.1795, 19.5818, 25.7975, 22.7541, 10.2586, 12.7255, 12.8752]])
ground truth
tensor([[18.4524, 26.6723, 25.5527, 11.7630, 16.6950, 13.8039, 17.8855],
        [10.8985, 16.2557, 13.2511, 15.4904, 20.0539, 28.6990, 35.4025],
        [12.2874, 19.4303, 27.6927, 28.5998, 11.4654, 11.2103, 10.5300],
        [22.2931, 16.7942,  0.0000,  7.9082,  9.0278, 10.1049, 12.6559],
        [16.4650, 19.6476, 26.6965, 13.4008, 15.7154, 15.9127, 14.8606],
        [ 8.9164, 1

batch_predictions
tensor([[13.3401, 13.5316, 13.6029, 21.6433, 34.3671, 29.0092, 13.5962],
        [12.1070, 13.0861, 16.9249, 22.7948, 23.0632, 11.3281, 13.4923],
        [19.6095, 10.0012, 11.4974, 10.9226, 11.0820, 18.4696, 26.9353],
        [10.8306, 14.3726, 12.8995, 12.1513, 19.3602, 33.0354, 25.5826],
        [14.4137, 14.5538, 16.2534, 23.2331, 25.6283, 13.3403, 16.1157],
        [37.5364, 22.9596, 13.1215, 14.5748, 13.9244, 17.3846, 29.2382],
        [34.3509, 14.9051, 17.2839, 16.2016, 16.3799, 25.4971, 40.6336],
        [19.2140, 18.9947, 18.6200, 19.2941, 21.2287, 26.6901, 32.2890]])
ground truth
tensor([[14.5692, 16.4683, 17.7438, 22.3214, 33.2908, 38.6196, 14.5266],
        [ 9.6791, 14.2162, 18.1878, 23.2772, 22.6460, 12.0200, 12.0200],
        [29.4319, 12.6118,  7.6539, 13.8743, 12.6644, 18.6875, 24.4477],
        [14.6967, 13.2511, 21.1168, 16.4116, 21.6978, 28.7840, 34.7506],
        [12.5197, 14.6633, 16.9122, 21.1862, 26.3809, 12.2041, 12.2436],
        [39.2007, 2

batch_predictions
tensor([[16.1868, 17.0214, 22.5693, 26.0960, 12.6746, 14.6877, 16.3638],
        [13.9147, 14.9995, 18.9485, 18.3515, 11.9456, 13.2441, 13.7656],
        [20.5258, 27.4683, 26.4583, 12.7318, 14.6849, 15.5305, 16.7975],
        [31.5917, 21.1096, 17.0195, 16.7521, 16.9116, 20.1125, 25.2259],
        [15.2898, 11.3340, 11.6734, 11.6035, 11.4719, 15.6622, 21.1921],
        [22.0493, 22.7455, 11.5733, 11.3894, 11.4539, 12.4660, 16.8688],
        [15.2329,  7.3948,  8.1687,  8.4652,  8.9478, 12.5849, 20.0058],
        [32.2540, 23.6798, 23.4974, 24.4064, 26.4573, 39.5374, 52.1817]])
ground truth
tensor([[18.0300, 18.0826, 24.4871, 28.7217,  9.7186, 16.5965, 13.5324],
        [15.2814, 16.8990, 24.1189, 19.5949, 14.7554, 13.8085, 12.9274],
        [23.4410, 32.1429, 34.0278, 17.5312, 13.7755, 18.1973, 18.0272],
        [28.0510, 14.7817, 13.8874, 23.1326, 16.3204, 17.3856, 26.2756],
        [27.6927, 11.3520,  6.9444, 16.3832, 12.1457, 14.9376, 26.0629],
        [19.2399, 2

batch_predictions
tensor([[14.4573, 25.2990, 33.0363, 29.2247, 13.2820, 15.6985, 14.6566],
        [27.1375, 12.8041, 14.7012, 14.5104, 15.1489, 20.8327, 29.1783],
        [ 7.7331,  8.1730,  8.0034,  9.6599, 12.7146, 15.9604,  6.7763],
        [13.4786, 13.7453, 13.3869, 14.2328, 22.2387, 33.5665, 18.9201],
        [10.4370, 10.9278, 11.9167, 17.7177, 26.4957, 24.0500,  7.5586],
        [22.2831, 23.6213, 16.3355, 16.3505, 17.6547, 18.0675, 20.4073],
        [24.7820, 20.2824, 10.7838, 11.4732, 11.6914, 11.9275, 16.8511],
        [40.0485, 31.3174, 14.5386, 16.8069, 17.0121, 17.4479, 24.8148]])
ground truth
tensor([[18.9909, 31.3634, 50.3968, 32.9223,  0.0000, 10.6434, 16.1706],
        [ 9.8498, 14.8101, 18.1973, 12.6559, 20.7766, 22.6616, 31.0516],
        [ 6.6468,  5.3430,  7.7098,  9.4246, 13.4637, 19.8696,  5.8107],
        [ 7.3838, 20.8192, 11.5930, 15.8022, 20.1531, 40.6179, 36.1536],
        [ 6.8911, 12.8748,  9.1662, 15.3472, 20.3840, 23.9085,  5.0237],
        [27.4660, 3

batch_predictions
tensor([[12.9490, 13.1546, 16.0834, 19.2275, 21.3788, 15.5086, 13.3330],
        [18.4632,  7.7019,  8.2193, 10.4205, 10.1869, 12.1407, 18.5598],
        [12.6799, 13.1819, 13.4224, 21.2868, 28.0853, 13.4422, 14.8652],
        [35.0499, 46.5934, 35.3906, 21.0302, 22.4209, 22.3636, 23.4796],
        [14.5524, 14.0523, 14.4370, 19.1101, 27.3512, 25.6016, 13.7827],
        [28.0449, 21.1157, 12.2543, 12.4154, 11.8600, 14.6279, 20.4863],
        [14.6294, 15.0314, 16.9879, 31.3582, 38.9898, 20.8361, 17.4671],
        [ 8.1554,  9.2287, 10.2154,  6.9574,  5.9997,  6.9851,  7.6328]])
ground truth
tensor([[18.0431, 14.7422, 16.1099, 25.7496,  7.5355, 11.4676, 14.5055],
        [24.2635,  9.2714,  3.7875,  2.3935, 13.2693, 13.2562, 14.7817],
        [11.3756, 13.8348, 17.9905, 33.9295, 32.1278,  5.6944, 19.1610],
        [ 7.4405, 90.2778, 50.4819, 33.2200, 28.5573, 24.1071, 28.0187],
        [13.2167, 13.9400, 12.5855, 14.6633, 26.9069, 27.3540, 14.0452],
        [32.8798, 2

batch_predictions
tensor([[12.8958, 12.9893,  6.9460,  6.7889,  7.8180,  8.6268, 10.5566],
        [16.0075, 21.7976, 28.7582, 27.4756,  9.5087, 14.6883, 15.9797],
        [13.9888, 16.1876, 21.0658, 27.8887, 23.0303, 10.9651, 14.8935],
        [15.3397, 23.5296, 21.6083, 10.4215, 10.7960, 11.1264, 11.3161],
        [13.4369, 13.9028, 21.2874, 27.1509, 12.8590, 13.1617, 14.1450],
        [19.3430, 18.7181, 17.2125, 16.6966, 26.7308, 47.2966, 50.2440],
        [13.5151, 15.8850, 17.4403, 24.4856, 33.7821, 31.5752,  9.7512],
        [29.5490, 39.7602, 41.7579, 29.3671, 24.7253, 24.6184, 24.0655]])
ground truth
tensor([[16.8201, 17.5171,  6.7596,  7.7196,  7.3645,  8.8769, 10.6128],
        [17.6729, 23.3277, 30.2438, 41.8651, 12.7126, 23.9512, 27.3101],
        [14.4661, 19.5818, 28.9716, 28.3535, 17.9116,  4.6554,  9.5345],
        [19.4371, 24.2241, 26.3677, 20.1867, 11.0863,  9.1662, 11.8622],
        [13.8348, 17.9905, 33.9295, 32.1278,  5.6944, 19.1610, 18.3325],
        [ 9.7843, 1

batch_predictions
tensor([[21.9032, 28.4560, 20.6718, 12.1227, 13.4468, 13.0794, 13.9448],
        [19.4407, 18.4218, 18.7122, 28.4735, 44.1813, 34.8099, 16.0533],
        [11.2573, 13.2876, 11.6758, 12.6707, 16.7234, 26.1315, 26.7534],
        [17.2086, 17.0902, 17.9456, 27.0192, 39.6275, 30.2221, 18.8003],
        [15.7633, 15.6699, 15.7631, 22.7015, 26.2996, 20.9521, 15.6010],
        [20.7851,  9.2988, 11.0821, 11.3681, 11.4984, 14.8757, 23.2868],
        [18.4730, 22.1674,  9.5159, 10.6016, 10.5059, 10.5919, 12.7715],
        [15.2891, 14.7770, 13.0752, 13.3613, 18.4149, 23.0113, 18.9717]])
ground truth
tensor([[23.3430, 33.4824, 24.8553, 12.9406, 15.5708, 12.3751, 14.2425],
        [17.3328, 20.5924, 25.5527, 33.3900, 51.8707, 41.5958, 21.7262],
        [ 3.8138,  0.0000,  1.0652, 12.7564, 15.1894, 28.5376, 28.7743],
        [20.1078, 22.8958, 17.9774, 23.9611, 38.9663, 34.1531, 11.0731],
        [16.3690, 14.2290, 16.2415, 22.2647, 33.0357, 11.1678, 12.0890],
        [23.2378,  

batch_predictions
tensor([[13.9246, 12.8128, 12.8748, 19.6330, 33.7239, 33.3164, 14.4158],
        [24.1305, 26.7526, 32.2229, 20.5807, 19.5421, 22.9529, 23.8510],
        [34.6251, 29.3286, 14.8646, 16.5024, 16.2874, 16.0214, 22.4994],
        [20.6709,  9.0238, 10.5859, 10.6547, 11.1116, 13.5923, 17.6597],
        [18.9464, 24.1540, 18.8906, 11.9887, 11.6578, 12.3697, 13.5089],
        [22.6605, 21.9937, 11.7259, 10.5820, 10.2890, 11.2876, 15.1517],
        [22.1340, 19.5965, 12.5354, 13.0399, 14.0559, 18.8417, 22.0699],
        [16.5854, 14.1268, 13.0730, 13.4667, 17.3617, 25.5155, 30.0108]])
ground truth
tensor([[16.4913, 11.9279, 13.9137, 21.2651, 26.0126, 26.0784, 13.2036],
        [26.1441, 27.2094, 37.2304, 39.7948,  6.1021, 18.2141, 24.2372],
        [28.7982, 34.2545, 17.1060, 15.7455, 14.6259, 15.2494, 26.3180],
        [24.5660,  7.4171, 11.5729, 12.1383, 12.4934, 15.5313, 15.0973],
        [17.9847, 28.4722, 19.6570, 12.6559, 13.4070, 13.8180, 13.9031],
        [27.1305, 1

batch_predictions
tensor([[10.2833, 11.9365, 15.9207, 14.3108,  7.7075,  7.5752,  8.7352],
        [26.1836, 25.5716, 11.1408, 13.3754, 13.5822, 15.4595, 19.0471],
        [19.1113, 27.1034, 23.2698, 10.7006, 12.8430, 13.2587, 13.8017],
        [28.7254, 22.9143, 11.0993, 13.3321, 13.3515, 14.6622, 19.9942],
        [11.2991, 14.3260, 18.6876, 23.0320, 20.4859,  9.7564, 11.5668],
        [11.6297, 11.8261, 16.4906, 22.5585, 21.1717, 11.7022, 14.0402],
        [13.9669, 14.2844, 13.7282, 13.2899, 16.3081, 24.1253, 27.9478],
        [19.2417, 15.5029, 15.1530, 14.7886, 19.3283, 36.2330, 37.4867]])
ground truth
tensor([[ 8.8901,  9.3898, 14.8080, 14.8343,  0.0000,  2.4855,  8.1405],
        [27.0833, 26.9133, 12.3866, 11.7489, 11.1961, 20.0397, 26.5448],
        [21.3152, 28.3305, 28.5147, 15.1644, 14.7251, 12.9819, 14.0023],
        [29.7477, 27.2959,  7.4263, 14.8668, 13.3078, 14.3282, 17.7296],
        [ 8.6796, 14.2951, 18.6612, 20.8706,  0.0000,  0.0000,  5.3393],
        [18.8350, 1

batch_predictions
tensor([[14.7016, 14.9493, 13.9452, 15.6661, 25.3991, 30.9172, 23.7057],
        [10.7243, 11.4984, 12.9384, 21.8023, 23.5930,  8.8590, 12.9582],
        [31.3603, 28.3502, 12.7762, 16.0478, 16.5567, 17.1097, 24.6062],
        [19.3360, 20.1991, 20.4302, 23.6204, 31.0642, 40.2957, 37.2786],
        [13.6749, 22.0073, 24.0434, 20.1200, 12.2186, 12.4443, 12.3580],
        [40.5683, 33.3876, 13.0166, 16.1001, 16.1287, 16.2854, 28.9512],
        [13.9289, 17.7247, 25.0092, 22.1793, 11.7317, 11.9500, 12.6883],
        [13.2112, 13.5345, 14.0640, 17.5748, 27.5254, 31.6584, 10.9854]])
ground truth
tensor([[13.7755, 11.8481, 14.6684, 20.6633, 25.1701, 33.7302, 28.3872],
        [11.6649, 11.1257, 10.7575, 28.5508, 28.7086, 10.5339, 11.5071],
        [41.2698, 32.9932, 12.0465, 15.8022, 15.4195, 23.0867, 37.2874],
        [18.5232, 19.3594, 23.0584, 26.9558, 38.6480, 33.7160, 32.8231],
        [10.8418, 19.7562, 29.7902,  6.5760,  5.7115,  8.5884, 14.4983],
        [58.3772, 5

batch_predictions
tensor([[ 7.3527,  6.4981,  6.3899,  6.2476,  6.5825,  7.3225, 11.0023],
        [ 8.4596,  8.1348,  7.2819,  7.8649,  7.8407,  8.1844,  9.1113],
        [ 7.1490,  8.6132,  8.2483,  8.1774, 11.0900, 15.2280, 13.7740],
        [ 9.6268, 10.7364, 10.0216, 10.6436, 12.7713, 20.8363, 19.7090],
        [11.6030, 14.0569, 19.8721, 24.2778, 11.7951, 11.7115, 12.6140],
        [20.4007, 14.2612, 14.3357, 14.0504, 17.1015, 26.2462, 33.0795],
        [22.3245, 11.8404, 11.1669, 10.4381, 10.9561, 15.7473, 24.4088],
        [20.0789, 22.2839, 17.0526,  9.7439, 10.9919, 11.6070, 13.8513]])
ground truth
tensor([[ 7.3838,  6.4626,  8.5034,  6.1933,  4.5210,  5.4989, 10.0057],
        [15.7023, 11.3361,  7.2330, 10.0868,  8.1273, 10.5208, 14.6107],
        [ 3.1179,  7.0862,  8.3475,  9.2545, 15.2353, 17.0777, 26.5448],
        [ 8.8577,  9.5805,  8.8294, 10.4450, 12.3866, 24.6740, 18.6650],
        [10.4308, 13.2937, 20.3656, 36.5363, 21.5703, 20.7058, 20.2664],
        [38.0811, 1

batch_predictions
tensor([[19.7697, 17.6192,  9.4288,  9.1763,  8.9784,  9.7786, 13.9912],
        [17.8758,  9.1121, 10.2875, 10.2311, 10.9078, 16.2359, 22.1379],
        [12.6001, 17.8277, 27.2719, 27.1308, 13.1756, 12.2791, 12.6750],
        [ 6.1289,  6.6727,  7.2734,  8.4356,  9.3327,  6.3124,  5.8041],
        [15.5492,  7.3112,  9.0395,  8.9819,  9.4070, 12.8082, 19.7587],
        [13.1390,  7.1987,  7.7522,  8.6438, 10.3384, 10.7804, 14.0598],
        [15.1738, 13.1217, 13.2878, 19.1490, 28.6870, 27.7351, 13.2399],
        [29.7861, 22.9450, 12.4460, 13.3807, 13.2725, 17.0609, 28.5850]])
ground truth
tensor([[27.4724, 22.7380, 11.1257, 11.7175, 13.1510, 12.5197, 16.0705],
        [20.9325,  6.2217, 10.9836, 12.9535, 10.0482, 17.1485, 20.3515],
        [14.3566, 18.1548, 32.7523, 32.7806, 11.5505, 12.9960, 11.6780],
        [ 5.0170,  5.6689,  6.6043, 11.6355,  8.6735,  5.2154,  4.4218],
        [ 2.6433,  0.0000,  4.6949,  6.3651, 10.1920, 11.7175, 17.4908],
        [16.2415,  

batch_predictions
tensor([[17.3893, 26.6119, 36.7257, 17.1620, 14.8664, 17.0740, 16.6522],
        [10.0527, 14.3517, 11.9997, 11.8109, 17.1445, 28.9159, 30.1506],
        [13.0427, 12.8452, 14.3610, 19.2193, 12.5060,  9.8270, 12.6788],
        [20.4444, 10.3278,  9.4478, 10.7058, 11.0035, 13.0973, 18.5919],
        [32.8047, 28.9111, 12.5457, 13.9962, 14.3933, 15.4827, 24.8520],
        [13.3849, 14.0046, 16.0709, 22.0691, 27.8207, 15.6508, 14.1755],
        [15.9340, 17.8401,  8.3670,  7.5066,  7.7239,  8.7094, 11.3959],
        [13.2844, 14.3183, 13.1344, 14.9043, 18.7710, 23.5014, 19.4926]])
ground truth
tensor([[22.9025, 20.6491, 49.9858, 42.5028, 13.1236, 18.1689, 20.3231],
        [ 8.6402, 12.5460,  9.4819, 10.4287, 19.2530, 26.1967, 27.6302],
        [12.7170, 16.8201, 22.1068, 30.4971,  1.2756,  0.0000,  0.0000],
        [29.1557,  8.9295,  3.9321, 12.8222, 11.9411, 11.2572, 15.1105],
        [27.2357, 31.7464, 11.8227, 16.1757, 11.8490, 11.6123, 23.8690],
        [13.6905, 1

batch_predictions
tensor([[10.8478, 11.0693, 11.3165, 18.2371, 23.1858, 10.0052, 11.4979],
        [13.7098, 14.1686, 20.3793, 31.0269, 26.9405, 11.8497, 15.7978],
        [10.9860, 10.2052, 10.4148, 13.0850, 20.1900, 25.9787, 12.0568],
        [43.3617, 26.9904, 24.8067, 24.7563, 25.7305, 32.9498, 46.0908],
        [19.1490, 28.8314, 29.8961, 19.3381, 14.0229, 16.1044, 15.7741],
        [10.6596, 11.2393, 11.2142, 15.4761, 22.9904, 22.7285,  9.1165],
        [31.2153, 24.1668, 10.7807, 12.4305, 12.5371, 13.2009, 21.4109],
        [16.0900, 15.3278, 15.4942, 15.0147, 18.5507, 26.3409, 25.3335]])
ground truth
tensor([[10.3104, 17.6355, 15.6102, 27.3277, 28.0247,  8.0089, 10.7575],
        [12.6842, 13.4070, 23.0300, 35.2749, 32.9932, 14.9943, 14.5692],
        [10.8560, 11.7772, 10.4308, 13.2937, 20.3656, 36.5363, 21.5703],
        [54.5635, 28.7840, 22.9308, 25.0992, 25.0142, 38.7755, 61.8481],
        [17.1620, 22.8564, 32.9300, 27.2883, 15.8864, 16.1099, 11.5860],
        [11.6781, 1

batch_predictions
tensor([[37.2748, 25.9261, 15.6767, 16.0027, 16.8192, 19.3185, 26.8096],
        [26.8640, 24.8460, 15.1690, 15.4448, 14.9267, 16.0449, 18.3669],
        [16.0600, 15.5952, 18.8215, 34.3101, 38.0440, 22.9005, 17.4291],
        [11.4309,  6.8058,  7.7711,  8.0420,  8.2798, 10.9834, 16.1420],
        [16.6174,  6.5629,  7.9230,  8.7139, 10.1787, 15.1564, 21.6178],
        [16.3405, 23.2779, 20.8877,  9.9323, 10.7266, 11.3167, 11.3950],
        [17.5514,  6.5370,  6.9499,  7.4101,  8.3192, 10.2824, 15.3834],
        [17.5146, 17.8657, 16.5461, 21.6467, 36.3467, 45.3596, 30.0645]])
ground truth
tensor([[35.1474, 29.9461, 15.8305, 17.9989, 16.4966, 21.1451, 19.3027],
        [30.7540, 27.0692, 27.2676,  6.8594,  8.8435, 17.3753, 24.0646],
        [20.6491, 19.7279, 19.8980, 34.5522, 41.2840, 27.0550, 13.2937],
        [10.8233,  3.4850, 10.7049,  1.9069,  0.0789, 10.2709, 14.7028],
        [17.9905,  3.2614,  7.5355,  7.7854, 11.5466, 17.0437, 17.6618],
        [19.1084, 2

batch_predictions
tensor([[14.8377, 18.2784, 19.6152,  7.7995,  7.3691,  9.9654, 11.6928],
        [17.1946, 14.1010, 13.5844, 13.3431, 14.8235, 23.5637, 27.1224],
        [19.7994, 32.3226, 36.1770, 13.5491, 17.2909, 17.4942, 16.3272],
        [12.4526, 19.2820, 19.6753, 10.6718, 10.1323, 10.0404, 10.3225],
        [23.2372,  9.7349, 12.5155, 12.3648, 13.3663, 17.9073, 27.7051],
        [10.8603, 12.6089, 18.6901, 20.6304,  9.2517,  9.5430, 10.7020],
        [21.9030, 18.5825, 10.1132, 10.7784, 11.6207, 12.3585, 17.3169],
        [12.5439, 13.9235, 14.4722, 18.9924, 29.7474, 30.8647, 10.0568]])
ground truth
tensor([[12.5197, 14.2162, 17.5565, 15.1762,  9.0347, 12.9800, 15.7286],
        [12.9932, 20.8969, 14.2688, 13.7428, 18.7270, 25.4471, 21.1599],
        [23.2568, 39.5408, 46.9671, 42.4603, 19.1327, 19.2035, 17.8146],
        [13.9137, 21.4361, 21.8306,  8.4429,  9.4424,  8.1536,  8.9821],
        [26.6582,  9.9348, 11.8764,  9.7364, 14.2574, 15.2920, 25.0283],
        [11.3756, 1

batch_predictions
tensor([[37.0609, 25.0541, 13.4160, 14.8811, 15.1027, 16.0937, 25.3192],
        [10.2800, 10.0957, 10.4124, 15.1352, 22.9143, 14.7982,  9.4568],
        [34.7148, 32.2630, 11.3792, 16.4415, 16.2040, 15.8126, 23.2141],
        [17.6027, 16.5735, 15.8269, 29.1335, 46.8379, 39.2241, 16.7535],
        [13.0000, 13.3551, 15.2873, 22.9884, 22.7779, 16.9779, 12.3633],
        [15.9924, 16.4077, 19.7779, 33.1343, 32.6037, 21.5864, 16.6143],
        [18.5609, 18.1724, 19.7352, 31.0324, 43.5786, 31.8135, 20.0126],
        [15.7527, 22.7944, 24.8564, 12.0574, 13.1895, 13.8018, 13.8600]])
ground truth
tensor([[50.7653, 37.8260, 27.7069, 25.6803, 23.2993, 17.5028, 28.8124],
        [11.4413, 13.1641, 12.8222, 12.1910, 28.9716, 21.1205, 10.5602],
        [23.5686, 36.9898,  6.6043,  4.1667, 19.2602, 22.3498, 23.2568],
        [18.2799, 12.1778, 15.5181, 34.2977, 71.1205, 71.1336, 43.1746],
        [12.7693, 13.2228, 14.3282, 22.5907, 25.2834, 31.6893, 18.7358],
        [17.7296, 1

batch_predictions
tensor([[20.0681, 27.4121, 25.7022, 12.0251, 13.5520, 14.0447, 14.5327],
        [20.2262, 24.6352, 14.7872, 13.6601, 15.0738, 16.2320, 17.4134],
        [13.5944, 14.3572, 16.5177, 22.1209, 22.1169, 13.6455, 15.4454],
        [24.4755, 29.7493, 17.1338, 17.8235, 17.9906, 17.2339, 18.7783],
        [22.9617, 22.1175, 23.7968, 27.6971, 16.4561, 17.8585, 19.8180],
        [20.9266, 12.3163, 12.7024, 12.0533, 11.8440, 18.8122, 28.0403],
        [21.2575, 23.1408,  9.9271, 10.7793, 13.4073, 13.7608, 15.7109],
        [10.5139, 15.3275, 15.4839,  7.7866,  7.1778,  7.5657,  8.1681]])
ground truth
tensor([[20.1247, 34.1837, 32.4263, 15.2494, 12.7551, 16.6383, 10.9552],
        [24.2914, 36.5363, 22.4348, 24.7307, 28.9116, 23.0159, 21.3435],
        [16.5045, 17.5829, 17.7670, 23.0668, 23.8953, 10.7575,  1.5255],
        [22.3498, 33.2341, 18.7217, 19.1893, 22.8883, 23.4552, 21.2727],
        [26.7290, 22.3073, 24.5323, 35.2749, 19.3311, 19.4586, 19.9546],
        [ 7.3251,  

batch_predictions
tensor([[35.4722, 25.9321, 15.5362, 16.5067, 15.9761, 16.0011, 28.3190],
        [18.8492, 27.0985, 26.3390, 11.6749, 12.5282, 13.9689, 14.0380],
        [20.3595, 28.0691, 24.8035, 12.0492, 14.3932, 13.8581, 14.6669],
        [36.6915, 17.6662, 18.6481, 17.5051, 20.0300, 37.3585, 49.0257],
        [13.1039, 15.0308, 15.8943, 16.3290,  8.1046,  9.7380, 11.6088],
        [17.0573, 21.1494, 24.1238, 25.8906, 10.3030, 15.3093, 16.2584],
        [ 8.3073,  9.3526,  9.5751, 11.1461, 13.8003, 20.8260, 20.6204],
        [13.7456, 14.0837, 16.6803, 24.4042, 26.0517, 10.0628, 14.3793]])
ground truth
tensor([[ 0.0000,  0.0000,  0.0000,  7.2988, 14.2999, 14.9802, 28.9399],
        [20.7908, 28.6423, 27.2534, 12.6276, 13.6763, 22.9308, 17.0918],
        [14.3566, 22.6332, 28.1321, 11.6071, 13.2653, 13.8322, 13.0527],
        [54.9974, 25.8811, 21.3309, 19.4240, 19.5686,  6.6938,  7.7065],
        [ 9.9290, 14.0058, 12.6381, 23.6718, 10.2052, 13.6376, 11.1257],
        [18.4949, 2

batch_predictions
tensor([[10.7006, 11.0937, 15.3064, 21.0391, 18.6537,  8.5222, 11.4399],
        [30.4106, 24.3185, 12.1641, 12.8191, 13.0425, 15.0146, 21.9473],
        [23.0906, 26.7311, 12.4492, 15.0204, 15.4683, 17.0698, 20.6010],
        [11.0842, 11.5896, 10.2140, 10.0555, 14.5304, 23.7147, 22.7220],
        [23.2863, 27.3113, 11.0862, 13.6376, 14.0290, 14.4218, 17.5190],
        [22.1509, 20.4094, 11.4432, 11.8834, 11.9774, 12.9037, 18.1752],
        [12.1447, 14.1242, 16.8033, 12.5437,  9.7557, 10.1957, 11.4892],
        [22.8928, 26.8061, 11.1900, 12.4109, 12.1436, 12.5548, 14.6675]])
ground truth
tensor([[11.0863, 16.2152, 14.0715, 21.0021, 13.5587,  8.7059,  7.4566],
        [33.4892, 32.6247, 20.4365, 15.2778, 11.8764, 16.1423, 33.3050],
        [21.9095, 30.9311, 12.1646, 13.7033, 13.1247, 12.9669, 18.2536],
        [11.1257, 11.7175, 13.1510, 12.5197, 16.0705, 27.9327, 24.2109],
        [11.7044, 29.3924, 13.2825, 11.9016, 10.8759, 14.6633, 18.7270],
        [22.7466, 2

batch_predictions
tensor([[12.3436, 14.4838, 17.9214, 26.6325, 32.2505, 12.2136, 14.3982],
        [23.2476, 20.2673, 12.4454, 12.5475, 12.2459, 12.8234, 18.8272],
        [39.3834, 35.0594, 15.5722, 18.2150, 18.1779, 17.1379, 23.3677],
        [14.1862, 14.1245, 14.3963, 20.0430, 30.2824, 29.0058, 14.0515],
        [10.1009, 10.7777, 15.5757, 20.3463, 10.9786,  9.3740, 10.2981],
        [18.8932, 20.3699, 19.8363, 19.5360, 23.2608, 32.8706, 30.5374],
        [13.7675, 21.6831, 20.1399, 10.8688, 10.5550, 10.9497, 11.1754],
        [15.7590, 17.0436, 16.2784, 17.6581, 21.9664, 29.9143, 34.7251]])
ground truth
tensor([[16.4683, 18.7358, 20.5074, 31.5334, 44.0476, 33.1066,  9.6372],
        [24.1071, 18.8209, 10.0198, 11.7630, 16.4683, 11.6780, 23.4694],
        [42.4745, 41.1565, 12.5850, 21.3010, 23.5402, 17.7154, 43.0130],
        [17.7863, 11.2954, 15.0085, 19.7421, 23.7670, 34.0561, 12.6417],
        [ 9.0561, 11.0119, 15.9722, 25.1276, 14.7817, 10.9127,  9.3679],
        [21.1073, 2

batch_predictions
tensor([[11.0139, 16.0099, 23.6556, 23.9304,  9.1652, 12.1306, 11.5779],
        [21.0820, 10.1678, 11.6844, 12.1490, 13.0487, 16.5690, 18.9888],
        [19.4588, 22.0371,  8.7880, 10.0029, 10.5942, 10.7259, 13.2233],
        [13.8957, 17.0406, 22.5618, 20.6808, 11.8631, 12.0695, 12.7576],
        [24.3460, 23.0367, 11.2997, 12.5777, 12.8448, 13.4156, 18.1242],
        [15.8057, 16.3747, 19.4652, 22.7531, 18.0696, 13.2262, 15.0427],
        [13.1370, 14.1899, 13.0055, 13.7110, 17.3597, 23.2997, 18.1195],
        [18.4897, 27.5381, 26.4303,  9.5874, 13.3840, 14.1180, 15.3354]])
ground truth
tensor([[ 7.9432, 20.0947, 24.6055, 14.0058, 13.6770,  8.6928,  9.4161],
        [29.3924, 13.2825, 11.9016, 10.8759, 14.6633, 18.7270, 26.9332],
        [13.7165, 24.7764,  7.3777, 13.3877, 11.2572, 13.5718, 15.0447],
        [16.2982, 17.0210, 25.9921, 24.9575, 12.8827, 12.2874, 11.7489],
        [23.3167, 27.8932, 12.7301, 11.2967, 11.2572, 12.9011, 18.4377],
        [14.4661, 1

batch_predictions
tensor([[22.9341, 29.8060,  9.1531, 13.7528, 13.8373, 14.0577, 17.5470],
        [12.6897, 16.6892, 22.0792, 22.0600, 11.5946, 11.2300, 12.0000],
        [14.5176, 13.4195, 12.2072, 13.2973, 18.9595, 28.8419, 24.1928],
        [42.1064, 13.9082, 19.6546, 18.8788, 17.8171, 28.2058, 42.8927],
        [11.8071, 12.2205, 12.7734, 20.3356, 34.5699, 26.8596, 13.3832],
        [18.0373, 25.9763, 33.7197, 25.8763, 14.7882, 17.5456, 16.5456],
        [ 9.8481, 11.7189, 11.2650, 11.1436, 13.6902, 17.8100, 22.0982],
        [13.5670, 14.7101, 20.6368, 23.8606, 19.3154, 11.7286, 13.6808]])
ground truth
tensor([[23.9611, 38.5192,  6.0100, 12.4408, 17.6092, 16.4782, 22.6065],
        [14.5450, 17.7144, 19.2399, 21.9621, 14.8343, 13.4666, 10.7312],
        [15.8588, 16.6950, 10.3033, 15.2636, 23.9087, 29.0675, 37.2449],
        [48.8033, 17.8327, 12.4408, 13.2036, 15.4129, 31.9963, 60.8496],
        [11.2387, 12.2307,  9.9065, 23.2426, 37.8543, 34.7931, 15.3628],
        [19.7922, 2

batch_predictions
tensor([[24.7284, 19.3148, 10.0235, 10.8218, 10.7734, 11.7135, 18.5665],
        [20.0696, 19.3590, 19.9804, 21.4330, 22.0262, 25.7522, 31.2628],
        [12.5365, 20.4002, 23.7029, 11.8697, 10.1300, 11.7617, 11.9668],
        [ 7.1879,  7.4789,  7.9088,  8.3014,  8.4890, 10.5917, 13.8824],
        [15.0324, 16.5723, 24.5722, 21.8257, 18.1777, 12.7472, 13.7730],
        [21.5261, 27.1790, 20.6948, 12.3020, 13.2388, 13.3378, 14.2181],
        [11.7012, 12.5437, 16.7041, 21.4721, 20.6321,  8.7662, 11.6921],
        [24.5312, 25.6037, 34.3888, 43.9259, 26.4254, 23.0317, 24.2494]])
ground truth
tensor([[30.5556, 23.1009,  6.5476,  9.8781, 11.3379, 11.5079, 33.5317],
        [19.3027, 21.9529, 24.1071, 26.7290, 22.3073, 24.5323, 35.2749],
        [11.0600, 15.2946, 27.7617, 24.0400,  8.3772, 10.0342, 10.8364],
        [ 4.9053,  5.6286,  6.8254,  6.0231,  7.8906, 12.6775, 15.4129],
        [19.6995, 12.2591, 25.2551, 22.8175, 19.9546, 12.4291, 14.3424],
        [27.6565, 3

batch_predictions
tensor([[31.6981, 44.3330, 39.2264, 22.9177, 23.3197, 22.5164, 23.1849],
        [16.2072, 23.4773, 24.0452, 13.4329, 13.0096, 13.3580, 13.4439],
        [25.6928, 24.5667, 12.9545, 13.5552, 13.7683, 13.7647, 17.0068],
        [28.8046, 15.8377, 15.5297, 14.1859, 14.8518, 22.9589, 37.2691],
        [19.2204, 10.9866, 12.2924, 11.0854, 11.3278, 19.6042, 33.8340],
        [17.3484, 27.0882, 34.9772, 28.5547, 13.5750, 18.2013, 18.3120],
        [13.9982, 15.3117, 21.8472, 26.2492, 21.1048, 12.2954, 15.0636],
        [13.3600, 16.2469, 24.4089, 26.4824, 11.9293, 13.3795, 13.7017]])
ground truth
tensor([[36.7914, 58.6876, 57.4830, 31.4484, 24.3764, 27.4660, 20.1105],
        [19.1043, 27.6502, 25.0992, 15.1786, 16.1139, 17.9563, 13.0527],
        [27.9721, 27.1568, 14.9790, 16.1231, 11.4413, 16.5834, 19.6739],
        [41.9385, 14.1110, 16.4256, 16.1231, 19.1215, 22.9748, 43.0037],
        [28.1604, 14.6967, 13.2511, 21.1168, 16.4116, 21.6978, 28.7840],
        [17.7579, 2

batch_predictions
tensor([[10.7781, 12.5560,  9.5332,  8.2101,  8.7976,  9.8274, 10.0342],
        [12.6239, 12.8781, 15.4061, 21.6566, 21.0017, 10.5017, 12.7455],
        [20.7354, 11.8753, 11.0790, 11.6757, 15.2802, 19.3498, 22.1976],
        [21.9148, 12.4456, 12.8469, 12.5919, 13.1879, 17.0465, 27.4704],
        [21.2332, 24.2893, 13.2356, 13.0583, 13.3108, 13.8668, 15.9684],
        [ 8.7140, 11.7104, 17.3758, 20.4408, 11.0414,  9.5795,  9.7039],
        [19.2158,  7.3652,  8.6282,  8.4487,  9.3333, 12.8996, 18.0610],
        [ 6.4021,  6.1448,  5.8835,  6.0325,  6.3654,  7.0770,  7.8346]])
ground truth
tensor([[12.5066, 17.8590,  9.7449,  6.1284,  7.1278, 10.1131,  8.1405],
        [13.7559, 14.3083, 13.3482, 19.4108, 21.8832, 15.7286, 11.6649],
        [34.2057, 27.5644, 12.2567, 14.2031, 20.2656, 21.8701, 16.8332],
        [19.9435, 14.2294, 11.1849, 11.6781, 11.2046, 15.8601, 32.7722],
        [25.5918, 29.8790, 15.3340, 16.7543, 12.5197, 14.6633, 16.9122],
        [13.3614, 1

batch_predictions
tensor([[24.0436, 11.7291, 12.2332, 12.1728, 13.1297, 18.5739, 29.4773],
        [12.1285, 12.2229, 11.0147, 11.2244, 13.7072, 21.9812, 22.8932],
        [ 9.0835, 15.0024, 13.4664, 13.2852, 16.3480, 24.8946, 32.9364],
        [15.3110, 21.7724, 18.6961,  9.2699,  9.3149,  9.5683, 10.3960],
        [30.4023, 24.0196, 10.9110, 12.8249, 12.6171, 14.3290, 21.9651],
        [ 7.7903,  7.3212,  7.8083,  8.5957,  6.9034,  6.5915,  8.1147],
        [37.0611, 17.8791, 18.9271, 17.5719, 17.3128, 28.5393, 44.7966],
        [19.8795, 24.0472, 23.1397,  9.2079, 12.7660, 13.7705, 14.6638]])
ground truth
tensor([[28.1168,  9.3766,  8.6796, 10.7838, 16.0705, 21.6597, 28.9979],
        [13.0952, 12.4433, 12.1882, 11.9189, 14.9802, 28.7132, 21.2585],
        [ 5.6944, 12.2436, 10.5208, 13.2299, 19.6081, 23.9611, 38.5192],
        [11.4545, 27.4724, 22.7380, 11.1257, 11.7175, 13.1510, 12.5197],
        [34.2971, 31.1508, 12.1740, 21.8679, 13.0952, 18.5374, 20.7058],
        [ 7.3777,  

batch_predictions
tensor([[26.2182, 14.2292, 13.2005, 13.8384, 14.1624, 19.3086, 32.1106],
        [ 9.4650, 10.1166, 13.1462, 18.5910, 16.9306,  7.5044,  9.8605],
        [12.1656, 19.0072, 24.0347, 18.5982, 11.1262, 12.2274, 11.7277],
        [12.0394, 13.9480, 19.9889, 24.0185, 19.6261,  9.6027, 12.8979],
        [15.4247, 16.5346, 27.4531, 40.3949, 31.6036, 13.6772, 19.0657],
        [12.6829, 14.4275, 13.3692, 14.6770, 20.2424, 15.0967, 11.9148],
        [10.3400, 12.1172, 15.8108, 20.2764, 15.8038,  8.0536,  9.8586],
        [15.2577, 18.8507, 21.9561, 21.8770,  9.6709, 11.4317, 13.1615]])
ground truth
tensor([[29.3793, 14.3424, 30.3430, 12.6984, 16.7517, 17.0635, 31.1366],
        [ 7.1278, 10.9416, 15.2420, 24.6844, 15.9916,  0.0000,  8.9164],
        [11.4087, 20.9609, 24.1780, 26.8707, 11.0969, 14.5975,  9.8073],
        [11.7772, 14.3141, 21.2302, 22.0805, 23.0726, 12.1032, 10.8985],
        [13.4637, 13.9739, 23.7954, 48.1434, 11.4938, 17.3328, 17.7721],
        [14.5055, 1

batch_predictions
tensor([[16.9655, 24.9893, 38.2676, 32.6973, 13.7753, 18.0148, 16.3595],
        [34.7111, 16.1907, 18.0419, 17.1225, 18.5374, 33.3710, 46.5364],
        [21.8266, 23.5199, 10.7381, 11.7760, 12.1782, 12.5651, 14.4167],
        [32.3954, 16.8370, 15.0221, 15.5782, 15.0419, 17.9161, 28.6859],
        [16.9075, 21.9496, 29.9158, 25.7485, 11.5202, 15.8407, 16.1562],
        [29.0332, 10.2482, 15.0839, 14.4166, 15.1687, 21.8624, 33.8462],
        [ 9.5576, 12.6578, 12.2363, 12.8742, 18.4089, 29.4262, 27.1409],
        [20.7740, 20.6985, 18.5168, 19.2556, 30.4261, 49.6782, 45.1375]])
ground truth
tensor([[22.3304, 23.3561, 51.3282, 26.4466,  0.0000, 12.9800, 17.8853],
        [38.4070, 13.6763,  9.6939, 12.6134, 15.9722, 36.5363, 51.1621],
        [28.4297, 30.8957, 12.3866, 18.0414, 15.5612, 12.9960, 16.9218],
        [34.9816, 13.7165, 11.9542, 21.9227, 16.4256, 18.9242, 39.9001],
        [16.4519, 23.5534, 31.0626, 27.7749, 11.3493, 13.8611, 16.5702],
        [32.9082,  

batch_predictions
tensor([[12.4524, 13.4198, 15.3930, 19.0356, 27.9387, 28.7319, 10.7474],
        [14.9252, 22.5435, 25.1707, 11.4656, 11.0023, 13.4562, 13.3142],
        [11.1292, 11.1987, 11.3782, 15.1882, 23.0504, 26.0405,  9.7306],
        [13.7813, 14.3066, 16.2858, 23.5576, 30.7631, 30.4602, 10.7825],
        [18.7892, 10.4779, 11.3488, 10.9119, 11.7833, 19.0280, 23.7774],
        [23.7000, 11.9016, 12.5750, 13.0429, 14.5592, 19.1309, 24.8919],
        [30.0468, 13.7253, 14.9283, 14.2811, 14.1179, 18.5101, 28.3787],
        [15.9645, 18.0314, 24.0867, 25.9466, 11.1887, 13.5870, 15.5517]])
ground truth
tensor([[10.5997,  8.6007, 13.3614, 20.9890, 20.2130, 34.7712, 13.7559],
        [10.5159, 18.7217, 23.2851, 37.9252, 14.9660, 15.4053, 16.1848],
        [11.3098, 12.2962, 14.5713, 15.0579, 21.8701, 30.0368, 10.0736],
        [23.9512, 27.3101, 14.8101, 14.0731, 32.9932, 18.8492,  0.2126],
        [21.4286,  8.0782, 14.0164, 11.8622, 14.9802, 21.8112, 21.1876],
        [29.7477, 1

batch_predictions
tensor([[20.3095, 20.6662, 18.6368, 20.0035, 33.4555, 51.3476, 46.5670],
        [14.3441, 19.3868, 27.7530, 26.7689, 12.4316, 14.9790, 14.3770],
        [16.6133, 20.7709, 10.6555, 11.8281, 13.0153, 15.5990, 20.4527],
        [ 7.5834,  8.5649, 10.8452, 15.3057, 13.9300,  7.8943,  7.7768],
        [13.5524, 15.0799, 13.3084, 13.0667, 16.6079, 23.6859, 26.4551],
        [13.2685, 16.5601, 19.7938, 19.2584, 11.5891, 12.0968, 12.1645],
        [11.7074, 13.1100, 17.7123, 22.6478, 22.0842, 12.3780, 13.0741],
        [26.4319, 22.2946, 11.3114, 12.4993, 12.8150, 14.8057, 22.8595]])
ground truth
tensor([[22.7117, 20.9232, 16.7938, 20.8443, 31.6149, 55.7075, 45.6207],
        [13.3078, 15.4478, 27.2817, 33.2766, 16.6383, 18.5941, 15.5329],
        [25.6378, 22.9167, 10.7851, 12.1315, 11.1678, 12.6701, 20.5641],
        [ 7.7854,  7.2725,  8.1931, 16.4124, 15.9916,  7.5092,  8.5087],
        [11.5860, 14.0978, 14.7422, 15.3209, 16.2809, 24.2898, 27.4724],
        [12.8118, 1

batch_predictions
tensor([[11.8497, 11.8474, 12.4551, 16.9682, 23.7626, 24.5156, 11.7326],
        [35.9308, 41.7944, 39.6383, 23.1291, 23.8913, 24.4608, 26.3788],
        [10.4815,  7.5100,  6.5373,  7.1096,  7.4439,  8.3899, 10.1220],
        [14.1126, 16.2227,  7.7019,  7.2000,  8.9386, 10.6436, 13.1446],
        [24.9219, 13.5105, 14.1048, 13.2767, 14.3220, 20.8017, 30.2027],
        [17.0534, 16.7753, 17.4600, 31.5530, 41.4436, 30.0060, 17.6399],
        [12.5279, 13.8676, 17.0375, 25.8653, 26.2340, 11.5529, 15.2658],
        [25.1598, 32.3125, 41.3897, 36.0310, 22.8776, 23.8833, 24.0018]])
ground truth
tensor([[11.6781, 16.0836, 14.4924, 21.6597, 27.5118, 30.2341, 13.2956],
        [41.7806, 44.1741, 55.1946, 58.8112, 45.4761, 40.8206, 45.1210],
        [12.6417,  9.8923,  6.0091,  9.4813,  9.5663, 10.7426,  9.1128],
        [15.1762, 16.8727,  6.3388,  2.5381, 10.4156,  9.1662, 12.4408],
        [34.2057, 20.8969, 26.3019, 20.2393, 14.5055, 19.8317, 28.0905],
        [20.9609, 1

batch_predictions
tensor([[34.0695, 43.4232, 35.5010, 22.9721, 23.5271, 22.8500, 25.2787],
        [12.0364, 13.3299, 21.6750, 23.0516,  9.6949, 11.1657, 12.3599],
        [12.0165, 15.7096, 22.1140, 21.7794, 12.3770, 13.2275, 12.4861],
        [30.1542, 24.0150, 14.7210, 15.1558, 15.6311, 16.8653, 22.4650],
        [14.8422, 19.0592, 25.5648, 23.9153, 13.0232, 11.1357, 12.3268],
        [27.2254, 10.9359, 13.8362, 13.0236, 13.4638, 19.6557, 28.7017],
        [17.8527, 18.9872, 25.4858, 22.0426, 12.7313, 14.3884, 15.8794],
        [ 8.1864,  7.4280,  7.6813,  8.0499,  8.1425,  9.9932, 13.9528]])
ground truth
tensor([[43.6366, 60.1757, 41.7800, 34.2404, 24.2205, 21.4002, 28.7415],
        [11.1257, 10.7575, 28.5508, 28.7086, 10.5339, 11.5071,  9.9421],
        [12.0465, 19.2885, 30.4280, 22.6616, 11.3237, 14.3566, 11.1395],
        [29.1241, 54.7902, 53.2738, 58.0499, 28.7415, 21.7971, 21.5561],
        [12.1740, 15.0510, 22.4065, 31.3917, 16.9218,  4.2375, 12.2591],
        [30.4138,  

batch_predictions
tensor([[15.0823, 18.2490, 15.7425, 16.1541, 20.5323, 37.3882, 39.3360],
        [22.8827, 11.7523, 10.5935, 12.0014, 12.1830, 12.7548, 18.9569],
        [10.2681, 12.7036, 12.6108, 14.7563, 19.7539, 24.4107, 22.9688],
        [12.0896, 14.5260, 13.2082, 15.2904, 19.0680, 28.0385, 27.9394],
        [12.0569, 13.7624, 12.7879, 12.9605, 16.7477, 25.1060, 27.1781],
        [18.9186, 34.6286, 36.1704, 25.4625, 15.9668, 16.9209, 15.2251],
        [38.4333, 29.2444, 15.2467, 16.8604, 16.0519, 17.2888, 29.1334],
        [ 9.6021, 10.0803,  9.8936, 10.5468, 14.6044, 20.6582, 16.2390]])
ground truth
tensor([[ 8.9994,  1.3039, 18.9201, 15.0085, 27.8770, 36.7772, 37.0890],
        [26.0652, 22.9748,  9.5345, 14.9001, 14.0189, 14.5055, 16.2678],
        [ 8.2908, 13.6621, 15.0085, 13.4212, 15.0794, 22.6616, 21.7120],
        [12.3866, 11.7489, 11.1961, 20.0397, 26.5448, 24.8866, 26.0771],
        [10.1616, 12.2732, 10.9410, 13.3078, 24.3764, 23.4836, 27.1684],
        [21.5420, 3

batch_predictions
tensor([[ 8.0434,  9.3708,  9.2641,  8.8824,  9.7799, 16.1943, 20.7090],
        [11.7956, 11.9068, 13.4576, 19.2693, 30.5101, 30.9585, 13.0712],
        [11.6017, 10.8972, 11.1187, 13.1120, 19.7262, 17.0965,  8.3054],
        [11.6483, 12.5608, 14.5980, 18.8721, 20.0004, 20.5672, 10.3803],
        [21.2373,  9.6254, 11.3830, 11.6237, 12.0713, 17.7911, 23.5647],
        [ 6.1887,  6.4716,  7.1121,  8.4313,  8.4700,  6.7119,  6.0083],
        [ 9.4595,  9.8968, 10.6808, 12.9882, 17.1107, 14.7110, 10.7800],
        [10.5509,  8.6485,  8.6280,  8.9106,  9.7739, 11.9793, 15.4672]])
ground truth
tensor([[ 3.8138, 10.7312, 17.0174, 13.8611,  8.9690, 25.2630, 31.6544],
        [12.9960, 11.6780, 13.7046, 20.8333, 30.4563, 28.7557, 15.2778],
        [ 2.5644,  5.2078,  2.6302, 17.4250, 21.9095, 19.2267,  6.5097],
        [11.7307, 12.5329, 14.9264, 13.8217, 20.8837, 22.8958, 10.7838],
        [23.0017, 10.7001, 10.0057, 10.9694, 11.1820, 12.2024, 22.9308],
        [ 5.5130,  

batch_predictions
tensor([[10.9566, 12.0219, 12.2164, 14.8464, 20.2124, 24.1357,  9.0872],
        [18.6113, 26.6199, 36.5868, 27.1363, 13.4897, 17.1182, 16.7702],
        [26.9964, 32.5396, 28.4023, 14.8818, 15.9953, 15.6260, 18.3854],
        [17.3375, 18.1110, 17.9835, 19.7614, 21.8272, 26.1198, 29.1279],
        [12.6992, 12.4116, 12.9663, 18.2918, 23.8900, 22.6195, 12.0651],
        [20.5224, 11.1742, 10.4611, 10.2770, 11.2899, 14.1716, 19.9639],
        [14.8416, 16.1521, 18.7358, 33.4055, 34.1519, 13.3822, 18.9719],
        [16.3798, 16.3346, 14.1097, 14.8685, 24.6502, 36.5190, 31.6426]])
ground truth
tensor([[ 0.0000,  0.0000,  6.8254, 12.4540,  9.9684,  8.8506, 28.3272],
        [15.5181, 29.1294, 36.2309, 34.1662, 19.3056, 11.2572, 20.3709],
        [33.9427, 46.4994, 40.8163, 18.5516, 11.4229, 15.6746, 22.7183],
        [12.0181, 12.7834, 20.0539, 16.4824, 21.5986, 27.0266, 32.1145],
        [13.3362, 11.5505,  9.7222, 17.3469, 17.2194, 18.7783,  3.0896],
        [22.9748, 1

batch_predictions
tensor([[22.3220, 22.4490, 21.2038, 21.6722, 24.9900, 32.9545, 43.3896],
        [22.7967, 29.9647, 26.8123, 13.0912, 16.0358, 17.2494, 18.9358],
        [10.9839, 10.8727, 11.4232, 17.3560, 26.9341, 24.3770,  8.1279],
        [10.2489, 12.2267, 13.9095, 18.8268, 14.4960,  7.3622,  9.8642],
        [ 8.6573,  9.0702, 11.0303, 13.8952, 20.1602, 19.0545,  7.0769],
        [37.8383, 25.0386, 24.3005, 24.3575, 26.3285, 35.8207, 48.7755],
        [ 9.6039, 10.7974,  9.9391,  9.5995, 10.6464, 17.9097, 23.2373],
        [12.5920, 16.7479, 22.4306, 20.8862,  8.8200, 10.9728, 12.9575]])
ground truth
tensor([[50.8503, 62.2024, 30.0595, 36.2103, 35.9127, 40.9864, 57.5680],
        [23.6678, 25.0425, 32.2421, 12.7693, 16.4541, 20.6207, 21.0601],
        [ 4.4713,  8.2194,  8.8243, 14.2031, 28.3535, 27.6170,  8.6928],
        [15.7286, 13.2430, 16.8332, 22.7249, 26.5255,  6.4177, 10.7180],
        [15.4787, 10.3104,  7.9958,  8.6928, 15.7023, 17.6618,  4.9448],
        [50.1701, 3

batch_predictions
tensor([[19.0029, 17.9389, 22.4246, 31.3947, 29.4729, 23.9581, 20.4889],
        [13.3298, 15.2312, 19.0049, 13.8466, 10.6480, 11.3392, 11.7517],
        [12.1320, 18.0252, 25.4104, 23.0351, 10.2156, 10.9277, 12.2207],
        [43.3955, 36.3322, 24.0998, 23.7067, 23.6760, 25.5343, 34.2186],
        [20.1365, 19.7690, 26.8859, 36.4887, 25.9904, 19.1842, 20.6242],
        [11.9423, 13.3991, 14.0147, 18.5053, 24.2444, 24.9189,  7.6284],
        [20.3538, 15.0070,  8.3489,  9.1593,  9.7626, 10.9826, 18.3266],
        [12.2583, 18.9273, 27.8231, 21.7870, 10.2290, 12.5151, 12.5144]])
ground truth
tensor([[18.0130, 16.4824, 16.9218, 45.1105, 25.7795,  0.0000,  9.5096],
        [13.4929, 12.6644, 21.8569, 14.4529,  9.5345, 12.7959,  8.8638],
        [15.0316, 27.5776, 22.8564, 17.5171, 12.3882, 16.0179, 14.1636],
        [56.7460, 50.1701, 32.6531, 27.1259, 22.0947, 30.4138, 45.9892],
        [30.2438, 40.3628, 45.9042, 60.7285, 50.8503, 62.2024, 30.0595],
        [12.8543, 1

batch_predictions
tensor([[15.4359, 15.2794, 13.7360, 14.4257, 18.3407, 27.7415, 29.6794],
        [17.0634, 18.3164, 25.3198, 35.9418, 27.6890, 13.6664, 18.3523],
        [29.6775, 10.7038, 15.6454, 14.9393, 15.3358, 21.5076, 33.0695],
        [20.5174, 20.7052, 21.6896, 30.8367, 49.5857, 46.2090, 23.6483],
        [12.1540, 12.3757, 12.0396, 11.9163, 12.6601, 19.1338, 22.0895],
        [20.5170, 12.9288, 13.0905, 13.2213, 13.3805, 21.7835, 21.8548],
        [16.1571, 18.9143, 25.6847, 25.5879, 14.5020, 15.3549, 16.3781],
        [13.9832, 13.6647, 17.1552, 22.1505, 24.2010, 10.8158, 13.9341]])
ground truth
tensor([[16.7375, 12.0465, 12.2732, 16.7659, 21.1593, 29.9320, 29.6344],
        [15.8469, 15.5181, 29.1294, 36.2309, 34.1662, 19.3056, 11.2572],
        [37.3441,  4.9887, 14.5692, 13.7188, 17.1769, 20.2239, 26.5023],
        [24.8948, 20.8969, 20.3577, 37.9537, 71.5413, 77.3146, 24.3293],
        [15.8075, 15.5445, 13.2167, 14.6370, 14.3872, 22.6854, 23.7507],
        [19.9546, 1

batch_predictions
tensor([[17.3819, 24.9124, 24.1125,  9.7069, 12.2877, 12.7340, 13.9937],
        [14.7676, 20.3771, 24.8174, 15.6238, 11.5524, 13.3118, 13.5284],
        [21.2304, 19.5726,  9.3610, 10.6563, 11.8709, 12.9716, 19.1943],
        [21.9794,  8.2248, 11.4440, 11.8496, 11.8301, 15.0254, 21.3957],
        [29.1992, 37.3394, 26.1884, 14.1160, 16.5739, 15.8795, 17.2421],
        [ 7.3714,  8.1998,  8.2742,  8.3055,  9.5474, 12.9042,  8.7753],
        [13.7072, 16.3001, 19.1093, 26.6498, 25.5250, 11.5973, 14.1712],
        [11.7058, 12.0337, 12.7205, 16.8974, 24.7126, 19.6500, 12.2941]])
ground truth
tensor([[16.7234, 27.4660, 30.2438,  9.8073, 13.1094,  9.5522, 16.1706],
        [15.2211, 14.9235, 18.6224, 24.6173, 13.3929, 12.7976, 12.6134],
        [26.8141, 26.7432, 17.0777, 12.3441, 15.3628, 13.0102, 21.1026],
        [25.6312,  4.9711, 12.5592, 12.3882, 13.9137, 18.9637, 13.3088],
        [30.0028, 48.9938, 35.8135, 15.8588, 15.7880, 14.3707, 16.4541],
        [ 7.6672,  

batch_predictions
tensor([[ 8.7448,  9.9516, 13.2400, 14.4639,  9.5440,  7.4885,  8.7676],
        [25.3760, 12.7559, 14.8241, 13.7819, 14.3760, 25.4937, 37.4928],
        [ 8.4528, 11.2554, 15.2403, 13.4528,  6.9378,  7.5142,  8.0502],
        [11.0592, 11.9211, 12.2442, 13.2962, 22.2654, 29.5173, 13.5635],
        [25.5075,  9.6819, 11.2552, 12.4095, 12.6109, 15.7006, 19.5248],
        [14.3181, 16.8664,  7.0759,  6.7530,  7.0197,  7.9783, 11.1774],
        [25.4141, 19.5556, 11.1758, 12.0425, 12.5312, 14.2500, 21.3009],
        [23.8735, 32.3244, 29.4694, 10.7719, 15.9972, 16.0514, 17.1760]])
ground truth
tensor([[ 8.1536, 10.9679, 10.5471, 10.9942, 10.2578,  5.4708, 10.7312],
        [36.1536, 14.2290, 10.3033, 16.8934, 16.6383, 33.7443, 45.7341],
        [ 7.5618, 13.4008, 14.4135, 15.1236,  4.4582,  9.6134,  8.8901],
        [13.7033, 11.3756, 13.8348, 17.9905, 33.9295, 32.1278,  5.6944],
        [26.6176, 10.6786, 12.5855,  9.4424, 11.6386, 17.9511, 24.5923],
        [13.3614, 1

batch_predictions
tensor([[16.8104,  9.8383, 10.6017, 10.7187, 12.9194, 19.4183, 23.9218],
        [13.1017, 19.1243, 24.1808, 19.7890, 10.8262, 12.6381, 12.2868],
        [31.2101, 15.1848, 16.2060, 15.6787, 16.4704, 25.0804, 39.7908],
        [18.1046, 18.0419, 24.3924, 34.0244, 26.4180, 16.0064, 18.3997],
        [28.5631, 42.8081, 29.9085, 15.5910, 17.1372, 15.9414, 17.3783],
        [18.7146, 10.9880, 11.6171, 11.0803, 11.3935, 14.6029, 23.8115],
        [23.4847, 25.3832, 28.7507, 16.5514, 18.3468, 19.2411, 21.5364],
        [18.2509, 23.6810, 23.7235, 10.4086, 12.9768, 13.8948, 14.8635]])
ground truth
tensor([[ 8.6270,  9.3372,  9.8764,  9.9158, 16.8727, 24.6712, 26.1573],
        [17.0437,  6.6938, 28.2614, 26.0521, 13.5060, 12.0068, 14.6107],
        [39.7159, 19.1084, 15.7812,  9.4292, 20.5155, 30.1683, 42.8985],
        [17.6223, 23.5008, 30.8259, 47.2777, 39.3214, 25.0526, 19.8448],
        [35.5584, 55.0595, 37.2166, 10.4025, 12.1173, 19.8696, 18.7075],
        [19.5011, 1

tensor([[17.0483, 24.1576, 23.8193, 10.1465, 12.6019, 13.6282, 14.1074],
        [25.1679, 26.6178, 21.3006, 12.5989, 13.7249, 13.7131, 14.8277],
        [15.1238, 15.7642, 13.8199, 15.1064, 25.2236, 38.7569, 29.9042],
        [14.5923, 18.6362, 23.2066, 22.3841, 11.0753, 11.1683, 12.6502],
        [10.5904, 10.9606, 13.6163, 18.1037, 18.6850, 10.0149, 11.0663],
        [12.6881, 12.6523, 13.1768, 19.9147, 28.8218, 27.4839, 12.2868],
        [14.9402, 18.5395, 26.1372, 14.5189,  8.6669, 11.7762, 12.7168],
        [11.6417, 11.9823, 15.1058, 21.9573, 24.3381,  9.1466, 12.9101]])
ground truth
tensor([[24.3764, 23.4836, 27.1684,  7.9365, 10.6151, 14.6542, 13.0952],
        [21.5136, 29.9603,  8.7727, 10.9410, 17.0068, 16.3690, 21.0743],
        [17.2761, 15.7738, 13.7472, 18.0414, 27.9478, 48.8804, 29.9320],
        [16.7938, 18.1483, 19.8448, 22.9879, 16.2546, 13.9795, 11.2046],
        [10.9285, 10.8627, 12.7696, 18.8585,  0.0000,  0.0000,  5.7996],
        [15.4053, 16.1848, 13.4070, 2

batch_predictions
tensor([[13.1729, 21.9812, 21.3967,  8.7274,  9.5590, 10.4551, 11.1820],
        [18.8891, 32.3895, 31.4519, 17.9674, 17.7741, 17.0793, 15.3906],
        [12.6047, 13.1532, 12.0087, 12.7802, 17.2869, 20.7478, 22.0198],
        [39.9342, 23.7952, 13.7034, 15.5758, 14.7167, 17.1407, 28.7578],
        [10.6348, 14.1640, 12.9631, 12.5943, 16.9511, 24.9817, 27.8472],
        [13.4694, 13.9366, 12.6203, 17.5282, 24.4834, 25.8961, 12.6965],
        [ 6.9614,  7.3420,  8.6106,  7.4266,  7.1907,  8.2951,  8.8780],
        [13.8134, 14.2237, 14.4846, 23.7467, 38.1340, 30.0222, 15.7493]])
ground truth
tensor([[13.0589, 20.9890, 22.0673,  6.0757, 10.8890, 11.7833, 10.6654],
        [31.2075, 31.5618, 41.6241, 23.9654, 26.8424, 18.0130, 16.4824],
        [ 9.6514, 10.6434, 10.3175,  9.5805, 13.5062, 23.7245,  3.1746],
        [38.0952, 26.9700, 11.0402, 17.1910,  9.0703, 17.8430, 32.3554],
        [15.1894,  6.0889, 15.0710, 13.6902, 17.2935, 23.4219, 31.1547],
        [13.6763, 1

batch_predictions
tensor([[13.3638, 15.0695, 20.6855, 32.5250, 30.3012, 10.4289, 16.8283],
        [ 9.5121,  9.5791,  9.4418,  9.6118, 10.3459, 16.6815, 21.8484],
        [14.9466, 14.3498, 16.5264, 28.9036, 37.4860, 24.9876, 15.1713],
        [24.3980, 37.5021, 27.4906, 14.9385, 15.7906, 14.5202, 15.7860],
        [19.3718, 26.7886, 24.0080, 12.5300, 12.2657, 13.1420, 13.9517],
        [ 9.4720, 12.1934, 14.1179, 15.1935, 18.3085, 19.8305, 13.7207],
        [13.3639, 12.8394, 14.8923, 21.9655, 23.2706, 14.5850, 14.8312],
        [16.3178, 22.0981, 32.4438, 25.5339, 14.9161, 16.5288, 15.3026]])
ground truth
tensor([[15.0794, 18.3957, 22.5198, 38.8605, 39.4558,  7.8231, 17.6020],
        [ 6.2467,  8.0352,  9.3109, 10.2183, 11.9674, 23.0274, 28.0773],
        [17.1910,  9.0703, 17.8430, 32.3554, 39.2007, 29.9745, 13.4070],
        [30.0028, 50.0850, 34.5663, 29.7052, 22.5907, 14.9943, 16.0714],
        [21.2018, 25.7228, 21.8254, 13.8747, 14.8384, 11.1536, 16.6383],
        [11.4940, 1

batch_predictions
tensor([[10.0123,  7.5176,  7.0531,  7.8120,  8.1020,  8.3827,  9.3791],
        [12.8946, 13.6841, 14.6451, 20.6235, 32.9200, 30.6114, 12.6960],
        [10.1738, 10.1350,  9.7437, 10.1139, 14.5588, 21.3185, 17.0452],
        [ 9.0466,  9.0010,  6.9947,  6.0935,  5.9658,  6.5234,  7.0698],
        [16.0510, 22.3930, 18.0090, 10.8735, 10.2536, 10.7383, 13.5798],
        [19.8579, 26.4290, 17.7762,  9.6610, 10.3074, 10.3195, 11.3910],
        [ 5.9255,  6.0635,  6.1712,  6.5890,  7.7139,  8.5354,  6.7554],
        [16.6813, 17.8582, 31.2369, 44.5109, 35.6089, 13.4982, 20.6528]])
ground truth
tensor([[12.3583, 14.1015, 11.1536, 10.6009, 11.3095,  9.9915, 10.2466],
        [11.5597, 10.9153, 14.0978, 33.6139, 27.2357, 31.7464, 11.8227],
        [16.8332, 13.2956, 21.3966, 12.8880, 12.8617, 22.3172, 15.5445],
        [10.6151, 11.5363,  7.5113,  4.4501,  4.1100,  5.4705,  5.2579],
        [12.4014, 23.3693, 21.5939, 11.2441,  9.8106,  9.3109, 14.2031],
        [19.4444, 3

batch_predictions
tensor([[ 7.2434,  7.6382,  7.6600,  8.2200, 10.4335, 15.4779, 19.3822],
        [15.5934, 14.9835, 15.0848, 20.0856, 31.6870, 24.9804, 17.6547],
        [ 6.7766,  6.5492,  7.1812,  8.3599,  7.6029,  7.2440,  8.3413],
        [18.1005, 20.9031, 16.3508,  9.3729, 10.0027, 12.5065, 14.4101],
        [22.8159, 14.0224, 14.7846, 14.0886, 14.5597, 21.4773, 29.4660],
        [23.3074, 21.5344, 19.8703, 19.2889, 23.1896, 27.2645, 34.9363],
        [30.2787, 18.7080, 17.6304, 18.9457, 19.8724, 20.9490, 24.4689],
        [17.6250, 27.1288, 29.8084, 13.2232, 15.4225, 15.9951, 15.3945]])
ground truth
tensor([[ 6.2467,  5.5497,  7.3382,  6.9174,  9.9816, 13.3614, 14.2425],
        [17.5737, 17.6304, 15.7880, 21.6695, 37.4291, 28.3447, 19.4444],
        [ 8.6451,  6.3634,  6.9728, 11.7914,  7.3271,  6.2217,  9.5380],
        [18.5692, 26.5781, 24.0531,  8.0615,  8.0747, 15.8338, 18.5692],
        [11.1678, 12.0890, 14.0306, 15.9439, 17.2761, 25.1984, 33.8152],
        [29.3924, 2

batch_predictions
tensor([[ 8.8290, 10.1119, 11.5328, 16.5341, 21.5291,  9.6534,  8.0622],
        [18.1141, 11.9310, 12.1142, 12.0603, 13.1827, 19.7682, 25.8403],
        [12.6000, 13.3028, 13.9988, 21.7095, 33.7551, 27.2785, 12.0342],
        [12.9619, 14.4201, 13.6833, 13.8648, 16.8549, 25.2706, 31.3360],
        [21.4441, 11.8700, 11.8967, 11.2860, 12.0683, 20.4705, 30.7400],
        [10.2754, 14.2333, 13.2378, 15.0213, 24.1985, 34.3819, 20.0270],
        [27.3133, 19.4860, 11.1159, 11.2550, 12.0541, 13.3236, 21.8457],
        [12.5842, 12.9956, 14.6284, 19.6258, 27.7386, 24.7999, 13.4740]])
ground truth
tensor([[10.0210,  7.6539, 11.8885, 13.9269, 18.7796,  8.9427,  9.0084],
        [27.2959, 15.3203, 17.1344, 14.7676, 15.2211, 14.9235, 18.6224],
        [15.2946, 11.7833, 13.4008, 27.9590, 41.6228, 26.7622, 13.3351],
        [13.1904, 15.5971, 15.5050, 13.4797, 17.2804, 27.9721, 26.7228],
        [25.4535, 11.0544,  9.0703,  9.3963, 12.9252, 25.6378, 36.9756],
        [ 5.6973,  

batch_predictions
tensor([[15.5703, 15.3620, 20.0042, 20.5362, 11.7910, 12.7222, 14.7359],
        [10.6965, 10.7951, 11.9973, 16.1136, 19.7441, 17.3054, 11.2634],
        [20.3791, 24.1801, 31.2523, 26.9164, 18.0126, 20.6861, 20.8823],
        [11.3786, 17.5111, 25.5196, 22.8429,  7.7058, 10.2813, 11.7396],
        [18.9248, 28.5812, 25.7279, 11.1639, 13.6221, 13.8585, 14.1324],
        [14.8362, 14.9030, 17.5131, 22.2321, 15.6959, 11.7119, 14.7089],
        [19.0960,  9.3427, 10.7837, 12.9711, 13.0491, 14.3391, 18.0999],
        [ 7.7040,  7.9514,  8.0320,  8.5398, 11.0370, 17.0724, 14.6743]])
ground truth
tensor([[19.5949, 24.0400, 19.2004, 17.1883,  8.9295, 13.5455, 12.2962],
        [11.3624,  9.6660, 10.8759, 11.8622, 25.1447,  5.2341,  8.2720],
        [25.9206, 25.8285, 36.6649, 37.9932, 21.9227, 29.5634, 20.7654],
        [ 8.8243, 14.2031, 28.3535, 27.6170,  8.6928, 11.9279,  9.9816],
        [22.2931, 33.2908, 14.5833,  0.0000, 11.8197, 11.7489, 13.1519],
        [16.8201, 1

batch_predictions
tensor([[ 9.9367, 11.3642, 10.7859, 11.1760, 16.2407, 22.6162, 22.4167],
        [16.7358, 21.5094, 29.2544, 25.6944, 14.0447, 15.7130, 15.6303],
        [11.2672, 17.4599, 18.6029,  8.0932,  9.3192, 10.4182, 10.1633],
        [11.1848,  6.5862,  6.3721,  6.8580,  7.0693,  8.2445, 12.4852],
        [22.1747, 18.8286, 10.1639, 10.4187, 10.8057, 11.2882, 14.8292],
        [15.8530, 21.4137, 30.3353, 27.5057,  9.6274, 15.4429, 15.7969],
        [10.9976, 11.5456, 10.5892, 10.6842, 14.1836, 22.0861, 23.8150],
        [26.2032, 36.4536, 24.6884, 15.0281, 15.4498, 15.2950, 16.7732]])
ground truth
tensor([[ 6.2217, 10.9836, 12.9535, 10.0482, 17.1485, 20.3515, 25.9495],
        [21.2585, 22.5482, 28.3163, 32.4688, 16.5391, 17.0210, 15.3770],
        [ 8.9953, 11.1520,  0.0000,  2.8932,  9.1794,  8.4298,  9.1005],
        [15.8732,  5.2209,  6.6544,  6.6675,  6.4440, 10.5865, 13.0326],
        [27.9327, 24.2109, 14.0452, 12.2830, 10.4156, 10.0210, 16.5045],
        [17.1769, 2

batch_predictions
tensor([[11.7424, 12.4424, 11.4011, 13.1515, 17.1508, 25.8126, 24.1859],
        [21.3858,  7.1835, 10.0077, 10.3410, 10.3385, 14.6421, 23.3690],
        [13.1356, 13.1386, 14.8030, 19.8888, 29.7232, 26.9435, 11.6533],
        [ 9.5012,  9.6891, 10.1244, 13.7633, 20.3803, 16.0153,  9.5195],
        [15.0887, 15.6929, 20.0877, 29.4948, 26.9321, 15.9291, 17.7817],
        [27.8657, 20.9706, 10.9974, 11.1448, 10.6206, 11.4947, 19.2881],
        [23.9953, 21.4470, 11.7200, 12.7134, 12.2053, 12.0273, 18.4167],
        [10.9629, 11.2700, 10.6377, 12.9605, 16.9110, 23.1021, 20.2495]])
ground truth
tensor([[11.6922, 13.5488,  9.7931, 17.0777, 13.4779, 32.8656, 22.9592],
        [ 3.7612,  5.4708, 12.5986, 10.8890, 13.5192, 21.0810, 32.7459],
        [16.9926, 13.5062, 15.5187, 19.7846, 30.8673, 30.9382, 10.7710],
        [14.3214, 15.8732, 12.7696, 17.6486, 25.8154, 25.4866,  8.0089],
        [17.2902, 16.2698, 17.9705, 30.5556, 28.2455, 22.8883, 18.5658],
        [31.0941, 2

batch_predictions
tensor([[18.0901, 28.6116, 38.3412, 24.4507, 15.0309, 16.6603, 16.4245],
        [11.2323, 11.8758, 15.8497, 22.1084, 23.3107, 12.7253, 13.0754],
        [39.1494, 22.4785, 20.5884, 20.8286, 21.5526, 24.7734, 32.9222],
        [15.0094, 14.1421, 13.9746, 13.2876, 14.3225, 23.3092, 34.1559],
        [ 7.2845,  7.9578,  8.1416,  9.2617, 11.8119,  9.7201,  7.0756],
        [23.9655, 15.8582, 14.5226, 13.5784, 13.1055, 16.4205, 20.2837],
        [33.5520, 20.0236, 19.0725, 19.0003, 20.3422, 24.9902, 32.2657],
        [12.5219, 12.8330, 12.9730, 15.1514, 22.1555, 27.3697, 14.3915]])
ground truth
tensor([[21.8112, 25.6094, 35.1474, 29.9461, 15.8305, 17.9989, 16.4966],
        [11.8753,  0.0000,  6.4045, 24.9737, 27.1173, 24.8816, 11.5203],
        [60.7285, 50.8503, 62.2024, 30.0595, 36.2103, 35.9127, 40.9864],
        [41.6525, 22.2222, 10.4875, 16.7375, 18.6224, 22.6616, 31.2783],
        [ 9.0845,  7.0862,  9.2545, 11.0261, 12.6984,  7.8798,  5.8248],
        [21.1862, 1

batch_predictions
tensor([[15.8233, 15.3751, 14.5797, 19.4331, 39.7540, 34.7627, 15.2705],
        [15.7769, 15.5833, 17.3458, 24.3624, 33.8523, 27.3258, 16.1000],
        [20.3837,  9.3199, 11.0987, 10.5781, 10.6462, 13.0438, 22.9973],
        [13.2160, 12.5173, 11.4336, 11.4918, 15.2045, 23.7782, 25.2266],
        [16.3026, 23.8593, 22.0694, 11.4489, 10.7109, 11.1271, 11.3013],
        [12.5989, 16.2687, 15.5185,  7.8757,  8.9568,  8.9364, 10.3052],
        [13.0154, 15.3289, 22.3223, 27.9064, 16.6253, 12.5511, 13.1588],
        [ 7.1168,  6.5067,  6.0865,  5.9086,  6.1404,  6.3993,  7.1333]])
ground truth
tensor([[10.7710, 16.6100, 21.5420, 36.0261, 36.7914, 25.6094, 18.4382],
        [17.4461, 17.0351, 21.4853, 20.9892, 31.9586, 30.6406, 13.6480],
        [22.5340, 10.7568, 16.4824, 11.3237, 10.7710, 15.0652, 18.3532],
        [ 9.0703, 12.1882, 11.7914, 10.9269, 17.3611, 39.1582, 10.0765],
        [17.8985, 25.9469,  7.2856,  8.1536, 12.6775, 11.4150, 21.5018],
        [12.3751, 1

batch_predictions
tensor([[19.3150, 12.6939, 12.8972, 13.0643, 13.8871, 22.5410, 27.4921],
        [19.6322, 29.7531, 27.7650, 10.0112, 15.5048, 14.8008, 15.1179],
        [13.1151, 13.5954, 12.2307, 12.8992, 16.6383, 24.4884, 26.6507],
        [26.6132, 13.4670, 12.6658, 13.3296, 12.9505, 14.1509, 22.3508],
        [13.2460, 12.7197, 13.5424, 19.9040, 31.2458, 30.1736,  8.3837],
        [15.8702, 16.5985, 15.2794, 17.0628, 21.7259, 33.6687, 30.3267],
        [29.1217, 25.8633, 11.8172, 13.2433, 13.2398, 13.4201, 20.0584],
        [11.8519,  9.9721, 10.7088, 11.3811, 13.2995, 20.2972, 22.3748]])
ground truth
tensor([[23.7103, 13.3929, 14.8384, 11.9331, 12.0890, 23.2851, 24.2489],
        [19.2602, 29.4076, 36.0969,  6.7460, 14.0590, 14.7251, 18.8067],
        [ 9.0216, 14.1241,  8.5744, 14.3083, 17.3461, 26.0521, 24.7107],
        [26.1338, 27.2392, 13.2370, 14.5975, 13.1236, 17.0210, 22.6616],
        [14.0590, 14.7251, 18.8067, 21.1876, 37.0040, 34.9348,  8.5317],
        [13.6480, 1

batch_predictions
tensor([[11.8574, 12.0635, 12.9623, 17.3685, 25.3318, 24.9339, 12.8050],
        [11.6679, 10.6783, 11.1553, 14.2061, 22.3532, 20.9701, 12.7590],
        [ 8.1027,  6.1461,  5.7296,  6.0302,  6.3497,  6.5974,  8.0897],
        [18.0554, 18.5791, 25.5196, 36.8383, 30.3808, 17.7428, 20.3981],
        [22.3677, 28.5073, 35.1775, 19.9662, 19.7380, 21.0597, 21.8706],
        [34.8578, 43.7437, 34.4950, 22.4640, 22.5261, 21.7300, 23.3072],
        [28.7419, 13.8754, 12.9307, 12.9057, 13.0536, 13.4906, 22.0061],
        [23.7224, 13.1690, 13.1355, 13.2307, 13.7418, 16.3970, 22.8628]])
ground truth
tensor([[ 9.4671, 15.4620, 17.8997, 20.6774, 33.0499, 28.0896, 13.1803],
        [12.4433, 12.1882, 11.9189, 14.9802, 28.7132, 21.2585, 12.2024],
        [ 8.8152,  4.6202,  3.7557,  3.3305,  4.6344,  5.9240,  8.4042],
        [17.1489, 21.3703, 33.8901, 52.8669, 33.0221, 19.1215, 20.2656],
        [29.5634, 34.5739, 45.1604, 18.3456, 21.1073, 22.9616, 22.3304],
        [39.1298, 5

batch_predictions
tensor([[26.1472, 17.0441, 16.4065, 14.1952, 14.4912, 17.7086, 29.5340],
        [19.7662, 19.8331, 18.9838, 22.1827, 32.0904, 43.2038, 34.2011],
        [16.3983, 18.0313, 29.2394, 40.4457, 30.8384, 15.7973, 19.4875],
        [15.9677, 20.0452, 16.6551,  6.9173,  7.7979,  8.8232, 10.9771],
        [22.1070, 25.6643, 20.4230, 10.2899, 12.0987, 13.1823, 15.1667],
        [10.5085, 14.3074, 14.4520,  7.6787,  7.9945,  8.2387,  9.0595],
        [13.5815, 12.8393, 12.6662, 12.5467, 13.4672, 17.8585, 23.9746],
        [23.6459, 16.8965, 16.3506, 16.3175, 21.5943, 29.4561, 37.7460]])
ground truth
tensor([[29.2092, 20.3231, 18.0272, 11.4087, 14.5550, 22.9450, 30.3571],
        [23.2285, 22.8741, 17.8146, 16.9501, 25.2693, 41.2273, 33.3900],
        [15.5471, 20.1531, 35.5584, 55.0595, 37.2166, 10.4025, 12.1173],
        [17.0437, 17.6618, 15.5839,  2.9327,  3.9190,  8.2062,  8.7454],
        [17.1202, 26.4598,  7.0153,  8.1207, 14.0731, 10.9836, 18.5658],
        [13.4666, 2

batch_predictions
tensor([[23.4778, 29.3126, 11.4200, 16.4304, 16.3180, 16.6837, 21.8090],
        [19.9592, 27.0424, 34.5252, 22.9570, 15.4022, 17.6178, 16.6813],
        [ 6.4100,  6.0097,  6.0285,  6.4527,  7.2233,  9.1987,  9.8519],
        [11.3252, 12.4826, 13.7433, 19.8060, 29.2931, 32.6844, 10.9245],
        [13.7968,  7.8083,  7.4797,  7.5551,  7.8788,  9.6793, 16.2413],
        [20.9222, 20.0841, 18.1130, 17.2510, 20.5349, 38.8824, 42.1793],
        [22.0128, 14.3042, 14.9803, 14.7803, 16.2920, 26.4159, 31.3047],
        [15.2804,  7.3984,  7.5040,  7.5526,  8.1670, 10.5990, 16.4790]])
ground truth
tensor([[36.7772, 40.4195, 10.4734, 20.7766, 15.1077, 18.2965, 23.2143],
        [28.2455, 32.3696, 46.3152, 37.1457, 30.0595, 18.8492, 22.2789],
        [ 6.2217,  3.8124,  5.5130,  6.6752,  4.5068,  8.9002,  7.8231],
        [15.4655, 10.8890, 11.4413, 19.9106, 22.3567, 33.6139, 13.2956],
        [15.9916,  7.5092,  8.5087,  8.2325,  6.8648, 11.1783, 13.5718],
        [27.6644, 2

batch_predictions
tensor([[37.8561, 24.0059, 23.4953, 23.3519, 27.3917, 38.6655, 50.7597],
        [45.5867, 38.9645, 18.4988, 19.7194, 19.1913, 18.8278, 28.2096],
        [28.7373, 26.2966, 14.6259, 14.3615, 14.6880, 13.7691, 22.2489],
        [18.0820,  7.6868,  9.1101,  9.8229, 10.0402, 11.7383, 20.6696],
        [ 8.2196, 10.6293, 14.5059, 17.9030, 15.0295,  6.1803,  7.0356],
        [14.7691,  9.6142,  6.7535,  7.7052,  7.8986,  8.3298, 11.0624],
        [33.0526, 13.0401, 15.3020, 14.7815, 15.4136, 21.4620, 33.2061],
        [ 9.8927, 11.6515, 12.2684,  9.3724,  7.4961,  8.1917,  9.0347]])
ground truth
tensor([[42.2194, 30.9382, 25.8645, 21.2302, 29.7336, 42.8713, 62.4717],
        [55.7075, 45.6207, 29.8527, 28.1694, 20.3314, 23.0800, 29.5371],
        [54.3934, 45.8050, 17.1202, 12.8543, 13.9172, 18.9909, 31.3634],
        [22.4487, 10.0999, 11.8096,  9.2451,  8.7454, 11.7701, 17.8064],
        [ 8.2062,  8.7454, 14.1110, 21.0284, 18.0037,  1.7622,  6.8648],
        [14.2425, 1

batch_predictions
tensor([[12.2160, 15.1993, 12.9135, 13.9930, 25.3220, 36.3818, 29.5518],
        [19.4718, 23.9919, 12.8039,  9.5280, 12.3861, 13.6028, 15.6833],
        [20.2262, 13.2968, 13.4205, 13.3122, 14.6804, 19.5626, 19.8042],
        [15.0341, 20.9714, 26.8444, 12.8517,  9.4909, 12.6912, 13.5217],
        [ 9.3899, 12.2115, 12.1346, 12.5601, 14.8622, 24.0134, 27.5029],
        [20.3615, 24.4568, 18.3686, 10.0544, 11.3892, 11.8552, 13.0252],
        [20.4234, 17.1578, 10.8010, 11.2711, 11.5322, 13.3069, 16.9963],
        [17.6422, 22.0732, 20.9982,  8.4790, 10.3970, 11.9677, 12.5324]])
ground truth
tensor([[ 9.9773, 11.6497, 11.0544, 15.4337, 23.5969, 37.2591, 28.4297],
        [16.9926, 25.2551, 26.4314,  2.2959, 13.6054, 13.4779, 12.3724],
        [32.1278, 22.5013, 23.4482, 20.7522, 14.1636, 14.4661, 24.8685],
        [12.4717, 15.3345, 28.6706, 24.9575,  9.4104, 12.3158,  9.6088],
        [ 7.4303,  9.9553, 12.6775, 13.8085, 15.9784, 28.5771, 28.0905],
        [21.2302, 2

batch_predictions
tensor([[ 8.1815,  8.7688,  9.5788, 10.7322,  8.0792,  6.8631,  7.9687],
        [20.6005, 11.7444, 13.6770, 12.5401, 13.3347, 21.7049, 30.5237],
        [16.5657, 25.3488, 25.0437, 16.8765,  9.8673, 11.4946, 12.4595],
        [24.4493, 10.6648, 12.2408, 11.7308, 12.2647, 18.3394, 26.3085],
        [13.6196, 15.6445, 22.1905, 31.2198, 23.2712, 10.5948, 14.0271],
        [37.5206, 15.6907, 18.2579, 17.7678, 18.2458, 30.1544, 45.3984],
        [10.7803, 11.6531, 15.1358, 21.6497, 12.2752,  8.1736, 10.3641],
        [17.0808, 17.1244, 16.4579, 17.7822, 26.5217, 41.2132, 33.3060]])
ground truth
tensor([[ 9.5663, 10.7426,  9.1128, 18.0556,  7.2421,  5.7540,  6.0516],
        [33.3759, 14.7534, 24.1497, 13.9172, 14.4841, 26.1621, 27.1967],
        [14.5975, 19.8554, 30.8957,  8.9569, 12.1740, 12.7693, 13.0244],
        [27.2251, 11.5221, 12.5283, 15.4053, 13.7755, 15.8872, 34.3821],
        [ 0.0000,  0.0000, 11.0337, 27.4987, 28.1168,  9.3766,  8.6796],
        [23.2993, 1

batch_predictions
tensor([[12.8908, 15.2334, 20.2116, 23.2884, 23.9446,  7.3404, 12.5033],
        [40.4334, 32.7860, 15.1509, 17.2716, 17.5248, 17.7645, 24.6863],
        [12.2771, 15.1251, 12.9325, 14.2232, 22.2478, 32.4496, 27.6309],
        [39.4595, 33.9333, 19.7253, 20.0698, 20.0006, 22.8122, 36.6679],
        [11.9436, 16.3278, 14.1482, 14.2269, 21.7529, 34.2790, 30.8205],
        [14.1061, 14.0073, 15.0848, 23.9660, 35.7357, 31.1624, 11.5980],
        [42.6827, 39.6653, 37.5798, 35.5160, 37.5055, 40.4949, 45.2070],
        [12.9031, 11.5029, 10.8222, 10.9354, 15.3859, 24.1043, 24.1110]])
ground truth
tensor([[17.6355, 15.8206, 15.5050, 19.8974, 31.9306,  4.0505, 14.9264],
        [36.7914, 25.6094, 18.4382, 21.9388, 10.5867, 15.7738, 22.6049],
        [12.1740, 21.8679, 13.0952, 18.5374, 20.7058, 29.7477, 27.2959],
        [43.2272, 38.2956, 22.6328, 25.6839, 18.3456, 25.4603, 35.7443],
        [16.0705, 15.2946, 11.7833, 13.4008, 27.9590, 41.6228, 26.7622],
        [19.0901, 1

batch_predictions
tensor([[13.9195, 14.2738, 14.2844, 15.2409, 19.4462, 27.7604, 25.6844],
        [11.8229, 12.2384, 12.9411, 17.2868, 21.4785, 13.3219, 10.5275],
        [26.8946, 35.9649, 44.7347, 33.6430, 22.5664, 23.3996, 22.2655],
        [27.4537, 29.5711, 33.7072, 40.2730, 44.0565, 36.1034, 28.0407],
        [11.8535, 12.6732, 15.6981, 25.1319, 30.1577, 13.1547, 14.0855],
        [25.5882, 23.4153, 12.5442, 12.9006, 13.7697, 15.2761, 19.5272],
        [13.6108, 16.6838, 15.1736, 15.3939, 18.6358, 31.7051, 39.7741],
        [14.9727,  7.5325,  6.8986,  7.2953,  8.4676,  9.9588, 15.1580]])
ground truth
tensor([[10.0765, 13.6196, 13.5346, 17.0493, 26.4456, 26.6582, 28.9966],
        [12.9669, 14.9790, 14.4792, 17.7407, 20.6076, 26.1573, 13.8743],
        [34.0703, 41.2132, 57.7523, 42.2194, 30.9382, 25.8645, 21.2302],
        [16.6100, 18.3248, 19.6003, 27.6644,  6.6043, 17.3044, 11.5646],
        [11.5646,  6.4059, 14.7959, 32.5964, 38.9314, 30.4989, 13.7755],
        [25.8503, 2

batch_predictions
tensor([[35.5868, 28.5408, 13.2500, 13.9528, 14.3125, 14.6842, 22.0430],
        [10.0911, 10.5710, 10.4795, 10.1474, 12.2696, 14.8027, 17.2966],
        [16.8678, 21.1111, 28.6125, 24.1424,  9.1239, 13.6732, 13.8544],
        [14.0161, 13.6901, 16.4719, 24.3489, 23.8950, 12.8802, 15.1382],
        [31.1479, 12.9231, 14.6732, 13.5787, 13.5396, 20.5370, 30.3125],
        [29.2463, 16.0414, 16.7184, 16.4129, 19.7153, 32.5948, 40.1214],
        [19.5702, 15.2818, 13.9705, 13.8623, 16.8626, 24.9760, 31.6691],
        [11.8505, 14.0837, 13.2044, 16.6177, 24.9775, 27.6894, 21.9061]])
ground truth
tensor([[43.0414, 33.1207, 16.6950, 13.6480, 14.1298, 15.0935, 25.2834],
        [ 5.5556,  7.9790,  6.8736,  6.4059, 10.3600, 11.5646, 16.4399],
        [16.4116, 14.7392, 30.4563, 31.5476, 11.6922, 17.5170, 15.1927],
        [16.2152, 16.6623, 19.3714, 27.9721, 27.1568, 14.9790, 16.1231],
        [37.0200, 14.1504,  8.9032, 16.5176, 18.2667, 23.1720, 32.2988],
        [33.0782, 1

batch_predictions
tensor([[28.1324, 14.2644, 13.3674, 12.3632, 12.4201, 14.4326, 23.3003],
        [18.5486, 10.8672, 11.0939, 11.0283, 13.7389, 14.2456, 17.4199],
        [19.8891, 17.9423,  9.5180,  9.8426,  9.9383, 10.3923, 14.7722],
        [20.6069, 11.2554, 12.4986, 12.5365, 15.4588, 25.0909, 28.1202],
        [25.4897, 15.4529, 14.8206, 14.0188, 14.5384, 20.9925, 33.5506],
        [17.4751, 20.1389, 21.0235, 13.9474, 11.4115, 12.3243, 13.1332],
        [23.5482, 11.0672, 10.7765, 11.0236, 11.3115, 13.5049, 22.6764],
        [22.0983, 38.0507, 27.1172, 15.6815, 17.4559, 15.6473, 13.9656]])
ground truth
tensor([[25.2976, 12.3158, 12.0465, 13.0385, 13.4495, 20.5924, 25.3968],
        [25.3118,  7.7098,  9.3396, 12.4858, 18.4240, 16.6808, 31.0516],
        [26.0521, 23.8559,  8.0747, 11.0600, 11.0863, 16.2152, 14.0715],
        [25.9212,  7.1712,  7.7806, 11.6355, 14.5975, 19.8554, 30.8957],
        [23.2851, 20.7200, 17.4887, 15.0652, 15.0794, 20.1389, 36.1253],
        [14.1723, 1

batch_predictions
tensor([[12.6066, 16.9203, 21.7206, 21.4718, 12.2992, 11.1929, 12.1308],
        [ 8.2978,  7.5952,  8.8079,  9.0334, 10.2856, 14.7139, 20.1096],
        [17.8063, 24.2535, 33.4213, 31.3473, 11.3143, 17.0691, 16.8580],
        [ 8.7155,  8.9852, 10.2411, 15.3095, 22.3091, 19.3118,  7.4019],
        [10.9986, 10.3777, 11.3373, 16.9409, 24.3309, 23.3905, 10.7398],
        [17.1710, 28.3737, 27.9103, 13.8195, 15.6504, 16.2966, 17.1595],
        [13.3390, 12.2516, 12.4790, 15.4457, 22.1746, 27.7077, 11.5167],
        [ 9.3310,  9.4348,  6.4615,  6.0369,  6.5699,  7.3219,  8.0512]])
ground truth
tensor([[10.6654, 16.4387, 23.0274, 25.0000, 16.7280, 13.3219, 12.2962],
        [ 6.6675,  2.8538, 16.2283, 11.1389, 13.6376, 24.8948, 24.2635],
        [16.9785, 25.0000, 29.3651,  4.2234,  0.0000, 15.7455, 17.9138],
        [ 7.3645,  7.0752,  9.9553, 15.0973, 22.7117, 19.4503,  5.2604],
        [ 8.6928,  9.4161,  8.6007, 18.6349, 27.5250, 26.7359, 12.6907],
        [23.2851, 3

batch_predictions
tensor([[11.3630, 13.5663, 11.9038, 13.3246, 21.7607, 30.4975, 22.9552],
        [15.9153, 14.9464, 14.9438, 20.0411, 31.7550, 27.9388, 17.0634],
        [14.9354, 14.7387, 22.7460, 35.8889, 22.1678, 16.0008, 17.4163],
        [ 7.2509,  9.5507,  9.7696,  9.5002, 10.9400, 17.3698, 23.1156],
        [14.5507, 14.1850, 14.1307, 18.0374, 34.8940, 40.3474, 21.9115],
        [12.5000, 13.4032, 14.2892, 17.8448, 26.1964, 27.5017, 15.9091],
        [11.9307, 12.1289, 10.6829, 11.0937, 19.2711, 31.1722, 24.6929],
        [13.5783, 17.5958, 24.9953, 23.4562, 11.8034, 12.1177, 13.1224]])
ground truth
tensor([[ 6.5901, 10.6293,  0.0000,  9.6088, 14.9802, 29.4076, 25.0142],
        [17.2194, 17.2902, 16.2698, 17.9705, 30.5556, 28.2455, 22.8883],
        [14.3707, 17.5454, 32.4405, 46.6270, 38.7472,  4.2942,  0.0000],
        [ 3.5245,  8.8243, 10.4024, 11.8359,  8.9821, 15.2946, 17.7012],
        [12.3724, 20.7200, 21.9529, 15.8588, 33.5176, 37.0465, 36.8197],
        [10.5584, 1

batch_predictions
tensor([[ 7.4565,  8.5114, 10.1545, 10.4420,  7.9705,  7.9587,  8.0252],
        [16.8425,  7.4046,  8.3878,  8.8876, 10.0650, 12.6997, 18.7846],
        [ 6.6015,  8.3300,  9.3707,  9.8064, 11.1618, 14.0947,  6.6010],
        [14.1521, 14.8537, 13.9747, 17.9047, 36.3604, 36.3526, 17.0040],
        [21.1773, 27.1672, 29.7016, 15.5056, 16.4751, 16.7750, 16.7188],
        [13.1937, 19.4539, 28.3846, 25.2899, 11.1226, 13.6796, 13.7919],
        [25.6537, 25.0685, 10.1984, 11.6868, 12.2839, 14.1148, 19.8550],
        [36.8378, 24.2682, 15.5893, 16.5517, 16.1021, 19.6383, 30.5001]])
ground truth
tensor([[ 7.1145,  7.7948, 12.3583,  9.8356,  8.4467,  7.7948,  6.2217],
        [18.2667,  8.3114, 11.3756, 12.2567, 11.2441, 15.0184, 24.9342],
        [ 1.9595, 11.0074,  9.3372,  9.6397, 10.0079, 13.0458,  7.7854],
        [10.5584, 16.7517, 15.4478, 22.1797, 44.8696, 36.2528,  9.3963],
        [17.8713, 19.7279, 26.0062, 15.8872, 17.7296, 17.6446, 15.1644],
        [14.2715, 2

batch_predictions
tensor([[25.4928, 15.3282, 11.7815, 11.8617, 11.6702, 12.7770, 19.8831],
        [21.5707, 20.9843, 10.7094, 10.9122, 10.8275, 12.1073, 16.3339],
        [13.7123, 14.0398, 14.0666, 18.2161, 28.1811, 31.0919, 13.6152],
        [23.1832, 32.2767, 28.0057, 13.2955, 15.8865, 15.6620, 15.5005],
        [ 7.2905,  7.0325,  6.9203,  7.1052,  7.8854, 10.1626, 11.8971],
        [17.6404, 13.2768, 12.5652, 12.1580, 14.6093, 29.9078, 38.7170],
        [14.1043, 14.1737,  8.1261,  7.4192,  7.7598,  8.4577, 11.1889],
        [11.9018, 12.1133, 11.8698, 13.7275, 20.6046, 22.3035, 12.5481]])
ground truth
tensor([[34.9773, 31.1083, 18.7783, 15.0652, 18.2398, 16.3549, 24.7307],
        [27.9904, 47.6332, 21.8112, 19.8413, 16.1706, 14.5550, 18.8350],
        [16.4966, 14.4416, 14.7817, 20.0539, 33.4892, 30.6973, 15.1502],
        [35.5867, 23.0442, 26.1480,  9.2404,  0.1417, 10.8135, 12.9252],
        [ 7.7523,  8.4184,  5.4563,  5.8815,  6.7035, 12.1032, 11.2670],
        [30.4989, 1

batch_predictions
tensor([[28.3348, 13.5029, 15.2369, 15.1721, 15.9761, 23.6982, 29.8416],
        [24.8075, 22.0922, 12.2444, 12.4239, 12.8031, 13.6430, 20.6321],
        [17.0975, 33.0882, 36.4065, 21.5312, 14.3737, 16.8789, 14.5192],
        [13.6754, 15.5308, 21.7260, 29.3938, 25.8876, 12.7432, 16.0756],
        [17.6511, 19.4882,  8.5249,  8.5269,  8.2156,  9.3083, 14.5934],
        [13.3826, 13.8694, 12.2754, 12.5875, 17.1390, 26.8566, 25.6671],
        [16.3397, 31.6544, 35.9489, 20.8322, 14.9277, 15.7459, 14.8982],
        [18.2809, 18.7045, 20.3194, 31.3528, 42.1522, 32.5470, 19.0136]])
ground truth
tensor([[26.4881, 17.5312, 17.0918, 17.5737, 15.4337, 23.0584, 31.0374],
        [25.2126, 19.8271, 11.8906, 14.5975, 17.1627, 16.6100, 18.0697],
        [23.5969, 37.2591, 42.5879, 28.4297, 14.6684, 20.6491, 19.7279],
        [11.8764, 16.1423, 33.3050, 38.7046, 31.8311, 13.8322, 11.7914],
        [ 1.5590,  0.0000,  0.0000, 12.9110, 14.2007,  8.1349, 14.1582],
        [14.5833, 1

batch_predictions
tensor([[18.8907, 20.0690, 10.2031,  8.2850,  9.2379, 10.3752, 11.8541],
        [12.2045, 17.1654, 19.7694,  7.1219,  8.4706, 10.3393, 11.2407],
        [24.0089, 23.7492, 12.9082, 13.0599, 13.3829, 14.9733, 18.6522],
        [ 7.5872,  8.6525, 10.5955, 14.7665, 17.8315,  7.3064,  8.9219],
        [16.1366, 26.9266, 37.1910, 29.6309, 12.8072, 17.0031, 15.5482],
        [16.8467, 17.9960, 21.3819, 31.9800, 43.9637, 36.1701, 13.7301],
        [21.0432, 25.1422, 30.3813, 30.5450, 18.9180, 19.6174, 19.9690],
        [26.2557, 12.5926, 14.1811, 13.3091, 13.1451, 19.5808, 30.1816]])
ground truth
tensor([[20.0026, 21.4492, 10.8233, 10.5734,  8.0221, 11.7570,  9.1662],
        [12.2304, 28.8927, 30.7207,  8.7322, 12.7433, 14.6107,  6.4308],
        [21.3309, 26.4335, 13.1510,  3.7612, 10.0605, 16.7938, 18.1483],
        [ 6.8594,  6.5901, 16.1990, 17.6587, 20.8192,  8.9286,  9.4388],
        [15.7313, 23.9229, 46.4286, 42.1910, 18.3815, 17.5454, 17.4036],
        [13.5771, 2

batch_predictions
tensor([[13.5234, 13.7851, 14.7158, 21.1618, 31.7310, 27.9391, 15.2023],
        [ 7.8576,  8.0902,  6.3494,  5.6615,  6.0273,  6.5923,  6.9255],
        [20.9732, 29.6656, 26.0280, 12.2879, 14.4532, 14.2504, 14.7524],
        [ 8.7239,  9.2205, 11.3993, 17.2491, 18.7276,  8.2792,  9.1428],
        [20.1561, 14.0207, 12.1898, 11.6894, 12.0677, 16.9106, 22.2807],
        [13.2345, 16.7004, 22.9245, 22.2657,  9.0158, 10.8393, 13.0033],
        [17.1265, 16.7169,  7.6134,  7.6452,  8.8651,  9.8649, 12.4760],
        [24.3572, 25.2018, 10.1662, 12.4037, 12.4596, 13.8548, 17.3611]])
ground truth
tensor([[15.2778, 11.8764, 16.1423, 33.3050, 38.7046, 31.8311, 13.8322],
        [ 8.4042, 10.0765,  4.1950,  3.1746,  2.8770,  5.4280,  5.8957],
        [26.1621, 27.1967, 34.1412, 13.7613, 15.1077,  9.4813, 18.4099],
        [ 7.3645,  8.5876,  9.9553, 17.3987, 29.4450,  4.8790,  3.3403],
        [22.8883, 13.6763, 12.8685, 18.8492, 13.1661, 14.5125, 17.0210],
        [14.0847, 2

batch_predictions
tensor([[14.5693, 16.6598, 19.1439, 23.6472, 11.3997, 11.5023, 13.7414],
        [25.3782, 34.8624, 25.3305, 15.1739, 17.0797, 16.2211, 17.1742],
        [12.4890, 16.8972, 22.2172, 10.9428,  8.9476, 11.4366, 11.8335],
        [11.7114, 10.5248,  9.9137, 11.5142, 17.4395, 21.6091, 10.9553],
        [24.2288, 13.7666, 14.3655, 13.6828, 14.6715, 21.1564, 31.3307],
        [12.3963, 13.7281, 17.3074, 22.4295, 16.3401,  8.0178, 11.4483],
        [17.5218, 18.4136, 17.0870, 16.9090, 26.8555, 42.0525, 37.7088],
        [ 8.4133,  9.0466, 10.5954, 14.7645, 14.5370,  8.1215,  7.3857]])
ground truth
tensor([[ 6.8254, 12.4540,  9.9684,  8.8506, 28.3272, 16.2546, 14.0715],
        [26.0521, 36.4150, 26.6570,  0.0000, 19.7659, 18.4377, 18.7796],
        [17.2146, 16.7675, 28.7612, 27.4724, 10.4550, 11.0994, 14.4003],
        [ 2.1831, 10.7312, 11.4676, 13.7691, 20.3314, 17.7012,  6.6807],
        [27.3951, 20.6207, 17.5737, 17.6304, 15.7880, 21.6695, 37.4291],
        [12.5986, 1

batch_predictions
tensor([[ 8.0823,  9.0265,  9.6609, 12.7935, 17.5021,  9.6608,  6.6105],
        [20.0170, 11.3957, 12.1441, 11.2134, 11.9894, 18.6231, 23.8352],
        [19.2861,  8.5820,  8.5405,  9.0934,  9.8120, 10.8428, 17.4306],
        [22.8716, 24.1859, 11.1120, 11.5580, 12.3493, 12.7946, 16.5940],
        [15.5879, 20.1460, 32.3003, 38.6678, 29.6465, 14.8869, 17.4965],
        [21.7542, 23.8116, 11.0952, 12.9895, 13.4357, 14.4003, 16.5740],
        [27.0195, 14.5025, 14.7283, 13.5204, 14.6744, 24.7958, 35.1916],
        [20.4130, 31.3177, 23.5361, 14.7840, 15.0489, 14.0484, 14.7862]])
ground truth
tensor([[ 9.2057, 12.2567, 10.7575, 17.9248, 20.8837, 20.4892, 12.4540],
        [26.8707, 11.0969, 14.5975,  9.8073,  7.9790, 17.0635, 21.9955],
        [19.5423,  6.2467,  8.0352,  9.3109, 10.2183, 11.9674, 23.0274],
        [23.4219, 31.1547, 21.4624,  4.2346, 11.2046, 13.3614, 20.2656],
        [15.6746, 22.7183, 50.0000, 37.0040, 27.3951, 15.5612, 18.5516],
        [25.9337, 3

batch_predictions
tensor([[ 7.5544,  7.5873,  8.0597, 10.3791, 15.1799, 17.8655,  8.3359],
        [22.8317, 35.5932, 25.4200, 12.5146, 15.3340, 14.7301, 15.7698],
        [12.8026, 20.8385, 33.1004, 24.5092, 10.7187, 13.4924, 12.8759],
        [14.6113, 18.0766, 19.7165,  9.1956,  8.5563, 11.1337, 13.1389],
        [19.2311, 26.2391, 24.8065, 12.9768, 14.5849, 14.3832, 14.3232],
        [24.9507, 34.1232, 28.7298, 13.8711, 17.9676, 17.6233, 17.5422],
        [22.5174, 10.8098, 12.4791, 11.7233, 13.0407, 23.7755, 33.7876],
        [10.3212, 13.2813, 16.8819, 21.1681, 17.6205,  9.7371, 11.1031]])
ground truth
tensor([[ 7.7328,  9.6791, 10.0605, 11.7964, 16.2678, 15.1631,  8.3903],
        [21.1599, 36.8885, 26.7491, 14.8080,  8.9295,  0.0000,  0.0000],
        [13.4140, 20.5813, 37.3882, 23.7375,  2.2225,  9.9421, 14.3477],
        [16.3265, 18.0981, 24.2205,  9.9206,  2.3951, 16.3549, 13.2370],
        [27.1259, 30.2579, 35.2749, 12.1032, 13.5062, 16.4966, 14.8810],
        [23.9371, 4

batch_predictions
tensor([[14.4456, 18.0226, 23.2658, 23.3198, 11.3024, 12.1620, 13.8469],
        [11.9589, 13.4167, 16.8182, 24.2984, 19.9495,  9.4781, 12.9743],
        [14.2861, 15.0846, 13.3024, 13.6604, 19.3772, 27.6222, 26.7698],
        [16.2234, 25.0882, 25.2571, 10.5724, 12.4549, 13.2288, 13.4310],
        [14.8304, 12.6176, 13.3339, 12.4022, 12.6087, 17.0344, 24.1008],
        [11.2704, 12.2765, 12.1606, 17.2337, 26.1389, 26.1950,  9.6072],
        [16.2686,  7.0059,  6.6020,  8.6542,  8.8722,  9.6474, 15.9730],
        [18.0539, 24.1685, 26.3397, 12.7535, 14.6763, 15.2056, 15.3126]])
ground truth
tensor([[14.4924, 21.6597, 27.5118, 30.2341, 13.2956, 14.1636,  8.8901],
        [13.3362, 12.9252, 16.1281, 22.4773, 18.5941, 11.1820, 13.4637],
        [12.9677, 19.1610, 17.7721, 18.5374, 19.3027, 25.5811, 29.3367],
        [15.1894, 28.5376, 28.7743,  9.4424, 10.7706,  7.7065, 15.0316],
        [25.5655, 11.5860, 14.0978, 14.7422, 15.3209, 16.2809, 24.2898],
        [13.6507, 1

batch_predictions
tensor([[12.1079, 13.3700, 18.6935, 24.3470, 23.0926,  9.3920, 13.2348],
        [31.9329, 31.6804, 22.6830, 16.6326, 17.7567, 18.5240, 21.7934],
        [10.4109,  9.7357,  9.4616,  9.7445,  8.0801,  8.3955,  9.8908],
        [16.6783, 20.8159, 23.6352, 17.5232, 12.9020, 13.7092, 14.3534],
        [16.8561, 15.9840, 16.7786, 22.4652, 38.5558, 35.2701, 15.7180],
        [10.8971, 10.6768, 14.5333, 17.7376, 17.2649, 17.8702, 11.6275],
        [16.2927, 18.5849, 27.3951, 38.7827, 33.8766, 14.3526, 18.7159],
        [14.7163, 14.2773, 14.5991, 25.8230, 34.1496, 19.1402, 16.4425]])
ground truth
tensor([[18.1406, 14.8668, 13.8180, 25.6803, 26.7007, 10.1616, 12.2732],
        [30.9666, 35.5867, 36.7772, 15.2636, 17.6729, 27.6644, 30.0737],
        [20.4366, 13.1641,  7.3514,  9.6791,  8.8243,  9.9947, 11.0863],
        [14.5450, 17.0042, 32.7985,  8.6139,  9.9553, 16.4519, 14.6107],
        [10.6293, 19.5437, 17.3044, 22.1514, 36.2812, 43.1689, 20.6066],
        [ 0.0000,  

batch_predictions
tensor([[24.7112, 12.3896, 13.9129, 13.5676, 14.3102, 20.0209, 28.3678],
        [28.4193, 14.1747, 15.7705, 15.3562, 15.8482, 25.9119, 40.0191],
        [17.5056, 15.5739,  7.8517,  8.4537,  9.0164,  9.9041, 15.0847],
        [10.8156, 10.5034, 10.9389, 14.2292, 20.1473, 20.5175, 11.3779],
        [21.1712, 19.0196, 19.4724, 23.1375, 28.7647, 18.4187, 19.8527],
        [19.3000, 27.2081, 22.5746, 12.7730, 13.8346, 15.1051, 15.4999],
        [25.8179, 25.4369, 33.7842, 44.6781, 27.7148, 24.4169, 26.5417],
        [16.3596, 17.6825, 28.8982, 34.4933, 19.0206, 16.0145, 17.6684]])
ground truth
tensor([[27.6039, 16.6360, 13.2167, 13.9400, 12.5855, 14.6633, 26.9069],
        [30.5839, 19.9972, 16.8226, 14.9518, 19.6429, 25.4819, 39.3991],
        [23.6323, 17.7670,  6.6018,  9.6528, 11.8622, 10.3630, 15.5971],
        [ 9.2977, 10.5997,  9.6791, 11.2178, 23.9479,  7.9826,  9.1005],
        [20.7058, 23.4410, 26.3464, 25.0850, 32.9507, 16.3124, 19.0618],
        [16.9643, 2

batch_predictions
tensor([[14.1260, 16.8249, 15.4938, 18.2017, 26.2175, 34.4160, 30.8183],
        [15.0920, 14.2590, 12.6724, 12.3511, 17.8015, 29.9106, 28.8179],
        [22.1602, 20.1040, 10.0610, 10.5062, 11.8870, 14.6830, 18.6906],
        [21.6971, 16.7306, 16.6540, 15.7531, 18.6210, 26.9250, 36.9678],
        [16.0709, 18.6593, 22.0594, 29.9384, 30.9228, 22.6190, 20.8115],
        [ 9.4936,  9.1117,  8.5257,  9.1946, 12.7362, 18.9519, 19.2169],
        [20.3220, 20.5728,  9.1741, 11.2634, 12.2693, 12.5260, 17.0452],
        [10.0470,  9.4884,  9.8820, 11.7140, 15.3065, 13.7259, 10.7170]])
ground truth
tensor([[21.1593, 20.4507, 23.7670, 21.2018, 24.2772, 33.4467, 26.0771],
        [11.6922, 13.7188, 10.7001, 11.4654, 20.8192, 39.1723, 29.2234],
        [19.4897, 16.6491,  7.9169, 13.4140, 10.9022, 12.7170, 19.4371],
        [29.0249, 18.1264, 22.3356, 18.6650, 28.2455, 32.3696, 46.3152],
        [12.2436, 20.7654, 27.5907, 31.3519, 37.8222, 21.5018, 21.9095],
        [ 8.9853,  

batch_predictions
tensor([[20.7737, 11.9386, 11.7980, 11.1424, 11.3787, 15.5163, 25.0013],
        [ 9.8702, 12.4093, 11.6015, 11.0104, 12.6123, 20.7155, 27.4404],
        [36.3090, 29.2736, 15.7401, 16.0210, 15.3682, 15.7296, 23.1644],
        [13.1943, 12.8172, 12.6936, 13.7502, 18.5467, 19.6436, 12.7812],
        [24.2418, 17.7259, 10.5812, 10.6761, 11.2997, 12.4118, 18.5929],
        [26.3461, 40.0069, 30.3675, 12.9673, 16.5958, 15.6015, 16.4303],
        [11.2871, 11.9428, 13.4845, 15.1198, 19.1623,  9.6012,  9.9028],
        [15.9710, 22.3269, 23.2745, 11.2302, 11.8189, 12.0059, 12.2422]])
ground truth
tensor([[21.2585, 12.2024, 11.8622, 10.2041, 14.4841, 18.5941, 28.0045],
        [12.0726, 13.9532, 14.2951, 14.7159, 14.4924, 12.1515, 20.7128],
        [40.5471, 30.4989, 18.8209, 18.5374, 12.8401, 17.5454, 30.0028],
        [16.3467, 13.0326, 15.2814, 16.8990, 24.1189, 19.5949, 14.7554],
        [28.5113, 23.1326, 14.5450, 12.3093, 13.5587, 14.6502, 20.1210],
        [31.7035, 4

batch_predictions
tensor([[16.6675, 30.9444, 39.8842, 33.5528, 12.8492, 17.8409, 16.1577],
        [15.6325, 16.1260, 18.0070, 26.2289, 25.0148, 12.0123, 16.4817],
        [ 9.6784, 14.5645, 17.3075,  7.4857,  8.4542,  9.3355,  9.5856],
        [20.7858, 11.7578, 12.7728, 12.6945, 16.1226, 26.5998, 29.8740],
        [ 8.4693,  9.9892,  9.7314, 10.1473, 15.3085, 21.3288, 18.9312],
        [21.7565, 19.6647, 17.7009, 17.5608, 22.8787, 39.6466, 42.7221],
        [13.0256, 19.7412, 27.7992, 24.2615, 12.1944, 14.1623, 13.5169],
        [16.2643, 18.4469, 30.4008, 39.3387, 28.6132, 15.4609, 18.8405]])
ground truth
tensor([[15.3472, 33.8506, 45.4366, 79.2346, 34.7054,  8.9164, 16.4256],
        [19.8696, 19.7988, 19.1185, 24.0788, 35.9836, 36.4796, 16.2982],
        [ 9.9816, 17.0437, 18.3456,  4.0373,  8.1799,  8.3509,  8.4955],
        [26.4456, 12.4575, 12.0040, 11.3379, 14.1865, 26.2897, 42.2902],
        [ 6.6018,  9.6528, 11.8622, 10.3630, 15.5971, 17.4119, 16.7149],
        [32.0862, 1

batch_predictions
tensor([[16.7807, 14.4790,  8.9064,  9.1724,  9.2028, 11.1259, 13.2284],
        [10.5881, 13.2753, 12.1331, 11.5652, 15.6587, 23.8104, 25.8820],
        [18.8078, 19.7460, 28.2840, 40.1311, 25.3550, 17.1247, 19.6891],
        [13.3850, 13.5822, 13.7233, 18.5902, 26.5424, 25.9292, 12.7553],
        [19.7960, 23.3545, 13.4504, 11.5405, 11.9739, 11.9780, 13.5534],
        [ 8.7794,  9.6324, 11.0491, 11.6922,  9.5189,  7.5206,  8.7342],
        [16.9639, 24.3090, 22.3156, 10.7882, 10.1009, 11.1263, 12.0122],
        [11.7982,  6.3567,  6.2954,  7.0638,  7.2793,  8.8678, 10.0372]])
ground truth
tensor([[21.8537, 18.3107,  8.8577,  9.5805,  8.8294, 10.4450, 12.3866],
        [ 9.0479, 11.6781, 14.9921, 12.0594, 20.2656,  2.2883, 32.4040],
        [19.7133, 22.2383, 26.6176, 51.0389, 45.8969, 22.3435, 16.8858],
        [13.5455, 14.6896, 15.8601, 16.1362, 25.6312, 30.2735, 13.9400],
        [23.1009, 24.7874, 12.4575, 14.3566,  2.7636,  0.0000,  8.5601],
        [ 8.6007,  

batch_predictions
tensor([[18.2556,  8.8911,  9.0404, 11.8134, 13.2796, 15.7792, 20.0152],
        [10.5962, 14.7925, 13.7701, 14.2820, 18.4368, 23.5856, 19.0495],
        [12.7058, 14.2560, 17.6131, 25.4177, 21.6407, 10.1483, 13.9265],
        [17.9487,  9.2198, 10.5090, 11.4729, 12.8768, 17.5974, 22.2253],
        [10.7815, 16.6631, 14.9213, 16.6532, 21.0416, 28.8418, 31.1109],
        [13.6394, 14.0927, 18.1597, 26.9766, 27.5781,  8.0228, 15.8312],
        [10.7376, 13.4528, 13.3016, 15.7149, 21.8360, 27.7171, 20.9514],
        [12.7323, 15.5143, 13.5622, 14.2325, 24.2532, 37.1205, 30.5124]])
ground truth
tensor([[21.0743,  8.9711,  3.0754, 17.3469, 14.3707, 14.1156, 19.5011],
        [ 6.1933, 10.5867, 17.9989, 17.5028, 16.9926, 25.2551, 26.4314],
        [11.3804, 19.9830, 15.2778, 24.3764, 21.9104, 12.0465, 14.9518],
        [16.6754,  0.0000,  6.3914, 10.8627, 16.9648, 18.5692, 26.5781],
        [ 7.1145, 16.0856,  9.7506, 16.0714, 20.5357, 25.2409,  0.7370],
        [11.2954,  

batch_predictions
tensor([[11.3280, 13.1063, 12.5794, 13.4157, 17.7163, 25.4968, 25.3526],
        [13.6844, 17.3322, 22.8095, 24.3536, 10.0727, 12.0506, 13.4854],
        [17.7911, 17.4088, 16.1466, 15.3357, 28.6763, 46.7071, 44.4022],
        [18.1649, 22.5885, 19.2960, 10.2115, 11.9974, 13.1338, 15.0715],
        [12.4557, 12.6656, 13.3022, 16.0886, 23.7422, 24.2614, 13.5841],
        [18.9957, 18.4693, 15.3672, 14.6955, 21.8380, 40.7526, 40.0650],
        [21.9315, 23.7536, 30.9362, 34.5562, 15.4799, 19.2451, 20.2050],
        [17.2815, 26.2881, 26.5615, 12.2035, 14.3425, 14.8201, 16.1984]])
ground truth
tensor([[11.2528, 11.9473,  9.9915, 12.7409, 12.9960, 23.9938, 21.4853],
        [11.9937, 19.4240, 25.7496, 26.1704, 13.9663, 11.6386,  9.6791],
        [15.9390, 26.8017, 17.4382, 14.9527, 35.9022, 69.7265, 53.3535],
        [22.4356, 15.4261, 22.6197, 10.2052, 11.7307, 12.5329, 14.9264],
        [14.0847, 13.1641, 17.3330, 22.2120, 21.3309, 26.4335, 13.1510],
        [23.9229,  

batch_predictions
tensor([[17.5977, 17.7667, 15.7349, 17.8952, 32.6955, 43.5736, 32.0873],
        [19.3895,  7.8646,  8.2918, 10.0478, 10.8545, 12.0468, 16.1840],
        [13.5162,  9.0896, 10.7089, 10.7564, 11.1988, 17.5620, 27.3194],
        [11.0697, 12.2090, 18.4381, 25.1362, 22.1935,  8.9823, 11.3930],
        [20.7573, 11.1927, 11.4968, 11.3380, 12.1111, 15.8106, 20.4535],
        [40.5068, 32.9792, 15.1257, 17.3285, 17.5635, 17.8012, 24.7638],
        [14.9666, 14.5280, 15.1675, 25.5836, 36.1127, 25.0915, 14.6103],
        [23.6730, 21.8816, 10.0783, 11.8094, 12.7281, 14.3942, 18.1923]])
ground truth
tensor([[19.7562, 17.4745, 11.5079, 23.3985, 36.9189, 42.7579, 27.2959],
        [ 8.8506, 28.3272, 16.2546, 14.0715, 12.7564, 15.0710, 12.5460],
        [25.6444,  6.5229,  9.9816,  9.2451,  8.7980, 14.3477, 18.2930],
        [13.8743, 12.6644, 18.6875, 24.4477, 25.5523, 10.5208, 12.8748],
        [17.7538, 11.2704, 10.8759, 13.9400, 14.9921, 15.5050, 23.6323],
        [36.7914, 2

tensor([[20.3517, 18.3251, 25.5271, 33.0918, 28.6513, 13.5308, 19.5793],
        [ 6.7917,  7.0294,  7.5917,  9.3975, 12.8236, 15.3237,  6.7263],
        [28.8307,  9.1573, 14.2888, 14.1786, 14.5208, 20.2894, 26.2315],
        [17.4993, 29.3993, 32.1042, 22.4583, 12.5092, 15.5037, 13.5142],
        [10.6833, 12.3039, 14.9445, 18.3311, 20.8901,  7.0786,  9.5278],
        [16.7167, 18.5144, 17.5787, 18.8546, 29.1513, 42.1394, 35.1543],
        [12.0152, 12.7576, 14.8180, 20.9503, 24.2843,  9.2792, 13.0598],
        [10.2671, 13.5364, 19.0084, 18.3009,  7.7009,  7.8360,  9.9519]])
ground truth
tensor([[19.7137, 17.7579, 23.9371, 41.9359, 35.7285, 14.9376, 18.9768],
        [ 8.6928,  6.9306,  6.4308,  8.3640, 12.0068, 15.9127,  3.9716],
        [32.6407,  6.4703, 15.3077, 15.5576, 17.5302, 25.6970, 20.2525],
        [18.3107, 36.4229, 49.4615, 31.5193, 11.6922, 15.3770, 10.5726],
        [13.7691, 10.0868, 14.3346, 18.4245, 22.4882,  5.9837, 11.1915],
        [13.9598, 16.2415, 18.7500, 1

ground truth
tensor([[21.8537, 32.0011, 47.7041, 42.8713, 26.6440, 21.1026, 21.2868],
        [20.0822, 16.9218, 17.5454,  3.1746, 34.7364, 18.7925, 14.1156],
        [ 9.7449, 11.4150, 12.0594, 14.0847, 20.7128, 23.3956,  5.7601],
        [11.3624, 18.9900, 27.3672,  5.6944,  9.9421, 12.3225, 12.0068],
        [18.1746, 11.9411, 18.4114, 25.6970, 39.1899, 27.2488, 18.0037],
        [13.7165, 17.5697, 16.5176,  5.5497,  9.0347,  8.7980,  9.7843],
        [20.4103, 11.5466,  9.7449, 10.9153, 13.9926,  4.5502, 26.9726],
        [14.7554, 13.8085, 12.9274, 16.1362, 19.4240, 27.2225,  5.4182]])
batch_predictions
tensor([[14.3242, 16.5498, 20.9678, 27.4980, 26.0930, 12.0069, 15.9840],
        [13.2366, 13.0087, 13.9747, 21.0017, 30.8275, 27.7346, 11.7267],
        [19.8987, 12.8744, 13.7415, 14.1978, 14.5197, 16.0021, 27.1164],
        [11.4630, 12.6286, 12.6454, 13.5314, 20.7194, 25.3475, 12.1593],
        [20.3581, 21.6418, 15.8998,  9.5313, 10.0175, 10.9955, 13.5915],
        [22.4716, 3

batch_predictions
tensor([[15.0374, 14.0139, 13.2636, 14.7745, 17.9493, 24.4398, 26.1154],
        [26.5123, 20.2127, 10.8695, 11.7420, 11.7741, 13.1103, 20.1834],
        [14.7517, 13.5229, 13.2292, 12.6213, 13.5378, 21.6705, 28.2026],
        [45.9141, 42.1932, 25.3615, 24.3664, 23.9397, 24.9147, 31.9265],
        [13.4990, 15.9766, 20.8496, 25.7623, 25.7303,  8.9703, 14.1120],
        [16.5340, 15.8306, 16.1439, 29.2317, 43.4546, 34.7027, 16.7228],
        [ 6.8880,  7.9039,  7.4440,  7.8861, 10.1163, 12.8751,  9.4153],
        [23.3735, 29.6802, 32.5280, 24.6508, 20.1245, 16.5015, 18.8373]])
ground truth
tensor([[13.1510,  3.7612, 10.0605, 16.7938, 18.1483, 19.8448, 22.9879],
        [28.6990, 25.4393, 13.0385, 14.8810, 12.0748, 15.7596, 22.6332],
        [27.2392, 13.2370, 14.5975, 13.1236, 17.0210, 22.6616, 27.0408],
        [53.4297, 50.6519, 31.7319, 24.9433, 20.9042, 23.3277, 36.7914],
        [16.0998, 16.6525, 17.7579, 24.2489, 24.4898, 12.2449, 14.6825],
        [21.7545, 1

batch_predictions
tensor([[12.1246, 18.7879,  6.9204,  7.1519,  8.8506,  9.3867, 10.4112],
        [14.5241, 16.2213, 24.1320, 29.1300, 26.3852, 11.2986, 15.9650],
        [ 8.9721,  8.0518,  8.5657,  8.5591,  8.6749, 10.1243, 14.7428],
        [13.3424, 14.1760, 12.5909, 12.7925, 18.9458, 27.9258, 28.5218],
        [17.7080, 13.7753, 10.8163, 10.9718, 10.5510, 11.6663, 14.3198],
        [15.8956, 14.0681, 13.0706, 12.9969, 16.9313, 25.6182, 26.7937],
        [11.8605, 12.4416, 14.6988, 22.3816, 23.0026, 11.0784, 13.6779],
        [14.6746, 10.7802, 10.0306, 10.0481, 12.8496, 19.4897, 22.1031]])
ground truth
tensor([[17.1094,  2.5118,  8.2194, 11.5203,  9.0479,  9.4819, 13.4666],
        [11.8490, 11.6123, 23.8690, 29.6817, 34.0084, 11.8753, 13.8743],
        [ 8.1491,  7.6672,  7.9365,  9.0986,  7.2279,  9.6372, 15.0935],
        [11.6518, 16.6097, 13.5981, 14.2162, 18.2141, 22.2646, 29.2215],
        [21.8569, 14.4529,  9.5345, 12.7959,  8.8638, 11.3756, 14.2031],
        [15.1786, 1

batch_predictions
tensor([[18.6199, 16.9062, 17.7945, 22.8823, 41.5458, 41.1571, 18.8163],
        [25.5775, 14.0110, 13.5806, 12.7919, 13.0082, 19.9977, 28.5603],
        [16.3682, 24.5890, 24.8721,  8.9834, 13.1114, 13.7632, 13.8526],
        [14.1661, 22.2713, 29.4778, 22.2003, 10.9268, 13.4147, 12.9463],
        [21.6125, 32.8754, 27.4190, 12.4367, 14.9954, 14.8735, 15.4216],
        [ 9.5186, 10.3415, 12.5913, 15.8217, 16.4473,  8.9032,  9.3457],
        [11.5799, 11.3597, 14.1001, 25.4609, 31.1071,  9.4844, 15.1526],
        [25.4183, 12.2460, 12.9018, 12.6942, 14.2913, 20.1755, 27.1445]])
ground truth
tensor([[1.2358e+01, 1.6185e+01, 1.8452e+01, 3.3688e+01, 4.7506e+01, 5.1701e+01,
         2.0181e+01],
        [2.1953e+01, 0.0000e+00, 4.2517e-02, 0.0000e+00, 5.1871e+00, 1.2982e+01,
         2.8855e+01],
        [1.9416e+01, 2.9889e+01, 3.6154e+01, 4.9603e+00, 1.5646e+01, 1.1295e+01,
         1.8424e+00],
        [1.5873e+01, 2.2761e+01, 3.1888e+01, 3.3376e+01, 1.4753e+01, 2.4150

batch_predictions
tensor([[16.6808, 21.5674, 23.2264, 13.2220, 14.0519, 14.3950, 15.2522],
        [17.0505, 32.3405, 39.3258, 29.6614, 13.4372, 16.9343, 15.3822],
        [13.7343, 13.8539, 14.8633, 20.4248, 28.6338, 27.7854, 13.2552],
        [13.2155, 23.9143, 23.1039,  9.4729, 10.6278, 11.5528, 11.8988],
        [16.6085, 24.7980, 25.7465, 10.5535, 12.9506, 13.6684, 13.2385],
        [11.1333, 10.9771, 13.7932, 21.3123, 20.3929,  8.3697, 11.4063],
        [14.9786, 16.1613, 15.3323, 16.7474, 19.6958, 30.8019, 31.4605],
        [13.4539, 16.4154, 14.8219, 14.8052, 16.6473, 28.8689, 37.6486]])
ground truth
tensor([[28.0247, 14.8211, 32.1278, 22.5013, 23.4482, 20.7522, 14.1636],
        [15.9722, 36.5363, 51.1621, 36.5930, 17.1344, 17.1202, 18.2398],
        [15.5896, 17.4887, 21.3010, 23.4410, 32.1429, 34.0278, 17.5312],
        [10.5997, 22.3041, 22.2777,  7.4303,  9.9027,  7.9695,  9.9421],
        [15.7680, 18.2404, 28.8138,  9.3372, 10.9942, 11.6386, 10.9942],
        [13.3877, 1

batch_predictions
tensor([[20.8637, 12.3034, 12.0947, 11.6337, 11.5209, 14.1979, 21.3068],
        [19.9120, 14.1230, 10.3725,  9.8600, 10.2096, 11.7964, 16.2665],
        [11.5137, 15.5087, 14.2286, 15.0840, 19.9399, 29.0952, 30.2412],
        [10.2466, 10.5067, 10.8064, 14.7512, 21.3617, 18.2978, 10.2451],
        [17.5006, 24.0552, 11.7799, 10.9559, 12.6252, 12.6789, 12.8356],
        [12.9195, 14.7009, 12.7464, 14.2513, 25.2509, 35.6218, 26.0043],
        [14.7349, 15.2861,  7.5529,  7.2946,  7.2254,  8.1611, 10.8935],
        [13.2108, 13.5022, 18.9654, 27.2625, 24.7158, 11.8553, 15.8939]])
ground truth
tensor([[23.2378, 13.7559, 12.2173,  9.7712, 14.2031, 16.9253, 19.5029],
        [19.8185, 13.3482, 11.2441,  9.6923, 10.5339, 13.6639, 12.3225],
        [ 8.4325, 17.2477, 13.5488, 18.1406, 25.6236, 21.7687, 32.6247],
        [ 9.1978, 11.7772, 11.6638, 19.3027, 26.6156, 28.5289, 15.9439],
        [18.0694, 26.0652, 22.9748,  9.5345, 14.9001, 14.0189, 14.5055],
        [11.8764, 1

batch_predictions
tensor([[15.0339, 29.8561, 39.0401, 28.8648, 14.9560, 16.7396, 13.3914],
        [11.3967, 12.2511, 16.8570, 25.6955, 23.2886,  8.4403, 13.5592],
        [19.1514, 21.7194, 27.6387, 29.4536, 17.6724, 17.3908, 18.2522],
        [10.5178, 11.9328, 15.5757, 22.5190, 25.0552, 15.5908, 10.5175],
        [23.3313, 24.9527, 11.8180, 12.7862, 12.6853, 13.0429, 16.8940],
        [41.0620, 36.7979, 15.6451, 17.7271, 17.8632, 16.8180, 24.5242],
        [10.3870, 11.0605, 11.4470, 15.0048, 21.6172, 22.4296, 10.9629],
        [12.5518, 18.6721, 27.8460, 23.1907,  8.7284, 12.2813, 12.6750]])
ground truth
tensor([[20.2239, 32.6814, 48.1151, 36.6497, 19.9405, 20.0113, 15.5471],
        [11.1915, 12.3948, 12.3027, 24.9737,  7.7196,  6.7464,  9.1005],
        [16.9385, 23.6718, 35.0342, 15.3077, 49.8948, 23.3167, 28.6691],
        [12.2591, 13.4212, 16.8084, 23.7103, 28.1888, 13.1236, 11.2954],
        [25.3968, 28.3588, 10.8135,  8.6876,  9.1978, 13.0527, 19.8696],
        [45.6602, 4

batch_predictions
tensor([[16.7533, 22.1279, 24.5214,  9.3613, 12.6483, 13.9096, 14.2614],
        [17.1387, 24.7036, 26.5555, 13.1761, 14.4445, 14.5299, 14.4266],
        [12.7594, 12.3043, 13.3847, 18.6465, 24.7480, 24.0907, 12.6417],
        [11.6876, 12.2218, 15.0100, 20.5943, 20.9297, 10.2310, 11.7820],
        [16.7548, 21.8112, 21.0086, 13.0236, 11.7505, 11.9525, 12.6539],
        [14.4591, 15.5548, 13.2833, 13.4770, 23.7456, 31.6593, 30.6915],
        [14.1467, 17.8332, 23.7102, 23.7222, 17.0860, 10.1783, 12.8616],
        [26.8956, 21.5331, 11.7261, 13.0703, 13.2507, 14.9436, 24.5563]])
ground truth
tensor([[12.6118, 16.2941, 24.4740,  8.3903, 12.1515, 13.2562, 10.8101],
        [16.9122, 21.1862, 26.3809, 12.2041, 12.2436, 16.3204, 13.3088],
        [16.9926, 15.3486, 15.1502, 22.9450, 26.7999, 28.0187,  7.1287],
        [11.0119, 13.8322, 17.1485, 26.1763, 31.2358, 19.1893, 18.1264],
        [20.4649, 21.3010, 22.8883, 13.6763, 12.8685, 18.8492, 13.1661],
        [11.8753, 1

batch_predictions
tensor([[22.3229, 11.3055, 10.5518, 11.5624, 12.1010, 15.3068, 18.5255],
        [ 7.2346, 10.3876, 11.0751, 12.1613, 14.4710, 18.0741,  7.9968],
        [12.8600, 18.2907, 24.9695, 24.2465,  8.6354, 11.8112, 13.0782],
        [13.4383, 13.8902, 16.1816, 24.4399, 35.5832, 33.6547, 10.1871],
        [12.9125, 15.4369, 18.1475, 18.9860, 10.1911,  8.9276, 11.4295],
        [13.4552,  9.5487,  9.9008, 10.1877, 11.6635, 16.3627, 21.0355],
        [10.2515, 12.0667, 10.6207, 11.2468, 20.0222, 31.7199, 23.2096],
        [21.8600,  9.8514, 11.9109, 11.8995, 12.2154, 16.5759, 24.1918]])
ground truth
tensor([[27.0516, 31.2730, 10.7969, 14.8869, 10.3104, 15.8864, 16.3204],
        [ 3.2880, 11.2245, 11.7063, 10.6151, 20.0255,  3.2313,  4.5210],
        [12.2567, 19.0952, 27.7749, 24.9737,  4.9842,  6.8911, 12.8748],
        [21.1876, 16.9926, 16.9785, 25.0000, 29.3651,  4.2234,  0.0000],
        [23.4219, 18.0431, 18.1089, 22.0410, 15.2551,  9.9684, 13.3745],
        [24.4871,  

batch_predictions
tensor([[ 7.9354,  6.4434,  6.3108,  7.0205,  7.1910,  7.2644,  7.7510],
        [39.2017, 25.5783, 24.2011, 24.0178, 27.7332, 37.6434, 52.3562],
        [12.8699, 13.2372, 14.6691, 21.2389, 29.4680, 29.8553,  8.1435],
        [26.6505, 27.0319, 34.8497, 45.5743, 42.1582, 27.1564, 27.5212],
        [19.6241, 18.2825, 16.6472, 17.8976, 25.1545, 39.8049, 33.9818],
        [27.2095, 11.6568, 13.2311, 13.4121, 14.1654, 16.9879, 25.2465],
        [ 7.0839,  7.0011,  7.8869,  9.7706, 14.4328, 18.9946,  6.1287],
        [37.4456, 28.7837, 19.8799, 19.6720, 18.8459, 20.7826, 29.9176]])
ground truth
tensor([[ 7.9365,  7.9082,  5.2721,  7.7239,  7.5113,  7.8656,  8.9569],
        [41.7800, 34.2404, 24.2205, 21.4002, 28.7415, 36.1961, 60.1899],
        [11.5788, 15.2494, 17.6729, 23.3277, 30.2438, 41.8651, 12.7126],
        [22.0947, 30.4138, 45.9892, 66.4683, 53.0896, 40.1644, 34.5663],
        [23.3693, 21.5018, 20.9758, 22.0542, 28.3666, 43.2930, 37.1910],
        [35.3175, 3

batch_predictions
tensor([[12.7975, 12.5073, 13.1664, 16.9403, 23.1473, 23.8910, 10.9103],
        [19.2160, 19.8183, 26.7869, 37.8889, 28.1694, 16.1388, 18.4695],
        [12.3687, 13.2859, 13.3878, 15.7655, 24.9741, 30.4453, 12.4573],
        [21.7899, 24.1327, 10.3984, 12.5570, 13.4656, 13.1433, 17.3161],
        [23.4543, 26.5722, 23.3131, 22.1277, 20.9582, 22.4186, 33.6407],
        [16.6629, 25.4191, 38.9939, 29.6952, 13.3319, 16.7825, 16.0694],
        [12.5637, 11.5039, 11.6297, 13.1751, 14.2227, 12.5517, 11.8555],
        [22.7002, 22.8509, 32.2732, 47.6905, 45.5400, 25.1100, 25.5519]])
ground truth
tensor([[1.4335e+01, 1.4821e+01, 1.8096e+01, 1.5255e+01, 2.5934e+01, 3.1142e+01,
         1.1257e+01],
        [1.2075e+01, 1.8211e+01, 2.5652e+01, 4.9617e+01, 3.5856e+01, 0.0000e+00,
         1.2061e+01],
        [2.6302e-02, 1.3440e+01, 1.5189e+01, 1.6189e+01, 3.0313e+01, 4.1557e+01,
         3.0208e+01],
        [2.3369e+01, 2.7301e+01, 1.0337e+01, 1.2322e+01, 1.2585e+01, 0.0000

batch_predictions
tensor([[16.7968, 16.8413, 26.4608, 34.6928, 22.7367, 14.9257, 18.0670],
        [13.1610, 21.5190, 28.7725, 19.3933, 11.1753, 12.6656, 12.3441],
        [22.9860, 26.8038, 11.2916, 13.1722, 14.1947, 14.8992, 16.0959],
        [13.9323, 14.6340, 14.3907, 15.5717, 21.3805, 29.2393, 24.7839],
        [11.0395, 11.9553, 12.5641, 17.1652, 21.4777, 20.1039,  9.2298],
        [15.7315, 12.4604, 11.4058, 10.9841, 13.1214, 20.1351, 24.5799],
        [11.2652, 12.3758, 16.3795, 22.7238, 21.8710, 12.4980, 12.8548],
        [39.1397, 25.2156, 22.3927, 22.3178, 23.0506, 27.1560, 38.6698]])
ground truth
tensor([[15.3486, 12.7268, 33.6735, 54.1383, 29.3793, 15.5896, 18.8776],
        [16.1625, 24.7370, 37.9669, 22.0805, 14.3346, 17.7801, 13.7296],
        [23.9348, 30.8916,  4.6949,  9.2714, 15.4261, 14.6633, 19.0558],
        [11.6213, 14.4416, 10.9977, 14.1582, 25.2834, 31.1366, 24.3339],
        [12.6701,  9.8356, 10.4308, 12.6701, 22.4065,  4.6769,  8.7018],
        [ 0.0000, 1

batch_predictions
tensor([[16.7063, 20.9688, 27.8556, 24.3810, 17.6239, 14.6761, 14.1814],
        [14.4237, 19.8827, 29.4950, 25.7916, 11.1736, 14.6078, 13.9407],
        [17.8465, 20.5712, 23.4384, 23.3354, 22.2791, 14.5516, 16.3365],
        [15.8292, 19.1507, 22.2692, 21.7195, 11.7818, 13.8261, 14.3021],
        [17.0515, 18.6315, 16.9412, 18.2321, 30.4658, 42.9854, 33.8937],
        [21.5728, 28.0355, 25.5217, 13.7767, 15.2214, 15.5665, 15.8313],
        [23.6066, 20.5107, 12.4321, 12.8142, 12.8547, 13.7760, 18.0239],
        [ 9.0605, 14.4190, 12.6645, 12.7563, 16.7107, 26.4429, 29.4387]])
ground truth
tensor([[25.3401, 13.2795, 31.7602, 29.5493, 12.5142, 18.2398, 19.2885],
        [15.2778, 22.8883, 32.8940, 28.9683,  9.5096, 13.0952, 13.3929],
        [19.0334, 18.0272, 23.9938, 11.1111,  8.6026, 10.0198, 14.0731],
        [17.1883, 21.1073, 24.3951, 26.5913, 11.8622, 16.8464, 15.7417],
        [16.5675, 18.0981, 18.9768, 17.7721, 37.5000, 47.3073, 31.9586],
        [23.0584, 3

batch_predictions
tensor([[11.1980, 11.0908, 14.4869, 18.9915, 17.4508,  9.1518, 11.4248],
        [27.0884, 13.8373, 15.1801, 14.3247, 15.0931, 26.4833, 38.3041],
        [13.6283, 14.6697,  7.5058,  7.0302,  7.3536,  7.8975, 10.8755],
        [ 8.7121, 14.1151, 12.9474, 12.9413, 17.5527, 25.3946, 28.3121],
        [15.6568, 14.7379, 14.1641, 17.8357, 30.2001, 33.2050, 15.8953],
        [19.9218, 24.8271,  9.7717, 10.7621, 10.8999, 11.7408, 13.6752],
        [10.4332, 10.8952, 13.7749, 19.6016, 18.3979,  9.6525, 12.0391],
        [17.3131, 17.6561, 15.9645, 17.3152, 26.5814, 39.9858, 32.4737]])
ground truth
tensor([[ 8.6451, 13.3787, 17.2052, 21.9671, 16.8651,  9.2404, 11.3662],
        [25.9732, 12.2699, 13.0852, 13.8480, 15.1631, 32.6144, 24.3293],
        [14.7291,  0.0000,  0.0000,  7.8380,  8.8375,  8.4824,  9.5082],
        [ 4.7902, 14.3282, 11.8481, 11.6780, 20.0255, 26.5448, 31.3776],
        [19.1043, 16.6950, 19.1752, 21.9671, 32.2137, 35.4308, 18.0414],
        [24.2109, 2

batch_predictions
tensor([[16.2135, 12.7656,  6.6527,  6.8482,  7.1480,  8.3778, 12.6226],
        [19.3382, 10.3870, 10.3125, 10.0923,  9.8693, 11.9449, 20.9734],
        [18.3623, 27.2281, 40.7275, 32.5656, 14.9090, 18.6613, 17.0892],
        [14.5271, 15.4365, 16.9927, 20.1922, 25.7467, 25.5920, 13.0132],
        [27.0336, 22.0184,  9.8256, 11.0104, 11.4343, 12.1024, 17.1853],
        [ 8.5952,  8.4722,  8.7982, 11.1695, 18.0654, 20.0507,  8.4419],
        [26.8666, 27.3519, 12.3577, 13.9046, 13.4865, 13.4879, 17.0425],
        [24.4251, 19.2273, 11.5245, 11.7870, 11.9871, 12.0489, 18.4965]])
ground truth
tensor([[20.3577, 15.1105,  0.0000, 12.4934,  7.8511,  8.6402, 12.6512],
        [17.7670,  2.3672,  5.1946,  9.1662, 12.0726, 16.6097, 23.8690],
        [18.5955, 23.9216, 45.0158, 39.7291, 20.7785, 13.8348,  9.7449],
        [19.5292, 20.5287, 18.7007, 24.7764, 15.6234, 19.9500,  7.7065],
        [31.3350, 21.8112,  9.6372, 14.1015, 11.4229,  9.9632, 18.2398],
        [ 7.9432,  

batch_predictions
tensor([[13.2404, 14.3318, 19.8047, 24.1322, 18.9883, 11.8294, 14.1702],
        [11.6499, 15.0452, 13.3379, 15.0141, 17.9501, 27.8048, 29.2215],
        [25.2785, 10.1212, 12.5697, 11.6539, 11.8472, 17.3167, 27.9065],
        [19.3541, 19.7214,  9.4668,  9.5423,  9.8371, 10.5312, 12.5803],
        [19.1050, 14.1754, 13.2358, 11.6467, 11.8803, 18.5787, 27.1242],
        [ 8.2951, 10.2502, 14.3915, 17.0658,  6.5888,  7.0532,  7.4590],
        [10.7244, 11.4488, 15.0463, 22.3911, 23.3613, 13.1204, 12.5042],
        [10.8728, 12.4621, 18.2106, 22.7511, 17.5827,  8.9109, 11.5622]])
ground truth
tensor([[11.1536, 16.6383, 23.2426, 26.3605, 27.1117, 14.9943, 14.9943],
        [10.7969, 14.8869, 10.3104, 15.8864, 16.3204, 27.2883, 31.1810],
        [27.4592,  8.3377, 13.3482, 11.1915, 12.3948, 12.3027, 24.9737],
        [17.8064, 19.3188, 10.0473,  9.4950, 10.7443,  7.7722, 10.5865],
        [20.8192, 14.7534, 14.6967,  8.7727, 12.8260, 19.3594, 29.2517],
        [ 6.2073, 1

batch_predictions
tensor([[40.2862, 41.1156, 31.8444, 29.3551, 29.3414, 29.5553, 36.7004],
        [20.6067, 16.4757,  8.4937,  9.3746, 10.2639, 10.8855, 15.4256],
        [16.4683, 15.5949, 20.7510, 33.8282, 34.4258, 18.4399, 19.2575],
        [10.3707, 12.2308, 16.7104, 20.9378,  8.9463,  7.9398, 10.3101],
        [20.3839, 33.1952, 28.9678, 12.5452, 15.5467, 14.9614, 14.1551],
        [ 8.7438, 10.3904, 11.7186,  7.9411,  6.7195,  7.4364,  7.9762],
        [34.2045, 22.4613, 13.4228, 14.3219, 13.7978, 15.6048, 24.8069],
        [16.2746, 18.8182, 21.1175,  8.0680,  9.3405, 11.5778, 12.7644]])
ground truth
tensor([[5.6418e+01, 6.2230e+01, 6.6531e+01, 2.0134e+01, 5.2604e-02, 2.3435e+01,
         4.1781e+01],
        [2.0239e+01, 1.8490e+01, 4.6028e+00, 8.8375e+00, 8.7585e+00, 7.8117e+00,
         1.3769e+01],
        [1.7999e+01, 1.7149e+01, 1.5448e+01, 3.8832e+01, 3.9654e+01, 2.1060e+01,
         1.7163e+01],
        [7.6539e+00, 1.1888e+01, 1.3927e+01, 1.8780e+01, 8.9427e+00, 9.0084

batch_predictions
tensor([[16.2403, 21.9851, 22.1461, 12.1590, 11.2077, 11.6637, 12.4834],
        [18.5047, 17.8016, 19.7778, 27.6008, 31.3170, 18.9833, 18.8679],
        [16.2277, 16.7955, 21.5977, 34.1527, 31.4579, 18.1355, 19.9927],
        [12.8538, 19.1235, 26.4585, 26.2796, 12.0810, 13.7564, 12.9098],
        [14.2845, 13.1667, 13.6439, 17.4546, 23.0909, 15.6282, 12.6107],
        [26.1273, 40.9859, 32.3547, 14.5571, 17.7063, 16.0409, 16.3621],
        [12.3500, 19.3010, 25.8543, 26.1226, 10.0744, 11.9566, 13.1426],
        [24.2829, 12.0706, 13.3701, 12.5770, 13.0542, 17.1086, 26.5528]])
ground truth
tensor([[16.4387, 23.0274, 25.0000, 16.7280, 13.3219, 12.2962, 10.6391],
        [19.0689, 16.9385, 23.6718, 35.0342, 15.3077, 49.8948, 23.3167],
        [19.2319, 12.2166, 31.2075, 31.5618, 41.6241, 23.9654, 26.8424],
        [12.2304, 21.9490, 24.4477, 43.8322, 15.1762, 13.8743, 10.4287],
        [19.6870, 16.8201, 13.1641, 17.9511, 22.3304, 25.5655, 11.5860],
        [32.1854, 5

batch_predictions
tensor([[13.7209, 13.9470, 12.9179, 14.4988, 19.4241, 25.5964, 17.9888],
        [27.0709, 38.0967, 32.9668, 15.7940, 18.9215, 18.4242, 19.2332],
        [11.1657, 13.9153, 12.8544,  7.4735,  7.7823,  9.1813, 10.3648],
        [11.3220, 12.0587, 12.5686, 14.8573, 21.6679, 24.4709,  8.4546],
        [16.4406, 16.1530, 14.5767, 14.5288, 16.3833, 23.6832, 28.2297],
        [12.8429, 16.3128, 23.6414, 24.6618,  8.2880, 12.5771, 13.6963],
        [25.3176, 31.1471, 19.6194, 18.1389, 17.6415, 17.3502, 20.7879],
        [14.9016,  6.4947,  6.4435,  7.1267,  7.6908,  9.1288, 14.9080]])
ground truth
tensor([[17.7933, 22.5802, 16.8201, 14.6239, 15.8075, 26.0389, 27.6433],
        [27.8801, 44.0952, 37.7827, 17.7144, 17.9774, 16.6228, 19.5423],
        [10.6391, 16.1625, 16.2415,  6.8254,  7.1673,  6.3914,  9.8238],
        [15.0053,  9.4424,  0.0000, 12.1910, 12.2173, 28.1957,  4.5502],
        [20.5813, 19.4371, 15.0973, 16.3861, 13.9795, 18.4640, 31.9963],
        [12.9393, 1

batch_predictions
tensor([[18.9841, 20.3541, 10.6125, 10.2498, 10.1998, 10.8138, 12.6447],
        [15.2290, 24.7019, 36.5016, 27.5245, 13.1212, 15.8113, 14.5828],
        [39.4353, 37.7611, 19.3119, 19.6239, 19.5315, 17.8047, 21.7656],
        [19.4314, 19.7033, 21.3351, 28.5039, 35.2307, 20.5919, 20.1782],
        [13.2660, 20.5179, 33.4297, 25.6396, 11.0310, 13.3350, 13.2700],
        [ 9.2026, 13.3193, 12.9164, 12.3355, 16.6387, 27.8209, 20.1309],
        [11.5671, 12.2496, 12.5594, 17.8439, 26.4140, 26.5621, 10.2914],
        [12.4508, 16.6561, 23.1765, 18.4980, 11.7883, 11.7187, 11.2832]])
ground truth
tensor([[23.9479,  7.9826,  9.1005, 10.1131, 11.1915, 13.7296, 14.6107],
        [15.0935, 25.2834, 41.7234, 33.4751, 16.1990, 12.8543, 10.7285],
        [38.1803, 39.3282, 32.0862, 19.8980, 20.1814, 13.4637,  0.0000],
        [28.2746, 24.9079, 29.5634, 34.5739, 45.1604, 18.3456, 21.1073],
        [ 9.9065, 23.2426, 37.8543, 34.7931, 15.3628, 11.8764, 11.2245],
        [ 5.7115,  

batch_predictions
tensor([[12.2276, 14.5936, 12.6599, 12.8252, 18.0886, 26.4736, 29.2345],
        [ 9.6654,  6.7192,  6.2986,  7.4425,  7.3356,  7.7981,  9.8649],
        [40.6148, 34.6503, 18.8885, 19.8417, 19.9761, 20.7283, 29.6720],
        [ 7.6681,  8.5370,  9.9987, 14.4544, 17.9042,  6.4947,  7.2023],
        [37.5335, 29.1005, 17.8270, 18.6615, 18.2132, 19.8213, 26.2844],
        [19.9143, 19.7244, 20.2393, 31.7680, 47.0519, 32.8711, 19.7641],
        [12.3346, 20.3735, 28.2467, 21.8599, 11.5212, 12.6972, 11.7092],
        [ 8.7612,  6.7967,  6.0569,  6.1560,  6.3084,  6.4162,  7.1054]])
ground truth
tensor([[10.7568, 12.1032, 10.2324, 11.6780, 17.7579, 24.2063, 27.2251],
        [14.8606, 10.3235,  7.1278,  5.4182, 16.9779,  8.4429, 11.2572],
        [45.7625, 36.7630,  8.0782, 17.1485, 25.5952, 29.3651, 35.6151],
        [ 6.9437,  9.2188, 10.7706, 13.9269, 20.0947,  3.2614,  9.7054],
        [39.3214, 33.3772, 20.3709, 17.2409, 16.8332, 10.2183, 26.4335],
        [25.3814, 2

batch_predictions
tensor([[11.6142, 16.0888, 20.2101, 18.5912,  9.0709,  9.8197, 11.0919],
        [18.9068, 24.4100, 11.0290, 11.7442, 10.9749, 12.3797, 14.3656],
        [21.4973, 24.8346, 27.6792, 16.9351, 16.6435, 19.3592, 21.4274],
        [16.6344, 27.0679, 34.0439, 27.1398, 12.5109, 15.3171, 14.7713],
        [19.2901, 28.0196, 25.6695, 12.4244, 13.7269, 13.7788, 13.4131],
        [16.2363, 21.1903, 21.8150,  8.4430, 10.9618, 12.5072, 13.5245],
        [17.4591, 21.8503, 25.2383, 14.4786, 13.2387, 15.5784, 16.7712],
        [27.2964, 14.9337, 12.2178, 12.7185, 11.9600, 14.9445, 27.4679]])
ground truth
tensor([[ 9.8764, 14.9790, 24.0137, 20.0552,  7.2593, 10.1131, 13.5850],
        [25.1841, 35.3761,  5.0368, 14.3740,  4.3661, 17.5960, 15.7286],
        [26.3940, 30.8259, 36.9148, 37.9932, 29.0110, 33.5218, 29.6554],
        [15.0794, 27.6644, 40.9155, 36.8764, 14.4983, 19.7988, 16.1565],
        [21.3010, 35.1332, 33.7302, 13.0244, 12.7976, 12.6842, 13.4070],
        [13.4070, 1

batch_predictions
tensor([[19.5774, 20.4409,  9.1369, 10.7021, 11.3786, 11.6536, 14.4068],
        [19.6838, 28.3776, 25.0982, 11.8397, 13.7193, 13.8121, 14.3526],
        [19.5664, 22.1512,  9.4455, 11.3199, 12.3110, 12.9041, 16.3989],
        [14.5466, 18.6777, 26.9421, 26.1270, 17.7768, 13.7187, 14.6766],
        [17.6523, 22.9499, 31.8284, 24.9040, 12.9132, 15.5433, 15.2986],
        [33.8282, 43.1505, 35.5628, 22.9381, 23.5022, 22.8154, 25.2485],
        [26.7020, 11.0319, 13.2258, 13.1246, 15.2128, 22.2091, 28.5771],
        [14.1975, 15.1182,  8.0257,  7.3787,  7.3721,  8.7861, 10.1882]])
ground truth
tensor([[22.4065,  4.6769,  8.7018, 11.5930, 12.9960, 16.1848, 19.3027],
        [19.1468, 27.8203, 29.1525,  3.5856, 10.3033, 13.9881, 13.6338],
        [21.5939, 21.8569,  7.2199, 13.1641, 14.4003, 11.7833, 17.8327],
        [16.3549, 15.3345, 22.6899, 38.6196,  8.4184, 12.9819, 14.3849],
        [17.5737,  4.3226, 37.1457, 34.5947, 14.6117, 19.7279, 14.3991],
        [43.6366, 6

batch_predictions
tensor([[10.5608, 11.1901, 14.3263, 20.6636, 20.0335,  9.9502, 11.9970],
        [24.2979, 15.7543, 15.0933, 14.4382, 14.8251, 21.6581, 34.4554],
        [ 7.3573,  8.2796,  8.7634,  9.7297, 13.8053, 20.2644,  8.0413],
        [26.8610, 21.0160, 13.3483, 12.8201, 12.0181, 13.9983, 20.4826],
        [19.8231, 27.0908, 24.9372,  9.5260, 13.5903, 14.2269, 14.7179],
        [ 8.6284,  8.5957,  9.9085, 15.9125, 19.2643,  7.8699, 10.4982],
        [ 8.8314, 10.0283,  6.7874,  5.8247,  6.8931,  7.4966,  8.2058],
        [40.9014, 39.9580, 41.2682, 45.4451, 49.7848, 48.9083, 41.8437]])
ground truth
tensor([[ 8.7868, 13.9598, 18.4807, 27.9904, 47.6332, 21.8112, 19.8413],
        [28.3447, 19.4444, 15.6463, 13.3503, 17.5028, 19.4870, 34.3537],
        [ 3.8664, 11.5071, 10.0473, 10.1262, 13.4140,  7.0095,  6.8911],
        [29.0675, 37.2449, 25.0709,  7.4688, 14.2007, 15.9580, 21.0317],
        [19.7562, 31.4201, 43.3673, 18.0130, 17.6446, 13.8605, 13.1378],
        [ 6.7728,  

batch_predictions
tensor([[12.0150, 12.1223, 13.0513, 17.0232, 25.5188, 24.6564, 11.9925],
        [20.3391, 10.6859, 10.9551, 10.8877, 11.7398, 17.0116, 25.0213],
        [10.0253, 12.2328, 16.1913, 15.2420,  7.4547,  8.9063,  9.7846],
        [10.6573, 11.7669, 13.1543, 17.9648, 24.4580, 19.5889,  9.9085],
        [ 8.5132, 10.7761, 14.8466, 19.4359, 16.0824,  6.4446,  8.4179],
        [14.7962, 17.3408, 23.7967, 25.2853, 13.2084, 14.6613, 14.9451],
        [32.9596, 35.7744, 38.6546, 39.2469, 31.2524, 29.9711, 30.2969],
        [15.1366, 15.2850, 17.2789, 22.6871, 31.1018, 17.8382, 13.9462]])
ground truth
tensor([[8.6796e+00, 1.0784e+01, 1.6070e+01, 2.1660e+01, 2.8998e+01, 2.9340e+01,
         1.4243e+01],
        [2.3317e+01, 1.2099e+01, 1.2362e+01, 8.1273e+00, 9.4161e+00, 1.4650e+01,
         2.7130e+01],
        [1.0757e+01, 1.2375e+01, 1.5045e+01, 1.6754e+01, 6.2467e+00, 1.5400e+01,
         1.5729e+01],
        [1.0984e+01, 1.8226e+01, 1.1961e+01, 1.6823e+01, 2.5128e+01, 2.6899

batch_predictions
tensor([[19.1339, 32.7045, 30.4395, 10.1893, 16.7981, 15.8294, 15.3846],
        [13.7577, 14.5105, 15.0210, 25.1207, 39.0768, 30.4589, 16.0139],
        [19.7716, 13.0960, 12.6290, 11.8852, 12.6958, 21.4789, 29.9116],
        [27.9955, 16.3451, 15.5899, 14.2862, 16.1051, 28.2253, 40.6838],
        [16.4529, 16.4412, 15.6251, 14.7788, 17.0797, 27.4811, 34.3218],
        [15.4846, 23.2953, 18.3828,  9.8361,  9.3634,  9.4939, 10.3523],
        [14.9962, 14.8282, 18.2481, 27.7211, 27.2317, 11.5851, 16.0956],
        [14.0537, 18.6896, 27.0495, 23.9807, 10.0937, 11.6280, 12.0686]])
ground truth
tensor([[29.7194, 23.5686, 36.9898,  6.6043,  4.1667, 19.2602, 22.3498],
        [14.0023, 14.7392, 15.0368, 21.3152, 43.0414, 33.1207, 16.6950],
        [31.1083, 18.7783, 15.0652, 18.2398, 16.3549, 24.7307, 26.0204],
        [34.5663, 29.7052, 22.5907, 14.9943, 16.0714, 27.9620, 34.4529],
        [13.7165, 11.9542, 21.9227, 16.4256, 18.9242, 39.9001, 36.8753],
        [18.6349, 2

batch_predictions
tensor([[29.6865, 18.2128,  9.3164, 11.6455, 12.6483, 16.3727, 25.5607],
        [17.1765, 12.6448, 12.4912, 12.2256, 14.2924, 20.2721, 25.3170],
        [17.7375, 12.4624, 10.9342, 12.0772, 13.4019, 17.3047, 20.7970],
        [12.2648, 13.8732, 13.5119, 15.5497, 23.0309, 31.7035, 22.7679],
        [11.2019, 13.3217, 21.4893, 28.5257, 19.7590, 10.3718, 13.8902],
        [10.8309, 10.9192, 12.4258, 18.8365, 21.9930, 18.8602,  9.9750],
        [14.1147, 13.8974, 15.8633, 29.6154, 39.4129, 21.3637, 15.1013],
        [17.7495, 17.2953, 14.9473, 15.6830, 24.6613, 38.8292, 33.4081]])
ground truth
tensor([[34.0136, 22.8316,  4.8469, 13.2370, 12.5283, 12.2024, 24.0079],
        [26.3414, 17.7933, 22.5802, 16.8201, 14.6239, 15.8075, 26.0389],
        [29.2741, 29.0110,  8.7585,  9.8501, 19.1741, 18.4771, 18.7927],
        [12.3093,  9.3109, 10.0473, 14.7817, 21.5939, 42.0305, 30.3919],
        [ 8.5034, 10.4450, 18.9342, 28.8265, 21.2018,  9.9206,  9.1553],
        [14.0164, 1

batch_predictions
tensor([[13.5241,  8.7385,  7.3652,  7.9316,  8.7475,  8.2781, 10.9879],
        [15.7482, 15.2460, 14.2597, 17.6828, 25.9196, 29.7505, 22.4969],
        [22.7296, 21.2700, 21.5306, 22.5597, 24.7880, 33.0756, 39.0556],
        [12.5450, 12.1531, 13.3562, 18.7445, 25.5617, 23.7155, 13.1027],
        [28.3798,  9.2082, 14.6691, 13.3894, 13.6734, 22.5402, 33.7782],
        [13.1657, 13.7985, 18.3512, 22.6196, 21.1961, 10.3395, 13.1637],
        [32.0886, 27.4457, 14.8273, 14.7852, 15.3118, 16.7194, 24.8492],
        [12.1465, 13.2691, 16.5577, 23.9880, 23.2052,  8.8102, 13.4649]])
ground truth
tensor([[12.6381,  7.5355,  4.7080,  4.7344,  5.0631,  4.9448,  8.1010],
        [14.5125, 15.1219, 16.3549, 15.3345, 22.6899, 38.6196,  8.4184],
        [18.3456, 21.1073, 22.9616, 22.3304, 27.7223, 30.8390, 40.4655],
        [15.1502, 11.0261, 12.7551, 15.1644, 22.8175, 26.2613, 14.1723],
        [45.8759, 11.9189, 18.8634,  3.6706, 17.7721, 20.7200, 29.7619],
        [11.5505,  

batch_predictions
tensor([[23.3330, 22.5134, 10.3925, 11.5814, 12.0247, 12.2536, 17.2805],
        [10.6901, 12.1396, 11.7480,  7.5149,  7.2774, 11.2391, 14.3057],
        [12.3854, 15.6356, 21.2442, 26.4882, 24.8699,  8.4425, 13.1913],
        [16.1134, 15.6506, 16.1012, 28.1813, 42.9250, 33.9961, 13.4885],
        [12.4600, 14.9540, 12.9261, 13.1776, 20.6109, 32.3511, 29.6451],
        [13.9992, 16.9753, 22.5724, 26.4049, 14.1074, 12.9931, 14.7744],
        [21.3840, 13.7287,  8.5495,  8.8112,  8.8510,  9.7404, 14.5055],
        [27.3372, 11.9689, 14.4843, 13.4146, 14.1208, 21.6405, 31.9081]])
ground truth
tensor([[22.8432, 23.2772, 12.1252, 12.6249, 14.0715, 14.9395, 20.5550],
        [19.3188, 17.5171, 25.8285, 16.5965, 10.4287,  9.4555, 23.4219],
        [10.3458, 13.8605, 13.2653, 30.3288,  4.9745,  5.2438, 12.3583],
        [17.3611, 10.0057, 18.8776, 31.7035, 45.2523, 31.4768, 19.2744],
        [10.2041, 13.8605, 12.7976, 13.5346, 24.1497, 39.2149, 29.1808],
        [16.5958, 1

batch_predictions
tensor([[15.1650,  5.7920,  6.5022,  7.2420,  7.9068,  9.2323, 13.4235],
        [25.9146, 11.5582, 13.9653, 12.8694, 13.4546, 23.2454, 32.9174],
        [18.4954, 25.2680, 24.4288, 13.3233, 13.9097, 14.8674, 16.1288],
        [10.2476, 10.9441, 12.9544, 18.0521, 21.2849, 10.4010,  9.3479],
        [23.6172, 13.6826, 13.0996, 12.5437, 12.1508, 14.8802, 23.3993],
        [12.8397, 17.4989, 16.1178,  7.8157,  8.0935,  8.8143,  9.8050],
        [13.3120, 18.6141, 20.8420,  7.7845,  7.0731,  9.9766, 11.8225],
        [32.9904, 35.7325, 17.5271, 20.2476, 21.0753, 21.2816, 26.0138]])
ground truth
tensor([[18.9637,  2.8932,  8.1931,  4.9842,  8.4561,  9.5476, 15.4261],
        [29.1808, 13.1094, 14.3141, 11.2245, 10.7568, 23.9371, 30.6689],
        [23.8690, 30.7601, 28.3535, 18.5560, 15.0316, 13.0195, 16.5308],
        [12.8222, 11.9411, 11.2572, 15.1105, 16.7806,  8.0615,  3.7217],
        [27.5935, 17.2619, 14.0731, 11.9048, 14.2857, 19.1043, 27.6502],
        [13.3614, 2

batch_predictions
tensor([[15.3932, 14.5812, 15.2704, 17.3643, 17.4261, 15.3581, 15.6072],
        [21.5943, 23.6130, 31.3725, 42.1468, 35.7916, 17.8227, 22.4060],
        [22.4534, 13.8801, 11.8438, 12.7681, 12.6487, 13.7026, 18.5714],
        [ 8.1739, 10.5900, 10.3226, 12.4458, 18.7647, 20.7743,  8.2995],
        [11.7707, 11.4338, 13.0974, 19.9633, 25.1890, 20.3923, 11.4187],
        [13.8035, 20.9260, 27.0460, 21.2293, 10.9393, 12.1626, 12.8470],
        [ 8.7391, 11.2614, 16.6330, 14.6474,  6.4477,  6.5560,  8.8714],
        [11.5313, 11.4780, 11.6502, 18.2889, 28.7117, 23.1852, 12.3608]])
ground truth
tensor([[11.0261, 15.0227, 16.4541, 21.7829, 24.3622, 17.4178, 17.5170],
        [23.6111, 25.1417, 19.1893, 41.2557, 40.3345, 14.7817, 23.7528],
        [26.0389, 27.6433, 16.0179, 14.2820, 11.3493, 14.1373, 16.9122],
        [ 3.7875,  2.3935, 13.2693, 13.2562, 14.7817, 17.4382, 13.7296],
        [13.1378, 11.7772, 14.3141, 21.2302, 22.0805, 23.0726, 12.1032],
        [14.6967, 1

batch_predictions
tensor([[18.4147, 20.1782, 23.8403, 28.7211, 30.8647, 20.2350, 22.0893],
        [22.0300, 11.9085, 13.0314, 12.1481, 11.8892, 18.9673, 26.9775],
        [31.5489, 18.2552, 19.1917, 18.6997, 21.7547, 32.3310, 36.6172],
        [11.8626, 12.1896, 14.6009, 21.4566, 23.4828, 15.0388, 13.1313],
        [12.5422, 13.4446, 18.7538, 28.0246, 30.0129, 11.8679, 16.5034],
        [11.7976, 11.8435, 12.4901, 17.0218, 24.0067, 24.6638, 12.3406],
        [33.0148, 44.4635, 39.3673, 24.6525, 23.0220, 21.9875, 23.8346],
        [14.6784, 17.3697, 24.6814, 29.8870, 12.4739, 15.2795, 16.6750]])
ground truth
tensor([[20.4103, 23.3561, 26.1573, 32.8643, 35.7312, 23.0800, 25.5655],
        [26.1054, 14.0731, 13.0811, 12.6276, 13.6054, 20.0964, 27.3810],
        [35.6293, 19.4303, 24.2205, 16.6525, 18.5516, 32.1003, 49.7874],
        [13.9532, 17.5829, 15.5050, 26.3940, 26.4729, 15.8469, 14.0847],
        [ 6.8648, 12.6775, 19.9500, 30.7601, 33.6533, 12.3225, 15.1236],
        [11.6386,  

batch_predictions
tensor([[12.4125, 17.4991, 26.0767, 23.3414,  9.2424, 11.8754, 12.4247],
        [16.1695, 16.6271, 20.3966, 24.0853, 26.5479,  9.0788, 16.3947],
        [11.9291, 12.1364, 14.9218, 20.7036, 22.6188, 14.3940, 13.8199],
        [16.4593, 22.9346, 22.5230,  9.4052, 11.8713, 13.0907, 12.9973],
        [34.0526, 29.1021, 14.0470, 15.8898, 16.1299, 16.8141, 24.7508],
        [35.5606, 45.1999, 26.7902, 22.6250, 22.8208, 22.8315, 24.8159],
        [10.3613, 13.9305, 13.8472, 18.1348, 26.7811, 27.6963, 23.1500],
        [ 7.4907,  7.2297,  7.2574,  7.5022,  8.9095, 12.6825, 16.5466]])
ground truth
tensor([[12.3948, 12.3027, 24.9737,  7.7196,  6.7464,  9.1005, 14.4792],
        [16.2698, 18.4949, 25.2551, 36.7772, 40.4195, 10.4734, 20.7766],
        [22.1514, 20.8050, 18.5941, 32.1145, 28.7273, 17.7721,  7.3129],
        [20.2656,  2.2883, 32.4040,  8.9953, 13.6376, 15.9653, 12.8090],
        [23.4836, 26.4881, 17.5312, 17.0918, 17.5737, 15.4337, 23.0584],
        [42.5454, 6

batch_predictions
tensor([[13.1174, 19.9550, 23.2390,  7.7173, 10.4318, 12.5537, 12.5259],
        [20.3226, 22.2550, 11.7814, 10.8172, 10.5905, 11.6755, 15.2021],
        [10.2148, 11.7560, 14.2229, 19.3054, 26.1226, 21.1052,  9.6674],
        [ 9.3426, 11.4134, 15.5085, 13.5739,  7.5707,  7.4408,  8.0153],
        [ 8.8111, 11.7825, 16.9722, 18.7440,  6.7327,  7.2803,  8.7125],
        [10.1472, 10.5525, 11.4404, 14.3318, 18.3670, 19.6966,  8.9160],
        [17.3889,  9.4423,  9.9622,  9.4888, 10.0408, 13.1258, 21.9589],
        [21.3202, 12.7166,  9.4677, 10.1436, 10.3818, 11.2042, 18.2925]])
ground truth
tensor([[16.3335,  9.9421, 22.9879,  3.5245,  4.7870, 13.9663,  8.4561],
        [26.3545, 28.5245, 15.0053, 10.6391, 11.2178, 14.5450, 17.7144],
        [ 1.4172, 19.8696, 20.0113, 23.3985, 33.6026, 34.4388,  8.4467],
        [11.4150,  7.2462,  0.8417,  0.0000,  0.0000,  5.3393,  8.8243],
        [ 9.5608, 10.1131, 16.1494, 19.9763,  4.0242,  9.1136,  1.5518],
        [11.7570, 1

batch_predictions
tensor([[ 9.9713, 10.6681, 11.2921, 14.1819, 20.4355, 19.0062,  8.5778],
        [14.9189, 14.7321, 15.0075, 19.6636, 28.1075, 26.0841, 12.5965],
        [18.3640, 25.7110, 23.4554, 10.5368, 11.5439, 12.6912, 13.8411],
        [22.7862, 17.3930, 10.5652,  9.9844,  9.5314, 10.3224, 17.9308],
        [12.6915, 11.8859, 14.2987, 20.7720, 19.2637, 10.3670, 13.2465],
        [12.4731, 15.3233, 20.1010, 19.5141, 10.7010, 11.3681, 12.7551],
        [30.8140, 27.3467, 12.8528, 14.9547, 15.3748, 15.2913, 20.5265],
        [25.5574, 13.2909, 14.5466, 14.7857, 16.9821, 22.8833, 26.5515]])
ground truth
tensor([[11.4808, 13.6113, 14.3872, 12.8748, 20.3314, 26.4729, 10.5997],
        [16.4519, 18.2667, 16.6360, 24.4740, 36.4545, 34.9947, 14.7422],
        [27.5776, 22.8564, 17.5171, 12.3882, 16.0179, 14.1636, 11.6649],
        [27.5250, 26.7359, 12.6907,  8.3509, 11.1783,  7.3251, 13.9795],
        [14.2715, 11.8481, 17.0918, 26.7715, 21.8112, 11.5788, 12.0181],
        [11.3493, 1

tensor([[ 5.4280,  5.8957,  9.7931, 10.6859,  4.4218,  4.7619,  5.0170],
        [10.1394, 12.8222, 20.0289,  6.6675,  2.8538, 16.2283, 11.1389],
        [64.6633, 47.9090, 42.1620, 46.8306, 41.8595, 42.5171, 47.6460],
        [15.5971, 11.5992,  9.5871, 12.2041, 13.3219, 18.3325, 22.9748],
        [17.3753, 32.1854, 55.2721, 40.0652, 16.5816, 21.7545, 18.0981],
        [10.6786, 18.2930, 30.1815,  6.8648,  3.1694, 11.0863, 11.2704],
        [21.5561, 21.5703, 12.0748, 18.2115, 25.6519, 49.6173, 35.8560],
        [18.5034, 27.3803, 27.6039, 16.6360, 13.2167, 13.9400, 12.5855]])
batch_predictions
tensor([[18.2775, 19.6105, 10.1351, 10.0853, 10.3667, 10.8533, 14.2444],
        [14.6576, 14.6861, 13.2949, 12.9373, 16.0659, 25.4941, 27.8438],
        [15.7911, 15.4256, 15.5786, 17.9416, 28.2790, 37.5498, 23.5459],
        [14.9376, 17.7014, 21.0600, 23.7473, 13.0420, 11.2732, 13.4013],
        [14.4125, 18.2742, 25.7983, 26.8947,  8.5559, 13.0289, 13.6705],
        [17.2828, 24.3421, 31.23

batch_predictions
tensor([[18.5116, 28.2275, 27.1779, 12.3471, 14.6886, 13.9974, 14.1656],
        [12.3656, 11.6750, 15.1613, 15.7733, 15.8980, 14.9902, 15.6507],
        [16.0567, 26.1848, 36.3466, 25.8633, 13.4724, 14.8382, 14.1439],
        [12.2424, 15.0073, 13.0626, 13.3609, 18.7910, 29.5765, 32.0742],
        [11.2059, 12.3789, 11.6574, 11.4996, 15.6896, 23.9793, 22.2591],
        [13.1895, 18.1098, 22.5925,  9.1229, 11.2213, 12.7608, 13.2039],
        [13.6652, 18.5356, 23.6500, 19.2005, 11.9720, 12.8656, 13.1048],
        [20.0953, 25.0285, 24.0061, 12.1774, 13.9758, 13.9817, 14.7006]])
ground truth
tensor([[15.4478, 27.2817, 33.2766, 16.6383, 18.5941, 15.5329, 15.7596],
        [ 7.1936,  0.0000,  0.4208,  3.5376,  7.2988, 21.2914, 20.2262],
        [15.7738, 22.6049, 38.6905, 38.2937, 22.5624, 21.9529, 20.8617],
        [ 5.6264, 13.0102, 18.4240, 17.8146, 20.8333, 29.2800, 31.6752],
        [10.6786,  7.4171,  9.6660, 10.7049, 18.8059, 30.3130, 29.4319],
        [13.5718, 2

batch_predictions
tensor([[28.5121, 22.9877, 12.7160, 13.8046, 13.8533, 15.0582, 20.0732],
        [14.0122, 14.5712, 18.8290, 27.0737, 25.9595,  9.7325, 13.8313],
        [13.3813, 14.9075, 13.4822, 14.8421, 23.0262, 33.3249, 28.0510],
        [12.4676, 12.5830, 18.9280, 25.2683, 14.3790, 10.3885, 12.7256],
        [15.0337, 15.6451, 18.9370, 27.7054, 25.1766, 17.9234, 18.0291],
        [12.4508, 16.8132, 21.8284, 18.7609,  8.4564, 10.3864, 12.4233],
        [14.1456, 14.0288, 14.4526, 19.0928, 31.2938, 27.1923, 14.1370],
        [12.6848, 14.9215, 12.7856, 12.8914, 22.3808, 33.9379, 29.0314]])
ground truth
tensor([[38.5346, 27.3951, 20.6207, 17.5737, 17.6304, 15.7880, 21.6695],
        [ 0.3260,  3.7415, 17.3186, 32.7664, 28.9541,  1.9133, 10.9552],
        [10.7851, 12.2591,  9.7506, 20.8192, 39.6684, 26.3039, 27.1967],
        [15.8075, 13.7691, 24.1583, 23.1589, 21.1731, 11.4019, 11.1520],
        [11.4087, 14.5550, 22.9450, 30.3571, 30.1304, 15.9580, 15.0652],
        [15.2157, 1

batch_predictions
tensor([[17.2814, 22.3442, 31.2978, 26.5848, 13.2990, 16.8133, 17.1005],
        [14.6203, 14.4408, 14.8779, 28.9796, 41.7981, 35.0967, 13.1220],
        [ 9.2834,  9.6863, 10.4555, 12.5961, 17.2032, 23.6062, 13.8801],
        [18.8999, 18.2036, 17.9410, 21.7389, 28.3738, 28.4685, 24.0054],
        [30.9767, 37.5503, 18.3349, 16.7494, 18.0591, 16.9831, 17.6857],
        [13.7699, 13.0709, 13.3145, 16.6408, 22.9405, 24.3528, 16.1798],
        [13.5471, 19.0060, 23.7525, 23.0544,  8.8714, 12.1502, 13.7953],
        [18.6847, 21.2674,  9.1897,  9.6945,  9.9258, 10.9975, 12.6987]])
ground truth
tensor([[21.6202, 35.9679, 23.1326, 38.2956, 18.4114, 10.2052, 23.7507],
        [13.9663, 16.8069, 16.4387, 34.2451, 58.3772, 50.4866, 15.0579],
        [ 0.0000,  7.4263, 11.4371, 15.9014, 18.8634, 31.3917, 30.8532],
        [58.0499, 28.7415, 21.7971, 21.5561, 28.3872, 40.6321, 22.9734],
        [28.9541, 51.4031, 36.6071, 13.9598, 16.2415, 18.7500, 14.0873],
        [14.8080, 1

batch_predictions
tensor([[17.7921, 19.4390, 25.3847, 33.7424, 24.3505, 15.3915, 18.8537],
        [13.5371, 11.9257, 11.4120, 11.1967, 12.1258, 20.1659, 25.9096],
        [19.3540, 22.5074,  8.2163, 10.1732, 11.3290, 11.2630, 13.0514],
        [ 8.7200, 10.1954, 11.6737, 11.3532, 13.4182, 16.6659, 16.2765],
        [14.2930, 19.9634, 25.2681, 15.1190,  8.9990, 10.6431, 11.5062],
        [ 9.2494,  9.1388,  6.2940,  5.7430,  5.8992,  6.2600,  7.0971],
        [14.1767, 13.8663, 14.1705, 24.0885, 41.6458, 32.0480, 14.5712],
        [21.5503, 33.7003, 25.8020, 13.4837, 15.6212, 14.7165, 15.6588]])
ground truth
tensor([[18.6650, 28.2455, 32.3696, 46.3152, 37.1457, 30.0595, 18.8492],
        [24.0400,  8.3772, 10.0342, 10.8364, 10.3761, 18.0168, 25.1578],
        [ 0.1578, 28.2614,  9.9421, 11.2441, 12.0857,  8.1010, 12.1120],
        [ 5.9179, 11.5860, 11.9805, 13.7033, 21.3703, 17.1489, 15.3077],
        [ 9.0216, 17.7012, 24.5529, 24.1057,  8.5613,  5.7075, 13.9006],
        [ 8.4892, 1

batch_predictions
tensor([[36.3432, 31.0003, 14.9232, 16.7085, 16.7786, 17.4528, 24.8123],
        [17.1244, 17.9654, 23.6540, 34.4093, 31.8745, 15.2884, 20.1508],
        [14.8995, 17.6378, 22.2290, 21.1496, 13.4965, 12.6076, 13.5191],
        [ 9.4348, 10.7033, 12.4621, 16.6880, 20.5015,  6.8109,  8.1776],
        [15.9558, 23.6001, 30.3226, 12.5210, 13.3223, 15.0943, 15.0383],
        [ 7.6785,  7.8317,  6.6095,  6.5261,  7.5238,  7.2724,  7.7126],
        [27.5624, 12.5227, 14.0663, 13.5520, 13.2544, 17.1572, 25.4078],
        [20.9084, 24.6389,  8.8692,  9.7432, 11.2783, 11.6545, 12.8084]])
ground truth
tensor([[44.5720, 39.8526, 19.5011, 21.8963, 17.2619, 20.1389, 22.7183],
        [16.5108, 20.6633, 26.4172, 46.9104, 36.2670, 14.5266, 20.0397],
        [17.7144, 15.9653, 21.6202, 26.0652, 13.2956, 15.5839, 16.5045],
        [13.9663,  8.4561, 10.3630, 11.4545, 20.8574,  4.6160,  7.9037],
        [15.3077, 27.4198, 40.8469, 29.3530, 10.1657, 12.5592, 18.5955],
        [13.2430, 1

batch_predictions
tensor([[28.7204, 14.0190, 16.1628, 14.9998, 14.8555, 20.8009, 35.8226],
        [18.5294, 17.9578, 17.5096, 24.5012, 39.2656, 32.4413, 20.2836],
        [14.0077, 15.5432, 12.9557, 12.7224, 19.8335, 31.5420, 33.8897],
        [13.3933, 15.9665, 14.3499, 14.3056, 18.8525, 27.2978, 29.2864],
        [10.2558, 15.9452, 20.0744, 10.5278,  7.8938,  7.9963,  7.9136],
        [17.1610, 27.3651, 26.6042, 11.1640, 13.2244, 14.4305, 13.2551],
        [24.9306, 27.5359, 22.5620, 13.3221, 13.1837, 12.9524, 15.1541],
        [24.6230,  7.5317, 12.5077, 12.7475, 12.6389, 16.6753, 24.2529]])
ground truth
tensor([[38.4354, 20.3373, 16.5816, 13.9598, 14.3141, 23.0300, 45.5782],
        [18.2930, 17.1489, 21.3703, 33.8901, 52.8669, 33.0221, 19.1215],
        [16.0968, 15.1368,  8.1405, 11.0994, 20.3446, 35.7969, 41.6623],
        [11.0994, 13.5455, 14.6896, 15.8601, 16.1362, 25.6312, 30.2735],
        [15.4655, 22.1331, 23.2641,  9.7317, 13.5587, 10.4156, 13.3614],
        [17.3186, 3

batch_predictions
tensor([[24.2930, 18.4737, 11.4501, 11.9047, 11.1551, 12.4295, 19.6543],
        [11.0218, 10.6309, 10.0923,  9.8277, 11.8422, 19.7924, 21.2898],
        [17.6959, 10.6376, 11.9473, 11.5482, 12.3515, 14.8402, 19.8572],
        [37.9932, 29.2849, 18.2229, 18.2351, 17.0887, 16.9042, 25.2594],
        [16.1891, 17.7382, 16.2055, 16.4360, 23.6075, 37.1980, 40.1122],
        [13.1943, 19.1334, 24.7649, 23.6808, 11.3620, 12.3640, 13.0812],
        [13.2010, 16.7165, 21.1553, 24.2218, 12.9499, 11.2247, 13.6821],
        [ 9.8544, 13.7298, 18.7860, 15.2435,  6.7972,  7.1363,  8.2742]])
ground truth
tensor([[24.1780, 26.8707, 11.0969, 14.5975,  9.8073,  7.9790, 17.0635],
        [ 2.3672,  5.1946,  9.1662, 12.0726, 16.6097, 23.8690, 22.6591],
        [35.4450,  3.7415, 12.6701, 10.0340, 12.2166, 12.4433, 22.6332],
        [45.5074, 37.4291, 15.3203, 15.5896, 18.8776, 23.0017, 34.6088],
        [21.1026, 15.2920, 16.2557, 13.7613, 16.7092, 35.0198, 40.9580],
        [11.8359, 1

batch_predictions
tensor([[25.0154, 37.8152, 30.0676, 12.7813, 16.8646, 15.5839, 15.8547],
        [21.7714, 32.2256, 28.3453, 19.4589, 16.8996, 16.4368, 15.5281],
        [24.4613, 12.4221, 12.9733, 12.6297, 13.3124, 19.8593, 27.5301],
        [11.4071, 14.5148, 21.3672, 26.1933,  9.1065, 12.1711, 10.3057],
        [ 7.6483,  7.6453,  7.5828,  7.4292,  8.6146,  9.8110, 13.5433],
        [15.2414, 22.9346, 24.8615, 13.3604, 12.3123, 12.4920, 12.1424],
        [38.4249, 14.0706, 17.3899, 16.7142, 16.9453, 34.0080, 47.9889],
        [ 7.4824,  8.7811,  9.4395, 12.4691, 18.7365, 18.8832,  7.6585]])
ground truth
tensor([[23.9229, 46.4286, 42.1910, 18.3815, 17.5454, 17.4036, 18.4382],
        [16.9218, 45.1105, 25.7795,  0.0000,  9.5096, 17.7721, 18.1264],
        [21.8254, 13.8747, 14.8384, 11.1536, 16.6383, 23.2426, 26.3605],
        [17.5960, 15.7286,  0.0000, 25.1841,  4.3530, 13.1904,  2.3277],
        [ 6.5760,  6.0232,  5.2154,  4.2800,  6.8736,  9.6088,  1.9700],
        [17.1910, 2

batch_predictions
tensor([[ 6.1562,  6.2914,  6.6878,  7.2934,  7.2899,  6.6297,  6.5317],
        [19.0303, 21.5068, 21.1353, 17.6278, 12.0433, 13.1282, 14.2808],
        [32.2234, 32.3039, 18.9059, 19.1471, 19.1907, 18.7603, 22.6626],
        [26.6420, 11.2021, 12.9619, 11.5834, 11.4090, 13.7096, 22.8659],
        [ 7.8889,  8.6508, 10.4183, 15.4846, 17.9990,  6.4163,  7.7386],
        [ 6.3600,  6.6078,  6.8681,  7.5764,  8.3267,  7.8561,  6.5533],
        [32.8806, 38.2315, 23.5839, 15.6326, 16.6486, 15.8411, 18.1166],
        [10.6949, 13.2046, 20.9188, 20.3390,  7.8651,  8.9608, 10.5456]])
ground truth
tensor([[ 6.1650,  6.3492,  8.5317,  9.8073,  7.1570,  4.0675,  6.4059],
        [18.4377, 22.4882, 29.8527, 27.7223, 15.4918, 12.9274, 15.1631],
        [23.3298, 40.5839, 28.4193, 22.5802, 19.2530, 23.0274, 26.1967],
        [23.2378,  5.9442,  8.7980,  0.9469, 18.5692, 12.1120, 21.0547],
        [ 1.5518,  3.2746, 10.7969, 14.6107, 21.8438,  4.2609,  7.0884],
        [ 6.4059,  

batch_predictions
tensor([[11.8271, 15.5177, 17.1061,  8.0092,  7.9763, 10.0012, 11.8039],
        [12.9742, 13.4377, 13.6355, 18.3814, 27.2037, 25.9659, 13.6034],
        [11.6598,  7.1356,  6.7904,  6.8614,  6.7300,  7.6733,  9.7994],
        [13.2431, 16.1362, 18.3032, 20.2019, 26.5673, 25.4616, 13.0977],
        [13.9302, 13.6780, 14.0301, 18.8225, 26.8840, 22.3059, 14.3311],
        [19.7642, 21.6303, 11.6302, 10.2280, 11.5156, 15.2883, 19.3898],
        [39.2040, 39.6392, 42.5666, 47.6323, 47.8589, 46.3924, 44.5607],
        [14.7603, 13.8423, 13.6082, 17.4677, 25.5807, 24.7945, 15.2731]])
ground truth
tensor([[17.1094, 15.3077, 14.9790,  6.7859,  8.9690, 10.0736, 10.5734],
        [20.3515, 15.4195, 16.6100, 27.1259, 30.2579, 35.2749, 12.1032],
        [ 1.9700,  4.1383,  5.1587,  5.7115,  6.0516,  7.0437, 11.2103],
        [18.3588, 10.1131, 19.7791, 19.8185, 26.5255, 29.7212,  4.2346],
        [19.4897, 12.3093, 14.4266, 20.7522, 26.3940, 23.3561, 13.5324],
        [22.3698, 3

batch_predictions
tensor([[15.5344, 14.4086,  7.5856,  7.6552,  8.2001,  8.9719, 13.5641],
        [16.3470, 26.0466, 24.4064,  8.7617, 11.7060, 12.1871, 12.1788],
        [22.3713, 34.3907, 31.8331, 16.0516, 17.4520, 17.0413, 16.0284],
        [15.2387, 19.1074, 26.9512, 20.8976, 10.4109, 13.1951, 13.0762],
        [23.6991, 17.9980, 16.8971, 16.2322, 20.7378, 38.8115, 47.0252],
        [13.8667,  5.7917,  6.0732,  7.1901,  9.0059, 12.3289, 16.2735],
        [11.6864, 12.2096, 11.8078, 13.1173, 24.1732, 30.5411, 18.7328],
        [13.6406, 14.1815, 14.5114, 16.0074, 21.0633, 22.4034, 10.8450]])
ground truth
tensor([[19.1478, 14.7028,  5.7075,  9.3635,  6.1415,  9.0216, 11.9937],
        [18.5429, 27.4987, 28.8795, 11.9279, 16.9253, 13.9006, 13.3877],
        [15.4478, 38.8322, 39.6542, 21.0601, 17.1627, 12.7268, 10.0198],
        [15.3770, 17.4036, 27.3384, 26.0062, 10.7710, 13.0669, 11.3804],
        [32.1712, 18.7358, 18.7075, 15.5329, 24.6032, 31.5334, 45.6491],
        [20.1078,  

tensor(36.1484)

In [34]:
class TransformerVanilla(Module):
    
    
    def __init__(self, d_model=64, nhead=4, num_encoder_layers=3, num_decoder_layers=3,
                dim_feedforward=512, dropout=0.1, activation=F.gelu, use_norm=True):
        
        super(TransformerVanilla, self).__init__()
        self.transformer_model = TransformerModel(d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers,
                                                  num_decoder_layers=num_decoder_layers, dim_feedforward=dim_feedforward,
                                                  dropout=dropout, activation=activation, use_norm=use_norm)
        
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        self.input_projection1 = nn.Linear(1, d_model)
        self.input_projection2 = nn.Linear(d_model, d_model)
        self.output_projection = nn.Linear(d_model, 1)
        
    def forward(self, encoder_x, decoder_x, src_mask=None, tgt_mask=None, memory_mask=None):
        """ encoder_x should be vector of size input_seq_len x bs x 1   
            decoder_x should be vector of size output_seq_len x bs x 1
        """
        encoder_x = self.pos_encoder(self.input_projection2(self.input_projection1(encoder_x)))
        decoder_x = self.pos_encoder(self.input_projection2(self.input_projection1(decoder_x)))
        x = self.transformer_model(encoder_x, decoder_x, 
                                   src_mask=src_mask, tgt_mask=tgt_mask, memory_mask=memory_mask)
        x = self.output_projection(x)
        return x

    
        
        
        

In [40]:
model_v0 = TransformerVanilla(d_model=64, nhead=nhead, num_encoder_layers=2, num_decoder_layers=2,
                dim_feedforward=256, dropout=0.1, activation=F.gelu)
model_v0.load_state_dict(torch.load('/home/mbaroody/Programming/Time series/NN5/transformer_vanilla_robustnull_v2-0.pt'))

<All keys matched successfully>

In [41]:
evaluate_model(model_v0, test_dataloader)

batch_predictions
tensor([[11.4240, 11.7356, 12.1022, 12.5022, 12.9190, 13.3243, 13.6408],
        [30.6787, 31.0758, 31.0345, 30.9378, 30.8638, 30.8237, 30.8084],
        [24.8778, 28.7262, 28.5320, 27.3244, 25.2732, 22.0480, 19.4626],
        [17.6461, 22.5739, 21.3124, 19.0141, 16.4467, 15.1427, 15.2202],
        [17.6647, 22.7957, 23.0727, 22.1033, 20.5371, 18.5566, 17.1777],
        [23.0011, 20.9399, 21.2400, 21.9991, 22.6732, 23.2022, 23.6021],
        [16.7230, 19.8401, 19.2410, 18.3349, 17.5219, 16.9240, 16.5867],
        [23.4345, 30.0252, 29.5433, 27.4890, 23.3897, 18.2426, 18.3207]])
ground truth
tensor([[13.5718, 20.6865,  9.6791,  6.8780,  6.9306,  8.3246, 11.8096],
        [24.9079, 29.5634, 34.5739, 45.1604, 18.3456, 21.1073, 22.9616],
        [15.0085, 27.8770, 36.7772, 37.0890, 21.1026, 15.2920, 16.2557],
        [18.0414, 15.5612, 12.9960, 16.9218, 23.0867, 32.1854, 14.1865],
        [21.3010, 35.1332, 33.7302, 13.0244, 12.7976, 12.6842, 13.4070],
        [30.8107, 2

batch_predictions
tensor([[20.8067, 29.1103, 29.0137, 27.0325, 22.3041, 17.1467, 17.3533],
        [20.4119, 21.0035, 20.8544, 20.5631, 20.3071, 20.1406, 20.0661],
        [18.7933, 19.9496, 19.6146, 19.0784, 18.6161, 18.3035, 18.1501],
        [25.6364, 26.1793, 26.1036, 25.8755, 25.6469, 25.4806, 25.3933],
        [20.4224, 21.5751, 21.3344, 20.8683, 20.4482, 20.1543, 20.0001],
        [16.5220, 15.1254, 15.5129, 16.2160, 16.8156, 17.2842, 17.6458],
        [21.0843, 20.0166, 19.7710, 19.9387, 20.2758, 20.6641, 21.0583],
        [24.3506, 28.5611, 28.0166, 26.7180, 24.9353, 22.3779, 20.2739]])
ground truth
tensor([[ 0.0000,  0.0000, 14.8384, 35.4025, 33.6593,  8.3617, 17.2761],
        [13.0527, 15.3345, 29.7761, 26.7999, 15.3345, 18.1122, 14.6117],
        [13.9795, 11.9411, 15.5576, 25.6575, 24.8816, 12.1778, 10.2841],
        [39.5125,  6.8736,  7.3838, 20.8192, 11.5930, 15.8022, 20.1531],
        [13.4271, 15.5445, 28.0247, 14.8211, 32.1278, 22.5013, 23.4482],
        [10.6786,  

batch_predictions
tensor([[19.3798, 19.5794, 19.5151, 19.3745, 19.2441, 19.1614, 19.1317],
        [13.2030, 12.7601, 12.5901, 12.5640, 12.6394, 12.7661, 12.9356],
        [ 5.0716,  4.3449,  4.3207,  4.3548,  4.3802,  4.3972,  4.4094],
        [21.3347, 22.3822, 22.1543, 21.6611, 21.1784, 20.8125, 20.5986],
        [19.8892, 19.7796, 19.8053, 19.8418, 19.8704, 19.8978, 19.9288],
        [23.1622, 25.5485, 25.1959, 24.4085, 23.5533, 22.7522, 22.1007],
        [30.2627, 27.7034, 27.1310, 27.5492, 28.2811, 29.0354, 29.7161],
        [17.9935, 23.7863, 22.9267, 20.7469, 17.5588, 16.0552, 16.3224]])
ground truth
tensor([[22.5013, 20.5024, 10.5734,  8.8769,  8.6796, 14.2951, 18.6612],
        [ 3.4850,  7.9695,  6.3914,  9.6002, 13.9269, 20.3577, 15.1105],
        [ 5.2721,  7.7239,  7.5113,  7.8656,  8.9569, 14.9660,  9.0703],
        [16.6754, 19.8054, 20.7391, 25.1578, 26.9726, 13.5192, 17.7012],
        [31.5623, 28.9453, 12.6381, 16.4256, 16.9385, 16.7806, 18.3719],
        [19.7279, 2

batch_predictions
tensor([[20.2508, 27.6320, 26.7556, 23.8845, 18.5791, 16.6841, 16.9938],
        [38.2789, 38.2719, 38.2730, 38.2732, 38.2730, 38.2728, 38.2728],
        [27.2057, 27.4831, 27.5000, 27.4149, 27.2873, 27.1776, 27.1157],
        [33.5748, 33.1687, 33.1489, 33.2084, 33.2730, 33.3416, 33.4157],
        [13.6617, 17.3827, 21.8854, 20.6916, 18.5013, 16.0529, 14.6150],
        [27.2129, 23.6230, 24.6123, 25.9462, 27.1088, 28.0112, 28.6583],
        [19.7465, 22.5189, 21.8739, 20.7657, 19.5445, 18.3660, 17.5782],
        [14.6181, 18.3109, 21.6524, 21.6582, 20.8906, 19.9369, 19.0848]])
ground truth
tensor([[16.5308, 18.2272, 17.1357, 25.2367, 37.9143, 25.9732, 12.2699],
        [58.4824, 60.6786, 56.6149, 65.8469, 52.4066, 44.9763, 55.1683],
        [13.8605, 17.7296, 13.3645, 15.6179, 40.0652, 53.6706, 24.0930],
        [29.7761, 38.5913, 18.5232, 19.3594, 23.0584, 26.9558, 38.6480],
        [10.5865,  9.5476, 10.0736, 18.5429, 27.4987, 28.8795, 11.9279],
        [14.3566, 4

batch_predictions
tensor([[26.4438, 34.9929, 37.1497, 36.4311, 34.6646, 31.0433, 25.2181],
        [23.2849, 26.6935, 26.1826, 24.8922, 23.0697, 20.8058, 19.2951],
        [13.6317, 12.8158, 13.3054, 14.0125, 14.6529, 15.0863, 15.3150],
        [13.4826, 17.2232, 20.0377, 20.9034, 20.6201, 20.0474, 19.5135],
        [31.1340, 29.8622, 29.4667, 29.5090, 29.7636, 30.1602, 30.6656],
        [20.8650, 21.4058, 21.2585, 20.9730, 20.7182, 20.5501, 20.4744],
        [19.4218, 18.3223, 18.1901, 18.5944, 19.1965, 19.8561, 20.5217],
        [19.0463, 22.0281, 20.9729, 19.3279, 17.3585, 15.7995, 15.5283]])
ground truth
tensor([[33.3640, 53.8532, 47.3567, 22.7117, 20.9232, 16.7938, 20.8443],
        [ 4.3226, 37.1457, 34.5947, 14.6117, 19.7279, 14.3991, 19.0760],
        [13.5850,  9.9816,  9.6923, 10.8890, 11.6781, 16.1494, 23.4219],
        [ 0.0000, 12.9819, 25.3118,  7.7098,  9.3396, 12.4858, 18.4240],
        [ 8.9164, 16.4256, 14.6370, 31.2599, 58.8901, 47.5802, 12.2830],
        [19.3027, 2

batch_predictions
tensor([[15.3520, 19.6411, 22.9015, 24.0992, 23.7341, 22.8954, 21.9650],
        [22.0974, 20.8243, 21.2551, 22.2344, 23.3537, 24.5108, 25.6248],
        [28.8209, 32.0102, 32.0948, 31.3066, 30.1416, 28.7573, 27.1906],
        [20.9436, 22.1412, 21.8427, 21.2592, 20.6872, 20.2513, 19.9894],
        [14.0910, 18.3234, 21.4162, 22.2700, 21.8505, 21.1432, 20.4847],
        [ 7.0016,  5.8869,  5.6891,  5.6995,  5.7297,  5.7515,  5.7671],
        [19.4592, 25.6227, 24.3653, 21.2222, 16.9391, 16.0209, 16.2694],
        [ 9.4287, 10.4196, 11.0885, 11.7087, 12.4295, 13.4694, 14.9995]])
ground truth
tensor([[ 0.0000,  5.1871, 12.9819, 28.8549, 33.4042, 16.5391, 10.2324],
        [13.0244, 12.7976, 12.6842, 13.4070, 23.0300, 35.2749, 32.9932],
        [26.6176, 51.0389, 45.8969, 22.3435, 16.8858, 16.3335, 20.9100],
        [13.9172, 12.9393, 15.3486, 18.7500, 29.6769, 30.7256, 16.7375],
        [11.3520, 12.0748, 16.2698, 17.8430, 34.0278, 31.5193, 12.6559],
        [ 8.7454,  

batch_predictions
tensor([[17.1527, 18.4297, 18.1940, 17.5330, 16.8511, 16.3089, 15.9699],
        [19.7329, 18.3243, 18.1707, 18.6189, 19.1712, 19.6893, 20.1480],
        [25.6944, 21.6685, 22.4947, 23.9591, 25.4001, 26.6402, 27.5830],
        [23.9560, 26.4983, 26.1450, 25.0248, 23.3990, 21.3961, 19.7698],
        [18.6392, 18.8674, 18.7205, 18.5370, 18.4049, 18.3384, 18.3247],
        [19.4334, 18.9493, 18.9385, 19.0759, 19.2311, 19.3722, 19.4951],
        [ 8.1858,  7.4357,  7.0766,  6.9519,  6.9209,  6.9217,  6.9255],
        [15.6443, 16.3370, 16.1960, 15.8621, 15.5711, 15.3847, 15.2991]])
ground truth
tensor([[16.5176, 11.2835, 13.5981, 17.0700, 20.7391, 17.7538, 11.2704],
        [31.4201, 30.3288, 15.9155, 12.9393, 16.5958, 15.7738, 21.3719],
        [38.2937, 22.5624, 21.9529, 20.8617, 18.8634, 24.3339, 44.2035],
        [ 0.2126, 12.0607, 20.5215, 20.4649, 31.9586, 38.6480, 41.0856],
        [12.3299, 17.6871, 21.5278,  7.2988,  3.4014, 14.8951, 12.7268],
        [24.4477, 1

batch_predictions
tensor([[17.2799, 17.1374, 17.1755, 17.2315, 17.2764, 17.3148, 17.3515],
        [16.2620, 19.6003, 20.0637, 19.5898, 19.1043, 18.7501, 18.5548],
        [29.3880, 36.7929, 37.0023, 36.2366, 35.0455, 33.3625, 30.7034],
        [23.6399, 25.1890, 25.0703, 24.4014, 23.5089, 22.5819, 21.7796],
        [14.4288, 19.3182, 22.6821, 21.3827, 19.0719, 15.7228, 14.3834],
        [17.6772, 18.6039, 18.4321, 18.0174, 17.6378, 17.3786, 17.2481],
        [15.2253, 18.8464, 19.3073, 18.4459, 17.4856, 16.6264, 15.9554],
        [17.1788, 21.2622, 20.1404, 18.2762, 16.2816, 14.8871, 14.6179]])
ground truth
tensor([[16.9385, 10.1657,  8.7454, 11.4808, 12.8617, 14.6107, 22.5934],
        [ 1.5255, 15.8206, 14.7422, 18.7533, 21.9227, 24.4477, 11.9542],
        [19.9369, 18.8848, 33.9427, 55.6418, 14.1899,  0.0000, 15.4787],
        [14.2432, 19.1185, 21.0317, 34.8639, 28.4722, 30.2863, 21.1593],
        [10.6293, 15.9439, 18.1689, 17.0635, 27.6219, 35.3175, 30.0737],
        [28.7132, 2

batch_predictions
tensor([[15.6108, 19.2451, 19.9290, 19.0974, 18.1402, 17.2737, 16.6055],
        [15.2633, 16.3018, 16.3842, 16.0913, 15.7623, 15.5114, 15.3563],
        [24.6761, 25.5426, 25.4089, 25.0453, 24.6760, 24.3937, 24.2281],
        [19.7211, 18.1080, 18.6380, 19.2863, 19.7901, 20.1605, 20.4302],
        [15.9947, 20.2612, 20.8492, 19.8768, 18.5969, 17.2353, 16.1449],
        [19.4270, 25.6100, 27.6598, 26.7499, 24.6152, 20.2984, 17.7197],
        [20.8541, 20.0912, 19.9490, 20.1067, 20.3898, 20.7205, 21.0598],
        [15.0917, 14.3273, 14.7493, 15.4311, 16.0211, 16.4626, 16.7569]])
ground truth
tensor([[11.5505, 14.0023, 16.9926, 26.7574, 25.0142, 15.5896, 13.0244],
        [14.1504, 13.3482, 13.1247, 13.6244, 15.4524, 14.1110,  9.6265],
        [17.4382, 19.1215, 30.2735, 42.0831, 36.7438, 16.9516, 15.5313],
        [36.6649, 20.5813, 19.4371, 15.0973, 16.3861, 13.9795, 18.4640],
        [11.7205, 12.4433, 22.0663, 24.0788, 25.8362,  9.4388, 10.5017],
        [19.2177, 3

batch_predictions
tensor([[21.2548, 24.0753, 23.6073, 22.5923, 21.4187, 20.2427, 19.3531],
        [19.9009, 26.0243, 25.1132, 22.6525, 18.8014, 16.6239, 16.8296],
        [24.8347, 21.0106, 21.4130, 22.7276, 24.0982, 25.3517, 26.4105],
        [27.1112, 27.9536, 27.9293, 27.6544, 27.3325, 27.0644, 26.8918],
        [21.3247, 22.7181, 22.3555, 21.6888, 21.0196, 20.4658, 20.0928],
        [22.0988, 29.6901, 28.8956, 26.8147, 22.9607, 18.3877, 18.4513],
        [38.2744, 38.1435, 38.1539, 38.1547, 38.1517, 38.1481, 38.1448],
        [13.2000, 16.4314, 20.3551, 19.3159, 17.7921, 16.3679, 15.3953]])
ground truth
tensor([[13.9795, 18.4640, 31.9963, 13.4403, 14.4398, 11.9805,  3.9716],
        [10.0624, 17.0068, 31.5760, 42.7863, 30.5981, 20.9042, 20.1247],
        [42.6871, 34.0703, 16.8226, 19.0760, 14.9943, 17.0918, 26.9133],
        [19.5160, 19.3319, 27.8538, 32.5750, 23.2246, 21.4229, 14.7685],
        [19.6995, 23.7103, 30.5414, 26.9416, 15.0652, 14.6967, 14.2007],
        [16.1494, 1

batch_predictions
tensor([[25.9251, 21.7233, 22.4939, 23.9396, 25.3931, 26.6868, 27.7104],
        [ 7.3609,  7.9312,  8.0958,  8.1452,  8.1570,  8.1617,  8.1686],
        [15.0373, 18.1964, 20.8717, 19.9071, 18.1455, 16.0313, 14.4908],
        [18.2398, 18.5614, 18.4625, 18.2821, 18.1329, 18.0451, 18.0134],
        [27.1629, 25.4714, 24.8741, 24.9204, 25.3491, 25.9850, 26.7386],
        [ 5.8350,  4.5750,  4.4819,  4.5222,  4.5577,  4.5828,  4.6011],
        [18.7618, 17.1061, 17.9106, 18.6621, 19.2290, 19.6320, 19.9118],
        [24.6438, 21.7762, 21.5058, 22.4119, 23.5081, 24.5797, 25.5829]])
ground truth
tensor([[4.2942e+00, 0.0000e+00, 0.0000e+00, 4.9745e+00, 3.5445e+01, 4.5196e+01,
         3.3759e+01],
        [8.1405e+00, 8.7059e+00, 1.0692e+01, 1.8674e+01, 1.3309e+01, 8.1799e+00,
         8.1405e+00],
        [2.8345e-02, 4.9036e+00, 8.7018e+00, 1.1026e+01, 2.0408e+01, 3.7217e+01,
         2.8160e+01],
        [1.3322e+01, 1.2296e+01, 1.0639e+01, 1.5952e+01, 2.3895e+01, 2.7196

batch_predictions
tensor([[18.8607, 18.1582, 18.1244, 18.3138, 18.5291, 18.7191, 18.8781],
        [ 7.2748,  7.5131,  7.5688,  7.5718,  7.5666,  7.5512,  7.5444],
        [20.5408, 23.4222, 22.5709, 20.9658, 18.7948, 16.8855, 16.4606],
        [28.9861, 30.1215, 29.9538, 29.6543, 29.4142, 29.2657, 29.1969],
        [12.2439, 11.2444, 10.1094,  8.4436,  7.0426,  6.5646,  6.4660],
        [16.6159, 18.8133, 18.7679, 18.2473, 17.7548, 17.4021, 17.1986],
        [26.1263, 29.9498, 29.6511, 28.4731, 26.6581, 23.9770, 21.3514],
        [19.2836, 22.4605, 21.6300, 20.1574, 18.3835, 16.7673, 16.2164]])
ground truth
tensor([[ 4.5371,  0.0000,  8.0878,  7.8511,  9.2188,  7.8248, 12.4408],
        [15.2288, 14.4924, 19.3188, 17.5171, 25.8285, 16.5965, 10.4287],
        [21.0884, 23.6961, 28.4014, 13.4921, 14.0023, 14.5550, 18.2115],
        [26.1573, 32.8643, 35.7312, 23.0800, 25.5655, 26.4466, 30.6023],
        [17.5829,  5.5234,  7.1936,  8.4298,  5.8390, 13.3614,  6.4703],
        [10.8233, 2

batch_predictions
tensor([[15.1943, 18.6542, 19.0010, 18.3992, 17.8123, 17.3839, 17.1397],
        [23.2005, 21.6141, 21.3306, 21.6768, 22.1984, 22.7233, 23.2036],
        [20.4069, 24.3000, 23.3523, 21.3868, 18.5645, 16.5751, 16.5549],
        [20.3493, 25.1718, 24.1246, 22.0446, 19.0894, 16.8915, 16.9748],
        [ 5.4589,  5.7603,  5.7946,  5.7949,  5.7951,  5.7922,  5.7929],
        [11.8832, 10.7767, 10.2740, 10.0266,  9.9437,  9.9763, 10.0824],
        [ 9.5470,  8.3123,  7.6198,  7.4051,  7.3646,  7.3708,  7.3819],
        [20.1798, 27.1976, 30.5113, 29.6586, 27.1249, 20.8222, 17.6650]])
ground truth
tensor([[17.1357, 17.3724, 18.3062, 25.9337,  5.7864,  4.6949, 11.0731],
        [16.3204, 14.5713, 11.7175, 15.6891, 25.7496, 33.9164, 29.0373],
        [15.5576, 17.5302, 25.6970, 20.2525, 29.4450,  4.5239,  7.2199],
        [11.5788, 13.2511, 19.8696, 28.7840, 25.1276, 16.8651, 14.0590],
        [ 4.9842,  6.4703,  6.2599, 11.5334, 15.2946, 18.3982,  5.5892],
        [ 5.8107,  

batch_predictions
tensor([[10.8255, 11.9853, 13.7212, 16.5380, 17.1180, 16.2713, 15.3551],
        [22.7367, 28.6629, 27.7890, 25.7246, 22.2148, 18.4730, 18.4681],
        [29.5708, 33.9164, 33.8849, 33.0081, 31.8828, 30.7254, 29.5242],
        [36.4580, 35.4834, 35.4032, 35.5471, 35.7206, 35.8801, 36.0181],
        [14.0313, 13.3185, 14.1026, 14.9619, 15.5450, 15.8574, 16.0187],
        [17.8319, 18.6483, 18.3855, 17.9385, 17.5544, 17.3044, 17.1920],
        [17.9869, 19.4306, 18.8141, 17.8862, 16.9917, 16.2864, 15.8806],
        [16.7252, 18.5659, 18.3343, 17.7733, 17.2759, 16.9284, 16.7372]])
ground truth
tensor([[13.5587, 10.4156, 13.3614, 18.2799, 27.3540, 36.9148,  7.8774],
        [15.4195, 23.0867, 37.2874, 24.0363, 28.5573,  9.2687,  9.6372],
        [21.2868, 23.1151, 35.6718, 45.7625, 36.7630,  8.0782, 17.1485],
        [67.6446, 59.4388, 35.5726, 34.4246, 21.9104, 25.3260, 37.4717],
        [20.4892, 12.4540, 14.1767, 13.7033, 11.2704, 15.3998, 23.4876],
        [21.8112, 2

batch_predictions
tensor([[20.2149, 20.7252, 20.4993, 20.1435, 19.8405, 19.6482, 19.5702],
        [16.7060, 21.2331, 20.3563, 18.5878, 16.4995, 15.0120, 14.6723],
        [26.7670, 28.4983, 28.3214, 27.8062, 27.2697, 26.8273, 26.5247],
        [17.9058, 19.0414, 18.7361, 18.1526, 17.6025, 17.1947, 16.9632],
        [14.8616, 14.0131, 14.1070, 14.5678, 15.0871, 15.5637, 15.9646],
        [22.4590, 31.3689, 32.0633, 30.1094, 25.6557, 18.6213, 18.3333],
        [24.8788, 25.1725, 25.1277, 24.9786, 24.8253, 24.7165, 24.6668],
        [16.3571, 20.5046, 19.6482, 18.1105, 16.4134, 15.0682, 14.5176]])
ground truth
tensor([[17.7275, 13.7428, 18.5429, 24.8948, 29.6028, 11.0994, 13.5455],
        [16.8226, 25.1276, 26.8991, 12.1457, 12.7693, 19.1468, 17.4603],
        [28.9683, 35.3175, 15.4337, 21.7120, 18.8776, 20.3231, 28.2738],
        [19.8271, 25.2126, 24.9717,  0.2409,  7.2562,  9.3396, 12.2591],
        [ 8.9427,  9.0084, 10.2183,  9.4161, 11.7701, 20.5550, 21.0416],
        [20.6491, 1

batch_predictions
tensor([[18.1553, 16.3175, 17.7689, 19.1681, 20.1381, 20.7703, 21.2094],
        [19.9754, 23.8702, 23.1049, 21.4205, 19.1397, 17.1245, 16.7495],
        [16.1648, 19.3266, 19.3556, 18.6347, 17.8744, 17.2514, 16.8228],
        [25.7888, 34.6035, 35.0384, 33.4712, 30.4533, 24.7767, 20.1396],
        [18.4083, 20.4321, 19.9672, 19.2247, 18.5420, 18.0136, 17.6813],
        [17.9544, 19.1159, 18.9127, 18.4937, 18.1352, 17.9013, 17.7890],
        [37.4878, 34.6606, 34.2748, 34.8539, 35.5938, 36.1944, 36.6045],
        [18.6773, 18.8973, 18.8345, 18.7175, 18.6217, 18.5672, 18.5495]])
ground truth
tensor([[ 3.8407,  8.8010, 10.3033,  5.9949, 11.0686, 16.6241, 30.0595],
        [15.7943, 18.1352, 23.6586, 39.3083, 29.1426, 19.5292, 14.7028],
        [12.2699, 14.5055, 20.3446, 28.6560, 33.0221, 12.0857, 15.4261],
        [20.5924, 25.5527, 33.3900, 51.8707, 41.5958, 21.7262, 21.3861],
        [13.6196, 17.4461, 17.9138, 17.0777, 20.8333, 21.5561, 16.9359],
        [19.5949, 2

batch_predictions
tensor([[15.9593, 21.2266, 22.0312, 20.5763, 18.0836, 15.2404, 14.5034],
        [11.7178, 14.5391, 18.6895, 20.4682, 19.1685, 17.3488, 15.3572],
        [ 8.2082,  9.0539,  9.4565,  9.6953,  9.8528,  9.9722, 10.0764],
        [ 8.2045,  8.1169,  8.0768,  8.0617,  8.0584,  8.0536,  8.0516],
        [13.4041, 12.5578, 12.5320, 12.9815, 13.7197, 14.5532, 15.2948],
        [16.9728, 16.7527, 16.8767, 17.0603, 17.2186, 17.3376, 17.4222],
        [22.1942, 20.3552, 20.1005, 20.6343, 21.3261, 21.9938, 22.6002],
        [28.1682, 29.5556, 29.5494, 29.1609, 28.6718, 28.2180, 27.8676]])
ground truth
tensor([[17.8713, 23.9512, 31.2075, 22.7183,  7.0153, 15.6604, 10.5017],
        [ 8.9427, 10.7143, 10.8277, 22.9308, 30.3571, 18.2540,  9.0561],
        [ 7.1410,  5.9442, 11.3624, 15.5839, 21.3440, 16.7806,  4.0505],
        [12.0068, 15.9127,  3.9716,  6.4308,  5.7207,  6.5755,  7.8906],
        [ 2.9053,  3.3305,  5.7115,  5.8673, 13.2937, 13.6054, 20.5357],
        [18.4508, 1

batch_predictions
tensor([[24.8884, 33.9287, 34.8712, 32.7491, 27.4905, 19.0618, 19.0285],
        [23.8060, 23.2749, 23.2040, 23.3121, 23.4766, 23.6618, 23.8506],
        [23.4889, 20.2084, 21.6232, 23.1483, 24.5571, 25.7381, 26.5794],
        [23.6664, 22.0476, 21.8092, 22.2120, 22.7724, 23.3166, 23.7962],
        [20.9359, 21.5223, 21.3747, 21.0793, 20.8138, 20.6371, 20.5554],
        [23.0082, 20.4049, 21.4898, 22.6425, 23.5905, 24.2996, 24.7917],
        [15.9430, 20.9022, 22.7954, 21.4512, 19.0766, 15.9559, 15.0075],
        [22.2390, 22.2434, 22.2520, 22.2456, 22.2486, 22.2359, 22.2069]])
ground truth
tensor([[1.4953e+01, 3.5902e+01, 6.9726e+01, 5.3353e+01, 1.8924e+01, 1.8280e+01,
         1.2178e+01],
        [4.0136e+01, 3.8435e+01, 2.0337e+01, 1.6582e+01, 1.3960e+01, 1.4314e+01,
         2.3030e+01],
        [1.4527e+01, 1.5944e+01, 1.9388e+01, 1.7517e+01, 3.3702e+01, 2.3484e+01,
         2.6488e+01],
        [1.5958e+01, 1.5065e+01, 1.9615e+01, 2.1259e+01, 2.2548e+01, 2.8316

batch_predictions
tensor([[15.1063, 14.2979, 14.3596, 14.8781, 15.6016, 16.4241, 17.2606],
        [27.1479, 29.9719, 29.7539, 28.9313, 27.8637, 26.6000, 25.1358],
        [19.0668, 19.2437, 19.1760, 19.0322, 18.8975, 18.8113, 18.7806],
        [15.9492, 19.1437, 19.0379, 17.9826, 16.7617, 15.6178, 14.7204],
        [18.0351, 21.0283, 20.1490, 18.7046, 17.0646, 15.5770, 14.9896],
        [20.8892, 21.2297, 21.1393, 20.9576, 20.7975, 20.6971, 20.6579],
        [18.9836, 17.5567, 17.5862, 18.0721, 18.5367, 18.9211, 19.2304],
        [17.9911, 22.9309, 22.9979, 21.7967, 19.8537, 17.6361, 16.4820]])
ground truth
tensor([[20.4235,  5.1420,  9.3766,  6.7728,  7.4566,  9.9027, 18.7138],
        [16.8464, 18.4903, 31.6018, 33.0615, 14.8606, 14.1373, 16.1494],
        [32.4688, 28.7698,  5.5272, 14.2574, 16.7659, 14.9660, 18.4524],
        [13.8605, 33.4467, 20.6774,  5.7115,  5.3713, 13.6338, 15.8872],
        [18.1406, 14.8668, 13.8180, 25.6803, 26.7007, 10.1616, 12.2732],
        [25.5811, 2

batch_predictions
tensor([[12.2548, 11.1616,  9.9219,  8.3022,  7.1732,  6.8101,  6.7338],
        [13.1411, 16.1304, 19.7057, 19.1843, 18.2264, 17.3402, 16.7370],
        [19.3193, 25.7519, 24.6590, 21.8516, 17.7030, 16.4318, 16.7548],
        [12.7033, 12.1837, 11.8394, 11.5437, 11.2266, 10.8144, 10.2026],
        [20.5288, 24.7878, 23.9056, 22.4399, 20.6815, 18.7808, 17.8257],
        [33.2853, 38.3237, 38.2090, 38.0564, 37.9296, 37.8361, 37.7747],
        [17.3599, 19.6532, 18.9557, 17.9705, 17.0504, 16.2942, 15.7884],
        [ 6.5223,  6.7429,  6.7675,  6.7634,  6.7602,  6.7528,  6.7500]])
ground truth
tensor([[12.8617,  4.3924,  7.4040,  8.0747,  7.3514, 11.3887, 15.9258],
        [14.2149, 12.8260, 13.8889, 18.3390, 31.4059, 24.7024,  5.2721],
        [12.7976, 12.6842, 13.4070, 23.0300, 35.2749, 32.9932, 14.9943],
        [ 8.7059,  8.1931,  8.5481, 10.1262, 11.1783, 19.6213, 14.9790],
        [19.3452, 14.2999, 18.0556, 19.4303, 31.5334,  9.8498, 14.8101],
        [28.7415, 3

batch_predictions
tensor([[18.7461, 19.2708, 19.0764, 18.6981, 18.3436, 18.0978, 17.9801],
        [17.1364, 17.6736, 17.4006, 17.0127, 16.7087, 16.5347, 16.4784],
        [24.8359, 20.9529, 22.2527, 24.0385, 25.7869, 27.2295, 28.2124],
        [18.0464, 22.7420, 21.3367, 19.1112, 16.7030, 15.3522, 15.4690],
        [22.8281, 24.6409, 24.5180, 23.6973, 22.5110, 21.1888, 20.0627],
        [27.7262, 32.1499, 32.1630, 30.9289, 28.8199, 25.3919, 21.7886],
        [ 9.7546, 10.6795, 11.4743, 12.3870, 13.6888, 15.5211, 16.5184],
        [18.2151, 23.9656, 24.6830, 23.4574, 21.0129, 17.5999, 16.5439]])
ground truth
tensor([[27.4660, 30.2438,  9.8073, 13.1094,  9.5522, 16.1706, 17.0493],
        [17.4382, 13.7296,  6.6149, 11.2572,  8.9032, 10.9548, 14.7159],
        [36.0119, 10.1616, 15.6604, 16.4541, 19.0051, 26.2188, 38.4921],
        [15.6497, 23.0011, 27.6565, 12.9274, 15.9258, 10.5339, 16.1362],
        [16.8651, 15.1786, 15.3061, 29.9603, 41.1139, 32.4830, 23.7528],
        [14.1723, 2

batch_predictions
tensor([[17.0130, 15.8211, 15.9093, 16.3916, 16.8479, 17.2221, 17.5194],
        [31.4065, 34.1951, 34.0309, 33.4558, 32.9112, 32.5081, 32.2584],
        [18.9551, 19.8217, 19.5672, 19.0912, 18.6558, 18.3515, 18.1969],
        [21.9830, 28.9033, 28.0246, 25.8569, 22.1443, 18.5321, 18.6248],
        [26.9264, 32.9463, 33.3060, 31.7809, 28.8561, 23.6566, 20.3733],
        [25.2796, 28.5073, 28.2335, 27.2259, 25.8367, 24.1390, 22.4732],
        [ 9.9273, 11.2373, 12.7448, 15.3082, 18.5934, 17.6961, 16.5118],
        [18.9028, 24.4604, 24.1060, 22.8444, 20.8997, 18.8063, 17.9007]])
ground truth
tensor([[31.0941, 27.3384, 10.6718, 13.1378, 11.7772, 14.3141, 21.2302],
        [22.3304, 27.7223, 30.8390, 40.4655, 32.8774, 23.9085, 26.1441],
        [14.1636, 11.6649, 20.3314, 29.6686, 30.1289, 14.5581, 11.0994],
        [14.3282, 25.6236, 46.3861, 13.1661,  0.0000, 10.7710, 18.4666],
        [18.5516, 25.9354, 29.1241, 54.7902, 53.2738, 58.0499, 28.7415],
        [44.0760, 3

batch_predictions
tensor([[18.3247, 17.1371, 16.9541, 17.2694, 17.7119, 18.1550, 18.5695],
        [14.8444, 19.1162, 20.4893, 18.9406, 16.7427, 14.3381, 13.3264],
        [16.0739, 16.7543, 16.5470, 16.1915, 15.9043, 15.7333, 15.6674],
        [15.9764, 20.6389, 21.1682, 19.6705, 17.2750, 14.8082, 14.1119],
        [22.7300, 24.3469, 24.0523, 23.3664, 22.6222, 21.9602, 21.4676],
        [13.7338, 17.8663, 21.1095, 19.6322, 17.5287, 15.1713, 13.9050],
        [ 7.1861,  6.3314,  6.1435,  6.1318,  6.1479,  6.1602,  6.1693],
        [12.8392, 16.8135, 21.3030, 21.7910, 20.9238, 19.7022, 18.4634]])
ground truth
tensor([[29.1525, 26.7857, 11.2103, 13.3362, 14.5833, 13.8180, 21.3152],
        [22.9308, 30.3571, 18.2540,  9.0561, 10.2041,  7.0578,  9.8781],
        [11.2704, 15.3998, 23.4876, 19.9895,  7.6013, 11.0074,  9.2057],
        [ 9.9421, 11.4940, 13.5850, 27.6433, 22.6065, 10.0079,  8.8769],
        [19.5862, 28.6990, 22.8600, 27.6361, 13.1944, 13.5062, 13.9172],
        [15.3472, 2

batch_predictions
tensor([[ 8.7352,  8.8285,  8.8703,  8.8781,  8.8691,  8.8479,  8.8334],
        [24.8066, 23.7117, 23.4965, 23.6964, 24.0656, 24.4857, 24.8972],
        [ 7.2602,  6.5110,  6.2542,  6.1978,  6.1998,  6.2096,  6.2180],
        [21.3566, 22.3383, 22.1030, 21.4924, 20.8030, 20.2003, 19.7825],
        [ 4.9944,  4.4212,  4.3936,  4.4205,  4.4412,  4.4548,  4.4648],
        [34.5057, 32.4293, 32.0163, 32.2925, 32.7998, 33.3477, 33.8538],
        [15.1030, 14.7427, 15.8834, 16.8398, 17.5253, 17.9387, 18.1523],
        [35.0936, 38.2103, 38.1003, 37.9316, 37.7988, 37.7062, 37.6484]])
ground truth
tensor([[ 7.9300,  5.9442,  8.6928, 10.6786, 17.0568,  6.3519,  5.8916],
        [15.7549, 19.2267, 12.5986, 16.4519, 18.3193, 32.4698, 31.8517],
        [15.6365, 16.6491,  9.4819,  6.2730,  7.1015, 11.1783, 10.8101],
        [17.8146, 20.8333, 29.2800, 31.6752,  8.4325, 17.2477, 13.5488],
        [ 6.4626,  8.5034,  6.1933,  4.5210,  5.4989, 10.0057,  7.4830],
        [49.2914,  

batch_predictions
tensor([[15.2399, 18.8910, 19.6144, 18.7061, 17.6897, 16.7763, 16.0656],
        [27.0201, 22.5465, 24.0059, 25.8798, 27.6914, 29.2450, 30.3964],
        [ 7.1870,  7.5873,  7.6760,  7.6840,  7.6766,  7.6621,  7.6557],
        [21.8052, 23.6945, 23.3382, 22.5948, 21.8112, 21.1202, 20.6129],
        [26.8072, 28.2371, 28.3031, 27.8651, 27.1821, 26.4067, 25.6494],
        [17.4484, 18.3586, 18.0878, 17.6305, 17.2435, 16.9942, 16.8827],
        [19.5419, 19.1166, 19.1083, 19.2164, 19.3322, 19.4361, 19.5280],
        [22.7478, 25.4519, 25.0483, 24.0680, 22.8780, 21.6143, 20.5573]])
ground truth
tensor([[ 7.4546, 16.6241, 35.5300, 35.5159, 15.8730, 15.5612, 11.4938],
        [41.7234, 16.6950,  7.6531, 21.3719, 21.3577, 29.6769, 48.4977],
        [ 2.6502, 14.3566, 15.0368, 13.7046, 14.6825, 20.9467,  2.2251],
        [22.0147, 20.1999, 25.2893, 42.6092, 21.2914, 21.0153, 25.0789],
        [39.5408, 46.9671, 42.4603, 19.1327, 19.2035, 17.8146, 19.9688],
        [19.5423, 1

batch_predictions
tensor([[17.6806, 22.4899, 21.5630, 19.8460, 17.8369, 16.4539, 16.1856],
        [14.3158, 17.7971, 20.3349, 19.5515, 18.6653, 17.9678, 17.5413],
        [21.8190, 22.2093, 22.1194, 21.9204, 21.7380, 21.6184, 21.5675],
        [24.3610, 28.8110, 28.4237, 27.0203, 24.7449, 21.3408, 19.3511],
        [22.9173, 28.1158, 27.5422, 25.5236, 21.6698, 17.8483, 17.8816],
        [ 7.4004,  7.4515,  7.4642,  7.4636,  7.4589,  7.4417,  7.4328],
        [19.3982, 21.9593, 21.4461, 20.6284, 19.8493, 19.2055, 18.7567],
        [24.3403, 24.2754, 24.2831, 24.2932, 24.2948, 24.3023, 24.3197]])
ground truth
tensor([[11.5079, 14.4983, 19.8696, 37.8260, 30.1304,  4.2234, 14.5975],
        [10.6151, 14.6542, 13.0952, 14.3141, 21.3861, 23.0017, 10.7001],
        [22.4487, 18.5429, 20.8180,  8.1931,  9.4555, 16.3598, 15.2420],
        [17.1485, 15.4478, 38.8322, 39.6542, 21.0601, 17.1627, 12.7268],
        [25.6519, 37.0465, 31.4201, 17.2761, 15.6746, 12.4717, 17.9280],
        [ 8.4298,  

batch_predictions
tensor([[21.1683, 21.1112, 21.1213, 21.1289, 21.1294, 21.1356, 21.1521],
        [31.6762, 31.3416, 31.3104, 31.3573, 31.4094, 31.4714, 31.5487],
        [26.0588, 22.5108, 22.7251, 23.8781, 25.0534, 26.0812, 26.9138],
        [19.2930, 23.1545, 22.0181, 19.9206, 17.2661, 15.8458, 15.9353],
        [12.3295, 14.3770, 16.8011, 16.9347, 16.2784, 15.6470, 15.1564],
        [19.2368, 24.9361, 23.6538, 21.0151, 17.5742, 16.1750, 16.4568],
        [24.6982, 21.0721, 21.7152, 23.1053, 24.4997, 25.7338, 26.7379],
        [11.8779, 11.7361, 11.5997, 11.4617, 11.3070, 11.1118, 10.8313]])
ground truth
tensor([[19.7704, 24.2914, 36.5363, 22.4348, 24.7307, 28.9116, 23.0159],
        [49.0646, 43.1122, 23.6395, 21.4711, 15.7596, 28.6990, 52.4518],
        [26.3747, 26.0629,  2.4093, 12.2307, 23.4127, 39.5125,  6.8736],
        [11.5930, 12.9393, 16.1565, 24.5607, 27.6644,  5.0879, 12.7126],
        [16.8464, 21.1731, 29.1557,  8.9295,  3.9321, 12.8222, 11.9411],
        [18.5374, 2

batch_predictions
tensor([[34.8446, 38.1969, 38.0789, 37.9217, 37.8021, 37.7199, 37.6685],
        [23.2711, 27.8506, 27.2733, 25.4873, 22.3679, 18.8742, 18.3206],
        [28.2891, 24.0384, 24.6150, 26.0134, 27.4003, 28.5751, 29.4694],
        [26.4086, 33.1961, 33.3025, 31.6778, 28.7308, 23.0613, 19.8500],
        [23.5535, 22.2257, 21.7700, 21.8245, 22.1769, 22.6913, 23.3095],
        [20.1848, 26.0256, 30.8805, 32.9656, 33.1656, 32.7029, 32.0569],
        [21.9794, 28.7772, 28.1683, 26.4339, 23.4954, 19.5922, 18.6526],
        [17.2335, 22.2381, 22.8480, 22.0954, 20.9341, 19.4666, 18.0421]])
ground truth
tensor([[25.0992, 25.0142, 38.7755, 61.8481, 20.0822, 26.4739, 27.9195],
        [21.1599, 36.8885, 26.7491, 14.8080,  8.9295,  0.0000,  0.0000],
        [24.1497, 16.7375, 18.1548, 19.3169, 25.5952, 47.6332, 14.7534],
        [16.9076, 20.6633, 30.9949, 46.7687, 34.8356, 19.1610, 16.2840],
        [39.4558, 36.1536, 14.2290, 10.3033, 16.8934, 16.6383, 33.7443],
        [17.6020, 2

batch_predictions
tensor([[18.8288, 20.9657, 20.6227, 19.9565, 19.3371, 18.8666, 18.5780],
        [22.5694, 21.1510, 20.7414, 20.9189, 21.3782, 21.9421, 22.5432],
        [11.5014, 14.1256, 16.9310, 19.0435, 19.3064, 18.7138, 17.9550],
        [18.7501, 19.1618, 19.0381, 18.7839, 18.5547, 18.4058, 18.3417],
        [15.3305, 18.2448, 18.2010, 17.4618, 16.7347, 16.1658, 15.7833],
        [20.0817, 27.7997, 27.2615, 24.4140, 18.3214, 16.2219, 16.4760],
        [13.5122, 12.6951, 13.2946, 14.2514, 15.0910, 15.5449, 15.7415],
        [26.8156, 24.6577, 24.9345, 25.6790, 26.3397, 26.8439, 27.2121]])
ground truth
tensor([[15.4787, 29.1426, 28.4850, 29.1031, 15.2683, 16.5834, 16.1757],
        [34.1412, 13.7613, 15.1077,  9.4813, 18.4099, 19.9972, 30.0595],
        [ 0.0000, 11.7772, 10.7143, 12.3299, 13.8605, 33.4467, 20.6774],
        [13.4779, 32.8656, 22.9592, 10.4167, 15.0227, 14.1015, 15.3770],
        [13.7559, 12.4145, 16.9516, 21.5939, 21.8569,  7.2199, 13.1641],
        [14.0306, 2

batch_predictions
tensor([[19.2146, 21.9736, 21.2977, 20.2369, 19.1432, 18.1402, 17.4412],
        [10.5045, 11.3952, 12.4583, 14.0134, 15.9299, 16.0555, 15.5294],
        [22.0250, 22.3359, 22.2880, 22.1540, 22.0318, 21.9548, 21.9240],
        [16.9231, 16.5700, 16.6044, 16.7421, 16.8884, 17.0191, 17.1302],
        [21.9740, 23.2009, 23.0224, 22.5155, 21.9934, 21.5829, 21.3297],
        [22.0206, 30.3371, 29.8208, 27.5523, 22.6497, 17.7520, 18.1130],
        [ 4.6886,  4.5080,  4.5074,  4.5228,  4.5332,  4.5395,  4.5444],
        [18.9935, 19.5312, 19.2948, 18.9378, 18.6462, 18.4698, 18.4038]])
ground truth
tensor([[ 7.1147, 28.0773, 24.0268,  8.8506, 19.0689, 16.7543, 17.9116],
        [ 2.3668,  8.4042, 16.7659,  1.5590,  0.0000,  0.0000, 12.9110],
        [15.9439, 18.4240, 16.9643, 11.0261, 15.0227, 16.4541, 21.7829],
        [22.8958, 20.0947, 10.5471, 12.7433, 10.1262, 14.4792, 19.5423],
        [31.9963, 13.4403, 14.4398, 11.9805,  3.9716, 16.8201, 23.9348],
        [13.0385, 1

batch_predictions
tensor([[37.5675, 38.1942, 38.1553, 38.0978, 38.0580, 38.0344, 38.0221],
        [19.2274, 18.5275, 18.4015, 18.5805, 18.9076, 19.3124, 19.7550],
        [ 5.5766,  5.1600,  5.1147,  5.1290,  5.1435,  5.1528,  5.1597],
        [26.9380, 22.4267, 23.8706, 25.7419, 27.5740, 29.1782, 30.3900],
        [29.9489, 30.4746, 30.3929, 30.2555, 30.1538, 30.0980, 30.0770],
        [16.5635, 21.2148, 23.5035, 23.0461, 22.0389, 20.6981, 18.9732],
        [13.9983, 17.5080, 21.3400, 20.5603, 19.0255, 17.2761, 15.7272],
        [21.1458, 21.9831, 21.7807, 21.3868, 21.0287, 20.7809, 20.6557]])
ground truth
tensor([[28.8549, 22.5057, 29.1241, 37.3866, 56.7460, 50.1701, 32.6531],
        [42.0305, 30.3919, 13.3482, 12.4540, 14.3740,  8.6796,  8.7454],
        [ 5.0595,  9.1270,  4.8753,  6.2925,  6.6468,  7.3271,  7.5680],
        [37.4717, 22.8175, 18.1973, 13.1236, 11.3662, 21.6837, 47.6190],
        [30.6023, 32.1410, 36.4413, 37.3225, 20.7259, 24.7764, 25.5392],
        [14.2007, 1

batch_predictions
tensor([[29.5861, 29.0110, 28.9898, 29.1067, 29.2379, 29.3613, 29.4747],
        [23.1207, 21.3065, 21.2699, 21.8494, 22.4544, 22.9701, 23.3865],
        [15.8314, 15.5204, 15.5397, 15.6386, 15.7446, 15.8413, 15.9257],
        [17.5350, 18.0574, 17.9245, 17.6666, 17.4522, 17.3232, 17.2712],
        [19.0763, 17.5159, 18.4717, 19.2579, 19.8288, 20.2238, 20.4918],
        [17.2493, 17.6668, 17.5174, 17.2818, 17.0991, 16.9973, 16.9634],
        [16.1079, 19.8925, 22.8549, 23.3157, 22.7187, 21.8280, 20.9099],
        [22.2913, 25.7480, 25.0568, 23.5306, 21.3165, 18.8481, 17.9375]])
ground truth
tensor([[23.2246, 21.4229, 14.7685, 17.8985, 20.6996, 26.6833, 31.8648],
        [23.5544, 16.4116, 11.4654, 18.3815, 17.3469, 40.4337, 25.8078],
        [21.4144, 21.6695, 13.4495, 11.8622, 11.2387, 11.3520, 15.0935],
        [15.4195, 15.5612, 27.0408, 17.0210, 11.4654, 11.7063, 11.5930],
        [26.0771, 26.4598, 15.4620, 13.5488, 16.9076, 12.5283, 22.9734],
        [18.9506, 2

batch_predictions
tensor([[23.7501, 31.8861, 31.7155, 29.9938, 26.9005, 20.9618, 19.0418],
        [18.1543, 23.5357, 22.0972, 18.9932, 15.4802, 14.6919, 14.9032],
        [17.1783, 23.4969, 24.6446, 23.4703, 21.1833, 17.3414, 16.0544],
        [23.1006, 27.8081, 27.3185, 25.4817, 22.0483, 18.2147, 17.9075],
        [12.2042, 11.0213, 10.3003,  9.6714,  9.1224,  8.6941,  8.4036],
        [ 9.3746, 11.1505, 12.8318, 15.1806, 17.5066, 18.2437, 18.0171],
        [ 8.1459,  8.4953,  8.6009,  8.6257,  8.6211,  8.6112,  8.6060],
        [16.8434, 21.4866, 20.9953, 19.7526, 18.2186, 16.8256, 16.1055]])
ground truth
tensor([[18.3982, 16.0836, 21.0021, 31.0889, 33.9690, 17.0042, 14.9527],
        [17.4887, 17.7438, 23.4977, 10.3741,  0.0000,  7.4263, 11.4371],
        [ 9.3821, 14.3566, 18.1548, 32.7523, 32.7806, 11.5505, 12.9960],
        [16.9501, 25.6519, 37.0465, 31.4201, 17.2761, 15.6746, 12.4717],
        [ 4.7080,  7.1410,  5.6418,  6.5755, 13.2956, 12.6249, 16.7149],
        [13.2693, 1

batch_predictions
tensor([[11.4941, 14.2524, 18.4899, 20.4452, 19.1051, 17.2477, 15.1921],
        [25.3785, 23.3088, 22.9470, 23.4178, 24.1037, 24.7847, 25.3986],
        [ 5.8203,  4.8252,  4.7500,  4.7854,  4.8161,  4.8370,  4.8521],
        [17.5916, 16.4169, 17.2588, 17.9176, 18.3725, 18.6826, 18.8956],
        [14.1727, 13.2428, 13.4421, 14.0571, 14.7323, 15.3178, 15.7778],
        [36.3475, 37.6233, 37.6255, 37.4750, 37.3339, 37.2350, 37.1781],
        [14.8052, 16.4999, 16.6053, 16.1360, 15.6460, 15.2672, 15.0182],
        [15.0106, 19.3301, 21.3830, 19.6795, 16.9569, 14.2310, 13.5793]])
ground truth
tensor([[10.3741, 19.6712, 32.2279, 17.3895,  6.9728,  8.9427, 10.7143],
        [17.5170, 16.3549, 21.1593, 24.5890, 27.0550, 36.8481, 29.9461],
        [ 6.3634,  6.9728, 11.7914,  7.3271,  6.2217,  9.5380,  8.2766],
        [ 6.1941,  5.1420, 13.7691, 10.0868, 14.3346, 18.4245, 22.4882],
        [ 4.7212,  7.8248,  9.1399, 10.6128, 11.7964, 23.1983, 22.5802],
        [38.6338, 6

batch_predictions
tensor([[18.7106, 17.3172, 17.6344, 18.2346, 18.7227, 19.0884, 19.3590],
        [ 7.2551,  6.6719,  6.5052,  6.4767,  6.4824,  6.4890,  6.4942],
        [37.4695, 37.4786, 37.4774, 37.4686, 37.4621, 37.4620, 37.4613],
        [13.2110, 17.1250, 20.4099, 21.6400, 21.2771, 20.4713, 19.6558],
        [21.6533, 21.7899, 21.7632, 21.6862, 21.6133, 21.5701, 21.5600],
        [28.5454, 23.0598, 24.4425, 26.4473, 28.4803, 30.3140, 31.8197],
        [17.7034, 16.2396, 17.2139, 18.0053, 18.5421, 18.8931, 19.1251],
        [12.9224, 15.9739, 18.7108, 18.5249, 18.0726, 17.7245, 17.5335]])
ground truth
tensor([[14.2007, 13.6196, 17.4461, 17.9138, 17.0777, 20.8333, 21.5561],
        [ 9.5082, 16.8201, 17.5171,  6.7596,  7.7196,  7.3645,  8.8769],
        [56.7460, 50.1701, 32.6531, 27.1259, 22.0947, 30.4138, 45.9892],
        [10.2578, 17.0437,  6.6938, 28.2614, 26.0521, 13.5060, 12.0068],
        [20.0947, 22.4487, 18.5429, 20.8180,  8.1931,  9.4555, 16.3598],
        [36.0686, 1

batch_predictions
tensor([[15.7814, 18.7022, 19.1792, 18.5632, 17.9239, 17.4301, 17.1162],
        [12.7760, 15.5661, 18.6280, 17.8844, 16.9187, 16.1156, 15.5732],
        [28.0099, 27.6013, 27.5921, 27.6803, 27.7771, 27.8713, 27.9630],
        [17.6145, 22.4249, 21.7194, 19.9939, 17.7961, 16.2246, 15.9046],
        [ 6.5645,  4.3508,  4.1351,  4.1644,  4.2011,  4.2317,  4.2565],
        [14.6806, 14.5898, 14.6193, 14.6544, 14.6821, 14.7085, 14.7350],
        [19.9101, 21.9469, 21.5569, 20.7968, 20.0324, 19.3915, 18.9544],
        [24.2967, 25.6660, 25.5277, 24.9783, 24.3297, 23.7354, 23.2812]])
ground truth
tensor([[13.2228, 15.0935, 27.9337, 23.6820, 11.7347, 16.7234, 11.5930],
        [11.7044, 10.6391, 19.1084, 25.5129, 22.0279, 10.6786,  7.4171],
        [15.8305, 17.9989, 16.4966, 21.1451, 19.3027, 44.0760, 30.8673],
        [10.5208, 13.2299, 19.6081, 23.9611, 38.5192,  6.0100, 12.4408],
        [ 5.4280,  5.8390,  5.0028,  4.6769,  5.0312,  4.9036, 11.7772],
        [17.0777, 2

batch_predictions
tensor([[20.7389, 27.0017, 25.7390, 22.4990, 17.0775, 16.1550, 16.3775],
        [13.8902, 17.0720, 19.7733, 18.7769, 17.4797, 16.2874, 15.3131],
        [20.6421, 22.0844, 21.6807, 20.7768, 19.7152, 18.7200, 18.0284],
        [30.4850, 31.1897, 31.2677, 31.1213, 30.9020, 30.7071, 30.5798],
        [23.2973, 24.7298, 24.7169, 24.1225, 23.2322, 22.2431, 21.3458],
        [19.9827, 20.2598, 20.1327, 19.9549, 19.8191, 19.7451, 19.7244],
        [17.2857, 21.8384, 21.8143, 20.9079, 19.6902, 18.4102, 17.4136],
        [14.7906, 18.3457, 19.5532, 18.5729, 17.4630, 16.4490, 15.6343]])
ground truth
tensor([[17.2477, 22.5198, 29.9461, 26.4456, 12.4575, 12.0040, 11.3379],
        [10.9679, 13.1115, 13.3745, 19.0558, 31.7464, 33.9032, 10.8759],
        [13.9663, 15.7549, 23.8559, 21.1468, 24.2504, 10.1526, 11.3493],
        [40.6321, 22.9734, 24.4898, 17.4036, 23.2426, 29.2092, 31.1650],
        [15.4787,  9.5608, 21.7386, 28.6165, 37.9537, 11.5334, 18.1089],
        [13.8180, 1

batch_predictions
tensor([[24.7747, 33.4828, 34.0219, 31.8947, 27.1973, 19.6303, 19.3634],
        [15.8567, 17.2546, 17.0401, 16.4956, 15.9952, 15.6315, 15.4187],
        [ 6.4543,  6.2817,  6.2547,  6.2585,  6.2635,  6.2654,  6.2671],
        [17.7166, 18.4490, 18.1631, 17.6255, 17.1191, 16.7596, 16.5784],
        [32.3755, 28.6969, 28.3014, 29.1350, 30.1478, 31.0834, 31.8956],
        [37.0180, 36.8260, 36.8417, 36.8582, 36.8697, 36.8843, 36.9040],
        [29.2051, 24.1690, 25.7544, 27.6712, 29.3708, 30.8164, 32.0566],
        [16.7563, 19.9313, 19.5903, 18.7840, 17.9995, 17.3838, 16.9942]])
ground truth
tensor([[16.0431,  4.3367, 20.2239, 32.6814, 48.1151, 36.6497, 19.9405],
        [10.0482, 17.1485, 20.3515, 25.9495,  9.1553,  8.1774, 10.4734],
        [ 7.7948, 12.3583,  9.8356,  8.4467,  7.7948,  6.2217,  7.9649],
        [20.2920, 22.6328, 31.2073,  4.9579, 11.5597, 12.5986, 11.7833],
        [18.0563, 19.1347, 21.9095, 25.5786, 32.4961, 39.3872, 32.4566],
        [59.1128, 4

batch_predictions
tensor([[12.5718, 11.5086, 11.1516, 11.0463, 11.1543, 11.4703, 12.0934],
        [25.3117, 21.7298, 21.4243, 22.4726, 23.7214, 24.9658, 26.1573],
        [37.2778, 37.2619, 37.2601, 37.2563, 37.2543, 37.2544, 37.2618],
        [19.3137, 17.4057, 18.6402, 19.7392, 20.5863, 21.1569, 21.5068],
        [19.4822, 20.1559, 20.0141, 19.7004, 19.4199, 19.2353, 19.1491],
        [ 8.6673,  6.5925,  5.8177,  5.7104,  5.7250,  5.7507,  5.7729],
        [19.4201, 21.8768, 21.2099, 19.8396, 18.1114, 16.5479, 15.9483],
        [18.0057, 16.9200, 16.9033, 17.2757, 17.6627, 17.9910, 18.2574]])
ground truth
tensor([[17.5171,  9.4424,  8.5087,  9.2451,  8.6665, 12.8222, 14.7685],
        [16.4399, 14.0023, 14.7392, 15.0368, 21.3152, 43.0414, 33.1207],
        [60.0482, 54.5210, 28.4297, 23.9654, 21.9246, 23.1718, 41.0714],
        [14.5581, 11.0994, 13.9400, 10.3367, 20.0158, 28.8401,  7.3251],
        [25.9495, 13.8322,  4.5210,  9.2120, 21.5136, 22.0947, 19.0193],
        [ 8.6310,  

batch_predictions
tensor([[25.1323, 24.5017, 24.4307, 24.5542, 24.7273, 24.9090, 25.0841],
        [ 7.7683,  7.1196,  6.9192,  6.8806,  6.8840,  6.8904,  6.8956],
        [14.7890, 17.8540, 18.2735, 17.5112, 16.7137, 16.0528, 15.5666],
        [21.1544, 21.6439, 21.4948, 21.1987, 20.9263, 20.7418, 20.6565],
        [19.0170, 21.1333, 20.6911, 19.8910, 19.0996, 18.4375, 17.9846],
        [ 6.2939,  6.3065,  6.3089,  6.3113,  6.3104,  6.2939,  6.2876],
        [17.2485, 18.8621, 18.5729, 17.9323, 17.3253, 16.8656, 16.5860],
        [16.0738, 19.2001, 19.0929, 18.3043, 17.4845, 16.8056, 16.3277]])
ground truth
tensor([[24.9605, 36.3624, 10.8233, 13.6902, 19.5555, 17.6092, 20.3972],
        [ 8.4035, 12.4803, 16.8201, 10.8233,  3.4850, 10.7049,  1.9069],
        [11.1783, 13.6376, 19.2399, 18.7270, 20.5287,  7.5487, 10.7443],
        [16.0179, 19.8843, 26.1967, 23.3430, 29.8527, 10.8233, 14.7554],
        [11.1536, 16.6383, 23.2426, 26.3605, 27.1117, 14.9943, 14.9943],
        [ 6.7859,  

batch_predictions
tensor([[23.0382, 21.5004, 21.2649, 21.6690, 22.2544, 22.8499, 23.4052],
        [17.9830, 16.4014, 16.4455, 17.1081, 17.7871, 18.3984, 18.9468],
        [ 9.4617, 10.6713, 11.5488, 12.4627, 13.6749, 15.2244, 16.3296],
        [23.9041, 25.7916, 25.6690, 24.9091, 23.8803, 22.7919, 21.8524],
        [19.3058, 23.3724, 22.1757, 20.1233, 17.5356, 15.9395, 16.0039],
        [17.2833, 16.0404, 17.2209, 18.1854, 18.8861, 19.3328, 19.5987],
        [20.4177, 26.8873, 27.6843, 26.5128, 24.0121, 19.8812, 18.1537],
        [29.7191, 34.9202, 35.1757, 34.2277, 32.6536, 30.5024, 27.5201]])
ground truth
tensor([[ 8.5034, 15.5187, 14.8810, 19.2177, 22.8458, 26.3605, 36.0969],
        [25.9863, 35.4024, 10.6128,  8.8769, 10.6654, 11.1520, 15.8469],
        [ 6.6043,  9.8498,  8.8577,  9.0986,  6.9728, 11.4371, 20.5074],
        [13.8874, 11.8096,  1.0521, 21.8306, 24.1846, 32.0358,  0.0000],
        [13.6338, 19.7562, 31.4201, 43.3673, 18.0130, 17.6446, 13.8605],
        [30.3713, 1

batch_predictions
tensor([[16.1903, 19.6503, 20.1570, 19.5810, 18.9665, 18.4867, 18.1909],
        [27.0565, 34.2604, 34.3746, 32.9569, 30.5532, 26.5231, 22.0948],
        [21.5307, 19.5081, 19.7728, 20.4873, 21.1242, 21.6363, 22.0371],
        [19.0830, 17.5168, 18.6405, 19.5692, 20.2619, 20.7397, 21.0510],
        [16.4392, 15.1241, 16.2355, 17.1718, 17.8303, 18.2338, 18.4751],
        [19.1548, 21.0192, 20.5300, 19.6227, 18.6955, 17.9284, 17.4563],
        [18.8068, 21.9616, 21.4323, 20.5495, 19.7034, 19.0279, 18.5863],
        [28.8655, 27.5989, 27.1452, 27.1506, 27.4299, 27.9059, 28.5274]])
ground truth
tensor([[14.9132, 11.3098, 12.6775, 16.2415, 20.1604, 24.2504,  9.9290],
        [21.9246, 35.3033, 43.8917, 49.7591, 20.1247, 27.7069, 30.1162],
        [22.7117, 14.4529, 13.7296, 16.2809, 19.8185, 26.4072, 44.7922],
        [22.6616, 11.3237, 14.3566, 11.1395, 13.3787, 19.6145, 22.3923],
        [22.4093, 11.5071, 10.7049,  9.6528, 10.0999,  6.8780, 25.2236],
        [24.3161, 2

batch_predictions
tensor([[35.7506, 34.4513, 34.4973, 34.7851, 35.0531, 35.2563, 35.4043],
        [26.1406, 35.9673, 36.7500, 35.5440, 32.8244, 27.1379, 21.0226],
        [23.1297, 26.6371, 26.1092, 24.9128, 23.3610, 21.4766, 19.9923],
        [18.3502, 17.3279, 18.6104, 19.7302, 20.6131, 21.2389, 21.6254],
        [ 9.4377, 10.1121, 10.5973, 11.0246, 11.4614, 11.9875, 12.7300],
        [17.6102, 21.6215, 24.8954, 26.6485, 26.6485, 25.8564, 24.6004],
        [ 7.7641,  7.6118,  7.5422,  7.5147,  7.5077,  7.5049,  7.5043],
        [38.0221, 37.4782, 37.5313, 37.6189, 37.6816, 37.7280, 37.7666]])
ground truth
tensor([[25.0263, 42.4250, 30.9442, 28.5245, 23.5928, 20.4235, 23.9479],
        [18.0826, 33.3640, 53.8532, 47.3567, 22.7117, 20.9232, 16.7938],
        [17.2902, 16.2698, 17.9705, 30.5556, 28.2455, 22.8883, 18.5658],
        [13.6639, 20.7259, 19.1478, 15.9390, 19.5292, 26.5650, 29.0373],
        [11.4150, 15.2157, 20.6207, 17.9905,  3.2614,  7.5355,  7.7854],
        [22.3172, 1

batch_predictions
tensor([[17.6206, 22.5861, 22.7977, 21.8575, 20.4050, 18.6458, 17.3379],
        [20.6095, 19.8146, 19.6406, 19.7697, 20.0297, 20.3482, 20.6924],
        [14.9365, 18.5113, 19.4367, 18.0377, 16.1408, 14.0670, 12.8787],
        [15.0506, 15.1065, 15.0773, 15.0226, 14.9779, 14.9590, 14.9621],
        [28.3659, 27.1796, 26.7463, 26.7396, 27.0098, 27.4757, 28.0926],
        [23.5683, 23.4432, 23.4575, 23.4802, 23.4984, 23.5203, 23.5497],
        [37.6611, 37.9906, 37.9696, 37.9315, 37.9051, 37.8912, 37.8861],
        [22.0361, 29.0857, 34.0554, 33.9748, 31.8382, 26.6599, 19.3886]])
ground truth
tensor([[18.3248, 14.3566, 22.6332, 28.1321, 11.6071, 13.2653, 13.8322],
        [35.8232, 11.9674, 15.7549, 17.1094, 17.5565, 17.6355, 27.3672],
        [13.6113, 11.6649, 11.1257, 10.7575, 28.5508, 28.7086, 10.5339],
        [ 6.3350, 17.6162,  2.9053,  3.3305,  5.7115,  5.8673, 13.2937],
        [45.2523, 31.4768, 19.2744, 17.8430, 15.4904, 14.5408, 31.6893],
        [24.1715, 2

batch_predictions
tensor([[38.3298, 38.3039, 38.3139, 38.3183, 38.3205, 38.3221, 38.3236],
        [15.7987, 18.6354, 18.6394, 17.8500, 17.0063, 16.2882, 15.7487],
        [29.0877, 24.1230, 25.1321, 26.8285, 28.4756, 29.8729, 30.9560],
        [18.1841, 16.2597, 17.2010, 18.2331, 19.0592, 19.6461, 20.0559],
        [30.8305, 35.7556, 35.6313, 35.0114, 34.4043, 33.9301, 33.6130],
        [15.4709, 20.0194, 23.4430, 24.1938, 23.6569, 22.8041, 21.9534],
        [18.2814, 23.5106, 22.2443, 19.4754, 16.0821, 15.0681, 15.2380],
        [24.5931, 28.7089, 28.3636, 27.0667, 25.0170, 21.9735, 19.7526]])
ground truth
tensor([[40.8206, 45.1210, 45.1604, 43.3588, 62.1120, 60.3893, 26.4335],
        [22.9308,  5.7540,  0.1134,  8.3333, 13.4212, 16.8226, 22.0805],
        [34.8356, 19.1610, 16.2840, 14.6400, 24.2772, 33.6168, 46.8679],
        [13.5060, 12.0068, 14.6107, 16.1494, 18.5955, 28.1168,  0.9074],
        [23.9479, 28.3535, 36.4282, 29.3924, 27.8012,  5.4314, 21.7123],
        [22.7749, 2

batch_predictions
tensor([[14.0621, 13.3270, 13.2752, 13.5420, 13.9187, 14.3109, 14.6674],
        [ 3.8838,  3.8173,  3.8287,  3.8383,  3.8445,  3.8468,  3.8478],
        [22.5463, 20.1255, 21.2369, 22.3475, 23.2457, 23.9144, 24.3798],
        [34.6345, 32.0235, 31.0704, 31.0489, 31.5310, 32.2484, 33.0819],
        [21.6660, 21.9257, 21.8242, 21.6529, 21.5088, 21.4233, 21.3958],
        [19.3836, 19.7765, 19.6715, 19.4913, 19.3485, 19.2660, 19.2353],
        [18.4553, 22.5084, 21.8885, 20.7528, 19.5282, 18.5082, 17.9124],
        [15.0761, 15.4666, 17.3525, 19.2241, 20.8007, 21.3367, 21.2468]])
ground truth
tensor([[14.4398,  5.4577,  5.3919,  4.9448,  8.2720, 11.0337, 19.6344],
        [ 5.9240,  8.4042, 10.0765,  4.1950,  3.1746,  2.8770,  5.4280],
        [27.6565, 11.7833, 16.4519, 18.2667, 16.6360, 24.4740, 36.4545],
        [38.2956, 22.6328, 25.6839, 18.3456, 25.4603, 35.7443, 51.3019],
        [21.8701, 16.8332, 24.3030, 29.6423, 13.1247, 18.3588, 10.1131],
        [17.5170, 2

batch_predictions
tensor([[12.2852, 11.8460, 11.4197, 10.9099, 10.1813,  9.0775,  7.8484],
        [18.3852, 21.7711, 21.6732, 20.5655, 19.0522, 17.5181, 16.5963],
        [15.2588, 14.3202, 15.3942, 16.3481, 17.0141, 17.3910, 17.5912],
        [33.6758, 28.7670, 28.8263, 29.9861, 31.2236, 32.3522, 33.3181],
        [25.6002, 23.0166, 23.0030, 23.7900, 24.6049, 25.3061, 25.8757],
        [24.8628, 35.5901, 37.0284, 36.3323, 35.1088, 33.2898, 30.2999],
        [10.5086,  9.7440,  8.6237,  7.4652,  6.8120,  6.5905,  6.5417],
        [ 5.2751,  5.4859,  5.4955,  5.4918,  5.4889,  5.4814,  5.4777]])
ground truth
tensor([[ 4.3924,  7.4040,  8.0747,  7.3514, 11.3887, 15.9258, 14.9527],
        [15.8872, 34.3821, 33.6593, 13.2795, 16.3407, 16.6100, 14.3849],
        [ 8.0747, 11.0600, 11.0863, 16.2152, 14.0715, 21.0021, 13.5587],
        [52.4461, 24.4345, 19.6739, 20.5418, 23.3693, 58.2588, 43.2272],
        [18.0414, 20.8333, 22.6049, 27.1967, 21.3152, 28.9683, 35.3175],
        [20.3090, 1

batch_predictions
tensor([[19.8381, 26.1630, 24.9912, 21.6165, 16.5371, 15.7765, 16.0055],
        [36.9233, 36.8934, 36.8922, 36.8894, 36.8889, 36.8907, 36.9018],
        [15.0481, 18.0877, 18.7453, 17.9883, 17.1620, 16.4632, 15.9433],
        [21.2422, 24.1100, 23.3554, 21.9217, 20.0138, 18.0774, 17.2835],
        [21.4597, 18.3857, 19.9041, 21.7444, 23.4173, 24.6252, 25.5345],
        [ 8.7822,  8.7784,  8.7802,  8.7822,  8.7759,  8.7549,  8.7516],
        [16.5721, 21.9386, 22.3701, 21.2504, 19.4284, 17.0757, 15.7322],
        [19.2410, 19.9571, 19.7487, 19.3099, 18.8895, 18.5881, 18.4314]])
ground truth
tensor([[13.6054, 12.1032, 12.1315, 16.6100, 24.4189, 28.4439, 18.6791],
        [24.9858, 19.6995, 32.3413, 30.0879, 54.1383, 53.5006, 39.6967],
        [ 8.3050, 13.2511, 13.8464, 15.0510, 22.3356, 23.6820, 12.2024],
        [ 9.7506, 16.0714, 20.5357, 25.2409,  0.7370,  0.0000,  0.7511],
        [11.5505, 12.9960, 11.6780, 13.7046, 20.8333, 30.4563, 28.7557],
        [ 7.9826, 1

batch_predictions
tensor([[20.8075, 20.1048, 20.0046, 20.1525, 20.3748, 20.6080, 20.8293],
        [20.1916, 19.6396, 19.6106, 19.7448, 19.8996, 20.0409, 20.1647],
        [13.4509, 13.1035, 13.0432, 13.1216, 13.2475, 13.3993, 13.5587],
        [17.4370, 23.6823, 24.0872, 22.6154, 19.5796, 15.9235, 15.6682],
        [18.6488, 24.6804, 24.4417, 22.9406, 20.2012, 17.3430, 16.9390],
        [19.8365, 25.4823, 24.2102, 21.7110, 18.3578, 16.5950, 16.8536],
        [18.8351, 24.0579, 22.8244, 20.2880, 17.1629, 15.9063, 16.1097],
        [16.4958, 15.8724, 15.8846, 16.1263, 16.4063, 16.6620, 16.8763]])
ground truth
tensor([[12.4291, 18.4949, 17.9847, 17.2902, 21.5420, 18.0272, 23.7954],
        [23.8953, 10.7575,  1.5255, 15.8206, 14.7422, 18.7533, 21.9227],
        [15.8730, 21.7545,  6.9870,  7.2279,  5.9524,  7.4121, 16.7092],
        [27.4987, 28.1168,  9.3766,  8.6796, 10.7838, 16.0705, 21.6597],
        [13.4008, 27.9590, 41.6228, 26.7622, 13.3351, 16.5308, 18.2272],
        [15.1894, 1

tensor([[19.7646, 21.1563, 20.8670, 20.3192, 19.8166, 19.4510, 19.2463],
        [ 9.5552,  9.6308,  9.6936,  9.7253,  9.7234,  9.7025,  9.6802],
        [15.0384, 14.8168, 14.9425, 15.1405, 15.3249, 15.4722, 15.5776],
        [15.1696, 18.5290, 20.0077, 19.2771, 18.4224, 17.6942, 17.1645],
        [27.6571, 33.7042, 33.8564, 32.3272, 29.5509, 24.4602, 20.5627],
        [23.3439, 22.9310, 22.8686, 22.9553, 23.0938, 23.2635, 23.4533],
        [25.2205, 21.4668, 22.5748, 24.1569, 25.6850, 26.9853, 27.9458],
        [19.6695, 23.4870, 22.6683, 20.7559, 18.0675, 16.1849, 16.1430]])
ground truth
tensor([[21.6202, 26.0652, 13.2956, 15.5839, 16.5045, 17.5829, 17.7670],
        [ 6.2730,  7.1015, 11.1783, 10.8101, 18.4508, 19.2267,  8.5350],
        [ 1.5590,  0.0000,  0.0000, 12.9110, 14.2007,  8.1349, 14.1582],
        [14.2031, 16.9253, 19.5029, 24.6186, 16.2152, 12.3882, 13.0721],
        [33.6735, 45.0680, 37.3158, 21.4569, 27.5368, 20.1814,  6.3776],
        [31.8452, 21.8679,  5.7965, 1

batch_predictions
tensor([[14.6114, 18.5840, 21.3995, 21.2386, 20.6688, 20.1320, 19.7698],
        [17.4785, 21.6591, 21.2079, 20.1094, 18.8348, 17.6951, 16.9985],
        [33.3860, 31.3233, 30.6877, 30.7831, 31.2415, 31.8756, 32.5893],
        [18.3561, 24.8982, 24.7018, 22.5747, 18.0811, 15.4361, 15.5930],
        [ 7.8400,  8.9830,  9.6002, 10.0665, 10.5023, 10.9891, 11.6098],
        [15.8721, 18.9312, 19.2933, 18.8418, 18.4072, 18.1077, 17.9527],
        [14.8982, 18.1940, 18.7642, 18.0154, 17.2350, 16.6005, 16.1553],
        [25.5699, 28.7252, 28.4830, 27.4341, 25.8233, 23.4351, 20.7789]])
ground truth
tensor([[ 9.1837, 19.4728, 18.7358, 18.8492, 27.2392, 25.5385,  5.8532],
        [21.7545, 27.1825, 14.0306, 14.0164, 10.9694, 14.2574, 12.7834],
        [ 4.6627,  0.0000, 17.6020, 21.4002, 23.2851, 34.4529, 42.4461],
        [ 9.9065, 23.2426, 37.8543, 34.7931, 15.3628, 11.8764, 11.2245],
        [ 3.9716,  4.3530,  8.7980, 11.3098, 19.4766, 19.7002,  3.3535],
        [18.0957, 1

batch_predictions
tensor([[12.2705, 14.6857, 18.4053, 18.9180, 17.6589, 16.1201, 14.6367],
        [28.0488, 34.1247, 34.5463, 33.2208, 30.6541, 26.1650, 21.7016],
        [18.1852, 22.3071, 21.4803, 19.8912, 18.0339, 16.5277, 16.0446],
        [ 5.6606,  4.4263,  4.3346,  4.3731,  4.4067,  4.4302,  4.4472],
        [14.0285, 13.6688, 14.8674, 16.0465, 16.8735, 17.2376, 17.3646],
        [23.2241, 22.3266, 22.1380, 22.3014, 22.6266, 23.0217, 23.4335],
        [15.7158, 20.1653, 24.3110, 25.3149, 24.6959, 23.3914, 21.5867],
        [30.0373, 33.2248, 33.0291, 32.2779, 31.4751, 30.7767, 30.2229]])
ground truth
tensor([[ 7.7591,  9.3766,  8.9953, 11.1520,  0.0000,  2.8932,  9.1794],
        [23.2001, 18.4382, 23.3277, 33.0499, 49.0646, 43.1122, 23.6395],
        [15.4904, 20.0539, 28.6990, 35.4025, 10.5867, 12.2449, 10.8702],
        [13.2430, 14.8606, 10.3235,  7.1278,  5.4182, 16.9779,  8.4429],
        [ 7.9826, 14.0715, 12.9274,  9.5608, 12.7564, 19.2530, 16.7675],
        [29.1557, 1

batch_predictions
tensor([[20.7304, 26.9038, 25.9695, 23.9474, 20.9138, 18.2719, 18.1847],
        [17.8435, 17.7987, 17.8174, 17.8274, 17.8293, 17.8360, 17.8523],
        [22.4192, 28.8404, 28.0092, 25.9609, 22.5646, 18.9991, 18.8634],
        [16.6501, 15.3062, 16.3326, 17.2063, 17.8103, 18.1928, 18.4340],
        [10.6422, 11.6526, 12.7982, 14.3965, 15.9994, 15.8777, 15.3796],
        [ 4.4438,  4.3835,  4.3871,  4.3951,  4.3998,  4.4022,  4.4046],
        [16.5382, 17.4607, 17.2508, 16.8505, 16.5176, 16.3085, 16.2148],
        [14.8621, 14.7663, 14.8015, 14.8431, 14.8758, 14.9052, 14.9329]])
ground truth
tensor([[15.1644, 18.7925, 23.2568, 39.2290, 29.7052, 17.3186, 17.4461],
        [26.3180, 24.5465, 14.0306, 16.0006, 18.8350, 15.1644, 26.6865],
        [49.8948, 23.3167, 28.6691, 25.3419, 22.4750, 48.3824, 43.4903],
        [28.1825, 15.1894,  6.0889, 15.0710, 13.6902, 17.2935, 23.4219],
        [ 9.2977,  9.3372, 16.3335, 20.1867, 19.2530,  8.6796, 11.4940],
        [ 6.7460, 1

batch_predictions
tensor([[20.3892, 27.2869, 28.6526, 27.6957, 25.5877, 21.2954, 18.2858],
        [21.3429, 22.7903, 22.3997, 21.6579, 20.9124, 20.3131, 19.9401],
        [19.0342, 21.1048, 20.6300, 19.6883, 18.6678, 17.7529, 17.1398],
        [18.8231, 17.9444, 17.8203, 18.0483, 18.3886, 18.7439, 19.0811],
        [23.1188, 26.4047, 25.9583, 24.8297, 23.3540, 21.5966, 20.1469],
        [27.6195, 33.2876, 33.6098, 32.1063, 29.1635, 23.9121, 20.4339],
        [19.5722, 17.5224, 18.3489, 19.2688, 20.0079, 20.5445, 20.9166],
        [33.0041, 35.2505, 35.3867, 35.0030, 34.4938, 34.0102, 33.6212]])
ground truth
tensor([[49.0363, 43.1548, 13.2370, 10.7710, 16.6100, 21.5420, 36.0261],
        [20.4103, 22.5276, 22.2777, 11.9148, 11.1520, 18.0957, 17.3593],
        [22.1372, 23.0867, 15.2353, 12.6276, 11.1820, 14.2999, 15.7596],
        [21.9813, 11.2528, 11.9473,  9.9915, 12.7409, 12.9960, 23.9938],
        [19.4870, 34.3537, 23.2851, 20.7200, 17.4887, 15.0652, 15.0794],
        [32.6814, 4

batch_predictions
tensor([[24.4847, 28.5512, 28.5389, 27.2916, 24.8762, 21.1038, 18.8501],
        [16.4472, 17.8887, 17.5041, 16.8490, 16.2739, 15.8625, 15.6312],
        [17.0768, 22.0949, 25.1688, 24.9146, 23.7521, 21.6962, 18.2108],
        [21.1832, 23.5192, 23.0039, 22.0836, 21.0868, 20.1226, 19.3648],
        [22.7967, 21.9338, 21.7987, 21.9839, 22.2802, 22.5999, 22.9054],
        [16.1014, 15.8461, 17.0556, 18.0441, 18.7612, 19.2164, 19.4698],
        [20.2212, 21.9247, 21.1504, 19.9143, 18.5214, 17.3229, 16.7452],
        [10.2947, 12.4666, 14.8460, 17.7527, 19.1872, 18.8190, 18.0260]])
ground truth
tensor([[10.5584, 16.7517, 15.4478, 22.1797, 44.8696, 36.2528,  9.3963],
        [14.7817, 17.4382, 13.7296,  6.6149, 11.2572,  8.9032, 10.9548],
        [17.6020, 18.4382, 19.0618, 23.0867, 34.8214, 38.9031,  6.5051],
        [19.8843, 35.9942, 27.6565, 11.7833, 16.4519, 18.2667, 16.6360],
        [15.5612, 16.5108, 20.7766, 32.7239, 29.2092, 20.3231, 18.0272],
        [29.7761, 1

batch_predictions
tensor([[18.5802, 24.9903, 24.5636, 22.9116, 19.8549, 17.0449, 17.2066],
        [19.9218, 26.1289, 30.7330, 32.8979, 33.2397, 33.0109, 32.7106],
        [26.5893, 32.1670, 32.0891, 30.6180, 28.0973, 23.5908, 20.2393],
        [21.8046, 21.4271, 21.3966, 21.5089, 21.6647, 21.8377, 22.0138],
        [27.0270, 23.1792, 23.1429, 24.3700, 25.7222, 26.9733, 28.0408],
        [20.3360, 21.6723, 21.2958, 20.5519, 19.7649, 19.0886, 18.6279],
        [21.8069, 26.1103, 25.3717, 23.3977, 20.0830, 17.2630, 17.2697],
        [16.2583, 22.4119, 23.4858, 21.6371, 17.5981, 13.9086, 13.6248]])
ground truth
tensor([[12.0331, 10.7049, 12.3356, 19.7265, 30.6812, 27.4329, 11.6386],
        [36.8622, 21.0547, 20.3051, 12.2436, 20.7654, 27.5907, 31.3519],
        [16.2840, 14.6400, 24.2772, 33.6168, 46.8679,  9.8498, 12.5567],
        [ 5.6973,  5.7540, 12.7834, 15.3061, 22.7749, 34.0136, 22.8316],
        [26.9700, 11.0402, 17.1910,  9.0703, 17.8430, 32.3554, 39.2007],
        [13.5981, 1

batch_predictions
tensor([[10.6586, 13.2112, 16.1470, 18.9512, 19.5994, 19.0338, 18.2299],
        [16.4730, 16.4051, 16.4246, 16.4415, 16.4507, 16.4624, 16.4802],
        [20.4595, 19.6654, 19.5045, 19.6439, 19.9044, 20.2071, 20.5171],
        [15.0958, 19.2468, 21.0840, 19.2704, 16.4163, 13.7640, 13.0767],
        [19.9273, 18.9485, 18.7509, 19.0016, 19.4689, 20.0288, 20.6210],
        [27.5452, 28.7629, 28.6895, 28.3119, 27.8908, 27.5388, 27.3014],
        [28.4695, 26.6784, 26.2531, 26.4998, 27.0188, 27.6088, 28.1809],
        [17.6147, 17.2222, 17.2619, 17.4330, 17.6225, 17.7949, 17.9409]])
ground truth
tensor([[11.0863, 11.2704, 13.0195, 18.8716,  5.4050,  6.6281,  3.7612],
        [24.4608, 29.8264,  9.9290, 12.0857, 12.4277, 11.0205, 19.8185],
        [29.8895,  4.9603,  0.0000, 10.9694,  9.1553, 14.2007, 22.8741],
        [15.9916, 21.8175, 26.6176, 10.6786, 12.5855,  9.4424, 11.6386],
        [33.6026, 34.4388,  8.4467, 11.5221, 10.1049, 18.5232, 21.3861],
        [21.5018, 2

batch_predictions
tensor([[19.1042, 19.3124, 19.2547, 19.1242, 19.0053, 18.9311, 18.9044],
        [25.7120, 27.1349, 26.9626, 26.5211, 26.0963, 25.7811, 25.5982],
        [20.6016, 18.7807, 20.3380, 21.8805, 23.2798, 24.4614, 25.2991],
        [18.5081, 19.7208, 19.3079, 18.5790, 17.8668, 17.3110, 16.9809],
        [21.4168, 24.6144, 24.1024, 22.6137, 20.3847, 18.0957, 17.3327],
        [24.6896, 31.1782, 30.8153, 29.0107, 25.6942, 20.4683, 19.3690],
        [18.3705, 18.4995, 18.4124, 18.3049, 18.2290, 18.1936, 18.1901],
        [10.5265, 11.8041, 13.3140, 15.4117, 16.7434, 16.2489, 15.7101]])
ground truth
tensor([[23.2772, 22.6460, 12.0200, 12.0200, 12.2699, 14.5055, 20.3446],
        [27.5652, 27.3526, 24.2489, 24.0788, 33.5176, 17.7012, 18.8067],
        [28.1463, 11.8764, 14.8526,  2.1825,  0.0000,  9.9915, 29.9178],
        [ 6.0889, 15.0710, 13.6902, 17.2935, 23.4219, 31.1547, 21.4624],
        [29.1383, 36.4654, 32.6247, 20.4223, 15.9722, 14.3424, 13.9031],
        [16.5176, 2

batch_predictions
tensor([[ 5.9251,  5.8370,  5.8274,  5.8329,  5.8367,  5.8379,  5.8393],
        [19.6463, 22.9280, 22.3614, 20.6797, 18.2275, 16.2039, 15.9044],
        [23.2162, 30.1932, 33.8624, 34.5246, 34.3878, 34.1525, 33.9543],
        [30.8607, 27.5538, 27.0747, 27.8619, 28.9017, 29.8715, 30.6947],
        [10.5152, 11.8121, 13.4881, 16.5014, 18.3273, 17.0399, 15.5168],
        [18.9899, 24.0174, 23.2067, 21.6998, 19.8275, 18.3107, 17.7879],
        [20.9281, 18.6178, 18.7850, 19.7471, 20.7048, 21.5507, 22.2963],
        [10.1122, 11.1957, 12.1291, 13.1984, 14.6181, 15.9257, 16.1541]])
ground truth
tensor([[ 4.7194,  5.8248,  3.4439,  5.7115,  7.1570,  5.0595,  9.1270],
        [12.2041, 13.4271, 15.1368, 13.3877, 18.7533, 33.4824, 13.8611],
        [18.2141, 24.2372, 23.6323, 30.2078, 31.9174, 39.7817, 26.2230],
        [26.6440, 21.1026, 21.2868, 23.1151, 35.6718, 45.7625, 36.7630],
        [12.2567, 12.9932, 23.7112,  6.9832,  6.9437,  9.1136,  7.9958],
        [18.8322, 1

batch_predictions
tensor([[11.6711, 11.3253, 10.9277, 10.4085,  9.6703,  8.7032,  7.8703],
        [16.8684, 17.8396, 17.6344, 17.1688, 16.7407, 16.4421, 16.2864],
        [16.8334, 16.6868, 16.7185, 16.7680, 16.8078, 16.8423, 16.8757],
        [25.9854, 27.5591, 27.4598, 26.8808, 26.1400, 25.3916, 24.7382],
        [13.4528, 12.5949, 12.8525, 13.4667, 14.1365, 14.6455, 14.9438],
        [22.0819, 21.2066, 21.0718, 21.2938, 21.6565, 22.0587, 22.4504],
        [ 4.2291,  4.2374,  4.2375,  4.2395,  4.2403,  4.2315,  4.2289],
        [18.5229, 24.2070, 23.0672, 19.9314, 15.7810, 14.9450, 15.1215]])
ground truth
tensor([[15.9127,  3.9716,  6.4308,  5.7207,  6.5755,  7.8906, 14.6502],
        [18.6481, 31.6938, 39.2031, 17.5697, 15.6891, 11.7833, 16.6886],
        [11.2704, 10.8759, 13.9400, 14.9921, 15.5050, 23.6323, 30.9179],
        [35.5867, 18.6366, 13.0952, 10.5726, 19.2177, 34.8214, 50.3543],
        [15.6891,  4.4319, 11.5992, 11.6649, 10.7575, 12.3751, 15.0447],
        [21.7687, 3

batch_predictions
tensor([[27.3127, 36.1574, 36.4301, 35.5542, 34.2178, 32.4592, 29.9198],
        [20.9001, 19.0718, 19.5528, 20.2721, 20.8570, 21.2960, 21.6189],
        [13.9636, 16.9966, 19.8274, 19.5326, 18.8457, 18.2345, 17.8366],
        [17.9227, 19.7394, 19.4313, 18.6791, 17.9078, 17.2722, 16.8513],
        [29.0999, 26.4986, 25.6065, 25.8319, 26.5984, 27.5461, 28.5265],
        [19.5006, 21.7886, 21.3662, 20.2183, 18.7654, 17.3785, 16.6128],
        [12.9910, 16.3570, 19.1607, 20.2969, 20.1189, 19.5313, 18.9244],
        [27.4852, 23.1484, 24.7276, 26.5673, 28.2305, 29.5984, 30.6381]])
ground truth
tensor([[21.3435, 23.6111, 25.1417, 19.1893, 41.2557, 40.3345, 14.7817],
        [29.3135, 12.6907, 20.6602, 15.5050, 15.8075, 18.5034, 27.3803],
        [ 9.0420, 10.8985,  7.4546, 16.6241, 35.5300, 35.5159, 15.8730],
        [18.5941, 28.0045, 27.3810, 14.5833, 13.9456, 12.4717, 14.1015],
        [42.7579, 27.2959, 18.2256, 17.8146, 17.4603, 23.7812, 35.0340],
        [18.2404, 2

batch_predictions
tensor([[18.6433, 18.8657, 18.7721, 18.6109, 18.4751, 18.3958, 18.3714],
        [19.5210, 25.9953, 24.8070, 21.7596, 17.3388, 16.3007, 16.5879],
        [23.4814, 30.6448, 30.8948, 29.3312, 26.2710, 20.9482, 19.1051],
        [11.3702,  9.4591,  7.3140,  6.3611,  6.1790,  6.1520,  6.1466],
        [19.3063, 25.2329, 24.2544, 21.8589, 18.3933, 16.5777, 16.7934],
        [16.6458, 20.5104, 21.9089, 21.5869, 21.0798, 20.6199, 20.2998],
        [28.3242, 23.4292, 24.8947, 26.7717, 28.5347, 30.0233, 31.1778],
        [12.3103, 15.7679, 19.5614, 21.0433, 20.4198, 19.4332, 18.4722]])
ground truth
tensor([[12.3882, 16.0179, 14.1636, 11.6649, 20.3314, 29.6686, 30.1289],
        [23.0300, 35.2749, 32.9932, 14.9943, 14.5692, 16.4683, 17.7438],
        [18.2009, 19.5818, 31.8122, 57.2988, 39.2425, 21.1205, 17.0174],
        [ 4.4713,  6.1284,  6.2073,  8.0747,  9.6134, 17.3067, 18.7796],
        [16.7375, 29.1383, 36.4654, 32.6247, 20.4223, 15.9722, 14.3424],
        [ 8.7322, 1

tensor([[12.3583,  9.8356,  8.4467,  7.7948,  6.2217,  7.9649,  8.6026],
        [20.8617, 18.8634, 24.3339, 44.2035, 14.3566, 46.0459, 18.3673],
        [ 9.0845, 12.6417, 14.2432, 13.9031, 22.7608, 27.5368, 27.2959],
        [ 7.6539,  8.5350, 10.8364, 12.4803, 20.9890, 28.2877, 11.4808],
        [25.9337,  6.1941,  5.1420, 13.7691, 10.0868, 14.3346, 18.4245],
        [20.0964, 17.7154, 19.7562, 26.0204, 27.6502,  5.8532, 15.5896],
        [ 6.0232,  8.2341,  7.8798,  9.4955,  9.6372, 16.3265,  8.8861],
        [21.7545, 31.9728, 26.0062, 17.4603, 12.3724, 17.7012, 13.6763]])
batch_predictions
tensor([[33.4229, 34.7775, 34.9419, 34.7334, 34.4217, 34.1311, 33.9166],
        [21.1793, 20.0288, 19.7484, 19.9239, 20.2922, 20.7205, 21.1594],
        [18.6951, 25.0886, 26.6771, 25.6537, 23.2235, 18.5168, 17.0505],
        [14.8978, 18.3714, 18.9849, 18.1695, 17.2844, 16.5184, 15.9308],
        [15.8690, 14.9193, 15.0704, 15.5210, 15.9374, 16.2690, 16.5211],
        [28.6665, 32.7269, 32.82

batch_predictions
tensor([[26.9290, 33.2283, 33.4574, 31.8447, 28.7587, 23.0366, 20.0664],
        [19.1302, 17.3662, 17.8824, 18.6023, 19.1889, 19.6330, 19.9614],
        [17.5368, 22.3938, 23.0331, 22.0136, 20.3446, 18.2577, 16.8900],
        [18.9795, 21.3840, 20.8918, 19.9375, 18.9281, 18.0238, 17.3958],
        [18.8011, 19.0977, 18.9622, 18.7434, 18.5617, 18.4544, 18.4188],
        [21.5199, 24.8741, 24.2743, 23.2178, 22.0583, 20.8721, 19.9085],
        [19.7036, 19.4857, 19.5110, 19.5846, 19.6580, 19.7258, 19.7900],
        [19.1382, 25.5014, 24.8509, 22.6203, 18.6813, 16.2489, 16.4136]])
ground truth
tensor([[19.4019, 33.9569, 44.1185, 32.1854, 15.4620, 20.9892, 17.4036],
        [24.3293, 15.0973, 16.0310, 16.0573, 15.5181, 17.6881, 27.7880],
        [16.8858, 16.9911, 18.5823, 12.9800, 26.6439, 19.6870,  0.0000],
        [18.0957, 15.2551, 25.9337, 31.1415, 11.2572, 15.4655, 15.7812],
        [23.4219, 31.1547, 21.4624,  4.2346, 11.2046, 13.3614, 20.2656],
        [17.7296, 1

batch_predictions
tensor([[26.0463, 30.0359, 29.9368, 28.7325, 26.5527, 22.7666, 19.5496],
        [13.4547, 17.6943, 22.4077, 22.6741, 21.7010, 20.1890, 18.4790],
        [37.9337, 35.0507, 35.1343, 35.8329, 36.4160, 36.7899, 37.0180],
        [ 5.2646,  5.0913,  5.0840,  5.0967,  5.1055,  5.1103,  5.1141],
        [25.7107, 31.4807, 31.3410, 29.8378, 27.2113, 22.2841, 19.3868],
        [28.8432, 24.7967, 25.1792, 26.4563, 27.7319, 28.8067, 29.6327],
        [29.8803, 28.0628, 27.6493, 27.9026, 28.4096, 28.9692, 29.4958],
        [29.7853, 29.2274, 29.1294, 29.2113, 29.3464, 29.5131, 29.7024]])
ground truth
tensor([[18.1689, 15.1502, 13.7046, 41.2132, 46.4853, 32.7948, 16.6241],
        [12.3724, 25.4252, 28.4722, 12.5425,  9.8498, 12.2307, 18.1689],
        [45.4932, 33.4892, 18.6083, 24.3339, 26.9274, 40.7171, 54.5351],
        [ 9.5380,  8.2766, 10.6718, 11.2103, 12.3583, 14.1015, 11.1536],
        [16.0836, 21.0021, 31.0889, 33.9690, 17.0042, 14.9527,  9.0610],
        [35.8418, 1

batch_predictions
tensor([[37.9095, 38.1597, 38.1418, 38.1158, 38.0984, 38.0896, 38.0865],
        [10.8076, 10.9851, 11.2119, 11.4706, 11.7651, 12.1299, 12.6182],
        [18.9869, 17.4420, 18.3585, 19.2679, 19.9151, 20.3556, 20.6558],
        [25.2300, 21.9696, 23.8693, 26.0016, 27.9947, 29.7112, 31.0612],
        [18.5595, 22.7719, 21.6486, 19.5416, 16.9456, 15.4033, 15.4402],
        [28.4085, 26.8479, 26.3675, 26.4668, 26.8716, 27.4244, 28.0361],
        [20.3269, 26.2369, 25.1945, 23.1002, 20.1230, 17.8137, 17.8802],
        [16.4033, 20.5788, 23.8284, 24.8128, 24.2922, 23.2206, 21.8396]])
ground truth
tensor([[45.1956, 55.9382, 52.4802, 25.5385, 19.2460, 23.9938, 36.5363],
        [18.6744, 13.3088,  8.1799,  8.1405,  7.0358, 10.5865, 10.4945],
        [25.5385,  5.8532, 13.0244, 10.4734,  9.9065, 13.4070, 19.8696],
        [24.7166, 12.9677, 19.9830, 12.5000, 17.3044, 30.4847, 56.2642],
        [15.6365, 11.1126, 14.6107, 16.7280, 27.7617,  7.7591, 11.2178],
        [46.8679,  

batch_predictions
tensor([[18.8097, 20.0965, 19.4642, 18.3581, 17.1265, 16.0688, 15.4969],
        [20.7520, 17.9558, 19.4958, 21.1386, 22.6844, 23.8514, 24.5878],
        [19.0910, 19.3672, 19.3009, 19.1442, 19.0022, 18.9122, 18.8758],
        [16.0281, 14.8338, 14.9955, 15.5771, 16.1507, 16.6483, 17.0712],
        [22.6446, 31.8307, 35.0872, 34.0465, 30.9613, 23.7972, 19.0869],
        [16.0447, 17.9765, 17.9745, 17.4400, 16.9114, 16.5159, 16.2695],
        [ 9.8409, 12.0914, 14.6492, 17.4585, 18.5809, 18.4955, 18.0707],
        [21.3309, 28.1064, 30.4409, 30.1182, 29.1948, 27.8308, 25.5795]])
ground truth
tensor([[20.3840, 23.9085,  5.0237, 10.7180,  8.3377,  9.5213, 12.9932],
        [38.9314, 30.4989, 13.7755, 10.6859, 11.9615, 14.0731, 23.7670],
        [26.6307, 24.3293, 15.0973, 16.0310, 16.0573, 15.5181, 17.6881],
        [ 7.7985, 10.0079,  7.9563, 16.5045, 12.0857, 17.5434, 21.5281],
        [12.9393, 15.3912, 10.3033, 22.9025, 43.4524, 40.4904, 24.7591],
        [11.6355, 1

batch_predictions
tensor([[15.8272, 21.2916, 22.3082, 21.1566, 19.2632, 16.5672, 15.0759],
        [22.1675, 24.8570, 24.7742, 23.8729, 22.6531, 21.3921, 20.3793],
        [18.3644, 16.9818, 16.8972, 17.4045, 17.9874, 18.5339, 19.0314],
        [20.6146, 28.0064, 28.6497, 27.2038, 24.2728, 19.6347, 18.8366],
        [11.6975, 13.5224, 15.9525, 17.3399, 16.6336, 15.9216, 15.3802],
        [19.9402, 22.8534, 22.3429, 21.1883, 19.9446, 18.9470, 18.4236],
        [28.2689, 31.1130, 30.9129, 30.1573, 29.2622, 28.3319, 27.3884],
        [ 9.2644, 10.5390, 11.4249, 12.3918, 13.8062, 15.8841, 17.2787]])
ground truth
tensor([[14.7817, 21.5939, 42.0305, 30.3919, 13.3482, 12.4540, 14.3740],
        [ 8.1633, 32.4263, 31.9019, 15.6604, 15.5896, 17.4887, 21.3010],
        [21.1731, 11.4019, 11.1520,  8.6796, 11.0600, 15.2946, 27.7617],
        [10.0198, 21.5561, 45.7200, 38.7755, 22.5482, 19.2177, 24.3764],
        [ 8.8112, 11.0600,  9.2583, 11.4282, 23.3561, 25.2762,  8.7585],
        [23.0274, 2

batch_predictions
tensor([[24.4941, 23.2287, 23.1554, 23.5103, 23.9275, 24.2976, 24.5987],
        [13.7522, 16.9473, 19.9571, 19.0641, 18.0280, 17.2164, 16.7313],
        [21.7043, 25.6192, 24.9344, 22.9969, 19.7034, 16.9580, 16.9390],
        [22.3774, 21.1602, 20.8680, 21.0624, 21.4692, 21.9489, 22.4457],
        [11.0056, 12.0704, 13.3847, 15.3255, 16.3661, 15.8461, 15.1656],
        [17.7915, 17.7090, 17.7378, 17.7686, 17.7895, 17.8103, 17.8366],
        [25.4572, 25.4726, 25.4735, 25.4422, 25.4514, 25.4530, 25.4264],
        [10.6804, 11.4437, 12.3310, 13.5665, 15.2004, 15.7451, 15.3323]])
ground truth
tensor([[12.7693, 14.0164, 19.0334, 18.0272, 23.9938, 11.1111,  8.6026],
        [14.0731, 10.9836, 18.5658, 22.6332, 28.4864, 27.9053, 11.2245],
        [15.0935, 16.2840, 19.8129, 31.0232, 32.9082,  5.4563, 16.7234],
        [29.2800,  4.5493,  4.5493, 14.0164, 19.6854, 26.5306, 32.8231],
        [15.7943, 19.5423,  6.2467,  8.0352,  9.3109, 10.2183, 11.9674],
        [27.4724, 2

batch_predictions
tensor([[ 8.2118,  8.7915,  9.0261,  9.1242,  9.1558,  9.1609,  9.1607],
        [17.5691, 22.0206, 26.0688, 28.5232, 28.8684, 28.1908, 26.8467],
        [13.4975, 16.3260, 17.9506, 17.4460, 16.9090, 16.5328, 16.3241],
        [22.7056, 32.1455, 32.9933, 31.1757, 27.2447, 20.0311, 18.8356],
        [24.5508, 26.2240, 26.0395, 25.4562, 24.8194, 24.2714, 23.8806],
        [17.6283, 20.0812, 19.7260, 19.0097, 18.3459, 17.8446, 17.5367],
        [18.3334, 22.8809, 24.3999, 23.2147, 20.9427, 17.9272, 16.6595],
        [16.0311, 19.7342, 19.6326, 18.7185, 17.6935, 16.7729, 16.0827]])
ground truth
tensor([[ 7.9790,  6.8736,  6.4059, 10.3600, 11.5646, 16.4399,  5.6406],
        [35.4450, 45.1956, 33.7585, 18.9909, 16.8651, 16.0006, 15.0794],
        [ 2.3951, 16.3549, 13.2370, 14.0164, 13.0244, 18.3673,  6.5051],
        [20.1672, 33.2766, 49.9575, 43.3107, 19.7846, 20.5782, 14.5975],
        [22.4206, 29.1241, 46.7120, 33.0924, 32.2279, 30.8107, 30.2438],
        [15.9439, 2

batch_predictions
tensor([[16.9961, 19.9183, 19.6613, 18.7651, 17.8086, 16.9831, 16.3889],
        [22.6634, 20.8156, 20.5002, 20.9754, 21.6399, 22.2965, 22.8990],
        [15.9801, 19.7704, 20.2190, 19.5297, 18.7658, 18.1224, 17.6779],
        [18.5295, 23.4402, 22.2722, 19.7119, 16.4847, 15.2828, 15.4285],
        [29.5868, 25.4549, 24.6533, 25.6364, 27.0582, 28.5077, 29.8700],
        [19.5436, 23.6261, 22.5259, 20.5665, 18.0024, 16.1306, 16.1119],
        [21.2973, 21.3481, 21.3372, 21.2919, 21.2449, 21.2187, 21.2191],
        [14.7514, 13.8205, 14.5409, 15.2466, 15.7314, 16.0290, 16.2092]])
ground truth
tensor([[23.0300, 19.5011, 11.3237, 12.9252, 12.9110, 11.3095, 12.6417],
        [28.3535, 17.9116,  4.6554,  9.5345, 14.7817, 15.7812, 18.1746],
        [12.1315, 11.1678, 12.6701, 20.5641, 26.7715,  6.5334,  8.2625],
        [11.7570, 19.0295, 27.5118, 32.4829, 11.0074, 13.0589, 16.5308],
        [24.7591, 16.7375, 13.4637, 13.9739, 23.7954, 48.1434, 11.4938],
        [22.2647, 3

batch_predictions
tensor([[16.1374, 16.9866, 16.8570, 16.4502, 16.0599, 15.7829, 15.6330],
        [28.5259, 30.4616, 30.3982, 29.8685, 29.2254, 28.6135, 28.1041],
        [22.8533, 26.5726, 25.9702, 24.3420, 21.6355, 18.3840, 17.5571],
        [20.4708, 21.9857, 21.6555, 20.8948, 20.0548, 19.3076, 18.7803],
        [16.7568, 16.4498, 16.5335, 16.7170, 16.9011, 17.0579, 17.1819],
        [28.2481, 29.2191, 29.2338, 28.9648, 28.6156, 28.2999, 28.0719],
        [12.8873, 13.9111, 14.6145, 14.7686, 14.5592, 14.2365, 13.9181],
        [21.5404, 22.3696, 22.1927, 21.6314, 20.9475, 20.3120, 19.8410]])
ground truth
tensor([[21.4361, 21.8306,  8.4429,  9.4424,  8.1536,  8.9821, 11.5992],
        [34.6088, 49.4898, 34.3537,  0.0000, 17.3328, 20.5924, 25.5527],
        [18.2115, 25.4393, 39.4983,  6.4201,  7.5680, 16.4683, 18.7358],
        [23.3844, 29.8895,  4.9603,  0.0000, 10.9694,  9.1553, 14.2007],
        [12.1120, 21.0547, 20.8969,  5.8916,  9.5213,  8.7980, 11.7570],
        [33.2983, 4

batch_predictions
tensor([[19.0113, 22.1314, 21.4396, 20.0362, 18.3359, 16.7723, 16.1468],
        [21.0086, 22.7141, 22.3691, 21.6739, 20.9675, 20.3784, 19.9792],
        [34.9285, 31.8544, 31.3046, 31.8075, 32.5976, 33.3946, 34.0936],
        [ 8.7267,  8.0934,  7.6855,  7.4860,  7.4061,  7.3819,  7.3776],
        [27.6765, 27.4457, 27.4184, 27.4559, 27.4961, 27.5490, 27.6249],
        [23.2360, 20.3835, 20.0499, 20.8591, 21.8374, 22.8076, 23.7552],
        [13.9084, 13.4475, 13.4540, 13.6522, 13.9075, 14.1589, 14.3729],
        [20.2095, 18.3163, 18.7952, 19.5698, 20.2200, 20.7223, 21.1006]])
ground truth
tensor([[18.7270, 26.9332, 24.7764,  8.1405,  9.2057, 13.8480, 12.6512],
        [26.0915, 23.6323, 12.9932, 20.8969, 14.2688, 13.7428, 18.7270],
        [ 7.4405, 13.5771, 21.7545, 21.0601, 32.4972, 40.6321, 40.6746],
        [13.4637, 19.8696,  5.8107,  7.6531,  7.0011,  8.6026, 14.8810],
        [14.6684, 20.6491, 19.7279, 19.8980, 34.5522, 41.2840, 27.0550],
        [20.9042, 2

batch_predictions
tensor([[19.5642, 26.0870, 24.8468, 22.1481, 18.2113, 16.5349, 16.9630],
        [15.8954, 14.5758, 15.0978, 15.8318, 16.4335, 16.8851, 17.2189],
        [26.4041, 23.5712, 23.5074, 24.3780, 25.3123, 26.1293, 26.7933],
        [12.1064, 13.2510, 14.4524, 15.0694, 14.9226, 14.5058, 14.0590],
        [20.1598, 26.3221, 25.1506, 22.6094, 18.8932, 17.0128, 17.3257],
        [21.6240, 21.2867, 21.2555, 21.3500, 21.4789, 21.6250, 21.7808],
        [29.1904, 24.5446, 25.9490, 27.6777, 29.1804, 30.3772, 31.2931],
        [22.7498, 19.9476, 20.6248, 21.8007, 22.8851, 23.7766, 24.4691]])
ground truth
tensor([[23.2710,  5.2721,  8.8152, 13.6338, 15.8447, 20.6774, 31.8594],
        [20.6996,  7.7985, 10.0079,  7.9563, 16.5045, 12.0857, 17.5434],
        [29.9461, 20.0822, 16.8793, 14.0731, 17.0351, 27.5794, 52.4943],
        [19.6344, 12.0857,  0.0000,  0.0000,  5.8785,  8.6007,  8.2194],
        [16.4966, 14.8810, 17.6871, 28.5856, 29.1241, 10.0057, 13.2795],
        [35.7851, 2

batch_predictions
tensor([[29.3900, 28.2100, 28.1004, 28.3510, 28.6744, 28.9779, 29.2355],
        [18.6703, 20.6204, 20.2786, 19.5502, 18.8286, 18.2454, 17.8639],
        [18.7068, 22.2391, 21.6386, 20.5780, 19.4852, 18.5630, 17.9514],
        [21.7174, 25.4498, 25.1080, 23.4291, 20.4875, 17.5985, 17.1298],
        [14.1918, 18.1281, 20.9956, 19.7731, 18.0607, 16.2114, 14.8155],
        [23.6373, 30.6041, 30.0138, 28.4116, 25.9201, 21.9010, 19.7927],
        [24.4722, 24.0571, 24.0112, 24.0984, 24.2181, 24.3516, 24.4920],
        [30.8915, 25.0076, 25.5459, 27.1619, 28.8443, 30.4450, 31.9897]])
ground truth
tensor([[36.7772, 15.2636, 17.6729, 27.6644, 30.0737, 36.8481, 50.2126],
        [16.5249, 13.1236, 13.3787, 17.0918, 18.1831, 21.0743, 12.9677],
        [15.3486, 15.6746, 16.9785, 24.2630, 15.3912,  0.0000,  9.2829],
        [36.0600, 37.0200, 14.1504,  8.9032, 16.5176, 18.2667, 23.1720],
        [12.1740, 12.7693, 13.0244, 16.4683, 28.0471, 29.3084, 28.6990],
        [17.9989, 1

batch_predictions
tensor([[ 8.7381,  8.8640,  8.9323,  8.9563,  8.9499,  8.9322,  8.9189],
        [28.8468, 29.0708, 29.0916, 29.0281, 28.9261, 28.8382, 28.7905],
        [12.8371, 11.8087, 11.4318, 11.2878, 11.3299, 11.5657, 12.0847],
        [25.7375, 33.6287, 33.9673, 31.9796, 27.7838, 20.6832, 19.8292],
        [16.3569, 20.0060, 20.0022, 19.2578, 18.4391, 17.7441, 17.2557],
        [ 7.9800,  8.4290,  8.5585,  8.5909,  8.5904,  8.5843,  8.5824],
        [16.7539, 21.0346, 23.5606, 23.3584, 22.6353, 21.7338, 20.7797],
        [26.5333, 27.2972, 27.3406, 27.1170, 26.7911, 26.4774, 26.2404]])
ground truth
tensor([[1.2954e+01, 1.7938e+01, 1.8924e+01, 4.4845e+00, 7.9958e+00, 5.7733e+00,
         5.4840e+00],
        [3.9768e+01, 4.3013e+01, 1.8211e+01, 1.6525e+01, 1.6738e+01, 1.7375e+01,
         3.2185e+01],
        [9.4424e+00, 8.5087e+00, 9.2451e+00, 8.6665e+00, 1.2822e+01, 1.4769e+01,
         1.7438e+01],
        [2.1046e+01, 4.1610e+01, 4.3679e+01, 3.8407e+01, 1.3676e+01, 9.6939

batch_predictions
tensor([[16.8777, 16.8358, 16.8508, 16.8580, 16.8584, 16.8634, 16.8766],
        [18.0217, 17.4800, 18.7484, 19.8392, 20.6885, 21.2998, 21.6965],
        [10.6786, 11.7143, 12.9216, 14.6251, 16.1236, 15.8463, 15.3226],
        [38.3853, 38.3687, 38.3733, 38.3753, 38.3762, 38.3769, 38.3777],
        [14.8909, 19.3501, 20.7070, 19.9694, 19.1388, 18.4287, 17.9158],
        [20.2763, 27.6159, 29.8707, 28.4317, 24.8956, 18.3700, 17.4984],
        [13.4811, 13.1362, 14.2146, 15.3812, 16.1111, 16.3860, 16.4852],
        [14.9471, 19.1960, 21.0191, 19.9216, 18.4197, 16.7188, 15.3706]])
ground truth
tensor([[2.0160e+01, 2.5460e+01, 1.0757e+01, 1.1599e+01, 9.2057e+00, 1.5742e+01,
         1.5229e+01],
        [1.8126e+01, 2.1556e+01, 1.2571e+01, 1.3223e+01, 1.3194e+01, 2.0139e+01,
         3.0371e+01],
        [9.3372e+00, 1.6334e+01, 2.0187e+01, 1.9253e+01, 8.6796e+00, 1.1494e+01,
         8.4429e+00],
        [4.6870e+01, 5.6418e+01, 6.2230e+01, 6.6531e+01, 2.0134e+01, 5.2604

batch_predictions
tensor([[19.0243, 17.4581, 17.6297, 18.2166, 18.7383, 19.1549, 19.4802],
        [23.9817, 21.2026, 22.6025, 24.0614, 25.2922, 26.2244, 26.8527],
        [25.8738, 31.9554, 31.9496, 30.4144, 27.6514, 22.8687, 20.0991],
        [16.0614, 15.3316, 15.3511, 15.6765, 16.0676, 16.4373, 16.7544],
        [15.3954, 19.5537, 23.1630, 22.8104, 21.6961, 20.2051, 18.4744],
        [17.4504, 17.3394, 17.3724, 17.4159, 17.4494, 17.4796, 17.5116],
        [17.6577, 16.0919, 16.5795, 17.3018, 17.8858, 18.3242, 18.6480],
        [18.9616, 25.8629, 25.1893, 23.2386, 19.7039, 17.1489, 17.5282]])
ground truth
tensor([[31.6893, 18.7358, 16.8651,  9.9490, 10.8418, 19.7562, 29.7902],
        [10.4734, 20.7766, 15.1077, 18.2965, 23.2143, 26.7999, 37.7409],
        [23.9216, 45.0158, 39.7291, 20.7785, 13.8348,  9.7449, 22.6328],
        [26.6156, 28.5289, 15.9439, 12.0323, 14.9660, 16.1990, 26.8991],
        [13.0102, 18.4240, 17.8146, 20.8333, 29.2800, 31.6752,  8.4325],
        [ 0.0000,  

batch_predictions
tensor([[15.3835, 14.9581, 15.0958, 15.3941, 15.6953, 15.9454, 16.1295],
        [ 9.9483, 10.6387, 11.2726, 11.9954, 13.0376, 15.0069, 17.9680],
        [16.9881, 16.2331, 17.3553, 18.2779, 18.9688, 19.4377, 19.7256],
        [ 7.5722,  8.6342,  9.1512,  9.5007,  9.7876, 10.0645, 10.3617],
        [24.0915, 33.0118, 34.5451, 33.1021, 30.2032, 24.8283, 20.9955],
        [31.3189, 29.8738, 29.7901, 30.1397, 30.5519, 30.9166, 31.2126],
        [14.5317, 14.1161, 14.1784, 14.3894, 14.6272, 14.8437, 15.0182],
        [20.9468, 25.2390, 24.3420, 22.6686, 20.4328, 18.1638, 17.6369]])
ground truth
tensor([[ 3.2614,  7.5355,  7.7854, 11.5466, 17.0437, 17.6618, 15.5839],
        [ 3.8664, 11.5071, 10.0473, 10.1262, 13.4140,  7.0095,  6.8911],
        [ 2.7778,  0.0000, 11.7772, 10.7143, 12.3299, 13.8605, 33.4467],
        [12.4934,  7.8511,  8.6402, 12.6512, 18.1483, 15.5050,  5.1420],
        [23.6395, 23.7245, 26.7715, 39.0023, 51.9841, 39.9660, 24.7449],
        [32.8774, 2

batch_predictions
tensor([[ 5.6672,  5.6020,  5.5965,  5.6023,  5.6058,  5.6069,  5.6084],
        [21.7323, 29.9862, 32.3392, 31.2933, 28.7759, 23.0644, 18.8753],
        [20.7116, 18.5683, 19.6962, 20.7582, 21.6029, 22.2120, 22.6176],
        [15.7320, 19.7838, 20.6311, 19.6456, 18.4039, 17.1437, 16.1514],
        [26.5828, 29.9047, 29.8791, 28.9603, 27.4918, 25.4087, 22.9352],
        [17.9326, 17.0631, 18.4725, 19.7268, 20.7069, 21.3426, 21.6724],
        [25.4857, 30.0269, 29.6977, 28.4535, 26.5702, 23.5471, 20.5837],
        [25.9225, 30.6021, 30.4274, 29.0374, 26.5310, 22.1752, 19.4341]])
ground truth
tensor([[ 8.2851,  6.0363,  0.0000,  4.2215, 15.7023, 11.3361,  7.2330],
        [21.5986, 20.8333, 43.3532, 41.5533, 32.8656, 15.5329, 17.9847],
        [13.1803, 10.2183, 14.1298, 15.5896, 20.0113, 31.8027, 27.9478],
        [ 0.0000, 10.2578, 17.0437,  6.6938, 28.2614, 26.0521, 13.5060],
        [19.5423, 25.8548, 25.1973, 33.2983, 45.6996, 19.9895, 17.2278],
        [ 7.4688, 1

batch_predictions
tensor([[28.3604, 28.5642, 28.5597, 28.4830, 28.3989, 28.3438, 28.3276],
        [19.4999, 19.5757, 19.5457, 19.4859, 19.4345, 19.4083, 19.4074],
        [34.6759, 38.1651, 38.0537, 37.8588, 37.6996, 37.5860, 37.5143],
        [38.2336, 38.0078, 38.0457, 38.0745, 38.0917, 38.1055, 38.1186],
        [29.1616, 28.0545, 27.7862, 27.9035, 28.1864, 28.5455, 28.9305],
        [25.8013, 26.7935, 26.8110, 26.4895, 26.0407, 25.5949, 25.2281],
        [14.3054, 18.0624, 20.0620, 18.5295, 16.4094, 14.0980, 13.0567],
        [18.2213, 16.7517, 17.7544, 18.5523, 19.1050, 19.4716, 19.7146]])
ground truth
tensor([[20.6602, 16.8464, 18.4903, 31.6018, 33.0615, 14.8606, 14.1373],
        [16.7543, 17.9116, 21.5150, 19.2399, 26.8937, 11.4019, 15.3998],
        [27.1684, 25.6519, 39.9802, 59.4813, 49.4331, 27.5794, 19.7562],
        [55.9382, 52.4802, 25.5385, 19.2460, 23.9938, 36.5363, 58.8861],
        [12.6984, 14.5975, 16.7234, 18.2540, 30.0028, 48.9938, 35.8135],
        [45.5782, 3

batch_predictions
tensor([[15.5954, 14.4926, 15.6160, 16.5486, 17.1245, 17.4529, 17.6482],
        [18.1368, 18.1453, 18.1423, 18.1239, 18.1196, 18.1176, 18.1108],
        [16.9889, 19.3370, 19.0930, 18.5333, 18.0570, 17.7406, 17.5814],
        [37.0302, 37.4760, 37.4635, 37.4008, 37.3508, 37.3228, 37.3126],
        [20.9037, 18.0415, 19.7485, 21.5324, 23.3113, 24.7654, 25.5582],
        [21.6781, 21.7822, 21.7721, 21.7073, 21.6360, 21.5890, 21.5758],
        [17.3993, 17.8439, 17.7436, 17.5428, 17.3793, 17.2834, 17.2456],
        [23.3469, 24.9176, 24.7952, 24.1595, 23.3551, 22.5679, 21.9282]])
ground truth
tensor([[ 5.9442,  8.7980,  0.9469, 18.5692, 12.1120, 21.0547, 20.8969],
        [22.0410, 15.2551,  9.9684, 13.3745, 16.0179, 16.5439, 14.5581],
        [ 4.8133, 10.1920, 10.7575,  8.8901, 18.7796, 23.3824,  2.5381],
        [23.5686, 27.1684, 25.6519, 39.9802, 59.4813, 49.4331, 27.5794],
        [15.3628, 11.8764, 11.2245, 12.1882, 19.9121, 43.1122, 37.0890],
        [35.1332, 3

batch_predictions
tensor([[17.8489, 20.9874, 20.6217, 19.7973, 18.9816, 18.3308, 17.9121],
        [21.7169, 21.2982, 21.2911, 21.4100, 21.5468, 21.6767, 21.7947],
        [ 8.6516,  9.6511, 10.2893, 10.8430, 11.4355, 12.2111, 13.4099],
        [13.7717, 13.0722, 13.8778, 14.8499, 15.5504, 15.8963, 16.0507],
        [11.7778, 12.7849, 13.9958, 15.0604, 15.1836, 14.7848, 14.2732],
        [20.5162, 18.6504, 18.6160, 19.2334, 19.8676, 20.4252, 20.9016],
        [29.4114, 23.9185, 25.5584, 27.5498, 29.3228, 30.9199, 32.4355],
        [18.8834, 20.1536, 19.8025, 19.0389, 18.2219, 17.5352, 17.0950]])
ground truth
tensor([[14.5125, 17.0210, 25.9495, 13.8322,  4.5210,  9.2120, 21.5136],
        [30.5414, 26.9416, 15.0652, 14.6967, 14.2007, 13.7472, 20.5641],
        [ 5.8815,  7.4972,  8.7727, 12.6134, 14.4558, 21.9529,  3.2880],
        [ 7.9300, 10.1789,  7.1278, 10.9416, 15.2420, 24.6844, 15.9916],
        [18.4903, 21.1599,  3.9584,  8.5613,  9.2320, 10.9942, 12.2304],
        [11.2670, 1

batch_predictions
tensor([[ 5.8719,  6.3689,  6.4575,  6.4647,  6.4619,  6.4532,  6.4520],
        [16.7290, 20.8866, 21.7166, 20.6902, 19.2235, 17.6054, 16.4716],
        [22.9847, 28.3660, 27.7276, 26.5937, 25.3130, 23.8695, 22.1921],
        [11.3595, 14.1303, 18.4267, 20.6671, 19.2693, 17.2861, 15.0321],
        [20.2409, 20.6294, 20.4021, 20.0592, 19.7770, 19.6033, 19.5375],
        [22.0057, 28.5774, 27.8265, 26.1812, 23.6493, 20.3142, 18.9519],
        [37.8440, 37.4258, 37.4748, 37.5358, 37.5761, 37.6063, 37.6332],
        [18.9985, 23.6092, 22.4862, 20.2007, 17.2546, 15.7530, 15.8865]])
ground truth
tensor([[ 1.8543,  9.0084, 11.5071, 17.7275, 18.9506, 27.6433,  6.3388],
        [13.9881, 16.2982, 24.0930, 34.8498, 19.2602, 13.1661, 16.9643],
        [22.8695, 16.4782, 18.4903, 25.6181, 45.0684, 32.9432, 21.1994],
        [ 9.6372,  7.1003, 10.3741, 19.6712, 32.2279, 17.3895,  6.9728],
        [17.5028, 16.9926, 25.2551, 26.4314,  2.2959, 13.6054, 13.4779],
        [17.0351, 2

batch_predictions
tensor([[19.9788, 20.0513, 20.0337, 19.9872, 19.9455, 19.9242, 19.9238],
        [28.0120, 23.7328, 24.8322, 26.3741, 27.7841, 28.9252, 29.7636],
        [14.9360, 14.4925, 15.7658, 16.8037, 17.4194, 17.7215, 17.8828],
        [20.6895, 23.5979, 22.9825, 22.0392, 21.0838, 20.1786, 19.4403],
        [21.1244, 20.9643, 20.9808, 21.0255, 21.0700, 21.1164, 21.1679],
        [20.5183, 20.4420, 20.4518, 20.4660, 20.4714, 20.4821, 20.5073],
        [18.6885, 21.6523, 20.9798, 19.8164, 18.5601, 17.3975, 16.6580],
        [25.9764, 35.1780, 35.4316, 34.2768, 32.4243, 29.7835, 25.1368]])
ground truth
tensor([[24.0400, 19.2004, 17.1883,  8.9295, 13.5455, 12.2962, 11.9411],
        [18.3673, 19.2319, 12.2166, 31.2075, 31.5618, 41.6241, 23.9654],
        [ 3.3022,  7.3980,  5.6689,  8.4184, 15.4478, 16.4399, 23.7528],
        [19.4240, 13.3482, 11.2178, 17.6618, 26.4072,  9.4424, 46.7254],
        [27.4592, 11.3361, 17.7275, 13.1773, 11.3230, 15.9127, 26.8017],
        [39.1723, 2

batch_predictions
tensor([[ 8.1393,  8.6809,  8.9375,  9.0679,  9.1291,  9.1580,  9.1763],
        [17.3822, 16.1417, 17.7739, 19.1590, 20.0490, 20.4585, 20.6559],
        [16.7624, 16.0108, 15.9915, 16.2601, 16.5787, 16.8728, 17.1235],
        [14.4095, 17.2281, 17.9425, 16.7981, 15.3529, 13.7856, 12.4093],
        [37.5675, 38.1942, 38.1553, 38.0978, 38.0580, 38.0344, 38.0221],
        [ 9.1702, 10.5368, 11.8146, 13.6323, 16.9993, 20.0828, 18.9946],
        [37.4011, 38.4058, 38.3991, 38.3946, 38.3892, 38.3843, 38.3791],
        [ 7.7952,  8.3186,  8.4935,  8.5513,  8.5634,  8.5640,  8.5666]])
ground truth
tensor([[ 9.7506,  6.8594,  6.5901, 16.1990, 17.6587, 20.8192,  8.9286],
        [27.6361, 12.2449, 15.8730, 17.4036, 16.2698, 24.3764, 27.9620],
        [10.5471, 12.7433, 10.1262, 14.4792, 19.5423, 18.1878, 16.9385],
        [16.0836, 22.6460,  9.9947,  7.5618,  7.4566,  8.8112, 10.8364],
        [28.8549, 22.5057, 29.1241, 37.3866, 56.7460, 50.1701, 32.6531],
        [ 8.5601,  

batch_predictions
tensor([[20.9619, 20.9211, 20.9289, 20.9363, 20.9292, 20.9266, 20.9397],
        [ 4.8148,  4.7945,  4.7986,  4.8038,  4.8060,  4.8066,  4.8078],
        [15.0577, 15.0832, 15.0798, 15.0585, 15.0373, 15.0289, 15.0344],
        [19.3289, 18.7416, 18.6849, 18.8285, 19.0191, 19.2078, 19.3804],
        [17.9201, 16.1820, 16.7953, 17.6150, 18.2778, 18.7711, 19.1325],
        [23.0479, 21.2777, 21.1156, 21.6197, 22.2201, 22.7680, 23.2346],
        [14.6713, 17.1961, 17.4268, 16.7723, 16.1122, 15.5914, 15.2296],
        [18.2017, 21.4339, 20.9285, 20.0304, 19.1577, 18.4640, 18.0231]])
ground truth
tensor([[16.4116, 21.6978, 28.7840, 34.7506, 17.5454, 17.7863, 10.7285],
        [ 4.6769,  5.0595, 10.6151, 11.5363,  7.5113,  4.4501,  4.1100],
        [21.3435, 25.5669, 11.6638, 11.4938,  9.0561, 11.4512, 13.9598],
        [32.4546, 30.3571,  0.5385, 13.6480, 13.7897, 15.6604, 23.0300],
        [ 3.8138,  0.0000,  1.0652, 12.7564, 15.1894, 28.5376, 28.7743],
        [13.1519, 1

batch_predictions
tensor([[21.2413, 20.1559, 20.0146, 20.2711, 20.6210, 20.9572, 21.2525],
        [16.7116, 18.7913, 18.4948, 17.7291, 16.9793, 16.3794, 15.9754],
        [28.0379, 25.2866, 24.0841, 24.1119, 24.8596, 25.9159, 27.1540],
        [19.1214, 18.0562, 17.8963, 18.1403, 18.4850, 18.8242, 19.1323],
        [19.9484, 19.6480, 19.6565, 19.7344, 19.8113, 19.8799, 19.9432],
        [18.2487, 22.2580, 21.4172, 19.8690, 18.0744, 16.5554, 16.0014],
        [17.6668, 22.5318, 22.6225, 21.5847, 19.9940, 18.1547, 16.9436],
        [30.2524, 34.7949, 34.8411, 33.9962, 32.8053, 31.4516, 29.8838]])
ground truth
tensor([[13.6763, 12.3299, 15.4478, 15.3486, 20.8617, 31.2075,  9.3112],
        [10.8101, 13.5718, 24.6318,  5.4708,  5.5234,  7.1541,  3.1168],
        [50.5102, 41.7375, 16.2415, 20.6349, 19.3169, 19.4019, 33.9569],
        [13.7559, 12.2173,  9.7712, 14.2031, 16.9253, 19.5029, 24.6186],
        [16.9359,  6.9019, 10.5726, 16.6383, 17.9280, 24.3622, 21.2443],
        [14.1636,  

batch_predictions
tensor([[11.1080,  9.2732,  7.7895,  7.2636,  7.1588,  7.1395,  7.1327],
        [15.1150, 18.8999, 19.4025, 18.4441, 17.3524, 16.3329, 15.5064],
        [17.6398, 20.7753, 19.9780, 18.4473, 16.6720, 15.1541, 14.6668],
        [17.0498, 18.1145, 17.7878, 17.1837, 16.6365, 16.2521, 16.0513],
        [17.6550, 16.2958, 17.6178, 18.7084, 19.4773, 19.9351, 20.1924],
        [19.0482, 20.2819, 19.8223, 18.8948, 17.8634, 16.9770, 16.4402],
        [24.6154, 24.0412, 23.9675, 24.0897, 24.2729, 24.4776, 24.6864],
        [16.6284, 16.3959, 16.4274, 16.5075, 16.5830, 16.6487, 16.7068]])
ground truth
tensor([[ 9.0561,  0.2834,  6.6893, 10.3033,  7.9507, 10.1474, 13.1519],
        [15.2946, 27.7617, 24.0400,  8.3772, 10.0342, 10.8364, 10.3761],
        [15.9653, 12.8090, 17.1226, 23.3693, 27.3014, 10.3367, 12.3225],
        [23.6323,  5.9705, 11.6649,  9.6002,  9.0873, 12.7696, 17.8985],
        [15.1894,  6.0889, 15.0710, 13.6902, 17.2935, 23.4219, 31.1547],
        [13.0385, 1

batch_predictions
tensor([[18.4307, 17.0243, 17.6552, 18.3056, 18.7799, 19.1140, 19.3500],
        [ 8.4135,  9.6931, 10.4771, 11.1843, 11.9970, 13.0844, 14.4782],
        [15.9590, 14.6808, 15.4818, 16.2203, 16.7563, 17.1128, 17.3438],
        [27.4363, 24.2918, 23.2266, 23.6590, 24.6820, 25.8899, 27.1831],
        [12.0440, 14.2345, 16.3567, 17.6333, 17.5531, 17.2068, 16.8938],
        [18.2764, 24.0697, 23.2947, 21.1275, 17.8612, 16.1994, 16.3985],
        [31.7018, 31.2914, 31.2942, 31.3901, 31.4944, 31.5967, 31.6945],
        [21.9603, 31.0419, 34.3959, 33.4423, 30.1418, 22.1602, 18.5859]])
ground truth
tensor([[25.2834, 14.2007, 13.6196, 17.4461, 17.9138, 17.0777, 20.8333],
        [26.8707, 16.0289,  0.0000,  0.0000,  8.9569, 12.4150, 14.1865],
        [20.9892, 10.6576, 11.7063,  8.7868, 13.9598, 18.4807, 27.9904],
        [47.1514,  2.0408, 12.3724, 20.7200, 21.9529, 15.8588, 33.5176],
        [ 3.2880, 11.2245, 11.7063, 10.6151, 20.0255,  3.2313,  4.5210],
        [10.8702, 1

batch_predictions
tensor([[13.3766, 15.4168, 17.2711, 17.1798, 16.7418, 16.3779, 16.1411],
        [21.5432, 26.5664, 25.9460, 23.6634, 19.3537, 16.5976, 16.8708],
        [37.6537, 36.8890, 36.9218, 37.0466, 37.1501, 37.2285, 37.2917],
        [21.2882, 23.4366, 22.9841, 22.1725, 21.3344, 20.5795, 20.0086],
        [13.6864, 16.5435, 18.9635, 18.3769, 17.7487, 17.3130, 17.0910],
        [23.0212, 31.8507, 32.0713, 30.5604, 27.7534, 22.3127, 18.8279],
        [13.1566, 16.4094, 20.4947, 19.8394, 18.5220, 17.1260, 16.0912],
        [27.5934, 24.6710, 24.3794, 25.2887, 26.4141, 27.4649, 28.3502]])
ground truth
tensor([[12.7959,  8.8638, 11.3756, 14.2031, 21.3835, 13.2562, 11.7438],
        [24.6457, 31.8452, 21.8679,  5.7965, 12.6701, 11.1820, 14.2149],
        [42.7154, 32.5964, 27.9904, 26.7857, 32.4121, 38.2795, 56.4484],
        [17.9248,  1.4598,  0.0000, 10.3235, 10.2578, 13.5192, 22.2120],
        [19.9121, 23.8237, 21.4144, 12.0748, 14.1015, 11.0261, 11.9189],
        [19.9369, 2

batch_predictions
tensor([[15.6897, 20.0340, 22.1199, 21.1159, 19.4880, 17.3895, 15.8519],
        [18.3931, 23.9098, 26.7330, 26.1265, 24.4484, 21.0727, 17.4986],
        [12.8050, 16.0365, 20.6779, 19.4536, 17.2264, 14.8613, 13.4359],
        [19.1170, 25.6375, 25.0981, 22.8512, 18.6235, 16.2099, 16.4250],
        [19.6346, 23.6022, 22.8291, 20.8343, 17.9451, 16.0316, 16.0489],
        [14.9434, 17.0636, 17.2896, 16.6270, 15.8841, 15.2410, 14.7349],
        [15.5618, 15.3738, 16.5416, 17.4921, 18.1825, 18.6247, 18.8724],
        [28.1557, 27.6333, 27.5688, 27.6587, 27.7884, 27.9343, 28.0862]])
ground truth
tensor([[10.5602, 13.5455, 11.7570, 19.0295, 27.5118, 32.4829, 11.0074],
        [ 8.7443, 18.1831, 19.0334, 29.4076, 38.8180, 36.0119, 10.1616],
        [10.7143, 10.8277, 22.9308, 30.3571, 18.2540,  9.0561, 10.2041],
        [18.5955, 15.6365, 22.9748, 35.8759, 27.3672, 12.2962, 12.1778],
        [13.3877, 18.7533, 33.4824, 13.8611, 13.7033, 11.3756, 13.8348],
        [23.4613, 2

batch_predictions
tensor([[18.6706, 16.8795, 17.5358, 18.4727, 19.2451, 19.8335, 20.2806],
        [22.2476, 20.0421, 20.4643, 21.3086, 22.0459, 22.6230, 23.0570],
        [13.1459, 12.2687, 12.5032, 13.1463, 13.9829, 14.7594, 15.2065],
        [15.5244, 15.3503, 16.6788, 17.8245, 18.6687, 19.1454, 19.3516],
        [17.2086, 20.3122, 20.1372, 19.3723, 18.5853, 17.9473, 17.5237],
        [24.0463, 28.5279, 27.9154, 26.9471, 25.9606, 24.9687, 23.8711],
        [12.9843, 16.5306, 20.2977, 19.1323, 17.6699, 16.4280, 15.6482],
        [27.7570, 28.2185, 28.2373, 28.1008, 27.9113, 27.7474, 27.6459]])
ground truth
tensor([[31.4177,  5.1815, 16.7938,  5.7075, 15.7812, 16.2415, 21.3572],
        [ 9.8498, 14.8101, 18.1973, 12.6559, 20.7766, 22.6616, 31.0516],
        [11.3624,  3.5245,  7.0095,  8.5613,  9.0479, 12.2962, 24.0137],
        [12.0726, 13.9532, 14.2951, 14.7159, 14.4924, 12.1515, 20.7128],
        [11.9756, 12.5142, 12.8118, 17.8005, 28.5714, 24.9433, 12.0040],
        [15.2420, 2

batch_predictions
tensor([[28.3766, 24.2183, 24.9206, 26.3646, 27.7561, 28.9131, 29.7895],
        [20.5332, 26.5052, 25.3347, 22.1464, 17.1259, 16.3833, 16.6578],
        [20.9745, 28.0979, 27.0551, 24.6442, 20.3281, 16.9183, 17.3211],
        [38.2573, 38.1956, 38.2020, 38.2029, 38.2019, 38.2005, 38.1994],
        [16.6229, 20.4109, 20.4853, 19.5165, 18.3197, 17.1674, 16.3085],
        [ 8.3768,  9.0162,  9.2797,  9.4016,  9.4545,  9.4787,  9.4945],
        [17.1653, 21.1887, 21.0046, 20.0220, 18.8453, 17.7500, 16.9870],
        [19.9447, 18.3905, 19.1106, 19.9978, 20.6798, 21.1683, 21.5132]])
ground truth
tensor([[43.1615, 17.2278, 17.4250, 16.5176, 20.7391, 24.9868, 52.8538],
        [27.6644, 40.9155, 36.8764, 14.4983, 19.7988, 16.1565, 15.7313],
        [20.1814, 33.9711, 33.8719,  6.3067, 17.9138, 15.1219, 20.2098],
        [42.2935, 41.8201, 37.6381, 38.2167, 41.5834, 57.9037, 48.6981],
        [13.0385, 13.4495, 20.5924, 25.3968, 28.3588, 10.8135,  8.6876],
        [ 8.2720, 1

batch_predictions
tensor([[16.1204, 19.2967, 19.3874, 18.7601, 18.1278, 17.6445, 17.3468],
        [23.0892, 23.1307, 23.1252, 23.0958, 23.0677, 23.0560, 23.0621],
        [22.1500, 24.9896, 24.4377, 23.1306, 21.3044, 19.2480, 18.0971],
        [24.9614, 21.8637, 23.6178, 25.5447, 27.3083, 28.8321, 30.0953],
        [26.3066, 22.5848, 22.9735, 24.3113, 25.6935, 26.9147, 27.8925],
        [20.2270, 25.1031, 23.7594, 21.0533, 17.1960, 16.1034, 16.3517],
        [15.1933, 19.6613, 23.8570, 23.8228, 22.8187, 21.1746, 18.9474],
        [17.9139, 18.8317, 18.7455, 18.1993, 17.5410, 16.9817, 16.6237]])
ground truth
tensor([[12.5329, 14.9264, 13.8217, 20.8837, 22.8958, 10.7838, 14.5055],
        [18.5658, 13.9456, 14.0873, 22.8741, 32.3980, 30.8107, 23.5544],
        [21.5845, 33.1491, 39.7251,  7.2279, 11.9473, 15.0794, 18.3957],
        [35.9694, 13.9598, 18.2256, 12.4717, 16.0714, 21.9388, 44.3452],
        [41.0856,  8.1633, 13.9881, 22.7466, 19.3878, 22.7183, 31.7744],
        [17.8005, 2

batch_predictions
tensor([[21.3962, 20.6505, 20.5083, 20.6617, 20.9429, 21.2804, 21.6367],
        [18.4222, 23.2855, 26.3681, 26.3524, 25.4116, 23.5950, 20.2790],
        [22.4832, 24.2438, 23.9385, 23.1428, 22.1982, 21.2689, 20.5123],
        [24.4137, 26.1578, 25.9369, 25.3273, 24.6714, 24.1052, 23.6958],
        [11.9457, 12.9841, 13.9362, 14.6968, 14.9271, 14.7673, 14.5078],
        [30.4058, 33.9984, 34.0903, 33.3401, 32.2822, 31.1320, 29.9584],
        [14.5689, 14.2215, 15.4421, 16.4606, 17.0839, 17.3815, 17.5282],
        [16.5826, 20.5337, 20.8562, 20.0692, 19.1187, 18.2311, 17.5567]])
ground truth
tensor([[32.6407,  6.4703, 15.3077, 15.5576, 17.5302, 25.6970, 20.2525],
        [16.2020, 20.5550, 29.0242, 45.4629, 50.9600, 16.7675, 20.5944],
        [11.9411, 18.4114, 25.6970, 39.1899, 27.2488, 18.0037, 18.5692],
        [20.0026, 24.8948, 12.3225, 10.8496, 11.0205, 12.8880, 11.1915],
        [10.4550,  0.0000,  0.0000,  0.0000,  3.3798, 12.6249, 20.0684],
        [27.6696, 3

batch_predictions
tensor([[17.5341, 23.0499, 23.4629, 22.3737, 20.4363, 17.8221, 16.5151],
        [18.9962, 21.1969, 20.8357, 20.1758, 19.5733, 19.1216, 18.8490],
        [ 4.9919,  4.5272,  4.4956,  4.5195,  4.5392,  4.5522,  4.5618],
        [28.4997, 30.9835, 30.9410, 30.2160, 29.2093, 28.0474, 26.7808],
        [13.7045, 15.1376, 15.5818, 15.2707, 14.7623, 14.2755, 13.8616],
        [14.6671, 18.2327, 19.0028, 18.1277, 17.1689, 16.3191, 15.6416],
        [20.7110, 29.3179, 32.4340, 30.8187, 25.6901, 17.0859, 17.4247],
        [ 9.2343,  9.8862, 10.2433, 10.4772, 10.6406, 10.7651, 10.8707]])
ground truth
tensor([[16.9926, 10.9694, 15.2778, 22.8883, 32.8940, 28.9683,  9.5096],
        [16.7806, 13.9400, 14.6370, 15.4787, 29.1426, 28.4850, 29.1031],
        [10.1131,  3.8138,  3.3403, 14.5581, 13.7428, 12.0594, 17.3856],
        [14.9943, 16.0714, 27.9620, 34.4529, 31.9161, 20.2664, 18.0556],
        [23.0274, 28.0773, 10.9811, 14.1504,  7.8248, 11.3756, 12.3751],
        [14.4529, 1

batch_predictions
tensor([[16.6752, 16.4857, 16.5576, 16.6634, 16.7512, 16.8197, 16.8744],
        [18.2936, 20.7297, 20.3047, 19.4293, 18.5365, 17.7750, 17.2459],
        [20.8671, 21.0567, 21.0227, 20.9107, 20.7974, 20.7213, 20.6914],
        [18.6251, 21.7035, 20.5887, 18.6805, 16.2909, 14.7918, 14.7512],
        [15.0345, 18.6022, 20.0438, 19.3236, 18.5015, 17.8118, 17.3249],
        [18.1783, 17.7868, 17.7767, 17.8990, 18.0452, 18.1882, 18.3200],
        [19.3463, 22.7867, 21.2031, 18.2905, 14.5810, 13.8896, 14.0581],
        [18.8783, 22.6991, 21.8235, 20.4267, 18.8909, 17.4908, 16.7474]])
ground truth
tensor([[24.2205,  9.9206,  2.3951, 16.3549, 13.2370, 14.0164, 13.0244],
        [16.6808, 22.7466, 21.3294, 14.8384, 16.5249, 13.1236, 13.3787],
        [32.4263, 34.3679, 13.7613, 18.9484, 11.0544, 18.3248, 14.3566],
        [22.8741, 36.7630, 25.0567, 15.3486, 18.2965, 15.1644, 21.2018],
        [18.8350, 26.4031, 23.1009, 12.6984, 16.3265, 14.2290, 11.9473],
        [22.8432, 2

batch_predictions
tensor([[16.6081, 19.6885, 22.0958, 22.8351, 22.1711, 20.7828, 18.9326],
        [21.3375, 18.6246, 20.0767, 21.8005, 23.2160, 24.2049, 24.9201],
        [19.7318, 21.2196, 20.8287, 20.1579, 19.5169, 19.0135, 18.6989],
        [19.1893, 18.7428, 18.7248, 18.8564, 19.0162, 19.1713, 19.3123],
        [19.1855, 17.6750, 17.8207, 18.5242, 19.1958, 19.7557, 20.2097],
        [25.6176, 24.4399, 24.3382, 24.6229, 24.9828, 25.3165, 25.5982],
        [20.4288, 21.4589, 20.9649, 20.1641, 19.3686, 18.7372, 18.3662],
        [24.8676, 24.8028, 24.8018, 24.8077, 24.7991, 24.7988, 24.8184]])
ground truth
tensor([[ 7.8231, 10.1616, 17.9138, 29.5918, 21.2868,  8.1066, 13.9031],
        [ 5.4971, 14.6896, 13.6113, 14.7291, 15.0579, 25.0526, 31.5097],
        [13.0984, 15.5708, 17.6223, 23.8033, 31.6281, 10.4945, 14.3346],
        [28.9979, 12.6249, 13.5192, 11.3361, 11.9937, 19.4240, 25.7496],
        [27.4855, 10.3104, 13.6770, 15.8864, 19.3977, 19.4503, 21.7912],
        [18.1689, 1

batch_predictions
tensor([[20.5775, 21.8685, 21.6277, 20.9906, 20.3008, 19.7202, 19.3342],
        [22.1878, 21.9813, 21.9980, 22.0473, 22.0932, 22.1372, 22.1828],
        [16.3647, 20.1989, 20.6437, 20.0277, 19.3313, 18.7455, 18.3487],
        [20.7763, 25.5061, 23.9207, 20.9702, 16.8350, 16.1700, 16.5230],
        [21.7943, 27.6279, 27.0194, 24.8788, 20.9025, 17.4330, 17.6325],
        [20.5617, 25.4478, 24.6639, 22.5214, 19.1852, 16.9222, 17.0209],
        [17.7627, 17.8654, 17.8308, 17.7322, 17.6305, 17.5620, 17.5362],
        [15.3361, 18.9031, 19.7522, 18.8418, 17.7863, 16.8067, 16.0223]])
ground truth
tensor([[19.4766, 24.8685, 19.9435, 14.2294, 11.1849, 11.6781, 11.2046],
        [24.5134, 19.3582, 22.6986, 13.9006, 13.5981, 15.2946, 18.4377],
        [20.1672, 26.8991, 29.4218, 12.9677, 19.1610, 17.7721, 18.5374],
        [30.4563, 31.5476, 11.6922, 17.5170, 15.1927, 16.6950, 19.5011],
        [19.6003, 25.1134, 40.8163, 33.6026,  0.0000, 10.6293, 19.5437],
        [21.6465, 3

batch_predictions
tensor([[21.0640, 24.1587, 23.5318, 22.4736, 21.3146, 20.1232, 19.1720],
        [16.4618, 17.8641, 17.6598, 17.1162, 16.6079, 16.2367, 16.0205],
        [20.9542, 20.4166, 20.3665, 20.4953, 20.6668, 20.8403, 21.0029],
        [13.1568, 16.8399, 20.6623, 19.0507, 16.7007, 14.2616, 13.0281],
        [21.0916, 24.1417, 23.6242, 22.4251, 20.9049, 19.2956, 18.2789],
        [ 9.1746, 10.2453, 11.1288, 12.1330, 13.5778, 15.8618, 17.5941],
        [10.7051, 11.6642, 12.7085, 14.0827, 15.5195, 15.6551, 15.2362],
        [13.3810, 17.6778, 21.5200, 20.1169, 17.9969, 15.6296, 14.1076]])
ground truth
tensor([[19.9500,  7.7065,  8.7322, 11.5071, 19.8317, 18.7138, 20.8837],
        [11.8622, 11.2387, 11.3520, 15.0935, 21.1026, 27.4376, 18.1831],
        [29.6627, 23.7670, 12.4291, 18.4949, 17.9847, 17.2902, 21.5420],
        [ 9.1268, 19.4240,  9.0216, 17.7012, 24.5529, 24.1057,  8.5613],
        [19.8185, 26.4072, 44.7922, 28.4061, 16.9911, 17.8327, 18.8585],
        [ 9.3766, 1

batch_predictions
tensor([[18.7785, 18.8998, 18.8469, 18.7363, 18.6336, 18.5708, 18.5529],
        [14.3161, 16.3082, 16.7220, 16.3641, 15.9629, 15.6656, 15.4839],
        [12.7016, 12.6504, 12.6384, 12.6483, 12.6423, 12.6316, 12.6291],
        [38.0700, 38.3364, 38.3182, 38.2968, 38.2827, 38.2749, 38.2715],
        [14.0245, 17.4907, 19.5483, 18.6899, 17.7436, 16.9737, 16.4427],
        [21.6248, 23.3292, 22.7429, 21.7743, 20.7162, 19.7180, 18.9930],
        [17.5526, 19.1366, 18.7397, 17.7626, 16.6523, 15.6763, 15.0715],
        [12.8930, 16.4859, 20.9718, 20.0519, 18.5336, 16.9988, 16.0261]])
ground truth
tensor([[11.9615, 16.8226, 25.1276, 26.8991, 12.1457, 12.7693, 19.1468],
        [ 9.9290, 14.0058, 12.6381, 23.6718, 10.2052, 13.6376, 11.1257],
        [12.5592, 17.4250, 21.4361,  8.6665, 11.5992, 11.2441, 11.6781],
        [28.6565, 26.4172, 27.2534, 44.3736, 65.2069, 49.7449, 34.4813],
        [14.6502, 27.1305, 18.0957,  0.0000,  9.3503, 11.7044, 10.6391],
        [24.3030, 2

batch_predictions
tensor([[17.6362, 23.2194, 28.7068, 30.8555, 30.7425, 29.9227, 28.7302],
        [16.4602, 15.7026, 15.6723, 15.9119, 16.1933, 16.4499, 16.6676],
        [15.0271, 17.4748, 18.0673, 17.5587, 17.0004, 16.5698, 16.2911],
        [18.1002, 22.7844, 22.7783, 21.4620, 19.3858, 17.0523, 16.1556],
        [20.6021, 27.5332, 27.2663, 25.6985, 22.6724, 19.0125, 18.2380],
        [12.6444, 14.8157, 16.7517, 16.6719, 16.0206, 15.3555, 14.7812],
        [27.5909, 31.2567, 31.2198, 30.2109, 28.6212, 26.3002, 23.4144],
        [10.9084, 12.9928, 15.3281, 17.7345, 18.5500, 18.4807, 18.1916]])
ground truth
tensor([[13.8889, 14.2715, 14.1865, 13.7613, 20.7058,  6.3350,  4.2800],
        [19.6344,  7.7722, 11.7307, 12.4803, 12.6907, 16.0047, 22.6854],
        [12.8685, 21.0176, 28.9683, 20.6916, 12.4008, 13.7188, 13.3220],
        [12.7409,  9.2262, 22.6616, 21.2443, 32.8798, 24.6882, 16.6950],
        [10.0198, 21.5561, 45.7200, 38.7755, 22.5482, 19.2177, 24.3764],
        [13.6902, 2

batch_predictions
tensor([[18.6613, 24.5567, 26.1993, 25.5068, 23.9939, 21.2051, 18.0059],
        [24.2699, 29.3836, 29.0434, 27.6034, 25.1926, 21.3596, 19.1217],
        [20.7060, 21.5210, 21.3393, 20.9940, 20.6959, 20.5018, 20.4119],
        [18.5768, 18.5421, 18.5630, 18.5741, 18.5774, 18.5846, 18.5997],
        [11.1358, 13.8845, 16.7291, 18.4149, 18.6724, 18.3356, 17.8941],
        [ 9.1749,  9.5116,  9.6998,  9.8043,  9.8522,  9.8689,  9.8736],
        [20.9781, 23.4659, 23.0524, 21.9625, 20.5680, 19.1199, 18.1479],
        [17.7548, 17.7866, 17.7626, 17.7251, 17.6961, 17.6854, 17.6899]])
ground truth
tensor([[18.0556, 20.7058, 24.1071, 41.1706, 26.4881,  8.7727, 13.9881],
        [30.0028, 50.0850, 34.5663, 29.7052, 22.5907, 14.9943, 16.0714],
        [17.7670, 23.0668, 23.8953, 10.7575,  1.5255, 15.8206, 14.7422],
        [ 8.9144, 10.2466, 11.2387,  9.7647,  9.6655, 18.4099, 21.1310],
        [ 1.2330, 12.6842,  9.7647, 14.8243, 22.7891, 22.1372, 11.6780],
        [10.4813,  

batch_predictions
tensor([[19.1401, 25.3488, 25.0844, 23.4595, 20.3758, 17.3237, 17.0531],
        [19.2591, 21.7789, 20.8221, 19.2449, 17.2724, 15.6463, 15.3054],
        [ 4.8487,  4.9672,  4.9666,  4.9636,  4.9618,  4.9555,  4.9494],
        [37.3487, 38.3100, 38.2484, 38.1810, 38.1349, 38.1058, 38.0885],
        [19.3723, 23.0660, 22.1647, 20.5236, 18.4802, 16.6634, 16.2514],
        [32.6525, 27.6758, 27.9691, 29.3434, 30.7618, 32.0544, 33.1817],
        [25.2369, 26.7456, 26.5932, 26.0048, 25.2898, 24.5964, 24.0167],
        [23.7331, 20.9878, 21.3686, 22.3432, 23.2545, 24.0076, 24.5976]])
ground truth
tensor([[24.2347, 28.6848, 41.6525, 22.2222, 10.4875, 16.7375, 18.6224],
        [19.8554, 30.8957,  8.9569, 12.1740, 12.7693, 13.0244, 16.4683],
        [ 4.1383,  5.1587,  5.7115,  6.0516,  7.0437, 11.2103, 13.4637],
        [35.7993, 53.4297, 50.6519, 31.7319, 24.9433, 20.9042, 23.3277],
        [15.0184, 16.5308, 24.3556, 26.4072,  8.7849, 14.1899, 12.3225],
        [41.5958, 2

batch_predictions
tensor([[20.5815, 18.6805, 18.8715, 19.5664, 20.2002, 20.7147, 21.1204],
        [20.5885, 28.7411, 29.3137, 27.8609, 24.5647, 18.6956, 18.0083],
        [38.2652, 38.1792, 38.1896, 38.1923, 38.1920, 38.1910, 38.1902],
        [14.2874, 17.1906, 17.8774, 17.3446, 16.8411, 16.4947, 16.3053],
        [32.7856, 32.5465, 32.5374, 32.5671, 32.5927, 32.6259, 32.6733],
        [19.4028, 24.6998, 24.0793, 22.6386, 20.6641, 18.8715, 18.1671],
        [26.7259, 28.2003, 28.0733, 27.5444, 26.9208, 26.3385, 25.8725],
        [24.7293, 30.3716, 30.4643, 29.1070, 26.5866, 22.3011, 19.6400]])
ground truth
tensor([[29.1525,  3.5856, 10.3033, 13.9881, 13.6338, 19.7562, 31.4201],
        [21.3152, 36.1536, 42.0777, 33.2341, 19.7704, 18.1689, 15.1502],
        [41.3204, 43.2799, 42.8064, 54.3135, 52.0779, 43.3456, 54.1163],
        [16.9648, 18.5692, 26.5781, 24.0531,  8.0615,  8.0747, 15.8338],
        [55.6418, 14.1899,  0.0000, 15.4787, 21.9227, 27.8538, 38.6770],
        [13.6196, 3

tensor([[18.3788, 20.2729, 19.9047, 18.9293, 17.7951, 16.7721, 16.1294],
        [17.9841, 16.4292, 17.7469, 18.8318, 19.5851, 20.0365, 20.3008],
        [20.1407, 18.1150, 18.3208, 19.1198, 19.8655, 20.4846, 20.9863],
        [16.8921, 21.4760, 20.5646, 18.6561, 16.3287, 14.7735, 14.5286],
        [18.0152, 17.8580, 17.8922, 17.9522, 18.0042, 18.0504, 18.0952],
        [21.1709, 28.6859, 28.3872, 26.7843, 23.6553, 19.3857, 18.3280],
        [18.0462, 18.2850, 18.1796, 17.9968, 17.8420, 17.7506, 17.7214],
        [22.4830, 30.1977, 29.7941, 27.4137, 22.0538, 17.3647, 17.7975]])
ground truth
tensor([[21.8701, 30.0368, 10.0736, 11.2835,  9.1005, 13.4534, 18.2272],
        [11.2245, 14.2715, 16.3124, 11.6780, 17.3895, 27.0833, 29.5777],
        [26.9133, 12.3866, 11.7489, 11.1961, 20.0397, 26.5448, 24.8866],
        [12.7693, 13.0244, 16.4683, 28.0471, 29.3084, 28.6990, 13.2086],
        [14.1723, 16.1990, 19.7988, 27.6927, 11.3520,  6.9444, 16.3832],
        [17.5454, 30.0028, 50.0850, 3

batch_predictions
tensor([[20.7141, 23.2897, 22.6998, 21.6362, 20.4225, 19.1979, 18.3190],
        [38.3120, 38.3008, 38.3030, 38.3036, 38.3037, 38.3037, 38.3037],
        [15.9778, 15.9263, 15.9786, 16.0251, 16.0566, 16.0830, 16.1077],
        [15.5139, 17.5281, 17.7367, 17.2627, 16.7713, 16.4028, 16.1734],
        [26.5461, 29.1181, 28.7757, 28.1313, 27.5187, 27.0197, 26.6594],
        [ 6.3801,  5.6794,  5.5723,  5.5817,  5.6002,  5.6127,  5.6216],
        [25.3132, 30.3704, 30.1478, 28.8545, 26.7527, 23.2326, 20.1795],
        [23.9880, 30.9490, 30.6601, 28.5397, 24.0076, 18.5858, 18.6160]])
ground truth
tensor([[18.5034, 27.3803, 27.6039, 16.6360, 13.2167, 13.9400, 12.5855],
        [38.5324, 45.5155, 53.4850, 55.9179, 38.1378, 40.7286, 22.9090],
        [19.9500, 18.7270, 18.4508, 11.4940, 14.8869, 14.0847, 20.0815],
        [ 9.0347, 10.9022, 10.8101, 14.4661, 19.8185, 13.3482, 11.2441],
        [22.0673, 26.4203, 39.3477, 43.2667, 39.2031, 30.2078, 29.1294],
        [ 6.0363,  

batch_predictions
tensor([[22.6389, 31.5711, 32.9758, 31.8910, 29.8217, 25.9165, 20.6946],
        [38.3277, 38.1904, 38.2059, 38.2105, 38.2102, 38.2087, 38.2072],
        [18.8561, 17.7091, 19.0189, 20.2301, 21.2527, 22.0495, 22.5898],
        [15.9331, 16.5027, 16.3469, 15.9374, 15.5289, 15.2291, 15.0645],
        [29.6701, 25.6571, 25.2159, 26.2856, 27.6161, 28.8789, 29.9822],
        [13.5358, 12.6267, 12.9665, 13.7691, 14.6460, 15.2748, 15.6454],
        [36.8777, 36.6424, 36.6503, 36.6785, 36.6979, 36.7201, 36.7495],
        [19.7175, 17.9887, 18.1367, 18.7877, 19.3867, 19.8774, 20.2697]])
ground truth
tensor([[15.4337, 23.6395, 23.7245, 26.7715, 39.0023, 51.9841, 39.9660],
        [49.6712, 44.2662, 50.1447, 36.9937, 41.6623, 47.6854, 66.2941],
        [ 9.8764, 14.3740, 12.2962, 11.5729, 15.7680, 18.2404, 28.8138],
        [15.8732, 12.7696, 17.6486, 25.8154, 25.4866,  8.0089, 12.8880],
        [21.7517, 16.2020, 18.3719, 18.5955, 23.9216, 45.0158, 39.7291],
        [ 4.4845,  

batch_predictions
tensor([[18.7484, 19.6519, 19.3026, 18.8119, 18.4112, 18.1579, 18.0508],
        [18.2183, 20.8106, 20.1448, 19.1302, 18.1467, 17.3191, 16.7678],
        [20.5116, 22.8690, 22.5759, 21.4527, 19.9081, 18.3228, 17.3743],
        [25.4246, 21.9207, 22.7243, 23.9691, 25.1304, 26.1046, 26.8563],
        [19.6241, 19.1897, 19.1620, 19.2777, 19.4206, 19.5618, 19.6937],
        [15.7177, 14.8366, 15.8391, 16.5965, 17.0896, 17.3927, 17.5800],
        [ 9.9305, 11.3007, 12.7564, 14.8462, 17.3390, 17.5247, 16.9579],
        [18.4325, 22.9971, 22.4682, 21.3957, 20.1622, 19.0726, 18.4381]])
ground truth
tensor([[ 9.9684, 13.3745, 16.0179, 16.5439, 14.5581, 19.2925, 14.6765],
        [15.0935, 23.4127, 20.0539, 21.7545, 31.9728, 26.0062, 17.4603],
        [17.8590, 23.9742, 20.8837, 11.8359,  8.0747, 11.1389, 15.8732],
        [39.6542, 21.0601, 17.1627, 12.7268, 10.0198, 21.5561, 45.7200],
        [25.8220, 12.2307, 12.5000,  8.5176, 11.3520, 17.8005, 23.3277],
        [20.6916,  

batch_predictions
tensor([[23.2419, 21.0704, 21.2958, 22.0872, 22.8184, 23.4048, 23.8541],
        [15.2733, 15.1345, 15.1633, 15.2105, 15.2513, 15.2886, 15.3243],
        [20.6057, 24.3157, 23.4638, 21.6667, 19.1021, 16.9775, 16.7765],
        [20.5463, 20.6331, 20.6123, 20.5529, 20.4965, 20.4648, 20.4610],
        [13.4510, 15.3540, 16.1809, 15.9346, 15.4880, 15.1022, 14.8219],
        [29.0579, 32.3751, 32.6047, 31.7256, 30.2139, 28.0983, 25.3279],
        [25.5400, 26.3336, 26.2136, 25.9470, 25.7132, 25.5615, 25.4911],
        [ 4.8065,  4.7688,  4.7718,  4.7780,  4.7812,  4.7824,  4.7840]])
ground truth
tensor([[16.5391, 17.0210, 15.3770, 16.2698, 32.8515, 24.7732, 21.5420],
        [12.3751, 11.8227, 11.3887, 14.6239, 15.1631, 23.4087, 15.0579],
        [19.7846, 30.8673, 30.9382, 10.7710, 16.9926, 10.9694, 15.2778],
        [18.1406, 11.5363, 12.7693, 13.2228, 14.3282, 22.5907, 25.2834],
        [14.1504, 20.5681, 24.4871,  4.7870,  9.2057, 12.2173, 15.2157],
        [50.0000, 3

batch_predictions
tensor([[28.3020, 34.3420, 34.4940, 33.1395, 30.7334, 26.6483, 22.2205],
        [13.2592, 17.1200, 21.6441, 21.1642, 19.8998, 18.3398, 17.0018],
        [28.2214, 28.7117, 28.5975, 28.4199, 28.2848, 28.2083, 28.1785],
        [18.3446, 18.8861, 18.6527, 18.3114, 18.0401, 17.8807, 17.8237],
        [21.7215, 25.7618, 25.0136, 22.9872, 19.5687, 16.9927, 17.0112],
        [22.3957, 23.8622, 23.5755, 22.9125, 22.1969, 21.5696, 21.1139],
        [16.0063, 18.0698, 17.8329, 17.0729, 16.3234, 15.7154, 15.2865],
        [14.0693, 15.9852, 16.3999, 15.7501, 14.9142, 14.0874, 13.2866]])
ground truth
tensor([[ 6.3776,  0.0000, 13.6196, 32.4405, 14.3141,  8.6876, 16.5249],
        [12.6134,  9.6514, 13.3503, 18.6933, 33.8294, 29.0249,  7.4263],
        [27.0252, 39.1899, 10.6260, 15.6760, 28.9321, 23.4219, 26.3940],
        [15.5445, 13.2167, 14.6370, 14.3872, 22.6854, 23.7507, 12.4540],
        [27.4235, 30.6831,  0.1559, 10.2041, 10.3458, 13.8605, 13.2653],
        [27.4660, 2

batch_predictions
tensor([[20.5080, 19.6235, 19.4798, 19.6602, 19.9413, 20.2334, 20.5064],
        [14.4095, 14.0661, 15.0654, 15.9257, 16.5349, 16.9051, 17.1040],
        [14.3746, 16.2404, 16.7680, 16.4599, 16.0451, 15.7136, 15.4930],
        [23.8689, 31.7835, 31.5230, 30.0902, 27.8750, 24.2049, 20.2932],
        [18.8652, 19.5745, 19.4158, 19.0856, 18.7961, 18.6082, 18.5222],
        [38.2490, 36.2783, 36.4664, 36.9353, 37.2598, 37.4506, 37.5671],
        [ 7.3970,  4.8564,  4.4242,  4.4299,  4.4665,  4.4986,  4.5255],
        [15.2560, 14.3939, 14.4310, 14.9337, 15.6255, 16.3987, 17.1899]])
ground truth
tensor([[ 0.0000,  0.0000,  6.4703, 14.7422, 15.2420, 20.5287, 31.5623],
        [11.1395,  9.9206,  9.7789,  9.0561, 11.0119, 15.9722, 25.1276],
        [11.0337, 10.9548, 11.5992, 12.4934, 15.1236, 13.5060, 10.4024],
        [15.5896, 18.8776, 23.0017, 34.6088, 49.4898, 34.3537,  0.0000],
        [17.0210, 25.9921, 24.9575, 12.8827, 12.2874, 11.7489,  6.9019],
        [50.1417, 3

batch_predictions
tensor([[15.5066, 19.7911, 23.5671, 23.9666, 23.2106, 22.0047, 20.5622],
        [15.3161, 17.7088, 17.9265, 17.3966, 16.8707, 16.4836, 16.2495],
        [21.8396, 23.2253, 22.9230, 22.3165, 21.7114, 21.2231, 20.9042],
        [ 6.0421,  3.8437,  3.7276,  3.9605,  3.7722,  3.9924,  3.8108],
        [18.0905, 17.1482, 17.0969, 17.4236, 17.8171, 18.1832, 18.5006],
        [16.8130, 15.8264, 15.8494, 16.2124, 16.5779, 16.8840, 17.1292],
        [12.6788, 11.7619, 10.9855,  9.9470,  8.5041,  7.3388,  6.8885],
        [18.7955, 18.3995, 18.4165, 18.5607, 18.7253, 18.8777, 19.0107]])
ground truth
tensor([[14.5975, 17.8005, 20.3373, 24.1638, 31.7177, 28.7982,  6.9728],
        [15.6497, 19.9500,  6.2599,  9.6791, 17.0700, 14.4398, 14.5844],
        [13.6763, 16.9643, 28.7982, 29.1383, 17.3328, 14.9660, 12.7834],
        [ 4.1950,  3.1746,  2.8770,  5.4280,  5.8957,  9.7931, 10.6859],
        [29.2741, 29.0110,  8.7585,  9.8501, 19.1741, 18.4771, 18.7927],
        [27.4376, 1

batch_predictions
tensor([[16.0084, 18.0457, 17.6634, 16.8073, 15.9754, 15.2863, 14.7870],
        [28.4429, 22.6671, 24.4436, 26.6794, 28.9236, 31.0864, 33.1979],
        [17.5479, 21.5344, 22.7997, 21.0691, 17.6437, 14.4664, 14.0159],
        [12.1927, 15.9928, 19.4828, 20.7617, 20.5499, 19.9530, 19.3605],
        [21.6233, 19.9394, 19.7445, 20.2300, 20.8253, 21.3779, 21.8598],
        [34.0971, 37.6239, 37.5095, 37.2181, 36.9696, 36.7901, 36.6757],
        [24.6409, 22.9046, 22.5703, 22.9107, 23.4545, 24.0145, 24.5341],
        [38.2631, 38.3508, 38.3520, 38.3517, 38.3512, 38.3505, 38.3495]])
ground truth
tensor([[17.6486, 25.8154, 25.4866,  8.0089, 12.8880, 13.9137,  9.9027],
        [15.0579, 13.8743, 16.3335, 15.3472, 33.8506, 45.4366, 79.2346],
        [21.0810, 23.4350, 11.9279, 15.8601,  9.1399,  8.4824, 12.7827],
        [ 6.4045, 24.9737, 27.1173, 24.8816, 11.5203, 10.8890, 11.2704],
        [33.1774, 17.2052, 14.9518, 18.3248, 19.5862, 28.6990, 22.8600],
        [29.7052, 3

batch_predictions
tensor([[19.5465, 18.2069, 18.0118, 18.3684, 18.8442, 19.3044, 19.7215],
        [23.1046, 26.6731, 26.1100, 24.7674, 22.8656, 20.4456, 18.9645],
        [21.0198, 25.2452, 25.0787, 23.5887, 21.1353, 18.3481, 17.4102],
        [16.9809, 17.4127, 17.1982, 16.8688, 16.6028, 16.4484, 16.3974],
        [38.3428, 38.3424, 38.3424, 38.3423, 38.3421, 38.3420, 38.3420],
        [18.0577, 16.8630, 16.9985, 17.4324, 17.8067, 18.0969, 18.3188],
        [13.9909, 13.7627, 13.7863, 13.8704, 13.9647, 14.0549, 14.1335],
        [15.6967, 15.6309, 15.6688, 15.7076, 15.7364, 15.7622, 15.7871]])
ground truth
tensor([[15.0973, 16.0310, 16.0573, 15.5181, 17.6881, 27.7880, 27.9853],
        [22.5624,  9.6655, 20.2381, 15.0652, 30.1871, 10.7993, 13.6763],
        [16.3993, 15.6234, 16.3993, 16.3730, 22.6854, 22.6460, 13.1904],
        [17.2672, 24.5660,  7.4171, 11.5729, 12.1383, 12.4934, 15.5313],
        [33.7454, 38.9663, 59.4029, 49.6712, 44.2662, 50.1447, 36.9937],
        [21.5986, 1

batch_predictions
tensor([[21.4486, 25.6562, 24.8072, 23.0618, 20.5222, 17.9892, 17.5962],
        [27.8829, 23.8873, 24.7186, 26.1423, 27.4683, 28.5370, 29.3137],
        [21.7785, 30.3002, 33.9273, 33.0660, 30.1179, 23.0480, 18.5818],
        [22.2305, 23.6654, 23.3666, 22.7213, 22.0457, 21.4706, 21.0677],
        [18.1289, 20.8954, 20.5211, 19.8034, 19.1428, 18.6501, 18.3567],
        [20.2085, 18.6370, 18.4874, 19.0198, 19.6523, 20.2395, 20.7583],
        [19.9406, 18.0374, 19.2990, 20.4379, 21.3774, 22.0820, 22.5486],
        [19.3187, 23.4593, 22.3056, 20.0402, 17.0565, 15.6379, 15.7812]])
ground truth
tensor([[22.6986, 28.0905, 29.7212, 12.5855, 18.8322, 15.5708, 17.2935],
        [32.8656, 20.5215, 15.8730, 11.0544, 17.0493, 24.7591, 40.5471],
        [15.9722, 36.5363, 51.1621, 36.5930, 17.1344, 17.1202, 18.2398],
        [19.4303, 31.5334,  9.8498, 14.8101, 18.1973, 12.6559, 20.7766],
        [13.9172, 12.9252, 13.6480, 22.7608, 26.3322, 22.4206, 12.6276],
        [28.3730, 2

batch_predictions
tensor([[10.2209,  8.5516,  7.1622,  6.6816,  6.5931,  6.5917,  6.6034],
        [37.5054, 38.2119, 38.1912, 38.1481, 38.1185, 38.1015, 38.0935],
        [11.4600, 13.3929, 15.8678, 18.4334, 17.5762, 16.5530, 15.7238],
        [15.7848, 16.7636, 16.6832, 16.2365, 15.7550, 15.3548, 15.0707],
        [21.9318, 22.7367, 22.7149, 22.3477, 21.8962, 21.5098, 21.2557],
        [13.8033, 17.1092, 19.2099, 17.9824, 16.5374, 15.1116, 13.8442],
        [15.0328, 20.4255, 24.8517, 26.0884, 25.6535, 24.7606, 23.7250],
        [16.5515, 18.4143, 18.3192, 17.8884, 17.5031, 17.2432, 17.1120]])
ground truth
tensor([[12.4717,  9.1553,  6.8452,  8.0782,  7.9223,  8.9002, 10.4450],
        [26.2897, 39.3707, 54.6202, 49.0079, 31.3917, 28.8549, 22.5057],
        [ 9.9027,  7.9695,  9.9421, 13.0589, 20.9890, 22.0673,  6.0757],
        [20.9890, 22.0673,  6.0757, 10.8890, 11.7833, 10.6654, 13.6902],
        [16.1706, 14.5550, 18.8350, 15.2778, 23.3702, 10.6151, 10.4592],
        [10.0605, 1

batch_predictions
tensor([[17.6979, 20.6638, 20.2276, 19.2912, 18.3311, 17.5178, 16.9520],
        [28.0552, 26.0847, 25.6568, 25.9910, 26.5924, 27.2310, 27.8187],
        [26.9925, 27.6584, 27.6452, 27.4163, 27.1392, 26.9052, 26.7549],
        [15.9682, 15.6493, 16.8762, 17.7793, 18.3434, 18.6644, 18.8501],
        [19.7097, 21.4468, 20.9943, 20.1029, 19.1467, 18.3154, 17.7737],
        [14.3314, 17.6386, 20.4552, 20.8461, 20.2622, 19.5502, 18.9699],
        [11.4280, 10.5214,  9.2231,  7.6164,  6.7054,  6.4614,  6.4221],
        [17.5586, 19.6512, 19.3231, 18.6228, 17.9597, 17.4479, 17.1241]])
ground truth
tensor([[18.0130, 23.8095, 24.6599, 11.0969, 14.9660, 11.6213,  9.5947],
        [26.5913, 22.9879, 22.9090, 22.2514, 24.1189, 23.3298, 40.5839],
        [40.5471, 30.4989, 18.8209, 18.5374, 12.8401, 17.5454, 30.0028],
        [16.0289,  0.0000,  0.0000,  8.9569, 12.4150, 14.1865, 17.1627],
        [20.9892, 32.8656, 33.2766, 15.2778,  9.0278, 13.8889, 19.2035],
        [10.2315, 1

batch_predictions
tensor([[16.3808, 19.6742, 19.6425, 18.9515, 18.2326, 17.6563, 17.2764],
        [16.0973, 14.8036, 15.6298, 16.4357, 17.0069, 17.3870, 17.6391],
        [12.8574, 12.0180, 12.2169, 12.9528, 14.0668, 15.0920, 15.5497],
        [30.3903, 25.9765, 25.9853, 27.3846, 28.9027, 30.2502, 31.3855],
        [16.5349, 15.8021, 17.0374, 18.0333, 18.7332, 19.1534, 19.3834],
        [19.8989, 21.9573, 21.5349, 20.8470, 20.2094, 19.7086, 19.3851],
        [19.5730, 20.4336, 20.2183, 19.7865, 19.3894, 19.1124, 18.9710],
        [21.8908, 29.8145, 29.4658, 27.9430, 25.2818, 21.0795, 18.6506]])
ground truth
tensor([[21.7971, 29.1525, 26.1054, 14.0731, 13.0811, 12.6276, 13.6054],
        [22.1372, 11.6780,  9.1978, 11.7772, 11.6638, 19.3027, 26.6156],
        [ 4.3226,  7.0720,  5.5414,  6.3492, 13.4070,  1.2188, 25.9495],
        [37.3158, 21.4569, 27.5368, 20.1814,  6.3776,  0.0000, 13.6196],
        [24.0400,  8.3772, 10.0342, 10.8364, 10.3761, 18.0168, 25.1578],
        [18.4377, 2

batch_predictions
tensor([[18.9282, 18.8433, 18.8550, 18.8692, 18.8767, 18.8869, 18.9039],
        [17.5142, 21.8052, 24.9293, 25.3436, 24.7816, 23.8482, 22.7321],
        [19.0446, 21.7056, 21.0653, 19.9023, 18.6029, 17.3866, 16.6574],
        [ 5.0464,  4.6430,  4.6245,  4.6471,  4.6641,  4.6750,  4.6830],
        [ 8.9317,  7.7093,  7.1054,  6.9376,  6.9143,  6.9261,  6.9384],
        [28.3679, 26.3718, 25.5626, 25.5181, 25.9865, 26.7337, 27.6499],
        [ 9.2332, 10.3814, 11.2063, 12.0288, 13.0623, 14.5244, 15.9992],
        [13.5120, 16.5165, 19.9324, 18.9430, 17.5842, 16.3668, 15.5011]])
ground truth
tensor([[27.2225,  5.4182,  0.0000, 19.2267, 17.4513, 15.7680, 18.8322],
        [37.1457, 34.5947, 14.6117, 19.7279, 14.3991, 19.0760, 22.2931],
        [14.4841, 18.8917, 21.8537, 30.2863, 30.6264, 13.9739, 13.6905],
        [ 5.1587,  4.9603,  4.6769,  5.0595, 10.6151, 11.5363,  7.5113],
        [12.1457,  8.5034,  5.9949,  7.0578,  7.3413,  8.9427,  8.8294],
        [46.4994, 4

batch_predictions
tensor([[17.4742, 16.0192, 17.2039, 18.1570, 18.8071, 19.2082, 19.4550],
        [18.1685, 16.7160, 17.8528, 18.7259, 19.3222, 19.7044, 19.9475],
        [20.2113, 18.2913, 18.6412, 19.5069, 20.2785, 20.8963, 21.3795],
        [21.3804, 22.5510, 22.2859, 21.7106, 21.1241, 20.6517, 20.3489],
        [25.2408, 28.6900, 28.2125, 27.3074, 26.3253, 25.2959, 24.2521],
        [17.3216, 16.6146, 18.4231, 20.2730, 21.8178, 22.4575, 22.5844],
        [22.3698, 20.2631, 19.8087, 20.3521, 21.1712, 22.0314, 22.8941],
        [23.1072, 22.7866, 22.7593, 22.8377, 22.9369, 23.0464, 23.1632]])
ground truth
tensor([[28.1037, 25.8503,  9.0136, 10.6576, 12.8401, 10.5159, 18.7217],
        [10.7312, 16.0442, 13.6376, 16.4519, 15.0579, 26.2362,  5.3524],
        [12.2567, 14.2031, 20.2656, 21.8701, 16.8332, 24.3030, 29.6423],
        [ 7.3129,  8.3759, 13.9031, 19.4728, 22.8316, 27.6644, 18.3673],
        [27.6959, 22.2251, 28.0510, 14.7817, 13.8874, 23.1326, 16.3204],
        [ 8.6402, 1

batch_predictions
tensor([[19.8688, 18.6597, 18.6216, 18.9931, 19.3794, 19.7044, 19.9662],
        [31.5670, 32.7695, 32.6501, 32.3194, 32.0222, 31.8231, 31.7209],
        [17.6557, 17.1784, 17.1766, 17.3379, 17.5247, 17.6986, 17.8502],
        [15.3269, 15.0478, 16.4109, 17.6470, 18.5943, 19.1094, 19.2932],
        [18.9661, 18.1927, 18.0485, 18.2051, 18.4793, 18.7923, 19.1109],
        [23.6589, 32.3884, 32.2540, 30.1883, 26.0025, 19.2446, 18.9728],
        [ 9.4833, 10.7828, 12.0833, 14.0399, 17.7847, 19.7257, 18.3275],
        [ 4.6123,  4.5933,  4.5974,  4.6026,  4.6050,  4.6057,  4.6069]])
ground truth
tensor([[12.9143, 19.5818, 14.9264, 17.1883, 21.1073, 24.3951, 26.5913],
        [27.2094, 37.2304, 39.7948,  6.1021, 18.2141, 24.2372, 23.6323],
        [23.8237, 21.4144, 12.0748, 14.1015, 11.0261, 11.9189, 21.4144],
        [15.4337, 15.3203, 11.9615, 10.0624, 14.5692, 25.7795, 25.0709],
        [17.7438, 11.7489, 16.4541,  9.2404, 11.4087, 20.9609, 24.1780],
        [15.7880, 1

batch_predictions
tensor([[12.8585, 17.2724, 22.3628, 22.9628, 22.1103, 20.7699, 19.1620],
        [18.0128, 21.1111, 20.7091, 19.8473, 18.9905, 18.2986, 17.8454],
        [16.6518, 21.0299, 21.1025, 20.1232, 18.8488, 17.5536, 16.5966],
        [21.3851, 28.3828, 27.7751, 25.0236, 19.3169, 16.8124, 17.1145],
        [25.3930, 31.5522, 31.5949, 29.9285, 26.6382, 20.7085, 19.0634],
        [16.5117, 21.1387, 25.8415, 26.8393, 26.0545, 24.1057, 20.3663],
        [20.5333, 20.8928, 20.7278, 20.4796, 20.2778, 20.1588, 20.1185],
        [17.8762, 23.1314, 25.1483, 23.8682, 21.0427, 17.0218, 16.2373]])
ground truth
tensor([[12.4540, 13.5850, 15.1499, 17.6486, 25.8548, 31.4177,  5.1815],
        [17.0918, 18.1831, 21.0743, 12.9677, 14.4416, 10.3741, 13.4921],
        [11.9937, 19.4240, 25.7496, 26.1704, 13.9663, 11.6386,  9.6791],
        [43.4524, 38.3929, 19.0051, 19.2744, 12.7126, 21.8679, 24.0788],
        [17.3611, 10.0057, 18.8776, 31.7035, 45.2523, 31.4768, 19.2744],
        [10.0624, 1

batch_predictions
tensor([[19.7865, 24.2050, 23.4268, 21.9665, 20.2396, 18.6481, 17.8374],
        [15.6878, 14.4999, 14.8749, 15.5047, 16.0286, 16.4199, 16.7037],
        [21.9432, 27.6653, 26.5592, 24.0300, 19.5928, 17.3123, 17.6473],
        [18.6535, 16.7296, 17.6230, 18.5313, 19.2655, 19.8000, 20.1667],
        [31.4069, 33.6157, 33.7602, 33.2995, 32.6625, 32.0493, 31.5593],
        [22.7847, 24.5740, 24.3000, 23.3682, 22.1100, 20.7500, 19.6480],
        [20.9186, 27.2976, 26.3309, 24.2135, 20.8188, 17.7585, 17.9334],
        [19.7142, 22.5578, 21.9057, 20.5895, 18.9634, 17.3944, 16.6564]])
ground truth
tensor([[18.4640, 31.9963, 13.4403, 14.4398, 11.9805,  3.9716, 16.8201],
        [ 6.2217, 10.9836, 12.9535, 10.0482, 17.1485, 20.3515, 25.9495],
        [11.9615, 18.2965, 23.6961, 36.9756, 39.6684,  9.2971, 21.1876],
        [34.3503,  9.8764, 14.3740, 12.2962, 11.5729, 15.7680, 18.2404],
        [32.4961, 39.3872, 32.4566, 15.3603,  3.2614, 12.0989, 18.3588],
        [16.4966, 2

batch_predictions
tensor([[ 4.0017,  3.5129,  3.5638,  3.5816,  3.5938,  3.6023,  3.6088],
        [25.0361, 29.2892, 30.4045, 29.9681, 28.9097, 27.2876, 24.7642],
        [28.3256, 29.1228, 28.9965, 28.7693, 28.5885, 28.4800, 28.4334],
        [29.7836, 30.9898, 31.1198, 30.9176, 30.6297, 30.3767, 30.2111],
        [17.2102, 20.6938, 20.2124, 19.2081, 18.1690, 17.3036, 16.7429],
        [27.0177, 34.1279, 38.1267, 38.2677, 38.2696, 38.2612, 38.2516],
        [21.9408, 31.2162, 33.7854, 32.6813, 29.7180, 22.9840, 18.4135],
        [ 7.3033,  6.1057,  5.8545,  5.8390,  5.8548,  5.8669,  5.8753]])
ground truth
tensor([[ 5.1587,  5.6973,  6.1650,  6.3492,  8.5317,  9.8073,  7.1570],
        [19.9688, 31.4201, 45.4223, 38.3645, 17.9280, 18.0414, 16.9785],
        [25.9206, 25.8285, 36.6649, 37.9932, 21.9227, 29.5634, 20.7654],
        [23.7245, 26.7715, 39.0023, 51.9841, 39.9660, 24.7449, 23.6111],
        [17.8713, 15.0227, 17.5028, 20.6774, 28.0045, 23.4269, 13.4779],
        [29.9579, 4

batch_predictions
tensor([[ 9.2155,  8.0558,  7.2612,  6.9367,  6.8526,  6.8514,  6.8672],
        [ 3.7370,  3.6530,  3.6666,  3.6784,  3.6863,  3.6912,  3.6946],
        [23.6813, 22.2755, 22.0203, 22.3037, 22.7569, 23.2263, 23.6610],
        [21.9867, 23.1890, 22.9634, 22.4472, 21.9344, 21.5375, 21.2964],
        [25.0881, 21.9409, 22.5543, 23.6665, 24.6814, 25.5092, 26.1435],
        [15.9271, 20.3379, 22.0896, 21.2196, 19.8520, 18.1489, 16.6183],
        [16.7173, 22.0137, 25.5525, 25.5370, 24.6259, 23.0449, 20.4886],
        [33.6580, 32.3936, 32.2484, 32.4555, 32.7414, 33.0237, 33.2759]])
ground truth
tensor([[14.7028, 12.8617,  4.3924,  7.4040,  8.0747,  7.3514, 11.3887],
        [ 3.1321,  4.2092,  2.9762,  2.8061,  5.4563,  7.8798,  7.6814],
        [25.8022, 18.6086, 21.1862, 14.1110, 13.3877, 16.9516, 27.6696],
        [17.2541, 22.0147, 20.1999, 25.2893, 42.6092, 21.2914, 21.0153],
        [28.1431, 11.4019, 12.9537, 12.1383, 17.3330, 21.1599, 36.8885],
        [13.8889, 1

batch_predictions
tensor([[10.9794, 11.9715, 13.3113, 15.3322, 16.3373, 15.8084, 15.0891],
        [10.4603,  8.2846,  6.3533,  5.8130,  5.7467,  5.7506,  5.7606],
        [16.2353, 19.7297, 19.9207, 19.2216, 18.4652, 17.8387, 17.4117],
        [21.4479, 20.8180, 20.7292, 20.8626, 21.0660, 21.2869, 21.5040],
        [15.3325, 20.0068, 20.8479, 19.6631, 18.0024, 16.0654, 14.7610],
        [11.4538, 12.3227, 13.4035, 14.6758, 15.2812, 15.0468, 14.5615],
        [19.9900, 20.5179, 20.3858, 20.1324, 19.9166, 19.7816, 19.7245],
        [19.1477, 19.1615, 19.1618, 19.1439, 19.1236, 19.1148, 19.1143]])
ground truth
tensor([[ 7.5618,  7.4566,  8.8112, 10.8364, 17.4908, 25.5129,  8.1799],
        [15.8732,  5.2209,  6.6544,  6.6675,  6.4440, 10.5865, 13.0326],
        [13.7559, 14.8211, 22.5802, 26.7228, 14.9132, 11.9411, 13.9532],
        [28.7612,  5.1683, 10.7049, 15.4129, 16.9253, 14.0715, 22.3304],
        [12.8748,  9.0610, 15.0316, 27.5776, 22.8564, 17.5171, 12.3882],
        [10.8627, 1

batch_predictions
tensor([[14.8153, 18.6677, 19.7437, 18.7552, 17.6353, 16.5990, 15.7679],
        [38.2595, 38.2524, 38.2545, 38.2548, 38.2545, 38.2541, 38.2537],
        [21.7986, 20.2499, 19.6693, 19.7732, 20.2859, 20.9986, 21.8591],
        [15.7946, 16.9177, 16.9242, 16.5751, 16.2229, 15.9689, 15.8220],
        [22.1640, 26.2388, 25.6615, 23.8742, 20.8567, 17.8877, 17.5439],
        [13.7580, 17.4754, 21.0963, 19.8160, 17.9387, 16.0299, 14.7163],
        [21.3141, 25.5212, 28.4128, 31.0642, 33.5238, 35.0750, 35.6173],
        [19.0069, 24.8760, 23.4229, 20.1385, 16.1907, 15.6075, 15.8897]])
ground truth
tensor([[11.0994, 14.4003, 18.8059, 15.8864, 27.0516, 31.2730, 10.7969],
        [42.5565, 49.4345, 49.9211, 39.3740, 31.8911, 32.3382, 44.3582],
        [35.6859, 35.4734, 12.5567, 11.6497,  8.8294, 12.7834, 24.8724],
        [11.2441,  9.6923, 10.5339, 13.6639, 12.3225, 18.5692, 14.7554],
        [15.1786, 15.3061, 29.9603, 41.1139, 32.4830, 23.7528, 16.6383],
        [13.3482, 1

batch_predictions
tensor([[22.9044, 21.8178, 21.5660, 21.7757, 22.2205, 22.7732, 23.3621],
        [ 9.3242, 10.0468, 10.5127, 10.8886, 11.2384, 11.6192, 12.1029],
        [26.3921, 25.6545, 25.4989, 25.6145, 25.8479, 26.1468, 26.4766],
        [19.6621, 24.1674, 23.1291, 21.2393, 18.8566, 16.9723, 16.8401],
        [16.2413, 15.5419, 16.8423, 17.8800, 18.5681, 18.9447, 19.1417],
        [17.4469, 19.3133, 18.8278, 18.0034, 17.2100, 16.5623, 16.1275],
        [16.6152, 16.8050, 16.7981, 16.6841, 16.5501, 16.4457, 16.3845],
        [10.9677, 11.8935, 13.0747, 14.7041, 15.7684, 15.4877, 14.9691]])
ground truth
tensor([[20.2806, 22.3356, 21.1026, 37.2024, 39.8243, 10.0198, 14.9235],
        [10.9416, 15.2420, 24.6844, 15.9916,  0.0000,  8.9164, 11.4808],
        [26.3605, 12.9960, 13.3787, 17.5312, 19.7988, 24.9150, 38.1661],
        [16.7543, 23.5139, 30.6286, 35.6260,  1.9858, 11.2309, 16.7017],
        [12.9252, 15.0085, 18.8776, 14.1723, 16.1990, 19.7988, 27.6927],
        [ 9.0986, 1

batch_predictions
tensor([[23.5086, 30.7088, 30.1643, 28.5012, 25.8145, 20.8562, 18.8597],
        [28.0065, 32.5709, 32.8083, 31.5806, 29.3471, 25.7716, 22.1220],
        [22.5590, 21.4222, 21.0848, 21.1865, 21.5213, 21.9763, 22.5019],
        [ 6.5737,  6.8044,  6.8293,  6.8257,  6.8227,  6.8157,  6.8128],
        [19.6353, 20.8995, 20.6517, 20.0232, 19.3609, 18.8208, 18.4769],
        [12.8795, 14.9884, 17.1463, 17.0483, 16.3477, 15.6545, 15.0818],
        [ 3.6760,  3.6530,  3.6594,  3.6651,  3.6688,  3.6709,  3.6723],
        [28.4126, 24.5563, 24.8811, 26.0774, 27.2666, 28.2662, 29.0363]])
ground truth
tensor([[16.5675, 25.8220, 38.1803, 30.5839, 19.9972, 16.8226, 14.9518],
        [41.6100, 43.6791, 38.4070, 13.6763,  9.6939, 12.6134, 15.9722],
        [13.2036, 14.8738, 11.0205, 12.7433, 23.8033, 28.5508, 30.7338],
        [ 9.5082,  8.9953,  8.4035, 12.4803, 16.8201, 10.8233,  3.4850],
        [ 9.2404, 11.4087, 20.9609, 24.1780, 26.8707, 11.0969, 14.5975],
        [12.3866, 2

batch_predictions
tensor([[19.7382, 19.9590, 19.8950, 19.7748, 19.6744, 19.6165, 19.5985],
        [20.8894, 22.4987, 22.1693, 21.4059, 20.5635, 19.8109, 19.2728],
        [ 4.8848,  4.0297,  4.0096,  4.0454,  4.0720,  4.0904,  4.1035],
        [31.9111, 25.7754, 26.8295, 28.2146, 29.3943, 30.4046, 31.2486],
        [22.4109, 19.9547, 20.3808, 21.3114, 22.1528, 22.8313, 23.3541],
        [21.3407, 25.4550, 24.5927, 22.8485, 20.3428, 17.9618, 17.5821],
        [17.8275, 23.4086, 23.0825, 21.7325, 19.6005, 17.4215, 16.7537],
        [19.9724, 26.0511, 24.9433, 21.8103, 17.0170, 16.0194, 16.2318]])
ground truth
tensor([[32.1145, 28.7273, 17.7721,  7.3129,  8.3759, 13.9031, 19.4728],
        [27.7749, 30.1289, 11.2309, 19.8185, 21.0284, 20.0947, 22.4487],
        [ 7.9365,  7.9082,  5.2721,  7.7239,  7.5113,  7.8656,  8.9569],
        [53.9058, 64.7291, 14.0978, 12.3882, 15.8732, 17.8722, 33.8769],
        [12.7170, 18.9769, 18.0694, 16.4519, 23.5534, 31.0626, 27.7749],
        [25.7496, 3

batch_predictions
tensor([[20.9940, 25.7985, 25.0072, 23.7634, 22.3585, 20.9491, 19.7175],
        [ 3.7030,  3.6031,  3.6158,  3.6284,  3.6370,  3.6427,  3.6469],
        [21.6155, 26.6507, 25.8585, 23.6470, 19.8204, 17.1252, 17.3421],
        [16.3830, 15.1553, 16.0922, 16.8490, 17.3755, 17.7169, 17.9342],
        [19.5621, 19.5036, 19.5194, 19.5302, 19.5330, 19.5415, 19.5632],
        [22.2258, 26.6232, 25.9963, 24.3534, 21.8111, 18.8569, 18.1109],
        [32.2209, 29.5065, 28.8976, 29.2988, 30.0237, 30.7894, 31.5095],
        [16.0019, 14.7575, 14.8682, 15.4486, 16.0742, 16.6640, 17.2210]])
ground truth
tensor([[23.1326, 16.3204, 17.3856, 26.2756, 26.8148, 15.2157, 19.4240],
        [ 5.2863,  4.8469,  3.8974,  4.9036,  4.2234,  6.5051,  5.1587],
        [14.3424, 13.9031, 23.1151, 34.4955, 34.6514, 18.4666, 14.7392],
        [29.8264,  9.9290, 12.0857, 12.4277, 11.0205, 19.8185, 22.8432],
        [31.0941, 20.7625,  8.2483, 12.2449,  6.0941, 12.0323, 18.5374],
        [21.7780, 3

batch_predictions
tensor([[ 4.5452,  4.5533,  4.5541,  4.5564,  4.5570,  4.5462,  4.5430],
        [26.0216, 25.9523, 25.9555, 25.9653, 25.9633, 25.9690, 25.9905],
        [ 7.0016,  6.8148,  6.7566,  6.7447,  6.7466,  6.7475,  6.7489],
        [20.3980, 18.4762, 19.5850, 20.5622, 21.3039, 21.8249, 22.1722],
        [37.8955, 35.3487, 35.4054, 36.0245, 36.5457, 36.8818, 37.0885],
        [ 8.2018,  8.9124,  9.2441,  9.4215,  9.5171,  9.5721,  9.6096],
        [19.7482, 19.5630, 19.5796, 19.6246, 19.6643, 19.7011, 19.7393],
        [10.0941, 10.6535, 11.1567, 11.6856, 12.3292, 13.2399, 14.5472]])
ground truth
tensor([[ 4.2942,  4.3226,  4.7477,  8.4892, 10.0057,  6.2217,  3.8124],
        [41.7806, 33.0352, 18.5823, 17.5039, 18.7533, 22.0410, 29.4845],
        [ 6.5492,  7.2593, 11.5466, 17.0174,  4.7475,  6.2336,  6.7859],
        [28.7273, 17.7721,  7.3129,  8.3759, 13.9031, 19.4728, 22.8316],
        [42.7296, 31.1366, 25.8645, 21.9246, 32.4830, 43.6366, 60.1757],
        [ 2.6302, 1

batch_predictions
tensor([[19.3809, 16.7725, 17.7680, 19.1088, 20.3426, 21.3171, 22.1365],
        [17.9585, 23.1836, 27.1660, 28.0103, 27.5411, 26.5429, 25.1269],
        [22.0349, 30.1141, 29.9629, 28.1724, 24.4072, 18.5705, 18.0700],
        [34.2670, 31.4196, 30.4969, 30.6197, 31.2244, 32.0171, 32.8759],
        [37.7532, 37.5339, 37.5594, 37.5832, 37.5985, 37.6135, 37.6306],
        [20.0773, 20.8880, 20.5368, 19.9846, 19.4977, 19.1689, 19.0166],
        [ 9.5445, 11.0104, 12.1773, 13.5146, 14.9592, 15.8701, 16.0910],
        [15.1000, 17.0775, 17.2242, 16.6492, 16.0338, 15.5338, 15.1774]])
ground truth
tensor([[ 7.9507, 10.6151, 10.7710, 11.9898, 19.0051, 38.7330, 51.1905],
        [18.9909, 22.5907, 34.0420,  9.3821, 11.8197, 26.7857, 20.2806],
        [15.5187, 35.8985, 38.9031, 18.9626, 10.5584, 16.7517, 15.4478],
        [24.3425, 20.2130, 19.9369, 18.8848, 33.9427, 55.6418, 14.1899],
        [57.7523, 42.2194, 30.9382, 25.8645, 21.2302, 29.7336, 42.8713],
        [22.4619,  

batch_predictions
tensor([[25.4557, 27.8394, 27.7743, 27.2349, 26.5131, 25.6921, 24.7771],
        [38.2368, 38.2356, 38.2357, 38.2356, 38.2354, 38.2352, 38.2351],
        [16.7962, 21.7628, 25.7079, 27.3998, 27.0630, 25.9031, 24.2533],
        [17.6361, 21.7645, 24.8954, 27.4442, 28.7650, 28.8219, 28.1498],
        [17.9497, 17.5631, 17.5416, 17.6076, 17.6684, 17.7217, 17.7710],
        [20.0802, 18.3877, 20.0744, 21.7210, 23.1808, 24.2382, 24.7795],
        [22.3671, 21.4655, 21.2325, 21.2539, 21.3515, 21.4878, 21.6448],
        [19.8806, 20.7551, 20.4554, 19.8998, 19.3752, 18.9967, 18.7983]])
ground truth
tensor([[22.0410, 29.4845, 44.7002, 34.2451, 21.2388, 20.1078, 22.8958],
        [41.8201, 59.8501, 54.9448, 42.2935, 41.8201, 37.6381, 38.2167],
        [28.8124, 15.3486, 12.7268, 33.6735, 54.1383, 29.3793, 15.5896],
        [ 4.9745, 35.4450, 45.1956, 33.7585, 18.9909, 16.8651, 16.0006],
        [19.5011, 11.3237, 12.9252, 12.9110, 11.3095, 12.6417, 25.2268],
        [ 3.7415, 1

batch_predictions
tensor([[12.3453, 15.2282, 19.6429, 18.9979, 17.7511, 16.4257, 15.3171],
        [15.6869, 15.9488, 15.8895, 15.7224, 15.5642, 15.4614, 15.4172],
        [31.9101, 25.2994, 26.5109, 28.1617, 29.6757, 31.0907, 32.4263],
        [15.4138, 20.2819, 24.7556, 26.4227, 26.0413, 24.9236, 23.3316],
        [19.9313, 24.0544, 23.9349, 22.9524, 21.6874, 20.3926, 19.4262],
        [21.6912, 19.4770, 20.2475, 21.2331, 22.0406, 22.6395, 23.0631],
        [18.8354, 18.0225, 17.9470, 18.1823, 18.4935, 18.7956, 19.0637],
        [14.8792, 18.0200, 19.1372, 18.5299, 17.8872, 17.3987, 17.0978]])
ground truth
tensor([[11.1783,  7.3251, 13.9795, 21.3046, 25.7891, 10.8627, 10.5208],
        [10.9127,  9.3679,  8.7727, 15.1219, 17.6020, 30.8957, 13.6763],
        [14.0978, 12.3882, 15.8732, 17.8722, 33.8769, 48.5402, 41.5965],
        [ 9.9915, 29.9178, 25.4535, 11.0544,  9.0703,  9.3963, 12.9252],
        [13.3877, 16.9516, 27.6696, 26.0521, 14.6502, 14.6896, 13.3614],
        [ 0.0000,  

batch_predictions
tensor([[18.2518, 17.6343, 17.5629, 17.7300, 17.9750, 18.2371, 18.4909],
        [32.0196, 26.4533, 27.3999, 28.9274, 30.3508, 31.6087, 32.7033],
        [19.8781, 24.6393, 23.7372, 21.7173, 18.9647, 17.0105, 16.9908],
        [13.6804, 17.5302, 21.1714, 19.9705, 18.2684, 16.6139, 15.3330],
        [23.3885, 28.1673, 27.9708, 26.3705, 23.2513, 19.1496, 18.2158],
        [21.4955, 27.0998, 26.1639, 24.0696, 20.7627, 17.8099, 17.8655],
        [15.0362, 14.0329, 14.4624, 15.0952, 15.6033, 15.9660, 16.2142],
        [29.0210, 31.1267, 31.0829, 30.4766, 29.7053, 28.9210, 28.1915]])
ground truth
tensor([[26.9726,  9.7449, 11.4150, 12.0594, 14.0847, 20.7128, 23.3956],
        [45.6207, 29.8527, 28.1694, 20.3314, 23.0800, 29.5371, 49.6449],
        [17.5170, 16.3832, 20.7341, 26.0204, 35.9552, 15.3061, 16.4966],
        [ 1.4172, 19.8696, 20.0113, 23.3985, 33.6026, 34.4388,  8.4467],
        [14.3707, 17.5454, 32.4405, 46.6270, 38.7472,  4.2942,  0.0000],
        [16.0705, 1

batch_predictions
tensor([[24.2970, 25.8030, 25.6920, 25.0472, 24.1741, 23.2467, 22.4123],
        [21.5736, 19.4414, 20.0809, 21.2059, 22.1987, 22.9903, 23.6088],
        [22.8099, 22.9725, 22.9450, 22.8685, 22.7995, 22.7601, 22.7513],
        [14.5182, 18.7388, 23.5101, 23.9096, 22.8216, 20.7741, 17.9259],
        [15.9927, 18.7603, 18.8928, 18.1494, 17.3445, 16.6644, 16.1671],
        [20.7379, 29.1559, 28.9441, 26.9954, 22.6392, 17.1335, 17.1859],
        [19.4126, 21.5512, 21.1086, 20.3795, 19.6961, 19.1524, 18.7956],
        [12.5451, 15.2811, 17.6542, 18.5043, 18.3926, 18.1305, 17.9063]])
ground truth
tensor([[ 8.8010, 16.6383, 20.1389, 29.5493, 43.7783, 26.2188,  0.0000],
        [31.5476, 11.6922, 17.5170, 15.1927, 16.6950, 19.5011, 31.4342],
        [20.1999, 25.2893, 42.6092, 21.2914, 21.0153, 25.0789, 20.9100],
        [11.6213,  8.2058, 17.2336, 20.7625, 30.7965, 33.7585,  7.8940],
        [19.1215, 21.4755, 21.8832, 14.6239, 15.6628,  9.9947, 12.4408],
        [30.2438, 5

batch_predictions
tensor([[21.8529, 19.4544, 19.8980, 20.8380, 21.6927, 22.3888, 22.9326],
        [26.9570, 23.0662, 23.6872, 25.0608, 26.3952, 27.5062, 28.3354],
        [19.1057, 25.5069, 25.5157, 23.7116, 19.8423, 16.3118, 16.1463],
        [ 6.7363,  6.7543,  6.7584,  6.7606,  6.7586,  6.7415,  6.7344],
        [17.1490, 21.3404, 20.7216, 19.2326, 17.3848, 15.8037, 15.1687],
        [13.8652, 17.1602, 19.9613, 19.0055, 17.9035, 17.0087, 16.4189],
        [16.3708, 19.0466, 19.3172, 18.8609, 18.4172, 18.1057, 17.9353],
        [29.2712, 28.5888, 28.5133, 28.6361, 28.8075, 28.9877, 29.1620]])
ground truth
tensor([[30.9382, 10.7710, 16.9926, 10.9694, 15.2778, 22.8883, 32.8940],
        [30.4989, 18.8209, 18.5374, 12.8401, 17.5454, 30.0028, 50.0850],
        [19.9500, 30.7601, 33.6533, 12.3225, 15.1236, 10.1394, 12.2304],
        [ 6.6675,  6.4440, 10.5865, 13.0326, 14.3083,  7.6144,  6.6544],
        [15.4053, 13.7755, 15.8872, 34.3821, 33.6593, 13.2795, 16.3407],
        [14.3083, 1

batch_predictions
tensor([[19.1858, 22.7309, 21.7541, 20.0683, 17.9893, 16.2579, 15.9686],
        [22.1431, 23.1230, 23.0290, 22.6246, 22.1867, 21.8386, 21.6368],
        [14.7730, 18.3105, 19.2122, 18.0217, 16.5476, 15.0179, 13.8250],
        [21.4461, 19.1313, 19.5537, 20.4266, 21.2094, 21.8420, 22.3336],
        [19.4305, 17.9354, 19.1511, 20.1865, 20.9875, 21.5599, 21.9359],
        [21.2866, 18.7693, 18.6864, 19.5347, 20.4494, 21.3120, 22.1217],
        [19.9752, 17.9800, 18.9829, 19.9399, 20.6931, 21.2340, 21.5998],
        [19.1664, 24.2072, 28.2112, 29.1276, 28.5780, 27.4590, 25.7077]])
ground truth
tensor([[18.2141, 22.2646, 29.2215, 10.6391, 13.8348, 12.6644, 14.1504],
        [ 2.7778, 13.7613, 13.9314, 19.1327, 33.1207, 25.1417, 16.0856],
        [24.2109, 29.6028,  8.2720, 10.3235, 10.0605, 12.7170, 15.8864],
        [18.9342,  2.7778, 13.7613, 13.9314, 19.1327, 33.1207, 25.1417],
        [30.1420, 23.2246, 20.0421, 12.5460, 17.2672, 17.3724, 17.2541],
        [11.8764, 1

batch_predictions
tensor([[14.0754, 16.6928, 18.5512, 19.0614, 18.9009, 18.6230, 18.3916],
        [20.9111, 20.4720, 20.4329, 20.5419, 20.6859, 20.8361, 20.9826],
        [26.0381, 34.1860, 35.6882, 34.0283, 29.1283, 20.2654, 19.5098],
        [ 5.4932,  5.8330,  5.8773,  5.8699,  5.8640,  5.8537,  5.8456],
        [16.9971, 22.1489, 25.9907, 26.4692, 25.8520, 24.7788, 23.3618],
        [19.9176, 18.2016, 19.3739, 20.3728, 21.1359, 21.6760, 22.0335],
        [14.1187, 13.5346, 13.5784, 13.9069, 14.3442, 14.7917, 15.1826],
        [21.4941, 18.4840, 19.9546, 21.5464, 23.0321, 24.1708, 24.8503]])
ground truth
tensor([[10.1920, 10.7575,  8.8901, 18.7796, 23.3824,  2.5381,  4.8264],
        [28.1321, 11.6071, 13.2653, 13.8322, 13.0527, 15.6463, 30.0028],
        [20.3051, 19.3714, 85.1920, 53.9058, 64.7291, 14.0978, 12.3882],
        [ 9.0084, 11.5071, 17.7275, 18.9506, 27.6433,  6.3388, 13.9926],
        [12.9252, 15.5612, 13.6621, 20.1814, 33.9711, 33.8719,  6.3067],
        [30.3571,  

batch_predictions
tensor([[24.0542, 21.1287, 21.8097, 22.9346, 23.9495, 24.7611, 25.3640],
        [12.4176, 14.8773, 17.4801, 17.1947, 16.4462, 15.7586, 15.2228],
        [24.2447, 21.1191, 22.5765, 24.3042, 25.8126, 26.9099, 27.5884],
        [13.6136, 15.8645, 16.7933, 16.3067, 15.6310, 15.0281, 14.5351],
        [24.8436, 33.9858, 35.4315, 33.5789, 28.6290, 20.0369, 19.5114],
        [23.8423, 21.6781, 21.7080, 22.4715, 23.2624, 23.9426, 24.4943],
        [19.1273, 22.4892, 21.5618, 19.7591, 17.4204, 15.7376, 15.6139],
        [20.1398, 18.5343, 18.4356, 19.0097, 19.6496, 20.2256, 20.7221]])
ground truth
tensor([[29.2800, 15.3203, 17.2336, 22.5624,  9.6655, 20.2381, 15.0652],
        [11.5729, 15.9258, 23.1326, 23.6981, 14.7291, 14.2688, 13.0852],
        [ 5.3430, 12.9252, 15.5612, 13.6621, 20.1814, 33.9711, 33.8719],
        [22.7891, 25.4819, 27.5794, 15.4337, 15.3203, 11.9615, 10.0624],
        [34.0873, 63.3877, 44.1610, 13.9137, 12.0594, 15.0316, 14.1241],
        [29.2092, 2

batch_predictions
tensor([[17.8973, 18.3705, 18.2582, 18.0384, 17.8576, 17.7500, 17.7070],
        [ 8.2612,  4.7894,  4.1097,  4.0827,  4.0903,  4.1007,  4.1112],
        [ 7.7181,  7.9785,  8.0481,  8.0582,  8.0526,  8.0407,  8.0360],
        [18.6411, 24.2352, 23.2220, 21.0458, 18.1196, 16.4238, 16.5265],
        [15.3079, 17.7360, 17.6884, 16.9026, 16.0775, 15.3657, 14.8019],
        [21.7137, 30.6095, 31.3264, 28.8602, 22.2881, 16.5321, 17.1057],
        [18.9773, 21.9048, 21.0829, 19.8777, 18.6234, 17.4582, 16.7335],
        [22.8367, 22.1332, 21.9368, 22.0195, 22.2544, 22.6011, 23.0341]])
ground truth
tensor([[11.8339, 12.8118, 18.2256, 23.6536, 21.5986, 14.3282, 15.3628],
        [ 7.3777,  8.7980, 13.2430, 14.8606, 10.3235,  7.1278,  5.4182],
        [ 8.3509, 11.9411, 12.8485, 18.6875, 25.3025,  0.3945,  2.7486],
        [10.5997,  8.6007, 13.3614, 20.9890, 20.2130, 34.7712, 13.7559],
        [28.9716, 21.1205, 10.5602, 14.1373, 10.6260, 11.3493, 13.4929],
        [14.5844, 3

batch_predictions
tensor([[26.6043, 26.8524, 26.8378, 26.7709, 26.7120, 26.6812, 26.6779],
        [14.7442, 19.5843, 24.0064, 25.2008, 24.5966, 23.3535, 21.7425],
        [19.7748, 24.1122, 25.5139, 24.8106, 23.2612, 20.9461, 18.6676],
        [10.9709, 12.6621, 15.3725, 18.4359, 17.2926, 15.8595, 14.4979],
        [ 5.1109,  4.9778,  4.9665,  4.9739,  4.9795,  4.9823,  4.9848],
        [18.4019, 24.0471, 28.2371, 28.0593, 26.4021, 22.4066, 16.7706],
        [ 3.9879,  4.0501,  4.0480,  4.0475,  4.0472,  4.0401,  4.0358],
        [17.1652, 17.3559, 17.2220, 17.0382, 16.8999, 16.8294, 16.8160]])
ground truth
tensor([[26.7290, 22.3073, 24.5323, 35.2749, 19.3311, 19.4586, 19.9546],
        [13.5913, 13.1236, 15.6604, 26.1054, 35.6859, 35.4734, 12.5567],
        [12.0890, 24.2347, 28.6848, 41.6525, 22.2222, 10.4875, 16.7375],
        [18.4771, 20.4235,  5.1420,  9.3766,  6.7728,  7.4566,  9.9027],
        [ 6.7035, 12.1032, 11.2670,  5.7540,  4.8753,  5.2154,  6.3776],
        [13.6196, 1

batch_predictions
tensor([[28.6567, 24.4684, 25.5368, 27.0341, 28.3722, 29.4333, 30.2128],
        [20.4142, 27.4336, 28.7962, 27.7751, 25.4920, 20.8067, 18.2034],
        [31.5873, 36.4267, 36.4067, 35.8275, 35.1598, 34.5368, 34.0137],
        [ 6.5374,  6.3563,  6.3279,  6.3316,  6.3366,  6.3384,  6.3401],
        [16.0105, 19.7678, 20.2742, 19.3803, 18.3191, 17.3231, 16.5507],
        [11.4873, 12.5460, 13.5013, 14.5221, 15.3218, 15.3657, 15.0401],
        [10.7740, 12.1469, 13.8178, 16.1857, 16.8823, 16.2133, 15.5817],
        [12.9082, 11.9767, 11.6449, 11.5556, 11.6428, 11.8662, 12.2351]])
ground truth
tensor([[14.7534,  0.0000,  5.5272, 25.3685, 22.8175, 31.7177, 48.8237],
        [21.6837, 47.6190, 44.4870, 23.9229,  9.2120, 19.2177, 14.5833],
        [24.9150, 20.4082, 24.6457, 30.6548, 41.8651, 41.1706, 19.3311],
        [ 9.3254,  8.3900, 12.6417,  9.8923,  6.0091,  9.4813,  9.5663],
        [13.4637,  9.4813, 14.1723, 16.2415, 22.9592, 19.5437, 11.1961],
        [ 9.7712, 1

batch_predictions
tensor([[20.6908, 20.6119, 20.6253, 20.6396, 20.6482, 20.6627, 20.6888],
        [20.1136, 21.9081, 21.5004, 20.6908, 19.8352, 19.0856, 18.5636],
        [16.0272, 15.5481, 16.8627, 17.9487, 18.6921, 19.0954, 19.2922],
        [18.6892, 19.7645, 19.4856, 18.9669, 18.5028, 18.1833, 18.0219],
        [24.5155, 21.4864, 21.4245, 22.4470, 23.5421, 24.5284, 25.3699],
        [ 7.8695,  8.1284,  8.1933,  8.2007,  8.1946,  8.1808,  8.1744],
        [22.6242, 20.1332, 21.5129, 22.8440, 23.9784, 24.8873, 25.5562],
        [17.5070, 17.2889, 17.3312, 17.4218, 17.5065, 17.5792, 17.6428]])
ground truth
tensor([[26.0346, 40.4337, 38.0811, 15.7313, 17.2902, 12.9677, 12.0890],
        [15.5181, 17.6881, 27.7880, 27.9853, 13.8217, 12.8880, 13.8480],
        [31.4768,  9.5380, 10.9836, 18.2256, 11.9615, 16.8226, 25.1276],
        [11.1820, 12.7834, 17.6446, 21.4427, 28.7557, 16.5391, 12.6701],
        [41.1423,  7.4688, 18.0556, 13.3503, 21.2302, 18.2965, 32.6247],
        [ 8.6402,  

batch_predictions
tensor([[17.5783, 23.3811, 23.5766, 22.7145, 21.3348, 19.3988, 17.6935],
        [17.2686, 23.7889, 25.4946, 24.2027, 21.1823, 16.1836, 15.7143],
        [19.7934, 18.8791, 18.7274, 19.0310, 19.5673, 20.2144, 20.8994],
        [27.0867, 24.6042, 24.3140, 24.9452, 25.7358, 26.4674, 27.0869],
        [17.4081, 21.0140, 20.3511, 19.1292, 17.8055, 16.6606, 15.9538],
        [20.5082, 27.1540, 27.0170, 25.3386, 21.9378, 18.1794, 17.6142],
        [21.6711, 24.3585, 23.9205, 22.7819, 21.3050, 19.7371, 18.6795],
        [ 9.8058,  9.9289, 10.0375, 10.1130, 10.1498, 10.1595, 10.1578]])
ground truth
tensor([[18.8776, 36.6071, 12.6276, 23.2710,  5.2721,  8.8152, 13.6338],
        [11.6497, 11.0544, 15.4337, 23.5969, 37.2591, 28.4297, 11.9898],
        [34.1005, 44.3977, 12.7696, 15.4655, 10.8890, 11.4413, 19.9106],
        [26.6570,  0.0000, 19.7659, 18.4377, 18.7796, 11.4545, 38.4271],
        [18.9342, 28.8265,  8.0074,  7.3980, 14.2149, 12.8260, 13.8889],
        [ 0.0000, 1

batch_predictions
tensor([[37.9151, 34.9278, 34.9307, 35.6853, 36.3520, 36.7861, 37.0472],
        [31.6329, 30.4850, 30.1875, 30.2739, 30.5281, 30.8798, 31.2907],
        [20.2446, 19.9488, 19.9671, 20.0592, 20.1563, 20.2450, 20.3254],
        [20.9483, 20.2989, 20.2386, 20.4147, 20.6569, 20.9015, 21.1251],
        [22.9056, 29.2635, 28.5031, 26.6874, 23.7833, 19.9504, 19.2252],
        [19.9212, 26.0551, 25.0001, 22.0507, 17.4090, 16.0887, 16.3134],
        [24.7912, 28.7969, 28.4571, 27.1191, 24.9178, 21.6763, 19.5508],
        [19.6040, 25.4072, 27.9121, 27.2077, 25.3925, 21.6388, 18.1184]])
ground truth
tensor([[49.4048, 30.6548, 27.6927, 27.5935, 34.5805, 45.1956, 55.9382],
        [55.0595, 37.2166, 10.4025, 12.1173, 19.8696, 18.7075, 31.6752],
        [31.9728, 26.0062, 17.4603, 12.3724, 17.7012, 13.6763, 16.9643],
        [30.7601, 28.3535, 18.5560, 15.0316, 13.0195, 16.5308, 20.0552],
        [18.6744, 19.7922, 26.0521, 36.4150, 26.6570,  0.0000, 19.7659],
        [19.9688, 3

batch_predictions
tensor([[14.3901, 15.9279, 16.1136, 15.7195, 15.2758, 14.9240, 14.6866],
        [14.9194, 14.8305, 15.1144, 15.3986, 15.6039, 15.7423, 15.8362],
        [24.0951, 29.2491, 28.5517, 27.0241, 24.7931, 21.6123, 19.8828],
        [29.9538, 29.1231, 28.9738, 29.0998, 29.3157, 29.5606, 29.8073],
        [22.1493, 21.8468, 21.8478, 21.9333, 22.0249, 22.1129, 22.1977],
        [13.1925, 15.9760, 17.5821, 18.1396, 18.1207, 17.9204, 17.7071],
        [19.0295, 20.3131, 20.0216, 19.3404, 18.6301, 18.0546, 17.6949],
        [10.1058, 11.4428, 13.0511, 15.3565, 16.9553, 16.2803, 15.5819]])
ground truth
tensor([[15.6234, 20.8969, 16.2809,  7.0095,  9.9947,  9.3635, 10.2709],
        [17.5171, 25.8285, 16.5965, 10.4287,  9.4555, 23.4219, 18.0431],
        [23.6718, 35.0342, 15.3077, 49.8948, 23.3167, 28.6691, 25.3419],
        [49.4898, 34.3537,  0.0000, 17.3328, 20.5924, 25.5527, 33.3900],
        [30.6286, 35.6260,  1.9858, 11.2309, 16.7017, 14.1241, 19.8448],
        [ 0.0000,  

batch_predictions
tensor([[28.8964, 23.0108, 24.0570, 26.0144, 28.1261, 30.1514, 32.0025],
        [19.1699, 17.3356, 17.5166, 18.2881, 19.0176, 19.6350, 20.1518],
        [20.0278, 20.6968, 20.4638, 20.0799, 19.7476, 19.5296, 19.4316],
        [18.7577, 17.3902, 18.5848, 19.5540, 20.2471, 20.7026, 20.9914],
        [18.3969, 17.7498, 17.7210, 17.9204, 18.1602, 18.3819, 18.5722],
        [12.6318, 15.1276, 17.3050, 16.9399, 16.3205, 15.8145, 15.4729],
        [16.8146, 20.1505, 19.6227, 18.6069, 17.5846, 16.7429, 16.1926],
        [14.3040, 15.0196, 17.1140, 19.2349, 20.9720, 21.1719, 20.9561]])
ground truth
tensor([[79.2346, 34.7054,  8.9164, 16.4256, 14.6370, 31.2599, 58.8901],
        [13.0385, 14.8810, 12.0748, 15.7596, 22.6332, 32.3554, 21.4994],
        [15.8305, 14.6117, 35.0198, 13.7613, 28.1037, 15.6463, 11.8339],
        [33.6451, 14.3566, 11.7772, 12.5850, 16.2982, 20.7908, 28.6423],
        [22.0947, 19.0193,  8.9144, 10.2466, 11.2387,  9.7647,  9.6655],
        [ 8.6026, 1

tensor([[ 8.3377, 13.3482, 11.1915, 12.3948, 12.3027, 24.9737,  7.7196],
        [ 7.1003, 12.6276, 18.7358, 16.4966, 27.4235, 30.6831,  0.1559],
        [12.0068,  3.6428,  9.0084,  7.3514,  7.4171, 11.2046, 14.0978],
        [ 7.4972,  8.6168,  7.4405, 12.2874,  6.3209,  4.7194,  5.8248],
        [ 8.1273,  9.4161, 14.6502, 27.1305, 18.0957,  0.0000,  9.3503],
        [10.9410, 11.7347, 23.1718, 16.5675, 22.6190, 12.2307, 15.0510],
        [ 4.8895,  6.7460, 11.3379,  7.3838,  6.4626,  8.5034,  6.1933],
        [18.5941, 11.1820, 13.4637,  9.4813, 14.1723, 16.2415, 22.9592]])
batch_predictions
tensor([[24.9929, 33.7334, 36.5796, 35.6083, 32.2872, 24.1352, 19.5285],
        [24.6778, 26.9930, 27.0466, 26.2683, 24.9697, 23.3128, 21.6626],
        [16.6957, 15.5003, 16.3945, 17.0878, 17.5713, 17.8947, 18.1081],
        [16.8387, 21.4546, 21.2136, 20.2333, 19.0453, 17.9164, 17.1503],
        [17.2036, 21.7722, 20.5988, 18.4847, 16.1421, 14.8219, 14.7620],
        [19.8656, 21.2465, 20.79

batch_predictions
tensor([[18.8524, 17.3211, 17.7718, 18.3755, 18.8520, 19.2073, 19.4693],
        [11.8623, 10.7565,  9.4294,  7.9810,  7.2632,  7.0844,  7.0538],
        [16.4903, 19.7125, 19.8567, 19.2171, 18.5500, 18.0239, 17.6876],
        [10.0259, 11.0449, 12.0531, 13.4753, 15.9372, 17.5158, 16.6020],
        [15.4248, 18.2747, 18.4531, 17.7962, 17.1327, 16.6166, 16.2784],
        [11.4276, 12.4684, 13.6782, 14.9741, 15.3527, 15.0361, 14.5953],
        [16.9189, 19.1220, 18.4920, 17.4780, 16.5251, 15.7612, 15.2786],
        [20.5680, 23.8333, 23.1193, 21.4155, 18.9170, 16.7916, 16.5073]])
ground truth
tensor([[22.8316, 13.5629, 13.8322, 12.4008, 16.2982, 17.0210, 25.9921],
        [ 6.4484,  4.2800,  7.4972,  8.6168,  7.4405, 12.2874,  6.3209],
        [14.5408, 14.5550, 14.5125, 16.5108, 22.5907, 22.0380, 13.8464],
        [14.0715, 12.9274,  9.5608, 12.7564, 19.2530, 16.7675,  0.0000],
        [21.8396, 31.0941, 27.3384, 10.6718, 13.1378, 11.7772, 14.3141],
        [13.1247, 1

batch_predictions
tensor([[14.0748, 16.4938, 17.3328, 16.7721, 16.0892, 15.5101, 15.0689],
        [15.7796, 18.3134, 18.3974, 17.7759, 17.1524, 16.6729, 16.3634],
        [ 9.5834, 10.0176, 10.3906, 10.7536, 11.1570, 11.6952, 12.5679],
        [11.5215, 13.1610, 15.2000, 16.8699, 16.4095, 15.8384, 15.4028],
        [20.0002, 20.5102, 20.2152, 19.7384, 19.2983, 18.9894, 18.8413],
        [17.5368, 19.7948, 19.3254, 18.5544, 17.8528, 17.3193, 16.9884],
        [19.6665, 24.9070, 23.6797, 21.2602, 18.0310, 16.3840, 16.5997],
        [27.8241, 34.8958, 35.0245, 33.7253, 31.5199, 28.0216, 23.4091]])
ground truth
tensor([[12.7038, 12.9669, 17.1357, 20.1604, 25.4603, 10.7575, 11.5992],
        [13.3219, 18.3325, 22.9748, 12.5197, 11.3887,  8.4955, 11.3098],
        [ 8.5350, 10.1789, 34.1136, 45.1999,  0.0000, 11.0468,  9.4950],
        [12.3356, 23.7507,  3.2351,  5.6549, 11.5860,  9.6002,  9.0347],
        [25.8548, 31.4177,  5.1815, 16.7938,  5.7075, 15.7812, 16.2415],
        [15.0794, 2

batch_predictions
tensor([[20.6760, 24.8700, 23.9970, 22.2706, 19.8993, 17.6549, 17.2853],
        [ 7.9131,  7.7080,  7.5902,  7.5307,  7.5065,  7.5023,  7.5010],
        [25.4593, 26.8368, 26.6967, 26.1970, 25.6331, 25.1373, 24.7739],
        [20.0319, 17.8696, 19.3773, 20.8040, 21.9551, 22.7191, 23.1617],
        [15.7271, 17.7457, 17.7101, 17.1396, 16.5871, 16.1756, 15.9193],
        [20.8623, 20.7451, 20.7683, 20.8022, 20.8291, 20.8565, 20.8890],
        [17.0511, 22.0709, 21.6073, 20.0210, 17.7351, 15.8205, 15.4153],
        [37.4593, 34.3173, 33.5850, 34.1236, 35.0925, 36.0087, 36.6641]])
ground truth
tensor([[16.3549, 15.3345, 22.6899, 38.6196,  8.4184, 12.9819, 14.3849],
        [21.4361, 24.6055, 14.4924, 16.4519, 13.3219, 14.5187, 18.2667],
        [20.4507, 23.7670, 21.2018, 24.2772, 33.4467, 26.0771, 21.3294],
        [23.2426,  5.3288, 16.4257, 11.0686, 14.4274, 21.9246, 35.7001],
        [15.5313, 15.0973,  9.3898,  0.0000,  0.4471,  7.2330,  0.0000],
        [11.4654, 1

batch_predictions
tensor([[10.9599, 13.1051, 15.6644, 18.7694, 18.5110, 17.6093, 16.7087],
        [11.0359, 12.8435, 14.6408, 16.5793, 17.3967, 17.1779, 16.7665],
        [15.9910, 18.2343, 18.2505, 17.6777, 17.1156, 16.6956, 16.4351],
        [21.3703, 21.0687, 21.0692, 21.1663, 21.2767, 21.3857, 21.4895],
        [17.8662, 21.1121, 20.6770, 19.7819, 18.8837, 18.1542, 17.6815],
        [18.0306, 20.2576, 19.6784, 18.6459, 17.5574, 16.5874, 15.9491],
        [22.0175, 26.3760, 25.5447, 23.6120, 20.5131, 17.7115, 17.6144],
        [19.7790, 19.7137, 19.7245, 19.7338, 19.7365, 19.7436, 19.7589]])
ground truth
tensor([[ 3.7612,  3.3798,  8.4429, 11.6386, 15.0053, 17.4908,  6.3388],
        [ 7.1541,  3.1168,  9.9553, 25.1052, 29.4450, 27.8406,  9.5871],
        [11.3887,  8.4955, 11.3098, 18.7533, 26.3545, 28.5245, 15.0053],
        [24.0400, 32.6407,  6.4703, 15.3077, 15.5576, 17.5302, 25.6970],
        [17.3330, 22.2120, 21.3309, 26.4335, 13.1510,  3.7612, 10.0605],
        [ 8.5744, 1

tensor([[ 9.7506,  4.5918,  3.9966,  4.9461,  5.6406,  5.9807,  8.9144],
        [ 2.1831, 10.7312, 11.4676, 13.7691, 20.3314, 17.7012,  6.6807],
        [21.8963, 24.9150, 20.4082, 24.6457, 30.6548, 41.8651, 41.1706],
        [18.1220, 28.2746, 24.9079, 29.5634, 34.5739, 45.1604, 18.3456],
        [42.8713, 35.3600, 16.1706, 11.4796, 15.0368, 24.1780, 33.7160],
        [ 9.1268,  6.7859, 12.7959, 12.5592, 17.4250, 21.4361,  8.6665],
        [16.2020,  7.1147, 12.0068,  4.9448,  9.5739, 12.4803, 18.8585],
        [19.0295, 27.5118, 32.4829, 11.0074, 13.0589, 16.5308, 11.3756]])
batch_predictions
tensor([[17.9269, 17.8578, 17.8838, 17.9105, 17.9287, 17.9467, 17.9681],
        [21.4557, 22.8661, 22.5223, 21.6690, 20.6432, 19.6598, 18.9364],
        [18.9412, 18.8283, 18.8413, 18.8628, 18.8777, 18.8937, 18.9149],
        [12.9921, 12.5195, 12.2915, 12.1850, 12.1649, 12.2201, 12.3317],
        [17.8444, 18.7983, 18.5068, 17.9514, 17.4309, 17.0549, 16.8531],
        [19.3142, 21.3508, 20.86

batch_predictions
tensor([[38.3163, 38.3162, 38.3162, 38.3161, 38.3158, 38.3157, 38.3157],
        [28.0717, 32.5993, 32.7363, 31.5815, 29.6247, 26.5441, 22.7596],
        [15.7565, 17.4529, 17.4090, 16.9357, 16.4891, 16.1712, 15.9880],
        [15.7323, 19.8631, 22.7432, 22.8577, 22.2733, 21.5737, 20.9679],
        [21.2532, 28.7907, 28.1979, 26.1098, 22.0052, 17.5745, 17.8091],
        [17.3846, 18.0373, 17.8776, 17.5445, 17.2537, 17.0678, 16.9856],
        [20.9187, 27.4265, 32.4229, 33.2183, 32.2636, 30.1373, 25.8488],
        [20.1533, 20.6952, 20.4712, 20.1558, 19.9097, 19.7667, 19.7163]])
ground truth
tensor([[45.5155, 53.4850, 55.9179, 38.1378, 40.7286, 22.9090, 46.3177],
        [43.3532, 41.5533, 32.8656, 15.5329, 17.9847, 18.8209, 19.7988],
        [16.0047, 22.6854, 23.3035,  9.1268, 11.8622,  9.9684, 10.4419],
        [ 7.2199, 17.6355, 15.8206, 15.5050, 19.8974, 31.9306,  4.0505],
        [14.3141, 23.0300, 45.5782, 37.4717, 22.8175, 18.1973, 13.1236],
        [21.1026, 2

batch_predictions
tensor([[19.1608, 23.7641, 22.8315, 20.7954, 18.0857, 16.2096, 16.1897],
        [17.2712, 18.5797, 18.1889, 17.4053, 16.6163, 15.9849, 15.5957],
        [22.4045, 20.4477, 20.1183, 20.6909, 21.4845, 22.2855, 23.0509],
        [22.3519, 20.1296, 21.3131, 22.4504, 23.3517, 24.0129, 24.4675],
        [17.4670, 19.2859, 18.9076, 18.3077, 17.8061, 17.4686, 17.2967],
        [15.9533, 20.2338, 20.1696, 19.0379, 17.5848, 16.1009, 15.0220],
        [17.8304, 16.7275, 16.7275, 17.1008, 17.4742, 17.7853, 18.0349],
        [13.7879, 13.4163, 13.4479, 13.6370, 13.8897, 14.1530, 14.3866]])
ground truth
tensor([[15.0842, 15.7943, 18.1352, 23.6586, 39.3083, 29.1426, 19.5292],
        [14.7159, 14.4924, 12.1515, 20.7128,  8.2062, 15.1368, 12.1515],
        [33.4042, 16.5391, 10.2324,  9.3821, 14.3566, 18.1548, 32.7523],
        [22.4348, 24.7307, 28.9116, 23.0159, 21.3435, 25.3827, 31.5618],
        [13.2937, 19.4586, 28.3305, 25.7086, 13.8039, 15.0935, 23.4127],
        [33.5317, 3

batch_predictions
tensor([[26.4794, 22.4454, 23.2945, 24.7215, 26.0995, 27.2714, 28.1608],
        [21.3469, 22.4208, 22.2031, 21.6276, 21.0147, 20.5370, 20.2361],
        [17.6994, 20.4841, 19.8819, 18.7890, 17.6452, 16.6242, 15.9265],
        [24.3478, 32.0240, 33.5073, 32.0931, 29.2557, 24.3683, 21.1520],
        [21.4585, 29.8022, 30.1548, 29.0585, 27.1224, 23.3742, 19.1019],
        [14.3460, 13.9601, 13.9835, 14.1283, 14.2972, 14.4554, 14.5885],
        [15.6163, 20.2017, 20.9062, 20.0341, 18.9012, 17.6956, 16.6413],
        [11.9975, 15.7776, 20.9145, 23.7011, 23.0735, 21.7456, 20.0775]])
ground truth
tensor([[33.6026,  0.0000, 10.6293, 19.5437, 17.3044, 22.1514, 36.2812],
        [24.2489, 24.4898, 12.2449, 14.6825, 14.1723, 12.1032, 18.4807],
        [ 9.7931, 17.0777, 13.4779, 32.8656, 22.9592, 10.4167, 15.0227],
        [14.6400, 24.2772, 33.6168, 46.8679,  9.8498, 12.5567, 16.1706],
        [23.4127, 39.5125,  6.8736,  7.3838, 20.8192, 11.5930, 15.8022],
        [14.0321,  

batch_predictions
tensor([[36.4363, 37.6342, 37.5869, 37.4495, 37.3413, 37.2727, 37.2355],
        [11.3721, 12.9281, 14.6426, 16.6456, 16.8735, 16.2907, 15.7502],
        [27.6403, 28.0762, 28.0666, 27.9161, 27.7427, 27.6090, 27.5382],
        [16.4756, 15.2959, 15.8145, 16.3787, 16.7937, 17.0870, 17.2936],
        [20.7028, 27.1918, 26.1701, 23.5070, 19.0133, 16.8180, 17.1103],
        [17.7237, 17.9961, 17.8954, 17.7292, 17.5972, 17.5234, 17.4999],
        [15.4378, 15.6539, 15.5983, 15.4885, 15.4008, 15.3531, 15.3376],
        [19.2818, 20.7324, 20.3799, 19.7224, 19.0891, 18.5976, 18.2982]])
ground truth
tensor([[39.1298, 59.1128, 46.8821, 34.0561, 23.5686, 27.1684, 25.6519],
        [ 9.1005, 10.1131, 11.1915, 13.7296, 14.6107, 25.3288, 27.5776],
        [ 2.4093, 12.2307, 23.4127, 39.5125,  6.8736,  7.3838, 20.8192],
        [23.3035,  9.1268, 11.8622,  9.9684, 10.4419, 17.2278, 20.8574],
        [19.1185, 21.0317, 34.8639, 28.4722, 30.2863, 21.1593, 20.4507],
        [22.7117, 2

batch_predictions
tensor([[23.3573, 23.5937, 23.5393, 23.3983, 23.2596, 23.1670, 23.1309],
        [18.5532, 24.2130, 25.1992, 24.1218, 22.0412, 18.8364, 17.1704],
        [ 8.6403,  8.8162,  8.8854,  8.9028,  8.8942,  8.8768,  8.8645],
        [17.8932, 20.6798, 20.0980, 19.0940, 18.0822, 17.2117, 16.6064],
        [22.9323, 31.0417, 31.0086, 29.6478, 27.2681, 23.0949, 19.3857],
        [34.3855, 35.1159, 35.2179, 35.1142, 34.9514, 34.8081, 34.7142],
        [17.6991, 16.5022, 17.5623, 18.3643, 18.9118, 19.2695, 19.5028],
        [19.9124, 26.8376, 26.4872, 24.8411, 21.6774, 18.1709, 17.8122]])
ground truth
tensor([[21.1026, 37.2024, 39.8243, 10.0198, 14.9235, 22.3781, 19.3027],
        [12.8968, 13.5771, 10.0624, 17.0068, 31.5760, 42.7863, 30.5981],
        [11.9937, 15.0710, 15.6628,  8.5087,  7.9300,  5.9442,  8.6928],
        [15.6102, 15.8075, 17.2541, 28.0642, 31.3914, 13.7822, 16.7543],
        [25.9070, 42.4745, 41.1565, 12.5850, 21.3010, 23.5402, 17.7154],
        [56.7596, 5

batch_predictions
tensor([[38.3425, 38.3170, 38.3252, 38.3286, 38.3301, 38.3313, 38.3324],
        [35.6460, 37.1407, 37.1617, 36.9836, 36.8050, 36.6759, 36.6000],
        [18.5056, 18.8456, 18.7402, 18.5276, 18.3409, 18.2243, 18.1776],
        [24.2084, 32.6935, 32.7397, 30.5644, 25.9083, 18.9949, 18.9871],
        [22.8709, 20.2159, 20.8806, 21.9765, 22.9643, 23.7601, 24.3631],
        [10.6591, 10.2521,  9.7334,  9.1230,  8.5444,  8.1494,  7.9491],
        [12.9780, 16.2367, 18.7760, 17.6275, 16.2627, 14.9602, 13.7482],
        [15.2931, 17.3112, 17.1958, 16.4858, 15.7609, 15.1531, 14.6915]])
ground truth
tensor([[52.1699, 63.6376, 64.6107, 41.6491, 42.5829, 45.9890, 42.5565],
        [31.0091, 42.5454, 67.6446, 59.4388, 35.5726, 34.4246, 21.9104],
        [20.7391, 17.7538, 11.2704, 10.8759, 13.9400, 14.9921, 15.5050],
        [17.3753, 32.1854, 55.2721, 40.0652, 16.5816, 21.7545, 18.0981],
        [31.3519, 22.5671, 19.6476, 16.4913, 17.6092, 33.2062, 34.9816],
        [10.5471, 1

batch_predictions
tensor([[18.2402, 19.0969, 18.9114, 18.5428, 18.2262, 18.0227, 17.9298],
        [20.8251, 21.3116, 21.1973, 20.9581, 20.7452, 20.6072, 20.5467],
        [15.4739, 18.7322, 19.1431, 18.5457, 17.9502, 17.5081, 17.2473],
        [19.4740, 19.0987, 19.0933, 19.1905, 19.2935, 19.3868, 19.4710],
        [25.6109, 27.0227, 26.8470, 26.3774, 25.9087, 25.5460, 25.3227],
        [ 8.9671,  7.6135,  6.9977,  6.8510,  6.8381,  6.8521,  6.8655],
        [19.8423, 18.3950, 19.7245, 20.9428, 21.9607, 22.7580, 23.3208],
        [23.6945, 32.7453, 33.4388, 30.7789, 23.9927, 17.2314, 17.9141]])
ground truth
tensor([[13.0669, 13.8889, 16.5675, 20.2806,  0.0000,  0.0000,  9.8498],
        [29.6769, 30.7256, 16.7375, 12.0465, 12.2732, 16.7659, 21.1593],
        [ 8.4467, 19.9121, 23.8237, 21.4144, 12.0748, 14.1015, 11.0261],
        [33.0089, 27.3277, 14.9264, 19.4897, 12.3093, 14.4266, 20.7522],
        [20.9100, 29.1689, 32.1278, 23.1720, 25.7759, 17.2935, 18.1878],
        [15.6463,  

batch_predictions
tensor([[18.4324, 19.2092, 18.7919, 18.2295, 17.7667, 17.4738, 17.3562],
        [17.0851, 16.3691, 17.7905, 19.0752, 20.0894, 20.7209, 21.0076],
        [16.0268, 19.7119, 19.9276, 18.8251, 17.4491, 16.1003, 15.1170],
        [22.6423, 30.7933, 35.7162, 35.3834, 33.2568, 28.1030, 20.0753],
        [18.8497, 17.9818, 17.8700, 18.1020, 18.4359, 18.7774, 19.0955],
        [ 7.5017,  7.0140,  6.8241,  6.7673,  6.7584,  6.7609,  6.7638],
        [15.2833, 18.0071, 18.3189, 17.6916, 17.0519, 16.5574, 16.2354],
        [24.8533, 25.8627, 25.7601, 25.3686, 24.9398, 24.5890, 24.3627]])
ground truth
tensor([[17.9705, 24.5607, 29.8895, 31.8027,  6.1933, 10.5867, 17.9989],
        [23.1009,  6.5476,  9.8781, 11.3379, 11.5079, 33.5317, 33.4467],
        [10.8560, 11.7772, 10.4308, 13.2937, 20.3656, 36.5363, 21.5703],
        [31.2599, 58.8901, 47.5802, 12.2830, 14.0058, 14.4135, 12.1383],
        [28.9847, 29.7344, 16.2152, 17.7538, 17.7275, 13.7428, 18.5429],
        [10.9942, 1

batch_predictions
tensor([[14.9755, 14.9061, 14.9483, 14.9906, 15.0218, 15.0498, 15.0769],
        [17.8152, 21.9455, 20.7959, 18.7072, 16.2553, 14.7750, 14.7314],
        [38.3157, 38.3865, 38.3816, 38.3763, 38.3730, 38.3715, 38.3704],
        [12.9096, 12.4576, 12.2146, 12.0716, 11.9916, 11.9577, 11.9565],
        [17.7748, 21.2495, 20.6662, 19.3675, 17.8758, 16.5404, 15.8046],
        [17.6481, 22.1831, 21.5909, 20.0313, 18.0018, 16.3440, 15.8232],
        [20.4514, 19.0769, 18.8775, 19.2296, 19.7018, 20.1563, 20.5634],
        [14.6247, 18.9158, 23.2282, 24.2871, 23.5628, 22.0845, 20.0442]])
ground truth
tensor([[22.1331, 23.2641,  9.7317, 13.5587, 10.4156, 13.3614, 18.2799],
        [ 9.5522, 16.1706, 17.0493, 28.8974, 29.7194, 11.1678, 13.4921],
        [40.9864, 57.5680, 33.1207, 20.5357, 16.6100, 18.3248, 19.6003],
        [ 9.4292,  9.5082,  7.8511,  8.8638, 10.9679, 16.6886, 14.1767],
        [20.5924, 25.3968, 28.3588, 10.8135,  8.6876,  9.1978, 13.0527],
        [12.8543, 1

batch_predictions
tensor([[10.5284, 11.7709, 13.1474, 15.2530, 17.2377, 16.6233, 15.8733],
        [13.1993, 16.3365, 20.1713, 19.5857, 18.4592, 17.3312, 16.5054],
        [16.2667, 20.9914, 22.1162, 21.3599, 20.2341, 18.8884, 17.5550],
        [21.3115, 19.1641, 19.6092, 20.4482, 21.1772, 21.7484, 22.1796],
        [ 6.3610,  6.5043,  6.5221,  6.5189,  6.5161,  6.5050,  6.5006],
        [26.3256, 29.9543, 29.6329, 28.6791, 27.4726, 25.9984, 24.2834],
        [18.8184, 25.4119, 25.2055, 23.6556, 20.6328, 17.2740, 17.0484],
        [20.3441, 23.4694, 22.8162, 21.5764, 20.0988, 18.5955, 17.7067]])
ground truth
tensor([[16.0856, 11.6780, 14.2149, 22.2080, 19.5295, 10.4592, 14.9235],
        [ 8.1931, 12.0594, 23.1983, 25.9863, 35.4024, 10.6128,  8.8769],
        [14.1582, 22.2647, 30.3713, 24.2489, 31.9019, 28.4297, 13.7046],
        [29.6344, 12.0040, 14.7251, 15.4195, 19.8271, 22.2222, 27.7636],
        [ 4.6291,  5.3524, 12.6775, 14.3609, 11.4150,  6.0757,  6.4045],
        [27.8538, 3

batch_predictions
tensor([[25.9853, 31.1061, 30.7113, 29.6441, 28.3822, 26.8818, 24.6640],
        [25.9626, 26.3909, 26.3244, 26.1794, 26.0582, 25.9866, 25.9603],
        [11.0017, 13.0984, 15.1034, 16.6922, 17.2463, 17.2197, 17.0393],
        [37.6264, 37.5105, 37.5235, 37.5297, 37.5327, 37.5396, 37.5514],
        [19.0584, 19.9883, 19.7372, 19.0981, 18.3903, 17.7966, 17.4179],
        [14.0625, 17.7285, 19.3539, 18.1884, 16.8121, 15.4302, 14.2253],
        [11.5939, 13.8251, 16.9055, 18.0643, 17.2466, 16.4784, 15.9172],
        [30.1812, 28.0405, 27.4209, 27.6717, 28.3254, 29.1015, 29.8721]])
ground truth
tensor([[27.6644, 30.0737, 36.8481, 50.2126, 12.7551, 15.4337, 23.6395],
        [20.7058, 23.4410, 26.3464, 25.0850, 32.9507, 16.3124, 19.0618],
        [14.7392, 14.1015, 15.0652, 19.6145, 25.0142,  9.7647,  3.2738],
        [54.6202, 49.0079, 31.3917, 28.8549, 22.5057, 29.1241, 37.3866],
        [28.0905, 27.4592,  8.3377, 13.3482, 11.1915, 12.3948, 12.3027],
        [ 9.2057, 1

batch_predictions
tensor([[18.6184, 22.9726, 21.5991, 19.5339, 17.3422, 16.0151, 16.0056],
        [19.6932, 17.7477, 18.1327, 18.9444, 19.6497, 20.2035, 20.6273],
        [21.5819, 19.5745, 21.4235, 23.3652, 25.2444, 26.9082, 28.0199],
        [11.2195, 12.4533, 14.1612, 16.1593, 16.1545, 15.5526, 14.9479],
        [27.5797, 22.7940, 24.0743, 25.8336, 27.5569, 29.0622, 30.2150],
        [19.8040, 18.0615, 19.3521, 20.4822, 21.3727, 22.0067, 22.4104],
        [21.3536, 19.4609, 19.8490, 20.7411, 21.5407, 22.1821, 22.6821],
        [12.5077, 14.4649, 16.5947, 16.8642, 16.3760, 15.9363, 15.6340]])
ground truth
tensor([[ 2.3277, 14.9395, 23.0274, 29.5371, 28.3535,  3.8269,  9.6265],
        [12.3866, 11.7489, 11.1961, 20.0397, 26.5448, 24.8866, 26.0771],
        [14.9802, 14.2857, 10.4025, 18.3815, 27.1825, 26.0913, 27.6502],
        [12.9932, 22.1462, 22.4487, 10.0999, 11.8096,  9.2451,  8.7454],
        [38.3645, 17.9280, 18.0414, 16.9785, 22.2931, 36.5221, 42.7863],
        [14.0448, 1

batch_predictions
tensor([[36.2520, 36.3922, 36.3932, 36.3590, 36.3275, 36.3134, 36.3159],
        [17.0430, 21.1992, 21.3579, 20.6928, 19.8952, 19.1675, 18.6294],
        [16.9447, 21.5590, 21.6362, 20.7603, 19.5709, 18.2932, 17.2783],
        [29.2462, 28.8100, 28.8303, 28.9277, 29.0257, 29.1134, 29.1923],
        [23.4296, 29.0157, 28.2555, 26.1215, 22.0840, 17.9584, 18.1250],
        [ 7.4874,  7.5631,  7.5801,  7.5784,  7.5735,  7.5590,  7.5514],
        [11.5380, 13.8056, 16.1947, 17.9257, 18.1276, 17.7674, 17.3466],
        [16.2777, 17.6403, 17.2775, 16.6171, 16.0192, 15.5767, 15.3135]])
ground truth
tensor([[54.6202, 46.4002, 33.4751, 26.0913, 21.0176, 29.5210, 39.1298],
        [12.8880, 11.1915, 18.8585, 24.8422,  1.4861,  0.0000, 13.7822],
        [20.0397, 26.5448, 24.8866, 26.0771,  7.8090, 10.5017, 14.7392],
        [20.4103, 23.3561, 26.1573, 32.8643, 35.7312, 23.0800, 25.5655],
        [17.0635, 31.1366,  0.6236, 22.7749, 26.4881, 15.5612, 16.5108],
        [ 8.6928, 1

batch_predictions
tensor([[19.7079, 21.1487, 20.6311, 19.9035, 19.2437, 18.7436, 18.4513],
        [21.2092, 26.2578, 25.4429, 23.1622, 19.2768, 16.8451, 17.0719],
        [21.3137, 27.3300, 26.3208, 23.7923, 19.2296, 16.5596, 16.8733],
        [13.6585, 16.7017, 19.4812, 18.5223, 17.2685, 16.1352, 15.2386],
        [22.2343, 22.4217, 22.4190, 22.3253, 22.2130, 22.1279, 22.0869],
        [21.4733, 21.4613, 21.4607, 21.4596, 21.4580, 21.4560, 21.4668],
        [17.7892, 22.5658, 23.7495, 21.9313, 18.2385, 15.1352, 14.8216],
        [23.4149, 20.5983, 20.9156, 21.9754, 23.0083, 23.8950, 24.6228]])
ground truth
tensor([[13.0102, 21.1026, 24.3197, 18.2540, 13.8180, 12.7409, 11.8056],
        [17.2902, 12.9677, 12.0890, 24.2347, 28.6848, 41.6525, 22.2222],
        [22.3781, 19.3027, 25.9779, 33.6735, 41.1423,  7.4688, 18.0556],
        [11.0205,  7.5881, 10.5471, 11.3230, 24.2109, 29.6028,  8.2720],
        [39.3740, 16.0179, 18.2930, 21.4624, 14.6633, 21.6728, 28.5113],
        [18.0272, 2

batch_predictions
tensor([[17.5629, 21.5274, 20.0976, 17.7681, 15.1275, 13.9333, 13.9012],
        [38.3899, 38.3587, 38.3666, 38.3698, 38.3711, 38.3720, 38.3730],
        [24.9087, 28.4616, 28.3253, 27.2843, 25.6301, 23.2812, 20.9403],
        [21.8366, 19.4721, 19.5490, 20.4864, 21.4518, 22.3171, 23.0788],
        [17.9087, 17.5730, 17.6024, 17.7250, 17.8526, 17.9649, 18.0606],
        [18.9633, 18.1753, 18.0533, 18.2212, 18.4778, 18.7446, 18.9961],
        [14.6515, 14.2414, 14.3246, 14.5710, 14.8478, 15.0987, 15.2981],
        [19.1450, 19.4013, 19.3383, 19.1968, 19.0711, 18.9932, 18.9636]])
ground truth
tensor([[ 9.1662, 15.3472, 20.3840, 23.9085,  5.0237, 10.7180,  8.3377],
        [42.5039, 56.3388, 23.0668,  0.0000, 29.9579, 43.4377, 46.8701],
        [22.9748, 43.0037, 42.7012, 17.3330, 18.4903, 15.9390, 22.3304],
        [36.0969,  6.7460, 14.0590, 14.7251, 18.8067, 21.1876, 37.0040],
        [26.8141, 26.7432, 17.0777, 12.3441, 15.3628, 13.0102, 21.1026],
        [23.0867, 1

batch_predictions
tensor([[23.5074, 25.8062, 25.4845, 24.5385, 23.2999, 21.9071, 20.6611],
        [11.8905, 13.8543, 16.5695, 17.8503, 16.9927, 16.1418, 15.4808],
        [17.3331, 16.8840, 16.8811, 17.0167, 17.1662, 17.3013, 17.4181],
        [24.4895, 29.5520, 29.0227, 27.5336, 25.2343, 21.7544, 19.7968],
        [10.4884, 11.5283, 12.6438, 14.2265, 16.2702, 16.3914, 15.7870],
        [29.0145, 28.4841, 28.3836, 28.4622, 28.5997, 28.7761, 28.9811],
        [19.0618, 19.7135, 19.3871, 18.8428, 18.3432, 17.9965, 17.8316],
        [18.3967, 22.6904, 22.1921, 21.1268, 19.9289, 18.9009, 18.3013]])
ground truth
tensor([[14.7422, 15.1631, 25.1578, 26.9069, 27.9853, 15.4787, 17.4382],
        [11.3361, 13.3877, 14.4266, 16.2283, 24.3556, 23.2378,  6.5623],
        [11.1961, 12.6701, 15.4337, 15.4195, 15.5612, 27.0408, 17.0210],
        [18.7533, 22.0410, 29.4845, 44.7002, 34.2451, 21.2388, 20.1078],
        [ 9.0986,  6.9728, 11.4371, 20.5074, 17.1202,  4.6910, 10.3175],
        [51.7715, 3

batch_predictions
tensor([[28.6590, 30.1266, 30.1227, 29.7369, 29.2613, 28.8291, 28.5020],
        [30.7605, 36.6681, 36.7319, 36.1108, 35.3245, 34.4918, 33.6482],
        [13.7242, 13.3830, 14.4571, 15.5258, 16.2304, 16.5435, 16.6681],
        [15.7956, 15.8385, 15.8223, 15.7830, 15.7482, 15.7326, 15.7350],
        [15.8835, 19.6559, 20.8388, 20.1222, 19.2337, 18.4131, 17.7684],
        [15.4641, 19.4472, 21.0162, 20.3168, 19.4104, 18.5510, 17.8422],
        [16.3740, 18.6976, 18.2902, 17.3790, 16.4904, 15.7603, 15.2475],
        [20.8158, 20.8549, 20.8480, 20.8057, 20.7582, 20.7297, 20.7279]])
ground truth
tensor([[16.7375, 18.1548, 19.3169, 25.5952, 47.6332, 14.7534,  0.0000],
        [24.8299, 29.9178, 44.8413, 49.2914,  7.4405, 13.5771, 21.7545],
        [18.6612,  4.7212,  7.8248,  9.1399, 10.6128, 11.7964, 23.1983],
        [17.4119, 16.7149,  6.3782,  8.9164, 10.7180, 14.1899, 25.9469],
        [12.0465, 19.2885, 30.4280, 22.6616, 11.3237, 14.3566, 11.1395],
        [11.8359, 1

tensor(75.5777)

In [None]:
e, d, t, gn, ea, da, dea = se, d, t, gn, ea, da, dea = sequence_iter(dl, 3)
print(gn)

In [None]:
t = torch.Tensor([[1, np.nan], [2, 2], [np.nan, np.nan]]).isnan().nonzero().T
print(t)
t[1][t[0] == 0]


In [None]:
d = SequenceDatasetForTransformer(train_sequences)
dl = DataLoader(d, batch_size=8, shuffle=True)

In [None]:
e, d, t, ea, da, dea = next(iter(dl))

In [None]:
t

In [None]:
ea.repeat(3, 1, 1).shape

In [None]:
process_batch_nulls(e,d,t)

In [None]:
# testing this vanilla transformer. 
# Note that to batch, examples within a single batch must have same enc and dec input lengths
test_transformer = Transformer()
test_encoder_x = torch.unsqueeze(torch.Tensor([[1, 2, 3], [7, 8, 9]]).T, -1)
test_decoder_x = torch.unsqueeze(torch.Tensor([[0, 4], [0, 10]]).T, -1)
ground_truth = torch.Tensor([[4, 5], [10, 11]]).T
test_dec_output = torch.squeeze(test_transformer(test_encoder_x, test_decoder_x), -1)
# ground truth and output are now output_len x bs
test_loss = F.mse_loss(test_dec_output, ground_truth)
test_loss.backward()


In [None]:
input1 = torch.Tensor([0, 1, 2, 3, 4])
input2 = torch.Tensor([0, 1, 2, 3, 4])
w1 = torch.nn.Parameter(torch.Tensor([[0, 1, 2, 3, 4],
                                     [0, 1, 2, 3, 4],
                                     [0, 1, 2, 3, 4],
                                     [0, 1, 2, 3, 4], 
                                      [0, 1, 2, 3, 4]]), requires_grad=True)
output1 = input1 @ w1
output2 = input2 @ w1
target1 = torch.Tensor([3, 4, 1, np.nan, 4])
target2 = torch.Tensor([3, np.nan, 1, np.nan, 4])
loss1 = F.mse_loss(output1, target1, reduction='none')
loss_without_nan1 = loss1[~torch.isnan(loss1)]
loss2 = F.mse_loss(output2, target2, reduction='none')
loss_without_nan2 = loss2[~torch.isnan(loss2)]
losses = torch.cat((loss_without_nan1, loss_without_nan2), 0)
final_loss = torch.sum(losses) / (len(loss_without_nan1) + len(loss_without_nan2))
final_loss.backward()


In [None]:
w1.grad

In [None]:
input1 = torch.Tensor([2, 1, 2, 3, 4])
input2 = torch.Tensor([3, 1, 2, 3, 4])
w1 = torch.nn.Parameter(torch.Tensor([[5, 1, 2, 3, 4],
                                     [5, 1, 2, 3, 4],
                                     [6, 1, 2, 3, 4],
                                     [8, 1, 2, 3, 4], 
                                      [9, 1, 2, 3, 4]]), requires_grad=True)
output1 = input1 @ w1
output2 = input2 @ w1
target1 = torch.Tensor([3, 4, np.nan, 4, 4])
indices1 = list((~target1.isnan()).nonzero(as_tuple=True))
output1_nonull = output1[indices1]
target1_nonull = target1[indices1]
target2 = torch.Tensor([3, np.nan, 1, np.nan, 4])
indices2 = list((~target2.isnan()).nonzero(as_tuple=True))
output2_nonull = output2[indices2]
target2_nonull = target2[indices2]
loss1 = F.mse_loss(output1_nonull, target1_nonull)
loss2 = F.mse_loss(output2_nonull, target2_nonull)
final_loss = (loss1 + loss2) / 2
final_loss.backward()


In [None]:
loss1

In [None]:
# pytorch test

inputs = torch.Tensor([4, 0, 7])
weights_vec1 = torch.nn.Parameter(torch.Tensor([[5, 4, 2], [3, 4, 1], [5, 3, 2], [1, 2, 3]]), requires_grad=True)
weights_vec2 = torch.nn.Parameter(torch.Tensor([[1, 2, 3], [4, 5, 6], [2, 3, 2], [4, 3, 1]]), requires_grad=True)
vec3 = torch.Tensor([[1, 2, 3], [4, 5, 6], [2, 3, 2], [4, 3, 1]])
q = weights_vec1 * inputs
k = weights_vec2 * inputs
logits = (q.T @ k) + torch.Tensor([[0, float('-inf'), 0], [0, float('-inf'), 0], [0, float('-inf'), 0]])
logits = logits / 1000  # div by 1000 just for stability for demonstration
logits.retain_grad()
print('logits: ', logits, '\n')
soft = torch.nn.functional.softmax(logits, dim=1)
soft.retain_grad()
print('softmax: ', soft, '\n')
logits2 = (soft @ vec3.T) @ vec3 + torch.Tensor([[0, float('-inf'), 0], [0, float('-inf'), 0], [0, float('-inf'), 0]])
print('logits2: ', logits2, '\n')
soft2 = torch.nn.functional.softmax(logits2, dim=1)
print('softmax: ', soft2, '\n')
final_m = (soft2 @ vec3.T).T
print('final m: ', final_m, '\n')

final = torch.sum(final_m)  # just to get some scalar value at the end
print('final: ', final, '\n')
final.backward()
print('soft gradients: ', soft.grad, '\n')
print('logits gradients: ', logits.grad, '\n')
print('wq gradients: ', weights_vec1.grad, '\n')
print('wk gradients: ', weights_vec2.grad, '\n')

In [None]:
# pytorch test

inputs = torch.Tensor([4, 0, 7])
weights_vec1 = torch.nn.Parameter(torch.Tensor([5, 4, 2]), requires_grad=True)
weights_vec2 = torch.nn.Parameter(torch.Tensor([1, 2, 3]), requires_grad=True)
q = torch.unsqueeze(weights_vec1, 1) @ torch.unsqueeze(inputs, 1).T
k = torch.unsqueeze(weights_vec2, 1) @ torch.unsqueeze(inputs, 1).T
q.retain_grad()
k.retain_grad()
logits = (q.T @ k) + torch.Tensor([[0, float('-inf'), 0], [0, 0, 0], [0, float('-inf'), 0]])
logits = torch.cat([logits[:1], logits[2:]]) / 1000  # div by 1000 just for stability for demonstration
logits.retain_grad()
print('logits: ', logits, '\n')
soft = torch.nn.functional.softmax(logits, dim=1)
soft = torch.cat([soft[:1], torch.Tensor([[0, 0, 0]]), soft[1:]])
soft.retain_grad()
print('softmax: ', soft, '\n')
final = torch.sum(torch.prod(soft + 1, dim=0))  # just to get some scalar value at the end
print('final: ', final, '\n')
final.backward()
print('soft gradients: ', soft.grad, '\n')
print('logits gradients: ', logits.grad, '\n')
print('q gradients: ', q.grad, '\n')
print('k gradients: ', k.grad, '\n')
print('wq gradients: ', weights_vec1.grad, '\n')
print('wk gradients: ', weights_vec2.grad, '\n')

In [None]:
# for decoder w1 -- projecting from scalar to matrix dimensionwise

tens1 = torch.Tensor([1, 2])  # timesteps
tens2 = torch.Tensor([1, 2, 3, 4])  # pos encoding
tens1 = tens1.repeat(4, 1)
tens1.T * tens2

In [None]:
t1 = torch.Tensor([[1, 2], [1 ,2], [1, 2]])
t2 = torch.Tensor([[1, 4], [2 ,5], [3, 6], [4, 5]])


In [None]:
t1 = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
v = torch.nn.Parameter(torch.Tensor([3, 3.1, 3.2, 3.4]), requires_grad=True)
z = t1 * v
s = F.softmax(z, dim=1)
i_ = torch.multinomial(s, 1).squeeze()
print(i_)
selected = torch.cat((torch.unsqueeze(z[0][i_[0]], -1) , torch.unsqueeze(z[1][i_[1]], -1)))
target = torch.Tensor([4, 8])
l = torch.mean(target - z.T[i_])
l.backward()

In [None]:
z

In [None]:
v.grad

In [328]:
l = nn.LayerNorm(4)
t = torch.Tensor([[[1, 2, 3, 4], [2, 4, 6, 8]], [[1, 2, 3, 4], [2, 4, 4.2, 5.3]]])
l(t)

tensor([[[-1.3416, -0.4472,  0.4472,  1.3416],
         [-1.3416, -0.4472,  0.4472,  1.3416]],

        [[-1.3416, -0.4472,  0.4472,  1.3416],
         [-1.5752,  0.1050,  0.2730,  1.1971]]],
       grad_fn=<NativeLayerNormBackward>)

In [333]:
p = PositionalEncoding(64, max_len=5000)
p.pe

tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
           0.0000e+00,  1.0000e+00]],

        [[ 8.4147e-01,  5.4030e-01,  6.8156e-01,  ...,  1.0000e+00,
           1.3335e-04,  1.0000e+00]],

        [[ 9.0930e-01, -4.1615e-01,  9.9748e-01,  ...,  1.0000e+00,
           2.6670e-04,  1.0000e+00]],

        ...,

        [[ 9.5625e-01, -2.9254e-01,  6.4315e-01,  ...,  6.3049e-01,
           6.1813e-01,  7.8608e-01]],

        [[ 2.7050e-01, -9.6272e-01, -5.1133e-02,  ...,  6.3036e-01,
           6.1823e-01,  7.8599e-01]],

        [[-6.6395e-01, -7.4778e-01, -7.1816e-01,  ...,  6.3022e-01,
           6.1834e-01,  7.8591e-01]]])