Dataset includes several time series representing different ATMs and their daily cash money demand at different locations.

In [1]:
# NOTE FIX THESE IMPORTS

from data_loader import convert_tsf_to_dataframe
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import statsmodels.api as sm
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import linear, softmax
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch import Tensor
from typing import Optional, Tuple, Union, Callable, Any, List
from torch.nn.init import xavier_uniform_, constant_, xavier_normal_
import copy
from torch.nn import Module, LayerNorm, ModuleList
from torch.nn.parameter import Parameter
import datetime
from torch.utils.data import Dataset, DataLoader
import sys

In [2]:
torch.set_printoptions(threshold=1000)

In [3]:
TSF_FILE = 'nn5_daily_dataset_with_missing_values.tsf'

In [4]:
data = convert_tsf_to_dataframe(TSF_FILE)

In [5]:
series_values = data[0]['series_value']

In [None]:
data

In [None]:
data[0]

In [None]:
data[0]['series_value'][0]

In [None]:
date_range_0 = pd.date_range(start='1996-03-18', periods=791, freq='D')

In [None]:
date_range_0

In [None]:
data_0 = pd.DataFrame({'series_value': data[0]['series_value'][0], 'date': date_range_0})
data_0 = data_0.set_index('date')

In [None]:
data_0

In [None]:
plt.plot(data_0['series_value'][:21])

In [None]:
np.array(data_0['series_value'])

In [None]:
# proportion of nulls and zeroes
total = 791 * 111
null_count = 0
zero_count = 0
for i in range(len(data[0]['series_value'])):
    series_values = np.array(pd.to_numeric(data[0]['series_value'][i], errors='coerce'))
    nulls_in_series = np.sum(np.isnan(series_values))
    null_count += nulls_in_series
    zeroes_in_series = len(series_values) - np.count_nonzero(series_values)
    zero_count += zeroes_in_series
(null_count + zero_count) / total
    

Questions and notes:

- Global methods seem better than traditional univariate methods for this dataset because we can leverage related but seperate datasets in a unified model. Traditional methods such as ARIMA do not have this capability and if we would need to fit a separate model for each dataset. (Look into this).  


- How to deal with missing values? I assume if we're using a deep learning method, we could simply treat the missing values as breakpoints for training samples. I.e. let's say you have \[$x_1$, $x_2$, $x_3$, $x_4$, $x_5$, $x_6$, $x_7$, $x_8$\] where this is a sequence in order of time and $x_5$ is missing. Then, assuming the model uses a lag of 2, we can treat {$X$: \[$x_1$, $x_2$\], $y$: $x_3$}, {$X$: \[$x_2$, $x_3$\], $y$: $x_4$} and {$X$: \[$x_6$, $x_7$\], $y$: $x_8$} as our training examples, treating $x_5$ as a breakpoint. Is there a better way? ANSWER: for now no need to interpolate missing data, let's just try training on what we have.


- How can we identify the types of seasonality in the data? Can we deal with multiple seasonality? Traditional seasonal decomposition methods (even those which deal with multiple seasonality i.e. MSTL) will probably not work too well due to the multiple series. How do we integrate seasonality into the model? One idea is to use sin/cos encodings depending on the seasonal position. I.e. for seasonality that occurs one day a week, use $a\sin (x\cdot \frac{\pi}{3.5})$ and the cos version, where $a$ is tuneable. But how many of these terms should we include? Seems like a relatively wieldy hyperparameter. Also, if we need to see how to modify the architecture to include these terms. Another idea is to include the previous season's data, i.e. alongside the $k$ lags from the past we use $x_{t - p}$ as input, where $p$ is the seasonal period. Note that we would also need some kind of positional encoding for this value. Perhaps try dealing with seasonality after an original attempt without it, though.


- TODO: I want to try applying a transformer model to this data, research and look into transformers. From there we can try different things. So first just try applying a vanilla transformer. 
    - BUT WAIT... is the transformer really the best choice here? The problem I see with the transformer is that it is more suited for sequence-to-sequence models. In other words, you get one complete sequence as input and have to output another related sequence. Perhaps there is some modification to this that allows for completion of one sequence but anyway, if I'm correct, the self-attention used in the transformer will not be applicable to numerical time series such as this one. When it comes to sequence-to-sequence models that the transformer is meant for, the words can influence each other, i.e. a word later in the sequence can provide semantic information about a word earlier in the sequence, and self-attention deals with that scenario, where the information flow is not necessarily left-to-right. But for a time series like this, the information flows unidirectionally, elements earlier in the sequence do not depend on elements later in the sequence. Using a self-attention encoding where all elements of the sequence attend on each other would lead to spurious modelling. Perhaps an RNN model would be more suited to this kind of time series data. I will look into deep learning models for time series and find something that makes sense. 



- Better way to decode? Right now training uses "teacher forcing", but inference does not, if a very wrong prediction is produced then it negatively affects the subsequent predictions that get decoded, and without a concept of sequence probability we can't use beam search. Perhaps predict everything at once, i.e. no autoregressive decoding (which is needed for NLP because you don't know when the sequence will end, but we know the desired forecast horizon for time series). New "decoder" takes timestep position, performs some positional embedding, computes attention on encoder output, depending on query generated from positional embedding, performs transformations on the values and outputs the desired number of forecasts, one for each positional embedding. Something to try. Note this could possibly also enable forecasting of variable length for a single trained model depending on how I implement the positional encodings. If I use a lookup table for the positional encodings this is impossible, but if I use matrix mulitplications with the raw timestep number it is possible.


- Seasonailty: first create a seasonal encoder. Essentially given a forecast horizon, i.e. 7, and a hyperparameter for window size, i.e. 5, we create a "seasonal context" of len (horizon + window_sz) - 1. In other words, we take all our desired dates to forecast, seasonally shifted, get those values, as well as the (window_sz - 1) / 2 values on each side. So we have our seasonal context, comprised of horizon_len windows of length window_sz. Example: if our context is \[1, 2, 3, 4, 5\] and our window_sz is 3 (meaning we want to forecast 3 into the future), we have 3 windows \[1, 2, 3\], \[2, 3, 4\], and \[3, 4, 5\]. One corresponding to each position we want to decode. The intuition is that we hope each window contains seasonal data relevant to each date we want to forecast. We encode this seasonal context with its own learned position embeddings and weights; encoder is our modified encoder with causal self attention. Once the self attention finishes, we get a matrix of size (context_len x d_season). Normally we would then use the decoder-encoder attention as in the normal transformer but the problem here is that if we want to use this with the modified decoder as above, we can't do this. The decoder inputs for the modified decoder are merely timestep encodings so I'm not sure if that will give enough information to attend to the seasonal context, but I should probably try it nonetheless. My idea is to first perform attention from the decoder timestep encodings to the encodings of the previous $k$ time series elements immediately preceding the dates we want to forecast. Then we attend from the output of this to the seasonal encodings, applying as mask (see notes) before the softmax so that each decoder position only attends to the corresponding window for each season. We perform this attention on each seasonal component at the same time by concatentating the keys along the position axis. From here we get the regular dimension of values which we transform to get all the decoder timesteps at once. 


- Think about feeature engineering


- How to deal with nulls in target? I.e. how do we remove these from loss? We might still have to discard data... Like if at timestep 3 in the decoder the ground truth isn't available, i.e. null, how can we compute the loss and keep it differentiable? 
    - I suppose we could impute the null value in the target to be the same as the previous value or something. Not ideal but will be differentiable at least
    - Pytorch has a reduce flag in loss which lets you do perform loss elementwise. Idea is that we remove the nulls after doing this, and average the rest. But is this differentiable? Well apparently it's somehow differentiable but some of the gradients become null
    
    
- Perhaps I should apply some sort of clustering method. It seems like each ATM is at a different location, so the location itself is a latent variable that is not explicitly included. Different locations and location types may look different. Clustering could deal with this, use a different model for each cluster. Hopefully these inherent location differences get picked up by the clustering method. Look into time series clustering. 


- Here's the plan: try a variety of different models, and forecast for 1-day, 2-days, ..., 7-days ahead

- Models to consider:
    - ARIMA (baseline)
    - LSTM
    - Transformer
    - Temporal fusion transformer


- Transformer ideas:
    - Deal with nulls by setting them to 0 and applying an (input) mask which we add before the softmax which makes the null components of the self-attention negative infinity, becoming 0 after the softmax, thus negating their attention weights and ensuring they do not factor in to the final result. Experimentally verified that this does not affect the backprop computation (null component has 0 gradient in query/key matrices). 
        - Note that we also have to mask out the row of the queries for the null value, only way I can think of to do this is to first let the softmax do its thing on the row of 0's, then remove the row (using the null timestep dim input), and add in a row of 0's at the deleted position. Wouldn't this break the differentiable graph though? Perhaps there's no need to do this after all. Let's verify. Yes, we don't need to do the initial removal from the softmax tensor, we can simply add a mask in the decoding component which masks out the encoding corresponding to null in its attention
    - Mask out future timesteps during encoding. I.e. make the encoder self attention autoregressive by applying an encoding mask. Deals with the issue mentioned above. So the encoding at position 2 can only be affected by positions 1 and 2, etc.
    

In [None]:
# Trying vanilla transformer from pytorch with adaptations

In [6]:

def _in_projection_packed(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    w: Tensor,
    b: Optional[Tensor] = None,
) -> List[Tensor]:
    r"""
    Performs the in-projection step of the attention operation, using packed weights.
    Output is a triple containing projection tensors for query, key and value.
    Args:
        q, k, v: query, key and value tensors to be projected. For self-attention,
            these are typically the same tensor; for encoder-decoder attention,
            k and v are typically the same tensor. (We take advantage of these
            identities for performance if they are present.) Regardless, q, k and v
            must share a common embedding dimension; otherwise their shapes may vary.
        w: projection weights for q, k and v, packed into a single tensor. Weights
            are packed along dimension 0, in q, k, v order.
        b: optional projection biases for q, k and v, packed into a single tensor
            in q, k, v order.
    Shape:
        Inputs:
        - q: :math:`(..., E)` where E is the embedding dimension
        - k: :math:`(..., E)` where E is the embedding dimension
        - v: :math:`(..., E)` where E is the embedding dimension
        - w: :math:`(E * 3, E)` where E is the embedding dimension
        - b: :math:`E * 3` where E is the embedding dimension
        Output:
        - in output list :math:`[q', k', v']`, each output tensor will have the
            same shape as the corresponding input tensor.
    """
    E = q.size(-1)
    if k is v:
        if q is k:
            # self-attention
            return linear(q, w, b).chunk(3, dim=-1)
        else:
            # encoder-decoder attention
            w_q, w_kv = w.split([E, E * 2])
            if b is None:
                b_q = b_kv = None
            else:
                b_q, b_kv = b.split([E, E * 2])
            return (linear(q, w_q, b_q),) + linear(k, w_kv, b_kv).chunk(2, dim=-1)
    else:
        w_q, w_k, w_v = w.chunk(3)
        if b is None:
            b_q = b_k = b_v = None
        else:
            b_q, b_k, b_v = b.chunk(3)
        return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)

    
def _in_projection(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    w_q: Tensor,
    w_k: Tensor,
    w_v: Tensor,
    b_q: Optional[Tensor] = None,
    b_k: Optional[Tensor] = None,
    b_v: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor]:
    r"""
    Performs the in-projection step of the attention operation. This is simply
    a triple of linear projections, with shape constraints on the weights which
    ensure embedding dimension uniformity in the projected outputs.
    Output is a triple containing projection tensors for query, key and value.
    Args:
        q, k, v: query, key and value tensors to be projected.
        w_q, w_k, w_v: weights for q, k and v, respectively.
        b_q, b_k, b_v: optional biases for q, k and v, respectively.
    Shape:
        Inputs:
        - q: :math:`(Qdims..., Eq)` where Eq is the query embedding dimension and Qdims are any
            number of leading dimensions.
        - k: :math:`(Kdims..., Ek)` where Ek is the key embedding dimension and Kdims are any
            number of leading dimensions.
        - v: :math:`(Vdims..., Ev)` where Ev is the value embedding dimension and Vdims are any
            number of leading dimensions.
        - w_q: :math:`(Eq, Eq)`
        - w_k: :math:`(Eq, Ek)`
        - w_v: :math:`(Eq, Ev)`
        - b_q: :math:`(Eq)`
        - b_k: :math:`(Eq)`
        - b_v: :math:`(Eq)`
        Output: in output triple :math:`(q', k', v')`,
         - q': :math:`[Qdims..., Eq]`
         - k': :math:`[Kdims..., Eq]`
         - v': :math:`[Vdims..., Eq]`
    """
    Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1)
    assert w_q.shape == (Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
    assert w_k.shape == (Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
    assert w_v.shape == (Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
    assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
    assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
    assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
    return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)



def _scaled_dot_product_attention(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    attn_mask: Optional[Tensor] = None,
    dropout_p: float = 0.0,
) -> Tuple[Tensor, Tensor]:
    r"""
    Computes scaled dot product attention on query, key and value tensors, using
    an optional attention mask if passed, and applying dropout if a probability
    greater than 0.0 is specified.
    Returns a tensor pair containing attended values and attention weights.
    Args:
        q, k, v: query, key and value tensors. See Shape section for shape details.
        attn_mask: optional tensor containing mask values to be added to calculated
            attention. May be 2D or 3D; see Shape section for details.
        dropout_p: dropout probability. If greater than 0.0, dropout is applied.
    Shape:
        - q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length,
            and E is embedding dimension.
        - key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
            and E is embedding dimension.
        - value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
            and E is embedding dimension.
        - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of
            shape :math:`(Nt, Ns)`.
        - Output: attention values have shape :math:`(B, Nt, E)`; attention weights
            have shape :math:`(B, Nt, Ns)`
    """
    B, Nt, E = q.shape
    q = q / math.sqrt(E)

    # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
    if attn_mask is not None:
        attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1))
    else:
        attn = torch.bmm(q, k.transpose(-2, -1))
    
#     print('attn before soft')
#     print(attn)
    attn = softmax(attn, dim=-1)
#     print('attn after soft')
#     print(attn)

    if dropout_p > 0.0:
        attn = F.dropout(attn, p=dropout_p)
    # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
    output = torch.bmm(attn, v)
#     print('attn output')
#     print(output)
    return output, attn



def multi_head_attention_forward(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    embed_dim_to_check: int,
    num_heads: int,
    in_proj_weight: Optional[Tensor],
    in_proj_bias: Optional[Tensor],
    bias_k: Optional[Tensor],
    bias_v: Optional[Tensor],
    add_zero_attn: bool,
    dropout_p: float,
    out_proj_weight: Tensor,
    out_proj_bias: Optional[Tensor],
    training: bool = True,
    key_padding_mask: Optional[Tensor] = None,
    need_weights: bool = True,
    attn_mask: Optional[Tensor] = None,
    use_separate_proj_weight: bool = False,
    q_proj_weight: Optional[Tensor] = None,
    k_proj_weight: Optional[Tensor] = None,
    v_proj_weight: Optional[Tensor] = None,
    static_k: Optional[Tensor] = None,
    static_v: Optional[Tensor] = None,
    average_attn_weights: bool = True,
    is_batched: bool = True,
) -> Tuple[Tensor, Optional[Tensor]]:
    r"""
    Args:
        query, key, value: map a query and a set of key-value pairs to an output.
            See "Attention Is All You Need" for more details.
        embed_dim_to_check: total dimension of the model.
        num_heads: parallel attention heads.
        in_proj_weight, in_proj_bias: input projection weight and bias.
        bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
        add_zero_attn: add a new batch of zeros to the key and
                       value sequences at dim=1.
        dropout_p: probability of an element to be zeroed.
        out_proj_weight, out_proj_bias: the output projection weight and bias.
        training: apply dropout if is ``True``.
        key_padding_mask: if provided, specified padding elements in the key will
            be ignored by the attention. This is an binary mask. When the value is True,
            the corresponding value on the attention layer will be filled with -inf.
        need_weights: output attn_output_weights.
        attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
            the batches while a 3D mask allows to specify a different mask for the entries of each batch.
        use_separate_proj_weight: the function accept the proj. weights for query, key,
            and value in different forms. If false, in_proj_weight will be used, which is
            a combination of q_proj_weight, k_proj_weight, v_proj_weight.
        q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
        static_k, static_v: static key and value used for attention operators.
        average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
            Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
            when ``need_weights=True.``. Default: True
    Shape:
        Inputs:
        - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
          the embedding dimension.
        - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
        - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
        - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
          If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
          will be unchanged. If a BoolTensor is provided, the positions with the
          value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
        - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
          3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
          S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
          positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
          while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
          are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
          is provided, it will be added to the attention weight.
        - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
          N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
        - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
          N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
        Outputs:
        - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
          E is the embedding dimension.
        - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
          attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
          :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
          :math:`S` is the source sequence length. If ``average_weights=False``, returns attention weights per
          head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
    """
    
#     print('query')
#     print(query)
#     print('key')
#     print(key)
#     print('value')
#     print(value)
    
    
    # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
    # is batched, run the computation and before returning squeeze the
    # batch dimension so that the output doesn't carry this temporary batch dimension.
    if not is_batched:
        # unsqueeze if the input is unbatched
        query = query.unsqueeze(1)
        key = key.unsqueeze(1)
        value = value.unsqueeze(1)
        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask.unsqueeze(0)

    # set up shape vars
    tgt_len, bsz, embed_dim = query.shape
    src_len, _, _ = key.shape
    assert embed_dim == embed_dim_to_check, \
        f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
    
    head_dim = embed_dim // num_heads
    assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
    if use_separate_proj_weight:
        # allow MHA to have different embedding dimensions when separate projection weights are used
        assert key.shape[:2] == value.shape[:2], \
            f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
    else:
        assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"

    #
    # compute in-projection
    #
 
    if not use_separate_proj_weight:
        assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
        q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
    else:
        assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
        assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
        assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
        if in_proj_bias is None:
            b_q = b_k = b_v = None
        else:
            b_q, b_k, b_v = in_proj_bias.chunk(3)
        q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)

#     print('query_emb')
#     print(q)
#     print('key_emb')
#     print(k)
#     print('value_emb')
#     print(v)
    # prep attention mask
    if attn_mask is not None:
        if attn_mask.dtype == torch.uint8:
            warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
            attn_mask = attn_mask.to(torch.bool)
        else:
            assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \
                f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}"
        # ensure attn_mask's dim is 3
        if attn_mask.dim() == 2:
            correct_2d_size = (tgt_len, src_len)
            if attn_mask.shape != correct_2d_size:
                raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
            attn_mask = attn_mask.unsqueeze(0)
        elif attn_mask.dim() == 3:
            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
            if attn_mask.shape != correct_3d_size:
                raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
        else:
            raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")

    # prep key padding mask
    if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
        warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
        key_padding_mask = key_padding_mask.to(torch.bool)

    # add bias along batch dimension (currently second)
    if bias_k is not None and bias_v is not None:
        assert static_k is None, "bias cannot be added to static key."
        assert static_v is None, "bias cannot be added to static value."
        k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
        v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
        if attn_mask is not None:
            attn_mask = pad(attn_mask, (0, 1))
        if key_padding_mask is not None:
            key_padding_mask = pad(key_padding_mask, (0, 1))
    else:
        assert bias_k is None
        assert bias_v is None

    #
    # reshape q, k, v for multihead attention and make em batch first
    #
    q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)  
    if static_k is None:
        k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    else:
        # TODO finish disentangling control flow so we don't do in-projections when statics are passed
        assert static_k.size(0) == bsz * num_heads, \
            f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
        assert static_k.size(2) == head_dim, \
            f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
        k = static_k
    if static_v is None:
        v = v.contiguous().view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    else:
        # TODO finish disentangling control flow so we don't do in-projections when statics are passed
        assert static_v.size(0) == bsz * num_heads, \
            f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
        assert static_v.size(2) == head_dim, \
            f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
        v = static_v

    # update source sequence length after adjustments
    src_len = k.size(1)

    # merge key padding and attention masks
    if key_padding_mask is not None:
        assert key_padding_mask.shape == (bsz, src_len), \
            f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
        key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).   \
            expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
        if attn_mask is None:
            attn_mask = key_padding_mask
        elif attn_mask.dtype == torch.bool:
            attn_mask = attn_mask.logical_or(key_padding_mask)
        else:
            attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))

    # convert mask to float
    if attn_mask is not None and attn_mask.dtype == torch.bool:
        new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
        new_attn_mask.masked_fill_(attn_mask, float("-inf"))
        attn_mask = new_attn_mask

    # adjust dropout probability
    if not training:
        dropout_p = 0.0

    #
    # (deep breath) calculate attention and out projection
    #
    attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)
    attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
    attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
    attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))

    if need_weights:
        # optionally average attention weights over heads
        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
        if average_attn_weights:
            attn_output_weights = attn_output_weights.sum(dim=1) / num_heads

        if not is_batched:
            # squeeze the output if input was unbatched
            attn_output = attn_output.squeeze(1)
            attn_output_weights = attn_output_weights.squeeze(0)
        return attn_output, attn_output_weights
    else:
        if not is_batched:
            # squeeze the output if input was unbatched
            attn_output = attn_output.squeeze(1)
        return attn_output, None

In [7]:
class MultiheadAttention(nn.Module):
    r"""Allows the model to jointly attend to information
    from different representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    Multi-Head Attention is defined as:

    .. math::
        \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O

    where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.

    Args:
        embed_dim: Total dimension of the model.
        num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
            across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
        dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
        bias: If specified, adds bias to input / output projection layers. Default: ``True``.
        add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
        add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
            Default: ``False``.
        kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
        vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).

    Examples::

        >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
        >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
    """

    def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
                 kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(MultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim

        self.num_heads = num_heads
        self.dropout = dropout
        self.batch_first = batch_first
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        if self._qkv_same_embed_dim is False:
            self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
            self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
            self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
            self.register_parameter('in_proj_weight', None)
        else:
            self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
            self.register_parameter('q_proj_weight', None)
            self.register_parameter('k_proj_weight', None)
            self.register_parameter('v_proj_weight', None)

        if bias:
            self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
        else:
            self.register_parameter('in_proj_bias', None)
        self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)

        if add_bias_kv:
            self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
            self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
        else:
            self.bias_k = self.bias_v = None

        self.add_zero_attn = add_zero_attn

        self._reset_parameters()

    def _reset_parameters(self):
        if self._qkv_same_embed_dim:
            xavier_uniform_(self.in_proj_weight)
        else:
            xavier_uniform_(self.q_proj_weight)
            xavier_uniform_(self.k_proj_weight)
            xavier_uniform_(self.v_proj_weight)

        if self.in_proj_bias is not None:
            constant_(self.in_proj_bias, 0.)
            constant_(self.out_proj.bias, 0.)
        if self.bias_k is not None:
            xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            xavier_normal_(self.bias_v)

    def __setstate__(self, state):
        # Support loading old MultiheadAttention checkpoints generated by v1.1.0
        if '_qkv_same_embed_dim' not in state:
            state['_qkv_same_embed_dim'] = True

        super(MultiheadAttention, self).__setstate__(state)

    def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
                need_weights: bool = True, attn_mask: Optional[Tensor] = None,
                average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
        r"""
    Args:
        query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
            or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
            :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
            Queries are compared against key-value pairs to produce the output.
            See "Attention Is All You Need" for more details.
        key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
            or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
            :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
            See "Attention Is All You Need" for more details.
        value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
            ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
            sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
            See "Attention Is All You Need" for more details.
        key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
            to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
            Binary and byte masks are supported.
            For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
            the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding ``key``
            value will be ignored.
        need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
            Default: ``True``.
        attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
            :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
            :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
            broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
            Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
            corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
            corresponding position is not allowed to attend. For a float mask, the mask values will be added to
            the attention weight.
        average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
            heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
            effect when ``need_weights=True.``. Default: True (i.e. average weights across heads)

    Outputs:
        - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
          :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
          where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
          embedding dimension ``embed_dim``.
        - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
          returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
          :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
          :math:`S` is the source sequence length. If ``average_weights=False``, returns attention weights per
          head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.

        .. note::
            `batch_first` argument is ignored for unbatched inputs.
        """
        
        is_batched = query.dim() == 3
        if self.batch_first and is_batched:
            query, key, value = [x.transpose(1, 0) for x in (query, key, value)]

        if not self._qkv_same_embed_dim:
            attn_output, attn_output_weights = multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask, use_separate_proj_weight=True,
                q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
                v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights)
        else:
            attn_output, attn_output_weights = multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask, average_attn_weights=average_attn_weights)
        if self.batch_first and is_batched:
            return attn_output.transpose(1, 0), attn_output_weights
        else:
            return attn_output, attn_output_weights
        


In [8]:
def _get_clones(module, N):
    return ModuleList([copy.deepcopy(module) for i in range(N)])



class TransformerModel(Module):
    r"""A transformer model. User is able to modify the attributes as needed. The architecture
    is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
    Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
    Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
    Processing Systems, pages 6000-6010.
    Args:
        d_model: the number of expected features in the encoder/decoder inputs (default=512).
        nhead: the number of heads in the multiheadattention models (default=8).
        num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
        num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of encoder/decoder intermediate layer, can be a string
            ("relu" or "gelu") or a unary callable. Default: relu
        custom_encoder: custom encoder (default=None).
        custom_decoder: custom decoder (default=None).
        layer_norm_eps: the eps value in layer normalization components (default=1e-5).
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
        norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before
            other attention and feedforward operations, otherwise after. Default: ``False`` (after).
    Examples::
        >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
        >>> src = torch.rand((10, 32, 512))
        >>> tgt = torch.rand((20, 32, 512))
        >>> out = transformer_model(src, tgt)
    Note: A full example to apply nn.Transformer module for the word language model is available in
    https://github.com/pytorch/examples/tree/master/word_language_model
    """

    def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
                 num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
                 custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
                 use_norm=True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(TransformerModel, self).__init__()

        if custom_encoder is not None:
            self.encoder = custom_encoder
        else:
            encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
                                                    activation, layer_norm_eps, batch_first, norm_first, 
                                                    use_norm,
                                                    **factory_kwargs)
            encoder_norm = None
            if use_norm:
                encoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
            self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        if custom_decoder is not None:
            self.decoder = custom_decoder
        else:
            decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
                                                    activation, layer_norm_eps, batch_first, norm_first,
                                                    use_norm,
                                                    **factory_kwargs)
            decoder_norm = None
            if use_norm:
                decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
            self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)

        self._reset_parameters()

        self.d_model = d_model
        self.nhead = nhead

        self.batch_first = batch_first

    def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Take in and process masked source/target sequences.
        Args:
            src: the sequence to the encoder (required).
            tgt: the sequence to the decoder (required).
            src_mask: the additive mask for the src sequence (optional).
            tgt_mask: the additive mask for the tgt sequence (optional).
            memory_mask: the additive mask for the encoder output (optional).
            src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
            tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
            memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).
        Shape:
            - src: :math:`(S, E)` for unbatched input, :math:`(S, N, E)` if `batch_first=False` or
              `(N, S, E)` if `batch_first=True`.
            - tgt: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
              `(N, T, E)` if `batch_first=True`.
            - src_mask: :math:`(S, S)` or :math:`(N\cdot\text{num\_heads}, S, S)`.
            - tgt_mask: :math:`(T, T)` or :math:`(N\cdot\text{num\_heads}, T, T)`.
            - memory_mask: :math:`(T, S)`.
            - src_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
            - tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`.
            - memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
            Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked
            positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
            while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
            are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
            is provided, it will be added to the attention weight.
            [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
            the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero
            positions will be unchanged. If a BoolTensor is provided, the positions with the
            value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
            - output: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
              `(N, T, E)` if `batch_first=True`.
            Note: Due to the multi-head attention architecture in the transformer model,
            the output sequence length of a transformer is same as the input sequence
            (i.e. target) length of the decoder.
            where S is the source sequence length, T is the target sequence length, N is the
            batch size, E is the feature number
        Examples:
            >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
        """

        is_batched = src.dim() == 3
        if not self.batch_first and src.size(1) != tgt.size(1) and is_batched:
            raise RuntimeError("the batch number of src and tgt must be equal")
        elif self.batch_first and src.size(0) != tgt.size(0) and is_batched:
            raise RuntimeError("the batch number of src and tgt must be equal")

        if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model:
            raise RuntimeError("the feature number of src and tgt must be equal to d_model")
        
#         print('ENCODER_STEP')
        memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
#         print('DECODER_STEP')
        output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
                              tgt_key_padding_mask=tgt_key_padding_mask,
                              memory_key_padding_mask=memory_key_padding_mask)
        return output

    @staticmethod
    def generate_square_subsequent_mask(sz: int) -> Tensor:
        r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
            Unmasked positions are filled with float(0.0).
        """
        return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)

    def _reset_parameters(self):
        r"""Initiate parameters in the transformer model."""

        for p in self.parameters():
            if p.dim() > 1:
                xavier_uniform_(p)


class TransformerEncoder(Module):
    r"""TransformerEncoder is a stack of N encoder layers. Users can build the
    BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
    Args:
        encoder_layer: an instance of the TransformerEncoderLayer() class (required).
        num_layers: the number of sub-encoder-layers in the encoder (required).
        norm: the layer normalization component (optional).
        enable_nested_tensor: if True, input will automatically convert to nested tensor
            (and convert back on output). This will improve the overall performance of
            TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
        >>> src = torch.rand(10, 32, 512)
        >>> out = transformer_encoder(src)
    """
    __constants__ = ['norm']

    def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=True):
        super(TransformerEncoder, self).__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm
        self.enable_nested_tensor = enable_nested_tensor

    def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the input through the encoder layers in turn.
        Args:
            src: the sequence to the encoder (required).
            mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).
        Shape:
            see the docs in Transformer class.
        """
        for mod in self.layers:
            output = mod(src, src_mask=mask, src_key_padding_mask=src_key_padding_mask)

        if self.norm is not None:
            output = self.norm(output)

        return output


class TransformerDecoder(Module):
    r"""TransformerDecoder is a stack of N decoder layers
    Args:
        decoder_layer: an instance of the TransformerDecoderLayer() class (required).
        num_layers: the number of sub-decoder-layers in the decoder (required).
        norm: the layer normalization component (optional).
    Examples::
        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
        >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
        >>> memory = torch.rand(10, 32, 512)
        >>> tgt = torch.rand(20, 32, 512)
        >>> out = transformer_decoder(tgt, memory)
    """
    __constants__ = ['norm']

    def __init__(self, decoder_layer, num_layers, norm=None):
        super(TransformerDecoder, self).__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the inputs (and mask) through the decoder layer in turn.
        Args:
            tgt: the sequence to the decoder (required).
            memory: the sequence from the last layer of the encoder (required).
            tgt_mask: the mask for the tgt sequence (optional).
            memory_mask: the mask for the memory sequence (optional).
            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
            memory_key_padding_mask: the mask for the memory keys per batch (optional).
        Shape:
            see the docs in Transformer class.
        """
        output = tgt
#         print('x before decoder')
#         print(output)
        for mod in self.layers:
            output = mod(output, memory, tgt_mask=tgt_mask,
                         memory_mask=memory_mask,
                         tgt_key_padding_mask=tgt_key_padding_mask,
                         memory_key_padding_mask=memory_key_padding_mask)

        if self.norm is not None:
            output = self.norm(output)

        return output

class TransformerEncoderLayer(Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.
    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of the intermediate layer, can be a string
            ("relu" or "gelu") or a unary callable. Default: relu
        layer_norm_eps: the eps value in layer normalization components (default=1e-5).
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
        norm_first: if ``True``, layer norm is done prior to attention and feedforward
            operations, respectivaly. Otherwise it's done after. Default: ``False`` (after).
    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    Alternatively, when ``batch_first`` is ``True``:
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
        >>> src = torch.rand(32, 10, 512)
        >>> out = encoder_layer(src)
    Fast path:
        forward() will use a special optimized implementation if all of the following
        conditions are met:
        - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor
          argument ``requires_grad``
        - training is disabled (using ``.eval()``)
        - batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``)
        - norm_first is ``False`` (this restriction may be loosened in the future)
        - activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu``
        - at most one of ``src_mask`` and ``src_key_padding_mask`` is passed
        - if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask``
          nor ``src_key_padding_mask`` is passed
        - the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case
          unless the caller has manually modified one without modifying the other)
        If the optimized implementation is in use, a
        `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be
        passed for ``src`` to represent padding more efficiently than using a padding
        mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be
        returned, and an additional speedup proportional to the fraction of the input that
        is padding can be expected.
    """
    __constants__ = ['batch_first', 'norm_first']

    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
                 use_norm=True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
                                            **factory_kwargs)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)

        self.norm_first = norm_first
        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        
        self.activation = activation
        
        self.use_norm = use_norm

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(TransformerEncoderLayer, self).__setstate__(state)

    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the input through the encoder layer.
        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).
        Shape:
            see the docs in Transformer class.
        """

        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf

    
        x = src
        if self.norm_first:
            x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
            x = x + self._ff_block(self.norm2(x))
        else:
            x = x + self._sa_block(x, src_mask, src_key_padding_mask)
            if self.use_norm:
                x = self.norm1(x)
#             print('x after self attn, res, norm')
#             print(x)
            x = x + self._ff_block(x)
            if self.use_norm:
                x = self.norm2(x)
            

        return x

    # self-attention block
    def _sa_block(self, x: Tensor,
                  attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
        x = self.self_attn(x, x, x,
                           attn_mask=attn_mask,
                           key_padding_mask=key_padding_mask,
                           need_weights=False)[0]
        return self.dropout1(x)

    # feed forward block
    def _ff_block(self, x: Tensor) -> Tensor:
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout2(x)


class TransformerDecoderLayer(Module):
    r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
    This standard decoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.
    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of the intermediate layer, can be a string
            ("relu" or "gelu") or a unary callable. Default: relu
        layer_norm_eps: the eps value in layer normalization components (default=1e-5).
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
        norm_first: if ``True``, layer norm is done prior to self attention, multihead
            attention and feedforward operations, respectivaly. Otherwise it's done after.
            Default: ``False`` (after).
    Examples::
        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
        >>> memory = torch.rand(10, 32, 512)
        >>> tgt = torch.rand(20, 32, 512)
        >>> out = decoder_layer(tgt, memory)
    Alternatively, when ``batch_first`` is ``True``:
        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
        >>> memory = torch.rand(32, 10, 512)
        >>> tgt = torch.rand(32, 20, 512)
        >>> out = decoder_layer(tgt, memory)
    """
    __constants__ = ['batch_first', 'norm_first']

    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
                 use_norm=True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(TransformerDecoderLayer, self).__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
                                            **factory_kwargs)
        self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
                                                 **factory_kwargs)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)

        self.norm_first = norm_first
        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = activation
        
        self.use_norm = use_norm

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(TransformerDecoderLayer, self).__setstate__(state)

    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the inputs (and mask) through the decoder layer.
        Args:
            tgt: the sequence to the decoder layer (required).
            memory: the sequence from the last layer of the encoder (required).
            tgt_mask: the mask for the tgt sequence (optional).
            memory_mask: the mask for the memory sequence (optional).
            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
            memory_key_padding_mask: the mask for the memory keys per batch (optional).
        Shape:
            see the docs in Transformer class.
        """
        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf

        x = tgt
        if self.norm_first:
            
            x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask) 
            
            
            x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask)  
            
            x = x + self._ff_block(self.norm3(x))
        else:

            x = x + self._sa_block(x, tgt_mask, tgt_key_padding_mask)
            if self.use_norm:
                x = self.norm1(x)
                
#             print('x after self attn, res, norm')
#             print(x)
             
            x = x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask)
            if self.use_norm:
                x = self.norm2(x)
                
            x = x + self._ff_block(x)
            if self.use_norm:
                x = self.norm3(x)
            
#         print('declayer out')
#         print(x)
        return x

    # self-attention block
    def _sa_block(self, x: Tensor,
                  attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
        x = self.self_attn(x, x, x,
                               attn_mask=attn_mask,
                               key_padding_mask=key_padding_mask,
                               need_weights=False)[0]
        return self.dropout1(x)

    # multihead attention block
    def _mha_block(self, x: Tensor, mem: Tensor,
                   attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
        x = self.multihead_attn(x, mem, mem,
                                attn_mask=attn_mask,
                                key_padding_mask=key_padding_mask,
                                need_weights=False)[0]
        return self.dropout2(x)

    # feed forward block
    def _ff_block(self, x: Tensor) -> Tensor:
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout3(x)

In [9]:
# from https://pytorch.org/tutorials/beginner/transformer_tutorial.html

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [10]:
class Transformer(Module):
    
    
    def __init__(self, d_model=64, nhead=4, num_encoder_layers=3, num_decoder_layers=3,
                dim_feedforward=512, dropout=0.1, activation=F.gelu, use_norm=False):
        
        super(Transformer, self).__init__()
        self.transformer_model = TransformerModel(d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers,
                                                  num_decoder_layers=num_decoder_layers, dim_feedforward=dim_feedforward,
                                                  dropout=dropout, activation=activation, use_norm=use_norm)
        
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        self.input_projection1 = nn.Linear(1, d_model)
        self.input_projection2 = nn.Linear(d_model, d_model)
        self.output_projection = nn.Linear(d_model, 1)
        
    def forward(self, encoder_x, decoder_x, src_mask=None, tgt_mask=None, memory_mask=None):
        """ encoder_x should be vector of size input_seq_len x bs x 1   
            decoder_x should be vector of size output_seq_len x bs x 1
        """

        
        encoder_x = self.pos_encoder(self.input_projection2(torch.sigmoid(self.input_projection1(encoder_x))))
        decoder_x = self.pos_encoder(self.input_projection2(torch.sigmoid(self.input_projection1(decoder_x))))
        

        x = self.transformer_model(encoder_x, decoder_x, 
                                   src_mask=src_mask, tgt_mask=tgt_mask, memory_mask=memory_mask)
        x = self.output_projection(x)
        return x

    
        
        
        

In [None]:
# process data for vanilla transformer. I.e. create all possible sequences. Include null values

In [None]:
series_values = data[0]['series_value']
series_values

In [None]:
date_range = pd.date_range(start='1996-03-18', periods=791, freq='D')  # can do this because all series have same length
date_range

In [None]:
# to make it easy, set a default encoder and decoder seq len
encoder_seq_len = 14  # i.e. 2 weeks of data to use in encoder
decoder_seq_len = 7

In [None]:
# think about rolling cross validation but for now just have one test set at the end
train_set_dates = date_range[date_range < '1998-01-01'] 
test_set_dates = date_range[date_range >= datetime.datetime.strptime('1998-01-01', '%Y-%m-%d') - datetime.timedelta(encoder_seq_len)] 
# i.e. the first dates we want to predict are the week starting from 1998-01-01
# have to make sure these values we want to predict never pop up in the training set, including the decoder inputs
# in other words the latest decoder output should be one day before 1998-01-01 to avoid any leakage


In [11]:
def split_data_into_sequences(data_series, encoder_seq_len, decoder_seq_len, start_date=None, end_date=None, 
                              absolute_start_date='1996-03-18'):
    
    data_series = list(pd.Series(data_series).replace('NaN', np.nan))
    date_range = pd.date_range(start=absolute_start_date, periods=len(data_series), freq='D')
    start_date_idx = date_range.get_loc(start_date) if start_date is not None else 0
    end_date_idx = date_range.get_loc(end_date) if end_date is not None else len(date_range) - 1   

    data_subset = torch.Tensor(data_series[start_date_idx : end_date_idx + 1])
    
    
    sequences_ = []
    for start in range(0, len(data_subset) - decoder_seq_len - encoder_seq_len):
        enc_seq = data_subset[start : start + encoder_seq_len]
        dec_seq = data_subset[start + encoder_seq_len - 1 : start + encoder_seq_len + decoder_seq_len - 1]
        ground_truth = data_subset[start + encoder_seq_len : start + encoder_seq_len + decoder_seq_len]
        if torch.sum(~torch.isnan(ground_truth)).item() > 0:  # sequences with all null ground truth exist. Exclude these model ruiners
            sequences_.append((enc_seq, dec_seq, ground_truth))
            
        
    
#     sequences_ = [(data_subset[start : start + encoder_seq_len],
#                    data_subset[start + encoder_seq_len - 1 : start + encoder_seq_len + decoder_seq_len - 1],
#                    data_subset[start + encoder_seq_len : start + encoder_seq_len + decoder_seq_len])
#                  for start in range(0, len(data_subset) - decoder_seq_len - encoder_seq_len)]
    
    return sequences_


def split_dataset_into_sequences(dataset, encoder_seq_len, decoder_seq_len, start_date=None, end_date=None, 
                                 absolute_start_date='1996-03-18'):
    all_sequences = []
    for time_series in dataset:
        sequences_ = split_data_into_sequences(time_series, encoder_seq_len, decoder_seq_len, start_date=start_date, 
                                               end_date=end_date, absolute_start_date=absolute_start_date)
        all_sequences.extend(sequences_)
    return all_sequences


In [12]:
encoder_seq_len = 14
decoder_seq_len = 7

In [90]:
torch.sum(~torch.isnan(torch.Tensor([np.nan, np.nan, np.nan]))).item() > 0

False

In [27]:
testing_start_date = datetime.datetime.strptime('1998-01-01', '%Y-%m-%d') 
val_start_date = testing_start_date - datetime.timedelta(days=30*4)
training_end_date = val_start_date - datetime.timedelta(days=1)
val_end_date = testing_start_date - datetime.timedelta(days=1)

In [28]:
train_sequences = split_dataset_into_sequences(series_values, encoder_seq_len, decoder_seq_len, end_date=training_end_date)

In [None]:
len(train_sequences)

In [None]:
val_sequences = split_dataset_into_sequences(series_values, encoder_seq_len, decoder_seq_len, 
                                             start_date=val_start_date, end_date=val_end_date)

In [None]:
len(val_sequences)

In [None]:
test_sequences = split_dataset_into_sequences(series_values, encoder_seq_len, decoder_seq_len, start_date=testing_start_date)

In [None]:
len(test_sequences)

In [None]:
t = torch.Tensor([1, 2, 3])
i = torch.Tensor([0, np.nan, 1]).isnan().nonzero().squeeze()
torch.cat([t[:i], t[i+1:]])

In [None]:
torch.Tensor([1]).reshape((1))

In [13]:
def create_null_masks_general(in_seq, out_seq, is_self_attn=True, seq_len_dim=0):
    in_seq_len = in_seq.shape[seq_len_dim]
    out_seq_len = out_seq.shape[seq_len_dim]
    isnan_res = in_seq.isnan().nonzero()
    if len(isnan_res) == 0:
        return torch.zeros((out_seq_len, in_seq_len))
    else:
        null_indices = isnan_res.squeeze()
        if isnan_res.shape[0] == 1:
            null_indices = null_indices.reshape((1,))
        mask = torch.zeros((out_seq_len, in_seq_len)).index_fill_(1, null_indices, float('-inf'))
        if is_self_attn:
            for index in null_indices:
                mask[index, index] = 0  # need this to prevent bug when null is in first position (of decoder self attn)
        return mask
    

In [14]:
def create_null_masks(enc_seq, dec_seq, seq_len_dim=0):
    enc_sa_null_mask = create_null_masks_general(enc_seq, enc_seq, seq_len_dim=seq_len_dim)
    dec_sa_null_mask = create_null_masks_general(dec_seq, dec_seq, seq_len_dim=seq_len_dim)
    dec_enc_null_mask = create_null_masks_general(enc_seq, dec_seq, is_self_attn=False, seq_len_dim=seq_len_dim)
    
    return enc_sa_null_mask, dec_sa_null_mask, dec_enc_null_mask

In [25]:
def create_sa_causality_mask(n_t):
    return TransformerModel.generate_square_subsequent_mask(n_t)

In [26]:
def generate_transformer_masks(encoder_seq, decoder_seq, encoder_causality=False, seq_length_dim=0):
    encoder_sa_null_mask, decoder_sa_null_mask, decoder_encoder_null_mask = create_null_masks(encoder_seq, decoder_seq, 
                                                                                              seq_length_dim)
    decoder_sa_causality_mask = create_sa_causality_mask(decoder_seq.shape[seq_length_dim])
    decoder_sa_mask = decoder_sa_null_mask + decoder_sa_causality_mask
    
    encoder_sa_mask = encoder_sa_null_mask
    if encoder_causality:
        encoder_sa_causality_mask = create_sa_causality_mask(encoder_seq.shape[seq_length_dim])
        encoder_sa_mask = encoder_sa_mask + encoder_sa_causality_mask
    
    return encoder_sa_mask, decoder_sa_mask, decoder_encoder_null_mask

In [33]:
# testing mask generation
enc_seq_ex, dec_seq_ex, _ = train_sequences[5]
print(enc_seq_ex)
print(dec_seq_ex)
generate_transformer_masks(enc_seq_ex, dec_seq_ex, encoder_causality=True)

tensor([16.6100, 15.3203, 11.6071, 19.8838, 23.7670, 34.0278, 33.7868, 18.2540,
        19.3878, 17.2619, 23.8095, 36.1253, 33.6026, 32.6247])
tensor([32.6247,  7.4830,     nan, 10.1332, 17.9138, 22.1230, 33.2341])


(tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf],
         [0., 0.

In [34]:
class SequenceDatasetForTransformer(Dataset):
    def __init__(self, sequences, encoder_causality=False):
        
        """ sequences should be list of tuples of len 3. 
        First item of a tuple: context (encoder input)
        Second item of a tuple: decoder input
        Third item of tuple: ground truth
        """
        self.sequences = sequences
        self.encoder_causality = encoder_causality
        
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        encoder_input, decoder_input, ground_truth = self.sequences[idx]
        
        encoder_sa_mask, decoder_sa_mask, \
            decoder_encoder_null_mask = generate_transformer_masks(encoder_input, decoder_input, 
                                                                   encoder_causality=self.encoder_causality, 
                                                                   seq_length_dim=0)
        ground_truth_null_nans =  ground_truth.isnan().nonzero()
        ground_truth_null_indices = torch.ones(ground_truth.shape)
        null_indices_for_sequence = ground_truth_null_nans[:, 0]
        ground_truth_null_indices[null_indices_for_sequence] = 0
        
        encoder_input = torch.nan_to_num(encoder_input)
        decoder_input = torch.nan_to_num(decoder_input)
        ground_truth = torch.nan_to_num(ground_truth)
        
        return (encoder_input, decoder_input, ground_truth, ground_truth_null_indices,
                encoder_sa_mask, decoder_sa_mask, decoder_encoder_null_mask)
    

In [18]:
class SequenceDatasetForTransformerTesting(Dataset):
    def __init__(self, sequences):
        
        """ sequences should be list of tuples of len 3. 
        First item of a tuple: context (encoder input)
        Second item of a tuple: decoder input
        Third item of tuple: ground truth
        """
        self.sequences = sequences
        
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        encoder_input, _, ground_truth = self.sequences[idx]
        
        ground_truth_null_nans =  ground_truth.isnan().nonzero()
        ground_truth_null_indices = torch.ones(ground_truth.shape)
        null_indices_for_sequence = ground_truth_null_nans[:, 0]
        ground_truth_null_indices[null_indices_for_sequence] = 0
        
        
        
        return (encoder_input, ground_truth, ground_truth_null_indices)

In [16]:
# def process_batch_nulls(encoder_input, decoder_input, ground_truth):  
#     # need this because null indices for ground truth will be different sizes
#     encoder_input = torch.nan_to_num(encoder_input)
#     decoder_input = torch.nan_to_num(decoder_input)
#     ground_truth_null_nans =  ground_truth.isnan().nonzero().T
#     ground_truth_null_indices = torch.ones(ground_truth.shape)  # need this to remove null elements in gt from loss
#     batch_size = ground_truth.shape[0]
    
#     for b in range(batch_size):
#         null_indices_for_sequence = ground_truth_null_nans[1][ground_truth_null_nans[0] == b]
#         ground_truth_null_indices[b, null_indices_for_sequence] = 0
        
#     ground_truth = torch.nan_to_num(ground_truth)
    
#     return encoder_input, decoder_input, ground_truth, ground_truth_null_indices

In [35]:
def format_input(encoder_input, decoder_input, encoder_sa_mask, decoder_sa_mask, decoder_encoder_null_mask, n_heads):
    encoder_sa_mask = encoder_sa_mask.repeat(n_heads, 1, 1)
    decoder_sa_mask = decoder_sa_mask.repeat(n_heads, 1, 1)
    decoder_encoder_null_mask = decoder_encoder_null_mask.repeat(n_heads, 1, 1)
    
    # need e, d to be of shape Length x bs x 1
    encoder_input = encoder_input.T.unsqueeze(-1)
    decoder_input = decoder_input.T.unsqueeze(-1)
    
    return encoder_input, decoder_input, encoder_sa_mask, decoder_sa_mask, decoder_encoder_null_mask

In [20]:
def create_train_val_test_sets(data, training_end_date, val_start_date, val_end_date, testing_start_date,
                              encoder_seq_len, decoder_seq_len):
    train_sequences = split_dataset_into_sequences(data, encoder_seq_len, decoder_seq_len, 
                                                   end_date=training_end_date)
    val_sequences = split_dataset_into_sequences(data, encoder_seq_len, decoder_seq_len, 
                                                 start_date=val_start_date, end_date=val_end_date)
    test_sequences = split_dataset_into_sequences(data, encoder_seq_len, decoder_seq_len, 
                                                  start_date=testing_start_date)
    return train_sequences, val_sequences, test_sequences

In [21]:
testing_start_date = datetime.datetime.strptime('1998-01-01', '%Y-%m-%d') 
val_start_date = testing_start_date - datetime.timedelta(days=30*4)
training_end_date = val_start_date - datetime.timedelta(days=1)
val_end_date = testing_start_date - datetime.timedelta(days=1)

In [22]:
nhead = 4

In [36]:
encoder_causality = True

In [37]:
train_sequences, val_sequences, test_sequences = create_train_val_test_sets(series_values, training_end_date, val_start_date, 
                                                                            val_end_date, testing_start_date,
                                                                           14, 7)
train_dataset = SequenceDatasetForTransformer(train_sequences, encoder_causality=encoder_causality)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

In [38]:
vanilla_transformer_with_nullmask = Transformer(d_model=64, nhead=nhead, num_encoder_layers=3, num_decoder_layers=3,
                dim_feedforward=256, dropout=0.1, activation=F.gelu)

In [39]:
optimizer = torch.optim.Adam(vanilla_transformer_with_nullmask.parameters())

In [40]:
save_path = '/home/mbaroody/Programming/Time series/TimeSeriesExperiments/v4-{}.pt'

In [41]:

epoch = 0
# prev_params = None
while True:
    running_loss = 0
    print("Epoch {}".format(epoch))
    for i, data in enumerate(train_dataloader):
        if ((i + 1) % 250 == 0):
            print("Reached batch {}, running_loss was {}".format(i+1, running_loss / 250))
            running_loss = 0
            
        encoder_input, decoder_input, ground_truth, ground_truth_null_indices,\
                encoder_sa_mask, decoder_sa_mask, decoder_encoder_null_mask = data
        
        
        encoder_input, decoder_input, encoder_sa_mask,\
                decoder_sa_mask, decoder_encoder_null_mask = format_input(encoder_input, decoder_input, encoder_sa_mask, 
                                                                          decoder_sa_mask, decoder_encoder_null_mask, nhead)
        optimizer.zero_grad()
        outputs = vanilla_transformer_with_nullmask(encoder_input, decoder_input, encoder_sa_mask, decoder_sa_mask, 
                                                    decoder_encoder_null_mask)
        # outputs is shape dec_len x bs x 1
        outputs = outputs.squeeze().T

            
        # delete items that are null in target
        losses_nonull_per_sequence = torch.sum(((outputs - ground_truth) * ground_truth_null_indices) ** 2, 1)
        batch_divide = ground_truth_null_indices.shape[1] - torch.sum(1 - ground_truth_null_indices, 1)
        losses_nonull = torch.mean(losses_nonull_per_sequence / batch_divide)
        
        losses_nonull.backward()
        optimizer.step()
        
#         flag = False
#         for param in vanilla_transformer_with_nullmask.parameters():
#             if torch.sum(torch.isnan(param.data)) > 0:
#                 flag = True
#                 break
#         if flag:
#             print('outputs')
#             print(outputs)
#             print('ground truth')
#             print(ground_truth)
#             print('ground truth null indices')
#             print(ground_truth_null_indices)
#             print('loss a')
#             print(losses_nonull_per_sequence)
#             print('batch divide')
#             print(batch_divide)
#             print('loss')
#             print(losses_nonull.item())
            
#             break
                
#         prev_params = vanilla_transformer_with_nullmask.parameters()
        
        running_loss += losses_nonull.item()

        
    print('Saving')
    torch.save(vanilla_transformer_with_nullmask.state_dict(), save_path.format(epoch))
    epoch += 1
        
    

Epoch 0
Reached batch 250, running_loss was 79.06778560638428
Reached batch 500, running_loss was 51.179389556884765
Reached batch 750, running_loss was 37.37841262817383
Reached batch 1000, running_loss was 40.77492350769043
Reached batch 1250, running_loss was 40.383444709777834
Reached batch 1500, running_loss was 36.412503072738645
Reached batch 1750, running_loss was 34.781448913574216
Reached batch 2000, running_loss was 33.51221705436706
Reached batch 2250, running_loss was 34.32927252578735
Reached batch 2500, running_loss was 33.79481966018677
Reached batch 2750, running_loss was 33.58684712982178
Reached batch 3000, running_loss was 32.362316875457765
Reached batch 3250, running_loss was 36.87517970657348
Reached batch 3500, running_loss was 32.95948349761963
Reached batch 3750, running_loss was 34.21847271347046
Reached batch 4000, running_loss was 29.59201545333862
Reached batch 4250, running_loss was 33.98009145355225
Reached batch 4500, running_loss was 31.888674238204956

Reached batch 2000, running_loss was 29.806238704681398
Reached batch 2250, running_loss was 28.03031008529663
Reached batch 2500, running_loss was 30.641440671920776
Reached batch 2750, running_loss was 29.55279969215393
Reached batch 3000, running_loss was 30.86311809539795
Reached batch 3250, running_loss was 29.258385364532472
Reached batch 3500, running_loss was 29.730132108688355
Reached batch 3750, running_loss was 26.688039253234862
Reached batch 4000, running_loss was 28.602966915130615
Reached batch 4250, running_loss was 30.16948080444336
Reached batch 4500, running_loss was 31.073651790618896
Reached batch 4750, running_loss was 30.700766967773436
Reached batch 5000, running_loss was 28.802714279174804
Reached batch 5250, running_loss was 28.676991682052613
Reached batch 5500, running_loss was 27.115360122680663
Reached batch 5750, running_loss was 27.20308097076416
Reached batch 6000, running_loss was 29.347605529785156
Reached batch 6250, running_loss was 28.9004223518371

KeyboardInterrupt: 

In [50]:
val_dataset = SequenceDatasetForTransformerTesting(val_sequences)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=True)

In [43]:
def create_null_mask_for_batch(in_seq, out_seq, is_self_attn=True):
    in_seq_len = in_seq.shape[1]
    out_seq_len = out_seq.shape[1]
    batch_size = in_seq.shape[0]
    null_indices = in_seq.isnan().nonzero()
    if len(null_indices) == 0:
        return torch.zeros((batch_size, out_seq_len, in_seq_len))
    else:
        mask = torch.zeros((batch_size, out_seq_len, in_seq_len))
        for i in range(batch_size):
            sequence_nulls =  null_indices[null_indices[:, 0] == i]
            sequence_nulls = sequence_nulls[:, 1]
            mask[i].index_fill_(1, sequence_nulls, float('-inf'))
            if is_self_attn:
                for index in sequence_nulls:
                    mask[i, index, index] = 0  # need this to prevent bug when null is in first position (of decoder self attn)
        return mask

In [44]:
def create_causality_mask_batch(dec_seq_len, bs):
    return TransformerModel.generate_square_subsequent_mask(dec_seq_len).unsqueeze(0).repeat(bs, 1, 1)

In [45]:
def create_batch_masks(enc_seq_batch, dec_seq_batch):
    # should be shape bs x seq_len
    enc_mask = create_null_mask_for_batch(enc_seq_batch, enc_seq_batch, is_self_attn=True)
    dec_sa_null_mask = create_null_mask_for_batch(dec_seq_batch, dec_seq_batch, is_self_attn=True)
    dec_sa_causality_mask = create_causality_mask_batch(dec_seq_batch.shape[1], dec_seq_batch.shape[0])
    dec_sa_mask = dec_sa_null_mask + dec_sa_causality_mask
    dec_enc_mask = create_null_mask_for_batch(enc_seq_batch, dec_seq_batch, is_self_attn=False)
    
    return enc_mask, dec_sa_mask, dec_enc_mask
    
    
    

In [174]:
# e, g, gi = next(iter(val_dataloader))
# print(e)

# d = e[:, -1].unsqueeze(-1)

# create_batch_masks(e, d)

In [54]:
# vanilla_transformer_with_nullmask_ = Transformer(d_model=64, nhead=nhead, num_encoder_layers=2, num_decoder_layers=2,
#                 dim_feedforward=256, dropout=0.1, activation=F.gelu)
vanilla_transformer_with_nullmask_ = Transformer(d_model=64, nhead=nhead, num_encoder_layers=3, num_decoder_layers=3,
                dim_feedforward=256, dropout=0.1, activation=F.gelu)
vanilla_transformer_with_nullmask_.load_state_dict(torch.load('/home/mbaroody/Programming/Time series/TimeSeriesExperiments/v4-3.pt'))

<All keys matched successfully>

In [48]:
def evaluate_model(model, eval_dataloader, encoder_causality=False):
    model.eval()
    with torch.no_grad():
        total_loss = 0
        batch_count = 0
        for i, val_data in enumerate(eval_dataloader):  # todo refactor code to make val/test stepthroughs easier, remove redundant


            encoder_input, ground_truth, ground_truth_null_indices = val_data

            decoder_input = encoder_input[:, -1].unsqueeze(-1)
            enc_mask = create_null_mask_for_batch(encoder_input, encoder_input, is_self_attn=True)
            if encoder_causality:
                enc_sa_mask = create_causality_mask_batch(encoder_input.shape[1], encoder_input.shape[0])
                enc_mask = enc_mask + enc_sa_mask

            encoder_input = torch.nan_to_num(encoder_input)

            ground_truth = torch.nan_to_num(ground_truth)

            batch_errors = torch.zeros((ground_truth.shape[0], ground_truth.shape[1]))
            batch_predictions = torch.zeros((ground_truth.shape[0], ground_truth.shape[1]))
    #         print('enc input')
    #         print(encoder_input)
    #         print('ground truth')
    #         print(ground_truth)
            for h in range(ground_truth.shape[1]):
                dec_sa_null_mask = create_null_mask_for_batch(decoder_input, decoder_input, is_self_attn=True)
                dec_sa_causality_mask = create_causality_mask_batch(decoder_input.shape[1], decoder_input.shape[0])
                dec_sa_mask = dec_sa_null_mask + dec_sa_causality_mask
                dec_enc_mask = create_null_mask_for_batch(encoder_input, decoder_input, is_self_attn=False)
                decoder_input_ = torch.nan_to_num(decoder_input)

                encoder_input_, decoder_input_, enc_mask_, \
                    dec_sa_mask, dec_enc_mask = format_input(encoder_input, decoder_input_, 
                                                             enc_mask, dec_sa_mask, 
                                                             dec_enc_mask, nhead)
    #             print('dec in')
    #             print(decoder_input_)
    #             print(dec_sa_mask)
    #             print(dec_enc_mask)
                outputs = model(encoder_input_, decoder_input_, enc_mask_, dec_sa_mask, 
                                                            dec_enc_mask)
    #             print('outputs')
    #             print(outputs)
                if h > 0:
                    timestep_outputs = outputs.squeeze().T[:, h]
                else:
                    timestep_outputs = outputs.reshape(outputs.shape[0], outputs.shape[1]).T[:, h]
    #             print('latest')
    #             print(timestep_outputs)
                batch_predictions[:, h] = timestep_outputs
                timestep_batch_loss = (ground_truth[:, h] - timestep_outputs) ** 2
                batch_errors[:, h] = timestep_batch_loss

                decoder_input = torch.cat((decoder_input, timestep_outputs.unsqueeze(-1)), 1)

            losses_nonull_per_sequence = torch.sum(batch_errors, 1)
            batch_divide = ground_truth_null_indices.shape[1] - torch.sum(1 - ground_truth_null_indices, 1)
            error_nonull = torch.mean(losses_nonull_per_sequence / batch_divide)


            print('batch_predictions')
            print(batch_predictions)
            print('ground truth')
            print(ground_truth)
    #         print('batch error')
    #         print(error_nonull)
            batch_count += 1
            total_loss += error_nonull
        avg_loss = total_loss / batch_count
        return avg_loss

            
      
        

In [51]:
evaluate_model(vanilla_transformer_with_nullmask_, val_dataloader, encoder_causality=encoder_causality)

batch_predictions
tensor([[39.7324, 40.1827, 33.4770, 29.3647, 29.0814, 28.2421, 36.2364],
        [23.4058, 33.1547, 40.3623, 35.1886, 26.9763, 22.1588, 22.7864],
        [16.3108, 13.0066, 15.7317, 14.9797, 17.1447, 20.9447, 30.3867],
        [17.2823, 20.3815, 29.9368, 30.1558, 16.4632, 17.0013, 14.7735],
        [18.2353, 24.2265, 34.2950, 29.7136, 16.0194, 18.3915, 17.7236],
        [14.3601, 15.6162, 14.3060, 15.4450, 21.1630, 27.5691, 21.3094],
        [16.9266, 18.8469, 25.7142, 29.9132, 25.2900, 18.4843, 19.3121],
        [13.7636, 17.3571, 24.3512, 16.4217, 12.4746, 11.9286, 11.9821]])
ground truth
tensor([[ 2.7920,  4.5777, 33.6876, 12.1457, 22.4773, 23.1009, 32.9932],
        [21.7971, 40.3912, 56.8027, 45.4932, 37.8118, 30.4989, 26.0629],
        [31.3651, 12.4145, 17.9379, 19.3977, 14.1504, 21.8306, 35.9548],
        [20.1672, 11.7063, 37.7409, 48.9654, 22.7891, 20.2806,  9.1553],
        [14.7817, 30.5130, 44.4586,  0.0000, 15.2778, 18.5658, 20.4790],
        [ 0.0000,  

batch_predictions
tensor([[32.3337, 35.5807, 22.5141, 14.7143, 16.5846, 15.1842, 19.2186],
        [16.2578, 21.8546, 25.3921, 22.4328, 15.3888, 15.7359, 14.6118],
        [13.0849, 15.6244, 14.6949, 15.5995, 16.7833, 21.3619, 26.6184],
        [18.8399, 11.6734, 12.7592, 13.2151, 15.0331, 15.5408, 19.9649],
        [14.4486, 14.5702, 15.0183, 16.6305, 21.1939, 23.4572, 11.2312],
        [23.4987, 36.4754, 24.8249, 18.5034, 18.0603, 15.8117, 17.5707],
        [12.9021, 13.7842, 16.6863, 21.6792, 22.6988, 13.2371, 13.8243],
        [17.9878, 14.5884, 13.0957, 14.3285, 10.8224, 12.3311, 18.1670]])
ground truth
tensor([[31.1366, 32.4830, 26.9558,  9.8356, 12.1882, 13.6480, 14.1582],
        [17.1094, 18.8059, 30.0894, 22.3435, 16.9122, 15.7286, 18.1483],
        [12.6907, 17.1620, 12.3225, 14.8738, 19.3056, 24.7107, 29.0505],
        [27.7354, 16.0442, 16.2941, 11.8622, 11.0600, 12.8748, 21.5676],
        [13.0063, 11.9279, 15.2551, 14.7422, 23.1326, 25.9206, 11.9805],
        [29.3084, 3

tensor([[ 5.1623, 11.0138, 11.1697, 11.4073, 12.6343, 17.3518, 17.6539],
        [19.0050, 20.9226, 30.5079, 40.8449, 34.0592, 21.0600, 20.7165],
        [25.3500, 33.3030, 31.6987, 18.5864, 20.4046, 20.4195, 24.9651],
        [13.1794, 12.6658, 12.9785, 17.5768, 19.1716, 10.6787, 10.9536],
        [26.8457, 29.4347, 33.6902, 35.5428, 30.0075, 23.5333, 24.6579],
        [19.3830, 14.6829, 12.3386, 13.2239, 14.5474, 20.2126, 28.8895],
        [29.3740, 18.4630, 17.2197, 16.4155, 16.9308, 22.6183, 33.2261],
        [30.5825, 29.7803, 39.0148, 49.5706, 47.7942, 38.0658, 33.3081]])
ground truth
tensor([[ 5.7733, 16.5702, 12.9537, 12.0331, 18.6349, 21.7517, 19.3714],
        [15.9297, 18.0697, 34.8356, 58.2766, 45.3940, 33.0074, 30.0028],
        [ 4.1163,  3.1299, 17.9774,  7.5092, 19.6870, 20.6470, 25.3682],
        [12.9393, 11.8197, 12.1599, 17.0210, 24.4898, 21.7545, 10.8844],
        [21.3309, 27.4855, 31.6544, 39.1110, 24.6318, 17.5565, 30.8127],
        [20.6633, 32.1995, 24.5181, 3

batch_predictions
tensor([[29.3653, 31.9211, 40.3486, 49.0798, 45.4519, 37.0345, 32.8064],
        [10.8695, 11.9120, 14.2152, 11.4412,  9.2064,  9.8334,  8.5330],
        [13.2112, 16.7027, 19.4301, 17.1246, 12.9196, 13.9446, 13.2137],
        [16.1685, 15.5754, 16.2718, 18.6991, 19.1720, 11.8951, 14.9142],
        [11.5821, 15.8763, 15.4009, 15.2688, 17.5660, 25.6894, 27.1236],
        [12.9859, 12.6227, 12.7537, 13.4186, 14.5666, 18.6033, 14.2114],
        [16.5524, 16.3822, 23.4260, 31.4476, 32.3197, 17.8276, 18.2906],
        [12.5209, 12.1556, 12.5165, 15.0150, 19.9532, 19.6483, 11.5191]])
ground truth
tensor([[52.3668, 75.9495, 14.1440, 11.0402, 25.2409, 19.1185, 28.7557],
        [10.9548,  8.3114, 16.7675, 14.8080,  8.5876,  8.5481,  8.1010],
        [ 9.7317, 16.3861, 23.8164, 22.7249, 14.9527, 18.7007, 12.0068],
        [14.2162, 16.6228, 15.9653, 23.3956, 25.8548, 11.5729, 14.4003],
        [ 9.9947, 18.5692, 19.3188, 20.0026, 17.7801, 28.1694, 30.2604],
        [12.9274, 1

batch_predictions
tensor([[28.3612, 11.9317, 15.2749, 15.0651, 15.7491, 21.7815, 30.1640],
        [16.9335, 21.9890, 26.8988, 14.6581, 15.3547, 14.7495, 14.6566],
        [10.0856,  7.9993,  7.0102,  6.7958,  7.0137,  8.0387, 10.8899],
        [18.2539, 18.7539, 20.3553, 25.3431, 26.3219, 14.9034, 18.9088],
        [23.2148, 22.7592, 20.4957, 21.9716, 20.8957, 15.1053, 14.9535],
        [17.9745, 17.6300, 18.4353, 24.9910, 35.6650, 35.3902, 18.9036],
        [25.3239, 16.7362, 16.4287, 19.1841, 19.5858, 20.4090, 24.0727],
        [33.4546, 31.4880, 13.7611, 17.1232, 16.1972, 18.0510, 24.7028]])
ground truth
tensor([[41.3730, 10.4550, 17.1357, 13.6639, 11.9542, 26.4729, 38.1773],
        [17.4908, 27.5513, 28.7217, 14.9921, 15.3209, 18.2272, 16.2415],
        [ 9.9348,  5.6831,  4.4218,  3.6706,  4.4076,  6.3917,  9.6372],
        [30.2209, 30.6812,  3.9584,  5.0237, 13.2036,  6.6412, 18.0037],
        [17.3330, 25.2236, 20.0815, 26.1441, 24.1320,  8.4298, 12.1646],
        [19.0760, 1

batch_predictions
tensor([[22.1609, 19.5094, 22.0340, 21.8726, 22.1465, 24.3919, 30.5809],
        [29.8190, 22.9160, 15.6101, 15.6996, 15.5572, 15.4981, 23.9923],
        [23.9957, 14.0403, 16.2473, 16.1003, 16.3835, 20.4313, 23.3403],
        [15.4615, 16.3998, 23.5328, 27.3940, 20.4434, 17.3980, 15.2884],
        [18.8260, 23.5861,  8.2290, 14.0065, 15.3653, 15.3453, 18.1367],
        [25.4242, 31.0204, 25.4767, 16.1743, 17.8898, 17.7176, 18.6939],
        [15.5842, 19.8677, 23.8304, 21.6449, 14.6306, 14.2666, 13.8942],
        [39.3065, 25.8855, 18.4621, 17.1820, 15.3887, 16.5489, 27.2026]])
ground truth
tensor([[18.7075, 19.4728, 20.6916, 19.5862,  5.8248, 13.1378,  0.0000],
        [28.6848, 24.6882, 14.6542, 16.0006, 16.5675, 15.7313, 23.2285],
        [22.1088, 14.1723, 13.5913, 12.1173, 13.6196, 19.8129, 23.0159],
        [14.1241, 18.0168, 22.2514, 26.3414, 16.8990, 11.0337, 14.2820],
        [22.0522, 37.4717, 14.4274,  0.0000, 24.6740, 28.8974, 40.1644],
        [28.8974, 4

tensor([[12.4277, 31.2467, 38.5061, 28.3140, 18.9506, 22.6197, 24.3819],
        [12.4671, 10.4024, 13.7165, 12.9406, 29.0242,  0.0000, 14.1373],
        [ 7.2988,  7.8373, 12.9819,  7.4972, 16.1423, 16.3832, 25.6661],
        [30.4445, 14.0847, 14.7685, 11.7833, 12.5592, 16.5834, 22.7906],
        [ 7.5822,  3.8124,  5.7965,  6.4768,  6.2925,  9.7364, 11.6355],
        [11.4282, 11.5860, 17.7801, 15.1105, 21.3966, 27.7354, 16.0442],
        [ 7.3251,  6.2730,  9.3766, 14.2688, 14.0715,  6.2993, 10.0868],
        [10.8233, 12.9537, 20.1473, 18.1615,  9.9816,  8.5876, 10.3630]])
batch_predictions
tensor([[39.2989, 23.9463, 16.5078, 15.6814, 14.9813, 17.8119, 28.5187],
        [20.5542, 22.1969, 20.8527, 21.4131, 21.3359, 15.6481, 15.4321],
        [20.1549, 22.7557, 11.6212, 12.9971, 13.3474, 12.2312, 13.7501],
        [31.4624, 22.3573, 20.4395, 19.8368, 20.8648, 23.5155, 37.6328],
        [ 9.8541,  6.3302,  7.9080,  8.4708,  9.3697, 15.4043, 17.4677],
        [21.2910, 22.0861, 25.62

batch_predictions
tensor([[11.4236,  9.2684,  9.9041, 11.6624, 13.8844, 18.0065, 18.6999],
        [11.8696, 15.9439, 20.0870, 10.5184, 10.5660, 10.8215,  9.5110],
        [31.8890, 16.9213, 18.6681, 18.0544, 19.5346, 28.9673, 38.8970],
        [18.8395, 12.7125, 12.9270, 12.1055, 12.0978, 14.7910, 20.5993],
        [25.4865, 37.3103, 43.7034, 30.2334, 26.2097, 23.4074, 19.9511],
        [22.2574, 18.8824, 12.2017, 12.9177, 13.0003, 12.8358, 18.4220],
        [40.1761, 22.6281, 19.7461, 17.4337, 16.5957, 20.7675, 34.1213],
        [11.8034, 12.4972, 15.1760, 22.3338, 26.4098, 15.9640,  9.6251]])
ground truth
tensor([[ 9.9816,  8.5876, 10.3630,  9.4029, 11.8490, 22.8169, 20.8969],
        [10.5076, 14.5713, 20.6733, 20.8048, 14.9527, 13.0326, 14.6633],
        [37.0323, 17.1060, 23.9371, 20.3940, 17.8855, 32.3554, 36.8056],
        [19.2035, 11.1678, 13.7755, 11.1536, 12.5992, 20.0113, 28.6706],
        [ 5.8673, 54.6202, 54.9603, 45.8192, 10.0907, 19.3878, 22.5624],
        [25.0992, 2

batch_predictions
tensor([[12.6879, 17.2633, 22.0848, 21.5515,  7.7303, 12.2250, 11.7632],
        [21.0353, 22.8544, 24.1400, 31.2557, 38.1493, 40.9911, 30.4048],
        [22.0964, 21.0753, 21.6839, 28.4714, 28.8065, 20.3153, 22.6259],
        [19.0022, 20.1459, 10.5290,  9.7363,  9.0603, 10.0812, 12.7260],
        [23.8394, 24.3981, 32.7967, 43.7479, 38.8018, 29.1852, 25.3302],
        [30.3939, 10.6672, 15.5567, 15.1495, 16.3662, 19.7211, 27.9434],
        [30.1008, 18.7179, 18.4314, 18.3909, 20.5226, 27.6131, 38.0504],
        [10.3131, 11.3045, 11.0113, 11.2117, 12.7450, 18.0534, 19.6630]])
ground truth
tensor([[11.4808, 11.1389, 18.8585, 18.7533,  4.1426,  8.0878, 11.9542],
        [21.6070, 24.2109, 21.3309, 27.4855, 31.6544, 39.1110, 24.6318],
        [21.7649, 16.7675, 23.1589, 38.2036, 19.1478,  0.0000, 17.8985],
        [18.7138,  0.0000,  9.5739,  8.7849,  8.4035, 10.4945, 12.0857],
        [22.7466,  6.1508, 17.4887,  0.0000,  0.0000, 33.5743, 23.8946],
        [ 0.0000,  

batch_predictions
tensor([[16.7955, 14.2584,  8.8815,  9.1153,  9.5123, 11.9622, 12.6929],
        [10.7114, 10.8124, 14.9421, 15.5436,  6.7374,  8.5323,  7.4895],
        [ 7.2333,  7.0770,  7.9450,  8.6884, 10.5364, 13.6967, 14.0846],
        [35.2306, 43.8199, 38.7707, 25.4445, 26.9621, 25.4266, 26.2682],
        [38.1037, 19.8100, 20.3482, 18.2965, 18.1686, 30.0044, 43.1592],
        [18.3675, 20.0247, 27.3742, 41.1791, 38.4795, 22.7762, 18.7656],
        [14.6757, 14.7392, 14.9536, 15.4978, 18.9437, 23.2236, 10.0167],
        [15.1472, 14.0904, 16.7007, 24.6173, 30.5078, 13.2520, 16.4708]])
ground truth
tensor([[16.8332,  3.9979,  7.8906,  9.4555,  8.9295, 14.1636, 11.0994],
        [11.4019, 14.0321, 18.1220, 17.3856,  6.8122, 10.9679,  7.2856],
        [ 9.5345,  8.7454,  1.7228,  5.6812,  9.2451, 18.1089, 24.4082],
        [49.0363, 68.0272, 63.6054, 19.1185, 47.9450, 46.8963, 62.9393],
        [33.6168, 23.3702, 25.1701, 17.9138, 19.1327, 40.1077, 72.9592],
        [ 8.4751, 2

batch_predictions
tensor([[14.6322, 18.0253, 19.1931, 12.8421, 11.4821, 11.1474, 10.6002],
        [22.7154, 24.0219, 11.2870, 11.7079, 12.4850, 13.0863, 16.8279],
        [16.5722, 16.9463,  9.7843, 10.9880, 10.1045, 10.8132, 11.7472],
        [12.7585, 13.8584, 14.0403, 14.2592, 15.2855, 24.1770, 24.5449],
        [14.4107, 15.5512, 15.5167, 21.1391, 21.7279, 14.2583, 14.5435],
        [16.0306, 19.8892, 32.0075, 29.2145, 19.1580, 17.0235, 15.7735],
        [14.0528, 17.5011, 15.8665, 16.7258, 26.5871, 37.1346, 19.5369],
        [ 9.2799,  9.7101,  9.4069, 11.2474, 12.9213, 15.6535, 19.3332]])
ground truth
tensor([[16.8367, 17.8997, 27.4802, 46.6412, 32.0295, 17.2194, 27.6077],
        [25.3543, 25.9354, 19.3169, 10.3175, 12.2874, 10.6859, 16.8226],
        [14.6239,  5.7996,  0.0000,  4.5897,  7.9563, 10.9548,  8.3114],
        [ 0.3401,  8.7868, 12.7976, 11.7914, 17.8288, 21.5703, 26.4456],
        [11.7175, 13.7954, 19.1478, 30.4314, 27.6170, 11.2309, 12.9669],
        [22.1797, 2

batch_predictions
tensor([[26.7408, 33.8729, 29.1019, 23.5018, 26.3616, 26.7260, 24.9897],
        [12.5957, 13.2639, 14.4556, 16.3080, 21.8713, 19.6255, 10.3651],
        [19.4019, 22.9393, 22.3206, 14.9752, 16.3433, 15.4794, 15.3395],
        [18.7604, 23.1570, 19.5095, 15.6945, 15.7795, 14.4291, 14.3185],
        [44.7145, 31.0492, 26.2382, 25.5733, 28.0777, 39.3350, 49.7726],
        [11.7602, 14.6522, 16.1281, 11.2243,  9.1626, 11.6407, 10.8459],
        [22.9869, 21.4208, 16.4712, 16.6184, 15.7610, 16.2400, 18.9102],
        [15.0672, 16.2447, 16.3841, 18.2151, 28.8374, 28.1801,  9.3406]])
ground truth
tensor([[32.2199,  5.5366, 67.1489, 27.4987, 37.9011, 30.1157, 29.1426],
        [13.2937, 12.1173,  1.9416, 19.1043, 26.3605, 22.8600, 10.4025],
        [23.9371, 22.3214, 24.5890, 13.3645, 18.3532, 16.9785, 16.4966],
        [18.5516, 32.9082, 21.2160, 16.1848, 14.0164, 12.0748, 14.7109],
        [58.5176, 40.3912, 48.6961, 52.3668, 75.9495, 14.1440, 11.0402],
        [ 1.8280, 1

batch_predictions
tensor([[18.5645, 20.7017, 22.8088, 21.9301, 15.1972, 16.6552, 15.5886],
        [43.8710, 26.9111, 27.2661, 25.7746, 26.3661, 37.5693, 48.1209],
        [13.7670, 16.1510, 19.6239, 24.8489, 12.6683, 12.9396, 13.2550],
        [11.0690, 13.9709, 14.5354, 15.5873, 17.5856, 23.6078,  9.7417],
        [21.2104, 23.6159, 14.9449, 14.4929, 13.7001, 13.8312, 15.0217],
        [23.4632, 21.7540, 28.7775, 39.0457, 47.4944, 39.9183, 27.3477],
        [11.7230, 11.8581, 15.9934, 18.6397, 23.0601, 11.7445, 13.3915],
        [24.6070, 25.9873, 13.6510, 15.9642, 15.7119, 15.4495, 18.1614]])
ground truth
tensor([[17.1060, 21.2585, 30.7540, 28.2029, 26.9274, 28.3588, 20.7341],
        [63.6054, 19.1185, 47.9450, 46.8963, 62.9393,  2.6786,  3.9683],
        [14.8343, 14.0978, 21.3572, 34.4556, 29.1557, 12.7038, 13.6639],
        [13.6480, 14.2857, 16.0856, 20.9751, 15.3345, 23.7387,  2.2251],
        [22.2383, 19.5686, 10.7706, 11.4282, 11.5860, 17.7801, 15.1105],
        [26.7857, 2

batch_predictions
tensor([[14.7714, 14.1667, 15.7822, 20.2529, 29.2354, 22.4329, 13.0310],
        [13.3221, 14.4049, 14.5482, 16.9437, 20.6565, 22.6245, 11.8747],
        [22.6698, 21.4696, 23.8444, 31.2463, 39.6559, 34.3384, 20.5771],
        [19.0707, 15.9049, 16.2310, 16.3860, 20.1164, 29.7374, 27.2965],
        [21.5757, 15.4922, 14.0770, 13.8469, 14.3929, 16.6834, 23.0077],
        [15.7478, 19.9543, 28.1508, 21.6249, 14.5776, 14.5832, 13.1213],
        [12.2594, 12.8059, 15.8650,  8.7898,  8.4495,  9.0846,  8.3527],
        [24.2183, 15.7584, 14.4542, 13.6807, 13.0126, 13.4394, 17.8176]])
ground truth
tensor([[12.4858,  7.1995, 14.3849, 20.8050, 30.8107,  5.5414, 13.6480],
        [15.7171, 16.0006, 16.6525, 18.3957, 26.1621, 25.3118, 12.1599],
        [21.0034, 18.4099, 25.0992, 35.1474, 35.7001,  3.0329, 24.3622],
        [24.6599, 24.0788, 16.5249, 19.6429, 23.9229, 32.4405, 25.7937],
        [ 0.0000, 11.9473, 14.3424, 10.3175, 15.4478, 14.5408, 22.3923],
        [14.1440, 1

batch_predictions
tensor([[ 8.1858, 10.0464, 11.2116, 12.1137, 15.1512, 17.1112,  5.4782],
        [19.0425, 11.0459, 14.1980, 15.4629, 15.7993, 24.0505, 24.3533],
        [17.6002, 17.0940, 17.2929, 18.4405, 22.6363, 23.6128, 11.0507],
        [14.7663, 14.9467, 16.0288, 19.9047, 21.9581, 11.7341, 13.7088],
        [20.1080, 15.8920, 15.2899, 15.2411, 15.4223, 21.0813, 25.8213],
        [31.6104, 14.0616, 17.6529, 16.8306, 17.9601, 29.3398, 44.0439],
        [21.5620, 25.1925, 26.8474, 14.2457, 16.7714, 15.6760, 16.5890],
        [26.1306, 33.5831, 39.4290, 35.1589, 28.8671, 23.9789, 25.5242]])
ground truth
tensor([[ 9.8369, 11.4150,  7.4829, 13.3351,  7.9169, 22.0805,  6.1152],
        [70.8758, 13.2795, 21.0034, 25.9212, 18.5091, 34.6088, 42.7721],
        [18.8716, 16.0968, 21.0021, 16.8201, 23.5797, 26.1836, 12.0200],
        [14.2031,  6.9174, 15.1631, 23.5797, 23.3824,  9.6660, 15.7549],
        [22.2506, 11.8056, 13.8322, 15.8730, 13.8464, 21.1026, 27.5510],
        [43.1122,  

batch_predictions
tensor([[14.7900, 16.0278, 18.5902, 22.1061, 24.1048, 13.6551, 17.0860],
        [13.2408, 15.2935, 13.9020, 14.3648, 17.4572, 21.8937, 19.8109],
        [11.8118, 14.0770, 16.5433, 16.7101,  9.1700,  9.8402, 11.7562],
        [18.5047, 15.0441, 15.0256, 14.8039, 15.7890, 22.7568, 35.5508],
        [25.1259, 25.7854, 34.6128, 44.9276, 41.7363, 31.3418, 27.2581],
        [36.5123, 33.0872, 20.8500, 18.9923, 18.9929, 22.7636, 30.8241],
        [16.4706, 17.3632, 15.5939, 15.3406, 17.5094, 23.1289, 27.8200],
        [20.1763, 18.8125, 16.3817, 16.2976, 18.6692, 23.3538, 29.4272]])
ground truth
tensor([[13.7559, 14.2294, 18.6481, 28.0510,  0.0000, 12.2962, 17.3461],
        [10.8418, 13.8747, 11.2812,  9.3679, 16.7800, 22.0238,  0.0000],
        [21.5281,  2.9721,  4.4582,  7.6670,  4.8922,  8.7322,  8.9690],
        [33.4042, 14.6967, 13.1378, 13.9739, 15.6179, 24.2630, 44.6003],
        [23.8520, 27.3526, 38.0102, 56.3350, 52.3526, 33.9711, 31.3917],
        [58.7849, 6

batch_predictions
tensor([[24.4291, 19.5245, 15.6527, 14.5019, 15.8540, 18.7750, 26.0451],
        [17.4072, 20.6231, 27.7833, 34.1751, 17.6273, 18.5250, 18.7597],
        [19.1919, 18.7225, 24.3093, 21.6853, 15.6888, 15.8661, 16.5791],
        [25.5884, 31.6534, 20.1265, 17.7813, 17.7896, 17.2091, 21.7533],
        [16.5975, 13.0043, 14.4897, 14.3991, 16.2260, 19.5680, 26.5747],
        [16.3423, 17.0881, 18.8257, 22.5701, 19.4924, 14.7240, 15.7439],
        [19.3921, 19.3451, 12.8156, 11.0482, 11.7530, 13.7946, 16.4935],
        [23.4462, 29.9123, 37.3314, 30.0271, 22.5551, 21.0519, 19.3069]])
ground truth
tensor([[29.5635, 19.7279, 15.9155, 13.5204, 14.5408, 19.8129, 23.2143],
        [ 8.5459, 20.1105, 32.0720, 45.4365, 40.3628,  0.0000, 19.3878],
        [17.9248, 16.8201, 23.4745, 29.3661, 21.1336, 25.9074, 15.4261],
        [37.7834, 43.4666, 16.0856, 36.9331, 38.2937, 42.8005,  5.6406],
        [23.7812, 14.8101, 15.8447, 15.1219, 14.7534, 19.8554,  3.2880],
        [18.0981, 1

batch_predictions
tensor([[17.1015, 17.7401, 10.9477, 12.2859, 11.9722, 13.1187, 15.1716],
        [ 8.5700, 10.8107, 13.9115, 17.2207, 17.5883, 10.3418,  9.1740],
        [16.9503, 19.1637, 28.0987, 30.2601, 24.7131, 16.4030, 16.5163],
        [21.6012, 23.2315,  9.7793, 14.1494, 13.3011, 13.3246, 15.3046],
        [15.8871, 14.6877, 14.3829, 18.4442, 25.3280, 24.1790, 12.6322],
        [24.0740, 25.0012, 15.0069, 16.6156, 15.7295, 16.6731, 20.7401],
        [18.8968, 19.3226, 22.0953, 25.0406, 23.8490, 12.2517, 17.3605],
        [32.2958, 23.5524,  9.1839, 14.5793, 14.2258, 15.3228, 24.3188]])
ground truth
tensor([[19.0558, 28.3140, 11.6123, 13.1641,  1.1704, 28.4850, 18.8453],
        [10.3630,  9.4029, 11.8490, 22.8169, 20.8969,  9.9027,  9.6660],
        [13.6054, 15.0510, 22.2080, 33.0074,  9.8073, 15.3061, 11.9189],
        [22.1372,  0.0000,  8.4184, 12.4433, 10.2041, 13.9739, 14.9943],
        [ 7.1570, 12.4433, 11.7630, 25.6803, 36.7914, 35.6434, 21.9246],
        [30.7540, 2

batch_predictions
tensor([[12.6176, 12.6709, 12.5078, 13.7290, 22.1354, 19.4184,  7.8466],
        [25.2134, 20.4458, 14.9202, 14.5531, 14.6500, 18.4668, 23.1537],
        [24.9581, 23.0466, 13.1140, 14.5345, 14.3413, 15.6757, 20.5770],
        [ 8.7153, 10.8291, 11.2793, 12.9275, 14.5521, 10.5657,  6.4811],
        [13.2758, 12.8039, 13.6458, 20.0051, 26.4686, 22.2567, 11.6232],
        [20.2711, 25.9244, 12.9897, 15.0219, 14.4298, 14.9654, 21.1407],
        [13.1737, 14.7496, 13.7763, 10.1517,  8.7712, 12.7568, 12.8559],
        [ 5.7001,  5.6006,  7.6524,  9.0486, 10.2098, 12.6330, 14.0146]])
ground truth
tensor([[ 9.4555, 10.3235, 11.6386, 12.4803, 26.2362,  0.0000,  9.3240],
        [37.8880, 46.7386, 33.6270, 28.3009, 26.4598, 33.1405,  5.3656],
        [27.5227, 25.1701, 12.2449, 15.6037, 19.3311, 13.9314, 19.5153],
        [15.9439, 16.5675, 12.7551, 11.4796, 21.0176, 13.9314,  0.8929],
        [13.1115,  1.8411, 15.9127, 20.3183, 32.1410,  5.1683,  0.0000],
        [25.1134, 2

batch_predictions
tensor([[15.6766, 14.7346, 15.4953, 17.0439, 24.4282, 27.7984, 14.0459],
        [12.2572, 13.7909, 14.3928, 16.9248, 17.7492, 27.0623, 17.1737],
        [12.1034, 13.7482, 13.0737, 13.4997, 18.3033, 23.1624, 22.5510],
        [17.4231, 17.9330, 19.8564, 24.3817, 27.6364, 12.5881, 17.8769],
        [17.7000, 18.4587, 10.8561, 11.3099, 12.0331, 12.9630, 13.2424],
        [17.4567, 17.2358, 14.3950, 16.3931, 21.9569, 26.0557, 24.0639],
        [11.8481, 16.8022, 15.9624, 15.7260, 19.1697, 31.2181, 30.5889],
        [29.1429, 41.3850, 30.3661, 23.6233, 23.3762, 22.0902, 23.1733]])
ground truth
tensor([[15.9390, 15.6234, 17.9774, 18.0957, 26.3019, 31.4703, 13.2956],
        [ 0.0000,  8.9002, 13.1519, 22.1514, 21.1735, 32.9932,  7.1287],
        [11.5221, 18.4807, 13.4495, 12.3158, 24.6032, 25.1559,  7.8515],
        [21.7687, 21.9955, 23.3277, 31.9161, 34.7931, 10.5584, 22.6757],
        [16.7017, 15.9784,  7.5881, 10.0999,  7.0621, 11.7570, 10.0868],
        [17.2409, 1

batch_predictions
tensor([[26.8300, 15.8943, 16.6433, 15.9276, 16.6567, 19.7187, 26.4530],
        [32.6784, 32.3641, 23.8505, 22.5908, 23.6538, 30.1082, 41.9672],
        [23.4969, 25.5080, 33.0296, 38.7247, 30.8706, 23.9095, 25.7539],
        [22.3603, 31.6412, 42.1002, 35.5816, 26.3515, 22.1329, 20.0196],
        [13.0867, 15.0115, 14.8365, 16.0862, 21.1573, 24.8259, 11.8762],
        [13.2460, 18.2292, 20.9539, 22.4990, 16.7047, 13.6504, 12.9489],
        [10.6308, 13.8554, 14.1266, 17.3351, 21.7056, 22.0800, 12.6854],
        [32.0848, 16.4605, 15.8136, 16.4972, 17.0484, 22.3393, 29.3605]])
ground truth
tensor([[27.0516, 13.8480, 14.9790, 15.7549, 15.2946, 14.9001, 30.8259],
        [39.6684, 24.7874, 25.0425, 20.5924, 22.9592, 36.5363, 50.4960],
        [52.7486, 52.0516, 34.5213, 41.8201, 29.0768, 21.6070, 24.2109],
        [26.3747, 37.8118, 49.3764, 35.9269, 27.8203, 23.4269, 19.1327],
        [ 0.0000,  9.2846, 13.9663, 14.5450, 28.1036, 25.1973, 10.1789],
        [10.0473, 2

batch_predictions
tensor([[18.9632,  6.8660, 13.1148, 12.1500, 12.5738, 15.7420, 20.0671],
        [17.7090, 12.4279, 13.0957, 13.9059, 14.8371, 20.4062, 26.7120],
        [29.7902, 22.7428, 13.0699, 13.6694, 12.9961, 14.8065, 23.4834],
        [28.0120, 38.9981, 25.9588, 15.9858, 14.9226, 14.3871, 17.8597],
        [35.6190, 17.1907, 18.6716, 17.7145, 18.1813, 24.5537, 39.8045],
        [27.0655, 30.6357, 29.0097, 18.6232, 20.6951, 19.4363, 22.0123],
        [19.0816, 24.6057, 25.4437, 15.7799, 15.3933, 13.8389, 14.0948],
        [26.3424, 21.1871, 13.8678, 13.5303, 13.9073, 14.4993, 21.9026]])
ground truth
tensor([[19.3714,  6.9043, 23.4219, 24.4477, 23.8427,  1.8148,  2.8932],
        [24.7633, 15.6365, 17.3987, 16.2809, 16.3072, 22.4750, 26.0521],
        [32.2421,  0.0000, 12.9110, 11.8622, 11.3095, 13.0102, 24.3622],
        [19.9688,  0.0000, 15.2353, 11.9756, 12.0465, 14.3849, 14.6117],
        [ 5.1587, 11.3237, 18.7642, 19.3594, 17.8713, 30.5697, 47.4490],
        [28.2880, 3

batch_predictions
tensor([[11.0570, 13.9008, 19.1144, 27.2217, 23.3496, 11.7967, 12.9764],
        [20.5120, 31.6806, 42.6612, 32.9846, 23.0358, 21.1773, 18.8880],
        [12.8323, 12.5613, 12.2470, 15.2542, 19.1822, 21.4198, 11.5082],
        [33.4199, 36.3656, 25.5041, 21.8598, 21.4288, 23.2516, 25.7374],
        [40.2809, 28.7837, 17.5390, 17.2525, 16.5595, 20.8206, 34.0288],
        [21.0717, 26.2451, 19.2125, 16.3057, 16.4858, 14.2911, 17.6938],
        [21.1348, 21.8428, 16.5844, 15.4495, 14.0628, 14.1517, 16.1789],
        [13.6647, 15.7994, 20.8147, 28.3134, 22.1121, 13.6940, 14.8096]])
ground truth
tensor([[ 5.5556,  5.5981, 20.8333, 33.9002,  0.0000,  8.7868,  8.4609],
        [19.4371, 33.1405, 47.6986, 39.0189, 25.3551, 18.7533, 15.1236],
        [12.3619, 10.7575, 12.9800, 15.7023, 28.0510, 25.4734, 14.5450],
        [22.6723, 25.6707, 21.9358, 18.8059, 13.9926, 14.6896, 14.9527],
        [48.0584,  0.0000, 18.5799, 17.4461, 13.1378, 22.2222, 33.9569],
        [16.0179, 2

batch_predictions
tensor([[14.8278, 13.3500, 15.1359, 23.3431, 28.7106, 20.0557, 14.7748],
        [44.9982, 38.3804, 24.2491, 24.7354, 24.1885, 24.5254, 35.2469],
        [ 9.7723, 14.4595, 14.5792, 14.9215, 16.8075, 23.9165, 26.2449],
        [23.0046, 28.7660, 32.2544, 17.0877, 17.1991, 15.6825, 16.7889],
        [28.0879, 23.8783, 12.6896, 14.8649, 15.2117, 17.1101, 21.0208],
        [10.8543, 13.4750, 16.2007, 11.4368,  8.8538,  8.7392,  8.0177],
        [23.9275, 21.7753, 24.4702, 36.0390, 38.5184, 30.4023, 24.1438],
        [22.2530, 26.2915, 21.3133, 14.1665, 14.6809, 14.1642, 16.6680]])
ground truth
tensor([[ 0.3685,  3.0045, 13.5629, 25.1134, 33.1066, 22.5198, 12.8260],
        [42.7438,  0.0000, 30.6122,  2.5510, 45.7625, 28.4155, 36.2670],
        [ 7.5255, 18.9342, 17.0210, 20.3090, 28.2738, 38.8039, 41.9218],
        [ 2.3803, 45.6865, 31.1678, 20.6733,  8.7717, 15.1894, 14.7422],
        [28.6706,  5.0454, 13.6480, 14.2857, 16.0856, 20.9751, 15.3345],
        [10.9836, 1

batch_predictions
tensor([[11.0812, 13.2800, 19.7876, 18.6722,  7.4892, 10.1012,  9.8138],
        [16.8779, 17.3648, 19.5053, 24.0075, 22.6267, 16.2999, 16.3695],
        [14.2926, 16.7392, 19.8532, 24.7885, 29.9830, 13.6304, 16.3657],
        [46.8630, 36.4356, 14.5961, 17.6210, 17.2362, 17.3794, 32.6657],
        [17.1098, 16.0632, 13.9563, 14.4387, 18.2920, 27.9172, 26.5127],
        [23.6128, 18.9913, 11.8553, 13.5707, 13.5822, 13.7657, 19.7608],
        [ 9.3526,  8.5984,  8.3768,  8.2582,  7.5634,  7.5126,  7.9856],
        [22.5093, 25.5151, 24.3739, 15.3810, 15.1143, 15.8836, 17.2965]])
ground truth
tensor([[10.8233, 14.4003, 31.9306, 31.9963,  0.4077, 13.4534, 29.0505],
        [11.6649, 16.8464, 17.3330, 24.6975, 28.2614, 18.2930, 18.4903],
        [10.0198, 16.4116, 21.1593, 37.1032,  4.1525, 12.5567, 14.4274],
        [42.2672, 42.6223, 14.6633, 23.7375, 24.0137, 14.1504, 35.9548],
        [19.3169, 26.7432, 26.6440, 30.1587,  5.4280,  3.8974, 12.6559],
        [28.3305, 2

batch_predictions
tensor([[20.5025, 21.0873, 26.3172, 32.1616, 35.7752, 26.6289, 18.7805],
        [36.9989, 27.1036, 17.7654, 16.0747, 15.2330, 15.2927, 23.4176],
        [10.1230,  9.9472, 10.6521, 13.6152, 16.5068, 11.6633,  9.8945],
        [10.4303, 11.2502, 14.9406, 18.2229, 18.0371,  8.7997, 11.0770],
        [20.5278, 12.2543, 13.4645, 13.5577, 13.7339, 18.1791, 24.3321],
        [21.1872, 22.9392, 27.4657, 12.8578, 13.2634, 14.0712, 15.6824],
        [16.0756, 16.7811, 18.2205, 23.0653, 27.1995, 30.0444, 12.3175],
        [37.2255, 43.7912, 30.7794, 23.1657, 23.0648, 21.6371, 27.5169]])
ground truth
tensor([[28.6428, 21.2783, 24.8685, 34.6660, 41.2941, 29.0637, 21.1073],
        [38.7330, 38.1519, 22.1230, 15.2636, 14.8668, 22.1797, 28.5998],
        [10.6260, 10.3498, 11.5992, 11.8490, 18.3456, 13.8217,  9.9947],
        [ 9.4950, 10.6391, 14.1767, 28.0905, 24.8159, 10.5602, 12.9274],
        [17.4776, 13.5850, 14.7817, 13.2036, 11.7964, 19.5555, 33.7322],
        [25.3682, 3

batch_predictions
tensor([[35.2794, 26.4047, 16.0256, 16.5894, 15.0339, 16.9393, 27.3946],
        [39.3366, 20.1556, 19.8011, 20.4758, 20.1572, 24.9976, 38.4803],
        [17.2027, 15.5944, 15.9895, 18.3437, 24.6483, 24.0379, 14.7509],
        [13.8544, 14.9590, 14.4507, 16.4618, 20.8040, 21.3693, 13.0043],
        [46.0873, 36.4967, 22.1026, 20.6364, 18.4953, 19.7033, 32.7863],
        [15.3133, 17.3670, 17.6127, 20.0392, 24.0896, 22.7930, 16.5280],
        [24.0194, 28.0914, 39.5869, 39.5298, 27.5358, 24.1866, 24.3443],
        [19.0569, 12.9717, 11.1538, 10.7813, 10.7903, 11.5801, 19.8652]])
ground truth
tensor([[38.7428, 33.0221,  0.0000, 14.9658, 18.3719, 16.6491, 28.8927],
        [32.0011,  0.0000,  8.8152, 21.0884,  8.0924, 26.4598, 37.5850],
        [14.7159,  8.0089, 15.8995,  5.2735, 29.2609, 25.7496, 14.6633],
        [13.6338, 13.9739, 13.4495, 17.0493, 28.0612, 23.8804, 12.6701],
        [47.6986, 39.0189, 25.3551, 18.7533, 15.1236,  5.2604, 14.0058],
        [15.7286, 1

batch_predictions
tensor([[21.2202, 19.8723,  8.1965, 10.4022, 10.1744, 10.4163, 14.0889],
        [13.6602, 14.6744, 14.5082, 14.4274, 16.9053, 27.0030, 23.5467],
        [22.9143, 30.1483, 21.6182, 13.6036, 14.0151, 12.9993, 15.9277],
        [14.6444, 14.9112, 20.1782, 25.3551, 23.1959, 13.7668, 16.7681],
        [15.5281, 19.6507, 23.3282, 22.7122, 13.9959, 14.0730, 13.4279],
        [19.6766, 21.8404, 29.6745, 25.5494, 15.8867, 14.8136, 15.9139],
        [ 9.4949,  9.5509, 15.6216, 16.4028, 17.1803, 13.4793, 21.6368],
        [20.2171, 16.2057, 16.5752, 16.3265, 17.5265, 20.2641, 25.5687]])
ground truth
tensor([[17.6223,  0.0000,  5.3787,  6.9174,  7.2725,  9.6397, 14.4661],
        [13.0102, 15.7880, 15.2494, 15.8163, 17.1202, 33.8861, 30.5414],
        [23.3985, 32.2421,  0.0000, 12.9110, 11.8622, 11.3095, 13.0102],
        [14.7251, 16.0573, 22.5340, 26.2472, 22.3781, 13.4354, 14.3141],
        [16.8990, 19.8580, 28.0379,  0.0000, 15.7154, 14.2162, 13.0458],
        [17.3895, 2

batch_predictions
tensor([[22.0234, 23.6730, 14.7641, 14.8581, 13.9885, 14.6794, 14.9605],
        [15.8642, 21.1652, 29.6694, 26.9678, 19.8882, 20.2215, 19.6982],
        [19.1443, 21.5401, 19.2058, 18.0431, 25.0970, 37.3788, 33.1340],
        [15.1312, 17.7681, 22.2596, 22.5019, 12.7910, 15.5747, 13.7672],
        [18.8190, 24.9525, 14.9438, 12.9303, 12.9666, 12.2942, 12.9746],
        [20.8525, 20.1736,  7.7712,  9.5615,  9.7613, 10.4633, 15.5964],
        [10.5021, 12.3505, 17.8719, 17.8063, 10.4787,  9.5870,  8.5729],
        [15.9568, 15.7322, 18.2018, 23.2393, 32.9130, 28.1104, 18.2522]])
ground truth
tensor([[22.6328, 22.8169, 14.6239, 16.0179, 11.7175, 13.7954, 19.1478],
        [30.1587,  5.4280,  3.8974, 12.6559,  7.8798, 14.6542, 12.4433],
        [14.5692, 17.0493, 16.8084, 18.9059, 25.0709, 43.5941,  5.1587],
        [15.5050, 17.7801, 23.8427, 26.2625, 11.9279, 19.5818, 15.3209],
        [17.8146, 26.4314,  0.0000, 12.9393, 13.5913, 12.7126, 11.1536],
        [21.9753, 1

batch_predictions
tensor([[13.7961, 14.4965, 19.6350, 18.8974, 13.1560, 10.5621, 11.0751],
        [16.4385, 18.3672, 25.1257, 25.8239, 14.5099, 15.9334, 14.6036],
        [13.0924, 14.9263, 15.4215, 14.8068, 18.4740, 23.7688, 32.5870],
        [27.3246, 10.5998, 12.2950, 12.4752, 11.2124, 13.4759, 21.1229],
        [13.7156, 21.0688, 23.6823, 18.2526, 11.3275, 13.4954, 12.0657],
        [12.8841, 12.9084, 15.7740, 13.3024,  6.8774, 10.0275,  9.2666],
        [20.3087, 20.2607, 17.3553, 16.1106, 19.4314, 23.8053, 29.5568],
        [22.5124, 11.0924, 11.9548, 12.4168, 13.4905, 16.4012, 24.8709]])
ground truth
tensor([[12.1910, 13.0589, 15.3735, 34.6397, 17.2672, 12.7038,  8.6139],
        [17.9774, 18.0957, 26.3019, 31.4703, 13.2956, 15.1894, 17.9116],
        [17.4745, 14.0023,  8.3192, 10.9552, 11.9615, 22.0805, 28.0896],
        [33.3759, 23.2001, 11.7347, 14.2290, 15.9014, 13.8322, 19.7562],
        [ 9.6655, 20.7483, 21.1876,  0.0000, 11.2387, 10.3741, 10.5867],
        [ 9.7843, 1

batch_predictions
tensor([[13.3348, 12.2950, 13.8494, 13.1558,  9.1212, 11.9912,  9.6305],
        [13.8037, 13.4085, 14.8234, 15.9790, 20.8012, 20.9918, 12.9563],
        [20.8843, 20.8026, 28.6750, 40.5977, 35.4615, 19.8753, 20.7347],
        [27.4517, 15.5301, 16.5467, 16.1050, 17.5589, 27.3791, 34.1371],
        [ 7.7931, 12.6885, 13.5696, 12.7145, 14.1480, 14.3679, 10.9947],
        [13.3147, 13.0122, 12.1544, 12.3562, 13.0489, 19.0722, 19.9085],
        [25.6277, 12.6042, 14.2574, 14.6421, 14.4438, 17.1028, 24.1917],
        [10.4766, 12.4589, 17.2661, 16.1617, 11.2049,  9.2745,  8.6408]])
ground truth
tensor([[13.6902, 14.5055, 20.3577, 14.2557,  4.6423, 13.0326,  7.8643],
        [13.8605, 10.4592, 12.0181, 19.0901, 25.8078, 24.8158, 17.0635],
        [14.2574, 16.5391, 32.5113, 50.6944, 31.3209, 14.6825, 18.3390],
        [28.1562, 17.0042, 17.8722,  8.9032, 16.7412, 27.8406, 29.7212],
        [ 9.9027, 10.7706,  8.7717, 10.7969, 12.7170, 16.5834,  3.2483],
        [11.4282, 1

batch_predictions
tensor([[17.6675, 18.5085, 20.7017, 27.1614, 28.2927, 13.0520, 17.0997],
        [24.5825, 34.4494, 19.0268, 16.6783, 16.3274, 16.4945, 17.9126],
        [21.9528, 20.8329, 16.8949, 17.4415, 18.2924, 22.8537, 31.3115],
        [16.8459, 16.4418, 16.9437, 19.4005, 25.1874, 25.5177, 13.7097],
        [ 8.0621, 12.8863, 14.2022, 13.8862, 15.9025, 25.1056, 26.6705],
        [21.5335, 20.2433, 12.7026, 14.1638, 14.1118, 16.4944, 18.6032],
        [15.7803, 14.8871, 16.6362, 19.1024, 22.1440, 23.2761, 13.9594],
        [15.8823, 23.4053, 21.7014, 27.4759, 38.0790, 47.0412, 40.3216]])
ground truth
tensor([[12.5000, 17.1202, 16.4541, 30.1587, 31.8736,  5.5981, 11.9331],
        [25.1973, 39.1110, 21.2388, 13.7691, 14.4003, 10.4682, 12.7433],
        [16.0705, 17.9116, 16.0442, 19.6870, 19.9369, 26.7491, 31.4966],
        [14.9790, 15.7549, 15.2946, 14.9001, 30.8259,  0.0000, 17.6881],
        [ 8.0484, 13.9532, 11.8885, 12.4803, 16.8201, 34.1662, 35.4682],
        [21.8043, 2

batch_predictions
tensor([[14.2931, 16.0551, 18.3835, 14.6627,  8.2952, 11.8480, 12.2031],
        [18.3564, 21.0900, 26.4507, 27.5558, 15.1920, 18.3293, 16.8420],
        [21.2179, 26.0393, 24.3913, 13.1495, 16.3696, 16.8224, 17.7389],
        [14.4064, 14.7631, 15.8893, 20.2233, 24.9679, 20.8439, 12.1665],
        [19.4834, 17.6092, 19.7279, 10.3849, 11.5926, 18.2925, 18.9761],
        [23.2483, 20.1036, 16.1842, 16.7084, 23.8856, 38.8997, 32.8803],
        [19.8055, 27.8669, 25.6527, 14.1171, 15.3327, 14.2458, 17.0266],
        [24.4894, 29.1358, 23.0767, 18.4898, 19.3971, 19.5134, 20.5320]])
ground truth
tensor([[15.5576, 22.4224, 30.8916, 30.1552,  6.3519, 10.2841,  9.1136],
        [18.2141,  2.2225, 41.3072, 36.4676, 15.3603, 17.6618, 16.3467],
        [ 3.4193,  3.2746, 12.4408,  5.6681, 12.5460, 13.8085, 20.4366],
        [14.2149, 13.1094, 15.4053, 24.6599, 24.7024, 25.0850, 14.2999],
        [19.9830, 28.8549, 25.9212,  8.0924,  4.4218, 24.6740, 22.5482],
        [23.8662, 1

batch_predictions
tensor([[32.8410, 21.9742, 20.8983, 19.4806, 22.1142, 30.8084, 37.3092],
        [21.7163, 26.6381, 25.2371, 13.9228, 14.7445, 15.4631, 16.4055],
        [10.7958, 12.1008, 13.0827, 15.9196, 20.1860, 17.2265,  8.6384],
        [15.7992, 20.2546, 24.4131, 20.6262, 15.0901, 14.7892, 15.1637],
        [14.6280, 19.3052, 17.3262, 24.2436, 32.4591, 27.1244, 17.5819],
        [26.4143, 13.2458, 15.4019, 14.9746, 15.7467, 19.7795, 24.7066],
        [16.6528, 19.4949, 18.9065, 17.7452, 22.0045, 30.9772, 30.9910],
        [23.3069, 22.3749, 35.3057, 47.6613, 42.6002, 26.7779, 25.3797]])
ground truth
tensor([[ 3.0329, 24.3622, 15.0794, 21.7971, 25.0425, 31.0374, 49.5607],
        [18.4949, 28.7698,  0.0000, 14.2715, 15.6463, 12.3299, 14.7676],
        [11.0337,  8.3509, 10.6654, 14.4003, 22.2120, 20.3051,  6.3519],
        [14.5318, 19.3319, 24.5003, 25.3025, 13.2825, 15.7154, 12.4145],
        [ 9.1695,  7.6814, 20.7200, 20.6066, 64.6825, 50.6944, 28.5998],
        [30.8522, 1

batch_predictions
tensor([[15.9420, 18.2944, 15.4107, 10.4393, 11.6015, 11.7365, 11.7776],
        [18.3961, 16.7391, 17.9676, 18.5508, 28.2953, 34.0191, 17.9219],
        [20.3331, 19.0213, 18.5921, 22.9372, 26.5886, 32.2819, 22.9039],
        [12.8673, 13.1954, 18.0781,  7.0850,  9.1525,  9.7338,  8.2883],
        [16.4809, 19.2187, 26.5067, 34.4471, 22.5218, 19.3330, 17.9302],
        [26.2457, 13.9850, 16.0082, 16.5357, 16.9695, 21.9709, 30.6339],
        [42.3782, 49.8666, 34.5657, 26.8726, 27.1002, 26.0533, 31.6133],
        [23.0258, 14.3543, 15.3654, 14.9540, 15.7810, 18.9392, 26.0898]])
ground truth
tensor([[17.0568, 18.1483, 21.3177,  9.2583, 12.4408, 13.1247, 11.9411],
        [19.2460, 13.4212, 18.9484, 28.9966, 50.9495,  0.0000, 19.9830],
        [29.2346, 19.4897, 19.8974, 28.8532, 46.4492, 44.3582, 24.2898],
        [20.9100, 16.9516, 21.2651,  3.4850, 12.0200,  7.2462, 10.2052],
        [10.0057, 18.3248, 31.3067, 48.4410, 30.7681, 22.1514, 17.7012],
        [35.7575, 1

batch_predictions
tensor([[29.8097, 23.0459, 17.3301, 16.1533, 15.1434, 16.3821, 21.8851],
        [13.5620, 14.1686, 15.4462, 18.9579, 25.7662, 25.7710,  8.9552],
        [20.0727, 18.7268,  7.5280, 13.1376, 12.0540, 12.0159, 15.8952],
        [29.4641, 14.1507, 15.6111, 15.0830, 15.1751, 21.0967, 33.3012],
        [15.6868, 14.4343, 15.0411, 17.4118, 24.1030, 20.7588, 14.6242],
        [10.4321, 10.6274, 10.5512, 11.2652, 13.4115, 16.5399, 14.4952],
        [12.8115, 12.5165, 14.1455, 18.5769, 23.6261, 22.7570, 12.1582],
        [32.3352, 12.8105, 16.9213, 18.2929, 19.3486, 25.5068, 36.2918]])
ground truth
tensor([[33.1349, 30.9666, 22.1088, 18.5232, 14.5125, 19.8696, 20.9325],
        [14.1156, 14.6967, 14.1156, 19.4161, 41.2982, 32.9649,  8.9994],
        [21.7517, 19.3714,  6.9043, 23.4219, 24.4477, 23.8427,  1.8148],
        [28.1463, 14.9235, 15.0368, 10.4875, 14.7109, 23.3135, 35.4450],
        [11.9898, 13.6763, 15.0935, 15.8588, 24.2489,  0.0000, 11.9473],
        [ 7.4961,  

batch_predictions
tensor([[28.6675, 20.6037, 15.6639, 18.4649, 21.0896, 26.9968, 34.6606],
        [12.2222, 14.0358, 14.2310, 15.1940, 19.0620, 21.6291, 10.8644],
        [30.4609, 40.9409, 42.7799, 32.7964, 31.0899, 27.7838, 27.5815],
        [32.7413, 16.7036, 19.2657, 18.1613, 17.7498, 20.8863, 32.4038],
        [37.6818, 34.6115, 21.9510, 21.7396, 20.6103, 24.2825, 32.8090],
        [16.4755, 24.5414, 21.7790, 15.0083, 14.6556, 16.7758, 18.7028],
        [22.6436, 31.8037, 27.6569, 17.6767, 17.5750, 16.4957, 16.2024],
        [37.3647, 28.1993, 18.1810, 16.5135, 17.4106, 20.4303, 30.3916]])
ground truth
tensor([[24.6712, 26.0784, 28.7217, 28.3798,  3.8532,  3.6954, 18.3062],
        [15.3077, 14.2031,  6.9174, 15.1631, 23.5797, 23.3824,  9.6660],
        [24.9433, 40.1077, 18.8067, 13.0102, 18.1973, 13.7472, 17.4036],
        [37.9393, 27.3668, 34.4388, 41.2557, 47.5198,  6.3917,  5.0312],
        [42.1202,  3.1321,  0.0000, 19.4586, 22.0947, 29.3651, 36.9756],
        [19.4897, 2

batch_predictions
tensor([[25.4951, 13.5874, 15.2298, 14.3352, 15.1757, 18.8822, 26.7666],
        [16.0261, 14.6353, 17.3634, 19.0139, 18.9602, 27.3271, 13.0072],
        [24.1301, 28.8351, 29.7675, 20.3867, 18.5088, 18.3205, 25.1453],
        [13.6350, 14.5513, 17.4465, 18.0390, 27.6531, 24.0140, 12.3400],
        [13.8245, 19.2771, 26.6360, 17.0863, 11.7534, 11.9704, 11.8684],
        [21.1332, 24.5961, 34.7709, 41.9628, 35.9396, 27.1708, 22.8152],
        [16.5570,  9.3912, 11.6530, 11.2880, 12.5132, 15.5956, 22.7634],
        [19.7925, 20.9715, 19.3183, 23.7048, 26.9273, 14.0766, 17.0404]])
ground truth
tensor([[33.0215, 21.6270, 16.4966, 11.5788,  8.8435, 21.4711, 37.8685],
        [13.1944,  9.6514, 15.3061, 21.0176, 33.3192, 36.5221, 11.2528],
        [50.9863, 33.9558, 34.8369, 23.1720, 19.2793, 16.2809, 18.3062],
        [ 8.9002, 13.1519, 22.1514, 21.1735, 32.9932,  7.1287,  0.0000],
        [10.1899, 21.7687, 32.6956, 21.1310, 10.0198, 10.9552,  8.6168],
        [19.1327, 2

batch_predictions
tensor([[16.0077, 15.8327, 18.6506, 20.9360, 13.2251, 12.2740, 18.2094],
        [20.1942, 27.1278, 29.8043, 15.0761, 16.8076, 17.1033, 16.1532],
        [35.2783, 27.5702, 16.6332, 16.8792, 16.1627, 17.0169, 24.1901],
        [23.1633, 29.2414, 19.2854, 14.9521, 15.8587, 16.2227, 20.0783],
        [26.3737, 23.4566, 12.7862, 14.1866, 12.8433, 12.1354, 20.9848],
        [26.4927, 23.4229, 12.6928, 14.7349, 14.6822, 15.3811, 22.1962],
        [29.6499, 41.8439, 32.4181, 22.0519, 19.9871, 18.5991, 20.9437],
        [17.5213, 21.7813, 11.6200, 13.0056, 13.0836, 10.7287, 13.5020]])
ground truth
tensor([[27.1305, 24.5266,  0.6707,  0.7233,  4.1294,  2.9721,  7.7328],
        [23.5534, 30.6812, 31.5623, 13.5850, 18.4640, 16.2415, 16.2415],
        [36.1820, 11.1253,  0.0000,  9.6088, 17.1344, 16.3124, 20.0539],
        [21.9104, 30.7256, 24.8016, 12.1882, 12.7551, 13.0385, 21.2302],
        [25.5244, 25.2693, 10.8560, 13.1236, 13.3645, 16.3265, 21.7687],
        [32.2137, 2

tensor([[13.1291, 13.2080,  9.9384,  9.9321, 12.0993, 15.7240, 11.5400],
        [20.0776, 24.5516, 24.9991, 21.2854, 20.1624, 16.7680, 16.0823],
        [13.2257, 15.2513, 15.9053, 18.6187, 24.7622, 22.3021, 12.7858],
        [14.8916, 15.6227, 18.1416, 20.3833, 23.8014, 13.1871, 15.5566],
        [15.3325, 17.6264, 22.5392, 16.7822, 11.5350, 12.8339, 12.0238],
        [31.4279, 22.8026, 25.7727, 31.0635, 38.1512, 44.3214, 40.0478],
        [15.9425, 16.6767, 15.9254, 16.7892, 21.0421, 29.5772, 17.2815],
        [17.2514, 18.5960, 21.3022, 28.8937, 26.3729, 18.4883, 17.8471]])
ground truth
tensor([[ 1.1967,  1.4861,  0.9732,  0.3288,  2.6959,  1.7359,  0.3814],
        [13.4212, 15.8163, 27.6786, 34.9490, 28.4439, 15.3061, 15.4053],
        [20.7483, 11.9189, 12.3299, 17.9280, 20.6066, 18.9909, 10.8418],
        [13.4637, 15.2211, 20.3090, 26.7857, 29.7336, 13.6196, 18.5658],
        [ 1.9416, 19.1043, 26.3605, 22.8600, 10.4025, 17.6587, 15.5329],
        [42.1344, 37.7126, 32.9649, 3

batch_predictions
tensor([[10.0340, 13.1151, 12.4739, 12.9062, 15.7727, 22.8741, 22.0068],
        [13.6057, 13.6272, 13.9742, 14.6113, 18.2713, 24.9006, 18.6116],
        [12.7064, 18.7719, 18.9470,  9.0966, 10.8005, 10.2028,  9.8814],
        [20.1969, 23.4841, 11.2781, 11.4181, 11.0084, 11.0831, 14.6973],
        [21.1727, 20.1517, 18.3441, 14.2747, 15.4144, 22.3869, 22.1305],
        [30.4870, 27.3136, 15.2504, 15.6470, 14.8714, 15.9884, 24.7642],
        [18.5146, 10.6027, 11.2016, 11.6331, 12.7483, 16.5030, 20.3706],
        [44.4669, 26.2300, 16.9468, 17.5592, 17.0301, 19.7054, 32.2382]])
ground truth
tensor([[10.5602, 12.9274, 11.4545, 13.1904, 16.6491, 21.6465, 18.5166],
        [ 0.0000, 11.4087, 13.3078, 13.3787, 17.5028, 22.3214, 22.6474],
        [15.8864, 21.0679, 19.7133,  8.3903, 11.0074,  9.4950, 10.6391],
        [17.6749, 26.8937, 10.1920, 12.1515, 10.2972,  8.3640, 15.1499],
        [19.7791, 26.9069, 23.6191, 13.8217, 12.0463, 17.3330, 25.2236],
        [36.3379, 3

batch_predictions
tensor([[19.4186, 24.9738, 36.6787, 32.7062, 22.0507, 17.3945, 17.3159],
        [11.9827, 12.2471, 13.2585, 15.7745, 20.8655, 19.7713,  9.4727],
        [17.0686, 22.5966, 26.1120, 27.7172, 17.7784, 15.9632, 14.8441],
        [27.8784, 41.3603, 36.2924, 20.8797, 19.5851, 18.7875, 20.9821],
        [14.3874, 14.7306, 10.4720, 10.8058, 13.1933, 13.3341, 13.7710],
        [41.9748, 14.6277, 21.3320, 21.0122, 20.7029, 28.5909, 35.7813],
        [37.9370, 13.2717, 18.8952, 18.7887, 17.7518, 27.8008, 36.0283],
        [23.5100, 30.5012, 24.3090, 17.3396, 18.5987, 19.3092, 18.0088]])
ground truth
tensor([[22.6332, 34.6088, 50.6661, 46.4286, 27.3810,  2.1400, 16.1139],
        [ 9.2320, 11.1126, 11.8622, 14.9790, 23.0668,  0.0000,  7.5881],
        [13.0952, 17.8288,  3.0612,  3.5431,  8.5459,  8.0499, 13.3362],
        [32.0862, 47.5765, 45.1247, 26.2330, 20.7908, 21.6553, 21.9955],
        [16.8201, 17.3330,  9.5871,  6.7596, 14.4266, 14.8606, 11.0863],
        [36.1111, 1

batch_predictions
tensor([[16.5422, 18.4562, 17.8117, 19.9452, 21.9085,  8.4623, 12.3924],
        [14.8541, 15.6434, 20.6970, 24.2745, 21.9199, 12.6884, 15.0624],
        [11.1978,  9.8115, 10.0619, 10.4340, 12.2816, 11.1800,  8.9246],
        [15.0177, 17.9505, 21.6117, 19.7728, 13.2462, 14.9999, 12.5975],
        [ 9.5987, 10.7014, 14.9059,  8.4093,  7.7039,  8.9080,  8.8891],
        [23.4776, 24.8440, 14.1355, 14.1065, 13.5052, 13.9112, 17.6606],
        [ 6.3197,  7.3630, 10.4309, 11.4437, 12.6556, 14.3636, 15.0256],
        [12.3022, 18.5154, 23.4969, 20.3609, 10.8765, 12.6441, 11.4480]])
ground truth
tensor([[21.6128, 16.7800, 18.9626, 14.2432, 25.5952,  8.9994,  4.1667],
        [12.1032, 16.3690, 18.9768, 24.4615, 22.7891, 12.8685, 13.2370],
        [ 8.7585, 11.4512,  6.0658,  8.3192, 10.8985,  9.2829,  7.5113],
        [16.1494, 19.9369, 25.1315, 26.9858, 14.0584, 20.0815, 15.9127],
        [ 6.6610,  1.3464,  0.7653,  4.0958,  2.9478,  5.5839,  6.4201],
        [25.6312, 2

batch_predictions
tensor([[19.3377, 22.7003, 22.8820, 13.3874, 14.6308, 14.9360, 15.4314],
        [20.2799, 23.3845,  9.4518, 13.1108, 13.1307, 14.2410, 16.2147],
        [20.0899, 13.9570, 13.9264, 13.8998, 13.4234, 20.0748, 29.2026],
        [26.9439, 25.5247, 16.5140, 16.2454, 15.3902, 16.7641, 20.9290],
        [24.2845,  6.8917, 12.0306, 12.8003, 14.5524, 17.1626, 23.5373],
        [17.8872, 21.4425, 20.1494, 10.1237, 14.7584, 15.8081, 14.8899],
        [15.0626, 18.7108, 23.0201, 21.6382, 10.6078, 15.8228, 13.7087],
        [13.5009, 19.7520, 21.4275, 11.3555, 12.8737, 13.4617, 11.8509]])
ground truth
tensor([[18.3957, 26.1621, 25.3118, 12.1599, 11.8622, 10.0624, 12.0748],
        [24.2372, 25.8943,  8.7980, 17.7933, 15.4129, 15.3077, 14.9264],
        [31.4484, 14.5975, 15.2353, 18.4524, 16.9785, 20.7908, 36.1253],
        [29.4187, 30.2867, 14.7948, 19.9106, 16.0047, 16.5176, 19.7791],
        [24.3161,  9.1136, 12.0331, 12.0726,  0.9337,  7.1936,  0.0000],
        [17.9248, 1

batch_predictions
tensor([[20.7594, 22.5669, 11.3978, 13.2186, 13.5746, 13.2384, 15.6306],
        [17.5115, 27.4224, 29.9602, 20.9706, 13.3679, 14.5129, 14.7198],
        [14.6524,  6.0995,  9.9674, 10.0175, 10.1122, 11.4511, 17.2676],
        [14.7518, 14.7546, 15.8189, 20.8704, 25.1876, 13.6943, 14.2948],
        [13.5776, 19.2593, 22.2332, 25.6903, 14.2441, 15.6206, 13.0180],
        [18.4634, 23.2658, 22.9112, 12.8674, 13.9697, 13.3533, 14.0960],
        [14.8682, 17.1689, 25.1262, 19.3260, 12.2693, 12.9250, 12.5954],
        [15.7405, 16.2501, 23.2653, 31.4584, 27.7318, 17.2203, 16.6763]])
ground truth
tensor([[24.7896, 22.7380,  9.5608, 16.6360, 13.2167, 11.3098,  1.3019],
        [14.1582, 30.7115, 35.8844, 27.5652, 11.1536, 15.8022, 11.3095],
        [16.3730,  3.2877, 13.7033,  9.8895, 11.0863, 12.9143, 19.4240],
        [13.7472, 12.6559, 18.3107, 25.3968, 23.3844, 13.7330, 13.6338],
        [10.9552, 11.9615, 22.0805, 28.0896, 22.3356, 10.3175, 11.2528],
        [14.9093, 2

batch_predictions
tensor([[23.6579, 16.6547, 12.1747, 12.9536, 13.0987, 15.8367, 22.6343],
        [18.9317, 15.8360, 11.0260,  9.2740, 10.0658, 12.7433, 13.1209],
        [16.6877, 23.7435, 31.1286, 21.8256, 12.3585, 15.3989, 13.5733],
        [ 8.0214, 11.3934, 11.8686, 12.6566, 13.7928, 21.0202, 15.8687],
        [28.7451, 16.7305, 16.9522, 15.7130, 19.2731, 28.6299, 38.7378],
        [20.5062, 21.1464, 20.8371, 27.8057, 35.6072, 34.2707, 16.0605],
        [26.5932, 11.1939, 14.4297, 15.0180, 14.6925, 16.8141, 23.7300],
        [15.4158, 17.0192, 20.8536, 26.4046, 23.0537, 15.6160, 16.5808]])
ground truth
tensor([[27.5085, 15.7596, 16.2415, 16.5675, 12.0181, 16.2415, 22.9734],
        [23.4613, 19.3714, 11.0600,  8.9953, 10.1789, 16.1757, 18.9900],
        [19.0051, 27.1684, 54.7052, 36.5646, 12.8401, 22.1655, 33.5743],
        [13.4403, 20.2525, 10.9153, 10.9548, 14.7685, 25.4734, 24.0663],
        [ 7.3696, 18.9909, 17.7296, 13.0385, 20.3798, 36.2387, 42.3611],
        [20.5782, 1

batch_predictions
tensor([[16.2193, 20.1925, 26.9003, 25.1070, 19.8415, 14.5800, 14.1715],
        [18.9732, 19.8497, 31.8738, 42.2143, 33.3697, 20.2861, 19.4760],
        [17.1174, 20.1672, 26.3193, 38.2961, 36.6139, 24.7878, 13.9194],
        [ 8.7483,  9.7449, 10.9966, 13.7326, 16.5072, 10.2585,  5.3503],
        [10.4740, 11.6207, 11.1527, 12.1824, 14.2843, 18.2218, 21.9247],
        [26.8551, 37.2048, 34.5126, 17.1955, 21.1225, 21.1577, 20.0384],
        [18.9014, 17.4486, 22.1129, 34.9982, 39.4088, 24.6068, 18.7553],
        [15.0085, 18.9127, 24.3353, 21.2273, 13.7073, 14.9737, 14.2170]])
ground truth
tensor([[18.4666, 25.4252, 41.5391, 21.2160,  0.0000,  7.7523, 14.0164],
        [14.7109, 22.8741, 31.2217, 55.6406, 40.0368, 20.2948, 18.7500],
        [16.1139, 20.8333, 32.0862, 47.5765, 45.1247, 26.2330, 20.7908],
        [ 9.6791,  7.6144, 10.2709, 13.2036, 19.7791,  8.1799,  3.7743],
        [10.4550, 10.2315,  8.9427,  7.8643, 11.5597, 18.8059, 23.6981],
        [28.8690, 3

batch_predictions
tensor([[37.1109, 43.5588, 47.5170, 44.5765, 41.0045, 38.8920, 37.5147],
        [13.2907, 14.0384, 16.7535, 26.6534, 28.2164, 10.5524, 14.1814],
        [20.1850, 19.7481, 19.7685, 17.1217, 15.3923, 14.7608, 17.6400],
        [18.8765, 26.7156, 27.1105, 22.5143, 19.4497, 18.0159, 18.3272],
        [10.0773, 10.8439, 12.2751, 16.2445, 15.7648,  5.5624, 10.8285],
        [20.3574, 18.9354, 20.9230, 31.9465, 45.5868, 40.6885, 19.2524],
        [27.4427, 28.9043, 13.5335, 15.8009, 15.1947, 15.6142, 21.3646],
        [31.7154, 23.7091, 15.0807, 15.5848, 15.9790, 16.7078, 23.2995]])
ground truth
tensor([[43.1689, 58.6026, 74.4614, 57.9932, 53.9683, 70.1672, 67.6020],
        [34.8764, 32.8248,  2.2488,  1.5650, 14.4661,  3.6691, 14.2162],
        [42.6871, 54.3934, 36.3520, 31.6185, 16.4966, 17.6162, 13.4212],
        [23.7507, 33.5876, 57.9300, 45.1999, 25.3551, 26.9463, 23.5402],
        [ 9.8895, 11.0863, 12.9143, 19.4240, 16.7017,  5.7733, 16.5702],
        [46.0034, 3

batch_predictions
tensor([[12.2442, 14.7071, 22.1377, 26.6533, 19.8056, 13.7205, 12.4833],
        [24.6024, 36.4276, 27.9053, 16.9660, 16.7643, 15.0172, 17.0135],
        [16.1251, 16.1375, 20.3736, 20.7911, 13.9556, 14.1711, 12.7482],
        [23.2524, 25.0136, 13.3058, 15.6656, 14.0019, 15.0403, 18.0397],
        [30.3101, 11.8315, 15.6346, 15.5577, 16.9666, 19.1044, 28.8861],
        [ 9.8548, 12.2590, 16.0001, 19.8648, 18.2104,  7.1032, 10.0255],
        [12.4672, 12.3733, 13.9052, 18.3351,  8.8287,  6.8698, 10.9506],
        [15.9590, 14.3104, 16.7563, 19.7892, 30.2557, 32.0001, 15.0604]])
ground truth
tensor([[13.5346, 17.2619, 23.0017, 37.5000, 25.0283, 16.6525, 16.5816],
        [23.6678, 39.2574, 31.6893, 17.4887, 14.4983, 12.8685, 17.0635],
        [11.0600, 12.8748, 21.5676, 23.4745, 12.5197, 13.9006, 13.1773],
        [27.9458, 30.6418, 12.8090, 19.5160, 13.4403, 18.3982, 16.7280],
        [31.8736,  5.5981, 11.9331, 10.4875, 15.8872, 17.9563, 30.9949],
        [ 8.8769, 1

batch_predictions
tensor([[17.1572, 15.5713, 20.9663, 29.9960, 22.4330, 15.2021, 15.1866],
        [24.8964, 13.4127, 16.2477, 15.4831, 16.9702, 27.4393, 32.6697],
        [23.7033, 13.3875, 13.8443, 12.3934, 14.8312, 19.8544, 28.6365],
        [13.5664, 16.9914, 21.7164, 18.6953, 10.2950, 14.4228, 13.7302],
        [14.9137, 17.3882, 16.4220, 19.2661, 27.9723, 38.5740, 27.3655],
        [12.6460, 19.8303, 21.5504, 12.2857, 11.4855, 12.0297, 10.5627],
        [23.0037, 30.3466, 32.3087, 17.5891, 20.0806, 17.1908, 17.6360],
        [27.4321, 13.1253, 16.4554, 17.3079, 18.1637, 21.8058, 28.5737]])
ground truth
tensor([[12.7696, 16.3730, 20.6207, 32.7065, 23.7901, 14.7422, 12.7827],
        [26.9558,  9.8356, 12.1882, 13.6480, 14.1582, 30.7115, 35.8844],
        [ 5.5414, 13.6480, 14.0306, 12.1599, 15.2494, 24.5607, 31.1224],
        [16.3730, 20.2130, 31.0758, 25.6049,  8.1010, 19.7791, 12.2699],
        [16.4966, 17.4887, 13.2511, 21.1735, 33.3900,  4.9320, 63.5629],
        [ 9.9684, 2

batch_predictions
tensor([[14.5414, 17.6621, 18.7981, 11.4410,  8.8751,  8.9623, 10.0373],
        [24.1090, 22.9564, 14.5910, 14.5767, 13.5037, 14.2445, 18.6086],
        [19.5098, 22.8583, 15.7361, 14.5112, 13.3465, 13.0076, 13.3285],
        [16.5582, 17.5112, 16.4213, 15.7954, 20.2643, 34.6320, 32.4132],
        [23.9238, 31.2729, 35.4754, 24.0193, 20.9932, 17.3927, 21.1432],
        [14.5930, 19.5494, 23.3955, 20.9487, 13.8504, 14.9844, 13.4004],
        [22.7943, 20.6314, 21.4378, 28.5675, 39.8161, 39.4519, 21.2888],
        [48.0159, 40.4357, 26.9725, 24.7453, 20.8082, 22.5428, 32.8014]])
ground truth
tensor([[12.9537, 20.1473, 18.1615,  9.9816,  8.5876, 10.3630,  9.4029],
        [25.7511, 23.3560, 11.7914, 13.1944, 10.7710, 12.2166, 21.0601],
        [16.7942, 25.0992, 20.4932,  9.6655, 13.8322, 10.9269, 10.6576],
        [16.0856, 15.9439, 14.4416, 17.9422, 21.2018, 54.3651, 48.9654],
        [29.0637, 27.6959, 42.2672, 24.6712, 26.0784, 28.7217, 28.3798],
        [15.0794, 1

batch_predictions
tensor([[19.4812, 20.8422, 22.7556, 31.7438, 31.9655, 17.2302, 19.7531],
        [24.8248, 36.0564, 30.1870, 12.7866, 16.7766, 16.8229, 17.4061],
        [ 8.2471, 10.9067, 10.5455, 11.1313, 10.4995,  8.8672,  8.8349],
        [13.9464, 15.0700, 18.1855, 23.3530, 20.0228, 14.1821, 14.0461],
        [17.8886, 22.8886, 24.0565, 19.1682, 18.5355, 16.4420, 16.3915],
        [21.4939, 12.3567, 14.0898, 13.4943, 14.8364, 17.2345, 25.0958],
        [24.7697, 33.0110, 27.0082, 19.8631, 19.4925, 18.3116, 19.0472],
        [12.6943, 13.5850, 13.2446, 16.2958, 20.2066, 21.3523, 11.0020]])
ground truth
tensor([[33.9032, 45.7785,  4.1163,  3.1299, 17.9774,  7.5092, 19.6870],
        [30.0454, 59.5663, 40.4762, 17.6871, 30.7681, 38.3362, 61.9473],
        [ 8.7160, 12.1173,  9.8356,  5.6831,  6.7744,  5.1446,  5.1587],
        [13.6763, 15.0935, 15.8588, 24.2489,  0.0000, 11.9473, 14.3424],
        [21.6991, 32.6539, 26.7885, 21.5676, 27.9195, 24.6712, 33.3509],
        [32.3119, 1

batch_predictions
tensor([[11.9781, 12.1625, 18.6509, 16.9947, 12.7187, 10.6449, 10.6081],
        [28.9444, 17.2555, 16.8391, 15.8190, 16.7643, 19.5186, 25.2611],
        [22.0905, 31.6104, 29.9138, 10.7823, 15.7649, 16.0153, 16.5397],
        [33.0782, 28.4116, 26.7147, 20.2189, 20.3776, 28.5988, 28.5995],
        [18.1763, 18.8561, 28.3026, 39.7462, 33.6783, 15.4240, 21.0723],
        [12.2249, 13.2547, 12.2680, 14.4653, 18.8584, 25.7272, 11.9568],
        [16.4012, 16.3392, 16.1795, 18.1355, 24.0711, 25.0035, 14.7628],
        [24.2718, 24.1768, 29.7540, 40.5960, 48.2888, 42.1191, 23.6906]])
ground truth
tensor([[13.5455, 12.6512, 20.8574, 21.1468, 14.6896, 14.4661, 12.9537],
        [31.6807, 13.9532, 17.5171, 12.7564, 18.4640, 14.6765, 27.8406],
        [18.7358, 34.0136, 32.1854,  7.0011, 12.7976, 13.0244, 18.2398],
        [ 7.7722, 55.3787, 57.7196, 46.7649,  0.0000, 25.6049, 40.0053],
        [35.7001, 49.5181,  3.9966,  6.3067, 20.0822,  9.0136, 20.1814],
        [13.6480, 1

batch_predictions
tensor([[17.1352, 18.0227, 19.0047, 24.1583, 27.8795, 14.6993, 16.5181],
        [16.9964, 17.2381, 17.6898, 20.7180, 28.7182, 30.0828, 18.5228],
        [16.0040, 16.6230, 18.8930, 29.0342, 28.5420, 10.8582, 15.9285],
        [12.7689, 13.3697, 13.1254, 14.5997, 18.9388, 21.1906, 13.7987],
        [22.6142, 16.5768, 13.5263, 16.8675, 13.1330, 12.9200, 18.1265],
        [48.2996, 36.8192, 25.4757, 22.8300, 19.9128, 23.5845, 40.5328],
        [21.4440, 19.3843, 17.4982, 18.0245, 23.2128, 34.8794, 36.7272],
        [14.4638, 13.6392, 13.9946, 14.4519, 16.3717, 21.2806, 19.7808]])
ground truth
tensor([[12.6118, 15.6365, 19.9369, 23.8559, 30.6155,  9.1268, 13.6113],
        [ 1.9983,  7.7381, 18.4666, 25.4252, 41.5391, 21.2160,  0.0000],
        [13.0244, 18.2398, 25.0425, 28.6281,  0.0000,  8.9286, 13.8180],
        [ 9.0347,  8.9558, 10.0999, 13.7428, 22.6197, 18.6612, 11.3887],
        [23.9216, 22.8827, 14.0189, 18.2404, 18.8322, 17.5697, 15.9127],
        [54.9603, 4

batch_predictions
tensor([[10.5258, 14.1741, 14.7135,  9.5117,  7.5183,  7.2862,  6.4775],
        [21.7628, 20.4318, 18.5927, 16.5768, 15.3615, 14.7482, 16.0809],
        [16.4980, 15.9580, 17.0869, 19.5209, 24.8263, 21.9683, 12.8461],
        [18.5816, 23.3485, 23.9771, 13.4608, 15.4236, 14.4952, 14.8493],
        [19.3411, 24.3348, 34.4558, 33.3295, 20.8432, 19.2205, 17.5329],
        [22.0406, 25.2842, 23.6179, 24.2651, 32.0201, 47.2629, 41.5386],
        [35.1444, 22.7696, 15.4134, 16.0886, 15.2454, 17.8093, 26.2588],
        [27.2190, 23.9262, 13.3481, 15.8713, 15.3406, 16.2091, 20.6661]])
ground truth
tensor([[ 9.2451, 18.1089, 24.4082,  9.3766,  7.6670,  8.7717,  7.8643],
        [18.8059, 26.4203, 15.9653, 18.4640, 11.4676, 15.8732,  6.6675],
        [16.4683, 11.5646, 16.4824, 19.3736, 31.5334,  4.3509, 15.5187],
        [16.1494, 21.7123, 24.7764,  9.1662, 13.0063, 11.9279, 15.2551],
        [16.0856, 23.6820, 36.4654, 40.8163, 25.9921,  3.0754, 16.4399],
        [20.7766, 2

batch_predictions
tensor([[10.7407, 14.1603, 16.1590,  7.6672,  5.4700,  8.2820,  8.2281],
        [17.6865, 20.7074, 21.2679, 22.4292, 13.2273, 14.0183, 12.2973],
        [13.0958, 15.3473, 16.5329, 20.7830, 20.1720, 13.3415, 14.2578],
        [32.8110, 19.8437, 12.3739, 13.6683, 13.5754, 16.1288, 25.2077],
        [29.0769, 38.8585, 27.9818, 24.3086, 24.4220, 22.9008, 23.1560],
        [26.9034, 16.0727, 16.6477, 16.2904, 17.1991, 19.6327, 25.6095],
        [18.8820, 27.8790, 24.0877, 15.4533, 16.1840, 15.2899, 15.3503],
        [16.3696, 17.2148, 14.5864, 15.9835, 21.8184, 31.5295, 24.5850]])
ground truth
tensor([[ 9.6791, 11.2704, 17.2015,  1.3677,  3.3009, 10.9679, 18.6744],
        [15.3061, 21.0176, 33.3192, 36.5221, 11.2528, 15.0652, 10.0198],
        [10.4592, 12.0181, 19.0901, 25.8078, 24.8158, 17.0635, 18.4099],
        [37.7409, 27.6927,  8.9286, 15.0085, 11.8906, 15.5896, 25.4393],
        [ 5.1304, 60.5159, 42.4036,  4.4218, 13.2086, 22.7608, 20.1389],
        [28.9399, 1

batch_predictions
tensor([[24.9493, 15.0809, 15.0720, 14.9943, 15.2156, 19.4898, 26.0914],
        [14.2582, 16.6530, 14.7155, 14.7707, 16.1652, 21.3416, 21.8813],
        [39.6956, 26.6354, 22.8644, 19.4190, 20.1005, 31.3072, 44.7234],
        [20.7859, 22.1202, 21.3785, 13.6342, 13.5870, 15.9819, 17.0008],
        [19.9175, 32.1634, 45.7112, 37.5912, 22.4706, 21.7375, 19.3536],
        [21.4926, 14.7985, 15.0089, 13.7349, 15.9290, 22.9284, 32.8747],
        [30.1463, 26.5177, 10.8889, 13.4449, 12.4571, 14.2138, 22.1545],
        [14.4227, 15.1849, 16.0567, 21.0049, 27.0553, 15.1702, 15.5186]])
ground truth
tensor([[ 0.0000, 15.7171, 10.7993, 13.8322, 12.8543, 18.6933, 27.7211],
        [21.8569, 22.3304, 20.7917, 17.7933, 16.6886, 24.3819, 18.0037],
        [42.8327, 24.4082, 25.9337, 20.9890, 22.6197, 49.8027, 67.8722],
        [17.0831, 22.4750, 25.1973, 14.7948, 11.5729, 19.4240, 15.9258],
        [19.1752, 31.6327, 53.4722, 35.5159, 19.4870, 21.6270, 17.6729],
        [31.4177, 1

batch_predictions
tensor([[18.9452, 20.2271, 21.6772, 34.4935, 40.9756, 36.1688, 13.0355],
        [26.0732, 25.4461, 13.9952, 15.6683, 14.7207, 15.5413, 19.9270],
        [19.2278, 22.2055, 21.5526, 10.5945, 12.2176, 18.0403, 17.1674],
        [35.6107, 14.9343, 19.6471, 19.2669, 19.6086, 24.8628, 33.7894],
        [20.4131, 21.5641, 22.0516, 24.7409, 29.6235, 19.0053, 19.0903],
        [16.2367, 16.5292, 19.1968, 27.5339, 28.9152, 16.5643, 16.7809],
        [29.5400, 23.3663, 14.7042, 13.8643, 14.9326, 14.4180, 22.0171],
        [18.2308, 19.2673, 18.9909, 18.6363, 22.4934, 16.9478, 15.8604]])
ground truth
tensor([[19.7988, 19.3452, 30.1871, 42.5737, 61.5646, 49.0079, 37.2449],
        [23.1434,  3.3022,  9.0561, 15.4053, 13.9739, 13.7755, 22.0096],
        [19.9972, 25.3260, 29.0958,  8.4467,  5.0454, 19.9830, 17.5312],
        [48.2246, 23.3298, 31.9832, 45.6996, 28.0116,  1.4466,  1.2625],
        [12.7693, 11.0969, 16.4399, 24.8724, 27.5935, 11.0969, 13.2511],
        [17.2761, 1

batch_predictions
tensor([[43.6008, 39.7698, 30.8856, 31.2030, 35.1230, 38.8523, 39.0257],
        [25.2062, 13.0847, 13.9390, 13.2364, 15.2126, 18.0135, 25.6440],
        [12.9204, 13.3700, 18.2815, 22.1548, 13.1385, 10.9482, 13.8909],
        [14.1993, 17.3543, 18.3467, 17.7912, 20.3938, 27.2437, 23.4653],
        [19.5204, 22.9713,  8.7958, 11.8113, 11.6011, 11.9036, 12.3377],
        [25.4372, 28.3448, 27.8544, 14.7148, 17.2864, 16.9689, 19.8844],
        [23.8309, 23.8647, 34.1459, 41.1981, 37.7248, 25.7733, 23.9115],
        [15.3952, 18.4822, 22.5301, 21.2149, 13.0057, 13.6223, 12.9931]])
ground truth
tensor([[67.1357, 57.8906, 44.5555, 49.5266, 45.5287, 40.0184, 43.3062],
        [ 7.1429, 13.5913, 13.6196, 14.9235, 17.5595, 14.6259, 31.6185],
        [18.0556, 13.9314, 16.5958, 23.8520, 23.2001, 12.5850, 12.4717],
        [27.4660, 25.9779,  5.6973, 19.1043, 34.9490, 50.5811, 38.1944],
        [ 0.0000,  0.0000,  9.9290, 12.6381, 10.1920, 13.6902, 15.1894],
        [24.0646, 2

batch_predictions
tensor([[19.9370, 21.8966, 19.3473, 16.6086, 15.6005, 15.9484, 18.8861],
        [24.1040, 23.7908, 29.1039, 32.2154, 38.3548, 28.1463, 25.4534],
        [11.8919, 13.4092, 13.8375, 18.2031, 19.2369, 13.4037, 12.7342],
        [14.7024, 16.2136, 21.0598, 22.8108, 18.3908, 12.9706, 14.6804],
        [11.3078, 11.8659, 18.4216, 20.9542, 11.6767, 12.6615, 11.6232],
        [14.4192, 17.2334, 20.2318, 18.1926, 15.8643, 11.5223, 11.7196],
        [14.7246, 17.3866, 18.4835, 19.8558, 20.5710, 10.8449, 14.4307],
        [17.2716, 21.8679, 22.9778, 14.6134, 14.5645, 13.7846, 14.0781]])
ground truth
tensor([[18.3062, 24.7764, 15.0973, 12.4803, 12.3488, 14.9658, 15.6628],
        [38.2299, 48.8953,  2.3803,  0.0000, 20.2920, 19.5160, 23.7507],
        [12.9537, 13.0721, 16.6754, 23.2246, 29.2083, 14.7817, 13.5981],
        [10.9552, 12.7834, 20.4649, 23.7528, 21.3577, 10.7001, 11.9473],
        [12.0200, 13.7822, 18.4245, 19.5160, 11.3624, 10.9416,  9.6265],
        [ 9.2977,  

batch_predictions
tensor([[22.0917, 26.6406, 11.8374, 14.3414, 14.3369, 14.9067, 19.1836],
        [15.5662, 17.3477, 18.6827, 23.8064, 26.3578, 23.7860, 14.0442],
        [17.1945, 17.0794, 16.9610, 16.5844, 16.8797, 22.4673, 24.7683],
        [14.3177, 15.0717, 14.7067, 15.6495, 19.8567, 25.6021, 20.9757],
        [20.6909, 12.6120, 12.6734, 12.4140, 13.8274, 18.7654, 24.8788],
        [ 8.9420, 11.5485, 11.7238, 12.2879, 13.4569, 19.1389, 22.0187],
        [21.6191, 28.3321, 27.4245, 19.8354, 22.3639, 20.6129, 18.1884],
        [15.1304, 20.1396, 18.8744, 13.7501, 13.6565, 12.4997, 11.1653]])
ground truth
tensor([[24.4898, 28.2171, 15.3486, 13.3503, 14.3282, 12.7834, 10.1332],
        [27.6077, 30.9524, 34.9632, 22.4490, 28.8549, 15.3628, 12.8260],
        [21.1336, 25.9074, 15.4261, 16.2678, 21.6991, 32.6539, 26.7885],
        [11.5646, 14.2149, 13.1094, 15.4053, 24.6599, 24.7024, 25.0850],
        [24.1638, 13.5204, 14.2574, 10.5300, 15.4053, 22.5907, 27.0692],
        [ 5.2867, 1

batch_predictions
tensor([[12.9166, 13.8661, 13.9530, 17.2413, 20.1652, 23.7868, 16.5800],
        [19.2389, 20.0823, 24.8731, 37.4701, 37.1178, 21.7116, 19.2062],
        [30.5727, 13.3717, 15.6003, 15.9537, 16.3337, 21.7092, 29.1429],
        [21.7441, 18.1563, 16.5049, 15.0614, 16.4101, 18.5224, 22.8178],
        [17.5818, 21.9744, 25.0939,  9.9987, 11.2133, 13.9055, 12.2453],
        [24.7546, 22.5961, 12.9110, 12.8789, 13.5220, 14.0437, 20.0408],
        [13.7477, 10.2720, 13.9211, 20.8244, 23.4619, 21.7849, 11.8140],
        [13.7064, 14.2716, 22.7934, 26.6589, 19.8573, 12.0578, 13.7988]])
ground truth
tensor([[13.2167, 10.8101, 10.0473, 21.1336, 23.5271, 40.6102, 19.6081],
        [17.3856, 17.9905, 28.3798, 45.6339,  0.0000, 18.0431, 14.2294],
        [34.9348, 13.6763, 14.3141, 14.2857, 16.5533, 21.0317, 35.5584],
        [ 0.0000,  0.0000, 13.3219, 15.5971, 12.5460, 16.4782, 20.4761],
        [15.3912, 20.6207, 33.0924, 21.7120,  5.8248, 15.4337, 14.8243],
        [25.1134,  

batch_predictions
tensor([[ 9.4409, 11.2352, 14.0328, 19.1547, 18.1579,  7.3709, 10.1246],
        [17.7882, 21.1621, 27.8929, 30.4408, 20.7228, 19.4990, 17.0566],
        [29.8305, 28.1748, 25.3078, 25.7025, 30.1684, 37.1989, 36.4049],
        [14.4451, 17.5041, 22.7547, 17.8583, 13.8095, 13.3688, 12.3973],
        [17.5203, 13.6607, 14.0826, 13.7488, 14.2906, 17.7870, 20.1151],
        [18.1504, 18.5176, 20.1955, 21.5734, 21.5550, 16.8703, 14.0115],
        [20.8568,  7.8769,  9.0518,  8.7450,  8.3862, 10.7322, 14.6576],
        [21.6707, 22.2248, 24.8650, 35.0311, 36.0870, 24.8206, 23.8840]])
ground truth
tensor([[ 9.2451,  9.0873, 13.3088, 17.6223,  0.0000,  5.3787,  6.9174],
        [21.7971, 25.0425, 31.0374, 49.5607, 38.7188, 25.5244, 25.1276],
        [24.8816, 34.2188, 28.1299, 28.7743, 31.4834, 41.1099, 39.2820],
        [15.4478, 14.5408, 22.3923, 21.7262, 14.3282, 13.8605, 10.4592],
        [22.1088, 18.5091, 21.5136, 12.5142, 15.3912, 16.1848, 23.6961],
        [16.8201, 1

batch_predictions
tensor([[17.1004, 16.5085, 17.6857, 26.5734, 37.0986, 28.7203, 16.5353],
        [17.7148, 19.0086, 17.1841, 17.0373, 19.2138, 27.5019, 28.3542],
        [21.7678, 26.2652, 20.8521, 13.1381, 14.1526, 13.8418, 15.3799],
        [13.6581, 17.1229, 18.2519, 23.4340, 16.1029, 15.1262, 14.1949],
        [12.6928, 12.7541, 13.8460, 16.5589, 21.9846, 22.2507, 10.0627],
        [15.9338, 16.6500, 19.9947, 26.2722, 24.9729, 15.4343, 14.8268],
        [24.8223, 30.8971, 24.3043, 15.6937, 15.8426, 15.4804, 16.9904],
        [18.6756, 25.4241, 35.0685, 19.6244, 16.6853, 18.6515, 16.4403]])
ground truth
tensor([[11.9898, 12.1740, 15.7596, 19.9688,  0.0000, 15.2353, 11.9756],
        [11.1536, 31.9728, 18.1406, 16.1281, 21.2727, 32.8515, 34.7647],
        [24.6599, 24.7024, 25.0850, 14.2999, 16.1706, 17.0493, 16.1423],
        [16.6491, 18.4903, 20.0026, 24.6975, 23.0800, 15.1236, 11.7044],
        [10.5997, 11.5466,  5.5497, 15.2025, 23.1063, 24.6449,  8.7717],
        [13.4779, 1

batch_predictions
tensor([[18.6409, 19.4859, 15.5655, 13.5340, 11.3569, 10.8107, 13.1657],
        [13.5707, 14.3066, 19.3197, 22.4009, 16.6605, 14.3000, 13.9559],
        [12.5369, 10.0496, 12.8283, 14.4188, 12.2714, 13.1948, 15.5673],
        [34.6215, 20.6427, 19.7415, 19.4725, 19.6246, 25.5743, 41.7424],
        [18.4333, 10.6243, 12.2430, 13.4372, 15.4113, 19.0055, 25.2323],
        [16.5370, 18.3149, 25.5567, 36.1258, 30.1388, 19.9844, 18.7325],
        [13.6235, 13.3252, 13.3291, 16.2963, 21.4252, 23.1637,  7.1735],
        [10.3353,  9.5924,  5.9067,  5.1793,  4.8671,  4.9116,  5.7421]])
ground truth
tensor([[16.0179,  6.2073,  1.1967,  1.4861,  0.9732,  0.3288,  2.6959],
        [13.7613, 13.8747, 17.5170, 24.4048, 22.6616, 14.1298, 15.6179],
        [ 4.7477,  7.3980,  7.3271,  9.1695, 10.0057, 11.9756, 16.2982],
        [33.3900, 22.8033, 24.8299, 20.6633, 28.0471, 38.2937, 50.9637],
        [46.6412, 32.0295, 17.2194, 27.6077, 30.9524, 34.9632, 22.4490],
        [11.0261, 1

batch_predictions
tensor([[14.4283, 13.0427, 12.9560, 16.1902, 21.4370, 13.7789, 13.6305],
        [ 8.8051,  7.9013,  9.0527, 11.5342, 14.5946, 17.7295,  9.1494],
        [16.7014, 18.1800, 22.7947, 28.3650, 28.7830, 10.8924, 15.5084],
        [23.4673, 23.6449, 22.1341, 22.2249, 30.6286, 39.1909, 33.2514],
        [16.3107, 11.4091,  9.0993,  8.0848,  8.1357, 10.3626, 12.7548],
        [17.4862, 15.2397, 14.2766, 14.2643, 15.3606, 21.9469, 22.1186],
        [23.3836, 24.5294, 22.5737, 26.3601, 35.8484, 48.0952, 40.2648],
        [18.4041, 19.3986, 27.3655, 35.2464, 29.4714, 16.0245, 19.4000]])
ground truth
tensor([[16.1281, 13.5062, 12.7976, 14.4558, 22.6190, 24.5323, 14.8951],
        [ 7.6670,  8.7717,  7.8643, 11.0074, 17.2278, 23.9216,  5.9048],
        [12.8260, 21.0034, 22.1230, 33.0924,  1.7007,  6.1508, 18.5232],
        [20.1247, 19.5720, 18.6366, 20.9751, 28.5431, 40.8305,  8.2908],
        [17.4036, 11.4371,  6.4342,  7.9507,  7.8373, 10.9836, 13.4070],
        [13.7822, 1

batch_predictions
tensor([[43.5712, 32.0546, 24.7819, 24.9218, 25.0853, 35.0309, 45.8842],
        [16.0548, 17.7144, 23.8315, 24.3722, 11.6128, 14.1371, 15.2354],
        [23.2409, 14.3997, 14.3117, 13.1270, 14.9919, 22.6238, 28.7642],
        [22.5437, 28.3363, 25.0299, 15.2429, 16.1743, 14.7928, 16.5543],
        [29.8728, 32.1510, 19.0011, 18.0786, 16.2013, 16.3145, 24.2779],
        [26.4206, 25.5581, 22.6489, 25.3335, 38.0072, 46.0046, 39.4858],
        [15.2967, 17.6704, 17.1502, 20.5074, 29.2959, 28.0518, 19.4187],
        [21.7192, 17.9879, 20.7397, 28.4194, 31.2766, 27.5229, 22.1187]])
ground truth
tensor([[43.3957, 41.4541, 33.0924, 15.6746, 13.6338, 41.8509, 52.8345],
        [14.4135, 15.9390, 23.1852, 21.0021,  0.0000,  0.0000,  9.2846],
        [ 0.0000, 14.4841,  0.3685,  3.0045, 13.5629, 25.1134, 33.1066],
        [16.0431, 23.1434,  3.3022,  9.0561, 15.4053, 13.9739, 13.7755],
        [18.0957, 12.7564,  0.0000,  4.9053, 15.0710, 14.9790, 22.1594],
        [ 7.7098, 1

ground truth
tensor([[42.5737, 61.5646, 49.0079, 37.2449, 23.2143, 22.6190, 30.9524],
        [23.5271, 40.6102, 19.6081, 16.3072,  4.7870,  5.0105,  1.5255],
        [23.1589, 32.7196, 63.2299,  2.8143,  0.0000, 14.9395, 17.3461],
        [29.1294, 28.1036, 36.5992,  0.0000, 20.2656, 32.8248, 17.1226],
        [30.2579,  0.0000, 14.4841,  0.3685,  3.0045, 13.5629, 25.1134],
        [21.3577, 10.7001, 11.9473, 12.3866, 12.5992, 23.1718, 25.9212],
        [13.1661, 20.4649, 24.9291, 24.9291, 14.4700, 12.9252, 10.9127],
        [21.1451, 26.4172, 35.3316,  5.0312, 12.0040, 11.2954, 12.6276]])
batch_predictions
tensor([[13.1181, 19.6316, 19.5038, 12.2006, 12.6102, 12.2797, 11.3663],
        [12.2800, 15.9657, 19.2851, 22.2918,  8.2519, 11.8007, 10.9474],
        [17.2362, 18.8877, 21.8058, 28.1750, 32.0185, 12.6302, 18.6956],
        [20.0454, 22.0208, 30.0188, 28.8512, 17.3844, 19.1941, 19.0397],
        [13.1753, 19.5259, 22.7741, 21.3886, 13.4315, 13.1213, 12.8686],
        [25.8067, 1

batch_predictions
tensor([[22.6661, 18.9141, 17.2133, 18.4495, 21.0201, 32.1237, 28.1401],
        [29.3273, 43.0258, 38.5138, 19.8036, 18.7339, 18.0025, 21.0233],
        [23.3968, 15.7539, 15.2679, 14.8360, 14.4368, 15.1594, 21.3646],
        [15.7192, 14.4629, 13.2717, 12.9076, 17.6795, 25.9553, 20.7520],
        [21.0361, 21.6921, 14.1900, 12.4606, 12.5923, 13.0522, 14.8653],
        [21.7584, 23.0171, 23.6546, 26.2257, 28.8292, 17.3160, 18.4508],
        [16.8181, 19.2983, 26.0587, 23.0464, 12.9241, 14.0069, 12.6430],
        [22.2425, 26.3887, 21.6386, 14.0222, 15.2820, 14.8720, 15.9863]])
ground truth
tensor([[22.5482, 16.6100, 15.9439, 17.5879, 24.8158, 36.1820, 11.1253],
        [34.1837, 52.6077, 44.9121, 20.3373, 23.7954, 18.3532, 23.5119],
        [29.0505, 17.5697, 14.7554, 12.7564, 17.3461, 17.7144, 40.6234],
        [14.0873, 17.7721, 13.3503, 14.6542, 23.9371, 35.4167, 27.6927],
        [24.2241, 26.3019, 13.8480, 14.2294, 11.9805, 12.4408, 12.9406],
        [14.7948, 1

batch_predictions
tensor([[45.8122, 42.3167, 30.4484, 26.9349, 26.4100, 26.4540, 36.1740],
        [35.4420, 21.2692, 19.0307, 18.4800, 18.7682, 25.1969, 36.1752],
        [15.6197, 21.2606, 30.6045, 28.4160, 17.7635, 16.9745, 14.6498],
        [21.5010, 31.6424, 38.6124, 24.4521, 19.5915, 17.6667, 16.5909],
        [19.0075, 21.3457, 17.6288, 14.5974, 14.8072, 14.7917, 15.6439],
        [ 9.8722, 11.2450,  9.9544, 12.4594, 15.5538, 11.3361,  6.7247],
        [20.3911, 24.7560, 18.1687, 14.9224, 14.9766, 11.7592, 13.6007],
        [36.8790, 20.2182, 21.2618, 20.4136, 20.5110, 29.5346, 45.2922]])
ground truth
tensor([[52.4376, 43.3957, 41.4541, 33.0924, 15.6746, 13.6338, 41.8509],
        [ 9.6791, 11.7175, 13.1641, 21.6465, 19.9632, 27.7091, 46.1862],
        [22.4490, 29.8328, 26.7857, 13.0952, 20.5641, 16.8793, 11.3662],
        [21.9529, 39.7251, 51.4881, 34.7364, 15.5896, 11.7063, 14.7392],
        [20.2525, 21.0284,  0.0000,  6.5360, 10.1131, 11.5860, 11.8885],
        [ 9.9421, 1

batch_predictions
tensor([[29.4176, 12.1873, 15.4971, 15.0227, 15.4298, 22.8031, 33.4449],
        [23.0586, 21.6706, 12.0290, 14.7769, 14.4769, 15.0406, 17.7948],
        [20.0848, 23.8526, 22.9579, 14.4685, 16.6559, 14.3756, 15.4577],
        [19.2252, 23.7864, 27.3752, 15.2031, 16.6154, 16.8906, 17.2674],
        [36.6779, 45.7534, 31.9836, 20.1359, 20.9686, 20.5743, 21.9229],
        [25.6467, 30.6021, 14.6395, 17.1943, 16.4582, 17.3694, 22.2222],
        [26.9833, 35.6262, 29.9378, 17.3611, 18.2887, 17.3662, 18.0572],
        [29.5500, 22.5439, 13.1962, 14.5574, 12.0992, 14.0062, 21.0889]])
ground truth
tensor([[30.3146, 21.8963, 14.8243, 11.2954,  2.3243, 15.5471,  0.0000],
        [26.2330, 20.6491, 10.5300, 14.5550, 12.1740, 14.0448, 16.8367],
        [18.3193, 22.3567, 26.8806, 12.6118, 18.2536, 13.7559, 14.2294],
        [18.6744, 25.3025, 31.7596, 14.7948, 14.7685, 12.6118, 15.6365],
        [28.8974, 51.9983, 40.8022, 16.9926, 18.7783, 14.2574, 16.5391],
        [31.1224, 2

tensor([[39.1110, 21.2388, 13.7691, 14.4003, 10.4682, 12.7433, 21.4492],
        [23.7638, 30.2998, 37.4277, 34.9027,  0.0000,  1.7491, 16.8727],
        [ 9.4029, 11.8490, 22.8169, 20.8969,  9.9027,  9.6660,  7.4303],
        [22.5765, 40.2353, 46.3010, 32.8515, 20.2381, 17.2477, 12.4433],
        [45.7785,  4.1163,  3.1299, 17.9774,  7.5092, 19.6870, 20.6470],
        [15.6234, 14.4135, 15.9390, 23.1852, 21.0021,  0.0000,  0.0000],
        [ 5.6406,  7.1287, 18.7217, 10.0340, 20.1672, 21.0176, 27.9620],
        [15.9521, 16.3598, 23.2904, 33.9427, 32.2856, 20.8311, 18.7927]])
batch_predictions
tensor([[25.5821, 24.0991, 24.7086, 34.9191, 47.2249, 41.9971, 26.9669],
        [14.6247, 14.1695, 13.8979, 15.1265, 17.5653, 25.5470, 23.0361],
        [11.1253, 13.8018, 14.5739, 14.4167, 16.0030, 21.8688, 22.3287],
        [39.4085, 21.4178, 19.5210, 18.2349, 19.8500, 28.7320, 46.2485],
        [25.5995, 25.8150, 36.9743, 45.2150, 43.4168, 36.6941, 32.0666],
        [32.8781, 39.5411, 15.28

batch_predictions
tensor([[16.0510, 14.0650, 12.7664, 13.6387, 16.9950, 23.8133, 21.0131],
        [16.6005, 16.0893, 14.4361, 15.1181, 17.4337, 22.3406, 24.7318],
        [31.5312, 45.8227, 31.3975, 17.7836, 18.6744, 17.8631, 21.6606],
        [13.0982, 11.2377, 12.5152, 14.3573, 17.7614, 19.0887, 10.1832],
        [14.8191, 19.8472, 11.4201,  9.7979, 12.0903, 11.2237, 10.7372],
        [19.0010, 20.7278, 12.7942, 12.8250, 10.4074, 11.4722, 15.5150],
        [22.4283,  9.1310, 13.7472, 13.6110, 15.2076, 20.2560, 28.7969],
        [17.6429, 19.8854, 17.6357, 19.2863, 20.8995, 10.3165, 12.9857]])
ground truth
tensor([[ 5.2998,  8.2325,  8.9032, 13.0721, 17.6618, 25.6181, 25.4866],
        [15.0447, 14.8080, 13.1641, 15.9916, 19.9237, 28.1431, 36.4019],
        [40.4053, 52.0408, 41.8509, 22.5198, 18.1406, 18.1406, 19.2744],
        [14.3477, 10.3498, 13.7165, 26.6307, 33.9953, 40.8469, 31.8517],
        [13.5192, 22.8827, 10.4419,  7.5355, 15.3209, 13.6113, 14.8869],
        [18.6366, 1

batch_predictions
tensor([[21.1761, 27.1788, 33.4627, 31.1504, 21.1405, 21.0966, 18.3605],
        [22.3020, 23.7565, 31.9984, 42.3138, 34.9234, 25.7179, 24.9623],
        [11.7370, 13.2136, 12.7334, 13.7615, 17.9287, 27.0746, 17.3670],
        [13.6645, 19.1058, 18.2936, 17.3935, 17.9167, 26.8002, 31.3336],
        [16.4935,  9.6614, 10.7880, 10.6065, 12.2066, 15.8102, 13.3185],
        [14.5476, 18.8095, 10.3278, 10.9061, 11.7802, 11.3323, 12.0805],
        [17.9631, 19.9944, 20.9509, 25.9737, 35.7451, 18.7706, 19.7964],
        [41.3229, 13.0737, 20.1564, 19.8174, 25.5289, 38.0823, 48.1994]])
ground truth
tensor([[47.9308, 53.3730, 55.5697, 50.5527, 35.5442, 40.2353, 50.4535],
        [40.4478, 51.4598,  6.9303,  7.5539, 19.7988, 17.8855, 19.5011],
        [ 8.7868,  8.4609,  7.1854, 10.1899, 21.7687, 32.6956, 21.1310],
        [ 9.4424, 18.6481, 18.6612, 19.8843, 20.6733, 34.4424, 36.2704],
        [23.4087, 15.7549,  9.5871, 17.2541, 17.3593, 16.4782, 21.4361],
        [10.6576, 1

batch_predictions
tensor([[20.4068, 12.7418, 16.0058, 18.5455, 14.6772, 12.4131, 16.5422],
        [20.3145, 22.0102, 25.5201, 25.0352, 15.4954, 17.9251, 19.5550],
        [13.8140, 11.7738, 10.3297, 11.2912, 14.5073, 17.9214,  9.0401],
        [30.9376, 29.7590, 17.9089, 18.0181, 16.9970, 17.7538, 26.8075],
        [21.8234, 26.8580, 28.1994, 15.4029, 14.2444, 14.7985, 15.4951],
        [23.9579, 20.2274, 15.4067, 15.4684, 15.1096, 15.0599, 17.5763],
        [19.3893, 23.4490, 25.5777, 15.8695, 16.7449, 15.2603, 16.7719],
        [22.2467, 27.3187, 30.4010, 23.5305, 22.0913, 23.1285, 25.2059]])
ground truth
tensor([[21.0679, 18.5166, 18.3062, 22.7906, 13.9532, 11.2967, 18.0168],
        [30.6812,  3.9584,  5.0237, 13.2036,  6.6412, 18.0037, 16.0179],
        [ 6.9870,  3.9966,  4.6344, 12.2732, 12.7268,  0.0000,  5.4138],
        [62.9110, 46.3010, 25.9212,  6.5901, 37.1599, 41.7517,  8.3617],
        [26.0784,  4.4187, 24.5792, 51.1310, 17.9774, 12.5460, 12.5592],
        [35.6391, 2

batch_predictions
tensor([[14.0367, 13.1413, 17.5577, 21.5644, 19.6944, 11.2610, 14.6952],
        [13.8887, 16.4868, 15.8030, 15.6706, 20.5770, 31.5464, 26.6789],
        [ 7.9576,  8.6704, 11.2033, 13.3517, 15.4514,  9.4904,  7.0392],
        [40.6396, 33.4540, 14.0691, 18.3473, 18.2693, 19.3233, 29.1638],
        [14.0085, 13.3785, 13.9475, 17.8514, 23.7686, 17.4365, 14.1151],
        [ 8.2217, 12.4281, 10.3682,  7.7796,  6.8919,  6.5000,  5.9988],
        [14.7069, 19.6703, 26.4777, 13.2038, 13.4124, 13.4673, 13.5394],
        [ 6.6915,  8.9551,  9.1050,  9.6952, 11.0152, 14.8377,  7.8543]])
ground truth
tensor([[15.5329, 16.3407, 16.0289, 26.2330, 20.6491, 10.5300, 14.5550],
        [10.5159, 11.0402, 24.5181, 14.2857, 26.5306, 44.6287, 36.2670],
        [ 5.1587,  7.5964,  8.0499,  9.9206, 10.9836,  8.9569,  8.0074],
        [40.6888, 28.8690, 13.0669, 16.9076, 13.8889, 23.0159, 31.1366],
        [11.8359, 13.4403, 13.3745,  4.3924, 28.1431, 24.6449, 11.8227],
        [ 8.3759, 1

batch_predictions
tensor([[28.2806, 14.4564, 16.4242, 16.0877, 16.3620, 20.6418, 29.1353],
        [16.9965, 19.0059, 18.7883, 23.6034, 29.9711, 33.3226, 16.6162],
        [18.8495, 19.8302, 19.7900, 22.6838, 31.4017, 30.4658, 14.5835],
        [27.0903, 23.6979, 10.7824, 14.5659, 15.1143, 14.1489, 18.7435],
        [25.2123, 13.7811, 14.8999, 14.0835, 14.7274, 22.5892, 34.2273],
        [12.8839, 13.4319, 12.9159, 14.2212, 15.0685,  7.5820,  5.4959],
        [15.9912, 19.4584, 20.8430, 15.1183, 13.6680, 14.1835, 13.4608],
        [18.8582, 32.1456, 43.0711, 31.7447, 19.3596, 17.8026, 17.7490]])
ground truth
tensor([[29.3226, 13.8464, 16.7659, 15.3770, 18.4524, 25.6236, 34.0845],
        [ 0.0000, 19.4586, 22.0947, 29.3651, 36.9756, 37.2166, 52.7069],
        [16.4541, 21.2868, 20.6207, 24.8583, 34.3679, 25.4819,  0.0000],
        [22.3961, 27.7223, 10.4287, 14.5976, 16.9648, 14.4398, 18.7664],
        [30.1304, 15.0085, 15.1077, 13.0952, 16.5249, 24.6173, 34.3396],
        [23.2001, 1

batch_predictions
tensor([[19.5792, 21.7897, 11.5067, 12.2135, 12.1893, 13.3071, 17.0220],
        [16.0387, 16.5252, 16.9494, 19.5772, 26.0580, 24.3974, 10.0711],
        [ 9.9019,  8.6414, 12.8272, 13.4596, 12.7030, 13.8637, 14.3115],
        [ 8.7005,  9.9514, 11.4663, 13.9787, 14.7498,  8.0162,  7.8918],
        [20.4182, 10.0322, 13.1905, 12.6165, 13.8695, 16.3445, 23.3414],
        [16.2759, 21.1732, 29.8590, 26.9421, 15.0639, 16.6984, 16.3540],
        [14.3316, 14.3145, 12.9422, 14.3972, 17.3662, 28.2405, 23.2004],
        [19.9136, 18.8453, 17.3460, 20.3014, 31.3087, 41.1492, 28.7801]])
ground truth
tensor([[25.2409, 28.1179, 14.2007, 13.5062, 11.3946, 13.1661, 20.4649],
        [19.6995, 15.2494, 14.7817, 16.8934, 23.0726, 23.1718,  7.8940],
        [12.2041,  9.9027, 10.7706,  8.7717, 10.7969, 12.7170, 16.5834],
        [ 8.7585, 10.4550, 13.5981, 16.7412, 18.2667,  7.3514, 10.3498],
        [28.4439,  9.0420, 13.8464, 14.0873, 14.6259, 16.4824, 31.4626],
        [13.5062, 1

batch_predictions
tensor([[20.0960, 28.6935, 31.5380, 21.9737, 18.4615, 17.1265, 15.9730],
        [21.5981, 26.3346, 25.5498, 17.1569, 16.6228, 15.6748, 16.8065],
        [25.3060, 21.4486, 14.4063, 15.1354, 14.4278, 16.2260, 19.0786],
        [35.9312, 23.0708, 21.4440, 18.7825, 20.7652, 32.9791, 40.2936],
        [16.1008, 17.5962, 16.8156, 16.6241, 19.1157, 27.0225, 26.6749],
        [10.6756, 16.2135, 15.0147, 10.1203, 10.2731,  9.9288, 11.1570],
        [21.2542, 18.8974, 21.0273, 19.4682, 14.9660, 14.3001, 19.2170],
        [24.5454, 22.2638, 10.6748, 13.2946, 12.4855, 12.5748, 20.9182]])
ground truth
tensor([[21.6128, 33.8577, 45.2239, 27.8345, 16.1990, 17.2052, 10.1899],
        [19.6003, 34.3963, 31.9161, 23.1859, 30.4989, 35.3741, 34.9206],
        [27.2225, 23.9611, 17.8196, 18.6481, 12.3882, 20.3840, 16.0179],
        [50.5527, 35.5442, 40.2353, 50.4535, 47.5907,  2.7920,  4.5777],
        [20.2664, 11.2387, 14.6684, 15.5896, 19.3169, 27.4943, 25.7228],
        [ 9.8369, 1

batch_predictions
tensor([[18.5598,  7.1102,  9.9151, 10.2624, 11.4665, 13.7032, 18.9982],
        [ 8.9246,  9.9865, 10.1818, 10.8444, 11.4311, 14.9431, 10.4981],
        [13.7079, 16.2673, 16.5990, 19.0037, 22.3925, 29.4484, 23.2271],
        [13.5564, 15.3155, 15.3053, 16.3618, 19.9470,  9.5501, 11.6670],
        [18.0160, 16.3175, 17.3532, 26.2282, 36.4389, 31.0857, 17.7464],
        [20.0690, 11.1081, 13.3091, 13.4266, 14.2493, 17.0114, 20.0061],
        [27.0369, 13.9281, 10.2892, 12.6763, 12.3424, 15.4322, 21.8715],
        [ 9.8016,  7.2640,  5.5826,  6.0233,  6.5522,  7.0886,  8.8837]])
ground truth
tensor([[19.0558,  6.3651, 10.3630, 10.0868, 12.8354, 16.7280, 22.6197],
        [ 8.1774, 12.0748, 11.1395,  8.3759,  9.5663, 15.8447,  9.3537],
        [12.1882, 12.7551, 13.0385, 21.2302, 16.9359, 30.8107,  5.4847],
        [21.0884, 16.3832, 17.8146, 22.6616, 22.2222,  0.0000,  3.1746],
        [20.1247, 16.0431, 19.6854, 20.3940, 41.0289, 36.4512, 15.9014],
        [25.2409, 1

batch_predictions
tensor([[23.6891, 31.0206, 39.7848, 31.9378, 22.7526, 22.2083, 21.2169],
        [16.0379, 16.2683, 14.6454, 14.6538, 19.5695, 20.5207, 24.9788],
        [12.3554, 15.2300, 18.4315, 22.9750, 13.7355, 12.9047, 13.4019],
        [19.8989, 10.8953, 12.3073, 12.7682, 14.5319, 22.6991, 31.2407],
        [21.7816, 20.2222, 18.2166, 21.8199, 31.9922, 29.9149, 19.4589],
        [15.7967, 16.2894, 14.4554, 13.3090, 13.7091, 13.0462, 13.8654],
        [19.7225, 24.3606, 30.7468, 32.6701, 23.6484, 21.5230, 20.2723],
        [20.3639, 17.1532, 15.3485, 15.4710, 17.7669, 30.8500, 31.6772]])
ground truth
tensor([[22.7775, 33.7322, 44.9632, 40.2420, 22.5934, 20.3577, 22.4093],
        [11.1536, 15.6604, 14.2999, 16.1848, 20.1389, 25.1134, 21.9529],
        [14.9235, 17.5595, 14.6259, 31.6185, 25.1134, 14.9518, 14.0023],
        [27.4943, 10.3175, 12.5850, 10.4025, 13.3220, 23.1859, 37.0040],
        [17.8985,  3.7875, 11.7175, 21.7386, 34.1399, 24.3161,  0.0000],
        [14.9527, 2

batch_predictions
tensor([[22.7365, 24.6273, 11.7994, 14.3471, 13.5926, 14.1886, 17.1203],
        [17.7828, 20.3084, 22.5160, 24.3349, 25.0585, 28.0392, 14.9742],
        [16.4009, 14.8737, 16.2699, 18.3510, 21.7549, 20.4003, 15.5110],
        [26.9639, 31.0954, 33.0046, 23.3281, 24.5995, 23.0801, 22.1367],
        [34.7500, 46.5249, 36.8638, 25.7479, 22.9701, 21.6792, 25.1246],
        [10.0891, 10.6782, 11.6378, 13.8743, 16.0067, 22.2322, 10.0684],
        [16.5029, 22.5167, 22.1230, 10.4515, 11.9909, 12.0841, 11.0512],
        [13.2892, 14.6892, 15.2897, 15.8248, 19.7292, 27.0700, 13.9475]])
ground truth
tensor([[23.3824, 25.8811, 13.6376, 20.5944, 11.9542, 16.6097, 18.0563],
        [16.5533, 22.4632, 24.5323, 26.1480, 25.3827, 30.7398, 21.5278],
        [19.9500, 17.8590, 16.5308, 18.3719, 23.7375, 26.9200, 18.6744],
        [35.1920, 35.4287, 55.4577, 30.6549, 26.1836, 38.2299, 48.8953],
        [37.8118, 49.3764, 35.9269, 27.8203, 23.4269, 19.1327, 27.6219],
        [ 8.5176,  

batch_predictions
tensor([[18.8868, 20.0057, 11.5095, 10.1513, 10.1548, 10.6874, 13.0116],
        [22.4572, 23.5330, 33.0600, 44.0752, 39.0891, 28.5413, 22.3645],
        [17.7959, 23.6995, 31.5062, 22.7971, 18.7618, 17.9276, 15.7564],
        [15.7519, 15.9286, 16.9042, 26.9037, 31.3669, 25.3431, 11.3268],
        [20.2711, 29.5433, 22.9207, 14.5021, 14.7561, 13.4926, 15.4079],
        [35.0041, 23.7735, 19.1344, 17.4989, 17.1496, 25.4471, 42.5839],
        [31.4594, 10.4115, 16.1903, 15.7718, 16.2406, 19.7470, 32.7027],
        [ 7.5885,  5.7136,  5.2267,  5.4837,  5.8756,  6.0871,  7.5238]])
ground truth
tensor([[22.8169, 20.8969,  9.9027,  9.6660,  7.4303, 10.3630, 11.5334],
        [16.2840, 26.3747, 37.8118, 49.3764, 35.9269, 27.8203, 23.4269],
        [14.8606, 24.3556, 37.9932, 30.5892, 20.5418, 19.5160, 19.4240],
        [15.8022, 11.3095, 15.6321, 35.6718, 48.6395,  4.6627, 10.9269],
        [18.8776, 28.0754,  0.0000, 11.0686, 10.8560,  8.6310, 11.9331],
        [50.9600, 4

batch_predictions
tensor([[17.0552, 22.2864, 33.0525, 35.5551, 26.5957, 18.6411, 18.3550],
        [ 6.2285, 11.4456, 13.3415, 14.6040, 17.1297, 22.9265, 24.6616],
        [26.6235, 18.8702, 15.8897, 16.2574, 21.1891, 31.3499, 32.1629],
        [ 7.6731,  9.4495, 10.6344, 11.0947, 14.4417, 13.7040,  7.4082],
        [37.6844, 43.3867, 35.3866, 27.5025, 25.7975, 23.1245, 30.1014],
        [16.1327, 19.8660, 20.0280,  8.2067, 11.4350, 11.3332, 11.6577],
        [14.3732, 13.6894, 15.6796, 15.4826, 14.9771, 23.0583, 21.8181],
        [24.3045, 25.4706, 22.9166, 23.7043, 30.5602, 45.6779, 38.0916]])
ground truth
tensor([[15.9580, 21.9813, 37.7409, 48.0584,  0.0000, 18.5799, 17.4461],
        [ 9.1136, 12.0331, 12.0726,  0.9337,  7.1936,  0.0000,  0.0000],
        [20.4892, 16.7675, 19.1347, 24.5529, 30.2998, 58.7849, 61.6649],
        [ 8.5481,  8.1010, 12.9669, 12.7564, 16.8332,  3.9979,  7.8906],
        [ 6.9174, 30.0368, 14.1899, 95.9627, 48.9479, 34.7843,  7.7722],
        [18.6349, 2

batch_predictions
tensor([[14.5651, 15.2252, 18.2049, 21.5621, 23.4782, 10.3239, 15.2707],
        [12.6820, 12.5565, 14.8416, 18.7345, 17.8107, 13.3114, 13.2369],
        [23.8201, 14.3934, 14.5672, 13.0227, 14.3113, 17.3733, 27.2743],
        [12.1440, 15.7783, 15.0729,  9.3250,  8.2366,  9.5460,  9.6434],
        [15.0865, 17.6568, 17.6842, 12.5146, 10.8172,  9.9390, 11.6260],
        [12.2825, 13.8792, 14.2685, 15.0661, 23.2888, 24.5839, 10.1012],
        [10.8574, 14.4077, 18.1631, 22.4566, 11.4300, 11.3253,  9.5739],
        [13.9023, 15.2941, 16.6855, 23.2676, 31.4552, 23.2642, 10.9036]])
ground truth
tensor([[16.3407, 16.2557, 24.3056, 29.3367, 26.0488,  8.5884, 19.7279],
        [10.6128, 10.9811, 13.3219, 20.4498, 21.1336, 14.9790,  8.2062],
        [ 0.0000, 21.2651,  6.1941,  0.0000,  0.0000,  0.0000,  0.0000],
        [13.5981, 16.7412, 18.2667,  7.3514, 10.3498, 16.0179, 21.5281],
        [ 8.3114, 17.2804,  3.9190,  0.0000,  8.9032,  5.4971,  8.9032],
        [ 8.7868, 1

batch_predictions
tensor([[27.3675, 22.8103, 25.6917, 35.7483, 45.8298, 39.7295, 27.6290],
        [13.7229, 15.1888, 14.6218, 15.5194, 20.9599, 28.1104, 19.1573],
        [31.6615, 28.0683, 16.8880, 14.0463, 15.6819, 17.1739, 23.9255],
        [14.4936, 15.7662, 19.6541, 20.3689, 17.6432, 13.3565, 12.5942],
        [ 9.5157, 13.0906, 13.2511, 15.0840, 12.0771,  8.7004,  9.2967],
        [16.5572, 16.4334, 17.9585, 21.3030, 25.9388, 14.6506, 16.0255],
        [15.8839, 22.4600, 21.7517, 11.2781, 12.7070, 11.9354, 11.0234],
        [27.9945, 14.9984, 16.5466, 16.3155, 17.1611, 25.8067, 32.2850]])
ground truth
tensor([[27.3101, 20.0397, 21.9246, 32.5822, 53.8690, 43.7075, 29.4076],
        [13.3482, 12.0989, 12.4934, 17.2278, 23.7638, 37.7038,  5.8259],
        [64.6825, 50.6944, 28.5998, 23.4836, 20.2948, 18.5799, 19.8696],
        [11.5860, 11.8885, 20.0947, 23.1326,  9.1662, 12.4934, 15.8206],
        [ 7.8380, 12.9800, 14.7554, 19.5292, 19.3319,  9.8895,  9.1794],
        [11.7833, 1

tensor([[37.8636, 18.6958, 20.1223, 19.3295, 17.6990, 29.6230, 40.9834],
        [38.2110, 30.8177, 13.1313, 17.5562, 17.0027, 18.2202, 25.8269],
        [28.8605, 23.2219, 20.6263, 28.6445, 38.3599, 48.1709, 39.1468],
        [13.3858, 13.4098, 19.5389, 18.0168, 19.1619, 12.1585, 13.8529],
        [12.6914, 13.5293, 12.4513, 12.6034, 14.9924, 20.3563, 17.2896],
        [16.4430, 17.4300, 25.7601, 33.8079, 27.5357, 17.6427, 17.8966],
        [24.0201, 39.0843, 37.1858, 23.8616, 21.7873, 20.1125, 19.8874],
        [16.2920, 19.1832, 24.7328, 23.9641, 14.9380, 16.0239, 13.8750]])
ground truth
tensor([[35.4156, 15.3998, 15.4392, 17.6092, 20.5944, 36.9148, 48.3956],
        [54.3226, 39.6259, 15.4620, 20.3373, 19.5295, 15.4337, 26.8707],
        [29.6627, 26.7857, 21.7971, 29.9178, 38.3078, 55.2012, 42.5595],
        [14.2999, 16.1848, 20.1389, 25.1134, 21.9529, 11.7063, 16.5249],
        [11.1678, 13.7755, 11.1536, 12.5992, 20.0113, 28.6706, 21.8396],
        [13.8605, 14.9802, 31.8736, 4

batch_predictions
tensor([[14.2982, 14.5705, 12.3202, 12.1082, 11.4151, 11.4485, 12.4115],
        [ 8.4116,  9.6967, 12.6899, 14.7368, 15.2179,  6.9307,  8.2209],
        [ 7.1714,  8.5087, 10.0887, 12.4102, 14.3389, 15.5497,  6.8221],
        [27.4723, 31.9808, 23.4527, 16.0859, 17.0869, 17.2220, 20.4354],
        [11.9690, 13.3551, 17.0670, 16.0420,  6.9826, 12.9898, 11.2724],
        [18.3193,  9.0808, 10.5935, 10.5709, 11.6973, 13.9771, 22.9519],
        [19.1731, 20.2665, 21.2967, 22.1310, 11.6123, 11.0014, 18.3268],
        [11.3859, 14.1576, 13.2358, 13.2534, 16.3192, 21.5285, 17.7892]])
ground truth
tensor([[13.1773, 21.3440, 29.1820, 24.7633, 15.6365, 17.3987, 16.2809],
        [ 7.3908,  9.8895, 11.7175, 12.6644, 14.7291,  6.1941,  7.3382],
        [ 7.3382,  6.9569,  6.3388, 10.9285, 13.8743,  0.0000,  7.5092],
        [32.6956, 41.4824,  0.0000, 13.6480, 21.4286, 17.8430, 17.5595],
        [12.0331, 18.6349, 21.7517, 19.3714,  6.9043, 23.4219, 24.4477],
        [17.5829,  

batch_predictions
tensor([[11.9736, 14.0108, 17.3286, 21.4947, 23.3585, 12.6514, 13.7880],
        [48.0642, 36.2285, 24.5328, 25.9241, 24.0740, 26.3959, 35.1417],
        [12.1474, 12.4340, 13.0231, 17.1883, 20.1102, 10.9374, 12.6963],
        [15.3735, 17.1744, 26.2848, 26.7001, 11.7166, 13.6374, 13.1513],
        [29.2932, 29.1543, 19.8289, 21.4606, 20.4101, 20.7924, 22.7039],
        [28.8832, 25.3763, 20.9731, 21.9887, 29.7675, 44.7420, 37.5361],
        [10.6828, 15.0156, 14.2832, 14.5163, 15.3963, 19.9784, 24.5275],
        [33.9778, 13.9880, 17.0452, 16.4057, 17.3162, 31.8796, 46.6439]])
ground truth
tensor([[11.3946, 13.1661, 20.4649, 24.9291, 24.9291, 14.4700, 12.9252],
        [55.2012, 42.5595, 29.3226, 24.6882, 18.6650, 28.0045, 43.8776],
        [10.1920, 13.6902, 15.1894, 19.0558, 28.3140, 11.6123, 13.1641],
        [19.9688, 21.0743, 40.2069, 41.9359, 14.9235, 17.8430,  8.9711],
        [32.5750, 27.5250, 21.3572, 22.7512, 17.8722, 21.7780, 18.1352],
        [29.4076, 2

batch_predictions
tensor([[27.7582, 25.2512, 26.6226, 36.5161, 47.5510, 43.4058, 28.8623],
        [ 5.4370,  6.7720,  9.1819,  9.5559,  6.9655,  5.1310,  4.9623],
        [14.2099, 14.7142, 15.9776, 21.3774, 26.7523, 18.2872, 14.5643],
        [18.3857, 21.2629, 22.4827, 28.4010, 32.2790, 27.1622, 20.3756],
        [13.3658, 16.3912, 16.0553, 16.7542, 17.2308, 25.1677, 22.0573],
        [35.5176, 41.9122, 28.3620, 19.9204, 22.4342, 23.2256, 25.3638],
        [19.5434, 27.9920, 24.8042, 13.9878, 15.3099, 15.3794, 16.4381],
        [50.1164, 44.0526, 20.3013, 22.6872, 21.9444, 25.2540, 36.2097]])
ground truth
tensor([[28.1888, 23.8520, 27.3526, 38.0102, 56.3350, 52.3526, 33.9711],
        [ 5.7256,  7.5822,  9.3963, 10.0482,  5.9099,  6.3917,  5.6973],
        [10.4308, 12.1032, 12.3583, 19.9972, 33.8577, 21.4427, 15.1786],
        [ 4.8264, 18.5297, 20.3577, 26.7228, 32.6144, 42.1883, 31.0363],
        [16.4124, 13.7954, 11.9279, 16.4256, 13.4929, 25.2104, 39.7159],
        [39.9376, 4

batch_predictions
tensor([[17.1640, 17.1040, 21.8888, 28.9673, 30.7524, 15.7411, 18.6778],
        [13.2611, 14.8818, 18.7144, 21.5426, 12.1404, 13.7338, 12.2305],
        [13.8326, 16.0120, 18.2740, 24.6966, 25.4048, 13.6697, 16.4169],
        [10.9503,  9.5730, 10.9745, 12.4078, 10.9325, 10.0970,  9.5166],
        [25.1613, 26.2226, 15.4472, 16.7098, 15.0620, 15.0565, 20.2049],
        [18.2394, 12.1148, 12.8947, 13.2154, 14.1920, 19.3799, 22.4384],
        [25.7920, 20.2306, 17.3635, 16.7534, 16.8412, 19.9672, 25.7369],
        [20.6951, 22.1323, 13.6067, 14.5002, 13.1368, 14.7324, 13.5297]])
ground truth
tensor([[17.8430, 18.2540, 21.7829, 28.0045, 33.8861, 13.5913, 11.1678],
        [16.2415, 13.8889, 20.7058, 25.7795, 23.1151, 15.1077, 16.7234],
        [12.6512, 16.2546, 19.8054, 32.2462, 28.5376,  8.3377, 22.2251],
        [13.2693, 12.6249,  2.8801,  0.0000,  8.1142, 21.9884, 12.3619],
        [25.5527, 23.3277, 11.4796, 14.0023, 10.2891, 15.0794, 18.6224],
        [ 9.1662, 1

batch_predictions
tensor([[32.8242, 24.3258, 14.4070, 14.8892, 14.1643, 15.5564, 24.5695],
        [20.7653, 32.0595, 44.0068, 22.4150, 16.5296, 19.4545, 19.5158],
        [21.3769, 24.2671, 30.2234, 26.4753, 14.0088, 16.9174, 18.1954],
        [14.3328,  6.2078,  8.3485,  8.4724,  9.3816, 11.5159, 16.6156],
        [13.4404, 14.7017, 15.7954, 18.6088, 27.5531, 22.9724, 11.2737],
        [36.3229, 37.4715, 24.6525, 24.5557, 22.1610, 22.1203, 25.3498],
        [16.3930, 15.5861, 16.0627, 19.2018, 21.2049, 17.3970, 16.8998],
        [25.1738, 18.4219, 12.5068, 13.6431, 13.2269, 14.6653, 22.6182]])
ground truth
tensor([[32.7098, 25.5952, 15.8163, 12.7268, 10.8135, 13.1661, 25.0425],
        [20.0947, 26.4861, 16.4782, 50.8943, 16.2678, 22.8564,  1.6965],
        [24.6032, 27.2676, 33.5601, 37.8260,  9.6939,  6.8027, 24.0363],
        [13.2562,  4.9316,  7.4040,  6.9306, 10.2315, 12.1778, 17.0174],
        [12.0857, 13.2036, 15.8075, 23.5928, 37.7564, 28.7086, 12.5723],
        [48.3167, 4

batch_predictions
tensor([[15.3342, 17.2486, 16.0861, 16.2821, 22.2560, 36.1172, 29.3138],
        [14.6387, 16.7064, 15.0637, 18.0870, 18.7301, 26.3872, 19.4491],
        [12.6640, 19.5679, 22.5763, 20.0635, 12.5167, 14.5105, 13.0222],
        [22.5452, 19.8259, 20.7468, 32.8088, 43.9984, 39.7685, 22.6002],
        [15.5132, 13.9456, 15.7387, 21.4997, 27.3750, 19.7275, 15.8843],
        [14.7105, 20.5715, 25.1633, 12.2458, 14.2343, 13.9242, 13.5696],
        [23.7160, 21.7234, 11.8429, 14.3143, 13.9096, 14.5076, 17.6973],
        [ 9.7919, 12.1682, 11.8859, 12.3335, 13.8217, 22.3655, 17.2751]])
ground truth
tensor([[17.8571, 18.9626, 20.6491, 14.4558, 22.8883, 42.6304,  0.0000],
        [12.0040, 14.9802, 13.8039, 20.1247, 18.1548, 31.1083, 20.4507],
        [12.3158, 24.6032, 25.1559,  7.8515,  0.0000,  5.8107, 14.0448],
        [25.9337, 20.9890, 22.6197, 49.8027, 67.8722,  0.0000,  0.0000],
        [10.8277, 11.6638, 14.9093, 20.6633, 39.7817, 24.5748, 14.0306],
        [ 8.8294, 2

batch_predictions
tensor([[17.1064, 16.6515, 17.3763, 23.4293, 32.6766, 21.8736, 15.6483],
        [31.4643, 42.4149, 34.3940, 24.9718, 24.4373, 22.0844, 23.7588],
        [15.2179, 16.2088, 16.2179, 20.5433, 27.9722, 27.8322, 18.1926],
        [14.5283, 15.6202, 17.2816, 21.8575, 23.6897, 13.4198, 16.6637],
        [21.7714, 20.7439, 26.1289, 24.7996, 19.5009, 20.9833, 19.6663],
        [26.7153, 22.5277, 18.3729, 16.9263, 17.5414, 20.7090, 28.1931],
        [28.3048, 28.3241, 35.9451, 45.4118, 43.3247, 27.6397, 30.6997],
        [29.9142, 12.0398, 15.9535, 15.4531, 16.0512, 18.6277, 30.4798]])
ground truth
tensor([[18.9637, 13.1378, 17.4776, 26.4335, 33.3377,  4.0637,  0.0000],
        [36.5363, 50.4960, 14.3707, 28.8974, 26.6015, 17.1627, 28.8265],
        [15.8588, 12.2591, 13.6763, 28.2880, 26.6582, 24.3764, 17.4745],
        [11.9542, 16.6097, 18.0563, 27.9458, 30.6418, 12.8090, 19.5160],
        [19.1741, 23.6586, 34.0610, 33.1010, 20.6470, 26.8148, 21.7649],
        [25.7937,  

tensor([[11.0782, 18.1450, 16.8017, 16.3568, 17.9416, 30.4635, 33.6182],
        [ 8.5464,  9.3256, 11.4210, 13.0379, 13.8506,  7.0203,  7.9198],
        [33.3855, 23.8558, 18.8587, 15.6667, 16.5523, 25.0632, 39.5965],
        [14.6034, 15.0982, 14.8311, 17.1672, 25.7020, 26.3800, 11.6304],
        [11.0270, 10.5690, 12.3124, 13.7767, 20.8499, 29.5764, 10.4267],
        [35.2434, 22.0170, 21.9022, 20.9365, 21.4598, 30.4164, 41.9305],
        [13.1951, 14.0490, 14.5525, 15.5611, 20.1815, 26.1943, 13.9254],
        [27.7967, 14.7429, 16.1822, 15.8312, 16.0029, 17.8638, 27.1255]])
ground truth
tensor([[ 7.4405, 22.9875, 17.8713, 21.3861, 22.2080, 32.3413, 34.5947],
        [ 0.1447, 15.7154, 13.1510, 12.0331, 11.4413,  3.9321,  8.5350],
        [35.9269, 23.8662, 15.6179, 13.9881, 21.3719, 31.0799, 42.0068],
        [15.9390, 15.2814, 18.0037,  1.3677, 42.6223, 28.2877,  8.2851],
        [18.4666, 12.2591, 13.9314, 14.2007, 20.4790, 44.7421, 29.8611],
        [40.2420, 22.5934, 20.3577, 2

batch_predictions
tensor([[20.8779, 21.9054, 15.2683, 14.3694, 14.9050, 15.7294, 18.5703],
        [14.2430, 14.6053, 15.2947, 18.8809, 21.9533, 20.9830, 11.3346],
        [11.2750, 13.3595,  9.6376,  9.4352, 10.1253, 11.5091, 10.4014],
        [12.1140,  9.5351, 10.2903, 13.0439, 16.9731, 10.1679, 10.0560],
        [13.4081, 15.1439, 19.6098, 20.2948, 11.5406, 12.3356, 13.1568],
        [27.8548, 37.8279, 45.2079, 37.8446, 26.2041, 23.0348, 21.8250],
        [16.8221, 21.8171, 28.2673, 21.3991, 13.7289, 15.0494, 13.2069],
        [11.3011, 12.8723, 14.8902, 16.9130, 21.6677, 23.2333,  8.3693]])
ground truth
tensor([[23.2001,  9.7931,  1.3747,  0.9637,  6.7319, 16.8367, 17.8997],
        [13.7897, 10.8985, 14.9660, 18.0130, 21.0601,  5.1871, 13.2937],
        [ 2.8801,  0.0000,  8.1142, 21.9884, 12.3619,  3.8532, 11.0863],
        [ 9.9816, 12.0068, 20.3709, 22.6065,  6.4308, 12.0726,  8.7980],
        [15.0053, 16.3467, 21.5939, 24.6975, 10.6128, 15.3077, 14.2031],
        [30.9524, 3

batch_predictions
tensor([[10.4896, 11.8758, 13.6951, 13.3287, 14.5257, 12.2206, 10.8245],
        [31.3884, 18.8171, 17.0547, 16.2286, 16.5317, 24.8706, 38.5906],
        [23.5610, 17.5528, 14.4659, 14.4808, 15.0836, 19.0452, 33.1666],
        [15.9479, 15.6155, 16.9650, 22.5699, 25.8585, 17.3078, 16.8638],
        [15.3004, 13.4068, 13.4137, 14.8647, 11.9171,  9.8313, 12.1244],
        [18.0519, 10.1714, 13.9030, 14.2533, 16.7076, 21.9802, 29.4961],
        [13.3770, 17.4310, 15.5692, 15.9252, 17.3734, 23.0958, 25.4214],
        [12.0393, 13.8726, 13.0689, 13.5751, 19.1779, 25.2935, 21.8161]])
ground truth
tensor([[10.3498,  7.6013, 11.7438, 11.7307, 15.0579, 15.2814,  4.3793],
        [30.9807, 18.5374, 15.9722, 14.1015, 19.5578, 28.0471, 51.9558],
        [38.1519, 22.1230, 15.2636, 14.8668, 22.1797, 28.5998, 41.0289],
        [15.5971, 14.9790, 14.7948, 23.5139,  0.0000, 11.3493, 12.7959],
        [16.3832, 10.5442, 13.3078, 17.2477,  4.7477,  7.3980,  7.3271],
        [22.2364,  

batch_predictions
tensor([[27.7300, 32.0410, 24.1324, 12.0400, 14.1581, 13.7369, 16.9302],
        [14.0485, 18.4496, 15.8463, 17.8490, 13.1320, 12.2239, 13.0952],
        [28.0689, 15.6338, 16.7921, 17.6675, 16.8077, 20.6497, 26.2125],
        [24.8646, 36.2372, 34.6498, 19.5073, 19.1945, 17.9463, 19.0826],
        [ 9.3478,  8.7724, 10.2013, 11.2764, 12.9960, 15.6557, 17.0924],
        [35.4931, 37.8030, 18.3680, 19.7034, 20.9813, 22.9776, 32.4930],
        [12.2751, 14.8481, 18.5256, 18.0885, 11.4775, 13.3748, 10.1587],
        [19.0101, 21.4682, 20.4655, 14.8571, 14.8852, 11.4157, 13.6010]])
ground truth
tensor([[35.6718, 48.6395,  4.6627, 10.9269, 18.8634, 14.4274, 18.2540],
        [17.3593, 16.4782, 21.4361, 21.9884, 18.8059, 17.6355, 13.1510],
        [ 0.0000, 27.4660, 25.9779,  5.6973, 19.1043, 34.9490, 50.5811],
        [23.6820, 36.4654, 40.8163, 25.9921,  3.0754, 16.4399, 19.2885],
        [ 7.3514, 10.3498, 16.0179, 21.5281,  2.9721,  4.4582,  7.6670],
        [ 3.9966,  

batch_predictions
tensor([[20.9213, 21.9485, 20.8201, 22.5126, 31.3362, 39.8419, 22.7951],
        [14.8839, 17.0923, 21.2952,  9.6316, 10.3323, 10.2895, 10.3485],
        [41.7922, 49.5694, 41.2695, 34.8815, 34.9696, 34.2382, 37.8390],
        [18.6752, 13.2142, 11.3024, 12.0761, 14.0700, 16.1583, 20.6105],
        [18.7127, 19.3315, 18.4834, 28.7029, 38.0269, 35.4167, 13.1202],
        [29.3444, 38.7985, 27.4577, 18.6955, 16.8586, 15.8414, 19.5273],
        [10.1167, 12.4941, 17.7028,  9.5078,  6.0388,  7.0327,  7.7723],
        [20.8968, 14.8046, 13.8218, 13.2132, 13.9830, 15.6272, 20.6699]])
ground truth
tensor([[58.9427, 18.7358, 24.0646, 22.8316, 35.5300, 50.1134, 39.6400],
        [16.0714, 21.0459, 27.2676,  4.6627, 11.9048, 11.2528,  7.5397],
        [14.1440, 11.0402, 25.2409, 19.1185, 28.7557, 29.0816, 54.1100],
        [17.1489, 10.3235,  9.3109,  1.3677, 10.4550, 13.2430,  8.1010],
        [14.3282, 22.7749, 19.5578, 29.3793, 45.5924, 36.9473, 17.4461],
        [30.6973, 4

batch_predictions
tensor([[19.8741, 21.4868, 22.1004, 11.0940, 12.6774, 17.9190, 16.7341],
        [18.1356, 22.9792, 19.9863, 12.7233, 14.5480, 13.1769, 12.7858],
        [15.4201, 17.8397, 25.5224, 22.4341, 14.7469, 13.9047, 13.3322],
        [15.1705, 16.3885, 27.3790, 33.6336, 17.6220, 14.4478, 16.6880],
        [23.6830, 25.4593, 12.8732, 16.1717, 15.0327, 15.8630, 18.4946],
        [24.2426, 10.7660, 11.9551, 15.9327, 16.2856, 18.8513, 23.0189],
        [21.7501, 30.2037, 41.2246, 27.8554, 23.4662, 21.1830, 19.4446],
        [18.8882, 22.5300, 20.9795, 11.8662, 15.2019, 14.7785, 14.3512]])
ground truth
tensor([[18.5941, 23.0442, 23.8237,  8.4184,  4.7619, 22.4065, 17.1060],
        [16.0573, 19.5862,  0.0000,  0.0000, 14.5833, 13.7755, 15.4337],
        [16.3861, 20.2525, 22.4487, 27.7223, 15.2551, 15.7154, 13.7954],
        [14.4274, 18.2540, 30.5414, 43.9626, 24.4898, 10.9269, 16.5108],
        [25.1710, 27.8538, 13.8085, 14.3872, 13.0721, 13.2167, 16.1494],
        [23.8237,  

batch_predictions
tensor([[20.2807, 26.0638, 14.6899, 13.9928, 13.8878, 14.2397, 16.0883],
        [18.3238, 19.9776, 19.6281,  9.5716, 12.6911, 14.0739, 16.6588],
        [15.6737, 21.2472, 21.5319, 17.6987, 13.1274, 13.6132, 13.7243],
        [18.3823, 24.3527, 19.1455, 14.6825, 15.3460, 14.0934, 16.3505],
        [28.8711, 37.9006, 44.5499, 41.2404, 38.0477, 32.3817, 32.3193],
        [17.6922, 23.1766, 22.8900, 10.4451, 15.4174, 13.6326, 12.7059],
        [19.9969, 28.7624, 36.9194, 19.1278, 19.9393, 19.6433, 17.0390],
        [13.9701, 16.2372, 15.2379, 15.8970, 22.7948, 30.8809, 23.4424]])
ground truth
tensor([[23.9216, 31.0889, 14.5844, 16.2678, 13.9532, 14.7685, 17.4908],
        [ 5.8259, 17.6618, 27.6039, 15.4129, 10.9153, 22.0016, 21.0679],
        [14.9518, 21.1310, 25.7937,  5.6406, 13.1519, 13.2086, 17.0068],
        [19.1478, 27.2225, 23.9611, 17.8196, 18.6481, 12.3882, 20.3840],
        [41.2941, 44.1873, 60.3893, 58.3772, 29.4056, 38.0195, 42.2015],
        [16.8858, 2

batch_predictions
tensor([[19.5872, 23.4857, 28.6086, 21.2364, 20.6002, 22.5178, 21.8027],
        [14.9280, 21.0572, 19.6588, 10.7257, 11.6822, 11.8095, 11.4686],
        [15.4233, 18.2291, 14.8185,  9.1162, 10.6412,  9.7983, 10.1370],
        [17.1065, 21.0849, 19.0297, 14.7476, 13.9766, 12.3753, 14.2189],
        [11.3425, 11.6540, 14.0106, 14.1655, 15.4526, 19.1578, 14.1436],
        [10.3282, 11.3423, 11.0337,  9.7473,  9.8774,  9.0123, 11.0396],
        [20.9225, 20.3325, 23.3988, 25.6299, 31.4514, 24.8920, 22.4276],
        [26.3156, 25.9730, 27.0214, 18.9370, 18.8392, 20.7589, 23.5965]])
ground truth
tensor([[23.5402, 24.0788, 30.6689,  2.2534,  0.0000, 14.4274, 28.4864],
        [13.7033, 25.6970, 30.0894,  6.7596, 11.7701, 12.9669, 10.5997],
        [15.2288, 20.6207, 16.1231,  5.3919, 11.0468,  7.1936,  9.5082],
        [15.3209, 17.6223,  0.0000,  9.9027, 10.8759, 10.1789, 10.4419],
        [15.1762, 12.8222, 10.4945, 14.1636, 12.9800, 22.3567,  6.0494],
        [ 9.8356,  

batch_predictions
tensor([[15.0263, 15.9235, 15.6016, 19.2059, 23.3454, 23.2742, 13.1639],
        [10.3304, 10.1984, 10.0257, 12.3712, 12.3292, 14.7251, 14.0608],
        [15.7259, 15.8038, 17.8701, 20.7607, 27.8284, 25.5320, 14.3754],
        [10.7177, 13.1438, 13.6616, 14.9559,  9.7497,  8.1223,  7.1169],
        [19.1674, 18.0112, 21.3934, 28.3475, 30.9114, 19.6841, 18.3630],
        [37.8926, 37.7414, 38.9509, 41.2999, 44.1552, 48.7154, 45.5402],
        [13.3469, 14.2774, 16.7313, 19.3916, 24.6318, 14.5519, 13.2865],
        [34.3841, 44.9044, 36.1063, 24.7150, 24.0312, 21.7671, 21.9020]])
ground truth
tensor([[15.7154, 12.4145, 15.3735, 18.1878, 28.6691, 29.7607, 15.3077],
        [ 4.3793, 15.0316,  9.8106, 13.6902, 14.5055, 20.3577, 14.2557],
        [32.4698, 31.3651, 40.7812,  4.6949,  3.0773, 13.1247, 11.5203],
        [13.6902, 11.2835, 22.1988,  0.0000, 11.1652,  8.3246,  6.9174],
        [17.6881, 20.5944, 30.6418, 42.5302,  9.6791, 11.7175, 13.1641],
        [27.5776, 5

batch_predictions
tensor([[31.9503, 15.1886, 18.1876, 17.2787, 18.4548, 28.2162, 41.8877],
        [12.8999, 21.8574, 11.4190, 11.2122, 10.3675, 10.1081, 11.4605],
        [26.8008, 26.9868, 30.5299, 36.4172, 35.9802, 29.8713, 28.9463],
        [15.3563, 11.9079, 12.1428, 14.2532, 14.4733, 15.8922, 20.0121],
        [15.0549, 15.8662, 15.4060, 20.4568, 25.6281, 21.8012, 11.4810],
        [18.6087, 23.2078, 24.1066, 12.8562, 15.5641, 14.2034, 15.0089],
        [16.9868, 22.0503, 22.9629, 12.2461, 14.2492, 13.3189, 14.0879],
        [15.6074, 14.0442,  6.0782,  7.6441,  8.1744, 10.7079, 10.9022]])
ground truth
tensor([[35.9552, 12.8401, 23.9229, 22.0947, 22.6757, 29.2517, 41.8651],
        [21.8569, 26.1047, 10.5339, 13.6639, 11.2967, 12.0200, 13.7822],
        [23.8822, 27.3935, 33.4035, 35.9679, 42.3593, 24.8816, 34.2188],
        [13.9269, 15.1762, 12.8222, 10.4945, 14.1636, 12.9800, 22.3567],
        [13.0385, 13.9314, 12.4717, 19.7421, 25.8220,  0.0000, 12.3866],
        [17.7801, 2

batch_predictions
tensor([[20.7837, 20.1084, 19.2907, 20.5158, 29.4878, 30.3323, 18.2432],
        [15.9828, 14.1838, 14.6946, 16.1079, 19.6646, 20.2610, 13.9430],
        [15.0400, 21.1482, 18.9512, 10.5625, 12.4685, 10.8497, 10.6980],
        [19.1373, 12.2312, 12.7052, 12.0188, 13.1695, 16.3364, 19.2764],
        [ 7.1201, 11.4382, 11.4119, 11.5907, 14.0582, 18.7282, 10.8378],
        [28.5497, 19.3121, 16.9728, 15.8413, 16.6286, 20.4280, 29.6470],
        [16.0953, 17.5915, 27.0973, 33.7886, 21.8448, 17.3000, 17.1619],
        [21.4286, 12.2691, 13.1600, 13.1114, 14.0620, 19.2420, 28.2208]])
ground truth
tensor([[25.3156, 20.0158, 19.3582, 17.0174, 29.2478, 34.3635, 20.3446],
        [13.3645, 12.8968, 13.9031, 19.8838, 22.1797,  0.0000, 13.2795],
        [15.7812, 30.2867, 21.1336, 11.7833, 13.9400, 15.8206, 13.7165],
        [14.4135, 10.0736, 11.1257, 10.2183, 13.5850, 16.1494,  8.3772],
        [ 7.5355, 15.3209, 13.6113, 14.8869, 22.5276, 28.7743, 13.2036],
        [26.4881, 1

batch_predictions
tensor([[34.0285, 19.6554, 19.1300, 17.9401, 19.2429, 31.2934, 44.2885],
        [25.2814, 39.6877, 37.8875, 16.3580, 18.9799, 19.6260, 20.0992],
        [15.9384, 17.0920, 27.2716, 30.8991, 24.8384, 12.9720, 15.5552],
        [19.8309, 30.4774, 28.1892, 16.1487, 15.3574, 14.8883, 16.4124],
        [18.1776, 16.7073, 17.1364, 20.6870, 29.9389, 27.1165, 19.0286],
        [16.0839, 11.4846, 13.0753, 13.2795, 15.7554, 22.7394, 24.5394],
        [19.7736,  9.7059, 11.5459, 12.1001, 12.7678, 18.1562, 20.7852],
        [26.4515, 38.0807, 32.6134, 23.8624, 22.3452, 21.5553, 20.3570]])
ground truth
tensor([[45.3940, 33.0074, 30.0028, 15.8588, 23.7245, 32.8798, 57.1429],
        [25.8943, 54.7870, 53.8269, 30.2735, 25.1973, 16.3204, 23.6718],
        [11.3095, 15.6321, 35.6718, 48.6395,  4.6627, 10.9269, 18.8634],
        [23.9611, 26.0784, 35.5471, 15.2946, 13.5324, 11.9148, 14.3346],
        [18.5232, 14.5125, 19.8696, 20.9325, 27.5652, 27.3810, 18.9059],
        [15.7596, 1

batch_predictions
tensor([[26.4211, 26.2849, 13.3406, 15.5632, 15.1741, 14.9488, 18.4537],
        [33.7413, 40.3579, 45.2961, 42.2201, 35.7061, 31.0769, 29.3679],
        [30.0642, 30.1993, 13.8652, 16.6262, 16.4757, 15.6001, 20.7315],
        [ 7.1329,  5.7377,  6.1887,  6.9079,  7.5362,  9.7018, 12.0452],
        [20.0042, 23.1007, 21.8486, 12.5621, 14.5729, 13.3009, 13.1016],
        [23.2726, 25.4276, 10.4238, 14.1169, 14.1158, 14.5229, 17.1904],
        [37.7248, 31.8146, 12.8982, 17.0371, 16.6920, 16.9918, 27.3497],
        [22.7582, 21.6911, 23.4400, 19.7992, 12.8179, 13.5100, 16.5800]])
ground truth
tensor([[33.8861, 30.5414, 13.1803, 13.8464, 14.2715, 17.5170, 20.3231],
        [36.7772, 46.7829, 62.9252, 50.5811, 43.3248, 36.7489, 37.6134],
        [27.4235, 31.8169, 13.6621, 14.4274, 13.5771, 12.5709, 14.6400],
        [ 8.1207,  7.6956,  4.9461,  7.3554,  8.2625, 11.6213, 13.8322],
        [24.6032, 25.1559,  7.8515,  0.0000,  5.8107, 14.0448, 12.4858],
        [31.4626, 3

batch_predictions
tensor([[23.0384, 23.5779, 10.9166, 12.5281, 12.0982, 11.6577, 16.8998],
        [12.0345, 13.8717, 13.9229, 18.5397, 22.3124, 23.3360, 12.8313],
        [26.4147, 36.5647, 25.3375, 18.6234, 18.0166, 16.8300, 20.7794],
        [23.2594, 23.1882, 15.3656, 15.0088, 11.9977, 12.7585, 17.7288],
        [16.1641, 14.7229, 12.7743, 12.6487, 14.4344, 19.6527, 19.6648],
        [26.9932, 11.4461, 15.1791, 15.6714, 15.6804, 17.3832, 25.0277],
        [17.4433, 23.0617, 26.7463, 17.9713, 10.7327, 13.6274, 12.7701],
        [ 6.8910,  8.6554, 11.0878, 12.1657, 16.0957, 13.2014,  4.9138]])
ground truth
tensor([[23.1063, 24.6449,  8.7717, 11.7701, 10.8496, 10.1262, 17.3593],
        [13.9137, 13.4534, 11.8490, 18.2930, 18.1352, 26.3940, 17.6486],
        [23.0274, 36.4019, 24.4477, 17.8590, 15.7812, 13.0326, 16.3993],
        [26.0346, 31.4484, 15.9014, 11.3946, 13.5346, 14.0448, 16.6383],
        [14.2951, 14.3872, 10.6128, 10.9811, 13.3219, 20.4498, 21.1336],
        [28.3798,  

batch_predictions
tensor([[12.1393, 18.7740, 19.9061, 10.3726, 11.4592, 12.5445, 11.6148],
        [10.7056, 14.9254, 15.8362, 15.9042, 16.7615, 25.0890, 25.6330],
        [15.2625, 15.5380, 14.3384, 15.3436, 18.5243, 24.3909, 22.5341],
        [24.9262, 24.4176, 24.6891, 23.4115, 26.3717, 37.7846, 36.9235],
        [16.2796, 17.6277, 20.4696, 26.7745, 28.3597, 14.4965, 15.8004],
        [24.3358, 32.2078, 42.5634, 35.9994, 28.2484, 26.0057, 22.7246],
        [13.5603, 14.9334, 16.6196, 20.9957, 27.6052, 22.2614,  9.1174],
        [15.3566, 16.3572, 16.4275, 18.7720, 25.6032, 28.1571, 17.6122]])
ground truth
tensor([[12.1599, 17.0210, 24.4898, 21.7545, 10.8844, 14.2007, 14.7109],
        [19.1468, 25.1134, 16.5958, 21.9246, 25.9495, 39.5408, 42.9280],
        [18.5166, 16.1757,  6.0100, 17.8327, 18.3851, 26.8411, 40.0710],
        [29.1952, 31.5755, 19.5029, 22.7380, 22.9090, 42.1883, 38.9269],
        [14.9235, 14.0023, 23.4694, 29.5493, 34.9348, 13.6763, 14.3141],
        [ 6.1508, 1

batch_predictions
tensor([[32.2205, 40.3761, 26.1842, 17.2318, 16.7934, 15.6551, 22.4296],
        [24.5949, 23.2506, 19.2150, 19.8152, 34.3067, 45.8523, 43.4280],
        [37.9871, 29.7345, 14.9248, 17.1499, 17.0829, 17.7049, 29.8775],
        [20.2965, 23.8542, 12.8338, 13.5819, 13.7385, 14.2903, 14.9790],
        [12.2773, 13.7928, 14.5881, 17.4134, 20.7387, 22.6856,  8.6803],
        [18.6540, 20.5566, 24.9755, 36.3830, 27.0415, 20.8925, 18.5267],
        [22.5750, 14.7070, 14.5538, 14.0614, 15.6625, 21.5133, 30.1055],
        [16.8231, 19.3992, 25.0481, 17.4272, 14.4730, 14.8288, 13.3353]])
ground truth
tensor([[31.9444, 43.2115,  7.3696, 18.9909, 17.7296, 13.0385, 20.3798],
        [50.9732, 21.3440, 36.9674, 33.6007, 45.2262, 51.3809, 44.7396],
        [44.7988, 36.7205, 16.6667, 19.7137, 18.1831, 24.0646, 32.2988],
        [20.7058, 25.7795, 23.1151, 15.1077, 16.7234, 13.5488, 16.2982],
        [18.9626, 16.3407, 16.2557, 24.3056, 29.3367, 26.0488,  8.5884],
        [19.1478, 1

batch_predictions
tensor([[17.4628, 19.6873, 22.7945, 21.9809, 16.8185, 15.4914, 16.4010],
        [30.1230, 28.5990, 26.2650, 27.6306, 31.0791, 33.9741, 34.8721],
        [22.0461, 24.2319, 16.9465, 14.5802, 13.9718, 14.0043, 15.8960],
        [18.0019, 23.3217, 16.3679, 12.5183, 13.5653, 13.3510, 12.9537],
        [11.6748, 15.7645, 20.8856, 20.1227, 11.1320, 14.0562, 12.7288],
        [12.9343, 13.0379, 14.6497, 18.6506, 20.8173, 10.8633, 12.2708],
        [29.4324, 14.9797, 16.3001, 16.6209, 17.2872, 19.0507, 25.6179],
        [15.7176, 15.2530, 14.4067, 16.5086, 19.3035, 25.1341, 15.0893]])
ground truth
tensor([[20.4892, 19.7396, 25.2893, 24.2767, 19.5292, 17.6881, 13.4534],
        [31.6893, 27.9337, 29.0533, 23.7387, 22.4065, 30.0312, 33.2341],
        [20.5924, 32.1429, 17.5028, 14.2007, 10.4308, 12.1032, 12.3583],
        [18.1264, 29.1950, 26.0062, 43.6083, 10.0765,  5.7115, 23.8662],
        [13.3614, 18.9506, 29.2872, 27.6302,  9.5739, 22.3041, 22.1857],
        [14.0058, 1

batch_predictions
tensor([[19.4370, 26.2113, 34.0105, 29.4245, 21.5617, 19.8626, 18.2407],
        [13.6405, 17.0733, 17.9937, 18.6151, 22.1762, 21.3281, 14.1132],
        [ 9.9432,  7.9063,  8.2733,  7.5562,  8.0150,  7.5025,  9.7261],
        [ 8.3751, 10.2514, 10.8924, 12.1960, 10.9923,  8.5251, 10.1354],
        [19.9696, 22.4388, 14.6491, 13.8942, 17.1287, 18.6011, 19.6419],
        [34.9778, 42.2474, 29.8344, 22.9938, 21.9783, 22.7845, 29.8119],
        [12.9398, 14.6304, 13.9552, 15.0680, 20.1319, 27.0557, 13.6909],
        [14.6821, 15.2398, 16.4158, 18.9877, 24.0979, 23.7441,  9.8050]])
ground truth
tensor([[17.2278, 24.8685, 36.2835, 29.2741, 16.8990, 15.9258, 15.9521],
        [12.9537, 24.0794, 24.4345, 21.4492, 24.9342, 24.5660, 16.4650],
        [ 5.2863,  4.7052,  5.1304,  4.7761,  3.4580,  6.0232,  9.1695],
        [ 9.7931, 10.2749,  9.1837, 11.8906, 10.8135,  9.9206,  8.7585],
        [23.7244, 22.0542,  0.0000,  9.5082, 17.5960, 17.6618, 17.0831],
        [ 6.8254,  

batch_predictions
tensor([[19.9224,  8.8103,  7.6387, 10.2087, 10.2060, 11.3197, 12.8063],
        [29.3463, 33.9601, 17.3975, 16.2374, 16.0824, 17.1775, 25.0332],
        [14.0470, 16.4217, 22.9678, 21.9551, 15.0283, 11.2167, 12.7254],
        [14.6418, 14.7782, 15.5792, 21.6473, 32.7112, 24.4119, 13.7371],
        [28.1145, 15.3310, 16.8350, 15.8828, 17.7266, 19.9268, 30.0131],
        [15.9384, 17.8615, 15.6634, 15.7386, 20.2190, 31.1111, 30.4459],
        [12.2953, 12.3529, 14.6844, 12.1947, 10.3521, 13.2524,  8.4079],
        [15.5061, 19.0021, 17.7249, 10.9500, 11.2315, 11.9801, 12.8310]])
ground truth
tensor([[18.9769,  6.5492,  2.9458,  8.9427,  8.2588,  9.4161, 15.3603],
        [36.6518, 38.4666, 24.5397, 20.1078, 21.1468, 20.6339, 29.4713],
        [10.6859, 16.8226, 26.7149, 27.5085, 15.7596, 16.2415, 16.5675],
        [13.2430, 12.8222, 21.3309, 29.8659, 41.1625, 29.0900, 21.6728],
        [30.7996, 17.3724, 15.0316, 16.3993, 21.0679, 21.1468, 32.9432],
        [22.7749, 1

batch_predictions
tensor([[16.0928, 19.7810, 31.8275, 28.9383, 19.2694, 17.0064, 15.8394],
        [29.3761, 13.1194, 10.9506, 12.2186, 12.0457, 14.8315, 24.5664],
        [ 9.9088, 11.0041, 11.8305, 12.9897, 17.8227, 20.3881,  5.9143],
        [22.9217, 35.3207, 30.8844, 19.7779, 17.9059, 17.3436, 18.9596],
        [33.7937, 16.9348, 15.9501, 16.0925, 16.5163, 22.0172, 29.9994],
        [18.3927, 19.0246, 22.0682, 27.3957, 16.0530, 15.5967, 16.8102],
        [16.5788, 19.3576, 24.8902, 32.7866, 25.4183, 17.9569, 18.0001],
        [23.1757, 15.2386, 15.1330, 15.2370, 15.0606, 19.2275, 24.1366]])
ground truth
tensor([[22.1797, 28.5998, 41.0289, 44.7988, 21.2443, 19.0760, 15.9580],
        [28.5714, 18.6791,  7.4688, 10.0198,  9.5805, 11.8056, 21.7971],
        [10.8364,  7.6144,  6.7596, 11.8227, 17.9642,  0.0000,  6.2993],
        [28.5998, 41.0289, 44.7988, 21.2443, 19.0760, 15.9580, 16.0856],
        [41.2151, 15.6760, 16.0310, 15.6760, 13.7033, 28.1299, 44.9237],
        [20.3577, 1

batch_predictions
tensor([[35.3379, 43.4260, 27.5384, 16.7326, 16.2417, 16.8593, 23.4999],
        [23.0124, 24.2869, 11.2890, 14.8675, 14.3236, 15.5183, 18.7548],
        [30.6371, 11.8399, 16.3110, 15.4903, 17.3009, 20.5515, 27.7774],
        [13.6883, 15.0324, 19.2353, 23.8651, 24.4186, 13.8357, 15.4678],
        [22.2226, 23.0050, 19.0893, 21.3689, 33.4358, 45.0529, 34.6555],
        [20.1355, 29.1580, 35.5781, 20.8176, 20.5511, 22.9839, 22.7069],
        [28.1660, 21.3937, 20.5160, 31.7956, 45.3527, 43.5866, 33.5368],
        [15.6415, 16.2426, 17.1961, 17.9223, 23.7756, 31.8573, 13.4473]])
ground truth
tensor([[33.6310, 50.1276, 34.5096, 18.8209, 19.7988, 13.0385, 23.5402],
        [25.9779,  3.1604, 10.0198, 14.0164, 14.3707, 19.0618, 20.6491],
        [ 3.5998, 12.9677, 15.5329, 13.2653, 12.9819, 22.1655, 27.5510],
        [12.6842, 10.9269, 19.7562, 40.3203, 31.0374, 18.3532, 24.2205],
        [24.1213, 32.6105, 35.4025, 53.1037,  6.7035,  9.5663, 19.7704],
        [41.3993,  

batch_predictions
tensor([[43.1925, 28.8031, 23.2249, 22.5280, 19.6907, 23.2702, 33.9805],
        [20.2097, 20.6933, 12.7988, 14.8798, 14.4327, 15.2341, 15.6029],
        [17.8419, 17.7053, 14.5073, 14.8503, 16.4365, 20.2324, 20.3005],
        [21.9218, 21.8904, 22.0906, 31.6388, 38.4600, 34.4170, 19.7373],
        [11.7658, 12.4104, 15.9440, 17.9099, 20.0937, 19.2181, 15.9142],
        [11.7174, 11.7198, 13.9359, 14.6235,  8.6589,  6.8233,  9.9711],
        [17.2923, 23.8944, 31.1319, 20.2513, 15.8272, 15.6363, 14.0933],
        [22.4750, 10.8617, 12.0031, 12.0099, 13.3399, 18.2153, 25.6780]])
ground truth
tensor([[51.7007, 47.1939, 27.7069, 15.9722, 20.9467, 21.2727, 38.2795],
        [23.3956, 25.8548, 11.5729, 14.4003, 17.9248, 16.8464, 19.7922],
        [11.8339, 13.3645, 12.8968, 13.9031, 19.8838, 22.1797,  0.0000],
        [19.5720, 18.6366, 20.9751, 28.5431, 40.8305,  8.2908, 20.8475],
        [10.5442, 11.2528, 12.5567, 14.2007,  0.0000,  0.0000,  0.0000],
        [10.3600, 1

batch_predictions
tensor([[40.5384, 33.2050, 24.1263, 23.4509, 20.8809, 21.8486, 32.9815],
        [21.0673, 23.1728, 12.9645, 14.7018, 13.0972, 13.1235, 16.7315],
        [15.9782, 15.4973, 16.0166, 20.2840, 25.8260, 20.9989, 13.1142],
        [23.2581, 30.5574, 26.7815, 14.5136, 15.5873, 14.9556, 15.0164],
        [39.1542, 27.4305, 17.9779, 16.5830, 15.3148, 17.4752, 28.1005],
        [24.0036, 38.6688, 30.2210, 19.0009, 20.1052, 19.8987, 19.8628],
        [26.5999, 24.0319, 21.4603, 23.4519, 28.1089, 37.8004, 29.0271],
        [17.6801, 19.6290, 20.4630, 24.2983, 33.9626, 30.4638, 14.7932]])
ground truth
tensor([[41.6667, 46.8821, 32.7239, 16.9501, 28.0187, 21.7971, 40.3912],
        [20.0964,  2.4235, 12.9252,  7.1854, 13.1944, 14.7676, 16.4116],
        [13.6763, 13.6054, 13.5204, 23.1718, 17.8571, 19.9830, 10.3883],
        [20.2381, 35.8844, 28.1463, 14.9235, 15.0368, 10.4875, 14.7109],
        [48.4410, 30.7681, 22.1514, 17.7012, 16.2273, 19.1752, 27.6077],
        [33.2058, 4

batch_predictions
tensor([[19.3728, 17.7174, 16.2351, 16.6874, 19.7785, 28.5013, 25.5189],
        [24.7615, 25.6857, 23.2663, 23.9255, 27.6966, 34.9429, 33.1652],
        [16.6189, 24.2451, 24.4459, 10.9119, 13.9834, 14.0346, 13.2259],
        [35.1018, 25.1304, 15.4542, 16.8575, 16.0845, 17.7470, 24.2024],
        [22.0462, 19.9969,  7.7628, 12.2153, 12.6437, 12.8105, 15.6160],
        [17.8341, 18.5527, 16.9180, 19.8702, 30.4341, 39.3212, 25.6211],
        [16.6856, 16.7305, 20.5510, 30.7312, 26.9436, 16.7826, 17.6820],
        [16.0699, 15.5065, 16.2855, 18.7353, 23.7283, 29.1302, 17.0219]])
ground truth
tensor([[18.5941, 15.3345, 13.7755, 17.4461, 21.3010, 29.6769,  7.4546],
        [31.7201, 39.7948, 23.8822, 27.3935, 33.4035, 35.9679, 42.3593],
        [14.7554, 25.5655, 20.6733,  3.3140,  4.7344,  7.6144,  9.6528],
        [30.8653,  0.0000, 20.1999, 19.0032, 14.3872, 14.8606, 24.3556],
        [21.4098, 22.5802, 10.6128, 14.3740,  8.2457, 12.4671, 20.2262],
        [19.7846, 1

batch_predictions
tensor([[ 9.7855, 11.5547, 12.0947, 13.8722, 15.8410,  8.2191,  5.4278],
        [36.1224, 46.6491, 38.7014, 26.6238, 26.1822, 24.0120, 24.3268],
        [22.8167, 20.4070, 20.4854, 31.4479, 45.1450, 39.6756, 22.1346],
        [11.6789, 12.3565, 16.6041, 22.8488, 19.9858, 11.0778, 13.4046],
        [17.5761, 16.7459, 17.8796, 19.2823, 24.3299, 28.4021, 15.1525],
        [13.9625, 13.1762, 14.6391, 14.4093, 11.5546,  9.4559, 11.4479],
        [23.2579, 26.7762, 13.3190, 14.7418, 14.7303, 14.4339, 16.6910],
        [17.7069, 24.3380, 24.8615, 10.2719, 14.1438, 12.8707, 12.0402]])
ground truth
tensor([[ 8.2194,  6.7991,  9.6791, 11.2704, 17.2015,  1.3677,  3.3009],
        [34.6939, 55.6689, 39.4274, 30.8532, 25.1559, 22.7466,  6.1508],
        [30.0028, 15.8588, 23.7245, 32.8798, 57.1429, 57.0720, 19.9546],
        [15.8206, 13.7165, 16.2809, 23.4745, 16.1757, 10.1657, 10.1657],
        [17.7538, 15.4392, 20.4498, 17.4250, 29.0900, 33.9427, 18.2404],
        [ 8.7717, 1

batch_predictions
tensor([[31.1480, 36.2321, 35.3454, 27.8725, 30.5205, 27.5216, 26.6278],
        [32.9953, 20.2353, 17.6045, 17.4790, 19.8314, 26.9729, 35.8949],
        [24.4626, 21.3055, 24.1079, 34.7947, 41.5347, 38.9473, 24.9747],
        [31.9244, 42.4318, 27.9572, 23.4321, 20.5886, 20.1032, 23.5673],
        [27.4675, 30.2302, 26.6084, 16.9393, 18.9654, 20.5665, 24.3000],
        [12.4543, 14.3849, 13.7032, 14.8557, 18.6221, 24.3313, 18.9834],
        [21.0908, 25.2822, 21.8732, 13.6783, 14.3646, 14.2934, 15.7920],
        [10.3463, 12.5722,  8.6145,  6.6773,  5.8304,  6.6019,  6.7154]])
ground truth
tensor([[31.4834, 41.1099, 39.2820, 29.0900, 29.2346, 28.0642, 29.5634],
        [43.8453, 17.4513, 15.3735, 15.0842, 22.2777, 26.4729, 46.2914],
        [40.2353, 50.4535, 47.5907,  2.7920,  4.5777, 33.6876, 12.1457],
        [38.2795, 55.3713, 43.1406, 28.8832, 32.1003, 24.3481, 27.8770],
        [ 6.2217,  5.7256, 13.9456,  7.9507, 19.9546, 17.6587, 25.1417],
        [12.1740, 1

batch_predictions
tensor([[ 9.0131,  9.9937,  9.5163, 10.9722, 12.7240, 16.1357, 15.7886],
        [19.7534, 19.4784, 14.1078, 13.4853, 12.1108, 11.7449, 14.3870],
        [43.5341, 47.2178, 39.2644, 34.2188, 36.1118, 39.0836, 39.5359],
        [16.4994, 14.9911, 16.9530, 23.7151, 32.0417, 22.4974, 14.0402],
        [25.5734, 25.7120, 14.3019, 15.5624, 15.3249, 15.5432, 20.4999],
        [38.0606, 26.1154, 22.5705, 25.0480, 29.9845, 39.4903, 42.4787],
        [40.4238, 18.4000, 19.3889, 17.8562, 20.0506, 30.0043, 46.1031],
        [23.9378, 20.5968, 12.9704, 13.7805, 13.5487, 15.1093, 19.7943]])
ground truth
tensor([[ 7.4171, 11.6123, 10.3104, 11.7307, 10.9548, 14.6239,  5.7996],
        [23.0931, 22.6591, 14.2951, 14.3872, 10.6128, 10.9811, 13.3219],
        [48.4350, 61.8885, 58.7059, 27.5776, 56.4440, 39.7422, 30.9837],
        [14.9660, 12.1882, 16.0431, 29.4785, 45.4507, 31.7035, 15.4620],
        [23.6849, 28.5113, 12.9274, 13.1378, 10.5602, 16.8990, 19.8580],
        [51.5731, 4

batch_predictions
tensor([[15.2624, 14.1639,  9.8009,  8.4372,  9.3637,  9.7065, 10.6691],
        [19.2892, 19.6461, 12.3825, 12.7126, 10.6508, 11.3016, 13.6895],
        [33.1477, 16.0129, 20.1299, 18.0018, 20.0034, 30.8144, 35.7731],
        [13.5555, 17.3489, 20.5438, 10.6654, 13.1291, 12.8552, 10.2973],
        [25.8192, 19.4509, 17.4574, 15.4623, 16.5152, 24.7404, 36.0495],
        [12.6346, 19.5151, 18.2293, 17.8859, 20.9247, 28.9951, 30.5772],
        [14.5766, 17.0272, 15.9961, 16.3361, 18.5402, 26.2409, 30.8980],
        [24.0181, 18.6132, 17.9376, 17.4659, 22.1555, 31.4289, 30.5251]])
ground truth
tensor([[13.5850, 12.3225,  7.2330,  8.9164,  7.0752,  5.4840,  7.6933],
        [20.0421, 19.2267,  3.5245,  9.1005,  8.2983, 10.8759, 16.2283],
        [31.4484, 15.2636, 15.7738, 18.9768, 16.2840, 32.6956, 41.4824],
        [12.6644, 16.5045, 23.7507, 20.2130, 16.2546, 12.1778, 10.5997],
        [31.4768, 19.0051, 16.8367, 13.8605, 14.9802, 31.8736, 49.6457],
        [10.5584, 2

batch_predictions
tensor([[14.0366, 14.2541, 14.3066, 14.7148, 19.6838, 25.5863, 13.7837],
        [20.0625, 30.1886, 24.0333, 13.3522, 14.7800, 12.9407, 14.4017],
        [16.0931, 19.7840, 25.3108, 20.9769, 14.0939, 15.5735, 13.9445],
        [16.0209, 18.3586, 20.6511, 25.2841, 27.9306, 12.5989, 17.5967],
        [25.3003, 23.4822, 22.7424, 28.9137, 42.7150, 39.7611, 24.1213],
        [22.8549, 32.3672, 28.2291, 14.7015, 18.4631, 18.8783, 16.5783],
        [15.3453, 20.1040, 13.2207, 11.5484, 10.0490,  8.5139, 11.6642],
        [16.9887, 16.5455, 18.1926, 26.2444, 21.6806, 16.9131, 17.8660]])
ground truth
tensor([[13.6054, 14.5266, 12.7976,  8.8294, 21.3294, 31.8736, 31.4484],
        [22.6757, 40.3486, 37.0040, 20.4790, 21.3294, 31.0658, 36.2528],
        [12.3299, 17.9280, 20.6066, 18.9909, 10.8418, 13.8747, 11.2812],
        [14.7817, 18.7075, 20.2381, 29.0958,  3.5998, 12.9677, 15.5329],
        [ 8.8577, 13.6905, 22.6332, 34.6088, 50.6661, 46.4286, 27.3810],
        [26.3545, 3

tensor(79.7747)

In [55]:
test_dataset = SequenceDatasetForTransformerTesting(test_sequences)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True)

In [56]:
evaluate_model(vanilla_transformer_with_nullmask_, test_dataloader, encoder_causality=encoder_causality)

batch_predictions
tensor([[22.1771, 34.4582, 44.8772, 25.9733, 21.7029, 21.2547, 19.5553],
        [36.8802, 36.1286, 12.6004, 15.7763, 15.6215, 15.7772, 21.1368],
        [21.5604, 29.2440, 28.7017, 16.2675, 17.6960, 15.9569, 16.4904],
        [33.1258, 16.8547, 18.6037, 16.7708, 20.8429, 35.8431, 48.5800],
        [33.9833, 28.6188, 31.4902, 39.1149, 46.2199, 50.9829, 47.3872],
        [24.1578, 10.0502, 13.0063, 12.5410, 12.9372, 18.3746, 31.6111],
        [13.5777, 15.8519, 25.1355, 27.5033,  9.6382, 13.1102, 13.1232],
        [29.6623, 32.2600, 21.2895, 19.3225, 17.9875, 17.2012, 19.8468]])
ground truth
tensor([[20.8837, 32.1936, 56.7596, 53.1036, 29.5371, 24.8948, 20.8969],
        [43.4524, 38.3929, 19.0051, 19.2744, 12.7126, 21.8679, 24.0788],
        [17.2804, 28.3272, 31.4834, 17.3461, 16.9779, 14.1767, 17.7275],
        [43.3107, 19.7846, 20.5782, 14.5975, 20.3798, 36.1536, 51.8849],
        [41.6491, 42.5829, 45.9890, 42.5565, 49.4345, 49.9211, 39.3740],
        [25.6803, 1

batch_predictions
tensor([[15.5842, 21.2663, 29.8220, 24.4215, 12.9773, 14.5629, 13.9883],
        [12.5921, 21.8803, 20.1765,  6.7041,  8.4279,  9.2058, 10.3396],
        [23.0246, 16.0621, 11.5454, 11.3136, 11.2164, 12.3928, 16.4560],
        [17.3630, 22.7473, 23.1574, 25.7855, 16.4138, 16.9408, 16.1709],
        [15.2246, 12.9621, 13.7205, 15.2224, 19.2421, 28.3171, 28.6975],
        [14.2381, 16.1181, 23.3567, 31.2161, 32.6730, 12.3735, 15.7982],
        [31.6802, 34.3491,  8.1711, 14.4761, 15.3362, 16.6844, 21.4998],
        [ 9.5539, 10.7359, 18.0034, 18.7640,  7.1780,  7.4424,  8.6846]])
ground truth
tensor([[11.0731, 22.9353, 11.9542, 31.3125, 10.6654,  0.0000, 10.2578],
        [11.7964, 23.1983, 22.5802,  6.6938, 10.4156,  7.6802, 10.4550],
        [19.8185, 13.3482, 11.2441,  9.6923, 10.5339, 13.6639, 12.3225],
        [14.1636, 14.4661, 24.8685, 21.1862, 12.7564,  9.2451, 18.0431],
        [18.3673, 10.5584, 10.9694, 12.1740, 15.0510, 22.4065, 31.3917],
        [18.8350, 2

batch_predictions
tensor([[26.8848, 25.7983, 14.4655, 17.2576, 16.2151, 16.0713, 21.7104],
        [ 8.7871,  9.2125,  9.8631, 10.9479, 13.2874, 16.6957, 11.5960],
        [11.4998, 12.3261, 21.6345, 28.7725, 20.4605, 11.0797, 12.2334],
        [ 8.8054, 10.5014, 11.4152, 16.7617, 13.9410,  5.1909,  5.4051],
        [15.1606, 28.8469, 37.2602, 25.4666, 13.9368, 14.1789, 14.0363],
        [ 8.5392,  8.8435,  9.1924, 15.4687, 20.1639, 14.3862,  7.6921],
        [18.9404, 17.4120, 10.6313, 12.2768, 12.6421, 14.5669, 20.0149],
        [12.8403, 13.1605, 18.8431, 25.6271, 25.0096, 11.2214, 14.5471]])
ground truth
tensor([[33.0357, 11.1678, 12.0890, 14.0306, 15.9439, 17.2761, 25.1984],
        [ 8.2851, 11.1783, 12.6249,  7.5092, 12.5066, 17.8590,  9.7449],
        [ 7.0578,  9.8781, 19.6712, 29.5777, 23.2426, 10.2041, 11.4938],
        [14.3566, 15.0368, 13.7046, 14.6825, 20.9467,  2.2251,  3.0329],
        [14.1865, 31.5051, 40.1644, 25.4819,  7.9790, 11.6497, 12.9819],
        [10.3175, 1

batch_predictions
tensor([[13.2000, 16.7770, 26.0313, 31.0657,  8.1231, 12.6526, 12.3473],
        [16.2863, 23.4763, 25.6828, 11.0013, 12.6394, 12.3128, 14.1099],
        [ 7.6911,  6.2632,  7.6177,  8.1390,  9.3069, 13.8554, 19.6204],
        [29.0726, 23.6244, 16.5906, 16.7265, 16.3284, 16.0243, 25.8797],
        [29.6828,  9.2144, 13.4395, 14.1275, 15.7642, 17.7099, 29.8718],
        [ 9.7149,  9.8873, 11.0195, 16.6486, 21.4277, 15.3013,  7.3844],
        [25.2001, 32.1197, 36.9222, 26.8193, 21.7651, 20.3840, 20.6190],
        [11.2779, 12.4280, 14.2585, 20.2739, 23.8701, 11.3624, 11.2183]])
ground truth
tensor([[12.2567, 19.0952, 27.7749, 24.9737,  4.9842,  6.8911, 12.8748],
        [17.1485, 26.1763, 31.2358, 19.1893, 18.1264, 13.2795, 11.6071],
        [ 2.2357,  0.0000,  7.8248,  9.4555, 12.3751, 12.4277, 28.9584],
        [35.9836, 36.4796, 16.2982, 15.0794, 20.8050, 18.4099, 24.2205],
        [ 4.9745,  5.2438, 12.3583, 13.9598, 16.5675, 21.9246, 33.4042],
        [ 9.0347,  

batch_predictions
tensor([[14.9993, 20.5641, 31.8934, 31.1210, 12.0680, 13.9329, 14.6317],
        [ 9.6439, 13.2842, 12.6770, 13.1543, 17.1056, 31.1296, 29.6172],
        [16.2486, 16.1949, 16.4305, 21.8056, 29.6092, 28.8976, 12.1661],
        [20.3541, 22.6691, 34.2868, 43.3190, 42.7889, 25.1851, 20.4885],
        [32.9642, 18.0537, 19.8891, 21.2718, 24.7263, 34.9576, 42.6559],
        [ 6.7038,  7.4770,  8.8657, 13.3600, 16.1047, 19.1093,  5.7984],
        [14.1721, 17.5515, 16.6435, 17.1591, 20.1209, 27.3630, 27.3590],
        [38.9875, 25.3218, 22.2764, 21.4653, 24.5057, 33.7237, 45.6218]])
ground truth
tensor([[13.4637, 19.8129, 40.4337, 35.2608, 19.0476, 16.8651, 15.1786],
        [ 7.1712, 15.3912, 10.9410, 12.8118, 22.5198, 32.2137, 22.5198],
        [16.4519, 18.2667, 16.6360, 24.4740, 36.4545, 34.9947, 14.7422],
        [20.5418, 23.3693, 58.2588, 43.2272, 38.2956, 22.6328, 25.6839],
        [32.4566, 15.3603,  3.2614, 12.0989, 18.3588, 24.7107, 41.0968],
        [ 3.3305,  

batch_predictions
tensor([[29.5046,  8.7651, 14.2024, 13.4510, 13.5437, 17.4799, 29.0551],
        [17.5008, 23.8698, 31.8808, 30.3755, 13.4482, 15.7141, 15.5865],
        [34.9099, 44.9846, 36.9902, 15.8409, 18.7223, 17.4416, 18.1926],
        [ 6.9664, 10.8841, 10.8432, 11.0365, 17.2693, 22.8578, 10.1371],
        [14.8915, 20.7513, 28.0025, 26.8566,  9.7036, 14.2622, 14.7274],
        [22.1186, 29.3170, 42.9373, 39.6101, 31.3795, 25.3791, 20.9539],
        [26.4099, 36.8773, 12.3618, 16.6842, 16.2443, 16.3904, 18.7225],
        [17.1725, 19.2145, 23.6075, 21.0863, 20.2847, 11.8643, 14.0810]])
ground truth
tensor([[33.7302,  3.8407, 14.1015, 10.9410, 14.0731, 15.4620, 21.0034],
        [13.3929, 24.7874, 35.6718, 29.2800, 15.3203, 17.2336, 22.5624],
        [33.9569, 44.1185, 32.1854, 15.4620, 20.9892, 17.4036, 17.1060],
        [ 9.4029, 12.3093, 10.0079, 16.8464, 21.1731, 29.1557,  8.9295],
        [16.2698, 24.3764, 27.9620, 24.6882,  9.6230, 14.0306, 12.9393],
        [27.0550, 3

batch_predictions
tensor([[22.5676,  9.6402, 13.1896, 12.1344, 12.3539, 14.3898, 23.7866],
        [ 9.0416, 15.1620, 15.3200, 15.5561, 18.2681, 24.3088, 23.5351],
        [33.8538, 16.1187, 18.9870, 17.6834, 21.4609, 35.7575, 46.0256],
        [ 8.8377, 10.1111, 11.3276, 11.5591, 13.3698, 22.3423, 22.1801],
        [12.1167, 16.5357, 15.6191, 15.2828, 18.4632, 28.2877, 33.8672],
        [18.1062, 20.7504, 27.3451, 31.1409, 31.3893, 21.0667, 18.3158],
        [15.5249, 14.7313, 15.3126, 23.0415, 34.2966, 34.8075,  8.5058],
        [ 7.8509,  9.9680, 10.7471, 14.2943, 16.9634, 19.0007,  7.5665]])
ground truth
tensor([[18.4807, 10.1757, 10.6576, 10.4167,  8.6735, 15.3061, 21.9671],
        [ 4.3793,  4.5635, 34.6939, 38.7472, 41.0006, 33.2766,  7.9790],
        [35.8418, 16.5675, 18.0981, 18.9768, 17.7721, 37.5000, 47.3073],
        [10.0999, 11.8096,  9.2451,  8.7454, 11.7701, 17.8064, 19.3188],
        [14.1865, 14.5266, 13.5629, 14.3849, 17.0210, 27.4518, 36.6922],
        [19.5160, 1

batch_predictions
tensor([[31.5982, 17.9572, 19.7279, 21.9183, 23.1990, 23.5805, 26.5531],
        [15.8797, 15.5769, 16.6102, 25.9180, 37.5729, 27.6701, 18.6736],
        [16.8051, 16.1523, 16.8345, 23.7380, 34.0452, 28.0191, 11.7825],
        [13.4758, 15.2594, 17.1203, 23.0562, 25.4878, 10.5077, 12.3259],
        [13.8791, 14.7720, 21.8118, 29.4857, 30.5888, 13.8887, 15.4010],
        [12.5038, 20.5034, 22.5025,  7.2926,  9.4705,  9.7550, 12.1053],
        [14.8475, 14.7173, 14.8040, 20.2615, 27.6752, 24.5416, 13.7530],
        [10.0490, 10.7843, 11.8809, 14.7654, 16.3877,  9.6237,  9.1813]])
ground truth
tensor([[35.2749, 19.3311, 19.4586, 19.9546, 18.8350, 23.9512, 27.1542],
        [17.4887, 15.0652, 15.0794, 20.1389, 36.1253, 26.2330, 19.0476],
        [14.9235, 22.3781, 19.3027, 25.9779, 33.6735, 41.1423,  7.4688],
        [15.8206, 14.7422, 18.7533, 21.9227, 24.4477, 11.9542, 15.3603],
        [ 9.6088, 16.7800, 18.5091, 30.6406, 21.8112, 15.8588, 16.6950],
        [12.1646, 2

batch_predictions
tensor([[22.5900, 28.7802, 34.2112, 19.0372, 18.8348, 16.8022, 17.6745],
        [ 9.7312, 12.4121, 19.1772, 12.2387,  7.1020,  7.7357,  7.7931],
        [12.9827, 21.2293, 21.2141,  8.8471, 10.6596, 10.0799, 10.6509],
        [ 8.8686,  9.2579,  9.3090, 10.3926, 11.8240, 18.7536, 22.2079],
        [13.5737, 14.0316, 13.8468, 20.5367, 27.6437, 21.4776, 10.6448],
        [ 5.9850,  6.9815,  7.8508,  8.4038, 10.5386, 15.3830,  8.6142],
        [36.3074, 15.5009, 19.9129, 20.9317, 22.1453, 29.1944, 36.8165],
        [22.9020, 25.0052, 21.4474, 13.0143, 14.3873, 14.4908, 15.4097]])
ground truth
tensor([[17.8713, 19.7279, 26.0062, 15.8872, 17.7296, 17.6446, 15.1644],
        [12.7959, 12.5592, 17.4250, 21.4361,  8.6665, 11.5992, 11.2441],
        [18.1220, 26.0521, 23.8559,  8.0747, 11.0600, 11.0863, 16.2152],
        [ 6.2467,  8.0352,  9.3109, 10.2183, 11.9674, 23.0274, 28.0773],
        [14.9001, 14.0189, 14.5055, 16.2678, 22.1462, 26.1178, 11.3624],
        [ 8.2194, 1

batch_predictions
tensor([[14.3727, 15.1458, 20.5486, 32.2876, 32.7614, 12.2461, 16.6073],
        [25.1424, 25.1568, 29.3525, 41.2765, 49.5600, 44.0948, 28.3281],
        [17.1760, 27.6013, 41.0944, 29.8027, 17.1101, 17.1799, 15.2839],
        [ 7.7887,  6.7307,  7.0267,  6.0088,  5.3892,  5.8638, 10.7853],
        [10.0191, 11.8498, 20.3960, 22.1776, 12.1967, 11.2573, 11.4640],
        [27.0580, 17.8563, 18.2245, 18.6949, 22.9454, 35.7988, 47.5947],
        [26.4556, 14.7920, 16.5004, 14.0357, 14.1966, 17.6374, 25.2136],
        [26.6352, 21.6257, 16.7150, 19.1559, 24.6039, 35.9218, 36.5463]])
ground truth
tensor([[12.8968, 13.3078, 15.4478, 27.2817, 33.2766, 16.6383, 18.5941],
        [19.2460, 23.9938, 36.5363, 58.8861, 46.5278, 40.6463, 25.2693],
        [19.5437, 28.8690, 49.3197, 36.5221, 19.7846, 17.8571, 14.7817],
        [ 7.4830,  6.2358,  8.9853,  8.6451,  6.3634,  6.9728, 11.7914],
        [11.0863, 12.1646, 22.1988, 17.7670,  2.3672,  5.1946,  9.1662],
        [51.8806, 2

batch_predictions
tensor([[24.3930, 30.1977, 12.5187, 16.1556, 15.3817, 15.3302, 18.7676],
        [16.0981, 17.9475, 25.8066, 24.2321, 29.6081, 21.8340, 18.9699],
        [11.8488, 14.6115, 14.1839, 14.7751, 23.6650, 36.9967, 28.8795],
        [30.7140, 29.7176,  9.3911, 13.4208, 12.3324, 12.7501, 19.8288],
        [ 9.1936,  8.0935,  7.7647,  5.6541,  5.0920,  8.6867, 11.6026],
        [26.2748, 27.2150, 10.3653, 14.5692, 14.8664, 15.1593, 18.0427],
        [28.9905, 14.5510, 13.1135, 14.3295, 15.0680, 16.9388, 23.4919],
        [17.5771, 17.2311, 28.1117, 40.6754, 34.8704, 20.0811, 19.0510]])
ground truth
tensor([[25.5523,  5.2604,  8.0747, 14.1241, 14.1899, 16.7675, 22.3567],
        [21.7386, 25.5129, 28.2483, 32.1804, 37.2699, 22.0673, 31.9832],
        [10.7568, 10.7568, 11.1395, 15.8730, 23.8095, 36.3237, 26.4172],
        [26.1967, 27.6302,  7.4171, 10.5865,  9.5476, 10.0736, 18.5429],
        [12.0594, 17.3856, 16.4256,  8.7717,  5.9705, 15.2288, 14.4924],
        [28.0642, 3

batch_predictions
tensor([[14.2954, 15.8354, 16.3393, 24.7311, 36.0288, 33.3387,  9.0940],
        [13.5792, 13.7444, 23.5191, 33.6945, 32.0624, 12.8997, 15.4033],
        [35.0946, 28.8023, 18.3438, 17.0376, 15.3440, 16.4092, 27.1403],
        [16.7506, 15.8695, 20.9121, 36.1982, 43.3092, 31.6565, 17.9773],
        [28.5556, 20.5646, 20.3793, 20.2665, 24.5187, 33.4405, 38.4817],
        [22.9642, 24.7895, 10.3335, 12.4534, 13.4774, 13.4433, 17.7539],
        [25.6533, 35.9628, 29.2865, 13.7563, 16.3075, 16.4942, 18.2275],
        [21.6923, 27.6145, 29.3910, 13.0158, 17.1840, 16.7816, 18.1911]])
ground truth
tensor([[14.0873, 17.0493,  2.6927, 20.4649, 36.0402, 35.2749,  7.7239],
        [ 8.7868, 14.8526, 27.6644, 34.7647, 27.9904, 12.0040, 15.1502],
        [16.7375, 48.7387, 28.8124, 14.4274, 13.0385, 15.9864, 30.0737],
        [16.9359, 17.3895, 13.7046,  0.0000, 21.6837, 35.5867, 18.6366],
        [23.2246, 21.4229, 14.7685, 17.8985, 20.6996, 26.6833, 31.8648],
        [23.1063, 2

batch_predictions
tensor([[ 8.3816,  9.3390, 13.5406, 20.0333, 16.9291,  7.8994,  8.2104],
        [27.0988, 25.5193, 11.8850, 13.6963, 13.0232, 13.0487, 17.4192],
        [26.2573, 10.3976, 14.9421, 15.4671, 17.4745, 25.9327, 28.9683],
        [ 7.8415,  9.1264,  9.8784,  9.8889, 13.2637, 22.3884, 16.2106],
        [12.8733, 14.9550, 12.9594, 13.9820, 19.4508, 24.3163, 22.4726],
        [15.0086, 14.5387, 18.6839, 27.7506, 26.7257, 13.6241, 16.4307],
        [16.8429, 22.7625, 29.3959, 31.6688,  9.5400, 13.5580, 15.0447],
        [11.9831, 12.3514, 12.4332, 14.7982, 25.9334, 29.5039, 10.4362]])
ground truth
tensor([[ 9.2451,  8.6665, 12.8222, 14.7685, 17.4382,  7.1015,  8.2851],
        [21.4755, 21.8832, 14.6239, 15.6628,  9.9947, 12.4408, 12.4014],
        [21.5420, 11.4654, 11.5788, 13.4637, 13.2653, 25.7937, 28.3730],
        [17.0305, 12.2830,  9.8106, 10.8364, 16.4256, 22.5539, 15.7154],
        [11.0969, 14.5975,  9.8073,  7.9790, 17.0635, 21.9955, 22.6332],
        [14.6117, 1

batch_predictions
tensor([[25.4348, 31.1046, 25.2844, 14.4187, 15.6909, 14.7151, 16.6885],
        [ 8.4431, 11.9084, 11.6862, 13.1281, 18.4241, 21.6611,  8.5497],
        [22.4697, 26.4437, 25.5065,  9.9942, 13.3964, 13.3618, 15.9156],
        [12.9430, 13.8399, 17.3112, 24.0689, 22.3695,  9.4764, 13.0770],
        [32.1405, 23.6851, 13.9960, 14.5981, 15.0618, 18.6114, 26.2388],
        [18.7478,  8.9730,  8.5216,  8.3296,  9.7102, 11.2952, 18.7103],
        [16.1215, 17.7420, 22.1124, 32.2979, 24.5692, 13.2732, 16.0909],
        [17.8169, 24.0744, 27.4786, 15.3149, 15.5370, 14.4655, 15.6592]])
ground truth
tensor([[2.5500e+01, 3.6310e+01, 6.1678e+00, 2.6302e-02, 1.3440e+01, 1.5189e+01,
         1.6189e+01],
        [3.7875e+00, 2.3935e+00, 1.3269e+01, 1.3256e+01, 1.4782e+01, 1.7438e+01,
         1.3730e+01],
        [1.6879e+01, 2.9365e+01, 2.4391e+01, 1.0176e+01, 1.3194e+01, 1.2103e+01,
         1.7078e+01],
        [1.4032e+01, 1.3440e+01, 1.9740e+01, 2.7209e+01, 2.1028e+01, 1.2283

batch_predictions
tensor([[12.7430,  7.8834,  8.1790,  8.5101,  8.8787, 11.6771, 17.6266],
        [16.0611, 16.3463, 16.5026, 26.8879, 42.9081, 31.9105, 15.5223],
        [16.6609,  8.4447, 10.0726, 11.1174, 13.0006, 20.3858, 20.7759],
        [13.1041, 17.8269, 27.4142, 28.0830,  6.9708, 11.4602, 12.4447],
        [37.3313, 13.7596, 16.9746, 17.1391, 20.5497, 35.7880, 49.8302],
        [13.2005, 15.7003, 23.8785, 29.7460,  9.4123, 13.0978, 14.0401],
        [12.7389, 20.3414, 24.0612, 19.9024, 11.4011, 12.6363, 12.3463],
        [15.4194, 17.9260, 20.1082, 24.9796, 30.7324, 27.7596, 12.4243]])
ground truth
tensor([[17.5171,  6.7596,  7.7196,  7.3645,  8.8769, 10.6128, 16.6754],
        [10.7710, 18.4666, 19.7279, 25.7653, 50.7653, 37.8260, 27.7069],
        [15.6497,  8.4166,  9.9027, 10.1657, 10.9942, 10.9022, 19.1347],
        [ 9.1662, 15.3472, 20.3840, 23.9085,  5.0237, 10.7180,  8.3377],
        [37.5709, 16.8509, 18.1831, 13.7046, 14.2290, 28.1888, 39.7676],
        [17.3987, 1

batch_predictions
tensor([[15.5764, 15.5296, 19.3067, 30.4028, 35.5865, 17.8483, 17.1101],
        [19.5384, 18.4899, 37.9517, 49.0665, 46.6852, 21.8920, 18.4596],
        [39.3101, 25.5345, 23.6038, 19.9541, 20.0445, 26.6623, 43.6667],
        [19.1252, 12.6224, 13.8815, 14.4610, 16.4621, 24.4266, 25.8093],
        [13.9756, 14.7305, 17.6042, 27.5867, 33.8172,  9.6614, 15.6713],
        [14.0357, 21.1249, 29.4041, 24.5739,  9.3132, 12.8931, 13.0481],
        [14.2938, 18.2401, 24.8729, 22.7882, 11.5344, 13.3500, 13.5448],
        [40.6571, 46.0167, 36.4335, 29.6607, 28.4061, 28.7303, 33.6004]])
ground truth
tensor([[16.6950, 19.1752, 21.9671, 32.2137, 35.4308, 18.0414, 20.8333],
        [15.8732, 17.8722, 33.8769, 48.5402, 41.5965, 13.9006, 13.3088],
        [39.3282, 32.0862, 19.8980, 20.1814, 13.4637,  0.0000, 15.4478],
        [26.4598, 15.4620, 13.5488, 16.9076, 12.5283, 22.9734, 27.5652],
        [18.3325, 17.3067, 14.8080, 27.5907, 28.5113,  8.1799,  6.8254],
        [14.6967, 1

batch_predictions
tensor([[20.0295, 20.3846, 25.1410, 22.3280, 19.5459, 18.8449, 19.6407],
        [15.8380, 11.8154, 13.1869, 13.5827, 14.8608, 21.2970, 27.9210],
        [20.0678, 16.5389,  7.8574,  9.1509,  8.1594,  9.0574, 14.1919],
        [15.4648, 24.9523, 29.3271, 26.9343, 13.7382, 14.7321, 14.8934],
        [29.6785, 37.0966, 44.3691, 48.4263, 45.9693, 37.4552, 29.4223],
        [14.9341, 17.0006, 25.8257, 30.1249, 14.7265, 16.6151, 15.1971],
        [22.9153, 11.7862, 13.6199, 13.8358, 15.5233, 26.7665, 30.1836],
        [13.8046, 18.5955, 22.8293, 16.0664,  8.9355,  9.6278, 12.0080]])
ground truth
tensor([[21.7971, 21.5561, 28.3872, 40.6321, 22.9734, 24.4898, 17.4036],
        [27.2959, 15.3203, 17.1344, 14.7676, 15.2211, 14.9235, 18.6224],
        [18.1483, 15.5050,  5.1420,  7.5618,  5.1815,  9.3503, 13.3614],
        [14.7251, 21.5136, 29.9603,  8.7727, 10.9410, 17.0068, 16.3690],
        [43.2930, 45.1604, 52.1699, 63.6376, 64.6107, 41.6491, 42.5829],
        [14.6633, 1

batch_predictions
tensor([[14.3103, 17.8706, 27.5379, 25.3575, 14.1276, 15.0043, 14.2277],
        [18.3836, 20.7432,  7.1840,  7.5248,  7.5495,  7.6594, 12.6640],
        [14.6163, 22.3469, 28.9700, 25.8323, 11.2128, 13.9852, 13.8427],
        [26.7885, 25.3796, 13.7793, 15.0884, 14.0933, 15.2800, 21.2095],
        [15.0168, 18.5294, 26.6327, 31.3525, 19.3523, 15.8795, 15.4885],
        [16.3797, 22.1434, 31.2277, 18.1012, 14.9775, 15.2568, 15.1331],
        [30.6030, 33.2189,  8.2731, 13.2298, 14.0082, 15.8537, 21.5705],
        [12.9810, 18.6024, 28.4006, 27.6177, 10.7152, 13.8186, 12.2479]])
ground truth
tensor([[15.5181, 17.6881, 27.7880, 27.9853, 13.8217, 12.8880, 13.8480],
        [ 1.2188, 25.9495,  5.4422,  9.7647,  4.6344,  8.1066, 11.9756],
        [13.4212, 29.1241, 20.1389, 18.1406, 11.5363, 12.7693, 13.2228],
        [22.7749, 28.2596, 11.9473, 16.9926, 15.3486, 15.1502, 22.9450],
        [17.8855, 24.3622, 15.9864, 24.1213, 16.0714, 12.9110, 19.0760],
        [22.3498, 2

batch_predictions
tensor([[15.8204, 21.9647, 18.4063,  7.4678,  9.5650, 10.5764, 11.7881],
        [12.5621, 14.9489, 22.2884, 22.7275,  8.9788, 12.3746, 12.7994],
        [44.9116, 43.0247, 40.6228, 39.7192, 38.4379, 38.7675, 38.8072],
        [17.5837, 17.0338, 16.4747, 17.3212, 27.6609, 45.5089, 40.8062],
        [ 6.2983,  5.3422,  5.2220,  5.6228,  5.6562,  7.8535, 10.3014],
        [27.8200, 38.1781, 24.2382, 13.5747, 15.3664, 14.5300, 15.8964],
        [37.0535, 29.9614, 12.6250, 15.1865, 12.7918, 14.4912, 28.1515],
        [12.9540, 18.9984, 27.6067, 25.0316, 10.3359, 11.9484, 12.7930]])
ground truth
tensor([[19.0426, 17.2672, 22.2777,  4.5634,  6.1152,  9.1136,  8.4824],
        [11.3361, 15.6365, 19.7528, 19.6344,  7.7722, 11.7307, 12.4803],
        [57.9037, 48.6981, 44.2136, 45.8443, 45.3051, 41.9911, 43.3062],
        [22.8175, 18.1973, 13.1236, 11.3662, 21.6837, 47.6190, 44.4870],
        [ 4.1950,  3.1746,  2.8770,  5.4280,  5.8957,  9.7931, 10.6859],
        [33.6735, 5

batch_predictions
tensor([[14.3090, 15.2385, 22.1029, 31.4278, 31.6453, 11.6265, 14.9971],
        [15.4590, 23.6880, 29.3544, 21.1391, 12.9851, 13.1116, 14.7063],
        [ 8.7209,  9.6207, 12.1136, 12.8171, 15.5734, 19.9085, 22.5682],
        [23.5877, 22.3807, 34.5522, 49.6828, 46.8543, 33.4263, 26.8137],
        [22.8030, 27.2935,  8.2431, 11.4005, 12.4093, 12.3456, 17.0381],
        [32.3535, 26.2540, 11.2802, 13.6627, 11.7686, 12.8549, 19.3595],
        [14.9569, 15.3169, 14.8110, 15.9942, 21.0855, 32.3024, 30.9776],
        [15.5516, 19.7702, 24.3899, 20.4991, 12.2360, 13.2846, 13.2443]])
ground truth
tensor([[22.9308, 17.0918, 19.6429, 28.8974, 24.8158, 10.0765, 13.6196],
        [12.7834, 21.4853, 25.8220, 12.2307, 12.5000,  8.5176, 11.3520],
        [ 6.8648,  3.1694, 11.0863, 11.2704, 13.0195, 18.8716,  5.4050],
        [24.1071, 28.0187, 39.6400, 55.2012, 47.8600, 33.5034, 27.9620],
        [14.2425, 23.6323,  5.9705, 11.6649,  9.6002,  9.0873, 12.7696],
        [29.9178, 2

batch_predictions
tensor([[24.4299, 27.5816,  7.4664, 12.5930, 12.1359, 14.6217, 19.3179],
        [22.7325, 28.2426, 22.1981, 12.7129, 13.9091, 12.7075, 15.5401],
        [25.2571, 34.7821, 31.3288, 17.8107, 16.6364, 15.8979, 17.7932],
        [18.2732,  8.8027, 11.4977, 11.2236, 11.4331, 15.6472, 23.3223],
        [14.7119, 17.6471, 24.3791, 31.3433, 10.7548, 14.3659, 13.7704],
        [ 9.3273, 12.7800, 12.7833, 14.8467, 20.6239, 28.6472, 20.5471],
        [32.5828, 33.2469,  8.8499, 15.1195, 15.6664, 16.4015, 21.0552],
        [20.4034, 22.0185, 23.8027, 29.0829, 33.9197, 32.2461, 28.6274]])
ground truth
tensor([[29.5371, 28.3535,  3.8269,  9.6265,  3.7349, 14.0847,  9.2977],
        [23.2426, 26.3605, 27.1117, 14.9943, 14.9943, 13.1094, 14.6825],
        [29.5493, 43.7783, 26.2188,  0.0000,  3.4864, 15.6321, 16.5675],
        [18.4903,  4.6028,  8.8375,  8.7585,  7.8117, 13.7691, 20.3314],
        [18.7138, 27.8932, 30.9179, 21.9621,  0.0000,  0.0000, 19.9369],
        [ 6.1650, 1

batch_predictions
tensor([[23.2629, 32.6681, 24.7313, 13.4139, 15.1801, 14.5927, 16.6194],
        [14.1586, 14.1573, 15.5930, 23.1173, 32.4632, 26.2323, 10.8862],
        [20.5843, 11.8886, 13.5057, 12.3960, 12.6816, 20.0332, 25.2878],
        [20.7341,  9.3465, 11.0097, 10.0736,  9.7428, 12.1889, 21.6197],
        [33.6219, 42.6846, 35.2107, 15.9970, 17.0877, 16.3207, 18.6485],
        [17.1535, 29.3612, 48.4788, 43.7175, 21.2947, 19.9901, 17.4261],
        [11.9359, 14.9459, 14.9590, 15.4447, 21.6750, 30.1224, 25.9240],
        [21.0974, 21.1885, 21.7801, 28.4997, 35.3032, 34.4397, 16.9763]])
ground truth
tensor([[23.6586, 39.3083, 29.1426, 19.5292, 14.7028, 14.5055, 18.9506],
        [14.9093, 12.8118, 12.9535, 20.1247, 32.4263, 34.3679, 13.7613],
        [19.5720, 11.8906,  9.0278, 11.4087, 11.1111, 23.2285, 31.9444],
        [17.3198,  0.0000,  6.1152, 10.3893,  9.9816, 12.9932, 22.1462],
        [33.2766, 49.9575, 43.3107, 19.7846, 20.5782, 14.5975, 20.3798],
        [16.0310, 2

batch_predictions
tensor([[12.5606, 12.5643, 13.6916, 14.6009, 20.9311, 22.3501, 13.6738],
        [12.2157, 12.8254, 13.5494, 18.3934, 26.5243, 23.5348, 11.9270],
        [17.7094, 19.2566, 29.7458, 11.8838, 14.2804, 13.7159, 15.4800],
        [19.5017, 26.2621, 24.5723, 12.1772, 15.1221, 13.7057, 15.6191],
        [40.0473, 38.9587, 23.5583, 18.0737, 15.7167, 15.3704, 20.4366],
        [20.6074,  7.4777,  8.6907, 10.4679, 11.4619, 14.8965, 23.6840],
        [25.9833, 12.5015, 14.7196, 14.0883, 15.0807, 17.9230, 28.8402],
        [11.3482, 12.0020, 12.0853, 11.8054, 17.1118, 25.8548, 26.5514]])
ground truth
tensor([[13.6639, 13.1773, 12.8880, 10.8233, 25.7233, 25.6970, 14.2425],
        [10.4024,  9.3240, 11.8753, 16.7938, 22.8564, 23.3167, 12.0989],
        [18.7270, 26.9332, 24.7764,  8.1405,  9.2057, 13.8480, 12.6512],
        [21.3966, 21.0153, 27.1831, 11.7175, 16.5965, 13.8480, 17.5171],
        [35.8985, 38.9031, 18.9626, 10.5584, 16.7517, 15.4478, 22.1797],
        [19.7528, 1

batch_predictions
tensor([[29.4917, 44.2799, 41.9123, 19.2630, 19.3205, 17.6806, 18.6272],
        [ 9.6530, 12.5270, 12.1291, 12.6742, 16.2417, 21.2998, 23.8042],
        [ 7.7117,  7.7527,  7.9094,  7.5256,  9.3783, 10.2421, 13.6705],
        [ 8.7030, 10.6036, 16.8519, 20.5133,  8.9629,  8.0286,  8.6569],
        [13.8201, 15.2569, 19.5831, 29.0610, 29.3543, 14.3353, 13.6136],
        [17.8524, 31.2419, 30.9975, 12.7311, 15.3915, 13.9906, 15.1888],
        [ 9.8345, 10.1153, 12.9145, 19.7666, 20.8262,  7.8274,  9.7181],
        [ 8.9607,  9.3566, 12.6373, 20.4739, 21.2687,  8.5045,  9.7566]])
ground truth
tensor([[25.0526, 46.5281, 12.1120,  0.0000,  8.6533, 16.2020, 20.5550],
        [ 7.0358, 11.7570, 13.7559, 12.4145, 16.9516, 21.5939, 21.8569],
        [ 6.5760,  6.0232,  5.2154,  4.2800,  6.8736,  9.6088,  1.9700],
        [ 8.3246, 11.8096, 17.0831, 16.8990,  8.1010,  7.7328,  9.6791],
        [10.9694, 12.1740, 15.0510, 22.4065, 31.3917, 16.9218,  4.2375],
        [22.2506, 2

batch_predictions
tensor([[13.2988, 16.7536, 25.0290, 28.4983, 17.0494, 14.2711, 14.4879],
        [36.3288, 12.9767, 17.6224, 18.9620, 22.6894, 34.8242, 44.3686],
        [21.3298, 15.2293, 10.9925, 12.2315, 10.8187, 11.9142, 14.2644],
        [21.1384, 18.6143,  9.5835, 11.0873,  9.7865, 11.9222, 16.0250],
        [12.8940, 13.4724, 15.5974, 22.9416, 30.0775, 14.9941, 12.9254],
        [25.0480, 11.3805, 14.5980, 14.4432, 15.5207, 23.0464, 33.6408],
        [10.0777, 12.7333, 16.9649, 19.2327, 17.5324,  7.3460, 10.0990],
        [ 7.6878,  8.8360,  9.9119, 11.8195, 18.4316, 24.9278,  7.5138]])
ground truth
tensor([[10.1049, 17.1910, 29.9178, 32.9790, 14.0448, 13.2795, 11.4938],
        [40.6746, 20.7341, 18.4807, 22.3498, 21.8537, 32.0011, 47.7041],
        [21.3835, 13.2562, 11.7438, 11.0337, 10.9548, 11.5992, 12.4934],
        [24.1922, 20.6916,  9.8073, 11.3379,  8.5743,  4.9036, 13.4212],
        [ 9.1399, 13.4008, 13.4929, 23.1852, 23.0668, 11.5860, 11.1257],
        [33.3759, 1

batch_predictions
tensor([[11.1826, 12.7429, 15.4174, 21.7860, 25.2646, 12.2018, 11.8882],
        [28.3458, 45.6658, 44.0547, 21.9854, 17.7364, 17.2018, 16.5926],
        [32.5119, 17.1713, 18.7597, 18.2457, 23.2172, 37.3860, 47.2665],
        [19.8827,  9.7973, 11.8118, 11.9849, 12.5736, 17.3046, 25.1804],
        [29.2488,  8.8969, 13.5316, 13.3059, 13.7944, 18.3919, 30.1374],
        [15.1461, 15.6799, 20.2184, 23.3453,  7.9228,  9.1404, 14.9619],
        [12.3925, 13.3985, 13.5989, 16.9838, 24.5817, 25.4754,  8.3675],
        [11.3636, 11.7072, 12.6081, 22.2344, 29.5527, 20.0165, 10.3800]])
ground truth
tensor([[ 8.4955, 11.3098, 18.7533, 26.3545, 28.5245, 15.0053, 10.6391],
        [28.9321, 43.1483, 42.4645, 15.9390, 26.8017, 17.4382, 14.9527],
        [34.3537,  0.0000, 17.3328, 20.5924, 25.5527, 33.3900, 51.8707],
        [22.1372, 11.6780,  9.1978, 11.7772, 11.6638, 19.3027, 26.6156],
        [24.9737,  4.9842,  6.8911, 12.8748,  9.1662, 15.3472, 20.3840],
        [14.3707, 1

batch_predictions
tensor([[ 7.2139,  9.1447,  9.2849,  9.7994, 12.0458, 18.9793, 19.2393],
        [13.4473, 18.0300, 24.5572, 23.6460,  8.8985, 12.7060, 12.3297],
        [14.5678, 16.7983, 16.0385, 20.3719, 39.1047, 44.9072, 26.8521],
        [19.5394, 28.5288, 22.6634, 14.3656, 13.6245, 12.0555, 13.4690],
        [16.1756, 15.9705, 17.4121, 23.7629, 36.2029, 30.6895, 14.5630],
        [26.8320, 11.2436, 13.7123, 14.1068, 15.3778, 17.8223, 25.2824],
        [33.2598, 21.1306, 19.4489, 17.2862, 19.7260, 29.6849, 40.9087],
        [17.1930, 17.0517, 18.2759, 28.2358, 39.3938, 32.4321, 17.0767]])
ground truth
tensor([[ 4.4845,  7.9958,  5.7733,  5.4840, 10.7180, 15.5050, 21.9095],
        [11.6071, 21.9671, 24.7449, 23.7812, 11.0119, 12.1457, 10.3175],
        [14.6684, 20.6491, 19.7279, 19.8980, 34.5522, 41.2840, 27.0550],
        [19.3594, 29.2517, 26.7715, 12.7834, 12.7409,  9.2262, 22.6616],
        [13.4637, 15.1644, 18.7925, 23.2568, 39.2290, 29.7052, 17.3186],
        [38.5629,  

batch_predictions
tensor([[12.8123, 14.5516, 13.9454, 17.8952, 21.6837, 23.8856, 13.1927],
        [14.9054, 10.5124, 12.4507, 12.7390, 14.5600, 19.2227, 27.2059],
        [12.9665, 13.4380, 14.7445, 25.4404, 35.5412, 24.6967, 11.1103],
        [10.6059,  9.3816,  9.4542, 11.1303, 16.2931, 14.0174,  8.9741],
        [14.0826, 20.4618, 33.6555, 32.0273, 12.5447, 14.5925, 14.4189],
        [27.1437, 25.1754, 19.0746, 21.9269, 15.4205, 15.8422, 18.6106],
        [15.1718, 15.1556, 26.7296, 37.1444, 31.1999, 17.3376, 16.4706],
        [20.1670, 17.2782, 15.9686, 16.1553, 21.2529, 32.4703, 30.8384]])
ground truth
tensor([[ 4.5210,  9.2120, 21.5136, 22.0947, 19.0193,  8.9144, 10.2466],
        [27.4724, 10.4550, 11.0994, 14.4003, 18.8059, 15.8864, 27.0516],
        [10.0482, 10.0198, 13.8747, 27.1825, 40.6604, 29.5918, 11.7772],
        [ 9.6791,  8.8243,  9.9947, 11.0863, 20.5155, 13.5850,  9.9816],
        [13.4070, 23.0300, 35.2749, 32.9932, 14.9943, 14.5692, 16.4683],
        [32.7239, 2

batch_predictions
tensor([[14.4052, 15.5236, 15.6476, 17.9109, 31.1764, 42.9555, 22.6759],
        [ 9.9578,  6.3826,  5.1691,  4.8853,  5.2251,  6.0397, 10.1393],
        [29.2662, 13.4113, 16.2737, 17.1659, 18.4277, 24.7647, 27.2802],
        [12.1180, 13.5278, 22.7953, 28.9283,  9.4758, 11.8408, 11.9790],
        [37.0777, 30.5655, 12.7125, 14.0430, 12.9436, 13.6746, 23.4428],
        [15.3022, 14.7149, 23.5152, 40.3871, 39.1594, 20.5904, 18.8257],
        [12.1787,  8.1513,  6.6709,  8.6030, 10.2217,  8.6426, 11.7771],
        [13.8279, 15.7498, 21.8645, 25.9385, 13.9192, 14.5316, 14.5609]])
ground truth
tensor([[13.5455, 16.2546, 13.9795, 16.0705, 21.3046, 29.6423, 22.7117],
        [10.6859,  4.4218,  4.7619,  5.0170,  5.6689,  6.6043, 11.6355],
        [32.1145, 14.4133, 19.3452, 14.2999, 18.0556, 19.4303, 31.5334],
        [12.7170, 15.8864, 16.9516, 18.0957, 14.4135, 16.5176, 11.2835],
        [37.8543, 34.7931, 15.3628, 11.8764, 11.2245, 12.1882, 19.9121],
        [14.3707, 1

batch_predictions
tensor([[12.5702, 11.1815, 11.4135, 12.7621, 15.6504, 21.0919, 24.5613],
        [12.5080, 13.2420, 19.6373, 25.3414, 24.7367, 11.6834, 13.5202],
        [20.0916, 20.1123, 11.1187, 12.6142, 12.0113, 13.2717, 16.0847],
        [31.5102, 28.6525, 17.1877, 14.8719, 15.0185, 16.2527, 22.3453],
        [14.2013, 22.7855, 33.7626, 26.1963, 10.3652, 12.3717, 12.4781],
        [ 7.8709, 11.0370, 11.8948, 16.2733,  5.8619,  6.3391,  7.4235],
        [20.0152, 29.1128, 43.7605, 43.1277, 17.2960, 19.6158, 18.9260],
        [13.8677, 17.2937, 25.3276, 24.7107,  6.9763, 10.2194, 10.3835]])
ground truth
tensor([[12.5197, 11.3887,  8.4955, 11.3098, 18.7533, 26.3545, 28.5245],
        [13.3220, 14.2432, 16.6808, 22.7466, 21.3294, 14.8384, 16.5249],
        [19.6607,  7.9300,  7.3251,  8.9821, 11.7570, 11.2835, 14.1504],
        [42.4603, 34.2120,  0.0000, 11.5930, 17.1769, 13.1519, 20.4790],
        [13.8747, 27.1825, 40.6604, 29.5918, 11.7772, 13.8322, 14.1015],
        [ 6.5492,  

batch_predictions
tensor([[31.3297, 12.1824, 15.0656, 15.2974, 17.2020, 30.5031, 45.4161],
        [12.2170, 14.9299, 20.1071, 26.3829, 18.7566,  8.1165, 10.1796],
        [13.5396, 13.2758, 13.8079, 23.6654, 31.9861, 19.4120,  8.9939],
        [29.9525, 12.4178, 14.1427, 13.8834, 14.3758, 22.0891, 39.5314],
        [11.6939, 15.9139, 16.0220,  7.9737,  7.5253,  8.3717,  8.6165],
        [43.3283, 28.3598, 27.0869, 24.6547, 26.3442, 38.4251, 51.6692],
        [ 9.6556, 11.9747, 18.9506, 19.5669,  7.0519,  7.4736,  8.9705],
        [21.6595, 21.9866, 27.7670, 43.1086, 42.4800, 19.1653, 20.9030]])
ground truth
tensor([[46.5676, 12.3619, 14.2294, 15.0710, 14.5844, 31.9306, 42.7801],
        [11.4371, 15.9014, 18.8634, 31.3917, 30.8532,  8.1207,  1.4172],
        [10.6293,  0.0000,  9.6088, 14.9802, 29.4076, 25.0142,  6.1933],
        [34.7931, 15.3628, 11.8764, 11.2245, 12.1882, 19.9121, 43.1122],
        [ 9.5082, 16.8201, 17.5171,  6.7596,  7.7196,  7.3645,  8.8769],
        [58.4892, 3

tensor([[38.2956, 18.4114, 10.2052, 23.7507, 21.8569, 26.6044, 37.1383],
        [12.7976, 13.5346, 24.1497, 39.2149, 29.1808, 13.1094, 14.3141],
        [27.8932, 14.8080, 13.8217, 12.2304, 16.8201, 19.3451, 30.4445],
        [ 8.3903, 10.4813,  9.7186, 17.7407, 18.2536,  7.7854,  6.0494],
        [ 0.7511, 13.7046, 15.3770, 20.3090, 36.6497, 37.3441, 14.6967],
        [27.6959, 25.0921, 27.0252, 39.1899, 10.6260, 15.6760, 28.9321],
        [ 7.4698,  9.7843,  6.1678, 11.9016, 18.9374, 21.0021,  7.6933],
        [31.9444, 25.5952, 11.7347, 14.7817, 11.5646, 13.3220, 19.8129]])
batch_predictions
tensor([[31.0164, 14.2583, 17.4948, 17.3296, 19.8086, 31.5066, 46.7817],
        [40.1772, 26.6316, 14.8344, 16.1645, 14.3667, 16.5613, 30.2102],
        [13.6484, 12.7728, 13.8997, 18.5069, 24.9207, 22.4738, 12.3445],
        [10.0372, 13.8177, 13.5984, 14.9961, 21.7419, 33.2731, 25.1533],
        [22.5935, 34.4837, 31.0866, 16.8858, 18.0162, 16.0452, 16.9433],
        [10.4764, 10.5248, 12.41

batch_predictions
tensor([[22.7776, 10.3118, 14.0263, 14.0444, 15.5384, 20.0806, 35.8489],
        [11.3237, 11.1812, 10.9986, 11.5277, 17.4034, 24.2610, 16.6259],
        [15.9169, 16.0876, 16.3960, 29.7702, 39.9340, 32.7963, 12.9780],
        [14.0232, 15.8196, 22.2832, 34.4833, 31.5008, 12.4070, 16.0148],
        [14.4125, 15.9365, 26.0233, 35.7493, 31.9541, 14.8918, 15.6078],
        [34.1477, 10.7483, 14.8505, 14.1149, 15.2161, 20.8106, 35.9798],
        [15.3105, 16.3542, 21.3016, 31.6127, 31.6233, 10.3391, 13.0564],
        [13.1334, 15.7519, 15.2870, 15.4916, 20.8033, 31.3300, 29.8779]])
ground truth
tensor([[30.3430, 11.7489, 17.1202, 14.9802, 15.7880, 32.5397, 39.7109],
        [ 8.1536, 12.6775, 11.4150, 21.5018, 18.0694, 28.0642, 28.7349],
        [12.9110,  0.0000,  0.0000, 14.8384, 35.4025, 33.6593,  8.3617],
        [13.4354, 19.4586, 24.9858, 36.0969, 34.7789, 19.8696, 14.2432],
        [12.7693, 14.9093, 25.9070, 36.8764, 30.6264, 12.7409, 14.2715],
        [36.1395,  

batch_predictions
tensor([[ 7.4611,  8.9074, 11.9817, 12.8587,  8.1147,  8.1994,  7.5527],
        [24.4495, 33.8875, 43.3259, 39.4281, 28.3577, 24.9976, 22.4172],
        [16.1217, 16.6846, 23.4494, 40.2325, 43.6220, 31.9380, 16.9906],
        [15.0650, 24.8940, 24.0768, 12.0944, 13.1104, 11.5448, 12.0250],
        [22.2287, 25.3640,  9.9454, 12.5347, 12.6049, 12.8563, 16.0661],
        [12.4150, 11.9193, 18.3768,  7.7584, 11.0192,  8.6574,  8.6425],
        [14.6984, 16.3132, 15.3631, 17.1217, 27.2442, 38.3841, 27.2201],
        [ 7.6362,  7.7612, 10.6732, 10.5639, 11.1011, 17.5555, 22.2474]])
ground truth
tensor([[ 7.1145,  7.7948, 12.3583,  9.8356,  8.4467,  7.7948,  6.2217],
        [25.0142, 38.7755, 61.8481, 20.0822, 26.4739, 27.9195, 29.5777],
        [18.5516, 16.9076, 20.6633, 30.9949, 46.7687, 34.8356, 19.1610],
        [21.7545, 26.1338, 26.1763, 14.9660, 15.1502, 11.0261, 12.7551],
        [20.6602, 26.9726,  9.7449, 11.4150, 12.0594, 14.0847, 20.7128],
        [11.8764, 1

tensor([[13.3362, 17.1485, 15.6888, 12.2874, 19.4303, 27.6927, 28.5998],
        [11.0261, 20.4082, 37.2166, 28.1604, 14.6967, 13.2511, 21.1168],
        [ 8.9711,  3.0754, 17.3469, 14.3707, 14.1156, 19.5011, 20.4223],
        [ 8.6735, 15.3061, 21.9671, 19.7846,  9.4813, 12.7409,  8.6451],
        [14.7422, 13.1773, 18.3588, 16.4387, 21.7780, 33.8638, 31.1415],
        [11.6649, 20.3314, 29.6686, 30.1289, 14.5581, 11.0994, 13.9400],
        [16.7517, 15.4478, 22.1797, 44.8696, 36.2528,  9.3963, 10.6718],
        [12.1740, 15.0510, 22.4065, 31.3917, 16.9218,  4.2375, 12.2591]])
batch_predictions
tensor([[13.6945, 14.0995, 18.7288, 30.8868, 25.8575, 12.5893, 15.5981],
        [32.7649, 36.1967, 28.5276, 14.6324, 16.1339, 15.7142, 17.7077],
        [12.7930, 13.1016, 16.8625, 23.8325, 25.0449, 12.2615, 13.2439],
        [15.1063, 15.4714, 22.1396, 31.9398, 33.5229, 13.1186, 15.7978],
        [15.7333, 16.1200, 15.6709, 18.1024, 31.5253, 45.7337, 31.1990],
        [23.2214, 33.5236, 28.86

batch_predictions
tensor([[27.7180, 14.6133, 11.6965, 12.4493, 13.2186, 13.7107, 25.5215],
        [27.0489, 33.6825, 31.4346, 14.0497, 15.9334, 16.2647, 17.9761],
        [10.5334, 11.1859, 11.1334, 13.1082, 19.5746, 27.2204, 14.9266],
        [10.6474, 11.6667, 11.7092, 15.6201, 24.2976, 23.5530,  8.8890],
        [42.9838, 42.8785, 18.6101, 20.4014, 19.3991, 23.1019, 32.4523],
        [17.3893, 18.4684, 34.8549, 33.4128, 29.8389, 17.7431, 16.6927],
        [13.4374, 14.2113, 15.3238, 27.7476, 38.4339, 31.9401,  9.9494],
        [17.8411, 17.0683, 16.0910, 26.1122, 37.9228, 36.4222, 20.7569]])
ground truth
tensor([[24.9433, 23.1576, 11.9048, 11.0969,  9.9206, 11.9615, 29.0249],
        [23.1983, 34.1005, 26.8411, 19.2136, 21.3966, 17.4382, 19.1215],
        [13.9739,  9.1978, 11.6071, 14.2432, 18.1973, 29.0533, 11.8906],
        [11.3361, 13.3877, 14.4266, 16.2283, 24.3556, 23.2378,  6.5623],
        [37.0748, 45.1105, 14.8668, 12.1882, 21.5986, 20.8333, 43.3532],
        [13.3645, 1

batch_predictions
tensor([[17.6517, 16.8062, 18.2150, 22.3852, 25.9644, 16.5987, 17.1506],
        [14.0674, 20.0767, 28.3863, 29.1944, 11.7904, 14.8161, 14.1285],
        [33.0788, 35.8721, 19.1804, 17.8224, 17.2855, 18.9859, 25.1028],
        [12.4938, 14.8023, 14.6489, 15.9654, 20.0137, 27.3094, 23.2930],
        [ 8.3256,  8.5593, 11.4671, 16.2056, 16.0748,  6.5125,  7.7784],
        [33.2591, 44.8218, 40.6515, 29.1937, 26.5521, 26.2324, 25.7501],
        [14.8874, 15.3833, 15.1880, 19.9980, 28.7620, 26.9537, 11.6870],
        [22.3008, 18.3561,  8.3150, 10.2118, 10.3465, 10.6972, 17.2074]])
ground truth
tensor([[23.4127, 21.3861, 23.8095, 27.4660, 37.4008, 21.9955, 23.6961],
        [14.2162, 18.2141, 22.2646, 29.2215, 10.6391, 13.8348, 12.6644],
        [40.8163, 33.6026,  0.0000, 10.6293, 19.5437, 17.3044, 22.1514],
        [13.4779, 14.6684, 12.8543, 10.7143, 20.1672, 26.8991, 29.4218],
        [ 7.1936, 10.7838, 12.6381, 15.0053, 11.7833,  4.2609, 12.0463],
        [39.6400, 5

batch_predictions
tensor([[ 9.1527, 12.8184, 13.3245, 13.7638, 19.3888, 29.4615, 28.6391],
        [14.6482, 15.5411, 20.9281, 33.1149, 32.3073, 11.0322, 16.0577],
        [14.0047, 17.4900, 17.0474, 16.5237, 18.5583, 26.1550, 37.5585],
        [13.5729, 12.5366,  8.1604,  7.4733,  7.8588,  7.8008, 11.2925],
        [23.1564, 11.3814, 13.4328, 13.2149, 13.0094, 15.0283, 22.2653],
        [19.7341, 26.0737, 24.3184,  9.4549, 13.6280, 13.2618, 14.9137],
        [12.2821, 13.0075, 17.8487, 24.2467, 23.4840, 11.0061, 12.3646],
        [14.3662, 15.7149, 24.2730, 32.9534, 28.6006, 12.6475, 13.7734]])
ground truth
tensor([[ 6.5229,  9.9816,  9.2451,  8.7980, 14.3477, 18.2930, 30.3130],
        [10.9694, 15.2778, 22.8883, 32.8940, 28.9683,  9.5096, 13.0952],
        [15.6497, 17.8722, 16.4782, 22.8169, 27.7486, 24.9605, 36.3624],
        [16.8201, 17.5171,  6.7596,  7.7196,  7.3645,  8.8769, 10.6128],
        [22.6460,  9.1268, 13.1378, 13.2167, 11.3361, 15.6365, 19.7528],
        [22.3567, 2

batch_predictions
tensor([[12.1411, 14.2432, 14.0040, 15.0215, 20.6364, 33.9470, 31.2576],
        [13.4572, 13.9912, 14.5937, 20.7662, 27.9564, 18.5103, 10.3928],
        [12.1291, 15.6806, 24.7669, 27.4516,  9.2194, 11.9081, 11.3590],
        [17.3976, 15.7488, 16.6104, 16.5369, 20.8719, 30.8560, 29.3011],
        [ 7.2456,  9.3590,  8.7452,  8.3842, 11.7301, 16.2545, 10.8757],
        [20.0587, 18.1106, 19.8077, 27.2020, 36.3977, 19.8430, 19.8820],
        [28.1596, 10.8613, 14.3545, 14.2032, 13.5269, 15.5614, 24.6464],
        [14.8513, 15.8060, 25.9559, 34.0266, 34.3319, 11.7891, 15.7216]])
ground truth
tensor([[15.2778, 10.0057, 15.9014, 13.4637, 19.8129, 40.4337, 35.2608],
        [16.3204, 14.8211, 11.5071, 18.0694, 26.0652, 22.9748,  9.5345],
        [ 7.9432, 20.0947, 24.6055, 14.0058, 13.6770,  8.6928,  9.4161],
        [49.8948, 23.3167, 28.6691, 25.3419, 22.4750, 48.3824, 43.4903],
        [ 4.3924,  7.4040,  8.0747,  7.3514, 11.3887, 15.9258, 14.9527],
        [22.8883, 2

batch_predictions
tensor([[10.9178, 11.4802, 14.8022, 22.7480, 23.3650,  7.9825, 10.1709],
        [29.3395, 31.8618, 11.7012, 15.6119, 15.1824, 14.9932, 16.6810],
        [12.8166, 15.4903, 22.1871, 32.4865, 34.9799,  9.6089, 12.2486],
        [30.4465, 12.9721, 16.1438, 15.6670, 16.4682, 19.1544, 33.2016],
        [11.0395, 12.1197, 15.9075, 25.7323, 26.2389,  8.1145, 11.4191],
        [17.3125, 18.3728, 20.2056, 36.2200, 49.3702, 43.2086, 26.6719],
        [11.3597, 14.6284,  9.0756,  6.7263,  6.9260,  6.7347,  8.9729],
        [20.8538,  8.4099,  9.7350,  9.5847,  8.6139, 10.8554, 19.5283]])
ground truth
tensor([[11.0600,  9.2583, 11.4282, 23.3561, 25.2762,  8.7585, 11.3361],
        [27.7069, 29.8611, 12.5709, 15.6604, 13.3078, 17.3044, 17.4036],
        [12.6276, 18.7358, 16.4966, 27.4235, 30.6831,  0.1559, 10.2041],
        [30.8957, 16.7092, 21.2727, 17.7721, 20.9325, 23.2285, 30.4422],
        [14.4792, 15.3077, 20.2920, 22.6328, 31.2073,  4.9579, 11.5597],
        [24.9858, 1

batch_predictions
tensor([[14.0450, 15.4378, 18.7329, 25.8021, 25.6713, 13.7465, 15.2051],
        [22.0499, 10.4283, 12.8374, 12.4332, 13.0769, 19.6221, 27.7718],
        [15.1086, 21.8647, 32.1641, 30.2113, 10.3653, 14.0934, 14.1855],
        [14.1816, 20.6098, 29.4905, 14.3451, 13.2578, 12.4881, 12.8072],
        [15.6584, 17.0039, 17.5110, 23.1807, 39.2112, 41.6904, 30.7102],
        [ 7.4896,  9.5179, 10.0497, 10.9031, 16.3233, 23.1409, 11.8793],
        [27.2762, 12.4063, 14.3037, 14.1263, 15.0168, 25.4819, 38.1412],
        [27.0967, 12.6779, 14.8205, 13.7973, 14.0728, 17.6157, 21.0405]])
ground truth
tensor([[16.5045, 17.5829, 17.7670, 23.0668, 23.8953, 10.7575,  1.5255],
        [35.5159, 15.8730, 15.5612, 11.4938, 13.1378, 19.2744, 23.8662],
        [17.6871, 18.0981, 27.1117, 24.8724, 12.1315, 14.9093, 12.8118],
        [13.2937, 20.3656, 36.5363, 21.5703, 20.7058, 20.2664, 20.7341],
        [15.8075, 18.3982, 16.0836, 21.0021, 31.0889, 33.9690, 17.0042],
        [ 5.4708, 1

batch_predictions
tensor([[14.7584, 13.6048, 13.6351, 17.2428, 24.8664, 27.9944, 11.3228],
        [10.8478, 14.2395, 13.7070, 15.6380, 22.6360, 37.4757, 29.2656],
        [10.5043, 13.6165, 14.8635, 15.8633, 20.8947, 29.6491, 27.5861],
        [42.2678, 42.4955, 28.1143, 27.9282, 27.7919, 31.0640, 40.3153],
        [16.1915, 16.7699, 15.9274, 15.8553, 17.2153, 22.6013, 28.5050],
        [33.2181, 13.2838, 16.8511, 15.6899, 16.5859, 22.1381, 37.3813],
        [12.6656, 13.7846, 13.4365, 13.6416, 18.4986, 28.1815, 28.8507],
        [33.0630, 24.6376, 21.1084, 18.0220, 16.4482, 24.4870, 37.0476]])
ground truth
tensor([[15.5576, 15.5445, 15.4524, 15.8206, 24.0794, 26.4203, 10.3104],
        [ 5.7965, 12.6701, 11.1820, 14.2149, 19.9688, 34.2971, 31.1508],
        [12.6381, 16.4256, 16.9385, 16.7806, 18.3719, 30.6812, 28.9716],
        [44.1741, 55.1946, 58.8112, 45.4761, 40.8206, 45.1210, 45.1604],
        [20.5813, 19.4371, 15.0973, 16.3861, 13.9795, 18.4640, 31.9963],
        [39.7251,  

batch_predictions
tensor([[40.5862, 37.0821, 39.2054, 43.0457, 47.7053, 46.9839, 40.5450],
        [11.8487, 12.6294, 13.4166, 19.9277, 31.1452, 20.8574,  7.2104],
        [16.7727, 17.7911, 19.0538, 35.0244, 44.5167, 36.1563, 13.9686],
        [13.6223, 17.8583,  6.9019,  6.6972,  6.9972,  7.2873,  9.1371],
        [20.5059, 23.5913, 22.1927,  9.7005, 13.2669, 12.9736, 14.2146],
        [39.9601, 20.6827, 14.4451, 15.2642, 14.8916, 19.3372, 33.0043],
        [20.3212, 28.9877, 44.2408, 40.6309, 27.0491, 22.5336, 19.6101],
        [15.5851, 15.8051, 13.5279, 13.3781, 12.6275, 14.6899, 20.6208]])
ground truth
tensor([[41.8201, 37.6381, 38.2167, 41.5834, 57.9037, 48.6981, 44.2136],
        [ 5.3713, 13.6338, 15.8872, 15.8163, 25.2834, 21.3010, 10.1190],
        [18.0556, 14.1865, 17.0068, 32.4405, 52.0975, 35.7993, 18.6791],
        [17.3067, 18.7796,  4.7080,  7.1410,  5.6418,  6.5755, 13.2956],
        [15.0794, 22.6616, 21.7120,  7.7098, 12.6701,  9.8356, 10.4308],
        [56.2642, 2

batch_predictions
tensor([[11.1849, 12.5434, 17.5606, 19.0054,  6.9881, 11.8924, 12.6827],
        [13.7318, 14.2730, 20.7483, 34.8478, 39.1191, 11.6711, 15.6801],
        [14.8623, 19.0198, 27.2414, 26.7076, 14.6877, 15.4770, 14.9132],
        [17.5525, 16.2019, 16.7679, 28.4795, 42.3720, 34.8616, 21.5121],
        [14.9542, 19.4466, 26.5186, 32.6827,  7.3652, 14.6680, 15.7113],
        [28.9886, 15.2519, 16.5994, 15.7830, 16.5841, 19.4539, 29.6089],
        [12.0022, 18.6978, 22.1889,  8.5695, 11.8990, 10.3040, 11.1570],
        [12.4721, 22.1282, 29.8073, 13.4755, 10.9719, 10.7295, 12.3408]])
ground truth
tensor([[15.7680, 20.6733, 17.3593, 24.9079, 12.9537, 14.2162, 21.4361],
        [ 6.8648, 12.6775, 19.9500, 30.7601, 33.6533, 12.3225, 15.1236],
        [13.0527, 15.3345, 29.7761, 26.7999, 15.3345, 18.1122, 14.6117],
        [18.0556, 14.2715, 17.7863, 30.0454, 42.4603, 31.2642, 20.4790],
        [15.9014, 18.7500, 25.1134, 29.4359,  5.5981, 12.8543, 10.4167],
        [30.7256, 1

batch_predictions
tensor([[10.4134, 12.3896, 12.3639, 11.0555, 15.5936, 21.5998, 21.9781],
        [19.5521, 12.1178, 10.1753, 12.3155, 15.4887, 19.5960, 20.0546],
        [14.3126, 20.9374, 28.3390, 28.9706, 11.8160, 13.6605, 14.2621],
        [43.1618, 30.5926, 29.2062, 27.6699, 29.2686, 37.7276, 49.0577],
        [17.3830, 19.1271, 23.5068, 36.5510, 42.4956, 28.3373, 16.3855],
        [11.4227, 15.5378, 14.7721, 15.4208, 20.8564, 35.5952, 28.6483],
        [14.5577, 14.4264, 18.1951, 28.1929, 32.4442, 12.3561, 15.5696],
        [10.7456, 10.6600, 11.1643, 14.7770, 22.3544, 17.6606, 15.3866]])
ground truth
tensor([[ 9.0873, 10.0999, 10.6523, 14.2031, 19.7922, 20.9495, 20.1078],
        [22.0410, 15.2551,  9.9684, 13.3745, 16.0179, 16.5439, 14.5581],
        [15.5896, 20.0113, 31.8027, 27.9478,  4.4076,  0.6236,  0.3260],
        [50.6519, 31.7319, 24.9433, 20.9042, 23.3277, 36.7914, 58.6876],
        [19.9369, 21.6991, 23.5271, 30.9048, 42.5434, 45.4629, 19.3582],
        [11.7489, 1

batch_predictions
tensor([[14.4854, 15.2179, 23.1577, 31.6682, 28.6361, 12.0416, 15.8750],
        [12.5611, 11.6205, 10.8596, 11.7460, 16.8156, 25.9130, 23.1723],
        [22.4717, 23.4666,  7.9879, 11.0501, 10.7338, 10.8343, 15.1485],
        [26.9751, 18.7299, 17.4061, 16.8533, 20.4430, 27.5545, 27.4148],
        [16.2383, 16.6523, 21.6287, 38.4359, 42.7201, 25.2917, 16.0608],
        [13.6858, 15.3856, 21.6166, 23.8521, 24.7909, 10.0520, 12.6985],
        [ 8.1763,  8.9809, 10.3181, 10.9992, 13.9268, 16.3694, 20.0007],
        [13.4619, 13.8862, 19.7919, 25.2181, 25.5842, 11.4611, 13.7558]])
ground truth
tensor([[16.3690, 21.0743, 26.0346, 40.4337, 38.0811, 15.7313, 17.2902],
        [ 0.0000,  9.3503, 11.7044, 10.6391, 19.1084, 25.5129, 22.0279],
        [24.3425, 21.1731,  9.1136,  8.8112, 11.0600,  9.2583, 11.4282],
        [22.2383, 21.3966, 24.2635, 19.5160, 19.3319, 27.8538, 32.5750],
        [19.4161, 15.1361, 22.7183, 50.5811, 24.0221, 31.4059, 10.0624],
        [12.5329, 1

batch_predictions
tensor([[12.4666, 13.6121, 15.8331, 25.2548, 33.6417, 27.5117, 12.6811],
        [22.0974, 33.1618, 36.0087, 11.7494, 14.7481, 14.4020, 16.8562],
        [12.6233, 16.1462, 21.2551,  9.5962,  8.8780, 11.4176, 13.8096],
        [19.1268,  7.4024, 10.6068,  9.9503,  9.7652, 10.3904, 19.3945],
        [19.4321, 33.3267, 42.9375, 37.3036, 16.9750, 18.4616, 17.3070],
        [10.5891, 10.4810, 16.2941, 23.0886, 15.1048,  7.8075, 11.4686],
        [11.8897, 13.2924, 12.0778, 13.0823, 15.1672, 21.8294, 16.4395],
        [45.1551, 42.1691, 40.5327, 38.8516, 37.0670, 37.1249, 38.4832]])
ground truth
tensor([[10.7710, 12.5992, 15.5329, 22.3498, 35.7143, 27.5085, 12.8260],
        [27.1825, 26.0913, 27.6502,  7.9223,  7.1003, 12.6276, 18.7358],
        [11.0731, 14.5844, 16.2941, 14.4398,  5.9179, 11.5860, 11.9805],
        [17.6618,  4.9448,  5.5760,  4.9974,  8.4692,  7.2593, 19.1873],
        [30.6689, 38.6338, 60.1616, 47.1655, 37.1599, 24.9858, 19.6995],
        [10.0079, 1

batch_predictions
tensor([[18.3309, 16.3672, 15.3196, 16.3291, 22.5817, 36.4940, 26.9512],
        [14.8294,  7.0783,  8.3618,  7.8264,  7.8935, 11.8850, 17.3759],
        [12.9501, 12.9256, 13.6014, 22.4163, 31.1538, 23.0432, 11.3679],
        [ 9.4461, 10.3701, 12.2545, 13.0170, 16.8248, 21.7270, 24.0846],
        [17.4568, 20.2415, 22.8645, 14.2426, 11.9945, 11.8471, 13.1110],
        [12.7233, 19.0091, 11.5474,  6.2051,  5.6724,  8.7890, 10.5188],
        [14.0008, 14.3989, 22.0257, 31.9597, 30.9044, 11.2763, 14.9737],
        [22.6962, 18.2979,  6.9214,  8.7351,  8.9562, 10.7877, 17.5975]])
ground truth
tensor([[19.0476, 18.1831, 13.3220, 16.3407, 21.8254, 37.4150, 31.8736],
        [15.1236,  4.4582,  9.6134,  8.8901,  7.5355, 11.2309, 14.9658],
        [13.6054, 10.7426, 14.7392, 26.9133, 35.8135, 21.6837, 12.0748],
        [10.6786, 12.5855,  9.4424, 11.6386, 17.9511, 24.5923, 28.1825],
        [13.5192, 22.2120, 20.7654, 10.4813, 13.4140, 14.5713, 16.2415],
        [10.8844, 1

batch_predictions
tensor([[21.8639, 10.5833, 12.7352, 11.9767, 11.5824, 15.0815, 21.5230],
        [16.2638, 19.1116,  6.6897,  7.0047, 12.0779, 14.8014, 16.1399],
        [20.8445,  6.9623,  9.5858,  9.4006,  9.7799, 11.5894, 18.8969],
        [15.4333, 20.9037, 26.3869, 29.6669,  9.8658, 13.7566, 14.4699],
        [29.1661, 27.2606, 13.9888, 15.9982, 15.5679, 16.2126, 21.1036],
        [11.4525, 16.4944, 22.6294,  7.7645, 10.0058,  9.1542,  9.7600],
        [ 9.9359, 11.7817, 12.0084, 15.0112, 20.2745, 22.2178, 17.6652],
        [24.3457, 26.1202, 10.6971, 15.0137, 15.5001, 16.2215, 20.8577]])
ground truth
tensor([[19.7846,  9.4813, 12.7409,  8.6451, 13.3787, 17.2052, 21.9671],
        [19.4586, 31.9161, 33.4609,  4.3793,  4.5635, 34.6939, 38.7472],
        [22.1857,  3.5245,  8.8243, 10.4024, 11.8359,  8.9821, 15.2946],
        [15.0184, 16.5308, 24.3556, 26.4072,  8.7849, 14.1899, 12.3225],
        [29.6160, 29.3135, 12.6907, 20.6602, 15.5050, 15.8075, 18.5034],
        [ 8.9953, 1

batch_predictions
tensor([[33.8524, 46.9055, 42.6703, 32.1852, 26.5518, 27.6890, 27.6012],
        [16.1963, 16.8096, 16.2375, 18.7663, 30.2796, 43.3871, 34.1358],
        [39.5678, 32.2909, 14.7122, 16.8257, 16.1896, 17.9543, 32.8026],
        [13.2253, 13.2412, 13.6433, 22.8562, 36.7708, 30.0751, 12.1415],
        [11.7373, 12.6932, 14.4454, 21.1676, 27.2384, 11.7068, 10.1365],
        [13.4966, 19.8675, 29.0872, 22.2761, 11.9608, 12.9167, 13.1074],
        [20.5352, 18.2581, 33.4659, 40.5325, 33.9441, 26.6547, 23.6530],
        [15.6695, 17.9207, 25.9238, 31.6033, 12.2802, 15.7356, 15.9167]])
ground truth
tensor([[37.0465, 60.9269, 49.4048, 30.6548, 27.6927, 27.5935, 34.5805],
        [18.1746, 27.5644, 20.6602, 16.8464, 18.4903, 31.6018, 33.0615],
        [47.1514,  2.0408, 12.3724, 20.7200, 21.9529, 15.8588, 33.5176],
        [13.6054, 12.1032, 12.1315, 16.6100, 24.4189, 28.4439, 18.6791],
        [15.0085, 13.2086, 16.1990, 26.5164, 28.1037, 25.8503,  9.0136],
        [14.6502, 2

batch_predictions
tensor([[ 6.9314,  6.2901,  6.1409,  6.0099,  6.7304, 10.7386,  9.8141],
        [13.2199,  9.3143,  7.2559,  8.1665,  8.3644,  9.3096, 13.6315],
        [24.6917,  8.7254, 12.5739, 13.4344, 14.4390, 15.4808, 21.5019],
        [14.1268, 17.3502, 27.3520, 26.0059, 12.8225, 13.4270, 13.2392],
        [ 6.0625,  6.2108,  6.8125, 11.6005, 14.6404,  6.0453,  6.2773],
        [13.8892, 15.1688, 15.5288, 18.5537, 24.9119, 27.8940, 26.4646],
        [ 8.8719, 10.5011, 11.1204, 17.7491,  6.9740,  6.1757,  7.3729],
        [15.6403, 16.2030, 15.7799, 24.7220, 39.1350, 38.9194, 21.5797]])
ground truth
tensor([[ 7.7523,  8.4184,  5.4563,  5.8815,  6.7035, 12.1032, 11.2670],
        [12.1120,  7.1147,  5.2472,  8.2588,  8.7980,  9.5213, 13.2693],
        [26.4729, 10.5997, 10.0342, 11.4940, 10.2446, 15.3603, 17.6749],
        [10.8560, 20.4365, 21.8963, 22.6899, 15.5612, 12.5709, 10.2608],
        [ 5.5130,  5.3288,  7.6814,  9.2829, 11.1536,  6.1224,  6.9586],
        [13.6763, 1

batch_predictions
tensor([[14.4664, 14.7880, 15.8665, 26.4393, 37.0905, 30.3259, 11.8226],
        [12.0482, 11.5077, 11.9006, 15.7676, 25.0446, 25.2174, 11.8473],
        [46.9777, 38.4421, 21.4832, 20.7229, 20.3114, 24.5148, 38.3451],
        [21.8262, 20.1411, 22.0558, 31.0650, 30.8740, 34.7177, 26.3799],
        [30.1119, 14.1876, 16.6078, 16.1937, 18.4967, 24.7573, 34.0922],
        [14.3901, 14.6449, 14.8351, 23.6994, 33.1546, 26.1618, 10.3954],
        [ 7.4706,  6.5546,  7.5214, 10.7257, 17.6787, 21.1559,  6.0514],
        [33.2060, 39.9677, 33.6650, 14.0441, 16.4049, 16.0922, 18.3373]])
ground truth
tensor([[11.6497,  8.8294, 12.7834, 24.8724, 42.0777, 32.1570,  4.9745],
        [10.7049,  9.6528, 10.0999,  6.8780, 25.2236, 23.0537, 11.2572],
        [62.0200, 50.1578, 23.9085, 21.9621, 17.7407, 18.0826, 33.3640],
        [33.0484, 23.1194, 25.5129, 36.3756, 43.1483, 27.6039, 24.4871],
        [34.5380, 14.1865, 16.3124, 14.6542, 13.3929, 24.7874, 35.6718],
        [13.5913, 1

ground truth
tensor([[13.0458, 10.3104, 15.1762, 14.2425, 23.6323,  5.9705, 11.6649],
        [14.9802, 28.9399, 40.9580, 37.5283, 12.6417, 16.4399, 14.7676],
        [14.7817, 19.0426, 17.2672, 22.2777,  4.5634,  6.1152,  9.1136],
        [ 9.7364, 14.2574, 15.2920, 25.0283, 28.8549,  8.7160, 13.2795],
        [21.2585, 33.7443, 33.5317, 13.3362, 17.1485, 15.6888, 12.2874],
        [43.7783, 26.2188,  0.0000,  3.4864, 15.6321, 16.5675, 25.8220],
        [ 1.5518,  3.2746, 10.7969, 14.6107, 21.8438,  4.2609,  7.0884],
        [34.8214, 38.9031,  6.5051, 18.5941, 14.6684, 11.1820, 29.7194]])
batch_predictions
tensor([[11.6039, 12.4902, 14.0646, 17.8681, 27.5200, 29.6511,  9.4583],
        [31.9337, 35.2640, 23.1910, 23.2250, 21.9611, 22.9853, 26.1015],
        [15.3224, 21.0100, 39.8098, 37.0187, 20.5460, 17.5222, 16.5787],
        [13.8680, 15.7674, 15.5969, 17.1879, 23.4030, 39.1385, 31.7264],
        [24.9557, 26.1116, 29.6863, 13.4502, 13.9219, 13.7194, 15.3545],
        [25.4863, 2

batch_predictions
tensor([[12.9507, 20.2455, 21.7849,  7.2077,  6.5013,  8.4917, 10.5068],
        [25.7451,  8.9954, 12.2297, 13.5280, 13.8500, 14.4139, 18.3936],
        [12.2499, 14.5639, 19.2176, 23.4257, 12.9597, 11.4076, 11.7004],
        [17.4322, 19.8360, 29.3271, 37.2815, 21.7173, 16.8651, 17.3867],
        [11.4269, 11.2421, 12.9598, 15.4843, 13.3810, 13.8200,  9.3120],
        [13.4898, 13.5744, 14.9056, 24.8479, 33.2725, 22.3562, 11.6898],
        [15.3131, 23.8500, 20.7493,  9.7270, 10.7466,  9.5123, 10.7335],
        [ 9.9601, 12.0179, 13.5152, 16.3686, 19.9095, 26.5487, 24.2269]])
ground truth
tensor([[14.9921, 19.8974, 16.7412,  3.6691,  6.5360,  7.5355, 10.8759],
        [24.4740,  8.3903, 12.1515, 13.2562, 10.8101, 13.5718, 24.6318],
        [10.4813,  9.8238, 17.9642, 22.8301,  7.4698,  7.8906,  6.9832],
        [23.5402, 25.9206, 36.2441, 48.3167, 43.0037, 14.4792, 19.9369],
        [ 0.0000,  0.0000, 16.5965, 21.1336, 16.9122, 20.3183, 20.3972],
        [18.2965, 1

batch_predictions
tensor([[11.5703, 10.1256, 11.3462, 14.7697, 19.3156, 20.5440, 22.4970],
        [13.5388, 19.1872,  9.8571,  7.5394,  7.5978,  7.7093, 11.1310],
        [31.2708, 12.1749, 13.5805, 13.9579, 14.9750, 16.0345, 25.4389],
        [22.9764, 28.0828, 10.2277, 12.9208, 12.1384, 11.9314, 13.4955],
        [10.6331, 11.0458, 19.0713, 21.2807,  7.8848,  8.3670, 10.2655],
        [15.2113, 12.8248, 13.4834, 12.6677, 15.5770, 20.8419, 17.8464],
        [40.0167, 28.7363, 21.3017, 19.8013, 15.9372, 16.1383, 29.2919],
        [13.3925, 17.8884, 24.4186, 25.9535, 13.6971, 13.3075, 12.8005]])
ground truth
tensor([[15.2551,  9.9684, 13.3745, 16.0179, 16.5439, 14.5581, 19.2925],
        [12.5592, 17.4250, 21.4361,  8.6665, 11.5992, 11.2441, 11.6781],
        [32.9826, 33.6928, 12.9274,  9.5345, 14.3609, 22.8958, 19.7265],
        [28.5508, 28.7086, 10.5339, 11.5071,  9.9421, 11.4940, 13.5850],
        [11.5071, 11.8885, 13.4008, 17.7407,  6.7070,  8.0484, 11.5992],
        [12.2304, 1

batch_predictions
tensor([[16.0450, 16.8561, 17.0663, 18.8989, 27.8406, 29.9602, 27.5502],
        [19.1249, 25.4569, 21.1246, 11.5719, 13.7899, 13.3120, 14.6157],
        [ 9.3856, 14.2218, 14.3400, 14.4266, 17.7230, 23.4987, 20.7294],
        [14.2056, 20.0307, 26.3837, 26.7695,  8.9928, 12.4227, 13.6290],
        [11.3770, 11.3896, 22.5072, 26.0231,  8.7347, 11.8607, 12.8354],
        [17.6318, 22.4000, 35.6807, 38.4996,  9.3190, 16.4889, 16.7405],
        [17.6556, 20.3861, 35.2528, 35.6312, 11.0137, 16.6694, 15.8935],
        [13.4733, 13.9059, 15.5036, 19.1683, 23.0210, 21.2000, 20.9152]])
ground truth
tensor([[13.3078, 14.7392, 12.3016, 14.7251, 21.5136, 29.9603,  8.7727],
        [23.4694, 29.7902, 24.8299,  0.0000,  8.7443, 13.2653, 14.4841],
        [ 6.1933, 10.5867, 17.9989, 17.5028, 16.9926, 25.2551, 26.4314],
        [13.0952, 14.3141, 21.3861, 23.0017, 10.7001, 10.0057, 10.9694],
        [ 8.1010, 12.1120, 23.6455, 26.4466, 11.3756, 12.9011, 13.2693],
        [19.0618, 2

batch_predictions
tensor([[24.7206, 34.6609, 46.6066, 43.6741, 31.1696, 27.7012, 23.5163],
        [12.8416, 14.6192, 14.2430, 16.0755, 23.4620, 33.8569, 24.8874],
        [13.5389, 17.6407, 14.2656,  7.8361,  8.9719,  9.1911,  9.3529],
        [15.7286, 16.5471, 20.9560, 30.5418, 32.7123,  9.6577, 14.8288],
        [23.3938, 25.4198, 32.3505,  9.0032, 15.3769, 15.5756, 16.6383],
        [17.1531, 25.4332,  8.3870, 11.9803, 11.1242, 10.3785, 14.3695],
        [11.6974, 12.3286, 14.3539, 22.2756, 25.6752, 18.6032, 10.9919],
        [18.6803,  6.6177,  8.1603,  7.9105,  7.9324, 10.0877, 15.7418]])
ground truth
tensor([[23.3277, 36.7914, 58.6876, 57.4830, 31.4484, 24.3764, 27.4660],
        [12.0607, 11.4938, 11.6213, 16.7659, 23.3844, 29.8895,  4.9603],
        [13.7954, 20.6865,  4.6686,  5.0763,  9.2057, 12.2567, 10.7575],
        [11.9615, 18.2965, 23.6961, 36.9756, 39.6684,  9.2971, 21.1876],
        [25.2551, 36.7772, 40.4195, 10.4734, 20.7766, 15.1077, 18.2965],
        [17.8985, 2

tensor([[20.0562, 27.5703, 35.1314, 22.3274, 21.1103, 18.3904, 17.9080],
        [ 9.8184, 11.8009, 12.9958, 13.6760, 17.0056, 26.7018, 13.8230],
        [14.9008, 19.2170, 28.6342, 29.8748, 16.2039, 15.6800, 15.1414],
        [42.4603, 15.7706, 17.5049, 17.0936, 19.4045, 33.5611, 52.4483],
        [12.2474, 13.3572, 14.7233, 19.6877, 24.6654, 11.0708, 11.8449],
        [ 6.1748,  7.3924, 10.5513, 12.1839,  5.9160,  5.6072,  6.0717],
        [12.5611, 11.6205, 10.8596, 11.7460, 16.8156, 25.9130, 23.1723],
        [ 8.3555,  7.7693,  6.9295,  8.4271, 11.8986, 13.2725,  8.8011]])
ground truth
tensor([[27.6959, 22.2251, 28.0510, 14.7817, 13.8874, 23.1326, 16.3204],
        [ 0.0000, 11.7772, 10.7143, 12.3299, 13.8605, 33.4467, 20.6774],
        [20.0552, 28.9979, 31.4834, 33.2983, 17.2015, 10.2446, 12.3225],
        [42.4645, 15.9390, 26.8017, 17.4382, 14.9527, 35.9022, 69.7265],
        [ 9.2057, 15.7417, 15.2288, 27.3935, 28.4982,  9.0216, 14.1241],
        [ 6.0516,  7.0437, 11.2103, 1

batch_predictions
tensor([[14.3351, 14.6237, 15.1771, 21.1208, 26.6578, 25.5553,  9.8385],
        [31.7027, 39.7762, 31.4194, 11.8546, 14.7818, 14.7137, 17.1070],
        [22.6416, 10.2783, 12.9672, 12.1214, 12.8795, 19.5315, 30.5337],
        [24.4793, 22.8834, 13.0106, 14.7138, 14.9706, 16.5214, 20.4069],
        [17.5325, 24.5761, 30.0651,  7.3809, 13.8122, 12.7671, 15.0757],
        [12.7467, 21.0828, 22.7716,  8.8447, 10.9837, 12.1243, 11.8412],
        [12.7728, 13.7734, 19.1593, 27.3436, 26.8560, 10.9222, 12.2164],
        [31.1204,  8.0223, 13.3649, 13.3667, 15.0570, 19.5855, 27.0238]])
ground truth
tensor([[14.0023, 14.7534, 20.6916, 29.4076, 25.2126, 21.5420, 11.4654],
        [31.3067, 44.6003,  3.4155,  0.0000, 13.2795,  9.2262, 18.0414],
        [23.1009,  6.5476,  9.8781, 11.3379, 11.5079, 33.5317, 33.4467],
        [20.8333, 21.5561, 16.9359,  6.9019, 10.5726, 16.6383, 17.9280],
        [15.7286,  0.0000, 25.1841,  4.3530, 13.1904,  2.3277, 14.9395],
        [14.2149, 2

tensor([[10.6718, 11.2103, 12.3583, 14.1015, 11.1536, 10.6009, 11.3095],
        [16.1281, 13.2653, 19.4161, 29.8895, 36.1536,  4.9603, 15.6463],
        [42.5454, 67.6446, 59.4388, 35.5726, 34.4246, 21.9104, 25.3260],
        [28.9979, 29.3398, 14.2425, 14.8738, 17.4908, 19.4371, 20.8311],
        [17.5737, 15.4337, 23.0584, 31.0374, 30.0454, 13.3078, 14.7392],
        [ 7.3129,  8.3759, 13.9031, 19.4728, 22.8316, 27.6644, 18.3673],
        [16.6100, 21.5420, 36.0261, 36.7914, 25.6094, 18.4382, 21.9388],
        [28.4982, 28.8927,  6.1021,  4.4713,  8.2194,  8.8243, 14.2031]])
batch_predictions
tensor([[ 9.7699, 13.0288, 13.7502, 14.3069, 16.2103, 22.7213, 28.1618],
        [35.5952, 33.3730, 13.4919, 16.0719, 14.6351, 14.9823, 23.4200],
        [36.8874, 44.4702, 37.4912, 25.0217, 22.8808, 21.3084, 21.7186],
        [33.8127, 33.5030, 34.1597, 33.4212, 33.8645, 35.4356, 40.2595],
        [17.9419, 18.5008, 18.7145, 20.3423, 30.8211, 42.7708, 34.7938],
        [24.0171, 24.0796, 28.69

batch_predictions
tensor([[31.1817,  9.3798, 13.6869, 14.1004, 16.4307, 22.1301, 35.1986],
        [25.1512, 30.2900,  7.2471, 13.2766, 12.9875, 13.4349, 17.1722],
        [17.7949, 25.0892, 18.3448, 12.1575, 12.8249, 11.4499, 11.7132],
        [12.8160, 20.6652, 13.7104,  7.8938,  9.0965,  8.0193,  8.6291],
        [12.3764, 13.8953, 14.4200, 16.6884, 26.6778, 33.7343, 10.3647],
        [22.7592, 12.1029, 12.5511, 11.8929, 11.3200, 12.7691, 22.1928],
        [31.3253, 12.7150, 15.5275, 14.5581, 15.1035, 24.0136, 36.3830],
        [ 6.4277,  7.4369, 11.8619,  8.9027,  5.7966,  6.4900,  6.0368]])
ground truth
tensor([[41.8509, 14.9802, 14.2857, 10.4025, 18.3815, 27.1825, 26.0913],
        [24.5607, 27.6644,  5.0879, 12.7126, 13.6054, 12.8968, 20.7341],
        [14.5692, 25.7795, 25.0709,  8.1774, 12.0181, 11.5646,  9.4813],
        [ 6.9728, 11.4371, 20.5074, 17.1202,  4.6910, 10.3175, 11.9615],
        [15.4655, 10.8890, 11.4413, 19.9106, 22.3567, 33.6139, 13.2956],
        [17.7670,  

batch_predictions
tensor([[22.0379, 27.3993,  8.0347, 12.3621, 12.9903, 13.0246, 16.4586],
        [17.9934, 25.6260, 24.2688, 13.4975, 13.8611, 12.8029, 15.0230],
        [12.7105, 14.9469, 15.4892, 17.5256, 22.4708, 30.5639, 30.5651],
        [19.8745, 11.8614, 13.6439, 12.7460, 14.1081, 22.3307, 30.7797],
        [14.1609, 16.6848, 15.5751, 16.4178, 21.8454, 37.8703, 40.7021],
        [22.0845, 11.0854, 12.0274, 12.1494, 15.3074, 20.2810, 24.3049],
        [15.4447, 15.0375, 16.0135, 26.5272, 36.9764, 29.7106, 14.7279],
        [31.3372, 24.3511, 15.0572, 15.2416, 14.5275, 14.5483, 21.8714]])
ground truth
tensor([[19.6081, 25.6312,  4.9711, 12.5592, 12.3882, 13.9137, 18.9637],
        [17.0210, 25.9921, 24.9575, 12.8827, 12.2874, 11.7489,  6.9019],
        [15.6604, 15.5896, 17.4887, 21.3010, 23.4410, 32.1429, 34.0278],
        [25.5392, 14.0847, 16.0968, 15.3077, 15.7023, 23.3430, 33.4824],
        [ 7.2279, 11.9473, 15.0794, 18.3957, 22.5198, 38.8605, 39.4558],
        [ 0.0000,  

batch_predictions
tensor([[15.0625, 12.2696, 13.8770, 15.3709, 22.6172, 32.3645, 28.9936],
        [23.3363, 22.5551, 10.9123, 13.1315, 12.1841, 12.2341, 14.4112],
        [ 7.1577,  9.1161, 11.1176, 15.9006,  7.3875,  6.7750,  7.1788],
        [11.6621, 13.5596, 18.6364, 23.8655, 18.2297, 11.2931, 11.6258],
        [10.7833, 15.3128, 15.2963, 15.4133, 18.8181, 30.7685, 30.4566],
        [15.4115, 25.5954, 39.1031, 30.1851, 12.5749, 14.5531, 15.0371],
        [13.8398,  8.2194, 10.0820, 11.3588, 12.2801, 15.0350, 22.2307],
        [ 9.9317,  9.2056,  8.8041, 10.2344, 12.3634, 20.7948, 18.5300]])
ground truth
tensor([[12.5850, 15.5896, 10.9977, 12.6134, 21.2018, 25.7228, 21.8254],
        [25.2268,  0.0000,  0.0000,  9.5096, 13.9739, 14.4416, 14.8668],
        [ 4.2800,  6.8736,  9.6088,  1.9700,  4.1383,  5.1587,  5.7115],
        [10.9022, 10.8101, 14.4661, 19.8185, 13.3482, 11.2441,  9.6923],
        [ 8.1799,  6.8254, 10.2446, 19.2136, 21.1336, 26.3282, 37.2304],
        [14.0023, 2

batch_predictions
tensor([[19.8187, 22.3657, 21.0709, 12.2633, 13.2346, 15.2957, 16.9199],
        [13.4217, 15.7443, 15.0657, 15.8784, 20.9068, 29.2122, 26.5042],
        [17.6791, 18.0087, 16.8876, 19.4726, 31.8912, 46.0967, 35.7719],
        [27.8544, 13.0116, 14.8111, 15.2573, 17.6065, 23.0570, 33.5374],
        [12.9640, 13.8165, 18.5663, 25.7608, 28.8757, 12.1128, 13.6866],
        [17.8216, 19.3126, 19.8183, 24.1561, 38.6914, 47.6730, 37.8604],
        [12.9667, 13.1874, 16.2975, 24.4634, 26.6221, 10.3831, 13.8696],
        [20.8885, 14.3254, 10.3544, 10.9771, 10.2937, 11.4270, 14.8326]])
ground truth
tensor([[ 0.0000,  8.2851, 15.0053,  7.1936,  0.0000,  0.4208,  3.5376],
        [12.9677, 19.1610, 17.7721, 18.5374, 19.3027, 25.5811, 29.3367],
        [26.2330, 20.7766, 13.5771, 17.0210, 26.9416, 43.8350, 32.8656],
        [27.8143, 10.1920, 13.3614, 14.4661, 19.5818, 28.9716, 28.3535],
        [ 9.6791, 14.2162, 18.1878, 23.2772, 22.6460, 12.0200, 12.0200],
        [20.8759, 2

batch_predictions
tensor([[20.1147, 29.4158, 23.8537,  9.7949, 13.5751, 13.1098, 14.3169],
        [ 7.4368,  7.3451,  7.4076,  9.7180, 14.1583, 16.8212,  6.1769],
        [24.8210, 26.8005, 13.5443, 13.5217, 13.8246, 15.7109, 20.6310],
        [23.9423, 20.6086, 19.3238, 18.4958, 24.1582, 43.1701, 43.9823],
        [13.6121, 15.9566, 16.4434, 20.3228, 23.0265, 24.0943, 23.2845],
        [ 7.3505,  7.7929,  9.4999, 11.7759, 17.9861, 18.6424,  8.2552],
        [27.4522, 29.5307, 17.1171, 18.1324, 16.7628, 16.4212, 19.0981],
        [14.2392, 14.1709, 14.7057, 17.7802, 22.6353, 25.6061, 22.0964]])
ground truth
tensor([[24.1583, 23.1589, 21.1731, 11.4019, 11.1520,  8.6796, 11.0600],
        [ 6.4308,  5.7207,  6.5755,  7.8906, 14.6502, 14.9132,  5.0631],
        [19.8448, 22.9879, 16.2546, 13.9795, 11.2046, 13.3351, 15.8338],
        [27.6644, 25.4960, 15.7596, 16.6950, 27.7920, 38.1803, 39.3282],
        [13.9006, 13.5981, 15.2946, 18.4377, 22.4882, 29.8527, 27.7223],
        [ 6.2730,  

batch_predictions
tensor([[25.9406,  9.8643, 12.9994, 13.1908, 13.9363, 17.2303, 25.8058],
        [12.2733, 13.7437, 20.0036, 26.0797, 25.4815, 11.3397, 13.0965],
        [17.7509, 18.7520, 33.5623, 45.4653, 40.4294, 22.7090, 20.6398],
        [40.2858, 32.6581, 25.1309, 22.1656, 21.3589, 24.2448, 34.0198],
        [23.6786, 31.1551, 24.8023, 12.8938, 14.8714, 13.9615, 15.6348],
        [17.3265, 22.4720,  6.8325,  8.2068,  8.0497,  8.5525, 11.3314],
        [12.5616, 15.7682, 24.0682, 26.0731, 17.1065, 13.6696, 13.6201],
        [14.4982, 15.6793, 25.1057, 35.7831, 30.0906, 12.7731, 14.9801]])
ground truth
tensor([[26.9726,  9.7449, 11.4150, 12.0594, 14.0847, 20.7128, 23.3956],
        [12.2699, 14.5055, 20.3446, 28.6560, 33.0221, 12.0857, 15.4261],
        [18.4382, 23.3277, 33.0499, 49.0646, 43.1122, 23.6395, 21.4711],
        [52.8203, 47.1230,  0.0000, 29.9320, 23.2710, 29.7052, 34.7931],
        [23.3430, 33.4824, 24.8553, 12.9406, 15.5708, 12.3751, 14.2425],
        [14.6107, 2

batch_predictions
tensor([[13.5114, 17.6254, 27.6108, 29.1160, 17.5825, 14.3287, 13.6735],
        [22.7917, 11.7559, 13.2229, 12.6044, 13.3630, 16.4604, 21.1651],
        [35.4727, 16.0565, 18.7984, 17.2799, 18.5549, 27.7516, 45.9860],
        [15.2355, 18.6216, 19.4113, 19.3996, 30.4250, 38.6729, 34.5585],
        [14.1434, 14.8233, 15.7793, 25.0637, 34.1273, 25.9193, 12.1230],
        [25.4625, 30.3089,  7.3449, 13.4088, 13.6431, 14.7643, 18.0818],
        [41.3433, 26.5718, 18.7809, 18.5233, 18.1452, 23.2760, 40.6329],
        [15.1477,  9.9778,  7.5382,  8.2538,  9.4833,  7.8269, 11.2922]])
ground truth
tensor([[18.4382, 27.0550, 40.5612, 35.3033, 20.4649,  8.8010, 16.6383],
        [17.7538, 11.2704, 10.8759, 13.9400, 14.9921, 15.5050, 23.6323],
        [44.3027, 16.5958, 25.4252, 22.6757, 14.5125, 25.9070, 42.4745],
        [13.7755, 25.0000, 19.7137, 17.7579, 23.9371, 41.9359, 35.7285],
        [ 9.5345, 14.3609, 22.8958, 19.7265, 33.9032, 27.7880, 11.5334],
        [21.3572, 3

batch_predictions
tensor([[24.8552, 40.0429, 35.9225, 13.3541, 16.3567, 15.6414, 17.4003],
        [22.8518, 24.4712, 35.4375, 47.6802, 46.2137, 33.3201, 25.0366],
        [18.6652,  8.8819, 10.1903, 10.6291, 10.1778, 11.3518, 18.8063],
        [13.5858, 15.3111, 18.5615, 29.2942, 36.3988,  8.1038, 13.4762],
        [25.4462, 26.7064,  8.0684, 11.9104, 11.8736, 12.7287, 16.5458],
        [19.1426, 17.8825, 21.7039, 30.4909, 31.5535, 24.8548, 19.2072],
        [19.0295, 19.9500, 21.5978, 23.6394, 23.2210, 26.5205, 33.6106],
        [11.9704, 12.5962, 13.9642, 18.7823, 26.3403, 30.4732, 10.0426]])
ground truth
tensor([[24.0788, 35.5300, 43.9059,  8.9994,  1.3039, 18.9201, 15.0085],
        [20.9042, 23.3277, 36.7914, 58.6876, 57.4830, 31.4484, 24.3764],
        [14.4398,  5.4577,  5.3919,  4.9448,  8.2720, 11.0337, 19.6344],
        [ 9.8781, 14.6542, 17.3328, 28.9966, 37.0465,  4.7619, 11.6213],
        [28.5376, 28.7743,  9.4424, 10.7706,  7.7065, 15.0316, 14.3214],
        [19.0689, 1

batch_predictions
tensor([[16.5523, 16.5994, 17.2261, 28.4035, 40.0666, 33.0147, 11.5587],
        [13.1517, 13.5525, 14.7983, 20.0656, 30.6618, 27.6272,  9.6706],
        [ 8.4277,  8.8315,  8.2361, 11.2430, 15.1136, 18.3788,  6.8455],
        [11.1730, 11.0114, 11.6336, 14.3780, 18.1231, 23.3347,  8.2118],
        [23.0928, 30.6089, 24.7252, 13.8337, 15.7088, 14.3611, 17.6872],
        [20.1740, 14.9117, 15.8056, 15.9424, 18.7987, 26.5506, 35.5804],
        [12.0775, 13.3483, 13.3318, 13.3881, 19.6960, 36.2654, 26.8653],
        [19.2572, 24.2721, 31.4268, 18.0760, 16.7396, 15.7814, 16.4730]])
ground truth
tensor([[17.2194, 19.1185, 16.4966, 29.2234, 27.2959, 33.0499, 12.5850],
        [11.8764,  9.7364, 14.2574, 15.2920, 25.0283, 28.8549,  8.7160],
        [ 6.6468,  5.3430,  7.7098,  9.4246, 13.4637, 19.8696,  5.8107],
        [11.5729, 12.1383, 12.4934, 15.5313, 15.0973,  9.3898,  0.0000],
        [22.6616, 31.0516, 35.4592, 14.3707, 14.8951, 17.4603, 17.1202],
        [38.0811, 1

batch_predictions
tensor([[29.6560, 14.4083, 15.3610, 16.8518, 19.3510, 28.1985, 38.6640],
        [14.6590, 22.8655, 17.4925,  9.0918, 10.2936, 10.5002, 11.8703],
        [13.0418, 12.7866, 12.8971, 17.7760, 27.0539, 25.2514, 12.2567],
        [14.5854, 15.7173, 15.3272, 18.3740, 28.4096, 32.3417, 25.5320],
        [ 8.3571,  9.9226,  8.9924, 12.1000, 17.4383,  7.4388,  6.3750],
        [18.0154, 17.4362, 32.6455, 50.6378, 48.2660, 30.4829, 20.8668],
        [12.1544, 13.7911, 15.8445, 27.3673, 35.5680, 21.2187,  9.7495],
        [35.0160, 18.7439, 19.5519, 18.1289, 22.9887, 37.3124, 49.1275]])
ground truth
tensor([[26.8411, 19.2136, 21.3966, 17.4382, 19.1215, 30.2735, 42.0831],
        [14.6107, 25.3288, 27.5776, 14.9001, 15.9390,  9.0216,  8.6139],
        [17.1627,  9.7364, 12.0465, 19.2885, 30.4280, 22.6616, 11.3237],
        [13.7755, 11.8481, 14.6684, 20.6633, 25.1701, 33.7302, 28.3872],
        [ 4.7344,  5.0631,  4.9448,  8.1010,  4.6291,  5.0105,  3.6297],
        [14.5844, 1

batch_predictions
tensor([[14.3307, 15.1898, 25.2525, 39.9176, 34.0571, 12.2895, 15.2856],
        [18.5128, 22.8073,  7.9941,  9.0985,  9.5863, 10.1437, 12.6083],
        [12.1127, 13.7394, 22.5171, 30.1941, 22.2459, 11.2531, 12.5224],
        [10.6683, 11.1369, 12.0608, 21.4948, 29.3378, 17.1506,  8.9208],
        [42.8420, 46.5602, 40.4269, 32.7647, 35.7477, 39.1504, 40.6527],
        [ 8.0043,  8.7841, 10.4558, 12.2818, 19.3714, 21.7086,  7.5043],
        [12.5485, 13.3786, 12.8548, 15.2264, 22.4439, 27.4021, 24.7844],
        [17.6306, 27.4683, 34.0418,  7.9489, 13.1348, 13.2611, 14.9364]])
ground truth
tensor([[16.2698, 15.6888, 25.9212, 47.4206, 34.9065, 17.3895,  8.3192],
        [23.0274, 28.0773, 10.9811, 14.1504,  7.8248, 11.3756, 12.3751],
        [15.6102, 23.1194, 34.1005, 44.3977, 12.7696, 15.4655, 10.8890],
        [ 8.9427, 10.7143, 10.8277, 22.9308, 30.3571, 18.2540,  9.0561],
        [52.1699, 63.6376, 64.6107, 41.6491, 42.5829, 45.9890, 42.5565],
        [ 8.0352,  

batch_predictions
tensor([[14.3349, 15.1163, 15.6780, 19.5709, 27.4532, 28.6715, 11.7782],
        [14.6731, 15.2773, 17.3812, 23.6426, 27.8402, 11.5729, 14.8127],
        [13.7376, 13.7052, 22.1177, 26.8708, 17.0494, 10.9490, 12.3045],
        [23.6581, 15.9625, 17.0109, 16.0683, 18.1780, 34.6775, 44.5562],
        [12.8123, 15.6634, 22.7916, 29.7701,  8.9965, 13.3953, 14.3418],
        [14.4530, 12.9848, 21.7941, 26.5076, 12.1002, 15.3847, 15.7074],
        [21.9690, 13.0583, 13.7931, 12.5492, 13.3098, 20.4049, 29.7192],
        [40.4550, 24.5682, 18.2710, 17.8589, 16.1174, 17.7739, 30.3112]])
ground truth
tensor([[16.7543, 13.5981, 13.7822, 15.1894, 22.7775, 30.9705, 10.6260],
        [11.3098, 12.6775, 16.2415, 20.1604, 24.2504,  9.9290, 12.6644],
        [ 8.6796, 11.0600, 15.2946, 27.7617, 24.0400,  8.3772, 10.0342],
        [29.3793, 15.5896, 18.8776, 19.1468, 19.6854, 33.6735, 45.0680],
        [16.9253, 14.0715, 22.3304,  0.0000,  0.0000,  9.2451, 13.1378],
        [17.9116, 2

batch_predictions
tensor([[33.6205, 19.5042, 21.0209, 19.3707, 22.9696, 34.8577, 46.6547],
        [18.5563, 20.3087, 22.4690, 20.5661, 12.2635, 12.5232, 16.4430],
        [24.5448, 10.5117,  9.3555, 11.1227, 11.8301, 13.9533, 18.3335],
        [12.8714, 17.6863, 26.2926, 30.4197, 16.9735, 17.1607, 16.9456],
        [21.2926, 19.4871, 19.0672, 22.6794, 34.1974, 47.1024, 36.5375],
        [19.7263, 31.5291, 30.8593, 12.3554, 14.9256, 14.0686, 15.3098],
        [26.8474, 14.6994, 16.2461, 15.3345, 20.1516, 35.0520, 43.6586],
        [ 9.8053, 12.2074, 13.0262, 13.6916, 18.4268, 24.5871, 20.5056]])
ground truth
tensor([[40.0085, 17.0777, 21.1310, 17.1910, 24.9008, 37.4008, 50.3827],
        [17.3593,  0.0000,  8.2851, 15.0053,  7.1936,  0.0000,  0.4208],
        [28.8974, 29.7761, 12.3583, 13.2937, 15.0368, 14.4700, 18.0414],
        [13.3362, 30.2438, 54.3934, 45.8050, 17.1202, 12.8543, 13.9172],
        [19.1215, 20.2656, 26.5387, 23.5928, 31.4966, 42.4776, 46.1599],
        [19.8129, 4

batch_predictions
tensor([[ 8.8160,  8.9145, 11.1569, 16.1188, 22.5918, 10.5831, 10.5879],
        [12.8499, 14.6526, 23.0844, 30.8639, 23.2638, 11.1272, 14.0484],
        [16.5464, 16.7530, 26.3289, 38.3902, 34.5858, 15.7768, 16.7760],
        [ 9.8350, 12.9567, 13.5992, 13.3446, 16.3716, 23.9357, 23.9502],
        [13.6653, 13.9265, 13.8574, 18.1133, 27.4302, 31.5167, 27.5505],
        [40.7659, 29.6801, 13.8026, 15.7676, 15.6568, 21.1046, 36.8938],
        [22.5127, 11.9973, 14.2002, 13.5760, 14.1820, 19.6719, 27.6697],
        [16.3445, 16.9008, 15.4077, 24.1476, 42.6592, 40.1984, 16.2510]])
ground truth
tensor([[10.4156, 13.3614, 18.2799, 27.3540, 36.9148,  7.8774, 11.2309],
        [11.3095, 16.5249, 24.2063, 27.2959, 24.9150,  5.0595, 12.1032],
        [20.3709, 19.8054, 25.4340, 42.0568, 28.3009, 21.0021, 24.3161],
        [ 8.4035, 11.5203, 12.9011, 11.6123, 14.4661, 22.8958, 26.9463],
        [17.8571, 15.3628, 12.2874, 16.8367, 25.0425, 43.1122, 10.9552],
        [50.2126, 2

batch_predictions
tensor([[32.0177, 26.0300, 20.4897, 19.0920, 17.4670, 18.4607, 28.1886],
        [15.4034, 15.2770, 17.2538, 25.6349, 38.7859, 30.1925, 13.7076],
        [19.1740, 28.1802, 38.2215, 33.5360, 21.8729, 20.2660, 17.9604],
        [10.1873, 14.9042, 15.3366, 15.6461, 19.8310, 32.2533, 27.3208],
        [11.0892, 12.1534, 15.8061, 23.2939, 24.9889, 10.2262, 12.1564],
        [16.6317, 24.6091, 29.7185, 17.0746, 10.8237, 12.8002, 14.0191],
        [16.2660, 16.2030, 26.6525, 37.9968, 33.7723, 12.8238, 16.9494],
        [28.9755, 12.1395, 14.7645, 14.2747, 14.5967, 18.2058, 29.6135]])
ground truth
tensor([[33.3050, 28.9257, 20.3515, 20.1247, 16.2273, 16.9359, 25.0000],
        [14.2432, 19.1185, 21.0317, 34.8639, 28.4722, 30.2863, 21.1593],
        [21.3435, 25.8078, 40.0085, 43.6791, 34.1553, 30.1446, 26.6015],
        [14.2715, 17.5454, 15.1219, 14.6400, 16.8793, 27.6219, 33.2483],
        [ 8.7868, 13.9598, 18.4807, 27.9904, 47.6332, 21.8112, 19.8413],
        [20.2656, 2

batch_predictions
tensor([[15.2977, 15.3322, 24.3146, 41.5554, 37.9784, 13.6308, 17.3393],
        [23.5487, 27.7440, 18.4473, 20.9879, 22.3837, 20.2027, 21.2697],
        [19.4279, 30.1454, 41.5465, 32.7858, 15.7189, 16.7842, 16.9550],
        [13.9822, 13.8163, 14.8777, 21.8171, 27.6272, 26.6793, 12.3424],
        [23.3275, 31.0077, 19.4163, 15.6737, 15.8949, 15.5170, 17.8535],
        [23.1868, 29.3559, 18.5671, 12.8543, 14.7622, 14.2818, 16.6146],
        [13.4874, 15.5308, 16.0456, 17.8393, 27.3064, 39.6220, 23.1880],
        [37.2442, 42.8926, 27.2025, 19.6387, 18.2861, 18.3413, 22.2019]])
ground truth
tensor([[12.4717, 16.0714, 21.9388, 44.3452, 35.5867, 11.2670, 11.3662],
        [27.4660, 37.4008, 21.9955, 23.6961, 20.7058, 23.4410, 26.3464],
        [21.8112, 25.6094, 35.1474, 29.9461, 15.8305, 17.9989, 16.4966],
        [16.9926, 15.3486, 15.1502, 22.9450, 26.7999, 28.0187,  7.1287],
        [21.3719, 29.3934, 28.0471, 12.9535, 15.9439, 15.1786, 17.0918],
        [24.8724, 3

batch_predictions
tensor([[11.1936,  7.6692,  9.8387, 13.0728, 13.6968, 15.1233, 17.2759],
        [ 6.8460,  6.8971,  7.6438, 10.0161,  7.1104,  5.4103,  7.0531],
        [14.5331, 16.7045, 16.0876, 18.5459, 35.2338, 39.5615, 25.9428],
        [11.3699, 13.2748,  9.8507,  6.9682,  7.7437,  7.7197,  9.3727],
        [13.6494, 13.9402, 20.0360, 24.0368, 24.0647,  9.0932, 13.1409],
        [32.7473, 19.8139, 20.6224, 20.6855, 24.0184, 36.4015, 46.0067],
        [14.9780, 19.8217, 30.9664, 31.8463, 12.8835, 15.6879, 14.5480],
        [27.3710, 22.8090, 10.7503, 12.9123, 13.4433, 13.6499, 21.5743]])
ground truth
tensor([[13.5981,  9.8501, 21.8043, 24.6712,  4.9974, 39.2425,  4.1952],
        [ 7.5113,  7.8656,  8.9569, 14.9660,  9.0703,  6.0232,  8.2341],
        [ 0.0000,  7.2988, 14.2999, 14.9802, 28.9399, 40.9580, 37.5283],
        [12.6249, 14.6765,  9.2714,  6.8517,  9.1925,  8.6402, 10.2972],
        [ 9.8356, 10.4308, 12.6701, 22.4065,  4.6769,  8.7018, 11.5930],
        [46.1599, 1

batch_predictions
tensor([[12.9631, 18.1998, 27.0956, 11.6972, 11.7597, 12.5829, 14.6046],
        [22.7579, 38.9798, 42.8910, 36.2789, 25.4659, 23.7950, 20.1370],
        [17.1046, 17.5972, 20.1750, 30.9731, 35.8378, 16.6048, 17.8812],
        [11.6320, 12.3793, 11.0829, 11.5957, 19.0939, 26.0394, 20.9927],
        [15.0355, 18.8368, 26.8619, 35.4914,  7.4911, 14.2265, 15.1365],
        [10.8086,  6.4708,  7.8381,  7.8066,  7.9585, 11.0799, 15.7561],
        [14.6934, 20.0721, 27.2750, 26.5927, 11.7878, 14.1080, 13.8813],
        [26.0609, 11.5155, 14.0494, 14.4325, 14.7199, 18.9016, 28.5473]])
ground truth
tensor([[15.6102, 27.3277, 28.0247,  8.0089, 10.7575, 14.8738, 17.4119],
        [18.8848, 33.9427, 55.6418, 14.1899,  0.0000, 15.4787, 21.9227],
        [18.2536, 19.5949, 20.0026, 24.8948, 12.3225, 10.8496, 11.0205],
        [12.6907,  8.3509, 11.1783,  7.3251, 13.9795, 21.3046, 25.7891],
        [16.4782, 22.6065, 24.0400, 32.6407,  6.4703, 15.3077, 15.5576],
        [12.0068,  

batch_predictions
tensor([[22.1636, 21.4224, 17.8278, 12.8007, 13.4400, 12.3221, 14.1239],
        [14.9686, 14.3004, 13.5545, 14.1383, 19.6850, 24.9880, 25.9674],
        [15.5045, 22.6865, 22.4359,  8.7110, 11.8820, 10.8615, 12.0030],
        [14.4884, 15.4870, 16.1809, 29.2050, 43.3022, 35.4634, 13.6128],
        [18.7130, 10.0926, 11.2934, 11.7176, 14.1233, 20.8529, 25.0266],
        [28.2271, 17.7266, 17.6917, 16.0952, 16.5595, 21.5009, 36.8803],
        [15.2879, 16.8425, 18.4691, 22.5394, 24.7605, 20.6133, 11.1768],
        [17.9083,  9.5763, 12.2226, 13.8860, 16.5736, 21.6323, 28.9040]])
ground truth
tensor([[ 6.6938, 28.2614, 26.0521, 13.5060, 12.0068, 14.6107, 16.1494],
        [14.1156, 11.0686, 12.8827, 14.9235, 16.3407, 26.6582, 27.5935],
        [17.2278, 20.8574, 20.8311, 10.1131, 13.0063, 10.8364, 12.5066],
        [14.0023, 14.7392, 15.0368, 21.3152, 43.0414, 33.1207, 16.6950],
        [ 8.6270,  9.3372,  9.8764,  9.9158, 16.8727, 24.6712, 26.1573],
        [23.2851, 2

batch_predictions
tensor([[ 9.1440, 11.4154, 11.5829, 12.6790, 17.1856, 26.3559, 19.4514],
        [ 7.9091, 10.1192,  8.0505,  7.0889,  7.5945,  7.1965,  7.3211],
        [29.7856, 12.6366, 14.9837, 13.8963, 14.0511, 16.5161, 26.8206],
        [ 9.5069, 12.9669, 13.2198, 14.8438, 22.0346, 30.4561, 19.2961],
        [15.2563, 16.5379, 18.4403, 32.8624, 43.7567, 31.5994, 15.5778],
        [12.9151, 15.6253, 15.6103, 16.9580, 32.3908, 43.1575, 30.0798],
        [10.6025, 15.7018, 15.8810, 17.4979, 21.7692, 28.1563, 30.6513],
        [22.4677, 22.3372, 27.7994, 37.2972, 41.4579, 27.1472, 22.4397]])
ground truth
tensor([[ 6.7464,  9.1005, 14.4792, 15.3077, 20.2920, 22.6328, 31.2073],
        [11.4796, 18.2540,  8.6310,  7.8515,  3.8832,  6.9586,  9.0561],
        [23.1718,  9.0703, 12.1882, 11.7914, 10.9269, 17.3611, 39.1582],
        [ 6.8452, 14.2290, 10.6434, 10.2608, 15.0652,  0.9921, 48.0867],
        [17.4178, 18.5941, 21.8112, 25.6094, 35.1474, 29.9461, 15.8305],
        [15.6746, 1

batch_predictions
tensor([[42.6403, 34.2337, 16.3792, 18.2465, 16.7759, 20.5760, 34.2798],
        [13.3683, 14.3731, 14.6257, 20.4778, 29.3777, 31.4552,  7.8612],
        [14.0420, 14.4918, 16.2621, 23.8858, 29.7951,  9.0960, 14.4322],
        [ 9.1498, 13.4206, 14.0196, 14.7160, 17.2556, 19.7983, 24.5948],
        [17.7486, 18.6208, 19.3831, 22.5495, 25.7675, 25.9478, 27.1401],
        [11.0935, 13.5944, 13.2952, 14.0530, 19.7821, 28.4901, 28.7425],
        [11.8854, 18.0454, 26.1363, 28.7757, 11.0925, 11.2061, 11.7482],
        [13.7518, 16.4779, 16.2362, 18.3340, 23.6358, 33.5397, 32.7479]])
ground truth
tensor([[49.9575, 43.3107, 19.7846, 20.5782, 14.5975, 20.3798, 36.1536],
        [15.6463, 11.2954,  1.8424,  5.6973, 20.2239, 28.6423,  5.0737],
        [ 9.4424,  0.0000, 12.1910, 12.2173, 28.1957,  4.5502, 15.3866],
        [ 6.1941,  5.1420, 13.7691, 10.0868, 14.3346, 18.4245, 22.4882],
        [12.7693, 14.0164, 19.0334, 18.0272, 23.9938, 11.1111,  8.6026],
        [11.5221, 1

batch_predictions
tensor([[11.7735, 14.2687, 14.2329, 14.9495, 25.2704, 37.7799, 28.1222],
        [14.7765, 16.3567, 24.8557, 26.6147, 14.3931, 14.5784, 14.1743],
        [40.9169, 29.1098, 19.5625, 18.1528, 15.8655, 19.6583, 37.9730],
        [10.9843, 12.1074, 12.4459, 11.9928,  6.6875, 11.5482, 13.6721],
        [26.5407, 16.5517,  9.9723, 11.7463, 12.6468, 14.0767, 19.9420],
        [ 9.7525, 10.9651, 11.2574, 17.0571, 22.3768, 16.3095,  7.4906],
        [11.4575, 17.6803, 17.2902,  6.7826,  7.9730,  7.4481,  7.9233],
        [11.5688, 12.2503, 13.8046, 22.8040, 25.6332,  8.4165, 11.9275]])
ground truth
tensor([[ 9.7364, 11.2103, 10.0057, 12.3866, 22.7324, 35.7851, 27.0692],
        [17.3330, 22.2120, 21.3309, 26.4335, 13.1510,  3.7612, 10.0605],
        [46.4853, 32.7948, 16.6241, 19.7562, 14.0590, 21.5420, 33.6593],
        [11.9148, 13.6376, 18.3456, 13.5981,  9.8501, 21.8043, 24.6712],
        [31.7464, 33.9032, 10.8759, 12.6644, 14.8343, 13.0852, 16.7149],
        [ 8.9164, 1

batch_predictions
tensor([[17.5869, 19.9148, 21.3023, 19.7442, 20.2141, 23.6427, 32.6464],
        [15.9589, 25.7647, 27.9799, 22.7038, 14.1646, 14.1832, 14.7362],
        [ 8.6834,  9.2281, 10.6747, 12.7100, 15.9091, 14.2542,  9.6508],
        [18.1314, 29.4860, 24.1899, 10.7104, 14.0383, 12.8545, 15.2721],
        [38.9056, 33.8155, 18.0049, 20.0313, 18.6380, 21.6788, 30.1320],
        [15.3998, 16.3498, 27.8264,  9.7824, 11.3180, 11.7200, 12.9941],
        [18.9370, 18.7450, 19.5602, 34.8284, 45.0255, 36.7225, 17.1280],
        [23.5371, 22.5412, 28.1389, 44.3936, 50.5886, 41.7342, 27.6597]])
ground truth
tensor([[18.7217, 19.1893, 22.8883, 23.4552, 21.2727, 25.7086, 41.5958],
        [18.7270, 25.4471, 21.1599,  0.0000,  9.9290, 10.6128, 15.7417],
        [ 0.0000,  1.8543,  9.0084, 11.5071, 17.7275, 18.9506, 27.6433],
        [17.4036, 27.3384, 26.0062, 10.7710, 13.0669, 11.3804, 19.9830],
        [47.3923, 41.3265, 17.4036, 26.0062, 16.2273, 21.3435, 25.8078],
        [15.9916, 2

batch_predictions
tensor([[29.0763, 43.1754, 41.7862, 13.8768, 18.0656, 18.6267, 19.3495],
        [24.7851, 31.2712, 10.9271, 14.2237, 14.9373, 15.0711, 16.7697],
        [29.0007, 27.1102, 15.7480, 15.9994, 14.8396, 14.7473, 19.4142],
        [ 8.5993,  8.1041,  7.7615,  8.9636, 12.1165,  7.5954,  6.2684],
        [11.6041, 12.3591, 17.9816, 26.3346, 24.4768,  9.3895, 11.0017],
        [28.9121, 10.3185, 13.3534, 15.2076, 18.2513, 25.0083, 34.8222],
        [10.6533, 14.5707, 14.2979, 15.2606, 17.4910, 26.6308, 36.1038],
        [25.3126, 26.3618, 12.3482, 15.6754, 15.3823, 16.5758, 20.5927]])
ground truth
tensor([[36.4413, 53.8401, 13.9532, 11.4808, 18.1483, 16.2809, 24.2767],
        [27.5907, 28.5113,  8.1799,  6.8254, 10.2446, 19.2136, 21.1336],
        [28.4587, 31.0363, 20.2788, 21.8701, 21.9884, 19.8711, 17.2804],
        [ 6.9728,  6.8452,  4.9603,  9.4104, 10.8985, 13.8180,  9.4955],
        [19.4240,  9.0216, 17.7012, 24.5529, 24.1057,  8.5613,  5.7075],
        [24.2205,  

batch_predictions
tensor([[16.4780, 16.1835, 16.5153, 26.0153, 41.1909, 40.3489, 18.1981],
        [41.5144, 33.4027, 13.3930, 17.8058, 17.0906, 17.6688, 26.8026],
        [21.2133, 12.7088, 12.8658, 12.5190, 15.1537, 22.7775, 26.2707],
        [17.3176, 24.7903, 25.2079, 12.4178, 12.3536, 11.4097, 12.7048],
        [15.1421, 14.7061, 15.5432, 21.8755, 30.4879, 25.8382, 21.8771],
        [ 8.2437,  9.5175, 16.1680, 21.1362,  6.4219,  7.5821,  7.6140],
        [ 7.7533,  8.1959,  8.9821, 11.6318, 18.8818, 18.9500,  7.4225],
        [46.8673, 41.6174, 31.9414, 28.9340, 28.0257, 29.1060, 38.2687]])
ground truth
tensor([[10.4875, 16.7375, 18.6224, 22.6616, 31.2783, 31.4342, 15.7171],
        [48.5402, 38.7188, 13.3645, 19.8129, 18.2823, 16.2982, 31.5193],
        [24.8816, 11.5203, 10.8890, 11.2704, 12.1646, 20.9495, 27.0647],
        [13.8874, 25.9206, 20.1473,  8.4035, 10.0999, 10.9153, 14.4529],
        [28.4297, 13.7046, 25.3401, 13.2795, 31.7602, 29.5493, 12.5142],
        [ 6.2073, 1

batch_predictions
tensor([[12.4799, 12.8378, 13.5460, 14.9832, 21.9121, 27.6948, 23.4229],
        [18.5018, 10.2569,  8.7895, 10.4225, 10.6842, 11.1735, 13.7419],
        [14.8071, 13.5116, 15.1170, 15.1627, 17.7673, 24.4193, 31.5423],
        [14.4184, 14.6888, 14.7262, 20.3422, 28.8232, 26.6830, 13.7603],
        [ 9.1347, 11.1155, 12.6486, 16.5428, 23.1214,  6.2048,  7.6727],
        [13.8411, 14.0224, 14.5239, 23.3883, 35.8188, 28.8038, 10.2050],
        [19.0729,  8.5487,  8.7311, 10.0495, 10.0972, 12.9275, 21.0311],
        [33.5807, 15.9537, 19.0735, 17.7005, 20.8643, 35.9725, 47.1797]])
ground truth
tensor([[1.1891e+01, 1.4598e+01, 1.7163e+01, 1.6610e+01, 1.8070e+01, 2.3611e+01,
         2.5595e+01],
        [1.2441e+01, 3.8269e+00, 1.9595e+00, 1.1007e+01, 9.3372e+00, 9.6397e+00,
         1.0008e+01],
        [3.0329e+01, 1.5916e+01, 1.2939e+01, 1.6596e+01, 1.5774e+01, 2.1372e+01,
         2.9393e+01],
        [4.2517e-02, 0.0000e+00, 5.1871e+00, 1.2982e+01, 2.8855e+01, 3.3404

batch_predictions
tensor([[14.0644, 13.7091, 14.6831, 24.3719, 30.9930, 28.4188, 11.5212],
        [17.8074, 34.1978, 29.2259, 11.6731, 13.8499, 12.6420, 13.7623],
        [20.4161, 32.9669, 37.7507, 26.6818, 14.1472, 15.1853, 16.0692],
        [11.2779, 11.6098, 12.3456, 14.7649, 21.3469, 14.4580,  9.1721],
        [ 8.1653,  9.0191, 10.4146, 13.6982, 19.4360,  7.9997,  7.7947],
        [ 9.7211, 12.3716, 12.2932, 12.6518, 17.5104, 27.1493, 22.5580],
        [18.9753, 28.1596, 37.6897, 21.4545, 14.7679, 15.6907, 16.8853],
        [16.4291, 33.2675, 40.5980, 25.5540, 18.6748, 16.5872, 14.9374]])
ground truth
tensor([[15.1502, 17.6162, 16.2557, 24.6740, 38.2511, 33.3759, 10.3600],
        [23.2426, 37.8543, 34.7931, 15.3628, 11.8764, 11.2245, 12.1882],
        [18.9637,  2.2225, 22.3172, 19.3451, 14.8080, 15.3998, 20.0552],
        [ 2.5644,  5.2078,  2.6302, 17.4250, 21.9095, 19.2267,  6.5097],
        [ 9.1268,  6.7859, 12.7959, 12.5592, 17.4250, 21.4361,  8.6665],
        [10.6786,  

batch_predictions
tensor([[16.0638, 15.5387, 16.5791, 32.0759, 41.7848, 32.8978, 15.3925],
        [13.3232, 14.4366, 18.9102, 26.1012, 27.5848,  9.9038, 12.2386],
        [22.2010, 25.3208, 29.0431, 19.2323, 18.4095, 17.0104, 16.9508],
        [17.9534, 28.1219, 35.6496, 32.7515, 15.1388, 17.5158, 17.0926],
        [27.5360, 34.3224,  8.1244, 13.5958, 15.1354, 16.0586, 21.1432],
        [ 6.4181,  8.3997,  9.7081,  6.6494,  4.7118,  4.8132,  5.8371],
        [21.3985, 12.6186, 14.6615, 15.2357, 16.4200, 21.3837, 29.7288],
        [36.8802, 36.1286, 12.6004, 15.7763, 15.6215, 15.7772, 21.1368]])
ground truth
tensor([[19.0760, 14.9943, 17.0918, 26.9133, 43.4382, 38.3078,  6.3350],
        [18.2256, 11.9615, 16.8226, 25.1276, 26.8991, 12.1457, 12.7693],
        [17.3856, 26.2756, 26.8148, 15.2157, 19.4240, 13.3482, 11.2178],
        [ 9.0278, 27.8912, 48.4694, 10.3741, 10.1474, 16.6241, 20.3231],
        [24.0400, 32.6407,  6.4703, 15.3077, 15.5576, 17.5302, 25.6970],
        [ 5.9240,  

batch_predictions
tensor([[30.8977, 12.3872, 16.0123, 15.0435, 15.7066, 21.5066, 30.6441],
        [11.6834, 11.9587, 17.2997, 23.3432, 23.7558,  8.1789, 10.4681],
        [17.5956, 23.1967, 21.1812, 13.0873, 13.9640, 12.7936, 14.9045],
        [13.6567, 17.2081, 22.5339, 24.3203, 13.8580, 13.1346, 13.2462],
        [12.9767, 21.7630, 27.5002, 21.9880, 11.6291, 12.6824, 12.5506],
        [14.3688, 13.9274, 14.8591, 22.8454, 30.2097, 28.1112, 13.1125],
        [22.3637, 20.5337, 10.1461, 11.9750, 11.0531, 12.0272, 17.3202],
        [ 9.4611,  8.1776,  7.6025,  7.7299,  9.8559, 14.9463, 20.6700]])
ground truth
tensor([[25.1276, 16.8651, 14.0590,  9.6088, 16.7800, 18.5091, 30.6406],
        [10.4734, 12.6134, 17.7721, 23.7103, 23.8237, 12.1032, 15.1786],
        [19.4240, 27.2225,  5.4182,  0.0000, 19.2267, 17.4513, 15.7680],
        [18.0563, 18.2272, 24.3425, 30.3656, 15.2025, 17.1357, 16.2678],
        [13.3220, 19.8129, 21.7687, 17.7438, 11.7489, 16.4541,  9.2404],
        [ 9.9915, 1

batch_predictions
tensor([[18.9508, 24.8488,  8.1665,  8.7428, 11.7276, 11.9983, 13.1243],
        [14.3283, 15.4321, 14.1818, 13.9939, 17.8301, 31.9132, 26.4834],
        [39.7763, 38.1581, 38.9429, 42.2379, 47.1483, 46.3331, 42.3061],
        [15.6314, 16.0014, 17.7815, 24.6914, 32.4855, 29.7118, 13.3861],
        [18.3428, 21.4474,  8.2852,  8.9240, 12.3951, 13.1773, 14.0461],
        [ 7.5664,  5.9714,  5.6273,  5.8124,  6.1310,  6.7627,  9.5765],
        [34.0520, 14.8822, 18.0378, 16.7062, 16.7776, 20.9898, 32.9084],
        [ 8.9923, 11.6774, 14.9942, 22.7441,  6.9465,  9.1373,  8.4070]])
ground truth
tensor([[18.8716,  5.4050,  6.6281,  3.7612,  3.3798,  8.4429, 11.6386],
        [11.2670, 14.8951, 13.4637, 16.3832, 18.5941, 30.3571, 24.9150],
        [65.8469, 52.4066, 44.9763, 55.1683, 64.6633, 47.9090, 42.1620],
        [16.3124, 14.6542, 13.3929, 24.7874, 35.6718, 29.2800, 15.3203],
        [15.5187, 18.5091,  8.7018,  3.0612, 14.7392, 14.1015, 15.0652],
        [ 9.1270,  

batch_predictions
tensor([[16.5485, 17.9599, 17.6369, 30.2747, 44.2565, 35.2153, 13.5316],
        [14.4944, 23.4289, 31.9481, 19.6464, 10.6634, 11.9293, 13.0238],
        [26.5330,  9.4795, 13.6931, 14.2709, 13.0550, 15.9205, 22.4990],
        [10.3568, 13.6874, 12.7855, 13.0816, 15.8326, 22.9698, 28.5176],
        [19.1050, 18.1990, 23.3977, 21.5579,  8.0042, 12.9619, 15.1220],
        [46.5541, 43.0831, 33.4445, 29.0262, 29.4973, 30.3960, 33.8538],
        [ 6.0478,  6.0958,  6.0353,  5.9489,  6.6855,  8.5837,  7.0878],
        [10.1395, 10.7525, 16.1942, 21.9650, 20.3242,  7.8289, 10.2473]])
ground truth
tensor([[18.1689, 20.3231, 13.8322, 27.9620, 48.5402, 38.7188, 13.3645],
        [14.7817, 21.5939, 42.0305, 30.3919, 13.3482, 12.4540, 14.3740],
        [27.7617,  7.7591, 11.2178, 12.4408, 16.4782, 12.6118, 16.2941],
        [ 9.5476, 11.0205,  7.5881, 10.5471, 11.3230, 24.2109, 29.6028],
        [38.7472, 41.0006, 33.2766,  7.9790,  3.2455, 13.8889, 14.2715],
        [60.9269, 4

batch_predictions
tensor([[23.8707,  7.2188, 10.8810, 12.0718, 11.8583, 12.2197, 21.9448],
        [16.1911, 21.9698, 35.0971, 28.6344, 16.4489, 16.6768, 15.3062],
        [36.4018, 13.4153, 14.6857, 16.0440, 17.2497, 20.9418, 31.8156],
        [15.6403, 16.2030, 15.7799, 24.7220, 39.1350, 38.9194, 21.5797],
        [12.8329, 13.6422, 18.7633, 28.5658, 31.0650,  8.5948, 13.4565],
        [10.7298, 13.8067, 13.6751, 14.1306, 19.3113, 26.9080, 24.6841],
        [15.1480, 15.9150, 21.2178, 24.7538,  9.0032, 14.2682, 14.7992],
        [12.9472, 13.8138, 16.6368, 25.1869, 31.3042,  9.8936, 12.9653]])
ground truth
tensor([[24.8027,  4.5634,  8.4824, 11.4413, 11.5860, 11.3624, 18.9900],
        [16.9785, 20.6633, 30.3997, 29.3367, 19.1327, 17.2194, 17.2902],
        [44.0476, 33.1066,  9.6372,  8.7443, 18.1831, 19.0334, 29.4076],
        [18.3248, 17.9989, 17.1485, 15.4478, 38.8322, 39.6542, 21.0601],
        [16.7659, 14.9660, 18.4524, 33.1774, 32.0011,  8.1207, 15.3486],
        [15.8730, 1

batch_predictions
tensor([[28.7752, 14.8202, 18.0021, 17.5554, 17.6225, 20.1148, 27.2903],
        [ 6.8065,  7.3160,  6.3447,  5.8148,  6.1552,  9.4801,  7.2669],
        [14.8224, 16.1054, 15.2842, 15.9738, 19.3962, 27.9459, 26.7883],
        [33.5809, 17.9988, 17.7959, 17.5313, 16.7385, 20.2361, 36.2705],
        [13.4226, 15.0949, 27.4170, 37.8358, 25.4299, 14.3463, 14.3679],
        [12.4199, 13.6945, 13.8594, 14.8580, 25.9061, 34.6034, 22.2087],
        [17.7469, 26.5514, 28.2298,  9.1497, 13.2432, 12.1704, 14.6823],
        [ 9.1934, 10.5846,  9.5528,  9.9026, 14.3331, 26.1776, 13.5629]])
ground truth
tensor([[42.6092, 21.2914, 21.0153, 25.0789, 20.9100, 20.9100, 29.1689],
        [ 6.2358,  8.9853,  8.6451,  6.3634,  6.9728, 11.7914,  7.3271],
        [13.8217, 12.8880, 13.8480, 17.9905, 22.1331, 29.6160, 29.3135],
        [36.8753, 27.6302, 12.7959, 13.6902, 17.1094, 25.4340, 41.3335],
        [11.9615, 14.0731, 23.7670, 36.3379, 29.3793,  9.9773, 11.6497],
        [11.6123, 1

batch_predictions
tensor([[13.7907, 15.4282, 15.0505, 16.6854, 23.6996, 30.3084, 24.3551],
        [15.2262, 15.4819, 16.7448, 26.0022, 37.9210, 27.3101, 15.0811],
        [16.3048, 19.2565, 18.1732, 18.4165, 20.9301, 27.8308, 33.1315],
        [12.8162, 14.4514, 22.5487, 31.8425, 23.1183,  9.4366, 13.0805],
        [39.9708, 43.0216, 43.7480, 40.3436, 37.3306, 37.4399, 39.9827],
        [16.7647, 24.6934, 32.6828,  7.8330, 13.7421, 13.4819, 14.8299],
        [14.1120, 14.4915, 14.7312, 21.5144, 30.3405, 26.1930, 10.5141],
        [27.8023,  7.4869, 13.4897, 13.7487, 14.0372, 20.0855, 28.6197]])
ground truth
tensor([[14.8526, 12.8118, 19.6854, 16.1139, 21.9104, 27.7494, 24.6173],
        [14.7028, 14.5055, 18.9506, 27.5250, 52.9853, 17.4119, 13.7033],
        [21.2914, 21.0153, 25.0789, 20.9100, 20.9100, 29.1689, 32.1278],
        [10.7143, 18.8634, 28.2455, 37.6984, 23.1576,  5.6973,  5.7540],
        [38.5324, 45.5155, 53.4850, 55.9179, 38.1378, 40.7286, 22.9090],
        [15.2157, 2

batch_predictions
tensor([[14.0810, 14.3932, 22.0533, 22.9639, 13.0907, 13.2101, 13.3408],
        [28.8356, 13.3669, 15.5367, 15.8485, 16.9987, 21.5704, 30.4090],
        [ 8.7155,  8.9333, 11.3011, 11.6488, 14.0990, 20.5810, 19.6270],
        [30.1228, 13.6520, 15.5050, 15.7865, 16.0930, 18.0693, 26.2619],
        [12.9059,  7.2863,  6.7662,  6.8254,  6.4296,  7.6614, 13.9003],
        [10.9633, 11.1567, 11.0725, 13.9761, 18.7626, 23.1178,  7.2139],
        [10.1822, 10.9161, 11.2333, 16.9060, 24.0879, 18.6881,  7.7269],
        [15.2112, 16.4284, 18.2870, 25.9792, 34.5229, 30.8076, 13.0509]])
ground truth
tensor([[14.6370, 14.3872, 22.6854, 23.7507, 12.4540, 13.6639, 13.1773],
        [24.8158, 10.0765, 13.6196, 13.5346, 17.0493, 26.4456, 26.6582],
        [ 6.6807,  7.4698,  9.7843,  6.1678, 11.9016, 18.9374, 21.0021],
        [30.4708, 15.0447, 15.9127, 10.5997, 20.1341, 17.9116, 18.6218],
        [15.8732,  5.2209,  6.6544,  6.6675,  6.4440, 10.5865, 13.0326],
        [ 8.2720, 1

batch_predictions
tensor([[16.3347, 16.4287, 21.8460, 37.4572, 40.1367, 23.9814, 17.3139],
        [40.8827, 39.5054, 20.0872, 22.0481, 21.6900, 25.6352, 33.8991],
        [13.3778, 14.8462, 20.0181, 31.8743, 32.9115,  9.3517, 13.4534],
        [34.1113, 10.9036, 16.0394, 16.3944, 17.6127, 25.7150, 37.1938],
        [36.9978, 21.9841, 16.3848, 16.0331, 15.2898, 19.5496, 39.8247],
        [13.4219,  8.4636,  7.5421,  8.3757,  8.8004,  7.6105, 10.8375],
        [12.7555, 15.1947, 16.7151, 18.1851, 18.9349, 22.0589, 10.8001],
        [27.3479, 11.2955, 13.7518, 14.0956, 14.9455, 18.1882, 26.6883]])
ground truth
tensor([[ 9.0610, 14.9395, 18.9637,  2.2225, 22.3172, 19.3451, 14.8080],
        [36.3946, 37.1740, 10.0624, 15.3203, 20.3090, 18.0414, 35.7426],
        [19.0618, 16.4116, 14.7392, 30.4563, 31.5476, 11.6922, 17.5170],
        [ 4.2234,  0.0000, 15.7455, 17.9138, 21.0176, 24.5323, 23.4410],
        [38.9031, 18.9626, 10.5584, 16.7517, 15.4478, 22.1797, 44.8696],
        [15.0935,  

batch_predictions
tensor([[20.9238, 21.6853, 24.9301, 38.7265, 45.4548, 42.8821, 21.5315],
        [21.7404, 25.8136, 19.4561,  8.9271, 13.9873, 15.2314, 16.1400],
        [32.9368, 28.7981, 11.4830, 14.2817, 12.7887, 13.1934, 19.5238],
        [20.4977, 29.8765, 32.9247, 13.1744, 16.9341, 17.4314, 16.7501],
        [45.8470, 39.9734, 30.5592, 26.5769, 25.7496, 27.4608, 37.1519],
        [39.0842, 32.6277, 31.6353, 32.4387, 37.6942, 43.0839, 45.6734],
        [13.5079, 12.4695, 13.5362, 17.7558, 21.7758, 22.1634, 12.7221],
        [15.0953, 21.5707, 31.4804, 29.9505, 13.1581, 15.2246, 14.9016]])
ground truth
tensor([[21.3435, 23.6111, 25.1417, 19.1893, 41.2557, 40.3345, 14.7817],
        [16.9926, 25.2551, 26.4314,  2.2959, 13.6054, 13.4779, 12.3724],
        [39.1723, 29.2234, 10.7143, 14.1440, 12.9535, 14.3282, 21.8821],
        [31.9586, 38.6480, 41.0856,  8.1633, 13.9881, 22.7466, 19.3878],
        [59.1128, 46.8821, 34.0561, 23.5686, 27.1684, 25.6519, 39.9802],
        [23.0668,  

batch_predictions
tensor([[ 9.6314, 10.9461, 15.5689, 21.4507, 19.5570,  7.4869,  9.7112],
        [14.0574, 14.9039, 15.6848, 23.6058, 32.7935, 27.1616,  9.3034],
        [10.9664, 11.7562, 13.5434, 18.9381, 23.1618, 16.0812, 10.5378],
        [47.6823, 40.8276, 32.6076, 30.2404, 29.9004, 32.0498, 40.3309],
        [28.5502, 12.7407, 11.6636, 13.0228, 13.3565, 17.2614, 23.4382],
        [14.0179, 19.9971, 22.5863, 18.8697,  9.4521, 10.8701, 11.8140],
        [ 8.3415,  9.9457, 16.1819, 21.7919, 10.7767,  7.7939,  8.4859],
        [14.4876, 15.0219, 16.5219, 22.7073, 30.9018, 28.2002, 12.5362]])
ground truth
tensor([[ 8.7980,  9.7843, 15.7417, 18.8979, 16.7675,  7.0226,  9.9421],
        [11.2309, 16.7017, 14.1241, 19.8448, 28.1562, 33.4955, 31.3519],
        [ 5.3393, 14.0189, 14.4792, 18.0957, 21.7386, 16.6754,  0.0000],
        [55.9382, 52.4802, 25.5385, 19.2460, 23.9938, 36.5363, 58.8861],
        [28.4864, 27.9053, 11.2245, 14.2715, 16.3124, 11.6780, 17.3895],
        [12.8617, 1

batch_predictions
tensor([[ 7.7779,  8.5682, 10.0325, 12.8492, 19.9703, 19.8302,  7.9675],
        [18.4459, 20.7790, 30.7574, 38.7708, 25.4582, 18.7425, 18.5786],
        [14.5292, 15.9701, 23.1761, 23.8760, 22.7450, 11.1742, 12.7636],
        [39.3518, 20.7455, 20.6595, 19.4246, 21.8121, 34.8586, 52.4919],
        [17.6180, 16.1220, 21.0182, 36.6215, 42.7119, 30.2915, 18.3216],
        [26.6463, 11.2755, 13.0966, 12.9598, 14.3892, 23.5294, 36.8442],
        [16.7718, 16.4340, 17.1058, 24.7262, 29.9197, 32.8178, 12.1801],
        [11.6836, 11.5344, 12.0705, 16.0565, 24.1424, 22.3327,  9.4643]])
ground truth
tensor([[ 6.0494, 11.0205, 11.5071, 11.8885, 13.4008, 17.7407,  6.7070],
        [16.4966, 21.1451, 19.3027, 44.0760, 30.8673, 18.3107, 14.0306],
        [11.1678, 12.6701, 20.5641, 26.7715,  6.5334,  8.2625, 15.0085],
        [39.2425, 21.1205, 17.0174, 17.5829, 24.5003, 37.8880, 62.0200],
        [17.8146, 17.4603, 23.7812, 35.0340, 51.0488, 28.5431, 16.6950],
        [27.5227, 1

batch_predictions
tensor([[18.9497, 18.4916, 21.2357, 35.9887, 45.1594, 36.5229, 20.6919],
        [29.3435, 27.9533, 14.0365, 16.7815, 16.6388, 19.4904, 23.7696],
        [11.5599, 12.6262, 14.3739, 14.7874, 21.2883, 27.1353,  9.5014],
        [34.8719, 12.7843, 15.3930, 16.6087, 18.3314, 27.1788, 39.3012],
        [36.1042, 29.3962, 13.0590, 15.6016, 15.3598, 16.6917, 27.9047],
        [11.9595, 16.6373, 22.8691,  8.5731,  8.4428, 10.1390, 10.1807],
        [11.0901, 11.5511, 11.6357, 17.8798, 23.2375, 19.3549,  8.5375],
        [11.7231, 17.8981, 22.2724,  6.7288,  7.8995,  8.4764,  9.3316]])
ground truth
tensor([[20.2656, 26.5387, 23.5928, 31.4966, 42.4776, 46.1599, 18.0563],
        [19.3582, 22.6986, 13.9006, 13.5981, 15.2946, 18.4377, 22.4882],
        [10.7575, 14.8738, 17.4119, 17.3330, 27.1173, 33.7454,  9.5476],
        [37.0890, 21.1026, 15.2920, 16.2557, 13.7613, 16.7092, 35.0198],
        [40.9580, 37.5283, 12.6417, 16.4399, 14.7676, 15.0794, 27.6644],
        [11.7701, 2

batch_predictions
tensor([[ 8.6965,  9.2482, 12.1834, 19.8345, 21.5518,  6.3626,  7.9562],
        [11.9349, 14.5372, 14.3932, 14.8557, 17.9545, 22.3948, 20.2044],
        [11.4427, 14.5093,  8.9257,  6.8731,  7.0230,  7.3907,  8.7790],
        [26.9250, 36.8131, 30.5713, 13.4340, 16.5898, 14.1737, 15.3406],
        [15.1098, 21.0190, 25.9681, 10.6396, 12.6349, 12.6308, 13.9566],
        [20.5632, 10.5812, 12.4432, 13.7012, 18.5326, 21.4619, 24.6929],
        [27.8268, 25.4214, 11.7052, 13.4603, 13.0215, 14.3562, 18.9485],
        [23.6051, 11.6131, 10.1562, 11.4917, 11.3025, 15.5346, 21.7202]])
ground truth
tensor([[ 5.7733,  5.4840, 10.7180, 15.5050, 21.9095,  3.4719,  7.5881],
        [ 9.0347, 12.9669, 14.9790, 14.4792, 17.7407, 20.6076, 26.1573],
        [10.4450, 15.6463,  9.0561,  0.2834,  6.6893, 10.3033,  7.9507],
        [14.8384, 35.4025, 33.6593,  8.3617, 17.2761, 14.5266, 15.1077],
        [13.4271, 15.6497, 25.3156, 32.5224, 13.6639, 20.7259, 19.1478],
        [ 5.7864,  

tensor([[19.3074, 27.0464, 26.7588,  9.7979, 14.0967, 13.7853, 15.3146],
        [18.6353, 22.3388,  7.7486,  7.5517,  9.8848, 11.2465, 12.3411],
        [16.7298, 16.7135, 22.4279, 36.6621, 37.8903, 30.1900, 18.9462],
        [17.8219, 14.8911,  9.8393, 11.1695, 12.6764, 11.7798, 16.3575],
        [18.6788, 13.7072, 10.5082, 10.3733, 10.4945, 12.3726, 15.7964],
        [29.3994, 24.9075,  8.2696, 12.3275, 11.4315, 12.9832, 18.7774],
        [10.7828, 10.7455, 14.3134,  6.4433,  6.2959,  8.4648,  9.1241],
        [34.7168, 42.8283, 26.3775, 21.2745, 21.6203, 21.6097, 24.2578]])
ground truth
tensor([[20.5287, 31.5623, 28.9453, 12.6381, 16.4256, 16.9385, 16.7806],
        [14.2162, 17.5565, 15.1762,  9.0347, 12.9800, 15.7286, 11.0731],
        [19.2319, 12.2166, 31.2075, 31.5618, 41.6241, 23.9654, 26.8424],
        [23.3824,  2.5381,  4.8264,  8.7849, 13.4008, 12.5197, 14.2162],
        [23.3298, 16.0310, 13.1247, 14.1504, 13.3482, 13.1247, 13.6244],
        [33.8294, 29.0249,  7.4263, 1

batch_predictions
tensor([[14.0789, 15.2588, 19.8211, 29.2663, 30.8645,  8.7603, 13.0952],
        [11.8112, 17.7156, 16.2692,  6.7379,  8.9256,  7.5631,  7.9471],
        [15.5674, 25.0552, 33.2717, 21.4466, 14.7115, 14.3568, 13.8701],
        [14.0223, 17.8330, 24.1516, 19.7713, 10.2812, 11.0089, 11.4063],
        [15.8893, 18.1572, 16.2469, 16.0736, 19.6158, 29.3887, 36.0802],
        [10.5385, 12.9736, 13.3075, 13.5840, 18.4596, 25.9861, 26.2640],
        [14.8938, 19.7015, 32.7014, 32.7632, 22.5800, 19.0253, 15.9969],
        [ 7.3653,  8.3360, 11.4832, 18.4556, 17.7704,  6.3653,  7.2345]])
ground truth
tensor([[ 0.3260,  3.7415, 17.3186, 32.7664, 28.9541,  1.9133, 10.9552],
        [11.4150, 15.7812, 12.6644,  4.9185,  9.5213,  6.4308,  7.5618],
        [17.0700, 27.7091, 45.5944, 26.5124, 16.7412, 14.7948, 12.6644],
        [13.9926,  4.5502, 26.9726, 19.1347, 13.3351, 11.2835, 10.2709],
        [20.2098, 19.1043, 16.6950, 19.1752, 21.9671, 32.2137, 35.4308],
        [12.1252, 1

batch_predictions
tensor([[25.6029, 28.2856, 15.5798, 15.9573, 14.2847, 13.8900, 17.8592],
        [15.2582, 15.3144, 15.9360, 29.3930, 45.4300, 36.1459, 11.9271],
        [10.2234, 10.1534, 10.5662, 16.3950, 23.5235, 19.9931,  9.4290],
        [15.9084, 15.8511, 21.4117, 38.2564, 38.1122, 11.8071, 16.1983],
        [33.1029, 38.4271, 21.9043, 19.2474, 18.9338, 17.9034, 18.8765],
        [ 6.8306,  7.9612,  7.8629,  9.1648, 13.5090, 14.2728, 19.7776],
        [10.1594, 16.9256, 22.8080, 21.1397,  9.3610, 11.3162, 11.0313],
        [15.5972, 22.1252, 30.0760, 32.8559,  9.9988, 13.0431, 13.7896]])
ground truth
tensor([[22.5340, 24.8724, 15.7738, 11.7063, 11.6497, 10.1049, 17.1910],
        [11.3662, 12.7409, 13.1236, 27.2676, 41.3265, 28.4297, 13.8889],
        [ 9.3372,  8.8901,  9.3766, 11.4545, 27.4724, 22.7380, 11.1257],
        [15.7171, 15.6746, 15.4478, 43.4524, 38.3929, 19.0051, 19.2744],
        [28.9541, 51.4031, 36.6071, 13.9598, 16.2415, 18.7500, 14.0873],
        [ 2.9053,  

batch_predictions
tensor([[28.5164, 22.5911, 12.9591, 14.1822, 13.3505, 14.7786, 22.7384],
        [24.2923, 33.8093, 19.0271, 10.9497, 12.3882, 12.7807, 15.4920],
        [21.8256, 38.2864, 36.2004, 20.9427, 20.0189, 16.5788, 15.8189],
        [ 5.3391,  5.7302,  5.9298,  5.9873,  7.3388,  8.5002,  5.3277],
        [16.9956, 17.2403, 16.1974, 23.1173, 37.4668, 44.2990, 29.5266],
        [20.2397, 28.4863, 38.6404, 26.0379, 17.8780, 17.4513, 16.5695],
        [13.8257, 13.1542, 14.4245, 18.7304,  6.3720,  6.2496, 13.0026],
        [14.4263, 24.6384, 25.6756,  9.5135, 12.5696, 11.6163, 12.6641]])
ground truth
tensor([[2.6361e+01, 2.7112e+01, 1.4994e+01, 1.4994e+01, 1.3109e+01, 1.4683e+01,
         1.8537e+01],
        [3.2596e+01, 3.8931e+01, 3.0499e+01, 1.3776e+01, 1.0686e+01, 1.1961e+01,
         1.4073e+01],
        [1.1409e+01, 4.9036e+01, 4.3155e+01, 1.3237e+01, 1.0771e+01, 1.6610e+01,
         2.1542e+01],
        [3.7557e+00, 3.3305e+00, 4.6344e+00, 5.9240e+00, 8.4042e+00, 1.0077

batch_predictions
tensor([[14.3673, 16.9941, 28.0709, 29.0253,  9.2013, 13.3720, 12.9802],
        [42.3203, 39.3881, 14.4766, 19.6007, 19.1055, 22.1963, 32.4616],
        [23.4471, 24.0908, 35.0022, 46.9087, 48.5382, 41.2413, 27.9647],
        [ 8.3714, 10.1216, 13.1057, 19.1512, 20.9745,  6.2428,  7.4770],
        [12.7964, 12.2193, 13.3217, 17.8859, 22.4831, 24.6319, 14.7331],
        [17.2670, 25.7865, 34.1729, 32.3456, 15.3192, 16.5992, 16.3264],
        [10.7066, 14.3490, 14.7388, 15.8413, 24.8992, 31.2931, 25.1749],
        [28.4263, 18.4639, 13.2073, 14.2081, 15.7016, 21.8898, 30.7300]])
ground truth
tensor([[14.2574, 15.2920, 25.0283, 28.8549,  8.7160, 13.2795,  9.4813],
        [40.6321, 40.6746, 20.7341, 18.4807, 22.3498, 21.8537, 32.0011],
        [30.7256, 29.5210, 26.2897, 39.3707, 54.6202, 49.0079, 31.3917],
        [ 5.9442, 11.3624, 15.5839, 21.3440, 16.7806,  4.0505,  7.8248],
        [12.3882, 13.0721, 18.0563, 18.2272, 24.3425, 30.3656, 15.2025],
        [24.5890, 2

batch_predictions
tensor([[26.4147, 21.0629, 10.6097, 10.8077, 10.0949, 11.4623, 20.3428],
        [20.4496, 24.0892, 16.6807,  9.6592, 11.6267, 11.7749, 14.6204],
        [13.8061, 15.2325, 17.7230, 25.1876, 27.3592, 13.7190, 13.9158],
        [25.7884, 34.6135, 37.0352, 24.9406, 21.4504, 19.5363, 19.9687],
        [38.0393, 30.0114, 16.4614, 15.4640, 15.6549, 18.0556, 27.5540],
        [13.4065, 19.4595, 30.3504, 33.6462, 13.7164, 14.3395, 12.8475],
        [13.1328, 17.0659,  7.0954,  7.7989,  7.1375,  7.2152,  9.6487],
        [13.1590, 14.9783, 17.5463, 19.4048,  8.8862, 10.8487, 12.4924]])
ground truth
tensor([[27.5250, 26.7359, 12.6907,  8.3509, 11.1783,  7.3251, 13.9795],
        [18.5823, 20.0947, 33.7322, 12.9932, 13.7822, 14.4266, 15.2551],
        [13.1641, 17.3330, 22.2120, 21.3309, 26.4335, 13.1510,  3.7612],
        [26.4203, 39.3477, 43.2667, 39.2031, 30.2078, 29.1294, 24.5266],
        [38.1803, 30.5839, 19.9972, 16.8226, 14.9518, 19.6429, 25.4819],
        [11.0994, 2

batch_predictions
tensor([[19.7739, 10.1531, 12.6768, 13.5722, 15.6407, 22.0881, 25.7894],
        [15.7989, 15.5890, 16.3309, 29.1032, 39.1354, 26.8170, 14.8308],
        [26.0267, 19.6810, 12.1086, 13.0432, 12.9009, 13.6683, 20.5951],
        [14.3481, 17.0748, 25.9557, 29.1181,  7.7587, 14.1741, 13.8880],
        [12.6540, 21.0363, 22.4184, 11.5568, 11.5690, 10.6147, 11.1447],
        [24.0775, 26.5537, 10.3732, 14.4454, 14.0339, 14.4442, 18.3585],
        [10.3893, 17.8909, 18.1257,  7.6698,  7.4528,  7.8647,  7.7847],
        [16.7614,  7.3739,  7.2372,  7.0957,  7.0066,  9.1115, 14.8573]])
ground truth
tensor([[25.2834, 14.2007, 13.6196, 17.4461, 17.9138, 17.0777, 20.8333],
        [15.2814, 18.6349, 14.2162, 31.9437, 31.1021, 32.0621, 13.2825],
        [28.5714, 24.9433, 12.0040, 15.2069, 11.9898, 13.3929, 18.0130],
        [13.2653, 19.4161, 29.8895, 36.1536,  4.9603, 15.6463, 11.2954],
        [12.1646, 22.1988, 17.7670,  2.3672,  5.1946,  9.1662, 12.0726],
        [23.8033, 3

batch_predictions
tensor([[11.8402, 17.0659, 24.4989, 24.5747, 12.5068, 11.8848, 11.6121],
        [23.2688, 10.1534, 13.1218, 13.4223, 13.7894, 15.6567, 17.1999],
        [17.3611, 14.0063, 14.6301, 16.1987, 23.6205, 36.8558, 31.4855],
        [15.3345, 15.1837, 15.6816, 18.0440, 22.8109, 26.6917, 22.6275],
        [32.5135, 10.1444, 15.3006, 15.6709, 17.1801, 22.1148, 36.1533],
        [14.7605, 19.8189, 28.3797, 21.6438, 11.4171, 13.3030, 13.9467],
        [15.1535, 15.2745, 20.1453, 25.9131, 28.7999,  8.1290, 13.4693],
        [15.1994, 17.7551, 18.5899, 22.2462, 40.8123, 52.1215, 43.7574]])
ground truth
tensor([[11.8622, 13.8874, 25.9206, 20.1473,  8.4035, 10.0999, 10.9153],
        [19.6081,  6.5492,  9.9421, 11.5729,  9.9290, 14.0058, 12.6381],
        [ 0.0000, 11.5930, 17.1769, 13.1519, 20.4790, 25.6519, 29.3793],
        [14.2031, 18.4245, 14.4661, 14.5450, 17.0042, 32.7985,  8.6139],
        [34.9348,  8.5317, 16.4966, 15.0935, 16.2840, 19.8129, 31.0232],
        [13.0852, 1

batch_predictions
tensor([[11.4405, 11.2991, 13.5505, 21.1532, 24.7475,  8.3426, 11.3166],
        [11.9533, 13.3185, 13.0531, 13.6181, 20.1589, 33.5962, 26.2417],
        [ 8.7315,  9.2477, 10.9057, 16.8630, 11.2044,  7.2745,  9.0494],
        [35.0352, 44.1616, 36.5065, 17.5161, 17.6266, 18.7789, 22.2059],
        [37.4196, 32.5064, 11.0762, 16.5173, 16.1475, 17.2548, 26.2011],
        [12.5143, 12.8766, 16.4100, 25.9715, 24.2725, 11.7250, 13.1235],
        [33.8582, 30.4903, 21.8045, 24.0825, 22.5486, 22.7532, 25.2384],
        [11.1672, 11.8509, 13.5537, 19.6190, 22.0958, 17.5637, 10.6550]])
ground truth
tensor([[ 1.0915,  7.8248, 13.5718, 13.7165, 24.7764,  7.3777, 13.3877],
        [10.0340, 10.8277, 10.7568, 15.2353, 22.4632, 33.8294, 27.1117],
        [11.0686,  9.0278, 11.4512, 13.5488,  8.1491,  7.6672,  7.9365],
        [50.0000, 37.0040, 27.3951, 15.5612, 18.5516, 16.9076, 20.6633],
        [37.9960, 35.3316, 10.3458, 17.2194, 19.1185, 16.4966, 29.2234],
        [10.4156, 1

batch_predictions
tensor([[ 9.1707, 10.4181, 12.4173, 14.9963, 11.9063,  7.7159,  8.6573],
        [33.4650, 43.9161, 36.6618, 15.9486, 18.3409, 17.2067, 18.5876],
        [14.2471, 12.8444, 13.2484, 18.1524, 24.4370, 22.9262, 12.3112],
        [14.3477, 14.6480, 15.9704, 22.4944, 31.7256, 29.5463, 11.6923],
        [ 5.4356,  5.3433,  5.4485,  6.0423,  7.1000,  6.8340,  5.2086],
        [10.2048, 13.3113, 20.9946, 21.2857,  5.9125,  7.1556,  7.8652],
        [ 9.1738, 10.2350, 13.1787, 19.3794, 13.2787,  7.4816,  8.4319],
        [12.5293, 14.1236, 14.2125, 20.5288, 31.1639, 25.1577,  8.6979]])
ground truth
tensor([[10.6917, 10.7838, 14.9395, 15.3209, 11.1915,  8.2851, 11.1783],
        [32.7948, 43.5799, 35.8418, 16.5675, 18.0981, 18.9768, 17.7721],
        [14.3566, 11.1395, 13.3787, 19.6145, 22.3923, 17.4320,  6.8452],
        [16.9926, 13.5062, 15.5187, 19.7846, 30.8673, 30.9382, 10.7710],
        [ 5.6973,  6.1650,  6.3492,  8.5317,  9.8073,  7.1570,  4.0675],
        [10.8759, 1

batch_predictions
tensor([[16.6235, 16.5000, 16.5895, 22.2600, 32.0469, 30.4368, 20.7906],
        [18.8164, 19.2317, 21.0703, 17.0276, 11.6457, 13.2907, 14.7104],
        [ 6.7910,  8.1194,  7.6215,  8.6746, 13.6091, 20.3210, 14.7641],
        [12.8495, 16.3579, 21.7322, 18.7107,  9.8514, 12.9615, 12.7553],
        [ 6.5882,  7.1343,  7.4116, 11.2194, 16.0465, 16.9071,  6.7771],
        [31.3965,  9.3747, 15.5891, 16.1607, 17.0135, 21.6252, 35.8958],
        [21.3470, 34.6755, 31.0987, 15.0879, 16.3174, 14.9331, 15.9712],
        [12.9762, 15.1738, 22.2784, 27.9670,  6.9585, 13.3060, 11.5700]])
ground truth
tensor([[18.5658, 13.9456, 14.0873, 22.8741, 32.3980, 30.8107, 23.5544],
        [18.5297, 20.0815, 22.5145,  4.5371,  0.0000,  8.0878,  7.8511],
        [ 0.0000, 12.4934,  7.8511,  8.6402, 12.6512, 18.1483, 15.5050],
        [14.5844, 19.8711, 23.6323, 30.7733,  9.0347, 12.9669, 14.9790],
        [ 6.9700,  7.0226,  6.7728, 10.2578, 12.9274,  6.1547,  0.0000],
        [37.3441,  

batch_predictions
tensor([[20.8934, 22.6576, 33.2599, 35.1734, 24.8758, 20.3727, 20.2104],
        [11.8638, 11.9840, 13.5734, 20.6240, 23.0518, 16.8588, 10.1189],
        [19.3175, 19.6466, 19.1835, 25.3190, 41.3049, 47.9891, 28.7409],
        [14.7201, 15.5144, 18.7178, 22.9417, 25.3782, 13.1098, 13.4722],
        [15.5884, 18.7105,  9.5299,  8.6274,  7.4292,  7.5634, 10.1332],
        [14.9758, 15.9568, 24.6710, 33.8906, 27.7417, 15.5625, 16.7437],
        [24.3726, 21.8538, 20.8309, 12.8785, 13.3090, 13.4418, 14.8410],
        [14.9611, 15.7175, 15.7480, 25.0635, 37.6979, 34.5013, 14.7147]])
ground truth
tensor([[24.5266, 28.7612, 33.3772, 10.2972,  0.0000, 15.8864, 25.6970],
        [ 7.6013,  9.8238, 16.7675, 20.7785, 21.5413,  8.6270,  9.3372],
        [ 0.0000, 15.4787, 21.9227, 27.8538, 38.6770, 61.2967, 59.5082],
        [10.3235, 10.2578, 13.5192, 22.2120, 20.7654, 10.4813, 13.4140],
        [15.0316, 18.8585,  9.1925,  7.1015,  5.5892,  6.1284, 10.7575],
        [13.8874, 1

batch_predictions
tensor([[10.2708, 11.6398, 14.7043, 22.7754, 21.4453,  9.0009, 11.0302],
        [19.7985, 28.7242, 37.3309, 12.8516, 13.8168, 13.5716, 14.8894],
        [36.2500, 15.4094, 17.7972, 16.4336, 17.4562, 27.1383, 45.2438],
        [26.7471, 11.4761, 13.4868, 13.5313, 14.6690, 22.0520, 27.2171],
        [15.4501, 15.0546, 17.2865, 28.5258, 34.4839, 25.9204, 14.8693],
        [13.5260, 14.9690, 19.0860, 29.1798, 25.0739, 11.2354, 13.9285],
        [11.3343, 12.3097, 16.5244, 22.0386, 20.3073,  8.0017,  9.9532],
        [26.8763, 14.7168, 15.7352, 14.8976, 17.8420, 32.7832, 42.7053]])
ground truth
tensor([[ 9.7506, 17.3895, 16.2982, 31.5901, 24.1780,  9.7931, 15.2636],
        [19.9500, 30.7601, 33.6533, 12.3225, 15.1236, 10.1394, 12.2304],
        [37.4717, 22.8175, 18.1973, 13.1236, 11.3662, 21.6837, 47.6190],
        [17.5171, 12.3882, 16.0179, 14.1636, 11.6649, 20.3314, 29.6686],
        [11.8481, 14.6684, 20.6633, 25.1701, 33.7302, 28.3872, 14.5125],
        [11.3804, 1

batch_predictions
tensor([[ 8.6263, 10.2543, 11.3359, 12.2752, 15.8992, 23.7559, 21.9856],
        [ 9.1500, 11.6508, 11.1680, 12.3037, 14.3553, 22.5985, 23.0912],
        [22.3785, 25.5509, 14.1351, 13.1364, 14.2684, 15.7532, 18.1969],
        [10.2396, 12.9022, 12.8685, 15.2823, 22.4664, 24.8326, 18.0574],
        [10.3719, 12.5841, 12.0956, 11.6297, 20.0487, 24.0891,  6.7370],
        [11.9045, 16.3570, 18.1671, 19.1055, 28.2933, 38.7381, 38.9565],
        [ 7.1382,  8.0206,  9.3871, 12.1448, 13.5040,  5.6547,  5.7716],
        [13.4153, 14.8623, 15.7302, 16.6963, 16.5307, 22.4305, 23.3568]])
ground truth
tensor([[ 0.0000,  1.1178, 10.6128,  8.8901, 11.4150, 20.9890, 19.7528],
        [ 5.8916,  9.5213,  8.7980, 11.7570, 12.1646, 23.0142, 17.9248],
        [19.0193,  8.9144, 10.2466, 11.2387,  9.7647,  9.6655, 18.4099],
        [12.2449, 13.0102, 11.0828,  8.4467, 19.9121, 23.8237, 21.4144],
        [ 8.5087,  9.3109, 10.3498, 11.4150, 18.9506, 24.8027,  4.5634],
        [ 7.7806, 1

batch_predictions
tensor([[10.7966, 11.7040, 15.7679, 22.8724, 24.0895, 11.6807, 12.2254],
        [28.0574, 15.6566, 16.8478, 17.1211, 19.8898, 27.3037, 29.4369],
        [ 7.7592, 10.6131, 13.0190,  7.3113,  6.6573,  6.2714,  6.5594],
        [25.6344, 12.0556, 14.9717, 14.5618, 15.0101, 19.5827, 27.9716],
        [ 9.7046, 10.9745, 10.8536, 10.9697, 13.2570, 21.2466, 23.2409],
        [19.3225, 10.1121, 11.9361, 13.0357, 14.5861, 22.8876, 26.8578],
        [ 7.4502,  8.5448, 10.7398, 12.1436, 18.3558,  6.7010,  6.4581],
        [11.9798, 13.3607, 15.5390, 27.6622, 27.1564, 10.5888, 11.6198]])
ground truth
tensor([[11.2387, 11.3520, 15.0935, 21.1026, 27.4376, 18.1831, 16.9218],
        [21.5420, 16.5108, 22.7608, 21.2302, 16.5249, 22.9875, 28.6706],
        [11.9756, 10.0624, 15.0510,  7.4263,  8.9427,  9.7647,  5.7823],
        [27.8932, 12.7301, 11.2967, 11.2572, 12.9011, 18.4377, 22.8169],
        [10.0473,  9.4950, 10.7443,  7.7722, 10.5865, 21.7649, 22.1725],
        [37.9252, 1

batch_predictions
tensor([[24.8168, 11.2844, 13.2616, 12.9402, 13.3492, 16.6240, 23.7509],
        [12.9380, 12.4360, 13.1336, 14.8160, 21.2189, 25.3445,  7.6283],
        [18.5797, 22.3036,  8.9110, 10.0130, 13.8673, 14.4103, 15.7113],
        [12.5089, 13.2266, 19.0562, 25.6709, 24.8790, 13.0035, 13.5458],
        [16.9630, 18.0093, 32.8832, 43.6902, 37.7997, 19.4025, 17.0920],
        [14.0115, 15.4011, 15.7719, 18.5941, 35.2951, 43.7704, 23.0646],
        [23.3049, 35.7899, 30.6675, 13.0851, 16.7516, 15.7948, 16.9200],
        [12.2122, 14.5627, 13.9888, 13.7989, 18.2078, 32.8810, 28.9154]])
ground truth
tensor([[23.6820, 12.2024, 12.3724, 11.0119, 13.8322, 17.1485, 26.1763],
        [15.3866,  4.3398, 14.4003, 17.5434, 25.1841, 35.3761,  5.0368],
        [19.5011, 20.4223,  5.4989,  0.0000,  7.8090, 15.5187, 14.1865],
        [10.2608, 15.4904, 18.0697, 24.2347, 28.7273, 14.1156, 11.0686],
        [12.6134, 15.9722, 36.5363, 51.1621, 36.5930, 17.1344, 17.1202],
        [12.3724, 2

batch_predictions
tensor([[14.6407, 21.3545, 25.8122, 10.2343, 14.8190, 13.1959, 13.4067],
        [24.5850, 22.6048, 27.1687, 40.6809, 50.5102, 43.8091, 29.5219],
        [29.1256,  8.6168, 13.1383, 15.4177, 16.9020, 25.8019, 34.5862],
        [12.9552, 18.9984, 27.6079, 25.0237, 10.2882, 11.9037, 12.7852],
        [32.0208, 18.7546, 13.9122, 13.9344, 14.0336, 15.7915, 24.2792],
        [24.2539, 34.4635, 36.1008, 26.1219, 22.6206, 20.7286, 19.5121],
        [10.6741, 12.1971, 13.6170, 17.7958, 27.0679, 26.8709,  8.4298],
        [17.7573, 18.4496, 31.3487, 47.2913, 42.1797, 23.5059, 19.8777]])
ground truth
tensor([[15.5576, 25.6575, 24.8816, 12.1778, 10.2841, 13.7559, 14.3083],
        [26.3039, 23.8946, 34.0703, 41.2132, 57.7523, 42.2194, 30.9382],
        [32.9649,  6.7319, 14.0873, 17.0493,  2.6927, 20.4649, 36.0402],
        [12.6644, 18.6875, 24.4477, 25.5523, 10.5208, 12.8748,  9.0610],
        [34.8498, 19.2602, 13.1661, 16.9643, 12.3724, 16.3549, 25.0709],
        [28.7612, 3

batch_predictions
tensor([[11.0894, 13.8258, 16.5730, 17.2149, 21.9579, 34.4425, 34.1398],
        [27.2263, 32.9861, 20.9455, 22.2285, 20.9558, 20.6458, 23.7263],
        [16.3038, 16.6616, 30.8103, 42.5931, 33.8331, 17.2349, 18.4119],
        [24.6478, 14.2314, 15.5732, 15.4263, 16.9154, 23.5701, 38.0434],
        [36.0625, 17.0655, 18.6263, 16.9210, 19.5675, 32.2949, 47.0608],
        [29.9657, 32.3026,  8.6814, 14.0495, 14.9679, 15.9975, 20.9920],
        [10.5838, 11.3419, 15.3698, 24.1264, 22.2205,  9.9710, 11.9107],
        [28.1478, 25.1970, 12.6702, 14.0718, 12.7536, 13.1966, 19.8671]])
ground truth
tensor([[ 8.4467, 11.5221, 10.1049, 18.5232, 21.3861, 28.4297, 24.2205],
        [25.0850, 32.9507, 16.3124, 19.0618, 22.1230, 22.5907, 21.8963],
        [21.9529, 15.8588, 33.5176, 37.0465, 36.8197, 18.4524, 15.8022],
        [29.1426, 19.5292, 14.7028, 14.5055, 18.9506, 27.5250, 52.9853],
        [40.0652, 16.5816, 21.7545, 18.0981, 21.0459, 41.6100, 43.6791],
        [29.2800, 3

batch_predictions
tensor([[17.2745, 26.0142, 21.2382,  9.1040, 10.9799, 10.7522, 12.2564],
        [26.1724, 35.6157, 33.1702, 23.0331, 23.3119, 21.4374, 21.9672],
        [ 8.0916, 12.3011, 13.4167, 16.2732, 18.0683, 19.2248,  7.6047],
        [31.7154, 10.6258, 15.4170, 16.4859, 15.8906, 17.4398, 27.8139],
        [12.9506, 13.8176, 14.9135, 20.6020, 28.4565, 28.4947, 12.5726],
        [15.6538, 20.1369, 29.4083, 29.0182, 13.1215, 15.0521, 14.9527],
        [18.6228, 22.5267, 24.9423, 14.1780, 13.5376, 13.0784, 15.7388],
        [ 8.5360,  8.7586, 10.0707, 12.4951, 18.2992, 16.2930,  8.0567]])
ground truth
tensor([[18.8059, 30.3130, 29.4319, 12.6118,  7.6539, 13.8743, 12.6644],
        [31.9569, 38.1378, 40.5576, 29.3004, 28.8138, 24.8422, 24.6844],
        [ 3.2880, 11.2245, 11.7063, 10.6151, 20.0255,  3.2313,  4.5210],
        [29.2800,  4.5493,  4.5493, 14.0164, 19.6854, 26.5306, 32.8231],
        [ 9.0278, 13.8889, 19.2035,  8.3617, 20.9467, 29.9178, 12.9393],
        [12.5855, 1

batch_predictions
tensor([[26.9959, 14.5405, 15.2077, 14.2585, 14.1947, 15.6893, 24.5041],
        [28.3802, 17.8958, 17.9698, 17.2401, 23.0459, 39.5164, 46.9242],
        [13.8706, 14.3249, 14.6008, 19.3368, 24.9971, 27.8437, 18.6992],
        [32.1571, 26.8672, 12.7720, 15.1352, 14.3636, 15.2140, 20.6150],
        [14.1328, 23.8841, 32.7659, 23.0003, 11.5603, 12.6276, 12.6274],
        [29.9211, 38.8882, 26.8716, 12.7068, 14.3557, 14.2502, 16.6011],
        [16.6825, 16.9610, 16.1228, 28.5710, 42.7672, 30.5771, 12.4545],
        [16.0996, 22.1412, 22.2248, 11.2699, 12.5532, 11.7422, 13.1207]])
ground truth
tensor([[24.6457, 16.6950, 16.4683, 16.0006, 12.7693, 15.9155, 25.4535],
        [30.8107, 20.5215, 14.9093,  7.4405, 21.4711, 35.7851, 46.2302],
        [17.4461, 12.3016, 14.5550, 12.2449, 21.7545, 27.1825, 14.0306],
        [30.3571, 24.9150, 17.1060, 16.3690, 16.3407, 16.7234, 21.2302],
        [14.6117, 25.2268, 38.4070, 28.1463, 11.8764, 14.8526,  2.1825],
        [31.5051, 4

batch_predictions
tensor([[17.1194, 17.3115, 16.9543, 18.0471, 26.7012, 38.7085, 31.2198],
        [13.1735, 16.8378, 25.9730, 23.5051, 11.5695, 11.5809, 11.6390],
        [19.0286, 28.4552, 27.4211,  7.7092, 12.1354, 12.4076, 14.8115],
        [21.7568, 10.8794, 10.5585, 11.3195, 12.3144, 18.0230, 21.1157],
        [34.7788, 34.1917, 16.2098, 17.9447, 16.4281, 16.5377, 28.9127],
        [26.2403, 29.8466, 28.3078, 13.5411, 13.8352, 14.3862, 18.0648],
        [ 9.9785, 13.0085, 19.5821,  7.6699,  7.2671,  8.7345,  9.0976],
        [28.8278, 24.6126,  9.6715, 12.6055, 14.8899, 17.3786, 22.6313]])
ground truth
tensor([[16.9516, 15.5313, 18.6744, 19.7922, 26.0521, 36.4150, 26.6570],
        [11.8753, 16.7938, 22.8564, 23.3167, 12.0989, 12.3619,  8.1273],
        [23.5119, 31.2642, 35.7568,  5.6831, 13.7330, 19.0618, 16.4116],
        [ 0.0000,  0.0000,  4.2478, 12.4671, 11.0600, 15.9916, 21.8175],
        [23.0442, 26.1480,  9.2404,  0.1417, 10.8135, 12.9252, 24.3056],
        [25.0425, 4

batch_predictions
tensor([[13.3871, 22.2634, 14.5217,  9.8197, 10.0045,  9.1668, 10.3128],
        [13.1520, 15.4865, 24.5660, 25.7350, 11.0417, 12.0350, 11.4698],
        [26.3102, 13.8115, 15.4706, 15.1100, 18.0140, 26.2277, 35.2351],
        [11.5698, 12.1139, 15.3237, 26.7853, 28.4854,  8.4480, 11.6737],
        [24.1247, 26.8318,  8.9074, 13.0668, 12.9330, 13.2115, 19.4249],
        [ 9.3515, 11.3835, 12.3025, 14.9670, 19.3294, 23.5713,  6.6959],
        [13.7009, 12.5596, 13.3762, 18.4319, 23.9510, 22.4866, 12.7438],
        [13.9777, 17.0709, 15.9890, 15.6220, 25.8673, 36.9353, 36.2762]])
ground truth
tensor([[16.4256, 22.5539, 15.7154, 16.8332, 13.2956, 21.3966, 12.8880],
        [15.0368, 17.4178, 29.0533, 18.4240,  0.0000,  6.9586, 12.9677],
        [41.6525, 22.2222, 10.4875, 16.7375, 18.6224, 22.6616, 31.2783],
        [ 9.9816, 12.2567, 17.4382, 31.0363, 31.7201,  8.6402, 12.5460],
        [21.3861, 23.0017, 10.7001, 10.0057, 10.9694, 11.1820, 12.2024],
        [ 4.7870, 1

batch_predictions
tensor([[14.4715, 19.4827, 27.0039, 30.6098,  7.9353, 15.3032, 15.0123],
        [12.9872, 19.8765, 25.3694, 23.3385,  9.5865, 12.2398, 12.2348],
        [14.5980, 14.9097, 14.4197, 19.9874, 28.9383, 26.1715, 14.0065],
        [19.8793, 23.1571,  9.9003, 11.4842,  9.9320,  9.8548, 13.4838],
        [19.4873, 26.3840, 23.0054, 12.8106, 13.4062, 12.0644, 14.1401],
        [11.9711, 12.8857, 12.7618, 13.3988, 17.7370, 24.4229, 23.8747],
        [11.2694, 15.3824, 15.4745, 14.9573, 18.3339, 27.9793, 26.8681],
        [24.8527, 31.9553, 21.0641, 14.7992, 15.4149, 15.0551, 18.8871]])
ground truth
tensor([[17.7154, 19.7562, 26.0204, 27.6502,  5.8532, 15.5896, 13.7755],
        [14.3707, 18.9201, 21.6695, 18.0839,  8.6168, 12.3158,  9.8073],
        [18.1122, 14.6117, 14.9943, 24.0221, 18.0981, 22.4206, 13.8464],
        [19.6344, 16.6623,  9.6528, 10.9679, 11.1126,  9.9553, 18.1220],
        [18.8209, 26.7574, 23.0867, 13.0811, 13.7330, 11.5363, 16.3974],
        [12.2307, 1

batch_predictions
tensor([[28.1803,  7.5933, 12.7469, 12.5552, 12.9297, 18.5725, 33.3601],
        [27.1061, 30.6301, 12.1428, 14.4438, 12.9296, 13.3663, 20.0273],
        [13.9034, 13.7244, 13.2737, 17.5987, 23.5248, 24.9609, 10.0279],
        [13.7458, 18.1507, 27.1509, 31.9727,  6.8964, 13.8456, 13.0136],
        [16.7261, 16.7888, 31.7536, 44.1932, 41.8726, 17.4672, 17.7995],
        [11.6836, 11.5177, 12.0337, 15.9261, 23.3659, 23.0500, 10.1315],
        [14.1453, 14.1029, 14.7710, 19.0617, 25.6346, 30.1828, 20.0304],
        [11.0370, 10.4058, 10.5454, 13.0738, 20.7239, 22.2088,  9.7811]])
ground truth
tensor([[32.2846, 13.9598, 15.1077, 15.8022, 12.8827, 20.3090, 30.7398],
        [28.8664, 28.7875,  9.0873, 11.7175,  8.7322, 11.5203, 16.5176],
        [13.4140, 11.4940, 11.8096, 13.1247, 24.0663, 25.8943, 11.7570],
        [13.4070, 20.7908, 36.0261, 45.8759, 11.9189, 18.8634,  3.6706],
        [16.3335, 15.3472, 33.8506, 45.4366, 79.2346, 34.7054,  8.9164],
        [ 9.3503, 1

tensor([[ 8.8865, 11.7624, 11.4007, 11.8033, 13.1129, 18.4537, 21.9571],
        [12.1599, 11.5362, 11.7789, 13.1806, 18.2921, 26.7649, 25.4845],
        [23.5058, 13.8279, 16.0989, 14.7889, 15.4055, 21.0629, 30.7335],
        [22.8877, 39.5485, 37.4210, 17.4732, 20.1476, 18.1579, 17.9653],
        [24.8698, 23.8833, 13.1522, 15.1889, 16.2763, 20.5096, 22.6199],
        [13.9762, 15.1101, 25.5137, 31.5651, 32.3067, 12.4145, 14.9058],
        [39.7117, 36.7972, 13.5701, 17.3725, 15.8729, 17.2253, 26.6504],
        [18.8220, 25.6571, 29.6927, 17.8961, 15.7099, 14.3366, 15.3882]])
ground truth
tensor([[ 6.9043, 10.9679, 12.3488,  9.4424, 13.3745, 20.8311, 25.2236],
        [12.0989, 12.3619,  8.1273,  9.4161, 14.6502, 27.1305, 18.0957],
        [23.3561, 13.5324, 16.0047, 12.3882, 13.2299, 12.5855, 34.4292],
        [36.0261, 36.7914, 25.6094, 18.4382, 21.9388, 10.5867, 15.7738],
        [19.2004, 17.1883,  8.9295, 13.5455, 12.2962, 11.9411, 15.5181],
        [17.6162, 16.2557, 24.6740, 3

batch_predictions
tensor([[ 8.3401,  6.9972,  7.5278,  9.0221,  9.8617, 12.0946, 12.7015],
        [11.4055, 12.7460, 15.8315, 21.5190, 21.8374,  9.3882, 11.2363],
        [15.9542, 15.4020,  6.3579,  7.1567,  6.9580,  8.0349, 11.2711],
        [35.7462, 31.4553, 17.8509, 17.3399, 17.0900, 22.6437, 35.5287],
        [ 7.4926,  8.3559,  8.8230, 10.3162, 14.6184, 16.9324,  6.5008],
        [19.8545, 17.8188, 16.9741, 18.9536, 28.4356, 46.7728, 42.6672],
        [21.4530, 33.3685, 34.5327,  9.0657, 15.6328, 15.2614, 16.4315],
        [16.6190, 17.5262, 17.5349, 20.2894, 29.2598, 40.6711, 31.3499]])
ground truth
tensor([[10.3630,  5.6549,  8.3772,  9.7054,  8.8901, 10.5471, 13.4929],
        [12.4858, 18.4240, 16.6808, 31.0516, 29.7761,  9.6514, 15.8872],
        [15.4787, 12.0200,  4.3661,  7.3777,  4.9448,  7.2199, 11.7964],
        [46.4994, 40.8163, 18.5516, 11.4229, 15.6746, 22.7183, 50.0000],
        [10.0736,  8.1010,  8.4166,  9.9816, 17.0437, 18.3456,  4.0373],
        [19.9106, 1

batch_predictions
tensor([[13.7793, 18.5169, 24.8943, 22.5311, 12.0676, 13.2093, 12.9031],
        [16.3697, 31.5963, 39.3127, 29.3889, 15.2628, 15.3063, 14.6114],
        [14.3975, 15.9280, 15.5732, 16.2602, 18.6095, 25.9004, 30.1078],
        [35.3977, 27.6359, 16.2955, 17.1816, 16.4255, 17.3842, 26.8891],
        [20.8918, 24.1511, 34.0530, 38.7919, 35.0736, 19.9066, 18.6787],
        [29.4371, 40.0871, 47.2978, 41.0492, 32.1035, 28.9972, 27.5507],
        [20.6475, 21.9982,  8.2001, 11.3963, 11.3021, 11.7631, 16.8572],
        [19.0399, 24.1698, 20.3232, 10.5476, 14.0349, 13.5018, 14.6545]])
ground truth
tensor([[14.4135, 14.4924, 25.1578, 21.4229, 13.8348, 15.7023, 15.9390],
        [21.4994, 33.3050, 50.2126, 28.7982, 10.8560, 16.0998, 19.7562],
        [15.3340, 16.7543, 12.5197, 14.6633, 16.9122, 21.1862, 26.3809],
        [42.0831, 36.7438, 16.9516, 15.5313, 18.6744, 19.7922, 26.0521],
        [23.0584, 26.9558, 38.6480, 33.7160, 32.8231, 18.2965, 18.3532],
        [22.2789, 4

batch_predictions
tensor([[25.7366, 20.7148, 18.3279, 18.2029, 29.3951, 41.4167, 37.3958],
        [19.8820, 28.8337, 39.5160, 30.4146, 12.7720, 15.1135, 16.4586],
        [12.4943, 13.7338, 15.3953, 20.1187, 29.8970, 29.1931, 16.2994],
        [11.2130, 12.7292, 12.0981, 12.1074, 14.1309, 19.9586, 22.2712],
        [25.7871, 34.4827, 37.8673, 23.0656, 20.3050, 20.1700, 20.6040],
        [16.5842, 11.0369,  6.7262,  7.9070,  7.2494,  7.6730, 11.9651],
        [21.9582, 26.4999, 34.4345, 19.7034, 20.2602, 19.2453, 19.9830],
        [11.3411, 11.9594, 13.6381, 13.2744, 19.4546, 15.5560,  7.3962]])
ground truth
tensor([[ 0.0000,  9.5096, 17.7721, 18.1264, 24.9717, 43.2540, 33.8294],
        [17.8146, 31.3776, 42.8713, 35.3600, 16.1706, 11.4796, 15.0368],
        [10.5584, 10.9694, 12.1740, 15.0510, 22.4065, 31.3917, 16.9218],
        [10.2972,  9.0347, 10.9285, 10.8627, 12.7696, 18.8585,  0.0000],
        [26.7715, 39.0023, 51.9841, 39.9660, 24.7449, 23.6111, 26.7999],
        [14.2425, 1

batch_predictions
tensor([[32.2990, 27.3705,  9.0410, 13.5993, 13.4269, 15.8563, 22.9866],
        [27.4072, 40.8338, 32.5854, 20.2311, 16.3739, 15.6482, 17.7299],
        [13.2523, 15.9106, 15.5648, 16.5253, 18.8329, 25.4733, 24.3780],
        [19.8195, 23.3192, 29.0302, 30.0504, 12.5898, 16.1875, 17.5298],
        [24.1644, 27.2256, 17.8251, 13.5148, 13.5756, 13.1593, 16.5528],
        [15.4335, 26.0361, 32.5058, 30.0089, 12.4254, 14.6451, 14.7878],
        [37.4638, 24.0583, 16.3860, 16.6092, 16.1108, 16.6720, 31.5633],
        [17.9289, 25.8359, 27.4225, 10.6156, 15.2142, 13.6297, 14.6258]])
ground truth
tensor([[31.2075, 22.7183,  7.0153, 15.6604, 10.5017, 14.0306, 20.5924],
        [24.7591, 40.5471, 30.4989, 18.8209, 18.5374, 12.8401, 17.5454],
        [13.3088, 16.9253, 13.9663, 15.7549, 21.3966, 21.0153, 27.1831],
        [21.0601, 29.0533, 28.1179, 28.2171, 22.1655, 24.0788, 22.2080],
        [22.5802, 26.7228, 14.9132, 11.9411, 13.9532, 17.5829, 15.5050],
        [16.2557, 2

batch_predictions
tensor([[15.9344,  7.8267, 10.9166, 12.6319, 13.6846, 16.9691, 17.9370],
        [18.3243, 17.7140, 15.7818, 17.0556, 24.2646, 41.5634, 40.2449],
        [14.0505, 16.2734, 19.8052, 29.6417, 31.6929, 10.8050, 13.5798],
        [ 6.9751,  6.6209,  7.2339,  6.9987,  9.3343, 10.8601, 13.6011],
        [12.9323, 14.4395, 14.8484, 18.7612, 26.7481, 30.7110, 11.1151],
        [ 5.7082,  6.4166,  7.3598,  7.0198,  5.5380,  4.7217,  4.8432],
        [15.2098, 15.6729, 16.5961, 25.9654, 36.6187, 32.3465,  9.5088],
        [31.2999, 41.0060, 33.4555, 19.3935, 20.0253, 17.9160, 19.5592]])
ground truth
tensor([[24.9079, 12.9537, 14.2162, 21.4361, 15.5839, 14.8080, 17.3067],
        [14.1110, 16.4256, 16.1231, 19.1215, 22.9748, 43.0037, 42.7012],
        [ 8.6007, 13.3614, 20.9890, 20.2130, 34.7712, 13.7559, 15.5576],
        [ 4.7475,  6.2336,  6.7859,  6.1152,  8.0878, 15.5708,  3.7612],
        [14.6896, 13.7428, 18.2009, 14.6107, 23.9348, 30.8916,  4.6949],
        [ 4.1241,  

batch_predictions
tensor([[26.5544, 29.6315, 11.8113, 14.1497, 13.1148, 13.4073, 19.7705],
        [20.0558, 17.1086,  6.5025,  7.3229,  8.0318,  9.1632, 16.2891],
        [16.7352, 16.1062, 16.2161, 21.0202, 29.0454, 35.1264, 28.8800],
        [17.6549, 17.8944, 19.6659, 34.6173, 41.3704, 36.7386, 19.3350],
        [15.0475, 16.4050, 22.1268, 33.7039, 35.4907,  9.4517, 15.2151],
        [ 9.3064, 12.0852, 20.6711, 16.2939, 12.9171,  9.7024,  9.3510],
        [20.3816, 37.2345, 34.9921, 21.4723, 20.0727, 17.4829, 16.0159],
        [13.6050, 15.4803, 14.8398, 15.0026, 19.9916, 34.8379, 29.7540]])
ground truth
tensor([[24.2063, 27.2251, 11.5221, 12.5283, 15.4053, 13.7755, 15.8872],
        [21.0284, 18.0037,  1.7622,  6.8648,  6.3782,  9.7186, 12.8090],
        [21.3966, 24.2635, 19.5160, 19.3319, 27.8538, 32.5750, 23.2246],
        [16.7375, 18.1548, 19.3169, 25.5952, 47.6332, 14.7534,  0.0000],
        [15.0935, 16.2840, 19.8129, 31.0232, 32.9082,  5.4563, 16.7234],
        [10.8364, 1

batch_predictions
tensor([[18.4889, 27.6448, 33.6025, 14.9508, 17.2304, 14.8713, 15.5686],
        [19.6575, 26.1874,  8.0504,  8.7078,  9.4241,  9.6436, 12.9880],
        [46.2681, 43.2443, 32.2919, 29.1075, 28.0102, 29.0152, 36.2385],
        [15.4319, 24.5694, 34.6387, 14.8423, 15.3721, 14.6200, 15.5337],
        [13.6151, 19.9465, 14.4708,  8.2309,  9.8364,  9.4670, 11.6889],
        [16.5876, 25.3955, 27.1188, 16.9498, 13.5453, 13.2434, 14.9197],
        [38.0811, 25.4588, 12.4328, 14.2431, 13.6594, 15.4084, 31.2728],
        [17.2245, 23.3613, 25.2470, 10.9004, 13.6370, 13.7890, 14.9432]])
ground truth
tensor([[21.4144, 30.7965, 40.4195, 20.2098, 19.1043, 16.6950, 19.1752],
        [12.9932, 23.7112,  6.9832,  6.9437,  9.1136,  7.9958, 11.7833],
        [53.4297, 50.6519, 31.7319, 24.9433, 20.9042, 23.3277, 36.7914],
        [13.3877, 18.7533, 33.4824, 13.8611, 13.7033, 11.3756, 13.8348],
        [10.2841, 13.1247, 16.4387,  7.7591,  7.9826,  7.5750, 11.3361],
        [19.4728, 2

batch_predictions
tensor([[21.8929, 28.4468, 36.3218, 27.6466, 21.1222, 19.6807, 19.5080],
        [ 7.6759,  7.7124,  7.3369,  7.6630,  9.8061, 14.7002, 18.7495],
        [32.7002, 12.5859, 16.2132, 17.3033, 19.4436, 25.1802, 29.5205],
        [17.6527, 19.8917, 19.1524, 20.1999, 24.9115, 37.3036, 44.1784],
        [27.5340, 14.6299, 16.2555, 16.3971, 18.8152, 33.9395, 37.1755],
        [31.4099, 21.4404, 13.7894, 13.6890, 14.3547, 15.4838, 25.5239],
        [14.3048, 14.7656, 20.2149, 28.8636, 28.1396,  8.8653, 12.5475],
        [10.7471, 12.6767, 20.3287, 25.0661,  8.5225, 12.4975, 12.0162]])
ground truth
tensor([[24.6712, 32.5881, 37.7433,  0.0000, 18.1220, 28.2746, 24.9079],
        [ 5.8107,  7.6531,  7.0011,  8.6026, 14.8810, 15.8730, 21.7545],
        [39.6633, 14.5713, 15.7943, 15.0184, 16.0705, 17.3593, 21.1862],
        [43.0037, 14.4792, 19.9369, 21.6991, 23.5271, 30.9048, 42.5434],
        [37.2307, 26.3747, 26.0629,  2.4093, 12.2307, 23.4127, 39.5125],
        [28.7273, 2

batch_predictions
tensor([[ 9.0341,  9.0183, 10.1482, 14.4479, 21.1696, 18.8026,  7.6657],
        [14.1936, 18.5168, 27.0513, 14.6437, 12.6217, 12.9341, 13.6034],
        [19.5326, 28.1499, 28.6442, 13.6829, 14.5260, 13.2720, 15.0546],
        [12.3200, 11.9812, 13.6535, 10.3223, 10.4379,  9.9414, 10.3569],
        [13.6493, 14.7991, 29.0003, 38.2057, 25.1130, 15.2038, 14.4478],
        [12.9888, 13.8322, 16.0710, 27.4682, 26.2120, 16.9135, 10.9815],
        [ 9.7587, 10.4675, 13.1741, 11.6651,  9.0197,  9.2039,  9.5578],
        [22.5621, 32.6965, 29.9809, 14.4588, 17.0638, 15.5725, 16.9984]])
ground truth
tensor([[10.1789,  7.1278, 10.9416, 15.2420, 24.6844, 15.9916,  0.0000],
        [21.4711, 29.8469, 33.6451, 14.3566, 11.7772, 12.5850, 16.2982],
        [17.8590, 23.9742, 20.8837, 11.8359,  8.0747, 11.1389, 15.8732],
        [ 7.2988, 21.2914, 20.2262,  4.5502,  0.0000, 11.3098, 18.6349],
        [12.9819, 15.2636, 35.8560, 45.8333, 36.7063, 15.6746, 12.2591],
        [16.4257, 1

batch_predictions
tensor([[19.5529, 20.4450, 25.8167, 32.0525, 34.4498, 20.9438, 20.6246],
        [14.6476, 12.6187,  6.6355,  7.7728,  6.9850,  7.2710, 11.8400],
        [11.4676, 13.2315, 17.2423, 27.6551, 32.2871, 10.6100, 14.6530],
        [10.5222, 14.0745, 13.9065, 14.0640, 17.2387, 24.8987, 30.4520],
        [20.7527, 31.5026, 33.0308, 14.6349, 16.3834, 14.5874, 15.9483],
        [19.7204, 26.3864, 25.7259, 12.6768, 13.9833, 13.5079, 14.7844],
        [21.4105, 25.7725, 28.1797, 30.5367, 13.0185, 15.5341, 17.2906],
        [21.0455,  8.8383,  9.4931, 11.5211, 14.0078, 16.7075, 20.1009]])
ground truth
tensor([[20.4103, 23.3561, 26.1573, 32.8643, 35.7312, 23.0800, 25.5655],
        [14.0978, 10.7180,  3.3140,  9.5082,  8.9953,  8.4035, 12.4803],
        [10.9410, 12.8118, 22.5198, 32.2137, 22.5198,  6.5901, 10.6293],
        [10.0736, 11.2835,  9.1005, 13.4534, 18.2272, 23.0537, 32.6539],
        [19.7421, 23.7670, 34.0561, 12.6417, 15.7596, 12.8968, 13.3078],
        [26.8991, 2

batch_predictions
tensor([[ 8.2365,  8.9636,  9.2916, 12.3114, 17.8201, 16.7699,  7.5141],
        [28.6683, 10.9286, 14.2327, 13.6250, 15.3748, 18.6589, 30.0224],
        [14.4538, 14.8986, 19.0802, 27.4107, 25.1754, 17.3365, 16.0863],
        [16.6242, 21.3390, 33.8399, 37.0450,  9.9674, 15.9291, 15.1425],
        [ 7.4199,  8.1091,  9.9285, 15.4209, 20.4830,  6.7691,  7.5221],
        [ 8.4786, 11.3740, 16.6875, 19.1068,  6.8795,  9.4817,  8.0432],
        [15.7926, 16.3676, 17.8568, 35.0410, 46.3419, 35.9971, 14.2031],
        [37.1787, 41.3866, 30.4465, 19.4570, 17.8388, 17.5259, 21.9821]])
ground truth
tensor([[ 7.7196,  7.3645,  8.8769, 10.6128, 16.6754, 16.5045,  7.0358],
        [29.7194, 11.1678, 13.4921, 14.3707, 19.9688, 19.8413, 27.0833],
        [12.5460, 17.2672, 17.3724, 17.2541, 21.9227,  2.1436, 12.1515],
        [18.9909, 22.5907, 34.0420,  9.3821, 11.8197, 26.7857, 20.2806],
        [ 7.0011,  8.6026, 14.8810, 15.8730, 21.7545,  6.9870,  7.2279],
        [ 9.5739, 1

batch_predictions
tensor([[14.8489, 15.7491, 25.1693, 33.0142, 21.1270, 14.2153, 12.8547],
        [18.3679, 16.7549, 16.3319, 25.6330, 37.3758, 36.2779, 23.4883],
        [13.8266, 21.8031, 19.7954,  7.8580,  9.5611,  8.6309, 10.5009],
        [12.8699, 17.3342, 22.2486, 18.9412, 10.0539, 10.7439, 11.3518],
        [30.1345, 27.9641, 11.2826, 13.6485, 13.6157, 14.6759, 23.2974],
        [29.5411, 21.0561, 20.8865, 22.5348, 29.3197, 37.0832, 32.2794],
        [24.5406, 33.6981, 24.0531, 12.2725, 13.4266, 12.9167, 15.8934],
        [25.5386, 11.0438, 10.0046, 10.9349, 10.9060, 11.5528, 19.2210]])
ground truth
tensor([[ 1.3464, 14.8951, 24.3056, 28.7273, 27.1684, 13.3220, 10.7710],
        [14.4274, 13.0385, 15.9864, 30.0737, 45.4507, 41.7234, 16.6950],
        [15.2420, 24.6844, 15.9916,  0.0000,  8.9164, 11.4808,  9.2714],
        [10.8759, 11.8622, 25.1447,  5.2341,  8.2720, 13.4929,  6.4440],
        [22.8564, 17.5171, 12.3882, 16.0179, 14.1636, 11.6649, 20.3314],
        [14.8474, 1

batch_predictions
tensor([[16.8090, 16.3264, 22.5868, 40.9130, 38.4307, 16.2980, 18.2411],
        [18.6145, 17.8450, 26.0233, 42.4783, 41.1751, 18.3970, 19.7855],
        [30.3234, 26.1063, 25.7951, 29.0428, 40.8947, 52.1625, 45.6970],
        [12.7964, 15.6184, 15.2738, 16.5547, 19.9519, 26.4438, 29.1113],
        [18.7412, 20.3471, 19.9708, 21.7716, 26.5517, 33.0893, 20.7082],
        [10.4858, 13.5723,  9.5234,  6.8747,  6.8573,  6.9893,  8.5944],
        [ 8.6053, 10.2331, 16.3938, 21.4394,  6.3996,  7.2655,  7.8709],
        [12.6573, 12.7613, 15.8728, 22.3780, 23.0328,  9.2677, 12.6762]])
ground truth
tensor([[12.7126, 21.8679, 24.0788, 35.5300, 43.9059,  8.9994,  1.3039],
        [22.6757, 14.5125, 25.9070, 42.4745, 41.1565, 12.5850, 21.3010],
        [27.5794, 19.7562, 24.5748, 25.4819,  7.4405, 90.2778, 50.4819],
        [17.0918, 13.9172, 12.9393, 15.3486, 18.7500, 29.6769, 30.7256],
        [23.6961, 20.7058, 23.4410, 26.3464, 25.0850, 32.9507, 16.3124],
        [ 9.6372, 1

batch_predictions
tensor([[29.9307, 40.9804, 31.5149, 23.3417, 21.4536, 19.4176, 21.1494],
        [13.7594, 13.4843, 13.7886, 19.9421, 26.8828, 27.3650, 10.5034],
        [17.0721, 15.8552,  7.9337,  8.2284,  8.1759,  8.4176, 11.4757],
        [11.3431, 16.0747, 23.2842, 21.1154,  8.7787, 11.2286, 11.3027],
        [ 9.6305, 11.6257, 18.8698, 24.3557,  7.2181, 10.0280, 10.3743],
        [22.5704, 33.5868, 37.4830, 32.9039, 13.6857, 16.8007, 17.5303],
        [ 8.8468,  6.7780,  7.6405,  8.3163,  8.4415, 10.9559, 14.1145],
        [12.0389, 13.2678, 13.0053, 14.4356, 19.2281, 30.1848, 28.8083]])
ground truth
tensor([[24.8948, 39.3214, 33.3772, 20.3709, 17.2409, 16.8332, 10.2183],
        [13.8348, 12.6644, 14.1504, 21.5413, 17.6749, 22.8827,  8.8506],
        [16.6754, 16.5045,  7.0358,  8.7717,  6.5886,  7.0621, 10.6654],
        [16.2152, 14.0715, 21.0021, 13.5587,  8.7059,  7.4566,  9.1794],
        [ 9.3766,  8.9953, 11.1520,  0.0000,  2.8932,  9.1794,  8.4298],
        [23.2851, 3

batch_predictions
tensor([[30.1367, 12.4500, 15.9209, 14.0985, 14.3439, 18.2943, 29.3840],
        [ 9.5615,  9.9569, 13.1879, 21.4306, 22.1672,  8.7766, 10.6397],
        [27.7698, 14.0553, 16.2766, 15.5141, 16.0816, 22.7024, 33.3021],
        [22.0178, 20.5828, 19.1889, 21.7573, 36.9128, 51.5076, 45.2173],
        [ 7.2869,  9.0276,  8.1134,  8.6868, 12.4539, 16.1263, 21.9964],
        [11.0444, 11.8490, 13.9486, 19.7405, 26.4580,  6.2107,  8.8033],
        [16.3491, 26.0814, 31.9499, 12.2356, 14.8143, 14.0275, 15.2206],
        [14.4501, 16.6430, 16.6070, 19.2772, 34.0225, 44.3627, 35.1066]])
ground truth
tensor([[35.4024, 13.5192, 13.5192, 11.6386, 11.6649, 19.7659, 28.7743],
        [11.1126,  9.9553, 18.1220, 26.0521, 23.8559,  8.0747, 11.0600],
        [ 8.7727, 10.9410, 17.0068, 16.3690, 21.0743, 26.0346, 40.4337],
        [29.8527, 28.1694, 20.3314, 23.0800, 29.5371, 49.6449, 44.5950],
        [ 6.2642,  9.7506,  6.8594,  6.5901, 16.1990, 17.6587, 20.8192],
        [13.9663,  

batch_predictions
tensor([[29.2746, 29.2568,  8.5227, 12.5800, 11.7011, 12.3627, 19.9713],
        [11.5094, 15.2446, 16.1349, 16.7962, 19.3645, 27.6436, 28.4273],
        [11.2699, 11.5766, 13.4546, 20.0902, 27.0547, 14.2085, 12.2137],
        [16.2031, 15.1815, 17.4574, 24.3529, 33.4328, 26.9622, 14.5599],
        [45.0746, 32.6901, 23.3421, 19.8046, 17.2868, 21.1386, 38.4994],
        [22.8636, 10.2312, 13.2454, 13.0394, 13.5784, 17.0540, 23.6958],
        [18.6623, 19.5433, 28.5803, 40.7553, 36.5762, 18.7851, 18.8547],
        [15.4684, 16.8485, 18.3791, 24.4269, 31.5951, 35.5771, 13.3722]])
ground truth
tensor([[27.4987, 28.8795, 11.9279, 16.9253, 13.9006, 13.3877, 21.2914],
        [18.8322, 22.5539, 17.7801, 20.5155, 19.1610, 27.4461, 30.4840],
        [ 9.1978, 11.6071, 14.2432, 18.1973, 29.0533, 11.8906, 10.8560],
        [14.8951, 17.4603, 17.1202, 22.7324, 28.7840, 30.0028, 13.1519],
        [45.5155, 50.2762, 20.6996, 12.9537, 16.2415, 24.3425, 35.3630],
        [23.3693,  

batch_predictions
tensor([[15.7265, 17.0792, 30.8194, 42.8816, 36.6448, 19.2345, 18.2480],
        [20.9704, 17.7129, 15.4919, 16.4672, 22.4738, 34.1881, 31.6882],
        [15.2095, 24.9033, 42.5002, 35.0626, 14.6983, 16.7615, 15.0980],
        [20.2236, 16.8327, 15.9559, 13.9618, 14.3753, 16.2914, 21.3711],
        [ 8.9457, 11.7727, 10.2559,  6.2331,  5.3865,  5.7926,  6.8762],
        [12.7037, 19.7472, 24.5920, 21.7048, 11.0654, 12.8782, 12.0311],
        [12.4959, 13.7405, 15.0105, 22.7452, 27.7939, 21.9445, 11.9526],
        [17.1918, 27.5294, 28.3742, 24.4402, 19.0566, 16.9484, 17.3841]])
ground truth
tensor([[15.4904, 14.5408, 31.6893, 56.1366, 41.2273, 17.2477, 16.0431],
        [15.9580, 15.0652, 19.6145, 21.2585, 22.5482, 28.3163, 32.4688],
        [13.9739, 23.7954, 48.1434, 11.4938, 17.3328, 17.7721, 14.3707],
        [24.3622, 17.4178, 17.5170, 25.4393, 19.6995, 20.6066, 21.7545],
        [ 7.1570,  5.0595,  9.1270,  4.8753,  6.2925,  6.6468,  7.3271],
        [ 7.9790, 1

batch_predictions
tensor([[16.8523, 25.8468, 35.3817, 26.2721, 14.4393, 15.2561, 14.8674],
        [14.7300, 16.6617, 22.8841, 30.5135,  7.4902, 13.3088, 14.1205],
        [22.8473, 35.7675, 25.4328, 12.8943, 15.0845, 15.2943, 16.2208],
        [13.8040, 21.9840, 16.9843, 11.5599, 12.4684, 10.3569, 12.2818],
        [14.3591, 15.9399, 23.8861, 31.2611, 19.1950, 14.8267, 15.4032],
        [14.4809, 16.1912, 30.1914, 41.1803, 27.2579, 17.7038, 17.4497],
        [22.6507, 33.2800, 27.0587, 10.3206, 12.5261, 12.7410, 15.2113],
        [ 6.9531,  7.2655,  9.7136, 10.6092, 12.0614, 13.9800, 18.0015]])
ground truth
tensor([[15.1631, 25.1578, 26.9069, 27.9853, 15.4787, 17.4382, 13.8874],
        [14.3477, 16.8990, 22.4619,  5.4971,  4.2872, 16.1231,  4.5897],
        [21.6978, 28.7840, 34.7506, 17.5454, 17.7863, 10.7285, 11.9331],
        [14.2031, 21.3835, 13.2562, 11.7438, 11.0337, 10.9548, 11.5992],
        [12.3724, 16.3549, 25.0709, 25.2551, 21.9388, 18.5091, 15.3061],
        [12.5000, 1

batch_predictions
tensor([[15.8304, 22.5479,  7.1778,  8.7975,  9.3277,  9.6303, 11.8422],
        [37.0929, 38.2414, 22.2499, 18.9662, 15.3194, 15.1106, 20.0426],
        [22.8041, 27.1551, 31.4716,  9.6697, 14.6331, 14.9022, 16.4664],
        [14.3949, 14.3302, 19.2085, 28.9928, 33.5375,  8.7552, 14.7425],
        [12.0060, 15.2155, 22.1954, 22.7376,  8.0902, 11.9841, 12.2930],
        [17.0392, 18.1184, 18.0439, 30.5658, 46.3781, 40.2053, 13.4828],
        [21.0783, 22.0373, 24.9120, 42.2898, 49.3981, 41.2328, 18.5698],
        [15.3904, 16.5250, 23.4070, 36.7710, 38.6090,  9.7017, 15.4290]])
ground truth
tensor([[18.4903, 21.1599,  3.9584,  8.5613,  9.2320, 10.9942, 12.2304],
        [45.7200, 38.7755, 22.5482, 19.2177, 24.3764, 17.8713, 11.4087],
        [17.7579, 24.2489, 24.4898, 12.2449, 14.6825, 14.1723, 12.1032],
        [11.2954,  1.8424,  5.6973, 20.2239, 28.6423,  5.0737, 11.9473],
        [12.6907, 16.0047, 22.6854, 23.3035,  9.1268, 11.8622,  9.9684],
        [19.8129, 1

tensor([[ 7.6804, 10.3819,  9.1171,  8.1948, 11.5674, 16.4263,  8.8888],
        [14.5709, 15.1404, 18.1766, 26.3709, 28.3802, 20.5788, 12.1755],
        [11.9473, 13.4650, 16.3008, 26.0508, 27.6698,  8.6799, 11.2136],
        [16.5153, 15.8043, 20.4208, 30.8234, 24.7133, 18.6914, 18.6035],
        [15.3800, 15.2323, 17.4780, 21.4790, 23.9000, 16.4691, 16.2774],
        [14.7082, 20.1655, 26.7886, 23.5572, 12.1157, 12.4397, 13.9947],
        [31.4401, 28.9385, 11.2556, 14.7957, 13.1599, 13.8602, 19.3292],
        [ 9.4818,  9.7725, 14.3752, 20.8518, 19.3648,  7.2825,  8.8238]])
ground truth
tensor([[ 3.1179,  7.0862,  8.3475,  9.2545, 15.2353, 17.0777, 26.5448],
        [11.5788, 13.4637, 13.2653, 25.7937, 28.3730, 25.9212,  7.1712],
        [ 7.7065, 15.0316, 14.3214, 23.7112, 24.3556,  7.4303,  9.9553],
        [22.8301, 23.6981, 19.1215, 26.5124, 29.1557, 15.7549, 19.2267],
        [11.0261, 15.0227, 16.4541, 21.7829, 24.3622, 17.4178, 17.5170],
        [13.0102, 21.1026, 24.3197, 1

batch_predictions
tensor([[11.5299, 13.9135, 10.2671, 10.4741, 12.0831, 10.1027,  8.2021],
        [17.3095, 16.7659, 15.7593, 22.4324, 35.5162, 45.2368, 31.3061],
        [21.8391, 24.4454, 29.5940, 12.5191, 15.6223, 14.6379, 15.8387],
        [15.4254, 18.9542, 30.9243, 32.5915, 18.4013, 18.1982, 16.6410],
        [17.0642, 27.1133, 22.5570, 10.2459, 13.0765, 12.0376, 14.0222],
        [ 9.0377, 13.8302, 13.4055, 14.0118, 18.5543, 24.0808, 20.5730],
        [14.5477, 15.2398, 16.2379, 24.0688, 36.9736, 33.2890,  8.5219],
        [ 8.6924,  7.6908,  8.8717,  9.5365,  9.8164, 11.3429, 15.5221]])
ground truth
tensor([[21.2914, 20.2262,  4.5502,  0.0000, 11.3098, 18.6349, 16.5176],
        [18.5516, 16.9359, 17.3895, 13.7046,  0.0000, 21.6837, 35.5867],
        [28.5147, 17.1627, 22.7466,  9.4388, 10.7001, 16.7800, 15.8447],
        [14.7948, 18.0826, 29.2478, 25.8022, 18.6086, 21.1862, 14.1110],
        [16.1281, 22.4773, 18.5941, 11.1820, 13.4637,  9.4813, 14.1723],
        [ 5.7470,  

batch_predictions
tensor([[15.2350, 18.1876,  6.7257,  8.3953,  7.4331,  7.5810,  8.3606],
        [ 5.6686,  6.0168,  6.5722,  6.8519,  8.6912,  7.8155,  5.9746],
        [18.2774, 17.0972, 18.1483, 23.1115, 37.9589, 41.0202, 24.4206],
        [10.6524, 16.2911, 16.8562,  6.8726,  6.8011,  7.5684,  7.5588],
        [10.4700, 11.2064, 14.0287, 22.2010, 22.2560,  9.0435, 10.7979],
        [19.2585, 25.4756, 38.8623, 35.6720, 12.6426, 16.0699, 16.6063],
        [36.8514, 45.7206, 37.6031, 23.9772, 22.2508, 21.0012, 24.4556],
        [14.3950, 14.8782, 15.7225, 19.0782, 27.4796, 37.6815, 27.6532]])
ground truth
tensor([[19.1873, 21.2783,  3.8138, 10.7312, 17.0174, 13.8611,  8.9690],
        [ 3.6297,  2.9721,  6.4045,  5.2209,  6.1678,  7.2462,  6.6807],
        [16.6097, 18.2009, 19.5818, 31.8122, 57.2988, 39.2425, 21.1205],
        [13.3614,  6.4703, 11.2441,  4.0900,  4.7344,  7.4698,  6.5492],
        [11.1915, 13.7296, 14.6107, 25.3288, 27.5776, 14.9001, 15.9390],
        [17.3044, 2

batch_predictions
tensor([[13.2609, 14.6222, 15.5685, 22.0421, 28.2745, 26.9805,  9.2451],
        [14.1461, 15.5847, 14.1312, 14.6854, 20.5366, 30.8434, 25.7148],
        [14.0799, 14.5255, 21.6108, 30.4263, 31.9754, 12.1859, 14.8383],
        [13.0268, 13.2757, 14.3659, 21.4301, 31.0039, 19.0840, 13.8783],
        [12.8233, 17.0969, 24.3230, 23.2926, 11.7505, 12.1333, 12.6148],
        [15.9304, 17.6622, 22.0734, 26.8193, 30.4043, 11.6722, 14.6248],
        [32.8965, 13.9952, 16.8326, 16.7426, 17.5451, 25.7058, 41.1594],
        [26.4595, 29.8356, 11.7947, 14.5665, 12.9736, 12.6909, 15.3077]])
ground truth
tensor([[10.2041, 10.3458, 13.8605, 13.2653, 30.3288,  4.9745,  5.2438],
        [15.4918, 10.4156, 15.3077, 17.2541, 19.4766, 24.8685, 19.9435],
        [12.2166, 14.4841, 20.1247, 34.1837, 32.4263, 15.2494, 12.7551],
        [20.2664, 20.7341, 22.3498, 26.3464, 32.0720, 18.7925, 13.6763],
        [15.8588, 14.3282, 28.0471, 26.0062, 13.0385, 11.0544, 11.5505],
        [13.1247, 1

batch_predictions
tensor([[11.7921, 16.0823, 23.4039, 22.0333,  7.8957, 11.2875, 12.3364],
        [ 8.6933, 13.8714, 13.5773, 14.0101, 17.7822, 27.0999, 30.6873],
        [31.7412, 41.7044, 40.5784, 15.6956, 19.8217, 19.7306, 20.4591],
        [18.3612,  7.5708,  7.6856,  8.2641,  7.9935,  7.9971, 15.1868],
        [ 6.8182,  6.9498,  7.3135,  7.7432, 10.6470, 12.5289, 13.3843],
        [ 7.0174,  7.6354,  9.3630,  6.7984,  4.7889,  7.3221,  6.6108],
        [19.5407, 25.1208, 27.6177, 10.7007, 15.0876, 14.0903, 16.1403],
        [18.0783, 25.4577, 20.1658, 10.4077, 12.7707, 12.4199, 14.0753]])
ground truth
tensor([[15.2157, 18.7533, 22.8958, 20.0947, 10.5471, 12.7433, 10.1262],
        [ 4.8328, 14.4416, 11.5930, 12.9393, 16.1565, 24.5607, 27.6644],
        [32.0011, 47.7041, 42.8713, 26.6440, 21.1026, 21.2868, 23.1151],
        [18.0957,  3.6691, 10.0736,  8.1010,  8.4166,  9.9816, 17.0437],
        [ 4.0900,  4.7344,  7.4698,  6.5492,  7.2593, 11.5466, 17.0174],
        [ 8.7980, 1

batch_predictions
tensor([[31.6183, 10.7819, 14.1475, 14.6565, 16.2644, 26.6005, 42.1382],
        [13.4748, 13.2123, 17.5196, 21.4025,  7.3150,  7.3486, 13.3838],
        [26.6515, 22.1684,  8.4466, 12.2492, 12.9041, 14.9590, 21.4455],
        [11.8702, 15.1517, 19.8669, 16.2120,  7.3307,  9.1754, 10.6434],
        [16.1809, 16.1750, 16.2402, 22.4288, 31.1880, 40.5687, 22.5678],
        [22.0716, 21.7542,  9.5527, 10.4736, 11.3983, 12.1248, 16.3229],
        [27.2159,  9.3288, 12.6497, 11.9246, 12.7869, 15.2710, 23.4562],
        [12.9631, 12.9579, 13.1082, 17.4398, 25.0541, 26.0375, 10.1542]])
ground truth
tensor([[31.7177, 14.2432, 23.1859, 16.0998, 16.3832, 43.9909, 28.9399],
        [14.1015, 15.0652, 19.6145, 25.0142,  9.7647,  3.2738, 18.2965],
        [35.7001,  2.4235,  0.3968, 10.8560, 10.7143, 18.8634, 28.2455],
        [ 6.8385, 13.7954, 20.6865,  4.6686,  5.0763,  9.2057, 12.2567],
        [ 7.6531, 21.3719, 21.3577, 29.6769, 48.4977, 23.2993, 19.9405],
        [24.0794, 3

batch_predictions
tensor([[22.7234, 29.5561, 19.0663, 12.6567, 14.3445, 14.1244, 16.6769],
        [21.1245, 21.8120, 11.9886, 13.4778, 13.4997, 13.1557, 17.3686],
        [15.0945, 27.0670, 31.1292, 22.3378, 12.2411, 13.0342, 13.7559],
        [ 9.5605, 13.4884, 13.6350, 14.2075, 18.5345, 29.3950, 30.4898],
        [22.0156, 10.0189, 13.7492, 15.1431, 16.6561, 18.9421, 25.1470],
        [22.9069, 10.4874, 12.0878, 11.3305, 10.3795, 12.2188, 20.6290],
        [13.6183, 13.3964, 14.6138, 23.6612, 31.5207, 20.3281,  8.1537],
        [13.3907, 14.4868, 16.8282, 23.7382, 31.1559, 19.1752, 13.2254]])
ground truth
tensor([[21.8537, 30.2863, 30.6264, 13.9739, 13.6905, 14.5692, 19.6570],
        [19.7988, 27.6927, 11.3520,  6.9444, 16.3832, 12.1457, 14.9376],
        [12.7959, 21.0153, 31.9437, 26.3414, 13.0589, 12.2962, 14.9264],
        [ 5.5272, 14.2574, 16.7659, 14.9660, 18.4524, 33.1774, 32.0011],
        [28.3140, 12.5723, 15.4655, 16.9911, 14.4266,  7.1147, 28.0773],
        [22.1725, 1

batch_predictions
tensor([[13.3674, 14.6339, 20.5869, 30.9839, 36.2231, 11.9683, 14.9863],
        [15.6533, 15.1502, 19.4434, 38.7375, 40.6531, 19.6221, 15.7685],
        [25.2125, 32.1203, 22.1005, 20.4618, 20.4101, 19.4554, 18.5466],
        [13.0895, 13.6555, 16.7902, 25.9661, 36.1073, 15.9402, 13.7069],
        [20.5145, 28.2486, 29.7911, 11.9854, 13.8811, 12.5191, 14.1932],
        [10.4722, 14.0580, 12.6124, 14.4447, 23.3729, 32.1050, 21.6256],
        [16.3546, 25.8409, 37.9125, 29.4506, 14.6938, 15.0117, 14.9270],
        [29.6774, 31.9254, 25.2959, 24.1451, 17.3330, 19.4206, 22.8366]])
ground truth
tensor([[12.2567, 15.0842, 16.8069, 23.2772, 35.8232, 11.9674, 15.7549],
        [16.7517, 15.4478, 22.1797, 44.8696, 36.2528,  9.3963, 10.6718],
        [25.3827, 31.5618, 18.1689, 19.3311, 26.9983, 21.7829, 20.6066],
        [10.9127, 11.3520, 18.0839, 31.3492, 25.8929, 12.8685, 13.3362],
        [16.2020, 28.8664, 28.7875,  9.0873, 11.7175,  8.7322, 11.5203],
        [ 6.5901, 1

batch_predictions
tensor([[25.7384, 23.7924, 20.0412, 17.1735, 16.9912, 20.5225, 28.4986],
        [13.5427, 14.9094, 15.6945, 21.6889, 29.8211, 29.0287,  8.8820],
        [13.6564, 13.1259, 13.4386, 21.0434, 27.9276, 29.1819,  9.8909],
        [40.5586, 32.1129, 14.4103, 16.0921, 15.2612, 16.2854, 27.4520],
        [18.5241, 31.1334, 41.2889, 31.8552, 16.4232, 18.3774, 16.9285],
        [22.2158, 10.8576, 13.0292, 13.4449, 16.9937, 19.2619, 28.0244],
        [13.5785, 14.2927, 15.5316, 20.6267, 23.8785,  7.4707,  9.9568],
        [ 9.0158, 10.3696, 19.0151, 22.9258,  6.6974,  9.0469,  9.4037]])
ground truth
tensor([[39.9132, 22.2383, 21.3966, 24.2635, 19.5160, 19.3319, 27.8538],
        [17.2477, 13.5488, 18.1406, 25.6236, 21.7687, 32.6247,  9.8781],
        [11.7175,  8.7322, 11.5203, 16.5176, 28.0905, 27.4592,  8.3377],
        [39.7676, 34.3112, 16.4399, 14.0023, 14.7392, 15.0368, 21.3152],
        [15.8588, 33.5176, 37.0465, 36.8197, 18.4524, 15.8022, 18.8492],
        [31.2730, 1

batch_predictions
tensor([[13.6172, 16.6314, 17.7157, 17.0464, 24.3724, 34.3173, 28.8735],
        [26.8209, 23.8117, 22.3263, 24.3590, 39.3615, 53.8742, 50.0935],
        [20.2685, 24.7201,  9.7782, 11.8678, 10.9538, 11.1960, 13.2618],
        [ 9.1110, 11.5370, 19.0408, 17.8725,  7.6828,  7.8226,  8.2255],
        [15.8456, 15.1346, 16.6756, 34.3232, 45.4103, 27.3699, 13.8701],
        [21.7792, 22.5477, 10.7308, 12.3611, 10.5820, 10.6022, 13.5463],
        [30.8554, 20.8206, 22.8769, 25.4867, 29.2886, 34.8860, 32.6656],
        [ 8.8971, 11.2468, 13.2215, 13.1427, 15.4182, 24.7228, 25.8332]])
ground truth
tensor([[ 8.1633, 13.9881, 22.7466, 19.3878, 22.7183, 31.7744, 35.8277],
        [29.5371, 24.8948, 20.8969, 20.3577, 37.9537, 71.5413, 77.3146],
        [21.6465, 20.6996,  7.7985, 10.0079,  7.9563, 16.5045, 12.0857],
        [ 7.2725,  8.1931, 16.4124, 15.9916,  7.5092,  8.5087,  8.2325],
        [16.1848, 11.9898, 19.2602, 30.3571, 38.0952, 26.9700, 11.0402],
        [21.4361, 2

batch_predictions
tensor([[11.1967, 16.0746, 15.6109, 16.8271, 22.4626, 32.3925, 37.7485],
        [10.5393, 11.0858, 15.5371, 22.5527, 20.7898,  7.2547,  9.8533],
        [ 9.9722, 15.6187, 21.7366, 10.0456,  7.6550,  7.6424,  7.9585],
        [13.7798, 15.5070, 20.0877, 34.4124, 40.6219,  9.5432, 16.4656],
        [10.8284, 17.1366, 24.2662, 14.4356,  9.7763,  9.7741,  9.7846],
        [ 7.2645,  8.6413, 10.8977, 15.6368,  5.9916,  6.2287,  7.1740],
        [16.0385, 25.3237, 36.9291, 32.3268, 12.4133, 15.8650, 15.5057],
        [28.4014, 40.1646, 33.0280, 20.2562, 19.9378, 17.1720, 17.3039]])
ground truth
tensor([[ 7.9790, 16.5249, 14.8810, 18.4949, 27.4093, 25.7653, 31.4201],
        [ 8.7585,  7.8117, 13.7691, 20.3314, 14.7817,  7.2462,  9.2451],
        [10.7575,  8.6928, 21.6070,  9.8106,  6.3782,  8.3509, 11.9411],
        [22.2931, 14.7817, 22.2080, 22.8741, 27.3101,  4.2942, 12.8543],
        [15.1219, 17.6020, 30.8957, 13.6763, 13.9739,  9.1978, 11.6071],
        [ 6.1152,  

batch_predictions
tensor([[19.4415,  8.8108,  8.6580,  8.7739, 10.2094, 11.4386, 19.6812],
        [43.7201, 37.7028, 16.8963, 19.3988, 18.1330, 18.8612, 27.2521],
        [18.3082, 24.1492, 27.1762, 10.0915, 14.8494, 13.2077, 15.0226],
        [13.7514, 16.0815, 22.2238, 24.5000, 15.0317, 14.2714, 14.7944],
        [14.8271, 15.1738, 15.2911, 20.7920, 27.9742, 27.8309, 11.4594],
        [25.4743, 34.9658, 37.1770,  9.0636, 15.7405, 16.4524, 17.6868],
        [ 6.7459,  8.0902,  7.3233,  8.0797, 10.0260, 16.0093, 19.6086],
        [17.4704, 27.1119, 39.0378, 28.3959, 14.4594, 16.1552, 15.9106]])
ground truth
tensor([[18.2536,  7.7854,  6.0494, 11.0205, 11.5071, 11.8885, 13.4008],
        [36.7914, 25.6094, 18.4382, 21.9388, 10.5867, 15.7738, 22.6049],
        [17.6223, 23.8033, 31.6281, 10.4945, 14.3346, 14.8211, 18.0957],
        [20.8050, 18.5941, 32.1145, 28.7273, 17.7721,  7.3129,  8.3759],
        [15.9155, 15.5896, 17.2477, 33.1066, 20.9892, 29.3651, 13.0952],
        [22.5907, 3

batch_predictions
tensor([[25.8070, 12.2141, 14.7509, 12.7190, 12.6727, 17.6148, 25.9530],
        [14.7358, 18.3995, 23.9372, 24.2247, 10.7935, 13.5953, 14.0893],
        [25.2599, 26.0901, 13.9763, 14.5365, 13.8250, 14.1084, 17.2477],
        [37.3115, 40.2681, 45.7493, 45.3423, 44.5993, 43.5309, 38.2848],
        [23.4255, 25.0487,  9.5539, 12.0614, 14.1097, 15.6995, 17.3826],
        [20.0073, 26.2877, 37.0076, 35.4187, 15.2275, 17.9792, 17.3944],
        [12.7702, 12.1662, 15.1146, 23.0395, 31.2170,  9.0671, 12.8867],
        [14.0430, 15.4549, 21.9779, 25.3780, 21.6165, 12.7933, 13.9851]])
ground truth
tensor([[24.5465, 14.0306, 16.0006, 18.8350, 15.1644, 26.6865, 23.6678],
        [12.8222, 15.7680, 23.1063, 28.9847, 17.1752, 18.5823, 19.5555],
        [26.1763, 22.8316, 13.5629, 13.8322, 12.4008, 16.2982, 17.0210],
        [39.3346, 41.8201, 59.8501, 54.9448, 42.2935, 41.8201, 37.6381],
        [17.7438, 19.7562,  8.7727,  3.8974, 14.4274, 17.0493, 12.3299],
        [17.6162, 2

batch_predictions
tensor([[15.1003, 15.5528, 19.9561, 25.1366, 25.0880, 22.4761, 15.2365],
        [ 6.6727,  7.5357, 10.2216, 13.0595, 14.5434, 19.4714, 16.5907],
        [22.4941, 18.9317, 10.2142, 11.3986, 11.4855, 13.2116, 20.4881],
        [12.0831, 15.3672, 14.1693, 14.4701, 19.1362, 26.0019, 28.0623],
        [16.6418, 22.3758, 12.5347,  9.9317, 11.6631, 11.7968, 11.8830],
        [30.1519, 13.4470, 15.0115, 16.6661, 20.9822, 34.8817, 40.4320],
        [24.5749, 10.7811, 13.8702, 14.9514, 15.9534, 15.5787, 19.6189],
        [17.0715, 17.2388, 17.3404, 32.9632, 46.5331, 44.4660, 18.6475]])
ground truth
tensor([[12.9274, 15.1631, 16.5176, 18.9506, 28.5113,  7.6144, 13.0195],
        [ 6.6281,  3.7612,  3.3798,  8.4429, 11.6386, 15.0053, 17.4908],
        [24.8159, 21.9621, 11.1257,  7.6013,  9.8238, 16.7675, 20.7785],
        [10.6391, 13.8348, 12.6644, 14.1504, 21.5413, 17.6749, 22.8827],
        [15.3998, 23.4876, 19.9895,  7.6013, 11.0074,  9.2057, 10.6654],
        [26.0771, 1

batch_predictions
tensor([[10.8466, 11.6529, 13.2582, 19.3804, 25.0732, 16.6602,  8.7060],
        [13.6879, 14.4372, 14.8965, 17.7655, 24.9230, 30.7036, 18.4737],
        [15.6484, 23.3055, 27.0921, 12.0310, 14.7345, 12.9523, 13.9506],
        [29.8193, 16.1507, 17.3042, 16.3965, 18.5212, 25.7833, 43.6093],
        [10.8184, 15.2907, 21.8132, 19.7802,  7.0692,  9.2934, 10.5851],
        [ 9.7084, 13.9393,  7.9500,  7.0007,  6.4711,  5.6128,  7.2969],
        [10.8029,  9.8511, 10.6582, 12.3530, 15.6912, 20.8373, 15.7225],
        [12.9031, 12.2708, 13.6227, 18.7280, 24.0946, 19.7154, 10.7093]])
ground truth
tensor([[ 7.4566,  9.1794,  8.5876, 14.7291, 18.8979, 19.0426,  9.8501],
        [11.5646, 13.9739, 10.7426, 11.7063, 18.8067, 34.9773, 31.1083],
        [16.2415, 20.1604, 24.2504,  9.9290, 12.6644, 12.7959, 15.2683],
        [13.1661,  0.0000, 10.7710, 18.4666, 19.7279, 25.7653, 50.7653],
        [14.1899, 25.9469, 17.1357, 15.6497,  8.4166,  9.9027, 10.1657],
        [ 9.6088,  

batch_predictions
tensor([[14.2081, 15.2572, 12.8591, 13.1569, 17.2814, 24.3080, 23.6367],
        [12.6387, 13.0738, 13.7529, 18.3934, 31.1308, 27.4362,  8.3837],
        [18.0374, 24.8386, 21.7874,  7.8750, 11.8404, 12.3855, 12.8834],
        [15.3540, 18.4412, 17.9333, 18.6606, 23.7765, 31.4307, 36.1029],
        [15.0118, 15.0899, 15.6406, 16.0328, 17.3119, 20.1168, 20.7338],
        [31.3780, 18.0827, 18.6479, 16.8275, 20.1876, 31.3069, 44.9050],
        [18.9035, 28.7231, 31.3928, 11.4728, 15.2608, 14.7867, 15.1076],
        [18.0898, 20.5642, 29.4340, 31.8084, 21.1944, 18.6465, 18.5653]])
ground truth
tensor([[11.3237, 14.3566, 11.1395, 13.3787, 19.6145, 22.3923, 17.4320],
        [13.9314, 10.6859, 17.8713, 23.9512, 31.2075, 22.7183,  7.0153],
        [16.3993, 27.6302, 29.9842,  8.0747, 11.5992,  9.1531, 12.2567],
        [10.8233, 13.6902, 19.5555, 17.6092, 20.3972, 25.6575, 39.6633],
        [18.4240, 16.9643, 11.0261, 15.0227, 16.4541, 21.7829, 24.3622],
        [32.0437, 2

batch_predictions
tensor([[ 8.5635, 11.4054, 12.2902, 12.7637, 18.3603, 26.0090, 24.4468],
        [15.7084, 16.0689, 21.3918, 31.6387, 29.9305, 11.2791, 15.8129],
        [17.0667, 16.5237, 17.9307, 35.5283, 45.7607, 35.4428, 14.0188],
        [13.8562, 14.4581, 14.0876, 17.3678, 28.0539, 31.3739, 13.4440],
        [23.3914,  7.4514, 10.2936,  9.1054,  9.7885, 13.2783, 18.9546],
        [24.9369, 30.6506, 25.6267, 20.3762, 19.8484, 19.4932, 19.1702],
        [15.7420, 16.7148, 26.2684, 41.6616, 37.7208, 15.3963, 16.0840],
        [24.9132, 13.0282, 13.1749, 12.4293, 13.3117, 21.0551, 31.0385]])
ground truth
tensor([[ 5.0237, 10.7180,  8.3377,  9.5213, 12.9932, 25.8811,  3.7612],
        [16.6228, 14.8606, 23.7112, 34.3240, 31.6675, 12.7170, 18.9769],
        [12.1173, 19.8696, 18.7075, 31.6752, 38.2937, 35.5867, 14.6259],
        [11.7772, 12.5850, 16.2982, 20.7908, 28.6423, 27.2534, 12.6276],
        [20.8192,  8.9286,  9.4388,  2.3668,  8.4042, 16.7659,  1.5590],
        [28.3872, 4

batch_predictions
tensor([[ 9.5559, 13.6767, 18.0305,  6.6414,  6.8671,  7.0293,  7.4198],
        [41.0062, 34.6743, 17.0546, 20.2216, 21.2439, 26.5588, 38.9046],
        [29.0502,  9.3749, 14.5683, 14.8218, 16.8293, 22.3860, 29.5757],
        [ 7.2389,  7.3568, 10.0024, 11.9103, 14.6528,  6.9040,  7.4228],
        [15.5529, 22.6829, 18.8542,  7.8042, 10.1833, 10.3387, 11.0102],
        [13.6102, 14.1216, 15.0354, 28.4206, 37.9482, 28.5584, 11.4835],
        [14.0010, 13.7758, 16.2853, 24.5284, 26.6787, 15.3430, 14.4722],
        [15.3655, 15.6111, 18.7459, 26.8624, 29.1904, 12.0259, 15.4390]])
ground truth
tensor([[ 7.8906, 14.6502, 14.9132,  5.0631,  5.1026,  4.6291,  5.3524],
        [29.7761, 38.5913, 18.5232, 19.3594, 23.0584, 26.9558, 38.6480],
        [ 0.7370,  0.0000,  0.7511, 13.7046, 15.3770, 20.3090, 36.6497],
        [ 9.7647,  5.7823,  8.8294,  9.3963, 12.0607,  6.5760,  6.0232],
        [15.5971, 17.4119, 16.7149,  6.3782,  8.9164, 10.7180, 14.1899],
        [11.6497, 1

batch_predictions
tensor([[14.5079, 16.3529, 15.5415, 18.7123, 34.2821, 43.5916, 30.4326],
        [22.0126, 14.2901,  9.2848,  9.9632,  9.9206, 11.1728, 15.6048],
        [24.3019, 13.3768, 12.3569, 12.7667, 12.8494, 14.6848, 20.2524],
        [13.6633, 15.4361, 18.8281, 23.7297, 26.6271, 13.9079, 14.3725],
        [41.6299, 40.2413, 23.7310, 22.1067, 22.5259, 26.4902, 36.5286],
        [17.7931, 18.1228, 24.8303, 40.0878, 41.6269, 21.8853, 18.6308],
        [29.5329, 13.1958, 16.1916, 18.0029, 19.7639, 26.5904, 34.8348],
        [15.1828, 15.2904, 25.5911, 33.0309, 25.3680, 14.3859, 16.2573]])
ground truth
tensor([[11.6922, 15.3770, 10.5726, 21.4994, 33.3050, 50.2126, 28.7982],
        [23.4219, 16.0442, 12.3751, 11.8227, 11.3887, 14.6239, 15.1631],
        [28.1694, 26.3414, 17.7933, 22.5802, 16.8201, 14.6239, 15.8075],
        [16.2678, 14.7028, 15.4524, 25.5918, 29.8790, 15.3340, 16.7543],
        [52.5513, 52.4461, 24.4345, 19.6739, 20.5418, 23.3693, 58.2588],
        [10.3741, 2

batch_predictions
tensor([[10.9801, 13.1815,  8.0721,  6.5604,  8.3809, 11.4363,  8.8410],
        [10.2750, 12.2515, 17.7881, 20.4135,  6.1456,  8.2936,  9.2827],
        [26.4722, 22.7515, 13.1018, 13.7056, 12.8332, 13.8082, 19.3793],
        [16.4873, 10.5927, 12.9458, 13.6915, 13.6982, 18.9481, 25.6904],
        [17.8106, 27.7193, 30.1709,  6.9563, 13.5248, 12.5631, 13.9361],
        [18.0288, 25.7497, 39.5057, 33.6148, 14.2451, 17.1510, 16.7947],
        [15.6467,  8.9141,  7.6374,  8.4431,  8.9715,  8.9889, 11.5134],
        [21.1207, 24.2804,  7.5735, 11.7987, 11.3807, 11.4428, 14.4279]])
ground truth
tensor([[11.2572, 12.1515, 10.3630,  5.6549,  8.3772,  9.7054,  8.8901],
        [ 9.0873,  9.2451, 14.0452, 22.1857,  3.5245,  8.8243, 10.4024],
        [26.7574, 23.0867, 13.0811, 13.7330, 11.5363, 16.3974, 21.9388],
        [22.9748,  9.5345, 14.9001, 14.0189, 14.5055, 16.2678, 22.1462],
        [15.4620, 21.0034, 27.8628,  4.8328, 14.4416, 11.5930, 12.9393],
        [15.5181, 2

batch_predictions
tensor([[15.0727, 15.4093, 17.0617, 27.6055, 38.6685, 31.1024, 14.5101],
        [26.0321, 26.7761,  8.4860, 11.1724, 12.5028, 14.3555, 17.7559],
        [33.4647, 43.3736, 36.7937, 15.7686, 16.9278, 16.4566, 18.2382],
        [33.2464, 29.7254, 10.8634, 14.6322, 14.1141, 14.8902, 23.8515],
        [ 7.6189,  7.4615,  8.4617, 12.3995, 16.6112, 22.2362,  6.8766],
        [13.1475, 13.4642, 14.4741, 15.8890, 24.3362, 35.9461, 27.8255],
        [10.8889, 12.6593, 22.0260, 21.0872,  9.0293, 10.4236, 11.2495],
        [19.3943, 26.4381, 22.6869, 11.2887, 13.9837, 13.4175, 15.3481]])
ground truth
tensor([[15.7029, 13.8747, 17.5028, 25.0000, 53.6990, 37.9819, 24.4756],
        [23.7112, 24.3556,  7.4303,  9.9553, 12.6775, 13.8085, 15.9784],
        [36.5363, 51.1621, 36.5930, 17.1344, 17.1202, 18.2398, 21.5136],
        [39.2149, 29.1808, 13.1094, 14.3141, 11.2245, 10.7568, 23.9371],
        [ 9.7506,  6.8594,  6.5901, 16.1990, 17.6587, 20.8192,  8.9286],
        [16.0856, 1

batch_predictions
tensor([[12.9127, 14.7484, 24.4218, 23.2500, 10.7512, 13.1294, 13.0274],
        [15.0476, 15.3416, 22.5877, 35.7732, 39.1610, 10.7589, 16.6386],
        [11.7085,  7.6429,  7.6381,  8.2964, 10.5588, 11.5145, 14.1912],
        [10.1456,  9.7062,  9.9895, 13.9689, 21.4694, 21.0092,  8.6487],
        [29.8823, 39.0672, 29.1090, 14.2507, 15.7047, 15.6585, 17.5088],
        [17.4317, 26.9941, 23.0738, 12.1398, 13.0616, 11.9518, 13.8802],
        [18.5330, 18.1565, 18.1652, 29.0713, 42.2934, 35.9165, 22.3035],
        [21.6363, 10.9858, 12.8283, 11.8140, 13.3449, 20.9001, 30.9107]])
ground truth
tensor([[11.3095, 12.6417, 25.2268,  0.0000,  0.0000,  9.5096, 13.9739],
        [ 3.6706, 17.7721, 20.7200, 29.7619, 36.1395,  4.8753, 19.7421],
        [16.2415,  6.8254,  7.1673,  6.3914,  9.8238,  8.5613, 15.6365],
        [10.9679, 11.1126,  9.9553, 18.1220, 26.0521, 23.8559,  8.0747],
        [27.6644, 40.9155, 36.8764, 14.4983, 19.7988, 16.1565, 15.7313],
        [18.5941, 2

batch_predictions
tensor([[10.9361, 14.4307, 14.7392, 14.8205, 18.8170, 27.0205, 30.6813],
        [15.5846, 14.7749, 26.6077, 31.1225, 11.7373, 15.4569, 15.1688],
        [22.3905, 26.5743, 12.9937, 16.1299, 15.5073, 15.3661, 17.6853],
        [11.0581, 12.0838, 15.8132, 23.5099, 18.9764,  9.3889, 12.8373],
        [14.0269, 15.7470, 19.3048, 21.6192, 25.6168, 11.4157, 13.2245],
        [26.3803, 24.5540, 19.9318, 15.7709, 16.2834, 18.4418, 27.9194],
        [16.7628, 16.1063, 18.5586, 35.2503, 39.9297, 23.2788, 17.3970],
        [13.0041, 12.6960, 12.7292, 19.3032, 26.1596, 26.1312, 13.2352]])
ground truth
tensor([[11.0074, 13.0589, 16.5308, 11.3756, 17.8196, 25.2893, 28.3930],
        [15.3061, 16.8509, 27.7069, 29.8611, 12.5709, 15.6604, 13.3078],
        [24.8027, 20.4366, 10.9942, 15.8469, 14.0978, 17.4513, 20.0947],
        [ 6.4440,  7.2593, 16.2941, 23.3693, 23.5402, 12.3619, 13.3482],
        [13.0385, 14.7109, 16.5391, 26.1338, 26.6582,  9.9348, 11.8764],
        [29.5493, 1

batch_predictions
tensor([[26.4462, 15.2018, 15.9017, 15.5526, 15.6010, 16.6034, 23.7594],
        [18.8230, 34.3010, 34.1835, 15.6587, 15.0795, 14.1243, 14.6042],
        [13.1484, 22.3751, 20.0123,  9.2140, 10.7792,  8.5516,  9.8760],
        [17.7191, 27.1317, 28.9136, 10.9839, 14.9329, 13.5906, 14.7224],
        [ 9.9216, 13.7015, 16.3799,  8.6826,  6.5826,  7.6362,  8.2657],
        [30.9692, 10.0095, 15.6959, 16.3324, 17.3502, 24.6495, 31.7667],
        [22.6429, 32.6836, 37.8689,  9.4238, 15.7556, 14.8661, 16.4070],
        [11.9459,  7.3712,  6.0737,  8.1032,  8.3401,  8.8376,  9.9922]])
ground truth
tensor([[36.6649, 20.5813, 19.4371, 15.0973, 16.3861, 13.9795, 18.4640],
        [13.7755, 30.1871, 27.5652, 11.6922, 13.7188, 10.7001, 11.4654],
        [13.4212, 17.9422,  4.3651,  6.6043,  9.8498,  8.8577,  9.0986],
        [20.0539, 28.6990, 35.4025, 10.5867, 12.2449, 10.8702, 13.6905],
        [ 9.5213, 13.2693, 15.0447, 13.0852,  6.8780, 11.0205,  8.8901],
        [40.4195, 1

batch_predictions
tensor([[21.4302, 19.8552, 23.4202, 42.3647, 45.2647, 39.4092, 23.9189],
        [21.8136, 10.6551, 12.7988, 12.3721, 12.6375, 18.3005, 25.0928],
        [18.4979, 19.1343, 36.5092, 44.9608, 26.6644, 19.9207, 19.6195],
        [30.1756, 15.9518, 16.6401, 15.5375, 16.3912, 20.8716, 32.8414],
        [12.0739, 14.1114, 14.5555, 18.9809, 27.6034, 28.2605, 22.1004],
        [11.4055, 14.4717, 14.6692, 16.0709, 19.9060, 27.5449, 22.8667],
        [14.8255, 16.0095, 18.8103, 25.4497, 29.3800, 27.3309, 14.1240],
        [13.3721, 22.4573, 33.5011, 28.0877, 10.8881, 13.0726, 13.1278]])
ground truth
tensor([[20.2130, 19.9369, 18.8848, 33.9427, 55.6418, 14.1899,  0.0000],
        [29.8264,  9.9290, 12.0857, 12.4277, 11.0205, 19.8185, 22.8432],
        [26.3282, 17.3856, 34.0873, 63.3877, 44.1610, 13.9137, 12.0594],
        [29.6344, 12.0040, 14.7251, 15.4195, 19.8271, 22.2222, 27.7636],
        [ 7.1712,  7.7806, 11.6355, 14.5975, 19.8554, 30.8957,  8.9569],
        [10.8759, 1

batch_predictions
tensor([[13.1244, 19.1539, 26.6584, 26.1682,  9.8699, 11.4936, 12.2322],
        [21.7187, 32.5185, 30.2025,  8.3939, 14.6894, 13.8466, 15.9717],
        [13.4233, 13.1203, 13.2069, 17.4764, 25.6618, 28.7325, 17.0914],
        [15.2418, 16.2887, 36.0728, 36.8410, 20.2468, 18.2094, 17.0941],
        [26.1803, 12.7003, 14.0080, 12.6181, 13.0122, 19.7656, 34.3223],
        [27.2798, 10.5639, 14.2737, 12.7895, 13.7490, 17.0409, 26.3530],
        [15.6300, 23.2157, 25.4309,  9.1569, 11.8892, 11.0029, 13.2206],
        [15.2298, 15.8683, 18.3384, 25.9079, 28.1211, 11.2851, 15.0196]])
ground truth
tensor([[ 6.0757, 10.8364, 18.6349, 27.3672,  0.0000,  3.7480, 11.5071],
        [19.6287, 40.8588, 39.0023, 12.6842, 16.7092,  8.2483, 16.2698],
        [11.7063, 11.6497, 10.1049, 17.1910, 29.9178, 32.9790, 14.0448],
        [21.5420, 36.0261, 36.7914, 25.6094, 18.4382, 21.9388, 10.5867],
        [19.6287,  9.0703, 12.6276,  8.5459, 10.4025, 19.8129, 31.0941],
        [30.8957, 1

batch_predictions
tensor([[22.7057, 34.2864, 27.7873, 11.0577, 13.4351, 12.9040, 15.6189],
        [11.2627, 14.8729, 16.3566, 19.4983, 22.0275, 29.3392, 27.4602],
        [20.6402, 28.6048, 24.8230, 12.7210, 14.5205, 13.1483, 14.6794],
        [20.7439, 19.8492, 10.3573, 10.9758, 10.7547, 12.9617, 13.8932],
        [16.5907, 16.1569,  7.0762,  8.6960,  7.8568,  7.5668, 11.8696],
        [11.4052, 11.5978, 12.2279, 17.8657, 27.1599, 17.0795,  7.9310],
        [20.2194,  9.8868, 10.7020,  9.9570, 11.0086, 19.7566, 30.2071],
        [17.5527, 17.0351, 16.4617, 17.3096, 27.6639, 45.6638, 40.8849]])
ground truth
tensor([[20.6633, 31.1791, 29.1525, 12.3724, 12.4150, 14.9518, 16.8226],
        [13.1247, 18.3588, 10.1131, 19.7791, 19.8185, 26.5255, 29.7212],
        [20.0964, 27.3810, 26.6298, 13.4354, 16.3974, 18.3390, 13.6905],
        [26.9726, 19.1347, 13.3351, 11.2835, 10.2709, 16.0442, 20.5550],
        [15.7812, 12.6644,  4.9185,  9.5213,  6.4308,  7.5618, 13.4008],
        [ 9.1005, 1

tensor(33.1955)

In [34]:
class TransformerVanilla(Module):
    
    
    def __init__(self, d_model=64, nhead=4, num_encoder_layers=3, num_decoder_layers=3,
                dim_feedforward=512, dropout=0.1, activation=F.gelu, use_norm=True):
        
        super(TransformerVanilla, self).__init__()
        self.transformer_model = TransformerModel(d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers,
                                                  num_decoder_layers=num_decoder_layers, dim_feedforward=dim_feedforward,
                                                  dropout=dropout, activation=activation, use_norm=use_norm)
        
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        self.input_projection1 = nn.Linear(1, d_model)
        self.input_projection2 = nn.Linear(d_model, d_model)
        self.output_projection = nn.Linear(d_model, 1)
        
    def forward(self, encoder_x, decoder_x, src_mask=None, tgt_mask=None, memory_mask=None):
        """ encoder_x should be vector of size input_seq_len x bs x 1   
            decoder_x should be vector of size output_seq_len x bs x 1
        """
        encoder_x = self.pos_encoder(self.input_projection2(self.input_projection1(encoder_x)))
        decoder_x = self.pos_encoder(self.input_projection2(self.input_projection1(decoder_x)))
        x = self.transformer_model(encoder_x, decoder_x, 
                                   src_mask=src_mask, tgt_mask=tgt_mask, memory_mask=memory_mask)
        x = self.output_projection(x)
        return x

    
        
        
        

In [40]:
model_v0 = TransformerVanilla(d_model=64, nhead=nhead, num_encoder_layers=2, num_decoder_layers=2,
                dim_feedforward=256, dropout=0.1, activation=F.gelu)
model_v0.load_state_dict(torch.load('/home/mbaroody/Programming/Time series/NN5/transformer_vanilla_robustnull_v2-0.pt'))

<All keys matched successfully>

In [41]:
evaluate_model(model_v0, test_dataloader)

batch_predictions
tensor([[11.4240, 11.7356, 12.1022, 12.5022, 12.9190, 13.3243, 13.6408],
        [30.6787, 31.0758, 31.0345, 30.9378, 30.8638, 30.8237, 30.8084],
        [24.8778, 28.7262, 28.5320, 27.3244, 25.2732, 22.0480, 19.4626],
        [17.6461, 22.5739, 21.3124, 19.0141, 16.4467, 15.1427, 15.2202],
        [17.6647, 22.7957, 23.0727, 22.1033, 20.5371, 18.5566, 17.1777],
        [23.0011, 20.9399, 21.2400, 21.9991, 22.6732, 23.2022, 23.6021],
        [16.7230, 19.8401, 19.2410, 18.3349, 17.5219, 16.9240, 16.5867],
        [23.4345, 30.0252, 29.5433, 27.4890, 23.3897, 18.2426, 18.3207]])
ground truth
tensor([[13.5718, 20.6865,  9.6791,  6.8780,  6.9306,  8.3246, 11.8096],
        [24.9079, 29.5634, 34.5739, 45.1604, 18.3456, 21.1073, 22.9616],
        [15.0085, 27.8770, 36.7772, 37.0890, 21.1026, 15.2920, 16.2557],
        [18.0414, 15.5612, 12.9960, 16.9218, 23.0867, 32.1854, 14.1865],
        [21.3010, 35.1332, 33.7302, 13.0244, 12.7976, 12.6842, 13.4070],
        [30.8107, 2

batch_predictions
tensor([[20.8067, 29.1103, 29.0137, 27.0325, 22.3041, 17.1467, 17.3533],
        [20.4119, 21.0035, 20.8544, 20.5631, 20.3071, 20.1406, 20.0661],
        [18.7933, 19.9496, 19.6146, 19.0784, 18.6161, 18.3035, 18.1501],
        [25.6364, 26.1793, 26.1036, 25.8755, 25.6469, 25.4806, 25.3933],
        [20.4224, 21.5751, 21.3344, 20.8683, 20.4482, 20.1543, 20.0001],
        [16.5220, 15.1254, 15.5129, 16.2160, 16.8156, 17.2842, 17.6458],
        [21.0843, 20.0166, 19.7710, 19.9387, 20.2758, 20.6641, 21.0583],
        [24.3506, 28.5611, 28.0166, 26.7180, 24.9353, 22.3779, 20.2739]])
ground truth
tensor([[ 0.0000,  0.0000, 14.8384, 35.4025, 33.6593,  8.3617, 17.2761],
        [13.0527, 15.3345, 29.7761, 26.7999, 15.3345, 18.1122, 14.6117],
        [13.9795, 11.9411, 15.5576, 25.6575, 24.8816, 12.1778, 10.2841],
        [39.5125,  6.8736,  7.3838, 20.8192, 11.5930, 15.8022, 20.1531],
        [13.4271, 15.5445, 28.0247, 14.8211, 32.1278, 22.5013, 23.4482],
        [10.6786,  

batch_predictions
tensor([[19.3798, 19.5794, 19.5151, 19.3745, 19.2441, 19.1614, 19.1317],
        [13.2030, 12.7601, 12.5901, 12.5640, 12.6394, 12.7661, 12.9356],
        [ 5.0716,  4.3449,  4.3207,  4.3548,  4.3802,  4.3972,  4.4094],
        [21.3347, 22.3822, 22.1543, 21.6611, 21.1784, 20.8125, 20.5986],
        [19.8892, 19.7796, 19.8053, 19.8418, 19.8704, 19.8978, 19.9288],
        [23.1622, 25.5485, 25.1959, 24.4085, 23.5533, 22.7522, 22.1007],
        [30.2627, 27.7034, 27.1310, 27.5492, 28.2811, 29.0354, 29.7161],
        [17.9935, 23.7863, 22.9267, 20.7469, 17.5588, 16.0552, 16.3224]])
ground truth
tensor([[22.5013, 20.5024, 10.5734,  8.8769,  8.6796, 14.2951, 18.6612],
        [ 3.4850,  7.9695,  6.3914,  9.6002, 13.9269, 20.3577, 15.1105],
        [ 5.2721,  7.7239,  7.5113,  7.8656,  8.9569, 14.9660,  9.0703],
        [16.6754, 19.8054, 20.7391, 25.1578, 26.9726, 13.5192, 17.7012],
        [31.5623, 28.9453, 12.6381, 16.4256, 16.9385, 16.7806, 18.3719],
        [19.7279, 2

batch_predictions
tensor([[20.2508, 27.6320, 26.7556, 23.8845, 18.5791, 16.6841, 16.9938],
        [38.2789, 38.2719, 38.2730, 38.2732, 38.2730, 38.2728, 38.2728],
        [27.2057, 27.4831, 27.5000, 27.4149, 27.2873, 27.1776, 27.1157],
        [33.5748, 33.1687, 33.1489, 33.2084, 33.2730, 33.3416, 33.4157],
        [13.6617, 17.3827, 21.8854, 20.6916, 18.5013, 16.0529, 14.6150],
        [27.2129, 23.6230, 24.6123, 25.9462, 27.1088, 28.0112, 28.6583],
        [19.7465, 22.5189, 21.8739, 20.7657, 19.5445, 18.3660, 17.5782],
        [14.6181, 18.3109, 21.6524, 21.6582, 20.8906, 19.9369, 19.0848]])
ground truth
tensor([[16.5308, 18.2272, 17.1357, 25.2367, 37.9143, 25.9732, 12.2699],
        [58.4824, 60.6786, 56.6149, 65.8469, 52.4066, 44.9763, 55.1683],
        [13.8605, 17.7296, 13.3645, 15.6179, 40.0652, 53.6706, 24.0930],
        [29.7761, 38.5913, 18.5232, 19.3594, 23.0584, 26.9558, 38.6480],
        [10.5865,  9.5476, 10.0736, 18.5429, 27.4987, 28.8795, 11.9279],
        [14.3566, 4

batch_predictions
tensor([[26.4438, 34.9929, 37.1497, 36.4311, 34.6646, 31.0433, 25.2181],
        [23.2849, 26.6935, 26.1826, 24.8922, 23.0697, 20.8058, 19.2951],
        [13.6317, 12.8158, 13.3054, 14.0125, 14.6529, 15.0863, 15.3150],
        [13.4826, 17.2232, 20.0377, 20.9034, 20.6201, 20.0474, 19.5135],
        [31.1340, 29.8622, 29.4667, 29.5090, 29.7636, 30.1602, 30.6656],
        [20.8650, 21.4058, 21.2585, 20.9730, 20.7182, 20.5501, 20.4744],
        [19.4218, 18.3223, 18.1901, 18.5944, 19.1965, 19.8561, 20.5217],
        [19.0463, 22.0281, 20.9729, 19.3279, 17.3585, 15.7995, 15.5283]])
ground truth
tensor([[33.3640, 53.8532, 47.3567, 22.7117, 20.9232, 16.7938, 20.8443],
        [ 4.3226, 37.1457, 34.5947, 14.6117, 19.7279, 14.3991, 19.0760],
        [13.5850,  9.9816,  9.6923, 10.8890, 11.6781, 16.1494, 23.4219],
        [ 0.0000, 12.9819, 25.3118,  7.7098,  9.3396, 12.4858, 18.4240],
        [ 8.9164, 16.4256, 14.6370, 31.2599, 58.8901, 47.5802, 12.2830],
        [19.3027, 2

batch_predictions
tensor([[15.3520, 19.6411, 22.9015, 24.0992, 23.7341, 22.8954, 21.9650],
        [22.0974, 20.8243, 21.2551, 22.2344, 23.3537, 24.5108, 25.6248],
        [28.8209, 32.0102, 32.0948, 31.3066, 30.1416, 28.7573, 27.1906],
        [20.9436, 22.1412, 21.8427, 21.2592, 20.6872, 20.2513, 19.9894],
        [14.0910, 18.3234, 21.4162, 22.2700, 21.8505, 21.1432, 20.4847],
        [ 7.0016,  5.8869,  5.6891,  5.6995,  5.7297,  5.7515,  5.7671],
        [19.4592, 25.6227, 24.3653, 21.2222, 16.9391, 16.0209, 16.2694],
        [ 9.4287, 10.4196, 11.0885, 11.7087, 12.4295, 13.4694, 14.9995]])
ground truth
tensor([[ 0.0000,  5.1871, 12.9819, 28.8549, 33.4042, 16.5391, 10.2324],
        [13.0244, 12.7976, 12.6842, 13.4070, 23.0300, 35.2749, 32.9932],
        [26.6176, 51.0389, 45.8969, 22.3435, 16.8858, 16.3335, 20.9100],
        [13.9172, 12.9393, 15.3486, 18.7500, 29.6769, 30.7256, 16.7375],
        [11.3520, 12.0748, 16.2698, 17.8430, 34.0278, 31.5193, 12.6559],
        [ 8.7454,  

batch_predictions
tensor([[17.1527, 18.4297, 18.1940, 17.5330, 16.8511, 16.3089, 15.9699],
        [19.7329, 18.3243, 18.1707, 18.6189, 19.1712, 19.6893, 20.1480],
        [25.6944, 21.6685, 22.4947, 23.9591, 25.4001, 26.6402, 27.5830],
        [23.9560, 26.4983, 26.1450, 25.0248, 23.3990, 21.3961, 19.7698],
        [18.6392, 18.8674, 18.7205, 18.5370, 18.4049, 18.3384, 18.3247],
        [19.4334, 18.9493, 18.9385, 19.0759, 19.2311, 19.3722, 19.4951],
        [ 8.1858,  7.4357,  7.0766,  6.9519,  6.9209,  6.9217,  6.9255],
        [15.6443, 16.3370, 16.1960, 15.8621, 15.5711, 15.3847, 15.2991]])
ground truth
tensor([[16.5176, 11.2835, 13.5981, 17.0700, 20.7391, 17.7538, 11.2704],
        [31.4201, 30.3288, 15.9155, 12.9393, 16.5958, 15.7738, 21.3719],
        [38.2937, 22.5624, 21.9529, 20.8617, 18.8634, 24.3339, 44.2035],
        [ 0.2126, 12.0607, 20.5215, 20.4649, 31.9586, 38.6480, 41.0856],
        [12.3299, 17.6871, 21.5278,  7.2988,  3.4014, 14.8951, 12.7268],
        [24.4477, 1

batch_predictions
tensor([[17.2799, 17.1374, 17.1755, 17.2315, 17.2764, 17.3148, 17.3515],
        [16.2620, 19.6003, 20.0637, 19.5898, 19.1043, 18.7501, 18.5548],
        [29.3880, 36.7929, 37.0023, 36.2366, 35.0455, 33.3625, 30.7034],
        [23.6399, 25.1890, 25.0703, 24.4014, 23.5089, 22.5819, 21.7796],
        [14.4288, 19.3182, 22.6821, 21.3827, 19.0719, 15.7228, 14.3834],
        [17.6772, 18.6039, 18.4321, 18.0174, 17.6378, 17.3786, 17.2481],
        [15.2253, 18.8464, 19.3073, 18.4459, 17.4856, 16.6264, 15.9554],
        [17.1788, 21.2622, 20.1404, 18.2762, 16.2816, 14.8871, 14.6179]])
ground truth
tensor([[16.9385, 10.1657,  8.7454, 11.4808, 12.8617, 14.6107, 22.5934],
        [ 1.5255, 15.8206, 14.7422, 18.7533, 21.9227, 24.4477, 11.9542],
        [19.9369, 18.8848, 33.9427, 55.6418, 14.1899,  0.0000, 15.4787],
        [14.2432, 19.1185, 21.0317, 34.8639, 28.4722, 30.2863, 21.1593],
        [10.6293, 15.9439, 18.1689, 17.0635, 27.6219, 35.3175, 30.0737],
        [28.7132, 2

batch_predictions
tensor([[15.6108, 19.2451, 19.9290, 19.0974, 18.1402, 17.2737, 16.6055],
        [15.2633, 16.3018, 16.3842, 16.0913, 15.7623, 15.5114, 15.3563],
        [24.6761, 25.5426, 25.4089, 25.0453, 24.6760, 24.3937, 24.2281],
        [19.7211, 18.1080, 18.6380, 19.2863, 19.7901, 20.1605, 20.4302],
        [15.9947, 20.2612, 20.8492, 19.8768, 18.5969, 17.2353, 16.1449],
        [19.4270, 25.6100, 27.6598, 26.7499, 24.6152, 20.2984, 17.7197],
        [20.8541, 20.0912, 19.9490, 20.1067, 20.3898, 20.7205, 21.0598],
        [15.0917, 14.3273, 14.7493, 15.4311, 16.0211, 16.4626, 16.7569]])
ground truth
tensor([[11.5505, 14.0023, 16.9926, 26.7574, 25.0142, 15.5896, 13.0244],
        [14.1504, 13.3482, 13.1247, 13.6244, 15.4524, 14.1110,  9.6265],
        [17.4382, 19.1215, 30.2735, 42.0831, 36.7438, 16.9516, 15.5313],
        [36.6649, 20.5813, 19.4371, 15.0973, 16.3861, 13.9795, 18.4640],
        [11.7205, 12.4433, 22.0663, 24.0788, 25.8362,  9.4388, 10.5017],
        [19.2177, 3

batch_predictions
tensor([[21.2548, 24.0753, 23.6073, 22.5923, 21.4187, 20.2427, 19.3531],
        [19.9009, 26.0243, 25.1132, 22.6525, 18.8014, 16.6239, 16.8296],
        [24.8347, 21.0106, 21.4130, 22.7276, 24.0982, 25.3517, 26.4105],
        [27.1112, 27.9536, 27.9293, 27.6544, 27.3325, 27.0644, 26.8918],
        [21.3247, 22.7181, 22.3555, 21.6888, 21.0196, 20.4658, 20.0928],
        [22.0988, 29.6901, 28.8956, 26.8147, 22.9607, 18.3877, 18.4513],
        [38.2744, 38.1435, 38.1539, 38.1547, 38.1517, 38.1481, 38.1448],
        [13.2000, 16.4314, 20.3551, 19.3159, 17.7921, 16.3679, 15.3953]])
ground truth
tensor([[13.9795, 18.4640, 31.9963, 13.4403, 14.4398, 11.9805,  3.9716],
        [10.0624, 17.0068, 31.5760, 42.7863, 30.5981, 20.9042, 20.1247],
        [42.6871, 34.0703, 16.8226, 19.0760, 14.9943, 17.0918, 26.9133],
        [19.5160, 19.3319, 27.8538, 32.5750, 23.2246, 21.4229, 14.7685],
        [19.6995, 23.7103, 30.5414, 26.9416, 15.0652, 14.6967, 14.2007],
        [16.1494, 1

batch_predictions
tensor([[25.9251, 21.7233, 22.4939, 23.9396, 25.3931, 26.6868, 27.7104],
        [ 7.3609,  7.9312,  8.0958,  8.1452,  8.1570,  8.1617,  8.1686],
        [15.0373, 18.1964, 20.8717, 19.9071, 18.1455, 16.0313, 14.4908],
        [18.2398, 18.5614, 18.4625, 18.2821, 18.1329, 18.0451, 18.0134],
        [27.1629, 25.4714, 24.8741, 24.9204, 25.3491, 25.9850, 26.7386],
        [ 5.8350,  4.5750,  4.4819,  4.5222,  4.5577,  4.5828,  4.6011],
        [18.7618, 17.1061, 17.9106, 18.6621, 19.2290, 19.6320, 19.9118],
        [24.6438, 21.7762, 21.5058, 22.4119, 23.5081, 24.5797, 25.5829]])
ground truth
tensor([[4.2942e+00, 0.0000e+00, 0.0000e+00, 4.9745e+00, 3.5445e+01, 4.5196e+01,
         3.3759e+01],
        [8.1405e+00, 8.7059e+00, 1.0692e+01, 1.8674e+01, 1.3309e+01, 8.1799e+00,
         8.1405e+00],
        [2.8345e-02, 4.9036e+00, 8.7018e+00, 1.1026e+01, 2.0408e+01, 3.7217e+01,
         2.8160e+01],
        [1.3322e+01, 1.2296e+01, 1.0639e+01, 1.5952e+01, 2.3895e+01, 2.7196

batch_predictions
tensor([[18.8607, 18.1582, 18.1244, 18.3138, 18.5291, 18.7191, 18.8781],
        [ 7.2748,  7.5131,  7.5688,  7.5718,  7.5666,  7.5512,  7.5444],
        [20.5408, 23.4222, 22.5709, 20.9658, 18.7948, 16.8855, 16.4606],
        [28.9861, 30.1215, 29.9538, 29.6543, 29.4142, 29.2657, 29.1969],
        [12.2439, 11.2444, 10.1094,  8.4436,  7.0426,  6.5646,  6.4660],
        [16.6159, 18.8133, 18.7679, 18.2473, 17.7548, 17.4021, 17.1986],
        [26.1263, 29.9498, 29.6511, 28.4731, 26.6581, 23.9770, 21.3514],
        [19.2836, 22.4605, 21.6300, 20.1574, 18.3835, 16.7673, 16.2164]])
ground truth
tensor([[ 4.5371,  0.0000,  8.0878,  7.8511,  9.2188,  7.8248, 12.4408],
        [15.2288, 14.4924, 19.3188, 17.5171, 25.8285, 16.5965, 10.4287],
        [21.0884, 23.6961, 28.4014, 13.4921, 14.0023, 14.5550, 18.2115],
        [26.1573, 32.8643, 35.7312, 23.0800, 25.5655, 26.4466, 30.6023],
        [17.5829,  5.5234,  7.1936,  8.4298,  5.8390, 13.3614,  6.4703],
        [10.8233, 2

batch_predictions
tensor([[15.1943, 18.6542, 19.0010, 18.3992, 17.8123, 17.3839, 17.1397],
        [23.2005, 21.6141, 21.3306, 21.6768, 22.1984, 22.7233, 23.2036],
        [20.4069, 24.3000, 23.3523, 21.3868, 18.5645, 16.5751, 16.5549],
        [20.3493, 25.1718, 24.1246, 22.0446, 19.0894, 16.8915, 16.9748],
        [ 5.4589,  5.7603,  5.7946,  5.7949,  5.7951,  5.7922,  5.7929],
        [11.8832, 10.7767, 10.2740, 10.0266,  9.9437,  9.9763, 10.0824],
        [ 9.5470,  8.3123,  7.6198,  7.4051,  7.3646,  7.3708,  7.3819],
        [20.1798, 27.1976, 30.5113, 29.6586, 27.1249, 20.8222, 17.6650]])
ground truth
tensor([[17.1357, 17.3724, 18.3062, 25.9337,  5.7864,  4.6949, 11.0731],
        [16.3204, 14.5713, 11.7175, 15.6891, 25.7496, 33.9164, 29.0373],
        [15.5576, 17.5302, 25.6970, 20.2525, 29.4450,  4.5239,  7.2199],
        [11.5788, 13.2511, 19.8696, 28.7840, 25.1276, 16.8651, 14.0590],
        [ 4.9842,  6.4703,  6.2599, 11.5334, 15.2946, 18.3982,  5.5892],
        [ 5.8107,  

batch_predictions
tensor([[10.8255, 11.9853, 13.7212, 16.5380, 17.1180, 16.2713, 15.3551],
        [22.7367, 28.6629, 27.7890, 25.7246, 22.2148, 18.4730, 18.4681],
        [29.5708, 33.9164, 33.8849, 33.0081, 31.8828, 30.7254, 29.5242],
        [36.4580, 35.4834, 35.4032, 35.5471, 35.7206, 35.8801, 36.0181],
        [14.0313, 13.3185, 14.1026, 14.9619, 15.5450, 15.8574, 16.0187],
        [17.8319, 18.6483, 18.3855, 17.9385, 17.5544, 17.3044, 17.1920],
        [17.9869, 19.4306, 18.8141, 17.8862, 16.9917, 16.2864, 15.8806],
        [16.7252, 18.5659, 18.3343, 17.7733, 17.2759, 16.9284, 16.7372]])
ground truth
tensor([[13.5587, 10.4156, 13.3614, 18.2799, 27.3540, 36.9148,  7.8774],
        [15.4195, 23.0867, 37.2874, 24.0363, 28.5573,  9.2687,  9.6372],
        [21.2868, 23.1151, 35.6718, 45.7625, 36.7630,  8.0782, 17.1485],
        [67.6446, 59.4388, 35.5726, 34.4246, 21.9104, 25.3260, 37.4717],
        [20.4892, 12.4540, 14.1767, 13.7033, 11.2704, 15.3998, 23.4876],
        [21.8112, 2

batch_predictions
tensor([[20.2149, 20.7252, 20.4993, 20.1435, 19.8405, 19.6482, 19.5702],
        [16.7060, 21.2331, 20.3563, 18.5878, 16.4995, 15.0120, 14.6723],
        [26.7670, 28.4983, 28.3214, 27.8062, 27.2697, 26.8273, 26.5247],
        [17.9058, 19.0414, 18.7361, 18.1526, 17.6025, 17.1947, 16.9632],
        [14.8616, 14.0131, 14.1070, 14.5678, 15.0871, 15.5637, 15.9646],
        [22.4590, 31.3689, 32.0633, 30.1094, 25.6557, 18.6213, 18.3333],
        [24.8788, 25.1725, 25.1277, 24.9786, 24.8253, 24.7165, 24.6668],
        [16.3571, 20.5046, 19.6482, 18.1105, 16.4134, 15.0682, 14.5176]])
ground truth
tensor([[17.7275, 13.7428, 18.5429, 24.8948, 29.6028, 11.0994, 13.5455],
        [16.8226, 25.1276, 26.8991, 12.1457, 12.7693, 19.1468, 17.4603],
        [28.9683, 35.3175, 15.4337, 21.7120, 18.8776, 20.3231, 28.2738],
        [19.8271, 25.2126, 24.9717,  0.2409,  7.2562,  9.3396, 12.2591],
        [ 8.9427,  9.0084, 10.2183,  9.4161, 11.7701, 20.5550, 21.0416],
        [20.6491, 1

batch_predictions
tensor([[18.1553, 16.3175, 17.7689, 19.1681, 20.1381, 20.7703, 21.2094],
        [19.9754, 23.8702, 23.1049, 21.4205, 19.1397, 17.1245, 16.7495],
        [16.1648, 19.3266, 19.3556, 18.6347, 17.8744, 17.2514, 16.8228],
        [25.7888, 34.6035, 35.0384, 33.4712, 30.4533, 24.7767, 20.1396],
        [18.4083, 20.4321, 19.9672, 19.2247, 18.5420, 18.0136, 17.6813],
        [17.9544, 19.1159, 18.9127, 18.4937, 18.1352, 17.9013, 17.7890],
        [37.4878, 34.6606, 34.2748, 34.8539, 35.5938, 36.1944, 36.6045],
        [18.6773, 18.8973, 18.8345, 18.7175, 18.6217, 18.5672, 18.5495]])
ground truth
tensor([[ 3.8407,  8.8010, 10.3033,  5.9949, 11.0686, 16.6241, 30.0595],
        [15.7943, 18.1352, 23.6586, 39.3083, 29.1426, 19.5292, 14.7028],
        [12.2699, 14.5055, 20.3446, 28.6560, 33.0221, 12.0857, 15.4261],
        [20.5924, 25.5527, 33.3900, 51.8707, 41.5958, 21.7262, 21.3861],
        [13.6196, 17.4461, 17.9138, 17.0777, 20.8333, 21.5561, 16.9359],
        [19.5949, 2

batch_predictions
tensor([[15.9593, 21.2266, 22.0312, 20.5763, 18.0836, 15.2404, 14.5034],
        [11.7178, 14.5391, 18.6895, 20.4682, 19.1685, 17.3488, 15.3572],
        [ 8.2082,  9.0539,  9.4565,  9.6953,  9.8528,  9.9722, 10.0764],
        [ 8.2045,  8.1169,  8.0768,  8.0617,  8.0584,  8.0536,  8.0516],
        [13.4041, 12.5578, 12.5320, 12.9815, 13.7197, 14.5532, 15.2948],
        [16.9728, 16.7527, 16.8767, 17.0603, 17.2186, 17.3376, 17.4222],
        [22.1942, 20.3552, 20.1005, 20.6343, 21.3261, 21.9938, 22.6002],
        [28.1682, 29.5556, 29.5494, 29.1609, 28.6718, 28.2180, 27.8676]])
ground truth
tensor([[17.8713, 23.9512, 31.2075, 22.7183,  7.0153, 15.6604, 10.5017],
        [ 8.9427, 10.7143, 10.8277, 22.9308, 30.3571, 18.2540,  9.0561],
        [ 7.1410,  5.9442, 11.3624, 15.5839, 21.3440, 16.7806,  4.0505],
        [12.0068, 15.9127,  3.9716,  6.4308,  5.7207,  6.5755,  7.8906],
        [ 2.9053,  3.3305,  5.7115,  5.8673, 13.2937, 13.6054, 20.5357],
        [18.4508, 1

batch_predictions
tensor([[24.8884, 33.9287, 34.8712, 32.7491, 27.4905, 19.0618, 19.0285],
        [23.8060, 23.2749, 23.2040, 23.3121, 23.4766, 23.6618, 23.8506],
        [23.4889, 20.2084, 21.6232, 23.1483, 24.5571, 25.7381, 26.5794],
        [23.6664, 22.0476, 21.8092, 22.2120, 22.7724, 23.3166, 23.7962],
        [20.9359, 21.5223, 21.3747, 21.0793, 20.8138, 20.6371, 20.5554],
        [23.0082, 20.4049, 21.4898, 22.6425, 23.5905, 24.2996, 24.7917],
        [15.9430, 20.9022, 22.7954, 21.4512, 19.0766, 15.9559, 15.0075],
        [22.2390, 22.2434, 22.2520, 22.2456, 22.2486, 22.2359, 22.2069]])
ground truth
tensor([[1.4953e+01, 3.5902e+01, 6.9726e+01, 5.3353e+01, 1.8924e+01, 1.8280e+01,
         1.2178e+01],
        [4.0136e+01, 3.8435e+01, 2.0337e+01, 1.6582e+01, 1.3960e+01, 1.4314e+01,
         2.3030e+01],
        [1.4527e+01, 1.5944e+01, 1.9388e+01, 1.7517e+01, 3.3702e+01, 2.3484e+01,
         2.6488e+01],
        [1.5958e+01, 1.5065e+01, 1.9615e+01, 2.1259e+01, 2.2548e+01, 2.8316

batch_predictions
tensor([[15.1063, 14.2979, 14.3596, 14.8781, 15.6016, 16.4241, 17.2606],
        [27.1479, 29.9719, 29.7539, 28.9313, 27.8637, 26.6000, 25.1358],
        [19.0668, 19.2437, 19.1760, 19.0322, 18.8975, 18.8113, 18.7806],
        [15.9492, 19.1437, 19.0379, 17.9826, 16.7617, 15.6178, 14.7204],
        [18.0351, 21.0283, 20.1490, 18.7046, 17.0646, 15.5770, 14.9896],
        [20.8892, 21.2297, 21.1393, 20.9576, 20.7975, 20.6971, 20.6579],
        [18.9836, 17.5567, 17.5862, 18.0721, 18.5367, 18.9211, 19.2304],
        [17.9911, 22.9309, 22.9979, 21.7967, 19.8537, 17.6361, 16.4820]])
ground truth
tensor([[20.4235,  5.1420,  9.3766,  6.7728,  7.4566,  9.9027, 18.7138],
        [16.8464, 18.4903, 31.6018, 33.0615, 14.8606, 14.1373, 16.1494],
        [32.4688, 28.7698,  5.5272, 14.2574, 16.7659, 14.9660, 18.4524],
        [13.8605, 33.4467, 20.6774,  5.7115,  5.3713, 13.6338, 15.8872],
        [18.1406, 14.8668, 13.8180, 25.6803, 26.7007, 10.1616, 12.2732],
        [25.5811, 2

batch_predictions
tensor([[12.2548, 11.1616,  9.9219,  8.3022,  7.1732,  6.8101,  6.7338],
        [13.1411, 16.1304, 19.7057, 19.1843, 18.2264, 17.3402, 16.7370],
        [19.3193, 25.7519, 24.6590, 21.8516, 17.7030, 16.4318, 16.7548],
        [12.7033, 12.1837, 11.8394, 11.5437, 11.2266, 10.8144, 10.2026],
        [20.5288, 24.7878, 23.9056, 22.4399, 20.6815, 18.7808, 17.8257],
        [33.2853, 38.3237, 38.2090, 38.0564, 37.9296, 37.8361, 37.7747],
        [17.3599, 19.6532, 18.9557, 17.9705, 17.0504, 16.2942, 15.7884],
        [ 6.5223,  6.7429,  6.7675,  6.7634,  6.7602,  6.7528,  6.7500]])
ground truth
tensor([[12.8617,  4.3924,  7.4040,  8.0747,  7.3514, 11.3887, 15.9258],
        [14.2149, 12.8260, 13.8889, 18.3390, 31.4059, 24.7024,  5.2721],
        [12.7976, 12.6842, 13.4070, 23.0300, 35.2749, 32.9932, 14.9943],
        [ 8.7059,  8.1931,  8.5481, 10.1262, 11.1783, 19.6213, 14.9790],
        [19.3452, 14.2999, 18.0556, 19.4303, 31.5334,  9.8498, 14.8101],
        [28.7415, 3

batch_predictions
tensor([[18.7461, 19.2708, 19.0764, 18.6981, 18.3436, 18.0978, 17.9801],
        [17.1364, 17.6736, 17.4006, 17.0127, 16.7087, 16.5347, 16.4784],
        [24.8359, 20.9529, 22.2527, 24.0385, 25.7869, 27.2295, 28.2124],
        [18.0464, 22.7420, 21.3367, 19.1112, 16.7030, 15.3522, 15.4690],
        [22.8281, 24.6409, 24.5180, 23.6973, 22.5110, 21.1888, 20.0627],
        [27.7262, 32.1499, 32.1630, 30.9289, 28.8199, 25.3919, 21.7886],
        [ 9.7546, 10.6795, 11.4743, 12.3870, 13.6888, 15.5211, 16.5184],
        [18.2151, 23.9656, 24.6830, 23.4574, 21.0129, 17.5999, 16.5439]])
ground truth
tensor([[27.4660, 30.2438,  9.8073, 13.1094,  9.5522, 16.1706, 17.0493],
        [17.4382, 13.7296,  6.6149, 11.2572,  8.9032, 10.9548, 14.7159],
        [36.0119, 10.1616, 15.6604, 16.4541, 19.0051, 26.2188, 38.4921],
        [15.6497, 23.0011, 27.6565, 12.9274, 15.9258, 10.5339, 16.1362],
        [16.8651, 15.1786, 15.3061, 29.9603, 41.1139, 32.4830, 23.7528],
        [14.1723, 2

batch_predictions
tensor([[17.0130, 15.8211, 15.9093, 16.3916, 16.8479, 17.2221, 17.5194],
        [31.4065, 34.1951, 34.0309, 33.4558, 32.9112, 32.5081, 32.2584],
        [18.9551, 19.8217, 19.5672, 19.0912, 18.6558, 18.3515, 18.1969],
        [21.9830, 28.9033, 28.0246, 25.8569, 22.1443, 18.5321, 18.6248],
        [26.9264, 32.9463, 33.3060, 31.7809, 28.8561, 23.6566, 20.3733],
        [25.2796, 28.5073, 28.2335, 27.2259, 25.8367, 24.1390, 22.4732],
        [ 9.9273, 11.2373, 12.7448, 15.3082, 18.5934, 17.6961, 16.5118],
        [18.9028, 24.4604, 24.1060, 22.8444, 20.8997, 18.8063, 17.9007]])
ground truth
tensor([[31.0941, 27.3384, 10.6718, 13.1378, 11.7772, 14.3141, 21.2302],
        [22.3304, 27.7223, 30.8390, 40.4655, 32.8774, 23.9085, 26.1441],
        [14.1636, 11.6649, 20.3314, 29.6686, 30.1289, 14.5581, 11.0994],
        [14.3282, 25.6236, 46.3861, 13.1661,  0.0000, 10.7710, 18.4666],
        [18.5516, 25.9354, 29.1241, 54.7902, 53.2738, 58.0499, 28.7415],
        [44.0760, 3

batch_predictions
tensor([[18.3247, 17.1371, 16.9541, 17.2694, 17.7119, 18.1550, 18.5695],
        [14.8444, 19.1162, 20.4893, 18.9406, 16.7427, 14.3381, 13.3264],
        [16.0739, 16.7543, 16.5470, 16.1915, 15.9043, 15.7333, 15.6674],
        [15.9764, 20.6389, 21.1682, 19.6705, 17.2750, 14.8082, 14.1119],
        [22.7300, 24.3469, 24.0523, 23.3664, 22.6222, 21.9602, 21.4676],
        [13.7338, 17.8663, 21.1095, 19.6322, 17.5287, 15.1713, 13.9050],
        [ 7.1861,  6.3314,  6.1435,  6.1318,  6.1479,  6.1602,  6.1693],
        [12.8392, 16.8135, 21.3030, 21.7910, 20.9238, 19.7022, 18.4634]])
ground truth
tensor([[29.1525, 26.7857, 11.2103, 13.3362, 14.5833, 13.8180, 21.3152],
        [22.9308, 30.3571, 18.2540,  9.0561, 10.2041,  7.0578,  9.8781],
        [11.2704, 15.3998, 23.4876, 19.9895,  7.6013, 11.0074,  9.2057],
        [ 9.9421, 11.4940, 13.5850, 27.6433, 22.6065, 10.0079,  8.8769],
        [19.5862, 28.6990, 22.8600, 27.6361, 13.1944, 13.5062, 13.9172],
        [15.3472, 2

batch_predictions
tensor([[ 8.7352,  8.8285,  8.8703,  8.8781,  8.8691,  8.8479,  8.8334],
        [24.8066, 23.7117, 23.4965, 23.6964, 24.0656, 24.4857, 24.8972],
        [ 7.2602,  6.5110,  6.2542,  6.1978,  6.1998,  6.2096,  6.2180],
        [21.3566, 22.3383, 22.1030, 21.4924, 20.8030, 20.2003, 19.7825],
        [ 4.9944,  4.4212,  4.3936,  4.4205,  4.4412,  4.4548,  4.4648],
        [34.5057, 32.4293, 32.0163, 32.2925, 32.7998, 33.3477, 33.8538],
        [15.1030, 14.7427, 15.8834, 16.8398, 17.5253, 17.9387, 18.1523],
        [35.0936, 38.2103, 38.1003, 37.9316, 37.7988, 37.7062, 37.6484]])
ground truth
tensor([[ 7.9300,  5.9442,  8.6928, 10.6786, 17.0568,  6.3519,  5.8916],
        [15.7549, 19.2267, 12.5986, 16.4519, 18.3193, 32.4698, 31.8517],
        [15.6365, 16.6491,  9.4819,  6.2730,  7.1015, 11.1783, 10.8101],
        [17.8146, 20.8333, 29.2800, 31.6752,  8.4325, 17.2477, 13.5488],
        [ 6.4626,  8.5034,  6.1933,  4.5210,  5.4989, 10.0057,  7.4830],
        [49.2914,  

batch_predictions
tensor([[15.2399, 18.8910, 19.6144, 18.7061, 17.6897, 16.7763, 16.0656],
        [27.0201, 22.5465, 24.0059, 25.8798, 27.6914, 29.2450, 30.3964],
        [ 7.1870,  7.5873,  7.6760,  7.6840,  7.6766,  7.6621,  7.6557],
        [21.8052, 23.6945, 23.3382, 22.5948, 21.8112, 21.1202, 20.6129],
        [26.8072, 28.2371, 28.3031, 27.8651, 27.1821, 26.4067, 25.6494],
        [17.4484, 18.3586, 18.0878, 17.6305, 17.2435, 16.9942, 16.8827],
        [19.5419, 19.1166, 19.1083, 19.2164, 19.3322, 19.4361, 19.5280],
        [22.7478, 25.4519, 25.0483, 24.0680, 22.8780, 21.6143, 20.5573]])
ground truth
tensor([[ 7.4546, 16.6241, 35.5300, 35.5159, 15.8730, 15.5612, 11.4938],
        [41.7234, 16.6950,  7.6531, 21.3719, 21.3577, 29.6769, 48.4977],
        [ 2.6502, 14.3566, 15.0368, 13.7046, 14.6825, 20.9467,  2.2251],
        [22.0147, 20.1999, 25.2893, 42.6092, 21.2914, 21.0153, 25.0789],
        [39.5408, 46.9671, 42.4603, 19.1327, 19.2035, 17.8146, 19.9688],
        [19.5423, 1

batch_predictions
tensor([[17.6806, 22.4899, 21.5630, 19.8460, 17.8369, 16.4539, 16.1856],
        [14.3158, 17.7971, 20.3349, 19.5515, 18.6653, 17.9678, 17.5413],
        [21.8190, 22.2093, 22.1194, 21.9204, 21.7380, 21.6184, 21.5675],
        [24.3610, 28.8110, 28.4237, 27.0203, 24.7449, 21.3408, 19.3511],
        [22.9173, 28.1158, 27.5422, 25.5236, 21.6698, 17.8483, 17.8816],
        [ 7.4004,  7.4515,  7.4642,  7.4636,  7.4589,  7.4417,  7.4328],
        [19.3982, 21.9593, 21.4461, 20.6284, 19.8493, 19.2055, 18.7567],
        [24.3403, 24.2754, 24.2831, 24.2932, 24.2948, 24.3023, 24.3197]])
ground truth
tensor([[11.5079, 14.4983, 19.8696, 37.8260, 30.1304,  4.2234, 14.5975],
        [10.6151, 14.6542, 13.0952, 14.3141, 21.3861, 23.0017, 10.7001],
        [22.4487, 18.5429, 20.8180,  8.1931,  9.4555, 16.3598, 15.2420],
        [17.1485, 15.4478, 38.8322, 39.6542, 21.0601, 17.1627, 12.7268],
        [25.6519, 37.0465, 31.4201, 17.2761, 15.6746, 12.4717, 17.9280],
        [ 8.4298,  

batch_predictions
tensor([[21.1683, 21.1112, 21.1213, 21.1289, 21.1294, 21.1356, 21.1521],
        [31.6762, 31.3416, 31.3104, 31.3573, 31.4094, 31.4714, 31.5487],
        [26.0588, 22.5108, 22.7251, 23.8781, 25.0534, 26.0812, 26.9138],
        [19.2930, 23.1545, 22.0181, 19.9206, 17.2661, 15.8458, 15.9353],
        [12.3295, 14.3770, 16.8011, 16.9347, 16.2784, 15.6470, 15.1564],
        [19.2368, 24.9361, 23.6538, 21.0151, 17.5742, 16.1750, 16.4568],
        [24.6982, 21.0721, 21.7152, 23.1053, 24.4997, 25.7338, 26.7379],
        [11.8779, 11.7361, 11.5997, 11.4617, 11.3070, 11.1118, 10.8313]])
ground truth
tensor([[19.7704, 24.2914, 36.5363, 22.4348, 24.7307, 28.9116, 23.0159],
        [49.0646, 43.1122, 23.6395, 21.4711, 15.7596, 28.6990, 52.4518],
        [26.3747, 26.0629,  2.4093, 12.2307, 23.4127, 39.5125,  6.8736],
        [11.5930, 12.9393, 16.1565, 24.5607, 27.6644,  5.0879, 12.7126],
        [16.8464, 21.1731, 29.1557,  8.9295,  3.9321, 12.8222, 11.9411],
        [18.5374, 2

batch_predictions
tensor([[34.8446, 38.1969, 38.0789, 37.9217, 37.8021, 37.7199, 37.6685],
        [23.2711, 27.8506, 27.2733, 25.4873, 22.3679, 18.8742, 18.3206],
        [28.2891, 24.0384, 24.6150, 26.0134, 27.4003, 28.5751, 29.4694],
        [26.4086, 33.1961, 33.3025, 31.6778, 28.7308, 23.0613, 19.8500],
        [23.5535, 22.2257, 21.7700, 21.8245, 22.1769, 22.6913, 23.3095],
        [20.1848, 26.0256, 30.8805, 32.9656, 33.1656, 32.7029, 32.0569],
        [21.9794, 28.7772, 28.1683, 26.4339, 23.4954, 19.5922, 18.6526],
        [17.2335, 22.2381, 22.8480, 22.0954, 20.9341, 19.4666, 18.0421]])
ground truth
tensor([[25.0992, 25.0142, 38.7755, 61.8481, 20.0822, 26.4739, 27.9195],
        [21.1599, 36.8885, 26.7491, 14.8080,  8.9295,  0.0000,  0.0000],
        [24.1497, 16.7375, 18.1548, 19.3169, 25.5952, 47.6332, 14.7534],
        [16.9076, 20.6633, 30.9949, 46.7687, 34.8356, 19.1610, 16.2840],
        [39.4558, 36.1536, 14.2290, 10.3033, 16.8934, 16.6383, 33.7443],
        [17.6020, 2

batch_predictions
tensor([[18.8288, 20.9657, 20.6227, 19.9565, 19.3371, 18.8666, 18.5780],
        [22.5694, 21.1510, 20.7414, 20.9189, 21.3782, 21.9421, 22.5432],
        [11.5014, 14.1256, 16.9310, 19.0435, 19.3064, 18.7138, 17.9550],
        [18.7501, 19.1618, 19.0381, 18.7839, 18.5547, 18.4058, 18.3417],
        [15.3305, 18.2448, 18.2010, 17.4618, 16.7347, 16.1658, 15.7833],
        [20.0817, 27.7997, 27.2615, 24.4140, 18.3214, 16.2219, 16.4760],
        [13.5122, 12.6951, 13.2946, 14.2514, 15.0910, 15.5449, 15.7415],
        [26.8156, 24.6577, 24.9345, 25.6790, 26.3397, 26.8439, 27.2121]])
ground truth
tensor([[15.4787, 29.1426, 28.4850, 29.1031, 15.2683, 16.5834, 16.1757],
        [34.1412, 13.7613, 15.1077,  9.4813, 18.4099, 19.9972, 30.0595],
        [ 0.0000, 11.7772, 10.7143, 12.3299, 13.8605, 33.4467, 20.6774],
        [13.4779, 32.8656, 22.9592, 10.4167, 15.0227, 14.1015, 15.3770],
        [13.7559, 12.4145, 16.9516, 21.5939, 21.8569,  7.2199, 13.1641],
        [14.0306, 2

batch_predictions
tensor([[19.2146, 21.9736, 21.2977, 20.2369, 19.1432, 18.1402, 17.4412],
        [10.5045, 11.3952, 12.4583, 14.0134, 15.9299, 16.0555, 15.5294],
        [22.0250, 22.3359, 22.2880, 22.1540, 22.0318, 21.9548, 21.9240],
        [16.9231, 16.5700, 16.6044, 16.7421, 16.8884, 17.0191, 17.1302],
        [21.9740, 23.2009, 23.0224, 22.5155, 21.9934, 21.5829, 21.3297],
        [22.0206, 30.3371, 29.8208, 27.5523, 22.6497, 17.7520, 18.1130],
        [ 4.6886,  4.5080,  4.5074,  4.5228,  4.5332,  4.5395,  4.5444],
        [18.9935, 19.5312, 19.2948, 18.9378, 18.6462, 18.4698, 18.4038]])
ground truth
tensor([[ 7.1147, 28.0773, 24.0268,  8.8506, 19.0689, 16.7543, 17.9116],
        [ 2.3668,  8.4042, 16.7659,  1.5590,  0.0000,  0.0000, 12.9110],
        [15.9439, 18.4240, 16.9643, 11.0261, 15.0227, 16.4541, 21.7829],
        [22.8958, 20.0947, 10.5471, 12.7433, 10.1262, 14.4792, 19.5423],
        [31.9963, 13.4403, 14.4398, 11.9805,  3.9716, 16.8201, 23.9348],
        [13.0385, 1

batch_predictions
tensor([[37.5675, 38.1942, 38.1553, 38.0978, 38.0580, 38.0344, 38.0221],
        [19.2274, 18.5275, 18.4015, 18.5805, 18.9076, 19.3124, 19.7550],
        [ 5.5766,  5.1600,  5.1147,  5.1290,  5.1435,  5.1528,  5.1597],
        [26.9380, 22.4267, 23.8706, 25.7419, 27.5740, 29.1782, 30.3900],
        [29.9489, 30.4746, 30.3929, 30.2555, 30.1538, 30.0980, 30.0770],
        [16.5635, 21.2148, 23.5035, 23.0461, 22.0389, 20.6981, 18.9732],
        [13.9983, 17.5080, 21.3400, 20.5603, 19.0255, 17.2761, 15.7272],
        [21.1458, 21.9831, 21.7807, 21.3868, 21.0287, 20.7809, 20.6557]])
ground truth
tensor([[28.8549, 22.5057, 29.1241, 37.3866, 56.7460, 50.1701, 32.6531],
        [42.0305, 30.3919, 13.3482, 12.4540, 14.3740,  8.6796,  8.7454],
        [ 5.0595,  9.1270,  4.8753,  6.2925,  6.6468,  7.3271,  7.5680],
        [37.4717, 22.8175, 18.1973, 13.1236, 11.3662, 21.6837, 47.6190],
        [30.6023, 32.1410, 36.4413, 37.3225, 20.7259, 24.7764, 25.5392],
        [14.2007, 1

batch_predictions
tensor([[29.5861, 29.0110, 28.9898, 29.1067, 29.2379, 29.3613, 29.4747],
        [23.1207, 21.3065, 21.2699, 21.8494, 22.4544, 22.9701, 23.3865],
        [15.8314, 15.5204, 15.5397, 15.6386, 15.7446, 15.8413, 15.9257],
        [17.5350, 18.0574, 17.9245, 17.6666, 17.4522, 17.3232, 17.2712],
        [19.0763, 17.5159, 18.4717, 19.2579, 19.8288, 20.2238, 20.4918],
        [17.2493, 17.6668, 17.5174, 17.2818, 17.0991, 16.9973, 16.9634],
        [16.1079, 19.8925, 22.8549, 23.3157, 22.7187, 21.8280, 20.9099],
        [22.2913, 25.7480, 25.0568, 23.5306, 21.3165, 18.8481, 17.9375]])
ground truth
tensor([[23.2246, 21.4229, 14.7685, 17.8985, 20.6996, 26.6833, 31.8648],
        [23.5544, 16.4116, 11.4654, 18.3815, 17.3469, 40.4337, 25.8078],
        [21.4144, 21.6695, 13.4495, 11.8622, 11.2387, 11.3520, 15.0935],
        [15.4195, 15.5612, 27.0408, 17.0210, 11.4654, 11.7063, 11.5930],
        [26.0771, 26.4598, 15.4620, 13.5488, 16.9076, 12.5283, 22.9734],
        [18.9506, 2

batch_predictions
tensor([[23.7501, 31.8861, 31.7155, 29.9938, 26.9005, 20.9618, 19.0418],
        [18.1543, 23.5357, 22.0972, 18.9932, 15.4802, 14.6919, 14.9032],
        [17.1783, 23.4969, 24.6446, 23.4703, 21.1833, 17.3414, 16.0544],
        [23.1006, 27.8081, 27.3185, 25.4817, 22.0483, 18.2147, 17.9075],
        [12.2042, 11.0213, 10.3003,  9.6714,  9.1224,  8.6941,  8.4036],
        [ 9.3746, 11.1505, 12.8318, 15.1806, 17.5066, 18.2437, 18.0171],
        [ 8.1459,  8.4953,  8.6009,  8.6257,  8.6211,  8.6112,  8.6060],
        [16.8434, 21.4866, 20.9953, 19.7526, 18.2186, 16.8256, 16.1055]])
ground truth
tensor([[18.3982, 16.0836, 21.0021, 31.0889, 33.9690, 17.0042, 14.9527],
        [17.4887, 17.7438, 23.4977, 10.3741,  0.0000,  7.4263, 11.4371],
        [ 9.3821, 14.3566, 18.1548, 32.7523, 32.7806, 11.5505, 12.9960],
        [16.9501, 25.6519, 37.0465, 31.4201, 17.2761, 15.6746, 12.4717],
        [ 4.7080,  7.1410,  5.6418,  6.5755, 13.2956, 12.6249, 16.7149],
        [13.2693, 1

batch_predictions
tensor([[11.4941, 14.2524, 18.4899, 20.4452, 19.1051, 17.2477, 15.1921],
        [25.3785, 23.3088, 22.9470, 23.4178, 24.1037, 24.7847, 25.3986],
        [ 5.8203,  4.8252,  4.7500,  4.7854,  4.8161,  4.8370,  4.8521],
        [17.5916, 16.4169, 17.2588, 17.9176, 18.3725, 18.6826, 18.8956],
        [14.1727, 13.2428, 13.4421, 14.0571, 14.7323, 15.3178, 15.7778],
        [36.3475, 37.6233, 37.6255, 37.4750, 37.3339, 37.2350, 37.1781],
        [14.8052, 16.4999, 16.6053, 16.1360, 15.6460, 15.2672, 15.0182],
        [15.0106, 19.3301, 21.3830, 19.6795, 16.9569, 14.2310, 13.5793]])
ground truth
tensor([[10.3741, 19.6712, 32.2279, 17.3895,  6.9728,  8.9427, 10.7143],
        [17.5170, 16.3549, 21.1593, 24.5890, 27.0550, 36.8481, 29.9461],
        [ 6.3634,  6.9728, 11.7914,  7.3271,  6.2217,  9.5380,  8.2766],
        [ 6.1941,  5.1420, 13.7691, 10.0868, 14.3346, 18.4245, 22.4882],
        [ 4.7212,  7.8248,  9.1399, 10.6128, 11.7964, 23.1983, 22.5802],
        [38.6338, 6

batch_predictions
tensor([[18.7106, 17.3172, 17.6344, 18.2346, 18.7227, 19.0884, 19.3590],
        [ 7.2551,  6.6719,  6.5052,  6.4767,  6.4824,  6.4890,  6.4942],
        [37.4695, 37.4786, 37.4774, 37.4686, 37.4621, 37.4620, 37.4613],
        [13.2110, 17.1250, 20.4099, 21.6400, 21.2771, 20.4713, 19.6558],
        [21.6533, 21.7899, 21.7632, 21.6862, 21.6133, 21.5701, 21.5600],
        [28.5454, 23.0598, 24.4425, 26.4473, 28.4803, 30.3140, 31.8197],
        [17.7034, 16.2396, 17.2139, 18.0053, 18.5421, 18.8931, 19.1251],
        [12.9224, 15.9739, 18.7108, 18.5249, 18.0726, 17.7245, 17.5335]])
ground truth
tensor([[14.2007, 13.6196, 17.4461, 17.9138, 17.0777, 20.8333, 21.5561],
        [ 9.5082, 16.8201, 17.5171,  6.7596,  7.7196,  7.3645,  8.8769],
        [56.7460, 50.1701, 32.6531, 27.1259, 22.0947, 30.4138, 45.9892],
        [10.2578, 17.0437,  6.6938, 28.2614, 26.0521, 13.5060, 12.0068],
        [20.0947, 22.4487, 18.5429, 20.8180,  8.1931,  9.4555, 16.3598],
        [36.0686, 1

batch_predictions
tensor([[15.7814, 18.7022, 19.1792, 18.5632, 17.9239, 17.4301, 17.1162],
        [12.7760, 15.5661, 18.6280, 17.8844, 16.9187, 16.1156, 15.5732],
        [28.0099, 27.6013, 27.5921, 27.6803, 27.7771, 27.8713, 27.9630],
        [17.6145, 22.4249, 21.7194, 19.9939, 17.7961, 16.2246, 15.9046],
        [ 6.5645,  4.3508,  4.1351,  4.1644,  4.2011,  4.2317,  4.2565],
        [14.6806, 14.5898, 14.6193, 14.6544, 14.6821, 14.7085, 14.7350],
        [19.9101, 21.9469, 21.5569, 20.7968, 20.0324, 19.3915, 18.9544],
        [24.2967, 25.6660, 25.5277, 24.9783, 24.3297, 23.7354, 23.2812]])
ground truth
tensor([[13.2228, 15.0935, 27.9337, 23.6820, 11.7347, 16.7234, 11.5930],
        [11.7044, 10.6391, 19.1084, 25.5129, 22.0279, 10.6786,  7.4171],
        [15.8305, 17.9989, 16.4966, 21.1451, 19.3027, 44.0760, 30.8673],
        [10.5208, 13.2299, 19.6081, 23.9611, 38.5192,  6.0100, 12.4408],
        [ 5.4280,  5.8390,  5.0028,  4.6769,  5.0312,  4.9036, 11.7772],
        [17.0777, 2

batch_predictions
tensor([[20.7389, 27.0017, 25.7390, 22.4990, 17.0775, 16.1550, 16.3775],
        [13.8902, 17.0720, 19.7733, 18.7769, 17.4797, 16.2874, 15.3131],
        [20.6421, 22.0844, 21.6807, 20.7768, 19.7152, 18.7200, 18.0284],
        [30.4850, 31.1897, 31.2677, 31.1213, 30.9020, 30.7071, 30.5798],
        [23.2973, 24.7298, 24.7169, 24.1225, 23.2322, 22.2431, 21.3458],
        [19.9827, 20.2598, 20.1327, 19.9549, 19.8191, 19.7451, 19.7244],
        [17.2857, 21.8384, 21.8143, 20.9079, 19.6902, 18.4102, 17.4136],
        [14.7906, 18.3457, 19.5532, 18.5729, 17.4630, 16.4490, 15.6343]])
ground truth
tensor([[17.2477, 22.5198, 29.9461, 26.4456, 12.4575, 12.0040, 11.3379],
        [10.9679, 13.1115, 13.3745, 19.0558, 31.7464, 33.9032, 10.8759],
        [13.9663, 15.7549, 23.8559, 21.1468, 24.2504, 10.1526, 11.3493],
        [40.6321, 22.9734, 24.4898, 17.4036, 23.2426, 29.2092, 31.1650],
        [15.4787,  9.5608, 21.7386, 28.6165, 37.9537, 11.5334, 18.1089],
        [13.8180, 1

batch_predictions
tensor([[24.7747, 33.4828, 34.0219, 31.8947, 27.1973, 19.6303, 19.3634],
        [15.8567, 17.2546, 17.0401, 16.4956, 15.9952, 15.6315, 15.4187],
        [ 6.4543,  6.2817,  6.2547,  6.2585,  6.2635,  6.2654,  6.2671],
        [17.7166, 18.4490, 18.1631, 17.6255, 17.1191, 16.7596, 16.5784],
        [32.3755, 28.6969, 28.3014, 29.1350, 30.1478, 31.0834, 31.8956],
        [37.0180, 36.8260, 36.8417, 36.8582, 36.8697, 36.8843, 36.9040],
        [29.2051, 24.1690, 25.7544, 27.6712, 29.3708, 30.8164, 32.0566],
        [16.7563, 19.9313, 19.5903, 18.7840, 17.9995, 17.3838, 16.9942]])
ground truth
tensor([[16.0431,  4.3367, 20.2239, 32.6814, 48.1151, 36.6497, 19.9405],
        [10.0482, 17.1485, 20.3515, 25.9495,  9.1553,  8.1774, 10.4734],
        [ 7.7948, 12.3583,  9.8356,  8.4467,  7.7948,  6.2217,  7.9649],
        [20.2920, 22.6328, 31.2073,  4.9579, 11.5597, 12.5986, 11.7833],
        [18.0563, 19.1347, 21.9095, 25.5786, 32.4961, 39.3872, 32.4566],
        [59.1128, 4

batch_predictions
tensor([[12.5718, 11.5086, 11.1516, 11.0463, 11.1543, 11.4703, 12.0934],
        [25.3117, 21.7298, 21.4243, 22.4726, 23.7214, 24.9658, 26.1573],
        [37.2778, 37.2619, 37.2601, 37.2563, 37.2543, 37.2544, 37.2618],
        [19.3137, 17.4057, 18.6402, 19.7392, 20.5863, 21.1569, 21.5068],
        [19.4822, 20.1559, 20.0141, 19.7004, 19.4199, 19.2353, 19.1491],
        [ 8.6673,  6.5925,  5.8177,  5.7104,  5.7250,  5.7507,  5.7729],
        [19.4201, 21.8768, 21.2099, 19.8396, 18.1114, 16.5479, 15.9483],
        [18.0057, 16.9200, 16.9033, 17.2757, 17.6627, 17.9910, 18.2574]])
ground truth
tensor([[17.5171,  9.4424,  8.5087,  9.2451,  8.6665, 12.8222, 14.7685],
        [16.4399, 14.0023, 14.7392, 15.0368, 21.3152, 43.0414, 33.1207],
        [60.0482, 54.5210, 28.4297, 23.9654, 21.9246, 23.1718, 41.0714],
        [14.5581, 11.0994, 13.9400, 10.3367, 20.0158, 28.8401,  7.3251],
        [25.9495, 13.8322,  4.5210,  9.2120, 21.5136, 22.0947, 19.0193],
        [ 8.6310,  

batch_predictions
tensor([[25.1323, 24.5017, 24.4307, 24.5542, 24.7273, 24.9090, 25.0841],
        [ 7.7683,  7.1196,  6.9192,  6.8806,  6.8840,  6.8904,  6.8956],
        [14.7890, 17.8540, 18.2735, 17.5112, 16.7137, 16.0528, 15.5666],
        [21.1544, 21.6439, 21.4948, 21.1987, 20.9263, 20.7418, 20.6565],
        [19.0170, 21.1333, 20.6911, 19.8910, 19.0996, 18.4375, 17.9846],
        [ 6.2939,  6.3065,  6.3089,  6.3113,  6.3104,  6.2939,  6.2876],
        [17.2485, 18.8621, 18.5729, 17.9323, 17.3253, 16.8656, 16.5860],
        [16.0738, 19.2001, 19.0929, 18.3043, 17.4845, 16.8056, 16.3277]])
ground truth
tensor([[24.9605, 36.3624, 10.8233, 13.6902, 19.5555, 17.6092, 20.3972],
        [ 8.4035, 12.4803, 16.8201, 10.8233,  3.4850, 10.7049,  1.9069],
        [11.1783, 13.6376, 19.2399, 18.7270, 20.5287,  7.5487, 10.7443],
        [16.0179, 19.8843, 26.1967, 23.3430, 29.8527, 10.8233, 14.7554],
        [11.1536, 16.6383, 23.2426, 26.3605, 27.1117, 14.9943, 14.9943],
        [ 6.7859,  

batch_predictions
tensor([[23.0382, 21.5004, 21.2649, 21.6690, 22.2544, 22.8499, 23.4052],
        [17.9830, 16.4014, 16.4455, 17.1081, 17.7871, 18.3984, 18.9468],
        [ 9.4617, 10.6713, 11.5488, 12.4627, 13.6749, 15.2244, 16.3296],
        [23.9041, 25.7916, 25.6690, 24.9091, 23.8803, 22.7919, 21.8524],
        [19.3058, 23.3724, 22.1757, 20.1233, 17.5356, 15.9395, 16.0039],
        [17.2833, 16.0404, 17.2209, 18.1854, 18.8861, 19.3328, 19.5987],
        [20.4177, 26.8873, 27.6843, 26.5128, 24.0121, 19.8812, 18.1537],
        [29.7191, 34.9202, 35.1757, 34.2277, 32.6536, 30.5024, 27.5201]])
ground truth
tensor([[ 8.5034, 15.5187, 14.8810, 19.2177, 22.8458, 26.3605, 36.0969],
        [25.9863, 35.4024, 10.6128,  8.8769, 10.6654, 11.1520, 15.8469],
        [ 6.6043,  9.8498,  8.8577,  9.0986,  6.9728, 11.4371, 20.5074],
        [13.8874, 11.8096,  1.0521, 21.8306, 24.1846, 32.0358,  0.0000],
        [13.6338, 19.7562, 31.4201, 43.3673, 18.0130, 17.6446, 13.8605],
        [30.3713, 1

batch_predictions
tensor([[16.1903, 19.6503, 20.1570, 19.5810, 18.9665, 18.4867, 18.1909],
        [27.0565, 34.2604, 34.3746, 32.9569, 30.5532, 26.5231, 22.0948],
        [21.5307, 19.5081, 19.7728, 20.4873, 21.1242, 21.6363, 22.0371],
        [19.0830, 17.5168, 18.6405, 19.5692, 20.2619, 20.7397, 21.0510],
        [16.4392, 15.1241, 16.2355, 17.1718, 17.8303, 18.2338, 18.4751],
        [19.1548, 21.0192, 20.5300, 19.6227, 18.6955, 17.9284, 17.4563],
        [18.8068, 21.9616, 21.4323, 20.5495, 19.7034, 19.0279, 18.5863],
        [28.8655, 27.5989, 27.1452, 27.1506, 27.4299, 27.9059, 28.5274]])
ground truth
tensor([[14.9132, 11.3098, 12.6775, 16.2415, 20.1604, 24.2504,  9.9290],
        [21.9246, 35.3033, 43.8917, 49.7591, 20.1247, 27.7069, 30.1162],
        [22.7117, 14.4529, 13.7296, 16.2809, 19.8185, 26.4072, 44.7922],
        [22.6616, 11.3237, 14.3566, 11.1395, 13.3787, 19.6145, 22.3923],
        [22.4093, 11.5071, 10.7049,  9.6528, 10.0999,  6.8780, 25.2236],
        [24.3161, 2

batch_predictions
tensor([[35.7506, 34.4513, 34.4973, 34.7851, 35.0531, 35.2563, 35.4043],
        [26.1406, 35.9673, 36.7500, 35.5440, 32.8244, 27.1379, 21.0226],
        [23.1297, 26.6371, 26.1092, 24.9128, 23.3610, 21.4766, 19.9923],
        [18.3502, 17.3279, 18.6104, 19.7302, 20.6131, 21.2389, 21.6254],
        [ 9.4377, 10.1121, 10.5973, 11.0246, 11.4614, 11.9875, 12.7300],
        [17.6102, 21.6215, 24.8954, 26.6485, 26.6485, 25.8564, 24.6004],
        [ 7.7641,  7.6118,  7.5422,  7.5147,  7.5077,  7.5049,  7.5043],
        [38.0221, 37.4782, 37.5313, 37.6189, 37.6816, 37.7280, 37.7666]])
ground truth
tensor([[25.0263, 42.4250, 30.9442, 28.5245, 23.5928, 20.4235, 23.9479],
        [18.0826, 33.3640, 53.8532, 47.3567, 22.7117, 20.9232, 16.7938],
        [17.2902, 16.2698, 17.9705, 30.5556, 28.2455, 22.8883, 18.5658],
        [13.6639, 20.7259, 19.1478, 15.9390, 19.5292, 26.5650, 29.0373],
        [11.4150, 15.2157, 20.6207, 17.9905,  3.2614,  7.5355,  7.7854],
        [22.3172, 1

batch_predictions
tensor([[17.6206, 22.5861, 22.7977, 21.8575, 20.4050, 18.6458, 17.3379],
        [20.6095, 19.8146, 19.6406, 19.7697, 20.0297, 20.3482, 20.6924],
        [14.9365, 18.5113, 19.4367, 18.0377, 16.1408, 14.0670, 12.8787],
        [15.0506, 15.1065, 15.0773, 15.0226, 14.9779, 14.9590, 14.9621],
        [28.3659, 27.1796, 26.7463, 26.7396, 27.0098, 27.4757, 28.0926],
        [23.5683, 23.4432, 23.4575, 23.4802, 23.4984, 23.5203, 23.5497],
        [37.6611, 37.9906, 37.9696, 37.9315, 37.9051, 37.8912, 37.8861],
        [22.0361, 29.0857, 34.0554, 33.9748, 31.8382, 26.6599, 19.3886]])
ground truth
tensor([[18.3248, 14.3566, 22.6332, 28.1321, 11.6071, 13.2653, 13.8322],
        [35.8232, 11.9674, 15.7549, 17.1094, 17.5565, 17.6355, 27.3672],
        [13.6113, 11.6649, 11.1257, 10.7575, 28.5508, 28.7086, 10.5339],
        [ 6.3350, 17.6162,  2.9053,  3.3305,  5.7115,  5.8673, 13.2937],
        [45.2523, 31.4768, 19.2744, 17.8430, 15.4904, 14.5408, 31.6893],
        [24.1715, 2

batch_predictions
tensor([[38.3298, 38.3039, 38.3139, 38.3183, 38.3205, 38.3221, 38.3236],
        [15.7987, 18.6354, 18.6394, 17.8500, 17.0063, 16.2882, 15.7487],
        [29.0877, 24.1230, 25.1321, 26.8285, 28.4756, 29.8729, 30.9560],
        [18.1841, 16.2597, 17.2010, 18.2331, 19.0592, 19.6461, 20.0559],
        [30.8305, 35.7556, 35.6313, 35.0114, 34.4043, 33.9301, 33.6130],
        [15.4709, 20.0194, 23.4430, 24.1938, 23.6569, 22.8041, 21.9534],
        [18.2814, 23.5106, 22.2443, 19.4754, 16.0821, 15.0681, 15.2380],
        [24.5931, 28.7089, 28.3636, 27.0667, 25.0170, 21.9735, 19.7526]])
ground truth
tensor([[40.8206, 45.1210, 45.1604, 43.3588, 62.1120, 60.3893, 26.4335],
        [22.9308,  5.7540,  0.1134,  8.3333, 13.4212, 16.8226, 22.0805],
        [34.8356, 19.1610, 16.2840, 14.6400, 24.2772, 33.6168, 46.8679],
        [13.5060, 12.0068, 14.6107, 16.1494, 18.5955, 28.1168,  0.9074],
        [23.9479, 28.3535, 36.4282, 29.3924, 27.8012,  5.4314, 21.7123],
        [22.7749, 2

batch_predictions
tensor([[14.0621, 13.3270, 13.2752, 13.5420, 13.9187, 14.3109, 14.6674],
        [ 3.8838,  3.8173,  3.8287,  3.8383,  3.8445,  3.8468,  3.8478],
        [22.5463, 20.1255, 21.2369, 22.3475, 23.2457, 23.9144, 24.3798],
        [34.6345, 32.0235, 31.0704, 31.0489, 31.5310, 32.2484, 33.0819],
        [21.6660, 21.9257, 21.8242, 21.6529, 21.5088, 21.4233, 21.3958],
        [19.3836, 19.7765, 19.6715, 19.4913, 19.3485, 19.2660, 19.2353],
        [18.4553, 22.5084, 21.8885, 20.7528, 19.5282, 18.5082, 17.9124],
        [15.0761, 15.4666, 17.3525, 19.2241, 20.8007, 21.3367, 21.2468]])
ground truth
tensor([[14.4398,  5.4577,  5.3919,  4.9448,  8.2720, 11.0337, 19.6344],
        [ 5.9240,  8.4042, 10.0765,  4.1950,  3.1746,  2.8770,  5.4280],
        [27.6565, 11.7833, 16.4519, 18.2667, 16.6360, 24.4740, 36.4545],
        [38.2956, 22.6328, 25.6839, 18.3456, 25.4603, 35.7443, 51.3019],
        [21.8701, 16.8332, 24.3030, 29.6423, 13.1247, 18.3588, 10.1131],
        [17.5170, 2

batch_predictions
tensor([[12.2852, 11.8460, 11.4197, 10.9099, 10.1813,  9.0775,  7.8484],
        [18.3852, 21.7711, 21.6732, 20.5655, 19.0522, 17.5181, 16.5963],
        [15.2588, 14.3202, 15.3942, 16.3481, 17.0141, 17.3910, 17.5912],
        [33.6758, 28.7670, 28.8263, 29.9861, 31.2236, 32.3522, 33.3181],
        [25.6002, 23.0166, 23.0030, 23.7900, 24.6049, 25.3061, 25.8757],
        [24.8628, 35.5901, 37.0284, 36.3323, 35.1088, 33.2898, 30.2999],
        [10.5086,  9.7440,  8.6237,  7.4652,  6.8120,  6.5905,  6.5417],
        [ 5.2751,  5.4859,  5.4955,  5.4918,  5.4889,  5.4814,  5.4777]])
ground truth
tensor([[ 4.3924,  7.4040,  8.0747,  7.3514, 11.3887, 15.9258, 14.9527],
        [15.8872, 34.3821, 33.6593, 13.2795, 16.3407, 16.6100, 14.3849],
        [ 8.0747, 11.0600, 11.0863, 16.2152, 14.0715, 21.0021, 13.5587],
        [52.4461, 24.4345, 19.6739, 20.5418, 23.3693, 58.2588, 43.2272],
        [18.0414, 20.8333, 22.6049, 27.1967, 21.3152, 28.9683, 35.3175],
        [20.3090, 1

batch_predictions
tensor([[19.8381, 26.1630, 24.9912, 21.6165, 16.5371, 15.7765, 16.0055],
        [36.9233, 36.8934, 36.8922, 36.8894, 36.8889, 36.8907, 36.9018],
        [15.0481, 18.0877, 18.7453, 17.9883, 17.1620, 16.4632, 15.9433],
        [21.2422, 24.1100, 23.3554, 21.9217, 20.0138, 18.0774, 17.2835],
        [21.4597, 18.3857, 19.9041, 21.7444, 23.4173, 24.6252, 25.5345],
        [ 8.7822,  8.7784,  8.7802,  8.7822,  8.7759,  8.7549,  8.7516],
        [16.5721, 21.9386, 22.3701, 21.2504, 19.4284, 17.0757, 15.7322],
        [19.2410, 19.9571, 19.7487, 19.3099, 18.8895, 18.5881, 18.4314]])
ground truth
tensor([[13.6054, 12.1032, 12.1315, 16.6100, 24.4189, 28.4439, 18.6791],
        [24.9858, 19.6995, 32.3413, 30.0879, 54.1383, 53.5006, 39.6967],
        [ 8.3050, 13.2511, 13.8464, 15.0510, 22.3356, 23.6820, 12.2024],
        [ 9.7506, 16.0714, 20.5357, 25.2409,  0.7370,  0.0000,  0.7511],
        [11.5505, 12.9960, 11.6780, 13.7046, 20.8333, 30.4563, 28.7557],
        [ 7.9826, 1

batch_predictions
tensor([[20.8075, 20.1048, 20.0046, 20.1525, 20.3748, 20.6080, 20.8293],
        [20.1916, 19.6396, 19.6106, 19.7448, 19.8996, 20.0409, 20.1647],
        [13.4509, 13.1035, 13.0432, 13.1216, 13.2475, 13.3993, 13.5587],
        [17.4370, 23.6823, 24.0872, 22.6154, 19.5796, 15.9235, 15.6682],
        [18.6488, 24.6804, 24.4417, 22.9406, 20.2012, 17.3430, 16.9390],
        [19.8365, 25.4823, 24.2102, 21.7110, 18.3578, 16.5950, 16.8536],
        [18.8351, 24.0579, 22.8244, 20.2880, 17.1629, 15.9063, 16.1097],
        [16.4958, 15.8724, 15.8846, 16.1263, 16.4063, 16.6620, 16.8763]])
ground truth
tensor([[12.4291, 18.4949, 17.9847, 17.2902, 21.5420, 18.0272, 23.7954],
        [23.8953, 10.7575,  1.5255, 15.8206, 14.7422, 18.7533, 21.9227],
        [15.8730, 21.7545,  6.9870,  7.2279,  5.9524,  7.4121, 16.7092],
        [27.4987, 28.1168,  9.3766,  8.6796, 10.7838, 16.0705, 21.6597],
        [13.4008, 27.9590, 41.6228, 26.7622, 13.3351, 16.5308, 18.2272],
        [15.1894, 1

tensor([[19.7646, 21.1563, 20.8670, 20.3192, 19.8166, 19.4510, 19.2463],
        [ 9.5552,  9.6308,  9.6936,  9.7253,  9.7234,  9.7025,  9.6802],
        [15.0384, 14.8168, 14.9425, 15.1405, 15.3249, 15.4722, 15.5776],
        [15.1696, 18.5290, 20.0077, 19.2771, 18.4224, 17.6942, 17.1645],
        [27.6571, 33.7042, 33.8564, 32.3272, 29.5509, 24.4602, 20.5627],
        [23.3439, 22.9310, 22.8686, 22.9553, 23.0938, 23.2635, 23.4533],
        [25.2205, 21.4668, 22.5748, 24.1569, 25.6850, 26.9853, 27.9458],
        [19.6695, 23.4870, 22.6683, 20.7559, 18.0675, 16.1849, 16.1430]])
ground truth
tensor([[21.6202, 26.0652, 13.2956, 15.5839, 16.5045, 17.5829, 17.7670],
        [ 6.2730,  7.1015, 11.1783, 10.8101, 18.4508, 19.2267,  8.5350],
        [ 1.5590,  0.0000,  0.0000, 12.9110, 14.2007,  8.1349, 14.1582],
        [14.2031, 16.9253, 19.5029, 24.6186, 16.2152, 12.3882, 13.0721],
        [33.6735, 45.0680, 37.3158, 21.4569, 27.5368, 20.1814,  6.3776],
        [31.8452, 21.8679,  5.7965, 1

batch_predictions
tensor([[14.6114, 18.5840, 21.3995, 21.2386, 20.6688, 20.1320, 19.7698],
        [17.4785, 21.6591, 21.2079, 20.1094, 18.8348, 17.6951, 16.9985],
        [33.3860, 31.3233, 30.6877, 30.7831, 31.2415, 31.8756, 32.5893],
        [18.3561, 24.8982, 24.7018, 22.5747, 18.0811, 15.4361, 15.5930],
        [ 7.8400,  8.9830,  9.6002, 10.0665, 10.5023, 10.9891, 11.6098],
        [15.8721, 18.9312, 19.2933, 18.8418, 18.4072, 18.1077, 17.9527],
        [14.8982, 18.1940, 18.7642, 18.0154, 17.2350, 16.6005, 16.1553],
        [25.5699, 28.7252, 28.4830, 27.4341, 25.8233, 23.4351, 20.7789]])
ground truth
tensor([[ 9.1837, 19.4728, 18.7358, 18.8492, 27.2392, 25.5385,  5.8532],
        [21.7545, 27.1825, 14.0306, 14.0164, 10.9694, 14.2574, 12.7834],
        [ 4.6627,  0.0000, 17.6020, 21.4002, 23.2851, 34.4529, 42.4461],
        [ 9.9065, 23.2426, 37.8543, 34.7931, 15.3628, 11.8764, 11.2245],
        [ 3.9716,  4.3530,  8.7980, 11.3098, 19.4766, 19.7002,  3.3535],
        [18.0957, 1

batch_predictions
tensor([[12.2705, 14.6857, 18.4053, 18.9180, 17.6589, 16.1201, 14.6367],
        [28.0488, 34.1247, 34.5463, 33.2208, 30.6541, 26.1650, 21.7016],
        [18.1852, 22.3071, 21.4803, 19.8912, 18.0339, 16.5277, 16.0446],
        [ 5.6606,  4.4263,  4.3346,  4.3731,  4.4067,  4.4302,  4.4472],
        [14.0285, 13.6688, 14.8674, 16.0465, 16.8735, 17.2376, 17.3646],
        [23.2241, 22.3266, 22.1380, 22.3014, 22.6266, 23.0217, 23.4335],
        [15.7158, 20.1653, 24.3110, 25.3149, 24.6959, 23.3914, 21.5867],
        [30.0373, 33.2248, 33.0291, 32.2779, 31.4751, 30.7767, 30.2229]])
ground truth
tensor([[ 7.7591,  9.3766,  8.9953, 11.1520,  0.0000,  2.8932,  9.1794],
        [23.2001, 18.4382, 23.3277, 33.0499, 49.0646, 43.1122, 23.6395],
        [15.4904, 20.0539, 28.6990, 35.4025, 10.5867, 12.2449, 10.8702],
        [13.2430, 14.8606, 10.3235,  7.1278,  5.4182, 16.9779,  8.4429],
        [ 7.9826, 14.0715, 12.9274,  9.5608, 12.7564, 19.2530, 16.7675],
        [29.1557, 1

batch_predictions
tensor([[20.7304, 26.9038, 25.9695, 23.9474, 20.9138, 18.2719, 18.1847],
        [17.8435, 17.7987, 17.8174, 17.8274, 17.8293, 17.8360, 17.8523],
        [22.4192, 28.8404, 28.0092, 25.9609, 22.5646, 18.9991, 18.8634],
        [16.6501, 15.3062, 16.3326, 17.2063, 17.8103, 18.1928, 18.4340],
        [10.6422, 11.6526, 12.7982, 14.3965, 15.9994, 15.8777, 15.3796],
        [ 4.4438,  4.3835,  4.3871,  4.3951,  4.3998,  4.4022,  4.4046],
        [16.5382, 17.4607, 17.2508, 16.8505, 16.5176, 16.3085, 16.2148],
        [14.8621, 14.7663, 14.8015, 14.8431, 14.8758, 14.9052, 14.9329]])
ground truth
tensor([[15.1644, 18.7925, 23.2568, 39.2290, 29.7052, 17.3186, 17.4461],
        [26.3180, 24.5465, 14.0306, 16.0006, 18.8350, 15.1644, 26.6865],
        [49.8948, 23.3167, 28.6691, 25.3419, 22.4750, 48.3824, 43.4903],
        [28.1825, 15.1894,  6.0889, 15.0710, 13.6902, 17.2935, 23.4219],
        [ 9.2977,  9.3372, 16.3335, 20.1867, 19.2530,  8.6796, 11.4940],
        [ 6.7460, 1

batch_predictions
tensor([[20.3892, 27.2869, 28.6526, 27.6957, 25.5877, 21.2954, 18.2858],
        [21.3429, 22.7903, 22.3997, 21.6579, 20.9124, 20.3131, 19.9401],
        [19.0342, 21.1048, 20.6300, 19.6883, 18.6678, 17.7529, 17.1398],
        [18.8231, 17.9444, 17.8203, 18.0483, 18.3886, 18.7439, 19.0811],
        [23.1188, 26.4047, 25.9583, 24.8297, 23.3540, 21.5966, 20.1469],
        [27.6195, 33.2876, 33.6098, 32.1063, 29.1635, 23.9121, 20.4339],
        [19.5722, 17.5224, 18.3489, 19.2688, 20.0079, 20.5445, 20.9166],
        [33.0041, 35.2505, 35.3867, 35.0030, 34.4938, 34.0102, 33.6212]])
ground truth
tensor([[49.0363, 43.1548, 13.2370, 10.7710, 16.6100, 21.5420, 36.0261],
        [20.4103, 22.5276, 22.2777, 11.9148, 11.1520, 18.0957, 17.3593],
        [22.1372, 23.0867, 15.2353, 12.6276, 11.1820, 14.2999, 15.7596],
        [21.9813, 11.2528, 11.9473,  9.9915, 12.7409, 12.9960, 23.9938],
        [19.4870, 34.3537, 23.2851, 20.7200, 17.4887, 15.0652, 15.0794],
        [32.6814, 4

batch_predictions
tensor([[24.4847, 28.5512, 28.5389, 27.2916, 24.8762, 21.1038, 18.8501],
        [16.4472, 17.8887, 17.5041, 16.8490, 16.2739, 15.8625, 15.6312],
        [17.0768, 22.0949, 25.1688, 24.9146, 23.7521, 21.6962, 18.2108],
        [21.1832, 23.5192, 23.0039, 22.0836, 21.0868, 20.1226, 19.3648],
        [22.7967, 21.9338, 21.7987, 21.9839, 22.2802, 22.5999, 22.9054],
        [16.1014, 15.8461, 17.0556, 18.0441, 18.7612, 19.2164, 19.4698],
        [20.2212, 21.9247, 21.1504, 19.9143, 18.5214, 17.3229, 16.7452],
        [10.2947, 12.4666, 14.8460, 17.7527, 19.1872, 18.8190, 18.0260]])
ground truth
tensor([[10.5584, 16.7517, 15.4478, 22.1797, 44.8696, 36.2528,  9.3963],
        [14.7817, 17.4382, 13.7296,  6.6149, 11.2572,  8.9032, 10.9548],
        [17.6020, 18.4382, 19.0618, 23.0867, 34.8214, 38.9031,  6.5051],
        [19.8843, 35.9942, 27.6565, 11.7833, 16.4519, 18.2667, 16.6360],
        [15.5612, 16.5108, 20.7766, 32.7239, 29.2092, 20.3231, 18.0272],
        [29.7761, 1

batch_predictions
tensor([[18.5802, 24.9903, 24.5636, 22.9116, 19.8549, 17.0449, 17.2066],
        [19.9218, 26.1289, 30.7330, 32.8979, 33.2397, 33.0109, 32.7106],
        [26.5893, 32.1670, 32.0891, 30.6180, 28.0973, 23.5908, 20.2393],
        [21.8046, 21.4271, 21.3966, 21.5089, 21.6647, 21.8377, 22.0138],
        [27.0270, 23.1792, 23.1429, 24.3700, 25.7222, 26.9733, 28.0408],
        [20.3360, 21.6723, 21.2958, 20.5519, 19.7649, 19.0886, 18.6279],
        [21.8069, 26.1103, 25.3717, 23.3977, 20.0830, 17.2630, 17.2697],
        [16.2583, 22.4119, 23.4858, 21.6371, 17.5981, 13.9086, 13.6248]])
ground truth
tensor([[12.0331, 10.7049, 12.3356, 19.7265, 30.6812, 27.4329, 11.6386],
        [36.8622, 21.0547, 20.3051, 12.2436, 20.7654, 27.5907, 31.3519],
        [16.2840, 14.6400, 24.2772, 33.6168, 46.8679,  9.8498, 12.5567],
        [ 5.6973,  5.7540, 12.7834, 15.3061, 22.7749, 34.0136, 22.8316],
        [26.9700, 11.0402, 17.1910,  9.0703, 17.8430, 32.3554, 39.2007],
        [13.5981, 1

batch_predictions
tensor([[10.6586, 13.2112, 16.1470, 18.9512, 19.5994, 19.0338, 18.2299],
        [16.4730, 16.4051, 16.4246, 16.4415, 16.4507, 16.4624, 16.4802],
        [20.4595, 19.6654, 19.5045, 19.6439, 19.9044, 20.2071, 20.5171],
        [15.0958, 19.2468, 21.0840, 19.2704, 16.4163, 13.7640, 13.0767],
        [19.9273, 18.9485, 18.7509, 19.0016, 19.4689, 20.0288, 20.6210],
        [27.5452, 28.7629, 28.6895, 28.3119, 27.8908, 27.5388, 27.3014],
        [28.4695, 26.6784, 26.2531, 26.4998, 27.0188, 27.6088, 28.1809],
        [17.6147, 17.2222, 17.2619, 17.4330, 17.6225, 17.7949, 17.9409]])
ground truth
tensor([[11.0863, 11.2704, 13.0195, 18.8716,  5.4050,  6.6281,  3.7612],
        [24.4608, 29.8264,  9.9290, 12.0857, 12.4277, 11.0205, 19.8185],
        [29.8895,  4.9603,  0.0000, 10.9694,  9.1553, 14.2007, 22.8741],
        [15.9916, 21.8175, 26.6176, 10.6786, 12.5855,  9.4424, 11.6386],
        [33.6026, 34.4388,  8.4467, 11.5221, 10.1049, 18.5232, 21.3861],
        [21.5018, 2

batch_predictions
tensor([[19.1042, 19.3124, 19.2547, 19.1242, 19.0053, 18.9311, 18.9044],
        [25.7120, 27.1349, 26.9626, 26.5211, 26.0963, 25.7811, 25.5982],
        [20.6016, 18.7807, 20.3380, 21.8805, 23.2798, 24.4614, 25.2991],
        [18.5081, 19.7208, 19.3079, 18.5790, 17.8668, 17.3110, 16.9809],
        [21.4168, 24.6144, 24.1024, 22.6137, 20.3847, 18.0957, 17.3327],
        [24.6896, 31.1782, 30.8153, 29.0107, 25.6942, 20.4683, 19.3690],
        [18.3705, 18.4995, 18.4124, 18.3049, 18.2290, 18.1936, 18.1901],
        [10.5265, 11.8041, 13.3140, 15.4117, 16.7434, 16.2489, 15.7101]])
ground truth
tensor([[23.2772, 22.6460, 12.0200, 12.0200, 12.2699, 14.5055, 20.3446],
        [27.5652, 27.3526, 24.2489, 24.0788, 33.5176, 17.7012, 18.8067],
        [28.1463, 11.8764, 14.8526,  2.1825,  0.0000,  9.9915, 29.9178],
        [ 6.0889, 15.0710, 13.6902, 17.2935, 23.4219, 31.1547, 21.4624],
        [29.1383, 36.4654, 32.6247, 20.4223, 15.9722, 14.3424, 13.9031],
        [16.5176, 2

batch_predictions
tensor([[ 5.9251,  5.8370,  5.8274,  5.8329,  5.8367,  5.8379,  5.8393],
        [19.6463, 22.9280, 22.3614, 20.6797, 18.2275, 16.2039, 15.9044],
        [23.2162, 30.1932, 33.8624, 34.5246, 34.3878, 34.1525, 33.9543],
        [30.8607, 27.5538, 27.0747, 27.8619, 28.9017, 29.8715, 30.6947],
        [10.5152, 11.8121, 13.4881, 16.5014, 18.3273, 17.0399, 15.5168],
        [18.9899, 24.0174, 23.2067, 21.6998, 19.8275, 18.3107, 17.7879],
        [20.9281, 18.6178, 18.7850, 19.7471, 20.7048, 21.5507, 22.2963],
        [10.1122, 11.1957, 12.1291, 13.1984, 14.6181, 15.9257, 16.1541]])
ground truth
tensor([[ 4.7194,  5.8248,  3.4439,  5.7115,  7.1570,  5.0595,  9.1270],
        [12.2041, 13.4271, 15.1368, 13.3877, 18.7533, 33.4824, 13.8611],
        [18.2141, 24.2372, 23.6323, 30.2078, 31.9174, 39.7817, 26.2230],
        [26.6440, 21.1026, 21.2868, 23.1151, 35.6718, 45.7625, 36.7630],
        [12.2567, 12.9932, 23.7112,  6.9832,  6.9437,  9.1136,  7.9958],
        [18.8322, 1

batch_predictions
tensor([[11.6711, 11.3253, 10.9277, 10.4085,  9.6703,  8.7032,  7.8703],
        [16.8684, 17.8396, 17.6344, 17.1688, 16.7407, 16.4421, 16.2864],
        [16.8334, 16.6868, 16.7185, 16.7680, 16.8078, 16.8423, 16.8757],
        [25.9854, 27.5591, 27.4598, 26.8808, 26.1400, 25.3916, 24.7382],
        [13.4528, 12.5949, 12.8525, 13.4667, 14.1365, 14.6455, 14.9438],
        [22.0819, 21.2066, 21.0718, 21.2938, 21.6565, 22.0587, 22.4504],
        [ 4.2291,  4.2374,  4.2375,  4.2395,  4.2403,  4.2315,  4.2289],
        [18.5229, 24.2070, 23.0672, 19.9314, 15.7810, 14.9450, 15.1215]])
ground truth
tensor([[15.9127,  3.9716,  6.4308,  5.7207,  6.5755,  7.8906, 14.6502],
        [18.6481, 31.6938, 39.2031, 17.5697, 15.6891, 11.7833, 16.6886],
        [11.2704, 10.8759, 13.9400, 14.9921, 15.5050, 23.6323, 30.9179],
        [35.5867, 18.6366, 13.0952, 10.5726, 19.2177, 34.8214, 50.3543],
        [15.6891,  4.4319, 11.5992, 11.6649, 10.7575, 12.3751, 15.0447],
        [21.7687, 3

batch_predictions
tensor([[27.3127, 36.1574, 36.4301, 35.5542, 34.2178, 32.4592, 29.9198],
        [20.9001, 19.0718, 19.5528, 20.2721, 20.8570, 21.2960, 21.6189],
        [13.9636, 16.9966, 19.8274, 19.5326, 18.8457, 18.2345, 17.8366],
        [17.9227, 19.7394, 19.4313, 18.6791, 17.9078, 17.2722, 16.8513],
        [29.0999, 26.4986, 25.6065, 25.8319, 26.5984, 27.5461, 28.5265],
        [19.5006, 21.7886, 21.3662, 20.2183, 18.7654, 17.3785, 16.6128],
        [12.9910, 16.3570, 19.1607, 20.2969, 20.1189, 19.5313, 18.9244],
        [27.4852, 23.1484, 24.7276, 26.5673, 28.2305, 29.5984, 30.6381]])
ground truth
tensor([[21.3435, 23.6111, 25.1417, 19.1893, 41.2557, 40.3345, 14.7817],
        [29.3135, 12.6907, 20.6602, 15.5050, 15.8075, 18.5034, 27.3803],
        [ 9.0420, 10.8985,  7.4546, 16.6241, 35.5300, 35.5159, 15.8730],
        [18.5941, 28.0045, 27.3810, 14.5833, 13.9456, 12.4717, 14.1015],
        [42.7579, 27.2959, 18.2256, 17.8146, 17.4603, 23.7812, 35.0340],
        [18.2404, 2

batch_predictions
tensor([[18.6433, 18.8657, 18.7721, 18.6109, 18.4751, 18.3958, 18.3714],
        [19.5210, 25.9953, 24.8070, 21.7596, 17.3388, 16.3007, 16.5879],
        [23.4814, 30.6448, 30.8948, 29.3312, 26.2710, 20.9482, 19.1051],
        [11.3702,  9.4591,  7.3140,  6.3611,  6.1790,  6.1520,  6.1466],
        [19.3063, 25.2329, 24.2544, 21.8589, 18.3933, 16.5777, 16.7934],
        [16.6458, 20.5104, 21.9089, 21.5869, 21.0798, 20.6199, 20.2998],
        [28.3242, 23.4292, 24.8947, 26.7717, 28.5347, 30.0233, 31.1778],
        [12.3103, 15.7679, 19.5614, 21.0433, 20.4198, 19.4332, 18.4722]])
ground truth
tensor([[12.3882, 16.0179, 14.1636, 11.6649, 20.3314, 29.6686, 30.1289],
        [23.0300, 35.2749, 32.9932, 14.9943, 14.5692, 16.4683, 17.7438],
        [18.2009, 19.5818, 31.8122, 57.2988, 39.2425, 21.1205, 17.0174],
        [ 4.4713,  6.1284,  6.2073,  8.0747,  9.6134, 17.3067, 18.7796],
        [16.7375, 29.1383, 36.4654, 32.6247, 20.4223, 15.9722, 14.3424],
        [ 8.7322, 1

tensor([[12.3583,  9.8356,  8.4467,  7.7948,  6.2217,  7.9649,  8.6026],
        [20.8617, 18.8634, 24.3339, 44.2035, 14.3566, 46.0459, 18.3673],
        [ 9.0845, 12.6417, 14.2432, 13.9031, 22.7608, 27.5368, 27.2959],
        [ 7.6539,  8.5350, 10.8364, 12.4803, 20.9890, 28.2877, 11.4808],
        [25.9337,  6.1941,  5.1420, 13.7691, 10.0868, 14.3346, 18.4245],
        [20.0964, 17.7154, 19.7562, 26.0204, 27.6502,  5.8532, 15.5896],
        [ 6.0232,  8.2341,  7.8798,  9.4955,  9.6372, 16.3265,  8.8861],
        [21.7545, 31.9728, 26.0062, 17.4603, 12.3724, 17.7012, 13.6763]])
batch_predictions
tensor([[33.4229, 34.7775, 34.9419, 34.7334, 34.4217, 34.1311, 33.9166],
        [21.1793, 20.0288, 19.7484, 19.9239, 20.2922, 20.7205, 21.1594],
        [18.6951, 25.0886, 26.6771, 25.6537, 23.2235, 18.5168, 17.0505],
        [14.8978, 18.3714, 18.9849, 18.1695, 17.2844, 16.5184, 15.9308],
        [15.8690, 14.9193, 15.0704, 15.5210, 15.9374, 16.2690, 16.5211],
        [28.6665, 32.7269, 32.82

batch_predictions
tensor([[26.9290, 33.2283, 33.4574, 31.8447, 28.7587, 23.0366, 20.0664],
        [19.1302, 17.3662, 17.8824, 18.6023, 19.1889, 19.6330, 19.9614],
        [17.5368, 22.3938, 23.0331, 22.0136, 20.3446, 18.2577, 16.8900],
        [18.9795, 21.3840, 20.8918, 19.9375, 18.9281, 18.0238, 17.3958],
        [18.8011, 19.0977, 18.9622, 18.7434, 18.5617, 18.4544, 18.4188],
        [21.5199, 24.8741, 24.2743, 23.2178, 22.0583, 20.8721, 19.9085],
        [19.7036, 19.4857, 19.5110, 19.5846, 19.6580, 19.7258, 19.7900],
        [19.1382, 25.5014, 24.8509, 22.6203, 18.6813, 16.2489, 16.4136]])
ground truth
tensor([[19.4019, 33.9569, 44.1185, 32.1854, 15.4620, 20.9892, 17.4036],
        [24.3293, 15.0973, 16.0310, 16.0573, 15.5181, 17.6881, 27.7880],
        [16.8858, 16.9911, 18.5823, 12.9800, 26.6439, 19.6870,  0.0000],
        [18.0957, 15.2551, 25.9337, 31.1415, 11.2572, 15.4655, 15.7812],
        [23.4219, 31.1547, 21.4624,  4.2346, 11.2046, 13.3614, 20.2656],
        [17.7296, 1

batch_predictions
tensor([[26.0463, 30.0359, 29.9368, 28.7325, 26.5527, 22.7666, 19.5496],
        [13.4547, 17.6943, 22.4077, 22.6741, 21.7010, 20.1890, 18.4790],
        [37.9337, 35.0507, 35.1343, 35.8329, 36.4160, 36.7899, 37.0180],
        [ 5.2646,  5.0913,  5.0840,  5.0967,  5.1055,  5.1103,  5.1141],
        [25.7107, 31.4807, 31.3410, 29.8378, 27.2113, 22.2841, 19.3868],
        [28.8432, 24.7967, 25.1792, 26.4563, 27.7319, 28.8067, 29.6327],
        [29.8803, 28.0628, 27.6493, 27.9026, 28.4096, 28.9692, 29.4958],
        [29.7853, 29.2274, 29.1294, 29.2113, 29.3464, 29.5131, 29.7024]])
ground truth
tensor([[18.1689, 15.1502, 13.7046, 41.2132, 46.4853, 32.7948, 16.6241],
        [12.3724, 25.4252, 28.4722, 12.5425,  9.8498, 12.2307, 18.1689],
        [45.4932, 33.4892, 18.6083, 24.3339, 26.9274, 40.7171, 54.5351],
        [ 9.5380,  8.2766, 10.6718, 11.2103, 12.3583, 14.1015, 11.1536],
        [16.0836, 21.0021, 31.0889, 33.9690, 17.0042, 14.9527,  9.0610],
        [35.8418, 1

batch_predictions
tensor([[37.9095, 38.1597, 38.1418, 38.1158, 38.0984, 38.0896, 38.0865],
        [10.8076, 10.9851, 11.2119, 11.4706, 11.7651, 12.1299, 12.6182],
        [18.9869, 17.4420, 18.3585, 19.2679, 19.9151, 20.3556, 20.6558],
        [25.2300, 21.9696, 23.8693, 26.0016, 27.9947, 29.7112, 31.0612],
        [18.5595, 22.7719, 21.6486, 19.5416, 16.9456, 15.4033, 15.4402],
        [28.4085, 26.8479, 26.3675, 26.4668, 26.8716, 27.4244, 28.0361],
        [20.3269, 26.2369, 25.1945, 23.1002, 20.1230, 17.8137, 17.8802],
        [16.4033, 20.5788, 23.8284, 24.8128, 24.2922, 23.2206, 21.8396]])
ground truth
tensor([[45.1956, 55.9382, 52.4802, 25.5385, 19.2460, 23.9938, 36.5363],
        [18.6744, 13.3088,  8.1799,  8.1405,  7.0358, 10.5865, 10.4945],
        [25.5385,  5.8532, 13.0244, 10.4734,  9.9065, 13.4070, 19.8696],
        [24.7166, 12.9677, 19.9830, 12.5000, 17.3044, 30.4847, 56.2642],
        [15.6365, 11.1126, 14.6107, 16.7280, 27.7617,  7.7591, 11.2178],
        [46.8679,  

batch_predictions
tensor([[18.8097, 20.0965, 19.4642, 18.3581, 17.1265, 16.0688, 15.4969],
        [20.7520, 17.9558, 19.4958, 21.1386, 22.6844, 23.8514, 24.5878],
        [19.0910, 19.3672, 19.3009, 19.1442, 19.0022, 18.9122, 18.8758],
        [16.0281, 14.8338, 14.9955, 15.5771, 16.1507, 16.6483, 17.0712],
        [22.6446, 31.8307, 35.0872, 34.0465, 30.9613, 23.7972, 19.0869],
        [16.0447, 17.9765, 17.9745, 17.4400, 16.9114, 16.5159, 16.2695],
        [ 9.8409, 12.0914, 14.6492, 17.4585, 18.5809, 18.4955, 18.0707],
        [21.3309, 28.1064, 30.4409, 30.1182, 29.1948, 27.8308, 25.5795]])
ground truth
tensor([[20.3840, 23.9085,  5.0237, 10.7180,  8.3377,  9.5213, 12.9932],
        [38.9314, 30.4989, 13.7755, 10.6859, 11.9615, 14.0731, 23.7670],
        [26.6307, 24.3293, 15.0973, 16.0310, 16.0573, 15.5181, 17.6881],
        [ 7.7985, 10.0079,  7.9563, 16.5045, 12.0857, 17.5434, 21.5281],
        [12.9393, 15.3912, 10.3033, 22.9025, 43.4524, 40.4904, 24.7591],
        [11.6355, 1

batch_predictions
tensor([[15.8272, 21.2916, 22.3082, 21.1566, 19.2632, 16.5672, 15.0759],
        [22.1675, 24.8570, 24.7742, 23.8729, 22.6531, 21.3921, 20.3793],
        [18.3644, 16.9818, 16.8972, 17.4045, 17.9874, 18.5339, 19.0314],
        [20.6146, 28.0064, 28.6497, 27.2038, 24.2728, 19.6347, 18.8366],
        [11.6975, 13.5224, 15.9525, 17.3399, 16.6336, 15.9216, 15.3802],
        [19.9402, 22.8534, 22.3429, 21.1883, 19.9446, 18.9470, 18.4236],
        [28.2689, 31.1130, 30.9129, 30.1573, 29.2622, 28.3319, 27.3884],
        [ 9.2644, 10.5390, 11.4249, 12.3918, 13.8062, 15.8841, 17.2787]])
ground truth
tensor([[14.7817, 21.5939, 42.0305, 30.3919, 13.3482, 12.4540, 14.3740],
        [ 8.1633, 32.4263, 31.9019, 15.6604, 15.5896, 17.4887, 21.3010],
        [21.1731, 11.4019, 11.1520,  8.6796, 11.0600, 15.2946, 27.7617],
        [10.0198, 21.5561, 45.7200, 38.7755, 22.5482, 19.2177, 24.3764],
        [ 8.8112, 11.0600,  9.2583, 11.4282, 23.3561, 25.2762,  8.7585],
        [23.0274, 2

batch_predictions
tensor([[24.4941, 23.2287, 23.1554, 23.5103, 23.9275, 24.2976, 24.5987],
        [13.7522, 16.9473, 19.9571, 19.0641, 18.0280, 17.2164, 16.7313],
        [21.7043, 25.6192, 24.9344, 22.9969, 19.7034, 16.9580, 16.9390],
        [22.3774, 21.1602, 20.8680, 21.0624, 21.4692, 21.9489, 22.4457],
        [11.0056, 12.0704, 13.3847, 15.3255, 16.3661, 15.8461, 15.1656],
        [17.7915, 17.7090, 17.7378, 17.7686, 17.7895, 17.8103, 17.8366],
        [25.4572, 25.4726, 25.4735, 25.4422, 25.4514, 25.4530, 25.4264],
        [10.6804, 11.4437, 12.3310, 13.5665, 15.2004, 15.7451, 15.3323]])
ground truth
tensor([[12.7693, 14.0164, 19.0334, 18.0272, 23.9938, 11.1111,  8.6026],
        [14.0731, 10.9836, 18.5658, 22.6332, 28.4864, 27.9053, 11.2245],
        [15.0935, 16.2840, 19.8129, 31.0232, 32.9082,  5.4563, 16.7234],
        [29.2800,  4.5493,  4.5493, 14.0164, 19.6854, 26.5306, 32.8231],
        [15.7943, 19.5423,  6.2467,  8.0352,  9.3109, 10.2183, 11.9674],
        [27.4724, 2

batch_predictions
tensor([[ 8.2118,  8.7915,  9.0261,  9.1242,  9.1558,  9.1609,  9.1607],
        [17.5691, 22.0206, 26.0688, 28.5232, 28.8684, 28.1908, 26.8467],
        [13.4975, 16.3260, 17.9506, 17.4460, 16.9090, 16.5328, 16.3241],
        [22.7056, 32.1455, 32.9933, 31.1757, 27.2447, 20.0311, 18.8356],
        [24.5508, 26.2240, 26.0395, 25.4562, 24.8194, 24.2714, 23.8806],
        [17.6283, 20.0812, 19.7260, 19.0097, 18.3459, 17.8446, 17.5367],
        [18.3334, 22.8809, 24.3999, 23.2147, 20.9427, 17.9272, 16.6595],
        [16.0311, 19.7342, 19.6326, 18.7185, 17.6935, 16.7729, 16.0827]])
ground truth
tensor([[ 7.9790,  6.8736,  6.4059, 10.3600, 11.5646, 16.4399,  5.6406],
        [35.4450, 45.1956, 33.7585, 18.9909, 16.8651, 16.0006, 15.0794],
        [ 2.3951, 16.3549, 13.2370, 14.0164, 13.0244, 18.3673,  6.5051],
        [20.1672, 33.2766, 49.9575, 43.3107, 19.7846, 20.5782, 14.5975],
        [22.4206, 29.1241, 46.7120, 33.0924, 32.2279, 30.8107, 30.2438],
        [15.9439, 2

batch_predictions
tensor([[16.9961, 19.9183, 19.6613, 18.7651, 17.8086, 16.9831, 16.3889],
        [22.6634, 20.8156, 20.5002, 20.9754, 21.6399, 22.2965, 22.8990],
        [15.9801, 19.7704, 20.2190, 19.5297, 18.7658, 18.1224, 17.6779],
        [18.5295, 23.4402, 22.2722, 19.7119, 16.4847, 15.2828, 15.4285],
        [29.5868, 25.4549, 24.6533, 25.6364, 27.0582, 28.5077, 29.8700],
        [19.5436, 23.6261, 22.5259, 20.5665, 18.0024, 16.1306, 16.1119],
        [21.2973, 21.3481, 21.3372, 21.2919, 21.2449, 21.2187, 21.2191],
        [14.7514, 13.8205, 14.5409, 15.2466, 15.7314, 16.0290, 16.2092]])
ground truth
tensor([[23.0300, 19.5011, 11.3237, 12.9252, 12.9110, 11.3095, 12.6417],
        [28.3535, 17.9116,  4.6554,  9.5345, 14.7817, 15.7812, 18.1746],
        [12.1315, 11.1678, 12.6701, 20.5641, 26.7715,  6.5334,  8.2625],
        [11.7570, 19.0295, 27.5118, 32.4829, 11.0074, 13.0589, 16.5308],
        [24.7591, 16.7375, 13.4637, 13.9739, 23.7954, 48.1434, 11.4938],
        [22.2647, 3

batch_predictions
tensor([[16.1374, 16.9866, 16.8570, 16.4502, 16.0599, 15.7829, 15.6330],
        [28.5259, 30.4616, 30.3982, 29.8685, 29.2254, 28.6135, 28.1041],
        [22.8533, 26.5726, 25.9702, 24.3420, 21.6355, 18.3840, 17.5571],
        [20.4708, 21.9857, 21.6555, 20.8948, 20.0548, 19.3076, 18.7803],
        [16.7568, 16.4498, 16.5335, 16.7170, 16.9011, 17.0579, 17.1819],
        [28.2481, 29.2191, 29.2338, 28.9648, 28.6156, 28.2999, 28.0719],
        [12.8873, 13.9111, 14.6145, 14.7686, 14.5592, 14.2365, 13.9181],
        [21.5404, 22.3696, 22.1927, 21.6314, 20.9475, 20.3120, 19.8410]])
ground truth
tensor([[21.4361, 21.8306,  8.4429,  9.4424,  8.1536,  8.9821, 11.5992],
        [34.6088, 49.4898, 34.3537,  0.0000, 17.3328, 20.5924, 25.5527],
        [18.2115, 25.4393, 39.4983,  6.4201,  7.5680, 16.4683, 18.7358],
        [23.3844, 29.8895,  4.9603,  0.0000, 10.9694,  9.1553, 14.2007],
        [12.1120, 21.0547, 20.8969,  5.8916,  9.5213,  8.7980, 11.7570],
        [33.2983, 4

batch_predictions
tensor([[19.0113, 22.1314, 21.4396, 20.0362, 18.3359, 16.7723, 16.1468],
        [21.0086, 22.7141, 22.3691, 21.6739, 20.9675, 20.3784, 19.9792],
        [34.9285, 31.8544, 31.3046, 31.8075, 32.5976, 33.3946, 34.0936],
        [ 8.7267,  8.0934,  7.6855,  7.4860,  7.4061,  7.3819,  7.3776],
        [27.6765, 27.4457, 27.4184, 27.4559, 27.4961, 27.5490, 27.6249],
        [23.2360, 20.3835, 20.0499, 20.8591, 21.8374, 22.8076, 23.7552],
        [13.9084, 13.4475, 13.4540, 13.6522, 13.9075, 14.1589, 14.3729],
        [20.2095, 18.3163, 18.7952, 19.5698, 20.2200, 20.7223, 21.1006]])
ground truth
tensor([[18.7270, 26.9332, 24.7764,  8.1405,  9.2057, 13.8480, 12.6512],
        [26.0915, 23.6323, 12.9932, 20.8969, 14.2688, 13.7428, 18.7270],
        [ 7.4405, 13.5771, 21.7545, 21.0601, 32.4972, 40.6321, 40.6746],
        [13.4637, 19.8696,  5.8107,  7.6531,  7.0011,  8.6026, 14.8810],
        [14.6684, 20.6491, 19.7279, 19.8980, 34.5522, 41.2840, 27.0550],
        [20.9042, 2

batch_predictions
tensor([[19.5642, 26.0870, 24.8468, 22.1481, 18.2113, 16.5349, 16.9630],
        [15.8954, 14.5758, 15.0978, 15.8318, 16.4335, 16.8851, 17.2189],
        [26.4041, 23.5712, 23.5074, 24.3780, 25.3123, 26.1293, 26.7933],
        [12.1064, 13.2510, 14.4524, 15.0694, 14.9226, 14.5058, 14.0590],
        [20.1598, 26.3221, 25.1506, 22.6094, 18.8932, 17.0128, 17.3257],
        [21.6240, 21.2867, 21.2555, 21.3500, 21.4789, 21.6250, 21.7808],
        [29.1904, 24.5446, 25.9490, 27.6777, 29.1804, 30.3772, 31.2931],
        [22.7498, 19.9476, 20.6248, 21.8007, 22.8851, 23.7766, 24.4691]])
ground truth
tensor([[23.2710,  5.2721,  8.8152, 13.6338, 15.8447, 20.6774, 31.8594],
        [20.6996,  7.7985, 10.0079,  7.9563, 16.5045, 12.0857, 17.5434],
        [29.9461, 20.0822, 16.8793, 14.0731, 17.0351, 27.5794, 52.4943],
        [19.6344, 12.0857,  0.0000,  0.0000,  5.8785,  8.6007,  8.2194],
        [16.4966, 14.8810, 17.6871, 28.5856, 29.1241, 10.0057, 13.2795],
        [35.7851, 2

batch_predictions
tensor([[29.3900, 28.2100, 28.1004, 28.3510, 28.6744, 28.9779, 29.2355],
        [18.6703, 20.6204, 20.2786, 19.5502, 18.8286, 18.2454, 17.8639],
        [18.7068, 22.2391, 21.6386, 20.5780, 19.4852, 18.5630, 17.9514],
        [21.7174, 25.4498, 25.1080, 23.4291, 20.4875, 17.5985, 17.1298],
        [14.1918, 18.1281, 20.9956, 19.7731, 18.0607, 16.2114, 14.8155],
        [23.6373, 30.6041, 30.0138, 28.4116, 25.9201, 21.9010, 19.7927],
        [24.4722, 24.0571, 24.0112, 24.0984, 24.2181, 24.3516, 24.4920],
        [30.8915, 25.0076, 25.5459, 27.1619, 28.8443, 30.4450, 31.9897]])
ground truth
tensor([[36.7772, 15.2636, 17.6729, 27.6644, 30.0737, 36.8481, 50.2126],
        [16.5249, 13.1236, 13.3787, 17.0918, 18.1831, 21.0743, 12.9677],
        [15.3486, 15.6746, 16.9785, 24.2630, 15.3912,  0.0000,  9.2829],
        [36.0600, 37.0200, 14.1504,  8.9032, 16.5176, 18.2667, 23.1720],
        [12.1740, 12.7693, 13.0244, 16.4683, 28.0471, 29.3084, 28.6990],
        [17.9989, 1

batch_predictions
tensor([[ 8.7381,  8.8640,  8.9323,  8.9563,  8.9499,  8.9322,  8.9189],
        [28.8468, 29.0708, 29.0916, 29.0281, 28.9261, 28.8382, 28.7905],
        [12.8371, 11.8087, 11.4318, 11.2878, 11.3299, 11.5657, 12.0847],
        [25.7375, 33.6287, 33.9673, 31.9796, 27.7838, 20.6832, 19.8292],
        [16.3569, 20.0060, 20.0022, 19.2578, 18.4391, 17.7441, 17.2557],
        [ 7.9800,  8.4290,  8.5585,  8.5909,  8.5904,  8.5843,  8.5824],
        [16.7539, 21.0346, 23.5606, 23.3584, 22.6353, 21.7338, 20.7797],
        [26.5333, 27.2972, 27.3406, 27.1170, 26.7911, 26.4774, 26.2404]])
ground truth
tensor([[1.2954e+01, 1.7938e+01, 1.8924e+01, 4.4845e+00, 7.9958e+00, 5.7733e+00,
         5.4840e+00],
        [3.9768e+01, 4.3013e+01, 1.8211e+01, 1.6525e+01, 1.6738e+01, 1.7375e+01,
         3.2185e+01],
        [9.4424e+00, 8.5087e+00, 9.2451e+00, 8.6665e+00, 1.2822e+01, 1.4769e+01,
         1.7438e+01],
        [2.1046e+01, 4.1610e+01, 4.3679e+01, 3.8407e+01, 1.3676e+01, 9.6939

batch_predictions
tensor([[16.8777, 16.8358, 16.8508, 16.8580, 16.8584, 16.8634, 16.8766],
        [18.0217, 17.4800, 18.7484, 19.8392, 20.6885, 21.2998, 21.6965],
        [10.6786, 11.7143, 12.9216, 14.6251, 16.1236, 15.8463, 15.3226],
        [38.3853, 38.3687, 38.3733, 38.3753, 38.3762, 38.3769, 38.3777],
        [14.8909, 19.3501, 20.7070, 19.9694, 19.1388, 18.4287, 17.9158],
        [20.2763, 27.6159, 29.8707, 28.4317, 24.8956, 18.3700, 17.4984],
        [13.4811, 13.1362, 14.2146, 15.3812, 16.1111, 16.3860, 16.4852],
        [14.9471, 19.1960, 21.0191, 19.9216, 18.4197, 16.7188, 15.3706]])
ground truth
tensor([[2.0160e+01, 2.5460e+01, 1.0757e+01, 1.1599e+01, 9.2057e+00, 1.5742e+01,
         1.5229e+01],
        [1.8126e+01, 2.1556e+01, 1.2571e+01, 1.3223e+01, 1.3194e+01, 2.0139e+01,
         3.0371e+01],
        [9.3372e+00, 1.6334e+01, 2.0187e+01, 1.9253e+01, 8.6796e+00, 1.1494e+01,
         8.4429e+00],
        [4.6870e+01, 5.6418e+01, 6.2230e+01, 6.6531e+01, 2.0134e+01, 5.2604

batch_predictions
tensor([[19.0243, 17.4581, 17.6297, 18.2166, 18.7383, 19.1549, 19.4802],
        [23.9817, 21.2026, 22.6025, 24.0614, 25.2922, 26.2244, 26.8527],
        [25.8738, 31.9554, 31.9496, 30.4144, 27.6514, 22.8687, 20.0991],
        [16.0614, 15.3316, 15.3511, 15.6765, 16.0676, 16.4373, 16.7544],
        [15.3954, 19.5537, 23.1630, 22.8104, 21.6961, 20.2051, 18.4744],
        [17.4504, 17.3394, 17.3724, 17.4159, 17.4494, 17.4796, 17.5116],
        [17.6577, 16.0919, 16.5795, 17.3018, 17.8858, 18.3242, 18.6480],
        [18.9616, 25.8629, 25.1893, 23.2386, 19.7039, 17.1489, 17.5282]])
ground truth
tensor([[31.6893, 18.7358, 16.8651,  9.9490, 10.8418, 19.7562, 29.7902],
        [10.4734, 20.7766, 15.1077, 18.2965, 23.2143, 26.7999, 37.7409],
        [23.9216, 45.0158, 39.7291, 20.7785, 13.8348,  9.7449, 22.6328],
        [26.6156, 28.5289, 15.9439, 12.0323, 14.9660, 16.1990, 26.8991],
        [13.0102, 18.4240, 17.8146, 20.8333, 29.2800, 31.6752,  8.4325],
        [ 0.0000,  

batch_predictions
tensor([[15.3835, 14.9581, 15.0958, 15.3941, 15.6953, 15.9454, 16.1295],
        [ 9.9483, 10.6387, 11.2726, 11.9954, 13.0376, 15.0069, 17.9680],
        [16.9881, 16.2331, 17.3553, 18.2779, 18.9688, 19.4377, 19.7256],
        [ 7.5722,  8.6342,  9.1512,  9.5007,  9.7876, 10.0645, 10.3617],
        [24.0915, 33.0118, 34.5451, 33.1021, 30.2032, 24.8283, 20.9955],
        [31.3189, 29.8738, 29.7901, 30.1397, 30.5519, 30.9166, 31.2126],
        [14.5317, 14.1161, 14.1784, 14.3894, 14.6272, 14.8437, 15.0182],
        [20.9468, 25.2390, 24.3420, 22.6686, 20.4328, 18.1638, 17.6369]])
ground truth
tensor([[ 3.2614,  7.5355,  7.7854, 11.5466, 17.0437, 17.6618, 15.5839],
        [ 3.8664, 11.5071, 10.0473, 10.1262, 13.4140,  7.0095,  6.8911],
        [ 2.7778,  0.0000, 11.7772, 10.7143, 12.3299, 13.8605, 33.4467],
        [12.4934,  7.8511,  8.6402, 12.6512, 18.1483, 15.5050,  5.1420],
        [23.6395, 23.7245, 26.7715, 39.0023, 51.9841, 39.9660, 24.7449],
        [32.8774, 2

batch_predictions
tensor([[ 5.6672,  5.6020,  5.5965,  5.6023,  5.6058,  5.6069,  5.6084],
        [21.7323, 29.9862, 32.3392, 31.2933, 28.7759, 23.0644, 18.8753],
        [20.7116, 18.5683, 19.6962, 20.7582, 21.6029, 22.2120, 22.6176],
        [15.7320, 19.7838, 20.6311, 19.6456, 18.4039, 17.1437, 16.1514],
        [26.5828, 29.9047, 29.8791, 28.9603, 27.4918, 25.4087, 22.9352],
        [17.9326, 17.0631, 18.4725, 19.7268, 20.7069, 21.3426, 21.6724],
        [25.4857, 30.0269, 29.6977, 28.4535, 26.5702, 23.5471, 20.5837],
        [25.9225, 30.6021, 30.4274, 29.0374, 26.5310, 22.1752, 19.4341]])
ground truth
tensor([[ 8.2851,  6.0363,  0.0000,  4.2215, 15.7023, 11.3361,  7.2330],
        [21.5986, 20.8333, 43.3532, 41.5533, 32.8656, 15.5329, 17.9847],
        [13.1803, 10.2183, 14.1298, 15.5896, 20.0113, 31.8027, 27.9478],
        [ 0.0000, 10.2578, 17.0437,  6.6938, 28.2614, 26.0521, 13.5060],
        [19.5423, 25.8548, 25.1973, 33.2983, 45.6996, 19.9895, 17.2278],
        [ 7.4688, 1

batch_predictions
tensor([[28.3604, 28.5642, 28.5597, 28.4830, 28.3989, 28.3438, 28.3276],
        [19.4999, 19.5757, 19.5457, 19.4859, 19.4345, 19.4083, 19.4074],
        [34.6759, 38.1651, 38.0537, 37.8588, 37.6996, 37.5860, 37.5143],
        [38.2336, 38.0078, 38.0457, 38.0745, 38.0917, 38.1055, 38.1186],
        [29.1616, 28.0545, 27.7862, 27.9035, 28.1864, 28.5455, 28.9305],
        [25.8013, 26.7935, 26.8110, 26.4895, 26.0407, 25.5949, 25.2281],
        [14.3054, 18.0624, 20.0620, 18.5295, 16.4094, 14.0980, 13.0567],
        [18.2213, 16.7517, 17.7544, 18.5523, 19.1050, 19.4716, 19.7146]])
ground truth
tensor([[20.6602, 16.8464, 18.4903, 31.6018, 33.0615, 14.8606, 14.1373],
        [16.7543, 17.9116, 21.5150, 19.2399, 26.8937, 11.4019, 15.3998],
        [27.1684, 25.6519, 39.9802, 59.4813, 49.4331, 27.5794, 19.7562],
        [55.9382, 52.4802, 25.5385, 19.2460, 23.9938, 36.5363, 58.8861],
        [12.6984, 14.5975, 16.7234, 18.2540, 30.0028, 48.9938, 35.8135],
        [45.5782, 3

batch_predictions
tensor([[15.5954, 14.4926, 15.6160, 16.5486, 17.1245, 17.4529, 17.6482],
        [18.1368, 18.1453, 18.1423, 18.1239, 18.1196, 18.1176, 18.1108],
        [16.9889, 19.3370, 19.0930, 18.5333, 18.0570, 17.7406, 17.5814],
        [37.0302, 37.4760, 37.4635, 37.4008, 37.3508, 37.3228, 37.3126],
        [20.9037, 18.0415, 19.7485, 21.5324, 23.3113, 24.7654, 25.5582],
        [21.6781, 21.7822, 21.7721, 21.7073, 21.6360, 21.5890, 21.5758],
        [17.3993, 17.8439, 17.7436, 17.5428, 17.3793, 17.2834, 17.2456],
        [23.3469, 24.9176, 24.7952, 24.1595, 23.3551, 22.5679, 21.9282]])
ground truth
tensor([[ 5.9442,  8.7980,  0.9469, 18.5692, 12.1120, 21.0547, 20.8969],
        [22.0410, 15.2551,  9.9684, 13.3745, 16.0179, 16.5439, 14.5581],
        [ 4.8133, 10.1920, 10.7575,  8.8901, 18.7796, 23.3824,  2.5381],
        [23.5686, 27.1684, 25.6519, 39.9802, 59.4813, 49.4331, 27.5794],
        [15.3628, 11.8764, 11.2245, 12.1882, 19.9121, 43.1122, 37.0890],
        [35.1332, 3

batch_predictions
tensor([[17.8489, 20.9874, 20.6217, 19.7973, 18.9816, 18.3308, 17.9121],
        [21.7169, 21.2982, 21.2911, 21.4100, 21.5468, 21.6767, 21.7947],
        [ 8.6516,  9.6511, 10.2893, 10.8430, 11.4355, 12.2111, 13.4099],
        [13.7717, 13.0722, 13.8778, 14.8499, 15.5504, 15.8963, 16.0507],
        [11.7778, 12.7849, 13.9958, 15.0604, 15.1836, 14.7848, 14.2732],
        [20.5162, 18.6504, 18.6160, 19.2334, 19.8676, 20.4252, 20.9016],
        [29.4114, 23.9185, 25.5584, 27.5498, 29.3228, 30.9199, 32.4355],
        [18.8834, 20.1536, 19.8025, 19.0389, 18.2219, 17.5352, 17.0950]])
ground truth
tensor([[14.5125, 17.0210, 25.9495, 13.8322,  4.5210,  9.2120, 21.5136],
        [30.5414, 26.9416, 15.0652, 14.6967, 14.2007, 13.7472, 20.5641],
        [ 5.8815,  7.4972,  8.7727, 12.6134, 14.4558, 21.9529,  3.2880],
        [ 7.9300, 10.1789,  7.1278, 10.9416, 15.2420, 24.6844, 15.9916],
        [18.4903, 21.1599,  3.9584,  8.5613,  9.2320, 10.9942, 12.2304],
        [11.2670, 1

batch_predictions
tensor([[ 5.8719,  6.3689,  6.4575,  6.4647,  6.4619,  6.4532,  6.4520],
        [16.7290, 20.8866, 21.7166, 20.6902, 19.2235, 17.6054, 16.4716],
        [22.9847, 28.3660, 27.7276, 26.5937, 25.3130, 23.8695, 22.1921],
        [11.3595, 14.1303, 18.4267, 20.6671, 19.2693, 17.2861, 15.0321],
        [20.2409, 20.6294, 20.4021, 20.0592, 19.7770, 19.6033, 19.5375],
        [22.0057, 28.5774, 27.8265, 26.1812, 23.6493, 20.3142, 18.9519],
        [37.8440, 37.4258, 37.4748, 37.5358, 37.5761, 37.6063, 37.6332],
        [18.9985, 23.6092, 22.4862, 20.2007, 17.2546, 15.7530, 15.8865]])
ground truth
tensor([[ 1.8543,  9.0084, 11.5071, 17.7275, 18.9506, 27.6433,  6.3388],
        [13.9881, 16.2982, 24.0930, 34.8498, 19.2602, 13.1661, 16.9643],
        [22.8695, 16.4782, 18.4903, 25.6181, 45.0684, 32.9432, 21.1994],
        [ 9.6372,  7.1003, 10.3741, 19.6712, 32.2279, 17.3895,  6.9728],
        [17.5028, 16.9926, 25.2551, 26.4314,  2.2959, 13.6054, 13.4779],
        [17.0351, 2

batch_predictions
tensor([[19.9788, 20.0513, 20.0337, 19.9872, 19.9455, 19.9242, 19.9238],
        [28.0120, 23.7328, 24.8322, 26.3741, 27.7841, 28.9252, 29.7636],
        [14.9360, 14.4925, 15.7658, 16.8037, 17.4194, 17.7215, 17.8828],
        [20.6895, 23.5979, 22.9825, 22.0392, 21.0838, 20.1786, 19.4403],
        [21.1244, 20.9643, 20.9808, 21.0255, 21.0700, 21.1164, 21.1679],
        [20.5183, 20.4420, 20.4518, 20.4660, 20.4714, 20.4821, 20.5073],
        [18.6885, 21.6523, 20.9798, 19.8164, 18.5601, 17.3975, 16.6580],
        [25.9764, 35.1780, 35.4316, 34.2768, 32.4243, 29.7835, 25.1368]])
ground truth
tensor([[24.0400, 19.2004, 17.1883,  8.9295, 13.5455, 12.2962, 11.9411],
        [18.3673, 19.2319, 12.2166, 31.2075, 31.5618, 41.6241, 23.9654],
        [ 3.3022,  7.3980,  5.6689,  8.4184, 15.4478, 16.4399, 23.7528],
        [19.4240, 13.3482, 11.2178, 17.6618, 26.4072,  9.4424, 46.7254],
        [27.4592, 11.3361, 17.7275, 13.1773, 11.3230, 15.9127, 26.8017],
        [39.1723, 2

batch_predictions
tensor([[ 8.1393,  8.6809,  8.9375,  9.0679,  9.1291,  9.1580,  9.1763],
        [17.3822, 16.1417, 17.7739, 19.1590, 20.0490, 20.4585, 20.6559],
        [16.7624, 16.0108, 15.9915, 16.2601, 16.5787, 16.8728, 17.1235],
        [14.4095, 17.2281, 17.9425, 16.7981, 15.3529, 13.7856, 12.4093],
        [37.5675, 38.1942, 38.1553, 38.0978, 38.0580, 38.0344, 38.0221],
        [ 9.1702, 10.5368, 11.8146, 13.6323, 16.9993, 20.0828, 18.9946],
        [37.4011, 38.4058, 38.3991, 38.3946, 38.3892, 38.3843, 38.3791],
        [ 7.7952,  8.3186,  8.4935,  8.5513,  8.5634,  8.5640,  8.5666]])
ground truth
tensor([[ 9.7506,  6.8594,  6.5901, 16.1990, 17.6587, 20.8192,  8.9286],
        [27.6361, 12.2449, 15.8730, 17.4036, 16.2698, 24.3764, 27.9620],
        [10.5471, 12.7433, 10.1262, 14.4792, 19.5423, 18.1878, 16.9385],
        [16.0836, 22.6460,  9.9947,  7.5618,  7.4566,  8.8112, 10.8364],
        [28.8549, 22.5057, 29.1241, 37.3866, 56.7460, 50.1701, 32.6531],
        [ 8.5601,  

batch_predictions
tensor([[20.9619, 20.9211, 20.9289, 20.9363, 20.9292, 20.9266, 20.9397],
        [ 4.8148,  4.7945,  4.7986,  4.8038,  4.8060,  4.8066,  4.8078],
        [15.0577, 15.0832, 15.0798, 15.0585, 15.0373, 15.0289, 15.0344],
        [19.3289, 18.7416, 18.6849, 18.8285, 19.0191, 19.2078, 19.3804],
        [17.9201, 16.1820, 16.7953, 17.6150, 18.2778, 18.7711, 19.1325],
        [23.0479, 21.2777, 21.1156, 21.6197, 22.2201, 22.7680, 23.2346],
        [14.6713, 17.1961, 17.4268, 16.7723, 16.1122, 15.5914, 15.2296],
        [18.2017, 21.4339, 20.9285, 20.0304, 19.1577, 18.4640, 18.0231]])
ground truth
tensor([[16.4116, 21.6978, 28.7840, 34.7506, 17.5454, 17.7863, 10.7285],
        [ 4.6769,  5.0595, 10.6151, 11.5363,  7.5113,  4.4501,  4.1100],
        [21.3435, 25.5669, 11.6638, 11.4938,  9.0561, 11.4512, 13.9598],
        [32.4546, 30.3571,  0.5385, 13.6480, 13.7897, 15.6604, 23.0300],
        [ 3.8138,  0.0000,  1.0652, 12.7564, 15.1894, 28.5376, 28.7743],
        [13.1519, 1

batch_predictions
tensor([[21.2413, 20.1559, 20.0146, 20.2711, 20.6210, 20.9572, 21.2525],
        [16.7116, 18.7913, 18.4948, 17.7291, 16.9793, 16.3794, 15.9754],
        [28.0379, 25.2866, 24.0841, 24.1119, 24.8596, 25.9159, 27.1540],
        [19.1214, 18.0562, 17.8963, 18.1403, 18.4850, 18.8242, 19.1323],
        [19.9484, 19.6480, 19.6565, 19.7344, 19.8113, 19.8799, 19.9432],
        [18.2487, 22.2580, 21.4172, 19.8690, 18.0744, 16.5554, 16.0014],
        [17.6668, 22.5318, 22.6225, 21.5847, 19.9940, 18.1547, 16.9436],
        [30.2524, 34.7949, 34.8411, 33.9962, 32.8053, 31.4516, 29.8838]])
ground truth
tensor([[13.6763, 12.3299, 15.4478, 15.3486, 20.8617, 31.2075,  9.3112],
        [10.8101, 13.5718, 24.6318,  5.4708,  5.5234,  7.1541,  3.1168],
        [50.5102, 41.7375, 16.2415, 20.6349, 19.3169, 19.4019, 33.9569],
        [13.7559, 12.2173,  9.7712, 14.2031, 16.9253, 19.5029, 24.6186],
        [16.9359,  6.9019, 10.5726, 16.6383, 17.9280, 24.3622, 21.2443],
        [14.1636,  

batch_predictions
tensor([[11.1080,  9.2732,  7.7895,  7.2636,  7.1588,  7.1395,  7.1327],
        [15.1150, 18.8999, 19.4025, 18.4441, 17.3524, 16.3329, 15.5064],
        [17.6398, 20.7753, 19.9780, 18.4473, 16.6720, 15.1541, 14.6668],
        [17.0498, 18.1145, 17.7878, 17.1837, 16.6365, 16.2521, 16.0513],
        [17.6550, 16.2958, 17.6178, 18.7084, 19.4773, 19.9351, 20.1924],
        [19.0482, 20.2819, 19.8223, 18.8948, 17.8634, 16.9770, 16.4402],
        [24.6154, 24.0412, 23.9675, 24.0897, 24.2729, 24.4776, 24.6864],
        [16.6284, 16.3959, 16.4274, 16.5075, 16.5830, 16.6487, 16.7068]])
ground truth
tensor([[ 9.0561,  0.2834,  6.6893, 10.3033,  7.9507, 10.1474, 13.1519],
        [15.2946, 27.7617, 24.0400,  8.3772, 10.0342, 10.8364, 10.3761],
        [15.9653, 12.8090, 17.1226, 23.3693, 27.3014, 10.3367, 12.3225],
        [23.6323,  5.9705, 11.6649,  9.6002,  9.0873, 12.7696, 17.8985],
        [15.1894,  6.0889, 15.0710, 13.6902, 17.2935, 23.4219, 31.1547],
        [13.0385, 1

batch_predictions
tensor([[18.4307, 17.0243, 17.6552, 18.3056, 18.7799, 19.1140, 19.3500],
        [ 8.4135,  9.6931, 10.4771, 11.1843, 11.9970, 13.0844, 14.4782],
        [15.9590, 14.6808, 15.4818, 16.2203, 16.7563, 17.1128, 17.3438],
        [27.4363, 24.2918, 23.2266, 23.6590, 24.6820, 25.8899, 27.1831],
        [12.0440, 14.2345, 16.3567, 17.6333, 17.5531, 17.2068, 16.8938],
        [18.2764, 24.0697, 23.2947, 21.1275, 17.8612, 16.1994, 16.3985],
        [31.7018, 31.2914, 31.2942, 31.3901, 31.4944, 31.5967, 31.6945],
        [21.9603, 31.0419, 34.3959, 33.4423, 30.1418, 22.1602, 18.5859]])
ground truth
tensor([[25.2834, 14.2007, 13.6196, 17.4461, 17.9138, 17.0777, 20.8333],
        [26.8707, 16.0289,  0.0000,  0.0000,  8.9569, 12.4150, 14.1865],
        [20.9892, 10.6576, 11.7063,  8.7868, 13.9598, 18.4807, 27.9904],
        [47.1514,  2.0408, 12.3724, 20.7200, 21.9529, 15.8588, 33.5176],
        [ 3.2880, 11.2245, 11.7063, 10.6151, 20.0255,  3.2313,  4.5210],
        [10.8702, 1

batch_predictions
tensor([[13.3766, 15.4168, 17.2711, 17.1798, 16.7418, 16.3779, 16.1411],
        [21.5432, 26.5664, 25.9460, 23.6634, 19.3537, 16.5976, 16.8708],
        [37.6537, 36.8890, 36.9218, 37.0466, 37.1501, 37.2285, 37.2917],
        [21.2882, 23.4366, 22.9841, 22.1725, 21.3344, 20.5795, 20.0086],
        [13.6864, 16.5435, 18.9635, 18.3769, 17.7487, 17.3130, 17.0910],
        [23.0212, 31.8507, 32.0713, 30.5604, 27.7534, 22.3127, 18.8279],
        [13.1566, 16.4094, 20.4947, 19.8394, 18.5220, 17.1260, 16.0912],
        [27.5934, 24.6710, 24.3794, 25.2887, 26.4141, 27.4649, 28.3502]])
ground truth
tensor([[12.7959,  8.8638, 11.3756, 14.2031, 21.3835, 13.2562, 11.7438],
        [24.6457, 31.8452, 21.8679,  5.7965, 12.6701, 11.1820, 14.2149],
        [42.7154, 32.5964, 27.9904, 26.7857, 32.4121, 38.2795, 56.4484],
        [17.9248,  1.4598,  0.0000, 10.3235, 10.2578, 13.5192, 22.2120],
        [19.9121, 23.8237, 21.4144, 12.0748, 14.1015, 11.0261, 11.9189],
        [19.9369, 2

batch_predictions
tensor([[15.6897, 20.0340, 22.1199, 21.1159, 19.4880, 17.3895, 15.8519],
        [18.3931, 23.9098, 26.7330, 26.1265, 24.4484, 21.0727, 17.4986],
        [12.8050, 16.0365, 20.6779, 19.4536, 17.2264, 14.8613, 13.4359],
        [19.1170, 25.6375, 25.0981, 22.8512, 18.6235, 16.2099, 16.4250],
        [19.6346, 23.6022, 22.8291, 20.8343, 17.9451, 16.0316, 16.0489],
        [14.9434, 17.0636, 17.2896, 16.6270, 15.8841, 15.2410, 14.7349],
        [15.5618, 15.3738, 16.5416, 17.4921, 18.1825, 18.6247, 18.8724],
        [28.1557, 27.6333, 27.5688, 27.6587, 27.7884, 27.9343, 28.0862]])
ground truth
tensor([[10.5602, 13.5455, 11.7570, 19.0295, 27.5118, 32.4829, 11.0074],
        [ 8.7443, 18.1831, 19.0334, 29.4076, 38.8180, 36.0119, 10.1616],
        [10.7143, 10.8277, 22.9308, 30.3571, 18.2540,  9.0561, 10.2041],
        [18.5955, 15.6365, 22.9748, 35.8759, 27.3672, 12.2962, 12.1778],
        [13.3877, 18.7533, 33.4824, 13.8611, 13.7033, 11.3756, 13.8348],
        [23.4613, 2

batch_predictions
tensor([[18.6706, 16.8795, 17.5358, 18.4727, 19.2451, 19.8335, 20.2806],
        [22.2476, 20.0421, 20.4643, 21.3086, 22.0459, 22.6230, 23.0570],
        [13.1459, 12.2687, 12.5032, 13.1463, 13.9829, 14.7594, 15.2065],
        [15.5244, 15.3503, 16.6788, 17.8245, 18.6687, 19.1454, 19.3516],
        [17.2086, 20.3122, 20.1372, 19.3723, 18.5853, 17.9473, 17.5237],
        [24.0463, 28.5279, 27.9154, 26.9471, 25.9606, 24.9687, 23.8711],
        [12.9843, 16.5306, 20.2977, 19.1323, 17.6699, 16.4280, 15.6482],
        [27.7570, 28.2185, 28.2373, 28.1008, 27.9113, 27.7474, 27.6459]])
ground truth
tensor([[31.4177,  5.1815, 16.7938,  5.7075, 15.7812, 16.2415, 21.3572],
        [ 9.8498, 14.8101, 18.1973, 12.6559, 20.7766, 22.6616, 31.0516],
        [11.3624,  3.5245,  7.0095,  8.5613,  9.0479, 12.2962, 24.0137],
        [12.0726, 13.9532, 14.2951, 14.7159, 14.4924, 12.1515, 20.7128],
        [11.9756, 12.5142, 12.8118, 17.8005, 28.5714, 24.9433, 12.0040],
        [15.2420, 2

batch_predictions
tensor([[28.3766, 24.2183, 24.9206, 26.3646, 27.7561, 28.9131, 29.7895],
        [20.5332, 26.5052, 25.3347, 22.1464, 17.1259, 16.3833, 16.6578],
        [20.9745, 28.0979, 27.0551, 24.6442, 20.3281, 16.9183, 17.3211],
        [38.2573, 38.1956, 38.2020, 38.2029, 38.2019, 38.2005, 38.1994],
        [16.6229, 20.4109, 20.4853, 19.5165, 18.3197, 17.1674, 16.3085],
        [ 8.3768,  9.0162,  9.2797,  9.4016,  9.4545,  9.4787,  9.4945],
        [17.1653, 21.1887, 21.0046, 20.0220, 18.8453, 17.7500, 16.9870],
        [19.9447, 18.3905, 19.1106, 19.9978, 20.6798, 21.1683, 21.5132]])
ground truth
tensor([[43.1615, 17.2278, 17.4250, 16.5176, 20.7391, 24.9868, 52.8538],
        [27.6644, 40.9155, 36.8764, 14.4983, 19.7988, 16.1565, 15.7313],
        [20.1814, 33.9711, 33.8719,  6.3067, 17.9138, 15.1219, 20.2098],
        [42.2935, 41.8201, 37.6381, 38.2167, 41.5834, 57.9037, 48.6981],
        [13.0385, 13.4495, 20.5924, 25.3968, 28.3588, 10.8135,  8.6876],
        [ 8.2720, 1

batch_predictions
tensor([[16.1204, 19.2967, 19.3874, 18.7601, 18.1278, 17.6445, 17.3468],
        [23.0892, 23.1307, 23.1252, 23.0958, 23.0677, 23.0560, 23.0621],
        [22.1500, 24.9896, 24.4377, 23.1306, 21.3044, 19.2480, 18.0971],
        [24.9614, 21.8637, 23.6178, 25.5447, 27.3083, 28.8321, 30.0953],
        [26.3066, 22.5848, 22.9735, 24.3113, 25.6935, 26.9147, 27.8925],
        [20.2270, 25.1031, 23.7594, 21.0533, 17.1960, 16.1034, 16.3517],
        [15.1933, 19.6613, 23.8570, 23.8228, 22.8187, 21.1746, 18.9474],
        [17.9139, 18.8317, 18.7455, 18.1993, 17.5410, 16.9817, 16.6237]])
ground truth
tensor([[12.5329, 14.9264, 13.8217, 20.8837, 22.8958, 10.7838, 14.5055],
        [18.5658, 13.9456, 14.0873, 22.8741, 32.3980, 30.8107, 23.5544],
        [21.5845, 33.1491, 39.7251,  7.2279, 11.9473, 15.0794, 18.3957],
        [35.9694, 13.9598, 18.2256, 12.4717, 16.0714, 21.9388, 44.3452],
        [41.0856,  8.1633, 13.9881, 22.7466, 19.3878, 22.7183, 31.7744],
        [17.8005, 2

batch_predictions
tensor([[21.3962, 20.6505, 20.5083, 20.6617, 20.9429, 21.2804, 21.6367],
        [18.4222, 23.2855, 26.3681, 26.3524, 25.4116, 23.5950, 20.2790],
        [22.4832, 24.2438, 23.9385, 23.1428, 22.1982, 21.2689, 20.5123],
        [24.4137, 26.1578, 25.9369, 25.3273, 24.6714, 24.1052, 23.6958],
        [11.9457, 12.9841, 13.9362, 14.6968, 14.9271, 14.7673, 14.5078],
        [30.4058, 33.9984, 34.0903, 33.3401, 32.2822, 31.1320, 29.9584],
        [14.5689, 14.2215, 15.4421, 16.4606, 17.0839, 17.3815, 17.5282],
        [16.5826, 20.5337, 20.8562, 20.0692, 19.1187, 18.2311, 17.5567]])
ground truth
tensor([[32.6407,  6.4703, 15.3077, 15.5576, 17.5302, 25.6970, 20.2525],
        [16.2020, 20.5550, 29.0242, 45.4629, 50.9600, 16.7675, 20.5944],
        [11.9411, 18.4114, 25.6970, 39.1899, 27.2488, 18.0037, 18.5692],
        [20.0026, 24.8948, 12.3225, 10.8496, 11.0205, 12.8880, 11.1915],
        [10.4550,  0.0000,  0.0000,  0.0000,  3.3798, 12.6249, 20.0684],
        [27.6696, 3

batch_predictions
tensor([[17.5341, 23.0499, 23.4629, 22.3737, 20.4363, 17.8221, 16.5151],
        [18.9962, 21.1969, 20.8357, 20.1758, 19.5733, 19.1216, 18.8490],
        [ 4.9919,  4.5272,  4.4956,  4.5195,  4.5392,  4.5522,  4.5618],
        [28.4997, 30.9835, 30.9410, 30.2160, 29.2093, 28.0474, 26.7808],
        [13.7045, 15.1376, 15.5818, 15.2707, 14.7623, 14.2755, 13.8616],
        [14.6671, 18.2327, 19.0028, 18.1277, 17.1689, 16.3191, 15.6416],
        [20.7110, 29.3179, 32.4340, 30.8187, 25.6901, 17.0859, 17.4247],
        [ 9.2343,  9.8862, 10.2433, 10.4772, 10.6406, 10.7651, 10.8707]])
ground truth
tensor([[16.9926, 10.9694, 15.2778, 22.8883, 32.8940, 28.9683,  9.5096],
        [16.7806, 13.9400, 14.6370, 15.4787, 29.1426, 28.4850, 29.1031],
        [10.1131,  3.8138,  3.3403, 14.5581, 13.7428, 12.0594, 17.3856],
        [14.9943, 16.0714, 27.9620, 34.4529, 31.9161, 20.2664, 18.0556],
        [23.0274, 28.0773, 10.9811, 14.1504,  7.8248, 11.3756, 12.3751],
        [14.4529, 1

batch_predictions
tensor([[16.6752, 16.4857, 16.5576, 16.6634, 16.7512, 16.8197, 16.8744],
        [18.2936, 20.7297, 20.3047, 19.4293, 18.5365, 17.7750, 17.2459],
        [20.8671, 21.0567, 21.0227, 20.9107, 20.7974, 20.7213, 20.6914],
        [18.6251, 21.7035, 20.5887, 18.6805, 16.2909, 14.7918, 14.7512],
        [15.0345, 18.6022, 20.0438, 19.3236, 18.5015, 17.8118, 17.3249],
        [18.1783, 17.7868, 17.7767, 17.8990, 18.0452, 18.1882, 18.3200],
        [19.3463, 22.7867, 21.2031, 18.2905, 14.5810, 13.8896, 14.0581],
        [18.8783, 22.6991, 21.8235, 20.4267, 18.8909, 17.4908, 16.7474]])
ground truth
tensor([[24.2205,  9.9206,  2.3951, 16.3549, 13.2370, 14.0164, 13.0244],
        [16.6808, 22.7466, 21.3294, 14.8384, 16.5249, 13.1236, 13.3787],
        [32.4263, 34.3679, 13.7613, 18.9484, 11.0544, 18.3248, 14.3566],
        [22.8741, 36.7630, 25.0567, 15.3486, 18.2965, 15.1644, 21.2018],
        [18.8350, 26.4031, 23.1009, 12.6984, 16.3265, 14.2290, 11.9473],
        [22.8432, 2

batch_predictions
tensor([[16.6081, 19.6885, 22.0958, 22.8351, 22.1711, 20.7828, 18.9326],
        [21.3375, 18.6246, 20.0767, 21.8005, 23.2160, 24.2049, 24.9201],
        [19.7318, 21.2196, 20.8287, 20.1579, 19.5169, 19.0135, 18.6989],
        [19.1893, 18.7428, 18.7248, 18.8564, 19.0162, 19.1713, 19.3123],
        [19.1855, 17.6750, 17.8207, 18.5242, 19.1958, 19.7557, 20.2097],
        [25.6176, 24.4399, 24.3382, 24.6229, 24.9828, 25.3165, 25.5982],
        [20.4288, 21.4589, 20.9649, 20.1641, 19.3686, 18.7372, 18.3662],
        [24.8676, 24.8028, 24.8018, 24.8077, 24.7991, 24.7988, 24.8184]])
ground truth
tensor([[ 7.8231, 10.1616, 17.9138, 29.5918, 21.2868,  8.1066, 13.9031],
        [ 5.4971, 14.6896, 13.6113, 14.7291, 15.0579, 25.0526, 31.5097],
        [13.0984, 15.5708, 17.6223, 23.8033, 31.6281, 10.4945, 14.3346],
        [28.9979, 12.6249, 13.5192, 11.3361, 11.9937, 19.4240, 25.7496],
        [27.4855, 10.3104, 13.6770, 15.8864, 19.3977, 19.4503, 21.7912],
        [18.1689, 1

batch_predictions
tensor([[20.5775, 21.8685, 21.6277, 20.9906, 20.3008, 19.7202, 19.3342],
        [22.1878, 21.9813, 21.9980, 22.0473, 22.0932, 22.1372, 22.1828],
        [16.3647, 20.1989, 20.6437, 20.0277, 19.3313, 18.7455, 18.3487],
        [20.7763, 25.5061, 23.9207, 20.9702, 16.8350, 16.1700, 16.5230],
        [21.7943, 27.6279, 27.0194, 24.8788, 20.9025, 17.4330, 17.6325],
        [20.5617, 25.4478, 24.6639, 22.5214, 19.1852, 16.9222, 17.0209],
        [17.7627, 17.8654, 17.8308, 17.7322, 17.6305, 17.5620, 17.5362],
        [15.3361, 18.9031, 19.7522, 18.8418, 17.7863, 16.8067, 16.0223]])
ground truth
tensor([[19.4766, 24.8685, 19.9435, 14.2294, 11.1849, 11.6781, 11.2046],
        [24.5134, 19.3582, 22.6986, 13.9006, 13.5981, 15.2946, 18.4377],
        [20.1672, 26.8991, 29.4218, 12.9677, 19.1610, 17.7721, 18.5374],
        [30.4563, 31.5476, 11.6922, 17.5170, 15.1927, 16.6950, 19.5011],
        [19.6003, 25.1134, 40.8163, 33.6026,  0.0000, 10.6293, 19.5437],
        [21.6465, 3

batch_predictions
tensor([[21.0640, 24.1587, 23.5318, 22.4736, 21.3146, 20.1232, 19.1720],
        [16.4618, 17.8641, 17.6598, 17.1162, 16.6079, 16.2367, 16.0205],
        [20.9542, 20.4166, 20.3665, 20.4953, 20.6668, 20.8403, 21.0029],
        [13.1568, 16.8399, 20.6623, 19.0507, 16.7007, 14.2616, 13.0281],
        [21.0916, 24.1417, 23.6242, 22.4251, 20.9049, 19.2956, 18.2789],
        [ 9.1746, 10.2453, 11.1288, 12.1330, 13.5778, 15.8618, 17.5941],
        [10.7051, 11.6642, 12.7085, 14.0827, 15.5195, 15.6551, 15.2362],
        [13.3810, 17.6778, 21.5200, 20.1169, 17.9969, 15.6296, 14.1076]])
ground truth
tensor([[19.9500,  7.7065,  8.7322, 11.5071, 19.8317, 18.7138, 20.8837],
        [11.8622, 11.2387, 11.3520, 15.0935, 21.1026, 27.4376, 18.1831],
        [29.6627, 23.7670, 12.4291, 18.4949, 17.9847, 17.2902, 21.5420],
        [ 9.1268, 19.4240,  9.0216, 17.7012, 24.5529, 24.1057,  8.5613],
        [19.8185, 26.4072, 44.7922, 28.4061, 16.9911, 17.8327, 18.8585],
        [ 9.3766, 1

batch_predictions
tensor([[18.7785, 18.8998, 18.8469, 18.7363, 18.6336, 18.5708, 18.5529],
        [14.3161, 16.3082, 16.7220, 16.3641, 15.9629, 15.6656, 15.4839],
        [12.7016, 12.6504, 12.6384, 12.6483, 12.6423, 12.6316, 12.6291],
        [38.0700, 38.3364, 38.3182, 38.2968, 38.2827, 38.2749, 38.2715],
        [14.0245, 17.4907, 19.5483, 18.6899, 17.7436, 16.9737, 16.4427],
        [21.6248, 23.3292, 22.7429, 21.7743, 20.7162, 19.7180, 18.9930],
        [17.5526, 19.1366, 18.7397, 17.7626, 16.6523, 15.6763, 15.0715],
        [12.8930, 16.4859, 20.9718, 20.0519, 18.5336, 16.9988, 16.0261]])
ground truth
tensor([[11.9615, 16.8226, 25.1276, 26.8991, 12.1457, 12.7693, 19.1468],
        [ 9.9290, 14.0058, 12.6381, 23.6718, 10.2052, 13.6376, 11.1257],
        [12.5592, 17.4250, 21.4361,  8.6665, 11.5992, 11.2441, 11.6781],
        [28.6565, 26.4172, 27.2534, 44.3736, 65.2069, 49.7449, 34.4813],
        [14.6502, 27.1305, 18.0957,  0.0000,  9.3503, 11.7044, 10.6391],
        [24.3030, 2

batch_predictions
tensor([[17.6362, 23.2194, 28.7068, 30.8555, 30.7425, 29.9227, 28.7302],
        [16.4602, 15.7026, 15.6723, 15.9119, 16.1933, 16.4499, 16.6676],
        [15.0271, 17.4748, 18.0673, 17.5587, 17.0004, 16.5698, 16.2911],
        [18.1002, 22.7844, 22.7783, 21.4620, 19.3858, 17.0523, 16.1556],
        [20.6021, 27.5332, 27.2663, 25.6985, 22.6724, 19.0125, 18.2380],
        [12.6444, 14.8157, 16.7517, 16.6719, 16.0206, 15.3555, 14.7812],
        [27.5909, 31.2567, 31.2198, 30.2109, 28.6212, 26.3002, 23.4144],
        [10.9084, 12.9928, 15.3281, 17.7345, 18.5500, 18.4807, 18.1916]])
ground truth
tensor([[13.8889, 14.2715, 14.1865, 13.7613, 20.7058,  6.3350,  4.2800],
        [19.6344,  7.7722, 11.7307, 12.4803, 12.6907, 16.0047, 22.6854],
        [12.8685, 21.0176, 28.9683, 20.6916, 12.4008, 13.7188, 13.3220],
        [12.7409,  9.2262, 22.6616, 21.2443, 32.8798, 24.6882, 16.6950],
        [10.0198, 21.5561, 45.7200, 38.7755, 22.5482, 19.2177, 24.3764],
        [13.6902, 2

batch_predictions
tensor([[18.6613, 24.5567, 26.1993, 25.5068, 23.9939, 21.2051, 18.0059],
        [24.2699, 29.3836, 29.0434, 27.6034, 25.1926, 21.3596, 19.1217],
        [20.7060, 21.5210, 21.3393, 20.9940, 20.6959, 20.5018, 20.4119],
        [18.5768, 18.5421, 18.5630, 18.5741, 18.5774, 18.5846, 18.5997],
        [11.1358, 13.8845, 16.7291, 18.4149, 18.6724, 18.3356, 17.8941],
        [ 9.1749,  9.5116,  9.6998,  9.8043,  9.8522,  9.8689,  9.8736],
        [20.9781, 23.4659, 23.0524, 21.9625, 20.5680, 19.1199, 18.1479],
        [17.7548, 17.7866, 17.7626, 17.7251, 17.6961, 17.6854, 17.6899]])
ground truth
tensor([[18.0556, 20.7058, 24.1071, 41.1706, 26.4881,  8.7727, 13.9881],
        [30.0028, 50.0850, 34.5663, 29.7052, 22.5907, 14.9943, 16.0714],
        [17.7670, 23.0668, 23.8953, 10.7575,  1.5255, 15.8206, 14.7422],
        [ 8.9144, 10.2466, 11.2387,  9.7647,  9.6655, 18.4099, 21.1310],
        [ 1.2330, 12.6842,  9.7647, 14.8243, 22.7891, 22.1372, 11.6780],
        [10.4813,  

batch_predictions
tensor([[19.1401, 25.3488, 25.0844, 23.4595, 20.3758, 17.3237, 17.0531],
        [19.2591, 21.7789, 20.8221, 19.2449, 17.2724, 15.6463, 15.3054],
        [ 4.8487,  4.9672,  4.9666,  4.9636,  4.9618,  4.9555,  4.9494],
        [37.3487, 38.3100, 38.2484, 38.1810, 38.1349, 38.1058, 38.0885],
        [19.3723, 23.0660, 22.1647, 20.5236, 18.4802, 16.6634, 16.2514],
        [32.6525, 27.6758, 27.9691, 29.3434, 30.7618, 32.0544, 33.1817],
        [25.2369, 26.7456, 26.5932, 26.0048, 25.2898, 24.5964, 24.0167],
        [23.7331, 20.9878, 21.3686, 22.3432, 23.2545, 24.0076, 24.5976]])
ground truth
tensor([[24.2347, 28.6848, 41.6525, 22.2222, 10.4875, 16.7375, 18.6224],
        [19.8554, 30.8957,  8.9569, 12.1740, 12.7693, 13.0244, 16.4683],
        [ 4.1383,  5.1587,  5.7115,  6.0516,  7.0437, 11.2103, 13.4637],
        [35.7993, 53.4297, 50.6519, 31.7319, 24.9433, 20.9042, 23.3277],
        [15.0184, 16.5308, 24.3556, 26.4072,  8.7849, 14.1899, 12.3225],
        [41.5958, 2

batch_predictions
tensor([[20.5815, 18.6805, 18.8715, 19.5664, 20.2002, 20.7147, 21.1204],
        [20.5885, 28.7411, 29.3137, 27.8609, 24.5647, 18.6956, 18.0083],
        [38.2652, 38.1792, 38.1896, 38.1923, 38.1920, 38.1910, 38.1902],
        [14.2874, 17.1906, 17.8774, 17.3446, 16.8411, 16.4947, 16.3053],
        [32.7856, 32.5465, 32.5374, 32.5671, 32.5927, 32.6259, 32.6733],
        [19.4028, 24.6998, 24.0793, 22.6386, 20.6641, 18.8715, 18.1671],
        [26.7259, 28.2003, 28.0733, 27.5444, 26.9208, 26.3385, 25.8725],
        [24.7293, 30.3716, 30.4643, 29.1070, 26.5866, 22.3011, 19.6400]])
ground truth
tensor([[29.1525,  3.5856, 10.3033, 13.9881, 13.6338, 19.7562, 31.4201],
        [21.3152, 36.1536, 42.0777, 33.2341, 19.7704, 18.1689, 15.1502],
        [41.3204, 43.2799, 42.8064, 54.3135, 52.0779, 43.3456, 54.1163],
        [16.9648, 18.5692, 26.5781, 24.0531,  8.0615,  8.0747, 15.8338],
        [55.6418, 14.1899,  0.0000, 15.4787, 21.9227, 27.8538, 38.6770],
        [13.6196, 3

tensor([[18.3788, 20.2729, 19.9047, 18.9293, 17.7951, 16.7721, 16.1294],
        [17.9841, 16.4292, 17.7469, 18.8318, 19.5851, 20.0365, 20.3008],
        [20.1407, 18.1150, 18.3208, 19.1198, 19.8655, 20.4846, 20.9863],
        [16.8921, 21.4760, 20.5646, 18.6561, 16.3287, 14.7735, 14.5286],
        [18.0152, 17.8580, 17.8922, 17.9522, 18.0042, 18.0504, 18.0952],
        [21.1709, 28.6859, 28.3872, 26.7843, 23.6553, 19.3857, 18.3280],
        [18.0462, 18.2850, 18.1796, 17.9968, 17.8420, 17.7506, 17.7214],
        [22.4830, 30.1977, 29.7941, 27.4137, 22.0538, 17.3647, 17.7975]])
ground truth
tensor([[21.8701, 30.0368, 10.0736, 11.2835,  9.1005, 13.4534, 18.2272],
        [11.2245, 14.2715, 16.3124, 11.6780, 17.3895, 27.0833, 29.5777],
        [26.9133, 12.3866, 11.7489, 11.1961, 20.0397, 26.5448, 24.8866],
        [12.7693, 13.0244, 16.4683, 28.0471, 29.3084, 28.6990, 13.2086],
        [14.1723, 16.1990, 19.7988, 27.6927, 11.3520,  6.9444, 16.3832],
        [17.5454, 30.0028, 50.0850, 3

batch_predictions
tensor([[20.7141, 23.2897, 22.6998, 21.6362, 20.4225, 19.1979, 18.3190],
        [38.3120, 38.3008, 38.3030, 38.3036, 38.3037, 38.3037, 38.3037],
        [15.9778, 15.9263, 15.9786, 16.0251, 16.0566, 16.0830, 16.1077],
        [15.5139, 17.5281, 17.7367, 17.2627, 16.7713, 16.4028, 16.1734],
        [26.5461, 29.1181, 28.7757, 28.1313, 27.5187, 27.0197, 26.6594],
        [ 6.3801,  5.6794,  5.5723,  5.5817,  5.6002,  5.6127,  5.6216],
        [25.3132, 30.3704, 30.1478, 28.8545, 26.7527, 23.2326, 20.1795],
        [23.9880, 30.9490, 30.6601, 28.5397, 24.0076, 18.5858, 18.6160]])
ground truth
tensor([[18.5034, 27.3803, 27.6039, 16.6360, 13.2167, 13.9400, 12.5855],
        [38.5324, 45.5155, 53.4850, 55.9179, 38.1378, 40.7286, 22.9090],
        [19.9500, 18.7270, 18.4508, 11.4940, 14.8869, 14.0847, 20.0815],
        [ 9.0347, 10.9022, 10.8101, 14.4661, 19.8185, 13.3482, 11.2441],
        [22.0673, 26.4203, 39.3477, 43.2667, 39.2031, 30.2078, 29.1294],
        [ 6.0363,  

batch_predictions
tensor([[22.6389, 31.5711, 32.9758, 31.8910, 29.8217, 25.9165, 20.6946],
        [38.3277, 38.1904, 38.2059, 38.2105, 38.2102, 38.2087, 38.2072],
        [18.8561, 17.7091, 19.0189, 20.2301, 21.2527, 22.0495, 22.5898],
        [15.9331, 16.5027, 16.3469, 15.9374, 15.5289, 15.2291, 15.0645],
        [29.6701, 25.6571, 25.2159, 26.2856, 27.6161, 28.8789, 29.9822],
        [13.5358, 12.6267, 12.9665, 13.7691, 14.6460, 15.2748, 15.6454],
        [36.8777, 36.6424, 36.6503, 36.6785, 36.6979, 36.7201, 36.7495],
        [19.7175, 17.9887, 18.1367, 18.7877, 19.3867, 19.8774, 20.2697]])
ground truth
tensor([[15.4337, 23.6395, 23.7245, 26.7715, 39.0023, 51.9841, 39.9660],
        [49.6712, 44.2662, 50.1447, 36.9937, 41.6623, 47.6854, 66.2941],
        [ 9.8764, 14.3740, 12.2962, 11.5729, 15.7680, 18.2404, 28.8138],
        [15.8732, 12.7696, 17.6486, 25.8154, 25.4866,  8.0089, 12.8880],
        [21.7517, 16.2020, 18.3719, 18.5955, 23.9216, 45.0158, 39.7291],
        [ 4.4845,  

batch_predictions
tensor([[18.7484, 19.6519, 19.3026, 18.8119, 18.4112, 18.1579, 18.0508],
        [18.2183, 20.8106, 20.1448, 19.1302, 18.1467, 17.3191, 16.7678],
        [20.5116, 22.8690, 22.5759, 21.4527, 19.9081, 18.3228, 17.3743],
        [25.4246, 21.9207, 22.7243, 23.9691, 25.1304, 26.1046, 26.8563],
        [19.6241, 19.1897, 19.1620, 19.2777, 19.4206, 19.5618, 19.6937],
        [15.7177, 14.8366, 15.8391, 16.5965, 17.0896, 17.3927, 17.5800],
        [ 9.9305, 11.3007, 12.7564, 14.8462, 17.3390, 17.5247, 16.9579],
        [18.4325, 22.9971, 22.4682, 21.3957, 20.1622, 19.0726, 18.4381]])
ground truth
tensor([[ 9.9684, 13.3745, 16.0179, 16.5439, 14.5581, 19.2925, 14.6765],
        [15.0935, 23.4127, 20.0539, 21.7545, 31.9728, 26.0062, 17.4603],
        [17.8590, 23.9742, 20.8837, 11.8359,  8.0747, 11.1389, 15.8732],
        [39.6542, 21.0601, 17.1627, 12.7268, 10.0198, 21.5561, 45.7200],
        [25.8220, 12.2307, 12.5000,  8.5176, 11.3520, 17.8005, 23.3277],
        [20.6916,  

batch_predictions
tensor([[23.2419, 21.0704, 21.2958, 22.0872, 22.8184, 23.4048, 23.8541],
        [15.2733, 15.1345, 15.1633, 15.2105, 15.2513, 15.2886, 15.3243],
        [20.6057, 24.3157, 23.4638, 21.6667, 19.1021, 16.9775, 16.7765],
        [20.5463, 20.6331, 20.6123, 20.5529, 20.4965, 20.4648, 20.4610],
        [13.4510, 15.3540, 16.1809, 15.9346, 15.4880, 15.1022, 14.8219],
        [29.0579, 32.3751, 32.6047, 31.7256, 30.2139, 28.0983, 25.3279],
        [25.5400, 26.3336, 26.2136, 25.9470, 25.7132, 25.5615, 25.4911],
        [ 4.8065,  4.7688,  4.7718,  4.7780,  4.7812,  4.7824,  4.7840]])
ground truth
tensor([[16.5391, 17.0210, 15.3770, 16.2698, 32.8515, 24.7732, 21.5420],
        [12.3751, 11.8227, 11.3887, 14.6239, 15.1631, 23.4087, 15.0579],
        [19.7846, 30.8673, 30.9382, 10.7710, 16.9926, 10.9694, 15.2778],
        [18.1406, 11.5363, 12.7693, 13.2228, 14.3282, 22.5907, 25.2834],
        [14.1504, 20.5681, 24.4871,  4.7870,  9.2057, 12.2173, 15.2157],
        [50.0000, 3

batch_predictions
tensor([[28.3020, 34.3420, 34.4940, 33.1395, 30.7334, 26.6483, 22.2205],
        [13.2592, 17.1200, 21.6441, 21.1642, 19.8998, 18.3398, 17.0018],
        [28.2214, 28.7117, 28.5975, 28.4199, 28.2848, 28.2083, 28.1785],
        [18.3446, 18.8861, 18.6527, 18.3114, 18.0401, 17.8807, 17.8237],
        [21.7215, 25.7618, 25.0136, 22.9872, 19.5687, 16.9927, 17.0112],
        [22.3957, 23.8622, 23.5755, 22.9125, 22.1969, 21.5696, 21.1139],
        [16.0063, 18.0698, 17.8329, 17.0729, 16.3234, 15.7154, 15.2865],
        [14.0693, 15.9852, 16.3999, 15.7501, 14.9142, 14.0874, 13.2866]])
ground truth
tensor([[ 6.3776,  0.0000, 13.6196, 32.4405, 14.3141,  8.6876, 16.5249],
        [12.6134,  9.6514, 13.3503, 18.6933, 33.8294, 29.0249,  7.4263],
        [27.0252, 39.1899, 10.6260, 15.6760, 28.9321, 23.4219, 26.3940],
        [15.5445, 13.2167, 14.6370, 14.3872, 22.6854, 23.7507, 12.4540],
        [27.4235, 30.6831,  0.1559, 10.2041, 10.3458, 13.8605, 13.2653],
        [27.4660, 2

batch_predictions
tensor([[20.5080, 19.6235, 19.4798, 19.6602, 19.9413, 20.2334, 20.5064],
        [14.4095, 14.0661, 15.0654, 15.9257, 16.5349, 16.9051, 17.1040],
        [14.3746, 16.2404, 16.7680, 16.4599, 16.0451, 15.7136, 15.4930],
        [23.8689, 31.7835, 31.5230, 30.0902, 27.8750, 24.2049, 20.2932],
        [18.8652, 19.5745, 19.4158, 19.0856, 18.7961, 18.6082, 18.5222],
        [38.2490, 36.2783, 36.4664, 36.9353, 37.2598, 37.4506, 37.5671],
        [ 7.3970,  4.8564,  4.4242,  4.4299,  4.4665,  4.4986,  4.5255],
        [15.2560, 14.3939, 14.4310, 14.9337, 15.6255, 16.3987, 17.1899]])
ground truth
tensor([[ 0.0000,  0.0000,  6.4703, 14.7422, 15.2420, 20.5287, 31.5623],
        [11.1395,  9.9206,  9.7789,  9.0561, 11.0119, 15.9722, 25.1276],
        [11.0337, 10.9548, 11.5992, 12.4934, 15.1236, 13.5060, 10.4024],
        [15.5896, 18.8776, 23.0017, 34.6088, 49.4898, 34.3537,  0.0000],
        [17.0210, 25.9921, 24.9575, 12.8827, 12.2874, 11.7489,  6.9019],
        [50.1417, 3

batch_predictions
tensor([[15.5066, 19.7911, 23.5671, 23.9666, 23.2106, 22.0047, 20.5622],
        [15.3161, 17.7088, 17.9265, 17.3966, 16.8707, 16.4836, 16.2495],
        [21.8396, 23.2253, 22.9230, 22.3165, 21.7114, 21.2231, 20.9042],
        [ 6.0421,  3.8437,  3.7276,  3.9605,  3.7722,  3.9924,  3.8108],
        [18.0905, 17.1482, 17.0969, 17.4236, 17.8171, 18.1832, 18.5006],
        [16.8130, 15.8264, 15.8494, 16.2124, 16.5779, 16.8840, 17.1292],
        [12.6788, 11.7619, 10.9855,  9.9470,  8.5041,  7.3388,  6.8885],
        [18.7955, 18.3995, 18.4165, 18.5607, 18.7253, 18.8777, 19.0107]])
ground truth
tensor([[14.5975, 17.8005, 20.3373, 24.1638, 31.7177, 28.7982,  6.9728],
        [15.6497, 19.9500,  6.2599,  9.6791, 17.0700, 14.4398, 14.5844],
        [13.6763, 16.9643, 28.7982, 29.1383, 17.3328, 14.9660, 12.7834],
        [ 4.1950,  3.1746,  2.8770,  5.4280,  5.8957,  9.7931, 10.6859],
        [29.2741, 29.0110,  8.7585,  9.8501, 19.1741, 18.4771, 18.7927],
        [27.4376, 1

batch_predictions
tensor([[16.0084, 18.0457, 17.6634, 16.8073, 15.9754, 15.2863, 14.7870],
        [28.4429, 22.6671, 24.4436, 26.6794, 28.9236, 31.0864, 33.1979],
        [17.5479, 21.5344, 22.7997, 21.0691, 17.6437, 14.4664, 14.0159],
        [12.1927, 15.9928, 19.4828, 20.7617, 20.5499, 19.9530, 19.3605],
        [21.6233, 19.9394, 19.7445, 20.2300, 20.8253, 21.3779, 21.8598],
        [34.0971, 37.6239, 37.5095, 37.2181, 36.9696, 36.7901, 36.6757],
        [24.6409, 22.9046, 22.5703, 22.9107, 23.4545, 24.0145, 24.5341],
        [38.2631, 38.3508, 38.3520, 38.3517, 38.3512, 38.3505, 38.3495]])
ground truth
tensor([[17.6486, 25.8154, 25.4866,  8.0089, 12.8880, 13.9137,  9.9027],
        [15.0579, 13.8743, 16.3335, 15.3472, 33.8506, 45.4366, 79.2346],
        [21.0810, 23.4350, 11.9279, 15.8601,  9.1399,  8.4824, 12.7827],
        [ 6.4045, 24.9737, 27.1173, 24.8816, 11.5203, 10.8890, 11.2704],
        [33.1774, 17.2052, 14.9518, 18.3248, 19.5862, 28.6990, 22.8600],
        [29.7052, 3

batch_predictions
tensor([[19.5465, 18.2069, 18.0118, 18.3684, 18.8442, 19.3044, 19.7215],
        [23.1046, 26.6731, 26.1100, 24.7674, 22.8656, 20.4456, 18.9645],
        [21.0198, 25.2452, 25.0787, 23.5887, 21.1353, 18.3481, 17.4102],
        [16.9809, 17.4127, 17.1982, 16.8688, 16.6028, 16.4484, 16.3974],
        [38.3428, 38.3424, 38.3424, 38.3423, 38.3421, 38.3420, 38.3420],
        [18.0577, 16.8630, 16.9985, 17.4324, 17.8067, 18.0969, 18.3188],
        [13.9909, 13.7627, 13.7863, 13.8704, 13.9647, 14.0549, 14.1335],
        [15.6967, 15.6309, 15.6688, 15.7076, 15.7364, 15.7622, 15.7871]])
ground truth
tensor([[15.0973, 16.0310, 16.0573, 15.5181, 17.6881, 27.7880, 27.9853],
        [22.5624,  9.6655, 20.2381, 15.0652, 30.1871, 10.7993, 13.6763],
        [16.3993, 15.6234, 16.3993, 16.3730, 22.6854, 22.6460, 13.1904],
        [17.2672, 24.5660,  7.4171, 11.5729, 12.1383, 12.4934, 15.5313],
        [33.7454, 38.9663, 59.4029, 49.6712, 44.2662, 50.1447, 36.9937],
        [21.5986, 1

batch_predictions
tensor([[21.4486, 25.6562, 24.8072, 23.0618, 20.5222, 17.9892, 17.5962],
        [27.8829, 23.8873, 24.7186, 26.1423, 27.4683, 28.5370, 29.3137],
        [21.7785, 30.3002, 33.9273, 33.0660, 30.1179, 23.0480, 18.5818],
        [22.2305, 23.6654, 23.3666, 22.7213, 22.0457, 21.4706, 21.0677],
        [18.1289, 20.8954, 20.5211, 19.8034, 19.1428, 18.6501, 18.3567],
        [20.2085, 18.6370, 18.4874, 19.0198, 19.6523, 20.2395, 20.7583],
        [19.9406, 18.0374, 19.2990, 20.4379, 21.3774, 22.0820, 22.5486],
        [19.3187, 23.4593, 22.3056, 20.0402, 17.0565, 15.6379, 15.7812]])
ground truth
tensor([[22.6986, 28.0905, 29.7212, 12.5855, 18.8322, 15.5708, 17.2935],
        [32.8656, 20.5215, 15.8730, 11.0544, 17.0493, 24.7591, 40.5471],
        [15.9722, 36.5363, 51.1621, 36.5930, 17.1344, 17.1202, 18.2398],
        [19.4303, 31.5334,  9.8498, 14.8101, 18.1973, 12.6559, 20.7766],
        [13.9172, 12.9252, 13.6480, 22.7608, 26.3322, 22.4206, 12.6276],
        [28.3730, 2

batch_predictions
tensor([[10.2209,  8.5516,  7.1622,  6.6816,  6.5931,  6.5917,  6.6034],
        [37.5054, 38.2119, 38.1912, 38.1481, 38.1185, 38.1015, 38.0935],
        [11.4600, 13.3929, 15.8678, 18.4334, 17.5762, 16.5530, 15.7238],
        [15.7848, 16.7636, 16.6832, 16.2365, 15.7550, 15.3548, 15.0707],
        [21.9318, 22.7367, 22.7149, 22.3477, 21.8962, 21.5098, 21.2557],
        [13.8033, 17.1092, 19.2099, 17.9824, 16.5374, 15.1116, 13.8442],
        [15.0328, 20.4255, 24.8517, 26.0884, 25.6535, 24.7606, 23.7250],
        [16.5515, 18.4143, 18.3192, 17.8884, 17.5031, 17.2432, 17.1120]])
ground truth
tensor([[12.4717,  9.1553,  6.8452,  8.0782,  7.9223,  8.9002, 10.4450],
        [26.2897, 39.3707, 54.6202, 49.0079, 31.3917, 28.8549, 22.5057],
        [ 9.9027,  7.9695,  9.9421, 13.0589, 20.9890, 22.0673,  6.0757],
        [20.9890, 22.0673,  6.0757, 10.8890, 11.7833, 10.6654, 13.6902],
        [16.1706, 14.5550, 18.8350, 15.2778, 23.3702, 10.6151, 10.4592],
        [10.0605, 1

batch_predictions
tensor([[17.6979, 20.6638, 20.2276, 19.2912, 18.3311, 17.5178, 16.9520],
        [28.0552, 26.0847, 25.6568, 25.9910, 26.5924, 27.2310, 27.8187],
        [26.9925, 27.6584, 27.6452, 27.4163, 27.1392, 26.9052, 26.7549],
        [15.9682, 15.6493, 16.8762, 17.7793, 18.3434, 18.6644, 18.8501],
        [19.7097, 21.4468, 20.9943, 20.1029, 19.1467, 18.3154, 17.7737],
        [14.3314, 17.6386, 20.4552, 20.8461, 20.2622, 19.5502, 18.9699],
        [11.4280, 10.5214,  9.2231,  7.6164,  6.7054,  6.4614,  6.4221],
        [17.5586, 19.6512, 19.3231, 18.6228, 17.9597, 17.4479, 17.1241]])
ground truth
tensor([[18.0130, 23.8095, 24.6599, 11.0969, 14.9660, 11.6213,  9.5947],
        [26.5913, 22.9879, 22.9090, 22.2514, 24.1189, 23.3298, 40.5839],
        [40.5471, 30.4989, 18.8209, 18.5374, 12.8401, 17.5454, 30.0028],
        [16.0289,  0.0000,  0.0000,  8.9569, 12.4150, 14.1865, 17.1627],
        [20.9892, 32.8656, 33.2766, 15.2778,  9.0278, 13.8889, 19.2035],
        [10.2315, 1

batch_predictions
tensor([[16.3808, 19.6742, 19.6425, 18.9515, 18.2326, 17.6563, 17.2764],
        [16.0973, 14.8036, 15.6298, 16.4357, 17.0069, 17.3870, 17.6391],
        [12.8574, 12.0180, 12.2169, 12.9528, 14.0668, 15.0920, 15.5497],
        [30.3903, 25.9765, 25.9853, 27.3846, 28.9027, 30.2502, 31.3855],
        [16.5349, 15.8021, 17.0374, 18.0333, 18.7332, 19.1534, 19.3834],
        [19.8989, 21.9573, 21.5349, 20.8470, 20.2094, 19.7086, 19.3851],
        [19.5730, 20.4336, 20.2183, 19.7865, 19.3894, 19.1124, 18.9710],
        [21.8908, 29.8145, 29.4658, 27.9430, 25.2818, 21.0795, 18.6506]])
ground truth
tensor([[21.7971, 29.1525, 26.1054, 14.0731, 13.0811, 12.6276, 13.6054],
        [22.1372, 11.6780,  9.1978, 11.7772, 11.6638, 19.3027, 26.6156],
        [ 4.3226,  7.0720,  5.5414,  6.3492, 13.4070,  1.2188, 25.9495],
        [37.3158, 21.4569, 27.5368, 20.1814,  6.3776,  0.0000, 13.6196],
        [24.0400,  8.3772, 10.0342, 10.8364, 10.3761, 18.0168, 25.1578],
        [18.4377, 2

batch_predictions
tensor([[18.9282, 18.8433, 18.8550, 18.8692, 18.8767, 18.8869, 18.9039],
        [17.5142, 21.8052, 24.9293, 25.3436, 24.7816, 23.8482, 22.7321],
        [19.0446, 21.7056, 21.0653, 19.9023, 18.6029, 17.3866, 16.6574],
        [ 5.0464,  4.6430,  4.6245,  4.6471,  4.6641,  4.6750,  4.6830],
        [ 8.9317,  7.7093,  7.1054,  6.9376,  6.9143,  6.9261,  6.9384],
        [28.3679, 26.3718, 25.5626, 25.5181, 25.9865, 26.7337, 27.6499],
        [ 9.2332, 10.3814, 11.2063, 12.0288, 13.0623, 14.5244, 15.9992],
        [13.5120, 16.5165, 19.9324, 18.9430, 17.5842, 16.3668, 15.5011]])
ground truth
tensor([[27.2225,  5.4182,  0.0000, 19.2267, 17.4513, 15.7680, 18.8322],
        [37.1457, 34.5947, 14.6117, 19.7279, 14.3991, 19.0760, 22.2931],
        [14.4841, 18.8917, 21.8537, 30.2863, 30.6264, 13.9739, 13.6905],
        [ 5.1587,  4.9603,  4.6769,  5.0595, 10.6151, 11.5363,  7.5113],
        [12.1457,  8.5034,  5.9949,  7.0578,  7.3413,  8.9427,  8.8294],
        [46.4994, 4

batch_predictions
tensor([[17.4742, 16.0192, 17.2039, 18.1570, 18.8071, 19.2082, 19.4550],
        [18.1685, 16.7160, 17.8528, 18.7259, 19.3222, 19.7044, 19.9475],
        [20.2113, 18.2913, 18.6412, 19.5069, 20.2785, 20.8963, 21.3795],
        [21.3804, 22.5510, 22.2859, 21.7106, 21.1241, 20.6517, 20.3489],
        [25.2408, 28.6900, 28.2125, 27.3074, 26.3253, 25.2959, 24.2521],
        [17.3216, 16.6146, 18.4231, 20.2730, 21.8178, 22.4575, 22.5844],
        [22.3698, 20.2631, 19.8087, 20.3521, 21.1712, 22.0314, 22.8941],
        [23.1072, 22.7866, 22.7593, 22.8377, 22.9369, 23.0464, 23.1632]])
ground truth
tensor([[28.1037, 25.8503,  9.0136, 10.6576, 12.8401, 10.5159, 18.7217],
        [10.7312, 16.0442, 13.6376, 16.4519, 15.0579, 26.2362,  5.3524],
        [12.2567, 14.2031, 20.2656, 21.8701, 16.8332, 24.3030, 29.6423],
        [ 7.3129,  8.3759, 13.9031, 19.4728, 22.8316, 27.6644, 18.3673],
        [27.6959, 22.2251, 28.0510, 14.7817, 13.8874, 23.1326, 16.3204],
        [ 8.6402, 1

batch_predictions
tensor([[19.8688, 18.6597, 18.6216, 18.9931, 19.3794, 19.7044, 19.9662],
        [31.5670, 32.7695, 32.6501, 32.3194, 32.0222, 31.8231, 31.7209],
        [17.6557, 17.1784, 17.1766, 17.3379, 17.5247, 17.6986, 17.8502],
        [15.3269, 15.0478, 16.4109, 17.6470, 18.5943, 19.1094, 19.2932],
        [18.9661, 18.1927, 18.0485, 18.2051, 18.4793, 18.7923, 19.1109],
        [23.6589, 32.3884, 32.2540, 30.1883, 26.0025, 19.2446, 18.9728],
        [ 9.4833, 10.7828, 12.0833, 14.0399, 17.7847, 19.7257, 18.3275],
        [ 4.6123,  4.5933,  4.5974,  4.6026,  4.6050,  4.6057,  4.6069]])
ground truth
tensor([[12.9143, 19.5818, 14.9264, 17.1883, 21.1073, 24.3951, 26.5913],
        [27.2094, 37.2304, 39.7948,  6.1021, 18.2141, 24.2372, 23.6323],
        [23.8237, 21.4144, 12.0748, 14.1015, 11.0261, 11.9189, 21.4144],
        [15.4337, 15.3203, 11.9615, 10.0624, 14.5692, 25.7795, 25.0709],
        [17.7438, 11.7489, 16.4541,  9.2404, 11.4087, 20.9609, 24.1780],
        [15.7880, 1

batch_predictions
tensor([[12.8585, 17.2724, 22.3628, 22.9628, 22.1103, 20.7699, 19.1620],
        [18.0128, 21.1111, 20.7091, 19.8473, 18.9905, 18.2986, 17.8454],
        [16.6518, 21.0299, 21.1025, 20.1232, 18.8488, 17.5536, 16.5966],
        [21.3851, 28.3828, 27.7751, 25.0236, 19.3169, 16.8124, 17.1145],
        [25.3930, 31.5522, 31.5949, 29.9285, 26.6382, 20.7085, 19.0634],
        [16.5117, 21.1387, 25.8415, 26.8393, 26.0545, 24.1057, 20.3663],
        [20.5333, 20.8928, 20.7278, 20.4796, 20.2778, 20.1588, 20.1185],
        [17.8762, 23.1314, 25.1483, 23.8682, 21.0427, 17.0218, 16.2373]])
ground truth
tensor([[12.4540, 13.5850, 15.1499, 17.6486, 25.8548, 31.4177,  5.1815],
        [17.0918, 18.1831, 21.0743, 12.9677, 14.4416, 10.3741, 13.4921],
        [11.9937, 19.4240, 25.7496, 26.1704, 13.9663, 11.6386,  9.6791],
        [43.4524, 38.3929, 19.0051, 19.2744, 12.7126, 21.8679, 24.0788],
        [17.3611, 10.0057, 18.8776, 31.7035, 45.2523, 31.4768, 19.2744],
        [10.0624, 1

batch_predictions
tensor([[19.7865, 24.2050, 23.4268, 21.9665, 20.2396, 18.6481, 17.8374],
        [15.6878, 14.4999, 14.8749, 15.5047, 16.0286, 16.4199, 16.7037],
        [21.9432, 27.6653, 26.5592, 24.0300, 19.5928, 17.3123, 17.6473],
        [18.6535, 16.7296, 17.6230, 18.5313, 19.2655, 19.8000, 20.1667],
        [31.4069, 33.6157, 33.7602, 33.2995, 32.6625, 32.0493, 31.5593],
        [22.7847, 24.5740, 24.3000, 23.3682, 22.1100, 20.7500, 19.6480],
        [20.9186, 27.2976, 26.3309, 24.2135, 20.8188, 17.7585, 17.9334],
        [19.7142, 22.5578, 21.9057, 20.5895, 18.9634, 17.3944, 16.6564]])
ground truth
tensor([[18.4640, 31.9963, 13.4403, 14.4398, 11.9805,  3.9716, 16.8201],
        [ 6.2217, 10.9836, 12.9535, 10.0482, 17.1485, 20.3515, 25.9495],
        [11.9615, 18.2965, 23.6961, 36.9756, 39.6684,  9.2971, 21.1876],
        [34.3503,  9.8764, 14.3740, 12.2962, 11.5729, 15.7680, 18.2404],
        [32.4961, 39.3872, 32.4566, 15.3603,  3.2614, 12.0989, 18.3588],
        [16.4966, 2

batch_predictions
tensor([[ 4.0017,  3.5129,  3.5638,  3.5816,  3.5938,  3.6023,  3.6088],
        [25.0361, 29.2892, 30.4045, 29.9681, 28.9097, 27.2876, 24.7642],
        [28.3256, 29.1228, 28.9965, 28.7693, 28.5885, 28.4800, 28.4334],
        [29.7836, 30.9898, 31.1198, 30.9176, 30.6297, 30.3767, 30.2111],
        [17.2102, 20.6938, 20.2124, 19.2081, 18.1690, 17.3036, 16.7429],
        [27.0177, 34.1279, 38.1267, 38.2677, 38.2696, 38.2612, 38.2516],
        [21.9408, 31.2162, 33.7854, 32.6813, 29.7180, 22.9840, 18.4135],
        [ 7.3033,  6.1057,  5.8545,  5.8390,  5.8548,  5.8669,  5.8753]])
ground truth
tensor([[ 5.1587,  5.6973,  6.1650,  6.3492,  8.5317,  9.8073,  7.1570],
        [19.9688, 31.4201, 45.4223, 38.3645, 17.9280, 18.0414, 16.9785],
        [25.9206, 25.8285, 36.6649, 37.9932, 21.9227, 29.5634, 20.7654],
        [23.7245, 26.7715, 39.0023, 51.9841, 39.9660, 24.7449, 23.6111],
        [17.8713, 15.0227, 17.5028, 20.6774, 28.0045, 23.4269, 13.4779],
        [29.9579, 4

batch_predictions
tensor([[ 9.2155,  8.0558,  7.2612,  6.9367,  6.8526,  6.8514,  6.8672],
        [ 3.7370,  3.6530,  3.6666,  3.6784,  3.6863,  3.6912,  3.6946],
        [23.6813, 22.2755, 22.0203, 22.3037, 22.7569, 23.2263, 23.6610],
        [21.9867, 23.1890, 22.9634, 22.4472, 21.9344, 21.5375, 21.2964],
        [25.0881, 21.9409, 22.5543, 23.6665, 24.6814, 25.5092, 26.1435],
        [15.9271, 20.3379, 22.0896, 21.2196, 19.8520, 18.1489, 16.6183],
        [16.7173, 22.0137, 25.5525, 25.5370, 24.6259, 23.0449, 20.4886],
        [33.6580, 32.3936, 32.2484, 32.4555, 32.7414, 33.0237, 33.2759]])
ground truth
tensor([[14.7028, 12.8617,  4.3924,  7.4040,  8.0747,  7.3514, 11.3887],
        [ 3.1321,  4.2092,  2.9762,  2.8061,  5.4563,  7.8798,  7.6814],
        [25.8022, 18.6086, 21.1862, 14.1110, 13.3877, 16.9516, 27.6696],
        [17.2541, 22.0147, 20.1999, 25.2893, 42.6092, 21.2914, 21.0153],
        [28.1431, 11.4019, 12.9537, 12.1383, 17.3330, 21.1599, 36.8885],
        [13.8889, 1

batch_predictions
tensor([[10.9794, 11.9715, 13.3113, 15.3322, 16.3373, 15.8084, 15.0891],
        [10.4603,  8.2846,  6.3533,  5.8130,  5.7467,  5.7506,  5.7606],
        [16.2353, 19.7297, 19.9207, 19.2216, 18.4652, 17.8387, 17.4117],
        [21.4479, 20.8180, 20.7292, 20.8626, 21.0660, 21.2869, 21.5040],
        [15.3325, 20.0068, 20.8479, 19.6631, 18.0024, 16.0654, 14.7610],
        [11.4538, 12.3227, 13.4035, 14.6758, 15.2812, 15.0468, 14.5615],
        [19.9900, 20.5179, 20.3858, 20.1324, 19.9166, 19.7816, 19.7245],
        [19.1477, 19.1615, 19.1618, 19.1439, 19.1236, 19.1148, 19.1143]])
ground truth
tensor([[ 7.5618,  7.4566,  8.8112, 10.8364, 17.4908, 25.5129,  8.1799],
        [15.8732,  5.2209,  6.6544,  6.6675,  6.4440, 10.5865, 13.0326],
        [13.7559, 14.8211, 22.5802, 26.7228, 14.9132, 11.9411, 13.9532],
        [28.7612,  5.1683, 10.7049, 15.4129, 16.9253, 14.0715, 22.3304],
        [12.8748,  9.0610, 15.0316, 27.5776, 22.8564, 17.5171, 12.3882],
        [10.8627, 1

batch_predictions
tensor([[14.8153, 18.6677, 19.7437, 18.7552, 17.6353, 16.5990, 15.7679],
        [38.2595, 38.2524, 38.2545, 38.2548, 38.2545, 38.2541, 38.2537],
        [21.7986, 20.2499, 19.6693, 19.7732, 20.2859, 20.9986, 21.8591],
        [15.7946, 16.9177, 16.9242, 16.5751, 16.2229, 15.9689, 15.8220],
        [22.1640, 26.2388, 25.6615, 23.8742, 20.8567, 17.8877, 17.5439],
        [13.7580, 17.4754, 21.0963, 19.8160, 17.9387, 16.0299, 14.7163],
        [21.3141, 25.5212, 28.4128, 31.0642, 33.5238, 35.0750, 35.6173],
        [19.0069, 24.8760, 23.4229, 20.1385, 16.1907, 15.6075, 15.8897]])
ground truth
tensor([[11.0994, 14.4003, 18.8059, 15.8864, 27.0516, 31.2730, 10.7969],
        [42.5565, 49.4345, 49.9211, 39.3740, 31.8911, 32.3382, 44.3582],
        [35.6859, 35.4734, 12.5567, 11.6497,  8.8294, 12.7834, 24.8724],
        [11.2441,  9.6923, 10.5339, 13.6639, 12.3225, 18.5692, 14.7554],
        [15.1786, 15.3061, 29.9603, 41.1139, 32.4830, 23.7528, 16.6383],
        [13.3482, 1

batch_predictions
tensor([[22.9044, 21.8178, 21.5660, 21.7757, 22.2205, 22.7732, 23.3621],
        [ 9.3242, 10.0468, 10.5127, 10.8886, 11.2384, 11.6192, 12.1029],
        [26.3921, 25.6545, 25.4989, 25.6145, 25.8479, 26.1468, 26.4766],
        [19.6621, 24.1674, 23.1291, 21.2393, 18.8566, 16.9723, 16.8401],
        [16.2413, 15.5419, 16.8423, 17.8800, 18.5681, 18.9447, 19.1417],
        [17.4469, 19.3133, 18.8278, 18.0034, 17.2100, 16.5623, 16.1275],
        [16.6152, 16.8050, 16.7981, 16.6841, 16.5501, 16.4457, 16.3845],
        [10.9677, 11.8935, 13.0747, 14.7041, 15.7684, 15.4877, 14.9691]])
ground truth
tensor([[20.2806, 22.3356, 21.1026, 37.2024, 39.8243, 10.0198, 14.9235],
        [10.9416, 15.2420, 24.6844, 15.9916,  0.0000,  8.9164, 11.4808],
        [26.3605, 12.9960, 13.3787, 17.5312, 19.7988, 24.9150, 38.1661],
        [16.7543, 23.5139, 30.6286, 35.6260,  1.9858, 11.2309, 16.7017],
        [12.9252, 15.0085, 18.8776, 14.1723, 16.1990, 19.7988, 27.6927],
        [ 9.0986, 1

batch_predictions
tensor([[23.5086, 30.7088, 30.1643, 28.5012, 25.8145, 20.8562, 18.8597],
        [28.0065, 32.5709, 32.8083, 31.5806, 29.3471, 25.7716, 22.1220],
        [22.5590, 21.4222, 21.0848, 21.1865, 21.5213, 21.9763, 22.5019],
        [ 6.5737,  6.8044,  6.8293,  6.8257,  6.8227,  6.8157,  6.8128],
        [19.6353, 20.8995, 20.6517, 20.0232, 19.3609, 18.8208, 18.4769],
        [12.8795, 14.9884, 17.1463, 17.0483, 16.3477, 15.6545, 15.0818],
        [ 3.6760,  3.6530,  3.6594,  3.6651,  3.6688,  3.6709,  3.6723],
        [28.4126, 24.5563, 24.8811, 26.0774, 27.2666, 28.2662, 29.0363]])
ground truth
tensor([[16.5675, 25.8220, 38.1803, 30.5839, 19.9972, 16.8226, 14.9518],
        [41.6100, 43.6791, 38.4070, 13.6763,  9.6939, 12.6134, 15.9722],
        [13.2036, 14.8738, 11.0205, 12.7433, 23.8033, 28.5508, 30.7338],
        [ 9.5082,  8.9953,  8.4035, 12.4803, 16.8201, 10.8233,  3.4850],
        [ 9.2404, 11.4087, 20.9609, 24.1780, 26.8707, 11.0969, 14.5975],
        [12.3866, 2

batch_predictions
tensor([[19.7382, 19.9590, 19.8950, 19.7748, 19.6744, 19.6165, 19.5985],
        [20.8894, 22.4987, 22.1693, 21.4059, 20.5635, 19.8109, 19.2728],
        [ 4.8848,  4.0297,  4.0096,  4.0454,  4.0720,  4.0904,  4.1035],
        [31.9111, 25.7754, 26.8295, 28.2146, 29.3943, 30.4046, 31.2486],
        [22.4109, 19.9547, 20.3808, 21.3114, 22.1528, 22.8313, 23.3541],
        [21.3407, 25.4550, 24.5927, 22.8485, 20.3428, 17.9618, 17.5821],
        [17.8275, 23.4086, 23.0825, 21.7325, 19.6005, 17.4215, 16.7537],
        [19.9724, 26.0511, 24.9433, 21.8103, 17.0170, 16.0194, 16.2318]])
ground truth
tensor([[32.1145, 28.7273, 17.7721,  7.3129,  8.3759, 13.9031, 19.4728],
        [27.7749, 30.1289, 11.2309, 19.8185, 21.0284, 20.0947, 22.4487],
        [ 7.9365,  7.9082,  5.2721,  7.7239,  7.5113,  7.8656,  8.9569],
        [53.9058, 64.7291, 14.0978, 12.3882, 15.8732, 17.8722, 33.8769],
        [12.7170, 18.9769, 18.0694, 16.4519, 23.5534, 31.0626, 27.7749],
        [25.7496, 3

batch_predictions
tensor([[20.9940, 25.7985, 25.0072, 23.7634, 22.3585, 20.9491, 19.7175],
        [ 3.7030,  3.6031,  3.6158,  3.6284,  3.6370,  3.6427,  3.6469],
        [21.6155, 26.6507, 25.8585, 23.6470, 19.8204, 17.1252, 17.3421],
        [16.3830, 15.1553, 16.0922, 16.8490, 17.3755, 17.7169, 17.9342],
        [19.5621, 19.5036, 19.5194, 19.5302, 19.5330, 19.5415, 19.5632],
        [22.2258, 26.6232, 25.9963, 24.3534, 21.8111, 18.8569, 18.1109],
        [32.2209, 29.5065, 28.8976, 29.2988, 30.0237, 30.7894, 31.5095],
        [16.0019, 14.7575, 14.8682, 15.4486, 16.0742, 16.6640, 17.2210]])
ground truth
tensor([[23.1326, 16.3204, 17.3856, 26.2756, 26.8148, 15.2157, 19.4240],
        [ 5.2863,  4.8469,  3.8974,  4.9036,  4.2234,  6.5051,  5.1587],
        [14.3424, 13.9031, 23.1151, 34.4955, 34.6514, 18.4666, 14.7392],
        [29.8264,  9.9290, 12.0857, 12.4277, 11.0205, 19.8185, 22.8432],
        [31.0941, 20.7625,  8.2483, 12.2449,  6.0941, 12.0323, 18.5374],
        [21.7780, 3

batch_predictions
tensor([[ 4.5452,  4.5533,  4.5541,  4.5564,  4.5570,  4.5462,  4.5430],
        [26.0216, 25.9523, 25.9555, 25.9653, 25.9633, 25.9690, 25.9905],
        [ 7.0016,  6.8148,  6.7566,  6.7447,  6.7466,  6.7475,  6.7489],
        [20.3980, 18.4762, 19.5850, 20.5622, 21.3039, 21.8249, 22.1722],
        [37.8955, 35.3487, 35.4054, 36.0245, 36.5457, 36.8818, 37.0885],
        [ 8.2018,  8.9124,  9.2441,  9.4215,  9.5171,  9.5721,  9.6096],
        [19.7482, 19.5630, 19.5796, 19.6246, 19.6643, 19.7011, 19.7393],
        [10.0941, 10.6535, 11.1567, 11.6856, 12.3292, 13.2399, 14.5472]])
ground truth
tensor([[ 4.2942,  4.3226,  4.7477,  8.4892, 10.0057,  6.2217,  3.8124],
        [41.7806, 33.0352, 18.5823, 17.5039, 18.7533, 22.0410, 29.4845],
        [ 6.5492,  7.2593, 11.5466, 17.0174,  4.7475,  6.2336,  6.7859],
        [28.7273, 17.7721,  7.3129,  8.3759, 13.9031, 19.4728, 22.8316],
        [42.7296, 31.1366, 25.8645, 21.9246, 32.4830, 43.6366, 60.1757],
        [ 2.6302, 1

batch_predictions
tensor([[19.3809, 16.7725, 17.7680, 19.1088, 20.3426, 21.3171, 22.1365],
        [17.9585, 23.1836, 27.1660, 28.0103, 27.5411, 26.5429, 25.1269],
        [22.0349, 30.1141, 29.9629, 28.1724, 24.4072, 18.5705, 18.0700],
        [34.2670, 31.4196, 30.4969, 30.6197, 31.2244, 32.0171, 32.8759],
        [37.7532, 37.5339, 37.5594, 37.5832, 37.5985, 37.6135, 37.6306],
        [20.0773, 20.8880, 20.5368, 19.9846, 19.4977, 19.1689, 19.0166],
        [ 9.5445, 11.0104, 12.1773, 13.5146, 14.9592, 15.8701, 16.0910],
        [15.1000, 17.0775, 17.2242, 16.6492, 16.0338, 15.5338, 15.1774]])
ground truth
tensor([[ 7.9507, 10.6151, 10.7710, 11.9898, 19.0051, 38.7330, 51.1905],
        [18.9909, 22.5907, 34.0420,  9.3821, 11.8197, 26.7857, 20.2806],
        [15.5187, 35.8985, 38.9031, 18.9626, 10.5584, 16.7517, 15.4478],
        [24.3425, 20.2130, 19.9369, 18.8848, 33.9427, 55.6418, 14.1899],
        [57.7523, 42.2194, 30.9382, 25.8645, 21.2302, 29.7336, 42.8713],
        [22.4619,  

batch_predictions
tensor([[25.4557, 27.8394, 27.7743, 27.2349, 26.5131, 25.6921, 24.7771],
        [38.2368, 38.2356, 38.2357, 38.2356, 38.2354, 38.2352, 38.2351],
        [16.7962, 21.7628, 25.7079, 27.3998, 27.0630, 25.9031, 24.2533],
        [17.6361, 21.7645, 24.8954, 27.4442, 28.7650, 28.8219, 28.1498],
        [17.9497, 17.5631, 17.5416, 17.6076, 17.6684, 17.7217, 17.7710],
        [20.0802, 18.3877, 20.0744, 21.7210, 23.1808, 24.2382, 24.7795],
        [22.3671, 21.4655, 21.2325, 21.2539, 21.3515, 21.4878, 21.6448],
        [19.8806, 20.7551, 20.4554, 19.8998, 19.3752, 18.9967, 18.7983]])
ground truth
tensor([[22.0410, 29.4845, 44.7002, 34.2451, 21.2388, 20.1078, 22.8958],
        [41.8201, 59.8501, 54.9448, 42.2935, 41.8201, 37.6381, 38.2167],
        [28.8124, 15.3486, 12.7268, 33.6735, 54.1383, 29.3793, 15.5896],
        [ 4.9745, 35.4450, 45.1956, 33.7585, 18.9909, 16.8651, 16.0006],
        [19.5011, 11.3237, 12.9252, 12.9110, 11.3095, 12.6417, 25.2268],
        [ 3.7415, 1

batch_predictions
tensor([[12.3453, 15.2282, 19.6429, 18.9979, 17.7511, 16.4257, 15.3171],
        [15.6869, 15.9488, 15.8895, 15.7224, 15.5642, 15.4614, 15.4172],
        [31.9101, 25.2994, 26.5109, 28.1617, 29.6757, 31.0907, 32.4263],
        [15.4138, 20.2819, 24.7556, 26.4227, 26.0413, 24.9236, 23.3316],
        [19.9313, 24.0544, 23.9349, 22.9524, 21.6874, 20.3926, 19.4262],
        [21.6912, 19.4770, 20.2475, 21.2331, 22.0406, 22.6395, 23.0631],
        [18.8354, 18.0225, 17.9470, 18.1823, 18.4935, 18.7956, 19.0637],
        [14.8792, 18.0200, 19.1372, 18.5299, 17.8872, 17.3987, 17.0978]])
ground truth
tensor([[11.1783,  7.3251, 13.9795, 21.3046, 25.7891, 10.8627, 10.5208],
        [10.9127,  9.3679,  8.7727, 15.1219, 17.6020, 30.8957, 13.6763],
        [14.0978, 12.3882, 15.8732, 17.8722, 33.8769, 48.5402, 41.5965],
        [ 9.9915, 29.9178, 25.4535, 11.0544,  9.0703,  9.3963, 12.9252],
        [13.3877, 16.9516, 27.6696, 26.0521, 14.6502, 14.6896, 13.3614],
        [ 0.0000,  

batch_predictions
tensor([[18.2518, 17.6343, 17.5629, 17.7300, 17.9750, 18.2371, 18.4909],
        [32.0196, 26.4533, 27.3999, 28.9274, 30.3508, 31.6087, 32.7033],
        [19.8781, 24.6393, 23.7372, 21.7173, 18.9647, 17.0105, 16.9908],
        [13.6804, 17.5302, 21.1714, 19.9705, 18.2684, 16.6139, 15.3330],
        [23.3885, 28.1673, 27.9708, 26.3705, 23.2513, 19.1496, 18.2158],
        [21.4955, 27.0998, 26.1639, 24.0696, 20.7627, 17.8099, 17.8655],
        [15.0362, 14.0329, 14.4624, 15.0952, 15.6033, 15.9660, 16.2142],
        [29.0210, 31.1267, 31.0829, 30.4766, 29.7053, 28.9210, 28.1915]])
ground truth
tensor([[26.9726,  9.7449, 11.4150, 12.0594, 14.0847, 20.7128, 23.3956],
        [45.6207, 29.8527, 28.1694, 20.3314, 23.0800, 29.5371, 49.6449],
        [17.5170, 16.3832, 20.7341, 26.0204, 35.9552, 15.3061, 16.4966],
        [ 1.4172, 19.8696, 20.0113, 23.3985, 33.6026, 34.4388,  8.4467],
        [14.3707, 17.5454, 32.4405, 46.6270, 38.7472,  4.2942,  0.0000],
        [16.0705, 1

batch_predictions
tensor([[24.2970, 25.8030, 25.6920, 25.0472, 24.1741, 23.2467, 22.4123],
        [21.5736, 19.4414, 20.0809, 21.2059, 22.1987, 22.9903, 23.6088],
        [22.8099, 22.9725, 22.9450, 22.8685, 22.7995, 22.7601, 22.7513],
        [14.5182, 18.7388, 23.5101, 23.9096, 22.8216, 20.7741, 17.9259],
        [15.9927, 18.7603, 18.8928, 18.1494, 17.3445, 16.6644, 16.1671],
        [20.7379, 29.1559, 28.9441, 26.9954, 22.6392, 17.1335, 17.1859],
        [19.4126, 21.5512, 21.1086, 20.3795, 19.6961, 19.1524, 18.7956],
        [12.5451, 15.2811, 17.6542, 18.5043, 18.3926, 18.1305, 17.9063]])
ground truth
tensor([[ 8.8010, 16.6383, 20.1389, 29.5493, 43.7783, 26.2188,  0.0000],
        [31.5476, 11.6922, 17.5170, 15.1927, 16.6950, 19.5011, 31.4342],
        [20.1999, 25.2893, 42.6092, 21.2914, 21.0153, 25.0789, 20.9100],
        [11.6213,  8.2058, 17.2336, 20.7625, 30.7965, 33.7585,  7.8940],
        [19.1215, 21.4755, 21.8832, 14.6239, 15.6628,  9.9947, 12.4408],
        [30.2438, 5

batch_predictions
tensor([[21.8529, 19.4544, 19.8980, 20.8380, 21.6927, 22.3888, 22.9326],
        [26.9570, 23.0662, 23.6872, 25.0608, 26.3952, 27.5062, 28.3354],
        [19.1057, 25.5069, 25.5157, 23.7116, 19.8423, 16.3118, 16.1463],
        [ 6.7363,  6.7543,  6.7584,  6.7606,  6.7586,  6.7415,  6.7344],
        [17.1490, 21.3404, 20.7216, 19.2326, 17.3848, 15.8037, 15.1687],
        [13.8652, 17.1602, 19.9613, 19.0055, 17.9035, 17.0087, 16.4189],
        [16.3708, 19.0466, 19.3172, 18.8609, 18.4172, 18.1057, 17.9353],
        [29.2712, 28.5888, 28.5133, 28.6361, 28.8075, 28.9877, 29.1620]])
ground truth
tensor([[30.9382, 10.7710, 16.9926, 10.9694, 15.2778, 22.8883, 32.8940],
        [30.4989, 18.8209, 18.5374, 12.8401, 17.5454, 30.0028, 50.0850],
        [19.9500, 30.7601, 33.6533, 12.3225, 15.1236, 10.1394, 12.2304],
        [ 6.6675,  6.4440, 10.5865, 13.0326, 14.3083,  7.6144,  6.6544],
        [15.4053, 13.7755, 15.8872, 34.3821, 33.6593, 13.2795, 16.3407],
        [14.3083, 1

batch_predictions
tensor([[19.1858, 22.7309, 21.7541, 20.0683, 17.9893, 16.2579, 15.9686],
        [22.1431, 23.1230, 23.0290, 22.6246, 22.1867, 21.8386, 21.6368],
        [14.7730, 18.3105, 19.2122, 18.0217, 16.5476, 15.0179, 13.8250],
        [21.4461, 19.1313, 19.5537, 20.4266, 21.2094, 21.8420, 22.3336],
        [19.4305, 17.9354, 19.1511, 20.1865, 20.9875, 21.5599, 21.9359],
        [21.2866, 18.7693, 18.6864, 19.5347, 20.4494, 21.3120, 22.1217],
        [19.9752, 17.9800, 18.9829, 19.9399, 20.6931, 21.2340, 21.5998],
        [19.1664, 24.2072, 28.2112, 29.1276, 28.5780, 27.4590, 25.7077]])
ground truth
tensor([[18.2141, 22.2646, 29.2215, 10.6391, 13.8348, 12.6644, 14.1504],
        [ 2.7778, 13.7613, 13.9314, 19.1327, 33.1207, 25.1417, 16.0856],
        [24.2109, 29.6028,  8.2720, 10.3235, 10.0605, 12.7170, 15.8864],
        [18.9342,  2.7778, 13.7613, 13.9314, 19.1327, 33.1207, 25.1417],
        [30.1420, 23.2246, 20.0421, 12.5460, 17.2672, 17.3724, 17.2541],
        [11.8764, 1

batch_predictions
tensor([[14.0754, 16.6928, 18.5512, 19.0614, 18.9009, 18.6230, 18.3916],
        [20.9111, 20.4720, 20.4329, 20.5419, 20.6859, 20.8361, 20.9826],
        [26.0381, 34.1860, 35.6882, 34.0283, 29.1283, 20.2654, 19.5098],
        [ 5.4932,  5.8330,  5.8773,  5.8699,  5.8640,  5.8537,  5.8456],
        [16.9971, 22.1489, 25.9907, 26.4692, 25.8520, 24.7788, 23.3618],
        [19.9176, 18.2016, 19.3739, 20.3728, 21.1359, 21.6760, 22.0335],
        [14.1187, 13.5346, 13.5784, 13.9069, 14.3442, 14.7917, 15.1826],
        [21.4941, 18.4840, 19.9546, 21.5464, 23.0321, 24.1708, 24.8503]])
ground truth
tensor([[10.1920, 10.7575,  8.8901, 18.7796, 23.3824,  2.5381,  4.8264],
        [28.1321, 11.6071, 13.2653, 13.8322, 13.0527, 15.6463, 30.0028],
        [20.3051, 19.3714, 85.1920, 53.9058, 64.7291, 14.0978, 12.3882],
        [ 9.0084, 11.5071, 17.7275, 18.9506, 27.6433,  6.3388, 13.9926],
        [12.9252, 15.5612, 13.6621, 20.1814, 33.9711, 33.8719,  6.3067],
        [30.3571,  

batch_predictions
tensor([[24.0542, 21.1287, 21.8097, 22.9346, 23.9495, 24.7611, 25.3640],
        [12.4176, 14.8773, 17.4801, 17.1947, 16.4462, 15.7586, 15.2228],
        [24.2447, 21.1191, 22.5765, 24.3042, 25.8126, 26.9099, 27.5884],
        [13.6136, 15.8645, 16.7933, 16.3067, 15.6310, 15.0281, 14.5351],
        [24.8436, 33.9858, 35.4315, 33.5789, 28.6290, 20.0369, 19.5114],
        [23.8423, 21.6781, 21.7080, 22.4715, 23.2624, 23.9426, 24.4943],
        [19.1273, 22.4892, 21.5618, 19.7591, 17.4204, 15.7376, 15.6139],
        [20.1398, 18.5343, 18.4356, 19.0097, 19.6496, 20.2256, 20.7221]])
ground truth
tensor([[29.2800, 15.3203, 17.2336, 22.5624,  9.6655, 20.2381, 15.0652],
        [11.5729, 15.9258, 23.1326, 23.6981, 14.7291, 14.2688, 13.0852],
        [ 5.3430, 12.9252, 15.5612, 13.6621, 20.1814, 33.9711, 33.8719],
        [22.7891, 25.4819, 27.5794, 15.4337, 15.3203, 11.9615, 10.0624],
        [34.0873, 63.3877, 44.1610, 13.9137, 12.0594, 15.0316, 14.1241],
        [29.2092, 2

batch_predictions
tensor([[17.8973, 18.3705, 18.2582, 18.0384, 17.8576, 17.7500, 17.7070],
        [ 8.2612,  4.7894,  4.1097,  4.0827,  4.0903,  4.1007,  4.1112],
        [ 7.7181,  7.9785,  8.0481,  8.0582,  8.0526,  8.0407,  8.0360],
        [18.6411, 24.2352, 23.2220, 21.0458, 18.1196, 16.4238, 16.5265],
        [15.3079, 17.7360, 17.6884, 16.9026, 16.0775, 15.3657, 14.8019],
        [21.7137, 30.6095, 31.3264, 28.8602, 22.2881, 16.5321, 17.1057],
        [18.9773, 21.9048, 21.0829, 19.8777, 18.6234, 17.4582, 16.7335],
        [22.8367, 22.1332, 21.9368, 22.0195, 22.2544, 22.6011, 23.0341]])
ground truth
tensor([[11.8339, 12.8118, 18.2256, 23.6536, 21.5986, 14.3282, 15.3628],
        [ 7.3777,  8.7980, 13.2430, 14.8606, 10.3235,  7.1278,  5.4182],
        [ 8.3509, 11.9411, 12.8485, 18.6875, 25.3025,  0.3945,  2.7486],
        [10.5997,  8.6007, 13.3614, 20.9890, 20.2130, 34.7712, 13.7559],
        [28.9716, 21.1205, 10.5602, 14.1373, 10.6260, 11.3493, 13.4929],
        [14.5844, 3

batch_predictions
tensor([[26.6043, 26.8524, 26.8378, 26.7709, 26.7120, 26.6812, 26.6779],
        [14.7442, 19.5843, 24.0064, 25.2008, 24.5966, 23.3535, 21.7425],
        [19.7748, 24.1122, 25.5139, 24.8106, 23.2612, 20.9461, 18.6676],
        [10.9709, 12.6621, 15.3725, 18.4359, 17.2926, 15.8595, 14.4979],
        [ 5.1109,  4.9778,  4.9665,  4.9739,  4.9795,  4.9823,  4.9848],
        [18.4019, 24.0471, 28.2371, 28.0593, 26.4021, 22.4066, 16.7706],
        [ 3.9879,  4.0501,  4.0480,  4.0475,  4.0472,  4.0401,  4.0358],
        [17.1652, 17.3559, 17.2220, 17.0382, 16.8999, 16.8294, 16.8160]])
ground truth
tensor([[26.7290, 22.3073, 24.5323, 35.2749, 19.3311, 19.4586, 19.9546],
        [13.5913, 13.1236, 15.6604, 26.1054, 35.6859, 35.4734, 12.5567],
        [12.0890, 24.2347, 28.6848, 41.6525, 22.2222, 10.4875, 16.7375],
        [18.4771, 20.4235,  5.1420,  9.3766,  6.7728,  7.4566,  9.9027],
        [ 6.7035, 12.1032, 11.2670,  5.7540,  4.8753,  5.2154,  6.3776],
        [13.6196, 1

batch_predictions
tensor([[28.6567, 24.4684, 25.5368, 27.0341, 28.3722, 29.4333, 30.2128],
        [20.4142, 27.4336, 28.7962, 27.7751, 25.4920, 20.8067, 18.2034],
        [31.5873, 36.4267, 36.4067, 35.8275, 35.1598, 34.5368, 34.0137],
        [ 6.5374,  6.3563,  6.3279,  6.3316,  6.3366,  6.3384,  6.3401],
        [16.0105, 19.7678, 20.2742, 19.3803, 18.3191, 17.3231, 16.5507],
        [11.4873, 12.5460, 13.5013, 14.5221, 15.3218, 15.3657, 15.0401],
        [10.7740, 12.1469, 13.8178, 16.1857, 16.8823, 16.2133, 15.5817],
        [12.9082, 11.9767, 11.6449, 11.5556, 11.6428, 11.8662, 12.2351]])
ground truth
tensor([[14.7534,  0.0000,  5.5272, 25.3685, 22.8175, 31.7177, 48.8237],
        [21.6837, 47.6190, 44.4870, 23.9229,  9.2120, 19.2177, 14.5833],
        [24.9150, 20.4082, 24.6457, 30.6548, 41.8651, 41.1706, 19.3311],
        [ 9.3254,  8.3900, 12.6417,  9.8923,  6.0091,  9.4813,  9.5663],
        [13.4637,  9.4813, 14.1723, 16.2415, 22.9592, 19.5437, 11.1961],
        [ 9.7712, 1

batch_predictions
tensor([[20.6908, 20.6119, 20.6253, 20.6396, 20.6482, 20.6627, 20.6888],
        [20.1136, 21.9081, 21.5004, 20.6908, 19.8352, 19.0856, 18.5636],
        [16.0272, 15.5481, 16.8627, 17.9487, 18.6921, 19.0954, 19.2922],
        [18.6892, 19.7645, 19.4856, 18.9669, 18.5028, 18.1833, 18.0219],
        [24.5155, 21.4864, 21.4245, 22.4470, 23.5421, 24.5284, 25.3699],
        [ 7.8695,  8.1284,  8.1933,  8.2007,  8.1946,  8.1808,  8.1744],
        [22.6242, 20.1332, 21.5129, 22.8440, 23.9784, 24.8873, 25.5562],
        [17.5070, 17.2889, 17.3312, 17.4218, 17.5065, 17.5792, 17.6428]])
ground truth
tensor([[26.0346, 40.4337, 38.0811, 15.7313, 17.2902, 12.9677, 12.0890],
        [15.5181, 17.6881, 27.7880, 27.9853, 13.8217, 12.8880, 13.8480],
        [31.4768,  9.5380, 10.9836, 18.2256, 11.9615, 16.8226, 25.1276],
        [11.1820, 12.7834, 17.6446, 21.4427, 28.7557, 16.5391, 12.6701],
        [41.1423,  7.4688, 18.0556, 13.3503, 21.2302, 18.2965, 32.6247],
        [ 8.6402,  

batch_predictions
tensor([[17.5783, 23.3811, 23.5766, 22.7145, 21.3348, 19.3988, 17.6935],
        [17.2686, 23.7889, 25.4946, 24.2027, 21.1823, 16.1836, 15.7143],
        [19.7934, 18.8791, 18.7274, 19.0310, 19.5673, 20.2144, 20.8994],
        [27.0867, 24.6042, 24.3140, 24.9452, 25.7358, 26.4674, 27.0869],
        [17.4081, 21.0140, 20.3511, 19.1292, 17.8055, 16.6606, 15.9538],
        [20.5082, 27.1540, 27.0170, 25.3386, 21.9378, 18.1794, 17.6142],
        [21.6711, 24.3585, 23.9205, 22.7819, 21.3050, 19.7371, 18.6795],
        [ 9.8058,  9.9289, 10.0375, 10.1130, 10.1498, 10.1595, 10.1578]])
ground truth
tensor([[18.8776, 36.6071, 12.6276, 23.2710,  5.2721,  8.8152, 13.6338],
        [11.6497, 11.0544, 15.4337, 23.5969, 37.2591, 28.4297, 11.9898],
        [34.1005, 44.3977, 12.7696, 15.4655, 10.8890, 11.4413, 19.9106],
        [26.6570,  0.0000, 19.7659, 18.4377, 18.7796, 11.4545, 38.4271],
        [18.9342, 28.8265,  8.0074,  7.3980, 14.2149, 12.8260, 13.8889],
        [ 0.0000, 1

batch_predictions
tensor([[37.9151, 34.9278, 34.9307, 35.6853, 36.3520, 36.7861, 37.0472],
        [31.6329, 30.4850, 30.1875, 30.2739, 30.5281, 30.8798, 31.2907],
        [20.2446, 19.9488, 19.9671, 20.0592, 20.1563, 20.2450, 20.3254],
        [20.9483, 20.2989, 20.2386, 20.4147, 20.6569, 20.9015, 21.1251],
        [22.9056, 29.2635, 28.5031, 26.6874, 23.7833, 19.9504, 19.2252],
        [19.9212, 26.0551, 25.0001, 22.0507, 17.4090, 16.0887, 16.3134],
        [24.7912, 28.7969, 28.4571, 27.1191, 24.9178, 21.6763, 19.5508],
        [19.6040, 25.4072, 27.9121, 27.2077, 25.3925, 21.6388, 18.1184]])
ground truth
tensor([[49.4048, 30.6548, 27.6927, 27.5935, 34.5805, 45.1956, 55.9382],
        [55.0595, 37.2166, 10.4025, 12.1173, 19.8696, 18.7075, 31.6752],
        [31.9728, 26.0062, 17.4603, 12.3724, 17.7012, 13.6763, 16.9643],
        [30.7601, 28.3535, 18.5560, 15.0316, 13.0195, 16.5308, 20.0552],
        [18.6744, 19.7922, 26.0521, 36.4150, 26.6570,  0.0000, 19.7659],
        [19.9688, 3

batch_predictions
tensor([[14.3901, 15.9279, 16.1136, 15.7195, 15.2758, 14.9240, 14.6866],
        [14.9194, 14.8305, 15.1144, 15.3986, 15.6039, 15.7423, 15.8362],
        [24.0951, 29.2491, 28.5517, 27.0241, 24.7931, 21.6123, 19.8828],
        [29.9538, 29.1231, 28.9738, 29.0998, 29.3157, 29.5606, 29.8073],
        [22.1493, 21.8468, 21.8478, 21.9333, 22.0249, 22.1129, 22.1977],
        [13.1925, 15.9760, 17.5821, 18.1396, 18.1207, 17.9204, 17.7071],
        [19.0295, 20.3131, 20.0216, 19.3404, 18.6301, 18.0546, 17.6949],
        [10.1058, 11.4428, 13.0511, 15.3565, 16.9553, 16.2803, 15.5819]])
ground truth
tensor([[15.6234, 20.8969, 16.2809,  7.0095,  9.9947,  9.3635, 10.2709],
        [17.5171, 25.8285, 16.5965, 10.4287,  9.4555, 23.4219, 18.0431],
        [23.6718, 35.0342, 15.3077, 49.8948, 23.3167, 28.6691, 25.3419],
        [49.4898, 34.3537,  0.0000, 17.3328, 20.5924, 25.5527, 33.3900],
        [30.6286, 35.6260,  1.9858, 11.2309, 16.7017, 14.1241, 19.8448],
        [ 0.0000,  

batch_predictions
tensor([[28.8964, 23.0108, 24.0570, 26.0144, 28.1261, 30.1514, 32.0025],
        [19.1699, 17.3356, 17.5166, 18.2881, 19.0176, 19.6350, 20.1518],
        [20.0278, 20.6968, 20.4638, 20.0799, 19.7476, 19.5296, 19.4316],
        [18.7577, 17.3902, 18.5848, 19.5540, 20.2471, 20.7026, 20.9914],
        [18.3969, 17.7498, 17.7210, 17.9204, 18.1602, 18.3819, 18.5722],
        [12.6318, 15.1276, 17.3050, 16.9399, 16.3205, 15.8145, 15.4729],
        [16.8146, 20.1505, 19.6227, 18.6069, 17.5846, 16.7429, 16.1926],
        [14.3040, 15.0196, 17.1140, 19.2349, 20.9720, 21.1719, 20.9561]])
ground truth
tensor([[79.2346, 34.7054,  8.9164, 16.4256, 14.6370, 31.2599, 58.8901],
        [13.0385, 14.8810, 12.0748, 15.7596, 22.6332, 32.3554, 21.4994],
        [15.8305, 14.6117, 35.0198, 13.7613, 28.1037, 15.6463, 11.8339],
        [33.6451, 14.3566, 11.7772, 12.5850, 16.2982, 20.7908, 28.6423],
        [22.0947, 19.0193,  8.9144, 10.2466, 11.2387,  9.7647,  9.6655],
        [ 8.6026, 1

tensor([[ 8.3377, 13.3482, 11.1915, 12.3948, 12.3027, 24.9737,  7.7196],
        [ 7.1003, 12.6276, 18.7358, 16.4966, 27.4235, 30.6831,  0.1559],
        [12.0068,  3.6428,  9.0084,  7.3514,  7.4171, 11.2046, 14.0978],
        [ 7.4972,  8.6168,  7.4405, 12.2874,  6.3209,  4.7194,  5.8248],
        [ 8.1273,  9.4161, 14.6502, 27.1305, 18.0957,  0.0000,  9.3503],
        [10.9410, 11.7347, 23.1718, 16.5675, 22.6190, 12.2307, 15.0510],
        [ 4.8895,  6.7460, 11.3379,  7.3838,  6.4626,  8.5034,  6.1933],
        [18.5941, 11.1820, 13.4637,  9.4813, 14.1723, 16.2415, 22.9592]])
batch_predictions
tensor([[24.9929, 33.7334, 36.5796, 35.6083, 32.2872, 24.1352, 19.5285],
        [24.6778, 26.9930, 27.0466, 26.2683, 24.9697, 23.3128, 21.6626],
        [16.6957, 15.5003, 16.3945, 17.0878, 17.5713, 17.8947, 18.1081],
        [16.8387, 21.4546, 21.2136, 20.2333, 19.0453, 17.9164, 17.1503],
        [17.2036, 21.7722, 20.5988, 18.4847, 16.1421, 14.8219, 14.7620],
        [19.8656, 21.2465, 20.79

batch_predictions
tensor([[18.8524, 17.3211, 17.7718, 18.3755, 18.8520, 19.2073, 19.4693],
        [11.8623, 10.7565,  9.4294,  7.9810,  7.2632,  7.0844,  7.0538],
        [16.4903, 19.7125, 19.8567, 19.2171, 18.5500, 18.0239, 17.6876],
        [10.0259, 11.0449, 12.0531, 13.4753, 15.9372, 17.5158, 16.6020],
        [15.4248, 18.2747, 18.4531, 17.7962, 17.1327, 16.6166, 16.2784],
        [11.4276, 12.4684, 13.6782, 14.9741, 15.3527, 15.0361, 14.5953],
        [16.9189, 19.1220, 18.4920, 17.4780, 16.5251, 15.7612, 15.2786],
        [20.5680, 23.8333, 23.1193, 21.4155, 18.9170, 16.7916, 16.5073]])
ground truth
tensor([[22.8316, 13.5629, 13.8322, 12.4008, 16.2982, 17.0210, 25.9921],
        [ 6.4484,  4.2800,  7.4972,  8.6168,  7.4405, 12.2874,  6.3209],
        [14.5408, 14.5550, 14.5125, 16.5108, 22.5907, 22.0380, 13.8464],
        [14.0715, 12.9274,  9.5608, 12.7564, 19.2530, 16.7675,  0.0000],
        [21.8396, 31.0941, 27.3384, 10.6718, 13.1378, 11.7772, 14.3141],
        [13.1247, 1

batch_predictions
tensor([[14.0748, 16.4938, 17.3328, 16.7721, 16.0892, 15.5101, 15.0689],
        [15.7796, 18.3134, 18.3974, 17.7759, 17.1524, 16.6729, 16.3634],
        [ 9.5834, 10.0176, 10.3906, 10.7536, 11.1570, 11.6952, 12.5679],
        [11.5215, 13.1610, 15.2000, 16.8699, 16.4095, 15.8384, 15.4028],
        [20.0002, 20.5102, 20.2152, 19.7384, 19.2983, 18.9894, 18.8413],
        [17.5368, 19.7948, 19.3254, 18.5544, 17.8528, 17.3193, 16.9884],
        [19.6665, 24.9070, 23.6797, 21.2602, 18.0310, 16.3840, 16.5997],
        [27.8241, 34.8958, 35.0245, 33.7253, 31.5199, 28.0216, 23.4091]])
ground truth
tensor([[12.7038, 12.9669, 17.1357, 20.1604, 25.4603, 10.7575, 11.5992],
        [13.3219, 18.3325, 22.9748, 12.5197, 11.3887,  8.4955, 11.3098],
        [ 8.5350, 10.1789, 34.1136, 45.1999,  0.0000, 11.0468,  9.4950],
        [12.3356, 23.7507,  3.2351,  5.6549, 11.5860,  9.6002,  9.0347],
        [25.8548, 31.4177,  5.1815, 16.7938,  5.7075, 15.7812, 16.2415],
        [15.0794, 2

batch_predictions
tensor([[20.6760, 24.8700, 23.9970, 22.2706, 19.8993, 17.6549, 17.2853],
        [ 7.9131,  7.7080,  7.5902,  7.5307,  7.5065,  7.5023,  7.5010],
        [25.4593, 26.8368, 26.6967, 26.1970, 25.6331, 25.1373, 24.7739],
        [20.0319, 17.8696, 19.3773, 20.8040, 21.9551, 22.7191, 23.1617],
        [15.7271, 17.7457, 17.7101, 17.1396, 16.5871, 16.1756, 15.9193],
        [20.8623, 20.7451, 20.7683, 20.8022, 20.8291, 20.8565, 20.8890],
        [17.0511, 22.0709, 21.6073, 20.0210, 17.7351, 15.8205, 15.4153],
        [37.4593, 34.3173, 33.5850, 34.1236, 35.0925, 36.0087, 36.6641]])
ground truth
tensor([[16.3549, 15.3345, 22.6899, 38.6196,  8.4184, 12.9819, 14.3849],
        [21.4361, 24.6055, 14.4924, 16.4519, 13.3219, 14.5187, 18.2667],
        [20.4507, 23.7670, 21.2018, 24.2772, 33.4467, 26.0771, 21.3294],
        [23.2426,  5.3288, 16.4257, 11.0686, 14.4274, 21.9246, 35.7001],
        [15.5313, 15.0973,  9.3898,  0.0000,  0.4471,  7.2330,  0.0000],
        [11.4654, 1

batch_predictions
tensor([[10.9599, 13.1051, 15.6644, 18.7694, 18.5110, 17.6093, 16.7087],
        [11.0359, 12.8435, 14.6408, 16.5793, 17.3967, 17.1779, 16.7665],
        [15.9910, 18.2343, 18.2505, 17.6777, 17.1156, 16.6956, 16.4351],
        [21.3703, 21.0687, 21.0692, 21.1663, 21.2767, 21.3857, 21.4895],
        [17.8662, 21.1121, 20.6770, 19.7819, 18.8837, 18.1542, 17.6815],
        [18.0306, 20.2576, 19.6784, 18.6459, 17.5574, 16.5874, 15.9491],
        [22.0175, 26.3760, 25.5447, 23.6120, 20.5131, 17.7115, 17.6144],
        [19.7790, 19.7137, 19.7245, 19.7338, 19.7365, 19.7436, 19.7589]])
ground truth
tensor([[ 3.7612,  3.3798,  8.4429, 11.6386, 15.0053, 17.4908,  6.3388],
        [ 7.1541,  3.1168,  9.9553, 25.1052, 29.4450, 27.8406,  9.5871],
        [11.3887,  8.4955, 11.3098, 18.7533, 26.3545, 28.5245, 15.0053],
        [24.0400, 32.6407,  6.4703, 15.3077, 15.5576, 17.5302, 25.6970],
        [17.3330, 22.2120, 21.3309, 26.4335, 13.1510,  3.7612, 10.0605],
        [ 8.5744, 1

tensor([[ 9.7506,  4.5918,  3.9966,  4.9461,  5.6406,  5.9807,  8.9144],
        [ 2.1831, 10.7312, 11.4676, 13.7691, 20.3314, 17.7012,  6.6807],
        [21.8963, 24.9150, 20.4082, 24.6457, 30.6548, 41.8651, 41.1706],
        [18.1220, 28.2746, 24.9079, 29.5634, 34.5739, 45.1604, 18.3456],
        [42.8713, 35.3600, 16.1706, 11.4796, 15.0368, 24.1780, 33.7160],
        [ 9.1268,  6.7859, 12.7959, 12.5592, 17.4250, 21.4361,  8.6665],
        [16.2020,  7.1147, 12.0068,  4.9448,  9.5739, 12.4803, 18.8585],
        [19.0295, 27.5118, 32.4829, 11.0074, 13.0589, 16.5308, 11.3756]])
batch_predictions
tensor([[17.9269, 17.8578, 17.8838, 17.9105, 17.9287, 17.9467, 17.9681],
        [21.4557, 22.8661, 22.5223, 21.6690, 20.6432, 19.6598, 18.9364],
        [18.9412, 18.8283, 18.8413, 18.8628, 18.8777, 18.8937, 18.9149],
        [12.9921, 12.5195, 12.2915, 12.1850, 12.1649, 12.2201, 12.3317],
        [17.8444, 18.7983, 18.5068, 17.9514, 17.4309, 17.0549, 16.8531],
        [19.3142, 21.3508, 20.86

batch_predictions
tensor([[38.3163, 38.3162, 38.3162, 38.3161, 38.3158, 38.3157, 38.3157],
        [28.0717, 32.5993, 32.7363, 31.5815, 29.6247, 26.5441, 22.7596],
        [15.7565, 17.4529, 17.4090, 16.9357, 16.4891, 16.1712, 15.9880],
        [15.7323, 19.8631, 22.7432, 22.8577, 22.2733, 21.5737, 20.9679],
        [21.2532, 28.7907, 28.1979, 26.1098, 22.0052, 17.5745, 17.8091],
        [17.3846, 18.0373, 17.8776, 17.5445, 17.2537, 17.0678, 16.9856],
        [20.9187, 27.4265, 32.4229, 33.2183, 32.2636, 30.1373, 25.8488],
        [20.1533, 20.6952, 20.4712, 20.1558, 19.9097, 19.7667, 19.7163]])
ground truth
tensor([[45.5155, 53.4850, 55.9179, 38.1378, 40.7286, 22.9090, 46.3177],
        [43.3532, 41.5533, 32.8656, 15.5329, 17.9847, 18.8209, 19.7988],
        [16.0047, 22.6854, 23.3035,  9.1268, 11.8622,  9.9684, 10.4419],
        [ 7.2199, 17.6355, 15.8206, 15.5050, 19.8974, 31.9306,  4.0505],
        [14.3141, 23.0300, 45.5782, 37.4717, 22.8175, 18.1973, 13.1236],
        [21.1026, 2

batch_predictions
tensor([[19.1608, 23.7641, 22.8315, 20.7954, 18.0857, 16.2096, 16.1897],
        [17.2712, 18.5797, 18.1889, 17.4053, 16.6163, 15.9849, 15.5957],
        [22.4045, 20.4477, 20.1183, 20.6909, 21.4845, 22.2855, 23.0509],
        [22.3519, 20.1296, 21.3131, 22.4504, 23.3517, 24.0129, 24.4675],
        [17.4670, 19.2859, 18.9076, 18.3077, 17.8061, 17.4686, 17.2967],
        [15.9533, 20.2338, 20.1696, 19.0379, 17.5848, 16.1009, 15.0220],
        [17.8304, 16.7275, 16.7275, 17.1008, 17.4742, 17.7853, 18.0349],
        [13.7879, 13.4163, 13.4479, 13.6370, 13.8897, 14.1530, 14.3866]])
ground truth
tensor([[15.0842, 15.7943, 18.1352, 23.6586, 39.3083, 29.1426, 19.5292],
        [14.7159, 14.4924, 12.1515, 20.7128,  8.2062, 15.1368, 12.1515],
        [33.4042, 16.5391, 10.2324,  9.3821, 14.3566, 18.1548, 32.7523],
        [22.4348, 24.7307, 28.9116, 23.0159, 21.3435, 25.3827, 31.5618],
        [13.2937, 19.4586, 28.3305, 25.7086, 13.8039, 15.0935, 23.4127],
        [33.5317, 3

batch_predictions
tensor([[26.4794, 22.4454, 23.2945, 24.7215, 26.0995, 27.2714, 28.1608],
        [21.3469, 22.4208, 22.2031, 21.6276, 21.0147, 20.5370, 20.2361],
        [17.6994, 20.4841, 19.8819, 18.7890, 17.6452, 16.6242, 15.9265],
        [24.3478, 32.0240, 33.5073, 32.0931, 29.2557, 24.3683, 21.1520],
        [21.4585, 29.8022, 30.1548, 29.0585, 27.1224, 23.3742, 19.1019],
        [14.3460, 13.9601, 13.9835, 14.1283, 14.2972, 14.4554, 14.5885],
        [15.6163, 20.2017, 20.9062, 20.0341, 18.9012, 17.6956, 16.6413],
        [11.9975, 15.7776, 20.9145, 23.7011, 23.0735, 21.7456, 20.0775]])
ground truth
tensor([[33.6026,  0.0000, 10.6293, 19.5437, 17.3044, 22.1514, 36.2812],
        [24.2489, 24.4898, 12.2449, 14.6825, 14.1723, 12.1032, 18.4807],
        [ 9.7931, 17.0777, 13.4779, 32.8656, 22.9592, 10.4167, 15.0227],
        [14.6400, 24.2772, 33.6168, 46.8679,  9.8498, 12.5567, 16.1706],
        [23.4127, 39.5125,  6.8736,  7.3838, 20.8192, 11.5930, 15.8022],
        [14.0321,  

batch_predictions
tensor([[36.4363, 37.6342, 37.5869, 37.4495, 37.3413, 37.2727, 37.2355],
        [11.3721, 12.9281, 14.6426, 16.6456, 16.8735, 16.2907, 15.7502],
        [27.6403, 28.0762, 28.0666, 27.9161, 27.7427, 27.6090, 27.5382],
        [16.4756, 15.2959, 15.8145, 16.3787, 16.7937, 17.0870, 17.2936],
        [20.7028, 27.1918, 26.1701, 23.5070, 19.0133, 16.8180, 17.1103],
        [17.7237, 17.9961, 17.8954, 17.7292, 17.5972, 17.5234, 17.4999],
        [15.4378, 15.6539, 15.5983, 15.4885, 15.4008, 15.3531, 15.3376],
        [19.2818, 20.7324, 20.3799, 19.7224, 19.0891, 18.5976, 18.2982]])
ground truth
tensor([[39.1298, 59.1128, 46.8821, 34.0561, 23.5686, 27.1684, 25.6519],
        [ 9.1005, 10.1131, 11.1915, 13.7296, 14.6107, 25.3288, 27.5776],
        [ 2.4093, 12.2307, 23.4127, 39.5125,  6.8736,  7.3838, 20.8192],
        [23.3035,  9.1268, 11.8622,  9.9684, 10.4419, 17.2278, 20.8574],
        [19.1185, 21.0317, 34.8639, 28.4722, 30.2863, 21.1593, 20.4507],
        [22.7117, 2

batch_predictions
tensor([[23.3573, 23.5937, 23.5393, 23.3983, 23.2596, 23.1670, 23.1309],
        [18.5532, 24.2130, 25.1992, 24.1218, 22.0412, 18.8364, 17.1704],
        [ 8.6403,  8.8162,  8.8854,  8.9028,  8.8942,  8.8768,  8.8645],
        [17.8932, 20.6798, 20.0980, 19.0940, 18.0822, 17.2117, 16.6064],
        [22.9323, 31.0417, 31.0086, 29.6478, 27.2681, 23.0949, 19.3857],
        [34.3855, 35.1159, 35.2179, 35.1142, 34.9514, 34.8081, 34.7142],
        [17.6991, 16.5022, 17.5623, 18.3643, 18.9118, 19.2695, 19.5028],
        [19.9124, 26.8376, 26.4872, 24.8411, 21.6774, 18.1709, 17.8122]])
ground truth
tensor([[21.1026, 37.2024, 39.8243, 10.0198, 14.9235, 22.3781, 19.3027],
        [12.8968, 13.5771, 10.0624, 17.0068, 31.5760, 42.7863, 30.5981],
        [11.9937, 15.0710, 15.6628,  8.5087,  7.9300,  5.9442,  8.6928],
        [15.6102, 15.8075, 17.2541, 28.0642, 31.3914, 13.7822, 16.7543],
        [25.9070, 42.4745, 41.1565, 12.5850, 21.3010, 23.5402, 17.7154],
        [56.7596, 5

batch_predictions
tensor([[38.3425, 38.3170, 38.3252, 38.3286, 38.3301, 38.3313, 38.3324],
        [35.6460, 37.1407, 37.1617, 36.9836, 36.8050, 36.6759, 36.6000],
        [18.5056, 18.8456, 18.7402, 18.5276, 18.3409, 18.2243, 18.1776],
        [24.2084, 32.6935, 32.7397, 30.5644, 25.9083, 18.9949, 18.9871],
        [22.8709, 20.2159, 20.8806, 21.9765, 22.9643, 23.7601, 24.3631],
        [10.6591, 10.2521,  9.7334,  9.1230,  8.5444,  8.1494,  7.9491],
        [12.9780, 16.2367, 18.7760, 17.6275, 16.2627, 14.9602, 13.7482],
        [15.2931, 17.3112, 17.1958, 16.4858, 15.7609, 15.1531, 14.6915]])
ground truth
tensor([[52.1699, 63.6376, 64.6107, 41.6491, 42.5829, 45.9890, 42.5565],
        [31.0091, 42.5454, 67.6446, 59.4388, 35.5726, 34.4246, 21.9104],
        [20.7391, 17.7538, 11.2704, 10.8759, 13.9400, 14.9921, 15.5050],
        [17.3753, 32.1854, 55.2721, 40.0652, 16.5816, 21.7545, 18.0981],
        [31.3519, 22.5671, 19.6476, 16.4913, 17.6092, 33.2062, 34.9816],
        [10.5471, 1

batch_predictions
tensor([[18.2402, 19.0969, 18.9114, 18.5428, 18.2262, 18.0227, 17.9298],
        [20.8251, 21.3116, 21.1973, 20.9581, 20.7452, 20.6072, 20.5467],
        [15.4739, 18.7322, 19.1431, 18.5457, 17.9502, 17.5081, 17.2473],
        [19.4740, 19.0987, 19.0933, 19.1905, 19.2935, 19.3868, 19.4710],
        [25.6109, 27.0227, 26.8470, 26.3774, 25.9087, 25.5460, 25.3227],
        [ 8.9671,  7.6135,  6.9977,  6.8510,  6.8381,  6.8521,  6.8655],
        [19.8423, 18.3950, 19.7245, 20.9428, 21.9607, 22.7580, 23.3208],
        [23.6945, 32.7453, 33.4388, 30.7789, 23.9927, 17.2314, 17.9141]])
ground truth
tensor([[13.0669, 13.8889, 16.5675, 20.2806,  0.0000,  0.0000,  9.8498],
        [29.6769, 30.7256, 16.7375, 12.0465, 12.2732, 16.7659, 21.1593],
        [ 8.4467, 19.9121, 23.8237, 21.4144, 12.0748, 14.1015, 11.0261],
        [33.0089, 27.3277, 14.9264, 19.4897, 12.3093, 14.4266, 20.7522],
        [20.9100, 29.1689, 32.1278, 23.1720, 25.7759, 17.2935, 18.1878],
        [15.6463,  

batch_predictions
tensor([[18.4324, 19.2092, 18.7919, 18.2295, 17.7667, 17.4738, 17.3562],
        [17.0851, 16.3691, 17.7905, 19.0752, 20.0894, 20.7209, 21.0076],
        [16.0268, 19.7119, 19.9276, 18.8251, 17.4491, 16.1003, 15.1170],
        [22.6423, 30.7933, 35.7162, 35.3834, 33.2568, 28.1030, 20.0753],
        [18.8497, 17.9818, 17.8700, 18.1020, 18.4359, 18.7774, 19.0955],
        [ 7.5017,  7.0140,  6.8241,  6.7673,  6.7584,  6.7609,  6.7638],
        [15.2833, 18.0071, 18.3189, 17.6916, 17.0519, 16.5574, 16.2354],
        [24.8533, 25.8627, 25.7601, 25.3686, 24.9398, 24.5890, 24.3627]])
ground truth
tensor([[17.9705, 24.5607, 29.8895, 31.8027,  6.1933, 10.5867, 17.9989],
        [23.1009,  6.5476,  9.8781, 11.3379, 11.5079, 33.5317, 33.4467],
        [10.8560, 11.7772, 10.4308, 13.2937, 20.3656, 36.5363, 21.5703],
        [31.2599, 58.8901, 47.5802, 12.2830, 14.0058, 14.4135, 12.1383],
        [28.9847, 29.7344, 16.2152, 17.7538, 17.7275, 13.7428, 18.5429],
        [10.9942, 1

batch_predictions
tensor([[14.9755, 14.9061, 14.9483, 14.9906, 15.0218, 15.0498, 15.0769],
        [17.8152, 21.9455, 20.7959, 18.7072, 16.2553, 14.7750, 14.7314],
        [38.3157, 38.3865, 38.3816, 38.3763, 38.3730, 38.3715, 38.3704],
        [12.9096, 12.4576, 12.2146, 12.0716, 11.9916, 11.9577, 11.9565],
        [17.7748, 21.2495, 20.6662, 19.3675, 17.8758, 16.5404, 15.8046],
        [17.6481, 22.1831, 21.5909, 20.0313, 18.0018, 16.3440, 15.8232],
        [20.4514, 19.0769, 18.8775, 19.2296, 19.7018, 20.1563, 20.5634],
        [14.6247, 18.9158, 23.2282, 24.2871, 23.5628, 22.0845, 20.0442]])
ground truth
tensor([[22.1331, 23.2641,  9.7317, 13.5587, 10.4156, 13.3614, 18.2799],
        [ 9.5522, 16.1706, 17.0493, 28.8974, 29.7194, 11.1678, 13.4921],
        [40.9864, 57.5680, 33.1207, 20.5357, 16.6100, 18.3248, 19.6003],
        [ 9.4292,  9.5082,  7.8511,  8.8638, 10.9679, 16.6886, 14.1767],
        [20.5924, 25.3968, 28.3588, 10.8135,  8.6876,  9.1978, 13.0527],
        [12.8543, 1

batch_predictions
tensor([[10.5284, 11.7709, 13.1474, 15.2530, 17.2377, 16.6233, 15.8733],
        [13.1993, 16.3365, 20.1713, 19.5857, 18.4592, 17.3312, 16.5054],
        [16.2667, 20.9914, 22.1162, 21.3599, 20.2341, 18.8884, 17.5550],
        [21.3115, 19.1641, 19.6092, 20.4482, 21.1772, 21.7484, 22.1796],
        [ 6.3610,  6.5043,  6.5221,  6.5189,  6.5161,  6.5050,  6.5006],
        [26.3256, 29.9543, 29.6329, 28.6791, 27.4726, 25.9984, 24.2834],
        [18.8184, 25.4119, 25.2055, 23.6556, 20.6328, 17.2740, 17.0484],
        [20.3441, 23.4694, 22.8162, 21.5764, 20.0988, 18.5955, 17.7067]])
ground truth
tensor([[16.0856, 11.6780, 14.2149, 22.2080, 19.5295, 10.4592, 14.9235],
        [ 8.1931, 12.0594, 23.1983, 25.9863, 35.4024, 10.6128,  8.8769],
        [14.1582, 22.2647, 30.3713, 24.2489, 31.9019, 28.4297, 13.7046],
        [29.6344, 12.0040, 14.7251, 15.4195, 19.8271, 22.2222, 27.7636],
        [ 4.6291,  5.3524, 12.6775, 14.3609, 11.4150,  6.0757,  6.4045],
        [27.8538, 3

batch_predictions
tensor([[25.9853, 31.1061, 30.7113, 29.6441, 28.3822, 26.8818, 24.6640],
        [25.9626, 26.3909, 26.3244, 26.1794, 26.0582, 25.9866, 25.9603],
        [11.0017, 13.0984, 15.1034, 16.6922, 17.2463, 17.2197, 17.0393],
        [37.6264, 37.5105, 37.5235, 37.5297, 37.5327, 37.5396, 37.5514],
        [19.0584, 19.9883, 19.7372, 19.0981, 18.3903, 17.7966, 17.4179],
        [14.0625, 17.7285, 19.3539, 18.1884, 16.8121, 15.4302, 14.2253],
        [11.5939, 13.8251, 16.9055, 18.0643, 17.2466, 16.4784, 15.9172],
        [30.1812, 28.0405, 27.4209, 27.6717, 28.3254, 29.1015, 29.8721]])
ground truth
tensor([[27.6644, 30.0737, 36.8481, 50.2126, 12.7551, 15.4337, 23.6395],
        [20.7058, 23.4410, 26.3464, 25.0850, 32.9507, 16.3124, 19.0618],
        [14.7392, 14.1015, 15.0652, 19.6145, 25.0142,  9.7647,  3.2738],
        [54.6202, 49.0079, 31.3917, 28.8549, 22.5057, 29.1241, 37.3866],
        [28.0905, 27.4592,  8.3377, 13.3482, 11.1915, 12.3948, 12.3027],
        [ 9.2057, 1

batch_predictions
tensor([[18.6184, 22.9726, 21.5991, 19.5339, 17.3422, 16.0151, 16.0056],
        [19.6932, 17.7477, 18.1327, 18.9444, 19.6497, 20.2035, 20.6273],
        [21.5819, 19.5745, 21.4235, 23.3652, 25.2444, 26.9082, 28.0199],
        [11.2195, 12.4533, 14.1612, 16.1593, 16.1545, 15.5526, 14.9479],
        [27.5797, 22.7940, 24.0743, 25.8336, 27.5569, 29.0622, 30.2150],
        [19.8040, 18.0615, 19.3521, 20.4822, 21.3727, 22.0067, 22.4104],
        [21.3536, 19.4609, 19.8490, 20.7411, 21.5407, 22.1821, 22.6821],
        [12.5077, 14.4649, 16.5947, 16.8642, 16.3760, 15.9363, 15.6340]])
ground truth
tensor([[ 2.3277, 14.9395, 23.0274, 29.5371, 28.3535,  3.8269,  9.6265],
        [12.3866, 11.7489, 11.1961, 20.0397, 26.5448, 24.8866, 26.0771],
        [14.9802, 14.2857, 10.4025, 18.3815, 27.1825, 26.0913, 27.6502],
        [12.9932, 22.1462, 22.4487, 10.0999, 11.8096,  9.2451,  8.7454],
        [38.3645, 17.9280, 18.0414, 16.9785, 22.2931, 36.5221, 42.7863],
        [14.0448, 1

batch_predictions
tensor([[36.2520, 36.3922, 36.3932, 36.3590, 36.3275, 36.3134, 36.3159],
        [17.0430, 21.1992, 21.3579, 20.6928, 19.8952, 19.1675, 18.6294],
        [16.9447, 21.5590, 21.6362, 20.7603, 19.5709, 18.2932, 17.2783],
        [29.2462, 28.8100, 28.8303, 28.9277, 29.0257, 29.1134, 29.1923],
        [23.4296, 29.0157, 28.2555, 26.1215, 22.0840, 17.9584, 18.1250],
        [ 7.4874,  7.5631,  7.5801,  7.5784,  7.5735,  7.5590,  7.5514],
        [11.5380, 13.8056, 16.1947, 17.9257, 18.1276, 17.7674, 17.3466],
        [16.2777, 17.6403, 17.2775, 16.6171, 16.0192, 15.5767, 15.3135]])
ground truth
tensor([[54.6202, 46.4002, 33.4751, 26.0913, 21.0176, 29.5210, 39.1298],
        [12.8880, 11.1915, 18.8585, 24.8422,  1.4861,  0.0000, 13.7822],
        [20.0397, 26.5448, 24.8866, 26.0771,  7.8090, 10.5017, 14.7392],
        [20.4103, 23.3561, 26.1573, 32.8643, 35.7312, 23.0800, 25.5655],
        [17.0635, 31.1366,  0.6236, 22.7749, 26.4881, 15.5612, 16.5108],
        [ 8.6928, 1

batch_predictions
tensor([[19.7079, 21.1487, 20.6311, 19.9035, 19.2437, 18.7436, 18.4513],
        [21.2092, 26.2578, 25.4429, 23.1622, 19.2768, 16.8451, 17.0719],
        [21.3137, 27.3300, 26.3208, 23.7923, 19.2296, 16.5596, 16.8733],
        [13.6585, 16.7017, 19.4812, 18.5223, 17.2685, 16.1352, 15.2386],
        [22.2343, 22.4217, 22.4190, 22.3253, 22.2130, 22.1279, 22.0869],
        [21.4733, 21.4613, 21.4607, 21.4596, 21.4580, 21.4560, 21.4668],
        [17.7892, 22.5658, 23.7495, 21.9313, 18.2385, 15.1352, 14.8216],
        [23.4149, 20.5983, 20.9156, 21.9754, 23.0083, 23.8950, 24.6228]])
ground truth
tensor([[13.0102, 21.1026, 24.3197, 18.2540, 13.8180, 12.7409, 11.8056],
        [17.2902, 12.9677, 12.0890, 24.2347, 28.6848, 41.6525, 22.2222],
        [22.3781, 19.3027, 25.9779, 33.6735, 41.1423,  7.4688, 18.0556],
        [11.0205,  7.5881, 10.5471, 11.3230, 24.2109, 29.6028,  8.2720],
        [39.3740, 16.0179, 18.2930, 21.4624, 14.6633, 21.6728, 28.5113],
        [18.0272, 2

batch_predictions
tensor([[17.5629, 21.5274, 20.0976, 17.7681, 15.1275, 13.9333, 13.9012],
        [38.3899, 38.3587, 38.3666, 38.3698, 38.3711, 38.3720, 38.3730],
        [24.9087, 28.4616, 28.3253, 27.2843, 25.6301, 23.2812, 20.9403],
        [21.8366, 19.4721, 19.5490, 20.4864, 21.4518, 22.3171, 23.0788],
        [17.9087, 17.5730, 17.6024, 17.7250, 17.8526, 17.9649, 18.0606],
        [18.9633, 18.1753, 18.0533, 18.2212, 18.4778, 18.7446, 18.9961],
        [14.6515, 14.2414, 14.3246, 14.5710, 14.8478, 15.0987, 15.2981],
        [19.1450, 19.4013, 19.3383, 19.1968, 19.0711, 18.9932, 18.9636]])
ground truth
tensor([[ 9.1662, 15.3472, 20.3840, 23.9085,  5.0237, 10.7180,  8.3377],
        [42.5039, 56.3388, 23.0668,  0.0000, 29.9579, 43.4377, 46.8701],
        [22.9748, 43.0037, 42.7012, 17.3330, 18.4903, 15.9390, 22.3304],
        [36.0969,  6.7460, 14.0590, 14.7251, 18.8067, 21.1876, 37.0040],
        [26.8141, 26.7432, 17.0777, 12.3441, 15.3628, 13.0102, 21.1026],
        [23.0867, 1

batch_predictions
tensor([[23.5074, 25.8062, 25.4845, 24.5385, 23.2999, 21.9071, 20.6611],
        [11.8905, 13.8543, 16.5695, 17.8503, 16.9927, 16.1418, 15.4808],
        [17.3331, 16.8840, 16.8811, 17.0167, 17.1662, 17.3013, 17.4181],
        [24.4895, 29.5520, 29.0227, 27.5336, 25.2343, 21.7544, 19.7968],
        [10.4884, 11.5283, 12.6438, 14.2265, 16.2702, 16.3914, 15.7870],
        [29.0145, 28.4841, 28.3836, 28.4622, 28.5997, 28.7761, 28.9811],
        [19.0618, 19.7135, 19.3871, 18.8428, 18.3432, 17.9965, 17.8316],
        [18.3967, 22.6904, 22.1921, 21.1268, 19.9289, 18.9009, 18.3013]])
ground truth
tensor([[14.7422, 15.1631, 25.1578, 26.9069, 27.9853, 15.4787, 17.4382],
        [11.3361, 13.3877, 14.4266, 16.2283, 24.3556, 23.2378,  6.5623],
        [11.1961, 12.6701, 15.4337, 15.4195, 15.5612, 27.0408, 17.0210],
        [18.7533, 22.0410, 29.4845, 44.7002, 34.2451, 21.2388, 20.1078],
        [ 9.0986,  6.9728, 11.4371, 20.5074, 17.1202,  4.6910, 10.3175],
        [51.7715, 3

batch_predictions
tensor([[28.6590, 30.1266, 30.1227, 29.7369, 29.2613, 28.8291, 28.5020],
        [30.7605, 36.6681, 36.7319, 36.1108, 35.3245, 34.4918, 33.6482],
        [13.7242, 13.3830, 14.4571, 15.5258, 16.2304, 16.5435, 16.6681],
        [15.7956, 15.8385, 15.8223, 15.7830, 15.7482, 15.7326, 15.7350],
        [15.8835, 19.6559, 20.8388, 20.1222, 19.2337, 18.4131, 17.7684],
        [15.4641, 19.4472, 21.0162, 20.3168, 19.4104, 18.5510, 17.8422],
        [16.3740, 18.6976, 18.2902, 17.3790, 16.4904, 15.7603, 15.2475],
        [20.8158, 20.8549, 20.8480, 20.8057, 20.7582, 20.7297, 20.7279]])
ground truth
tensor([[16.7375, 18.1548, 19.3169, 25.5952, 47.6332, 14.7534,  0.0000],
        [24.8299, 29.9178, 44.8413, 49.2914,  7.4405, 13.5771, 21.7545],
        [18.6612,  4.7212,  7.8248,  9.1399, 10.6128, 11.7964, 23.1983],
        [17.4119, 16.7149,  6.3782,  8.9164, 10.7180, 14.1899, 25.9469],
        [12.0465, 19.2885, 30.4280, 22.6616, 11.3237, 14.3566, 11.1395],
        [11.8359, 1

tensor(75.5777)

In [None]:
e, d, t, gn, ea, da, dea = se, d, t, gn, ea, da, dea = sequence_iter(dl, 3)
print(gn)

In [None]:
t = torch.Tensor([[1, np.nan], [2, 2], [np.nan, np.nan]]).isnan().nonzero().T
print(t)
t[1][t[0] == 0]


In [None]:
d = SequenceDatasetForTransformer(train_sequences)
dl = DataLoader(d, batch_size=8, shuffle=True)

In [None]:
e, d, t, ea, da, dea = next(iter(dl))

In [None]:
t

In [None]:
ea.repeat(3, 1, 1).shape

In [None]:
process_batch_nulls(e,d,t)

In [None]:
# testing this vanilla transformer. 
# Note that to batch, examples within a single batch must have same enc and dec input lengths
test_transformer = Transformer()
test_encoder_x = torch.unsqueeze(torch.Tensor([[1, 2, 3], [7, 8, 9]]).T, -1)
test_decoder_x = torch.unsqueeze(torch.Tensor([[0, 4], [0, 10]]).T, -1)
ground_truth = torch.Tensor([[4, 5], [10, 11]]).T
test_dec_output = torch.squeeze(test_transformer(test_encoder_x, test_decoder_x), -1)
# ground truth and output are now output_len x bs
test_loss = F.mse_loss(test_dec_output, ground_truth)
test_loss.backward()


In [None]:
input1 = torch.Tensor([0, 1, 2, 3, 4])
input2 = torch.Tensor([0, 1, 2, 3, 4])
w1 = torch.nn.Parameter(torch.Tensor([[0, 1, 2, 3, 4],
                                     [0, 1, 2, 3, 4],
                                     [0, 1, 2, 3, 4],
                                     [0, 1, 2, 3, 4], 
                                      [0, 1, 2, 3, 4]]), requires_grad=True)
output1 = input1 @ w1
output2 = input2 @ w1
target1 = torch.Tensor([3, 4, 1, np.nan, 4])
target2 = torch.Tensor([3, np.nan, 1, np.nan, 4])
loss1 = F.mse_loss(output1, target1, reduction='none')
loss_without_nan1 = loss1[~torch.isnan(loss1)]
loss2 = F.mse_loss(output2, target2, reduction='none')
loss_without_nan2 = loss2[~torch.isnan(loss2)]
losses = torch.cat((loss_without_nan1, loss_without_nan2), 0)
final_loss = torch.sum(losses) / (len(loss_without_nan1) + len(loss_without_nan2))
final_loss.backward()


In [None]:
w1.grad

In [None]:
input1 = torch.Tensor([2, 1, 2, 3, 4])
input2 = torch.Tensor([3, 1, 2, 3, 4])
w1 = torch.nn.Parameter(torch.Tensor([[5, 1, 2, 3, 4],
                                     [5, 1, 2, 3, 4],
                                     [6, 1, 2, 3, 4],
                                     [8, 1, 2, 3, 4], 
                                      [9, 1, 2, 3, 4]]), requires_grad=True)
output1 = input1 @ w1
output2 = input2 @ w1
target1 = torch.Tensor([3, 4, np.nan, 4, 4])
indices1 = list((~target1.isnan()).nonzero(as_tuple=True))
output1_nonull = output1[indices1]
target1_nonull = target1[indices1]
target2 = torch.Tensor([3, np.nan, 1, np.nan, 4])
indices2 = list((~target2.isnan()).nonzero(as_tuple=True))
output2_nonull = output2[indices2]
target2_nonull = target2[indices2]
loss1 = F.mse_loss(output1_nonull, target1_nonull)
loss2 = F.mse_loss(output2_nonull, target2_nonull)
final_loss = (loss1 + loss2) / 2
final_loss.backward()


In [None]:
loss1

In [None]:
# pytorch test

inputs = torch.Tensor([4, 0, 7])
weights_vec1 = torch.nn.Parameter(torch.Tensor([[5, 4, 2], [3, 4, 1], [5, 3, 2], [1, 2, 3]]), requires_grad=True)
weights_vec2 = torch.nn.Parameter(torch.Tensor([[1, 2, 3], [4, 5, 6], [2, 3, 2], [4, 3, 1]]), requires_grad=True)
vec3 = torch.Tensor([[1, 2, 3], [4, 5, 6], [2, 3, 2], [4, 3, 1]])
q = weights_vec1 * inputs
k = weights_vec2 * inputs
logits = (q.T @ k) + torch.Tensor([[0, float('-inf'), 0], [0, float('-inf'), 0], [0, float('-inf'), 0]])
logits = logits / 1000  # div by 1000 just for stability for demonstration
logits.retain_grad()
print('logits: ', logits, '\n')
soft = torch.nn.functional.softmax(logits, dim=1)
soft.retain_grad()
print('softmax: ', soft, '\n')
logits2 = (soft @ vec3.T) @ vec3 + torch.Tensor([[0, float('-inf'), 0], [0, float('-inf'), 0], [0, float('-inf'), 0]])
print('logits2: ', logits2, '\n')
soft2 = torch.nn.functional.softmax(logits2, dim=1)
print('softmax: ', soft2, '\n')
final_m = (soft2 @ vec3.T).T
print('final m: ', final_m, '\n')

final = torch.sum(final_m)  # just to get some scalar value at the end
print('final: ', final, '\n')
final.backward()
print('soft gradients: ', soft.grad, '\n')
print('logits gradients: ', logits.grad, '\n')
print('wq gradients: ', weights_vec1.grad, '\n')
print('wk gradients: ', weights_vec2.grad, '\n')

In [None]:
# pytorch test

inputs = torch.Tensor([4, 0, 7])
weights_vec1 = torch.nn.Parameter(torch.Tensor([5, 4, 2]), requires_grad=True)
weights_vec2 = torch.nn.Parameter(torch.Tensor([1, 2, 3]), requires_grad=True)
q = torch.unsqueeze(weights_vec1, 1) @ torch.unsqueeze(inputs, 1).T
k = torch.unsqueeze(weights_vec2, 1) @ torch.unsqueeze(inputs, 1).T
q.retain_grad()
k.retain_grad()
logits = (q.T @ k) + torch.Tensor([[0, float('-inf'), 0], [0, 0, 0], [0, float('-inf'), 0]])
logits = torch.cat([logits[:1], logits[2:]]) / 1000  # div by 1000 just for stability for demonstration
logits.retain_grad()
print('logits: ', logits, '\n')
soft = torch.nn.functional.softmax(logits, dim=1)
soft = torch.cat([soft[:1], torch.Tensor([[0, 0, 0]]), soft[1:]])
soft.retain_grad()
print('softmax: ', soft, '\n')
final = torch.sum(torch.prod(soft + 1, dim=0))  # just to get some scalar value at the end
print('final: ', final, '\n')
final.backward()
print('soft gradients: ', soft.grad, '\n')
print('logits gradients: ', logits.grad, '\n')
print('q gradients: ', q.grad, '\n')
print('k gradients: ', k.grad, '\n')
print('wq gradients: ', weights_vec1.grad, '\n')
print('wk gradients: ', weights_vec2.grad, '\n')

In [None]:
# for decoder w1 -- projecting from scalar to matrix dimensionwise

tens1 = torch.Tensor([1, 2])  # timesteps
tens2 = torch.Tensor([1, 2, 3, 4])  # pos encoding
tens1 = tens1.repeat(4, 1)
tens1.T * tens2

In [None]:
t1 = torch.Tensor([[1, 2], [1 ,2], [1, 2]])
t2 = torch.Tensor([[1, 4], [2 ,5], [3, 6], [4, 5]])


In [None]:
t1 = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
v = torch.nn.Parameter(torch.Tensor([3, 3.1, 3.2, 3.4]), requires_grad=True)
z = t1 * v
s = F.softmax(z, dim=1)
i_ = torch.multinomial(s, 1).squeeze()
print(i_)
selected = torch.cat((torch.unsqueeze(z[0][i_[0]], -1) , torch.unsqueeze(z[1][i_[1]], -1)))
target = torch.Tensor([4, 8])
l = torch.mean(target - z.T[i_])
l.backward()

In [None]:
z

In [None]:
v.grad

In [328]:
l = nn.LayerNorm(4)
t = torch.Tensor([[[1, 2, 3, 4], [2, 4, 6, 8]], [[1, 2, 3, 4], [2, 4, 4.2, 5.3]]])
l(t)

tensor([[[-1.3416, -0.4472,  0.4472,  1.3416],
         [-1.3416, -0.4472,  0.4472,  1.3416]],

        [[-1.3416, -0.4472,  0.4472,  1.3416],
         [-1.5752,  0.1050,  0.2730,  1.1971]]],
       grad_fn=<NativeLayerNormBackward>)

In [333]:
p = PositionalEncoding(64, max_len=5000)
p.pe

tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
           0.0000e+00,  1.0000e+00]],

        [[ 8.4147e-01,  5.4030e-01,  6.8156e-01,  ...,  1.0000e+00,
           1.3335e-04,  1.0000e+00]],

        [[ 9.0930e-01, -4.1615e-01,  9.9748e-01,  ...,  1.0000e+00,
           2.6670e-04,  1.0000e+00]],

        ...,

        [[ 9.5625e-01, -2.9254e-01,  6.4315e-01,  ...,  6.3049e-01,
           6.1813e-01,  7.8608e-01]],

        [[ 2.7050e-01, -9.6272e-01, -5.1133e-02,  ...,  6.3036e-01,
           6.1823e-01,  7.8599e-01]],

        [[-6.6395e-01, -7.4778e-01, -7.1816e-01,  ...,  6.3022e-01,
           6.1834e-01,  7.8591e-01]]])