In [1]:
import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union
from types import FunctionType
import torch
import torch.nn as nn
import pytorch_lightning as pl

# Vision Transformer

In [2]:
class ConvStemConfig(NamedTuple):
    out_channels: int = 64
    kernel_size: int = 3
    stride: int = 2
    norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d
    activation_layer: Callable[..., nn.Module] = nn.ReLU

In [3]:
def log_api_usage_once(obj: Any) -> None:

    """
    Logs API usage(module and name) within an organization.
    In a large ecosystem, it's often useful to track the PyTorch and
    TorchVision APIs usage. This API provides the similar functionality to the
    logging module in the Python stdlib. It can be used for debugging purpose
    to log which methods are used and by default it is inactive, unless the user
    manually subscribes a logger via the `SetAPIUsageLogger method <https://github.com/pytorch/pytorch/blob/eb3b9fe719b21fae13c7a7cf3253f970290a573e/c10/util/Logging.cpp#L114>`_.
    Please note it is triggered only once for the same API call within a process.
    It does not collect any data from open-source users since it is no-op by default.
    For more information, please refer to
    * PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging;
    * Logging policy: https://github.com/pytorch/vision/issues/5052;

    Args:
        obj (class instance or method): an object to extract info from.
    """
    module = obj.__module__
    if not module.startswith("torchvision"):
        module = f"torchvision.internal.{module}"
    name = obj.__class__.__name__
    if isinstance(obj, FunctionType):
        name = obj.__name__
    torch._C._log_api_usage_once(f"{module}.{name}")

In [4]:
def make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
    """
    Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
    Otherwise, we will make a tuple of length n, all with value of x.
    reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8

    Args:
        x (Any): input value
        n (int): length of the resulting tuple
    """
    if isinstance(x, collections.abc.Iterable):
        return tuple(x)
    return tuple(repeat(x, n))

In [5]:
def expand_index_like(index: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
    """Expands the index along the last dimension of the input tokens.

    Args:
        index:
            Index tensor with shape (batch_size, idx_length) where each entry is
            an index in [0, sequence_length).
        tokens:
            Tokens tensor with shape (batch_size, sequence_length, dim).

    Returns:
        Index tensor with shape (batch_size, idx_length, dim) where the original
        indices are repeated dim times along the last dimension.

    """
    dim = tokens.shape[-1]
    index = index.unsqueeze(-1).expand(-1, -1, dim)
    return index

In [6]:
def get_at_index(tokens: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
    """Selects tokens at index.

    Args:
        tokens:
            Token tensor with shape (batch_size, sequence_length, dim).
        index:
            Index tensor with shape (batch_size, index_length) where each entry is
            an index in [0, sequence_length).

    Returns:
        Token tensor with shape (batch_size, index_length, dim) containing the
        selected tokens.

    """
    index = expand_index_like(index, tokens)
    return torch.gather(tokens, 1, index)

In [7]:
class ConvNormActivation(torch.nn.Sequential):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, ...]] = 3,
        stride: Union[int, Tuple[int, ...]] = 1,
        padding: Optional[Union[int, Tuple[int, ...], str]] = None,
        groups: int = 1,
        norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
        activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
        dilation: Union[int, Tuple[int, ...]] = 1,
        inplace: Optional[bool] = True,
        bias: Optional[bool] = None,
        conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
    ) -> None:

        if padding is None:
            if isinstance(kernel_size, int) and isinstance(dilation, int):
                padding = (kernel_size - 1) // 2 * dilation
            else:
                _conv_dim = len(kernel_size) if isinstance(kernel_size, Sequence) else len(dilation)
                kernel_size = make_ntuple(kernel_size, _conv_dim)
                dilation = make_ntuple(dilation, _conv_dim)
                padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim))
        if bias is None:
            bias = norm_layer is None

        layers = [
            conv_layer(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                dilation=dilation,
                groups=groups,
                bias=bias,
            )
        ]

        if norm_layer is not None:
            layers.append(norm_layer(out_channels))

        if activation_layer is not None:
            params = {} if inplace is None else {"inplace": inplace}
            layers.append(activation_layer(**params))
        super().__init__(*layers)
        log_api_usage_once(self)
        self.out_channels = out_channels

        if self.__class__ == ConvNormActivation:
            warnings.warn(
                "Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead."
            )

In [8]:
class Conv2dNormActivation(ConvNormActivation):
    """
    Configurable block used for Convolution2d-Normalization-Activation blocks.

    Args:
        in_channels (int): Number of channels in the input image
        out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
        kernel_size: (int, optional): Size of the convolving kernel. Default: 3
        stride (int, optional): Stride of the convolution. Default: 1
        padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will be calculated as ``padding = (kernel_size - 1) // 2 * dilation``
        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
        norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer won't be used. Default: ``torch.nn.BatchNorm2d``
        activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
        dilation (int): Spacing between kernel elements. Default: 1
        inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
        bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.

    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int]] = 3,
        stride: Union[int, Tuple[int, int]] = 1,
        padding: Optional[Union[int, Tuple[int, int], str]] = None,
        groups: int = 1,
        norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
        activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
        dilation: Union[int, Tuple[int, int]] = 1,
        inplace: Optional[bool] = True,
        bias: Optional[bool] = None,
    ) -> None:

        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            groups,
            norm_layer,
            activation_layer,
            dilation,
            inplace,
            bias,
            torch.nn.Conv2d,
        )

In [9]:
class MLP(torch.nn.Sequential):
    """This block implements the multi-layer perceptron (MLP) module.

    Args:
        in_channels (int): Number of channels of the input
        hidden_channels (List[int]): List of the hidden channel dimensions
        norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None``
        activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
        inplace (bool, optional): Parameter for the activation layer, which can optionally do the operation in-place.
            Default is ``None``, which uses the respective default values of the ``activation_layer`` and Dropout layer.
        bias (bool): Whether to use bias in the linear layer. Default ``True``
        dropout (float): The probability for the dropout layer. Default: 0.0
    """

    def __init__(
        self,
        in_channels: int,
        hidden_channels: List[int],
        norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
        activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
        inplace: Optional[bool] = None,
        bias: bool = True,
        dropout: float = 0.0,
    ):
        # The addition of `norm_layer` is inspired from the implementation of TorchMultimodal:
        # https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py
        params = {} if inplace is None else {"inplace": inplace}

        layers = []
        in_dim = in_channels
        for hidden_dim in hidden_channels[:-1]:
            layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
            if norm_layer is not None:
                layers.append(norm_layer(hidden_dim))
            layers.append(activation_layer(**params))
            layers.append(torch.nn.Dropout(dropout, **params))
            in_dim = hidden_dim

        layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
        layers.append(torch.nn.Dropout(dropout, **params))

        super().__init__(*layers)
        log_api_usage_once(self)

In [10]:
class MLPBlock(MLP):
    """Transformer MLP block."""

    _version = 2

    def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
        super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        version = local_metadata.get("version", None)

        if version is None or version < 2:
            # Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
            for i in range(2):
                for type in ["weight", "bias"]:
                    old_key = f"{prefix}linear_{i+1}.{type}"
                    new_key = f"{prefix}{3*i}.{type}"
                    if old_key in state_dict:
                        state_dict[new_key] = state_dict.pop(old_key)

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )

In [11]:
class EncoderBlock(nn.Module):
    """Transformer encoder block."""

    def __init__(
        self,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        self.num_heads = num_heads

        # Attention block
        self.ln_1 = norm_layer(hidden_dim)
        self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)

        # MLP block
        self.ln_2 = norm_layer(hidden_dim)
        self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)

    def forward(self, input: torch.Tensor):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        x = self.ln_1(input)
        x, _ = self.self_attention(x, x, x, need_weights=False)
        x = self.dropout(x)
        x = x + input

        y = self.ln_2(x)
        y = self.mlp(y)
        return x + y

In [12]:
class Encoder(nn.Module):
    """Transformer Model Encoder for sequence to sequence translation."""

    def __init__(
        self,
        seq_length: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        num_cls_tokens: int,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6)
    ):
        super().__init__()
        self.num_cls_tokens = num_cls_tokens
        self.seq_length = seq_length
        # Note that batch_size is on the first dim because
        # we have batch_first=True in nn.MultiAttention() by default

        self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02))  # from BERT

        self.dropout = nn.Dropout(dropout)
        layers: OrderedDict[str, nn.Module] = OrderedDict()
        for i in range(num_layers):
            layers[f"encoder_layer_{i}"] = EncoderBlock(
                num_heads,
                hidden_dim,
                mlp_dim,
                dropout,
                attention_dropout,
                norm_layer,
            )
        self.layers = nn.Sequential(layers)
        self.ln = norm_layer(hidden_dim)

    def forward(self,
                input: torch.Tensor,
                idx_keep: Optional[torch.Tensor] = None
               ):

        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")

        input = input + self._interpolate_pos_encoding(input,self.num_cls_tokens)
        if idx_keep is not None:
            input = get_at_index(input, idx_keep)
        return self.ln(self.layers(self.dropout(input)))

    def _interpolate_pos_encoding(self,
                                  input: torch.Tensor,
                                  num_cls_tokens: int = 3):
        """Returns the interpolated positional embedding for the given input.

        This function interpolates self.pos_embedding for all tokens in the input,
        ignoring the class token. This allows encoding variable sized images.

        Args:
            input:
               Input tensor with shape (batch_size, num_sequences).
            num_cls_tokens:
               number of classification tokens prepended to the
        """

        npatch = input.shape[1] - num_cls_tokens
        N = self.pos_embedding.shape[1] - num_cls_tokens
        diff = num_cls_tokens -1
        if npatch == N:
            return self.pos_embedding

        else:
            npatch += diff
            class_emb = self.pos_embedding[:, 1]
            pos_embedding = self.pos_embedding[:, 1+diff:]
        dim = input.shape[-1]
        pos_embedding = nn.functional.interpolate(
            pos_embedding.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=math.sqrt(npatch / N),
            mode="bicubic",
        )
        pos_embedding = pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_emb.unsqueeze(0), pos_embedding), dim=1)

In [13]:
class VisionTransformer(nn.Module):
    """Vision Transformer as per https://arxiv.org/abs/2010.11929."""
    # defualt settings vits16
    def __init__(
        self,
        image_size: int = 224,
        patch_size: int = 16,
        num_layers: int = 12,
        num_heads: int = 3,
        hidden_dim: int = 192,
        mlp_dim: int = 768,
        dropout: float = 0.0,
        attention_dropout: float = 0.0,
        num_classes: int = 1000,
        num_cls_tokens: int = 3,
        representation_size: Optional[int] = None,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
        conv_stem_configs: Optional[List[ConvStemConfig]] = None,
    ):
        super().__init__()
        log_api_usage_once(self)
        torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
        self.image_size = image_size
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim
        self.mlp_dim = mlp_dim
        self.attention_dropout = attention_dropout
        self.dropout = dropout
        self.num_classes = num_classes
        self.representation_size = representation_size
        self.norm_layer = norm_layer
        self.num_cls_tokens = num_cls_tokens

        if conv_stem_configs is not None:
            # As per https://arxiv.org/abs/2106.14881
            seq_proj = nn.Sequential()
            prev_channels = 3
            for i, conv_stem_layer_config in enumerate(conv_stem_configs):
                seq_proj.add_module(
                    f"conv_bn_relu_{i}",
                    Conv2dNormActivation(
                        in_channels=prev_channels,
                        out_channels=conv_stem_layer_config.out_channels,
                        kernel_size=conv_stem_layer_config.kernel_size,
                        stride=conv_stem_layer_config.stride,
                        norm_layer=conv_stem_layer_config.norm_layer,
                        activation_layer=conv_stem_layer_config.activation_layer,
                    ),
                )
                prev_channels = conv_stem_layer_config.out_channels
            seq_proj.add_module(
                "conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1)
            )
            self.conv_proj: nn.Module = seq_proj
        else:
            self.conv_proj = nn.Conv2d(
                in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
            )

        self.seq_length = (image_size // patch_size) ** 2
        # Add a class token
        self.class_token = nn.Parameter(torch.zeros(1, self.num_cls_tokens, hidden_dim))
        self.seq_length += self.num_cls_tokens

        self.encoder = Encoder(
            self.seq_length,
            num_layers,
            num_heads,
            hidden_dim,
            mlp_dim,
            dropout,
            attention_dropout,
            num_cls_tokens,
            norm_layer)
        #self.seq_length = seq_length #+ self.num_cls_tokens

        heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
        if representation_size is None:
            heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
        else:
            heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
            heads_layers["act"] = nn.Tanh()
            heads_layers["head"] = nn.Linear(representation_size, num_classes)

        self.heads = nn.Sequential(heads_layers)

        if isinstance(self.conv_proj, nn.Conv2d):
            # Init the patchify stem
            fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
            nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
            if self.conv_proj.bias is not None:
                nn.init.zeros_(self.conv_proj.bias)
        elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d):
            # Init the last 1x1 conv of the conv stem
            nn.init.normal_(
                self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
            )
            if self.conv_proj.conv_last.bias is not None:
                nn.init.zeros_(self.conv_proj.conv_last.bias)

        if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear):
            fan_in = self.heads.pre_logits.in_features
            nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
            nn.init.zeros_(self.heads.pre_logits.bias)

        if isinstance(self.heads.head, nn.Linear):
            nn.init.zeros_(self.heads.head.weight)
            nn.init.zeros_(self.heads.head.bias)

    def _process_input(self, x: torch.Tensor) -> torch.Tensor:
        n, c, h, w = x.shape
        p = self.patch_size
        #torch._assert(h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!")
        #torch._assert(w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!")
        n_h = h // p
        n_w = w // p

        # (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
        x = self.conv_proj(x)
        # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
        x = x.reshape(n, self.hidden_dim, n_h * n_w)

        # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
        # The self attention layer expects inputs in the format (N, S, E)
        # where S is the source sequence length, N is the batch size, E is the
        # embedding dimension
        x = x.permute(0, 2, 1)

        return x

    def forward(self, x: torch.Tensor,
                branch: str = 'target',
                idx_keep: Optional[torch.Tensor] = None):
        # Reshape and permute the input tensor
        x = self._process_input(x)
        n = x.shape[0]

        # Expand the class token to the full batch
        if branch == 'target':
            batch_class_token = self.class_token.expand(n, -1, -1)
            x = torch.cat([batch_class_token, x], dim=1)
            x = self.encoder(x)
            return x[:,0:self.num_cls_tokens]
        elif branch == 'anchor':
            batch_class_token = self.class_token.expand(n, -1, -1)
            x = torch.cat([batch_class_token, x], dim=1)
            x = torch.cat((x[:,:1],x[:,self.num_cls_tokens:]),dim = 1)
            x = self.encoder(x,idx_keep=idx_keep)
            return x[:,0]

        # Classifier "token" as used by standard language architectures
        # else:
        #x = x[:, 0]
        #x = self.heads(x)

        # return x

In [14]:
# stemconfig = [ConvStemConfig(out_channels = 64, kernel_size = 3 , stride = 2) for i in range(4)]
transformer = VisionTransformer(image_size=224,
                                patch_size=16,
                                num_layers=12,
                                num_heads=3,
                                hidden_dim=192,
                                mlp_dim=768,
                                dropout=0.0,
                                attention_dropout=0.0,
                                num_cls_tokens=3,
#                                 conv_stem_configs=stemconfig
                               )

In [15]:
x= torch.randn(64,3,224,224)
transformer(x,branch='anchor')

tensor([[ 0.5596, -1.3186, -2.4964,  ...,  1.8443,  1.0027,  0.7661],
        [-1.4058, -0.5556, -1.9416,  ...,  0.8630,  0.3313, -0.3766],
        [ 0.2584, -1.4279, -0.2105,  ...,  2.4241,  0.0840, -1.0784],
        ...,
        [ 1.2345, -1.6663, -0.9487,  ...,  1.0602, -0.1681, -0.2912],
        [-0.0206,  0.6050, -1.1963,  ...,  1.8352,  1.9001, -0.1919],
        [-0.2642, -1.6112, -1.3572,  ...,  1.9571, -0.2857, -0.1676]],
       grad_fn=<SelectBackward0>)

# ChexMSN

In [16]:
import copy

In [17]:
def count_parameters(model:nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [18]:
def deactivate_requires_grad(model: nn.Module):
    """Deactivates the requires_grad flag for all parameters of a model.

    This has the same effect as permanently executing the model within a `torch.no_grad()`
    context. Use this method to disable gradient computation and therefore
    training for a model.

    Examples:
        >>> backbone = resnet18()
        >>> deactivate_requires_grad(backbone)
    """
    for param in model.parameters():
        param.requires_grad = False

In [19]:
@torch.no_grad()
def update_momentum(model: nn.Module, model_ema: nn.Module, m: float):
    """Updates parameters of `model_ema` with Exponential Moving Average of `model`

    Momentum encoders are a crucial component fo models such as MoCo or BYOL.

    Examples:
        >>> backbone = resnet18()
        >>> projection_head = MoCoProjectionHead()
        >>> backbone_momentum = copy.deepcopy(moco)
        >>> projection_head_momentum = copy.deepcopy(projection_head)
        >>>
        >>> # update momentum
        >>> update_momentum(moco, moco_momentum, m=0.999)
        >>> update_momentum(projection_head, projection_head_momentum, m=0.999)
    """
    for model_ema, model in zip(model_ema.parameters(), model.parameters()):
        model_ema.data = model_ema.data * m + model.data * (1.0 - m)

In [20]:
def random_token_mask(
    size: Tuple[int, int],
    mask_ratio: float = 0.15,
    mask_class_token: bool = False,
    device: Optional[Union[torch.device, str]] = None,
) -> torch.Tensor:
    """Creates random token masks.

    Args:
        size:
            Size of the token batch for which to generate masks.
            Should be (batch_size, sequence_length).
        mask_ratio:
            Percentage of tokens to mask.
        mask_class_token:
            If False the class token is never masked. If True the class token
            might be masked.
        device:
            Device on which to create the index masks.

    Returns:
        A (index_keep, index_mask) tuple where each index is a tensor.
        index_keep contains the indices of the unmasked tokens and has shape
        (batch_size, num_keep). index_mask contains the indices of the masked
        tokens and has shape (batch_size, sequence_length - num_keep).
        num_keep is equal to sequence_length * (1- mask_ratio).

    """
    batch_size, sequence_length = size
    num_keep = int(sequence_length * (1 - mask_ratio))

    noise = torch.rand(batch_size, sequence_length, device=device)
    if not mask_class_token and sequence_length > 0:
        # make sure that class token is not masked
        noise[:, 0] = -1
        num_keep = max(1, num_keep)

    # get indices of tokens to keep
    indices = torch.argsort(noise, dim=1)
    idx_keep = indices[:, :num_keep]
    idx_mask = indices[:, num_keep:]

    return idx_keep, idx_mask

In [21]:
class DenseBlock(nn.Module):
    def __init__(self,
                 in_features: int = 768,
                 out_features: int = 2048,
               bias: bool = False,
              ) -> None:

        super().__init__()

        self.dense_block = nn.Sequential(nn.Linear(in_features= in_features,
                                                   out_features= out_features,
                                               bias=bias),
                                         nn.LayerNorm(normalized_shape= out_features),
                                         nn.GELU()
                                        )
    def forward(self,
                x: torch.Tensor
                ) -> torch.Tensor:
        x = self.dense_block(x)
        return x

In [22]:
class ProjectionHead(nn.Module):
    def __init__(self,
                 in_features: int = 192,
                 hidden_features: int = 768,
                 out_features: int = 192,
                 bias : bool = False
                 ) -> None:

        super().__init__()
        self.in_features = in_features
        self.hidden_features = hidden_features
        self.out_features = out_features
        self.projection_head = nn.Sequential(DenseBlock(in_features= self.in_features,
                                                    out_features= self.hidden_features,
                                                    bias= bias
                                                    ),
                                         DenseBlock(in_features= self.hidden_features,
                                                    out_features= self.hidden_features,
                                                    bias= bias
                                                    ),
                                         nn.Linear(in_features= self.hidden_features,
                                                   out_features= self.out_features,
                                                   bias= bias
                                                   )
                                         )

    def forward(self,
                x: torch.Tensor
                ) -> torch.Tensor:
        x = self.projection_head(x)
        return x

In [23]:
projection_head = ProjectionHead(in_features=192,
                                 hidden_features=768,
                                 out_features=192)

In [24]:
class ChexMSN(nn.Module):
    def __init__(self,
               backbone: nn.Module,
               projection_head: nn.Module,
               masking_ratio: float = 0.15,
               ema_p: float = 0.996,
               focal: bool = True
              ) -> None:
        super().__init__()

        self.masking_ratio = masking_ratio
        self.ema_p = ema_p
        self.focal = focal

        self.backbone = backbone
        self.projection_head = projection_head

        self.target_backbone = copy.deepcopy(self.backbone)
        self.target_projection_head = copy.deepcopy(self.projection_head)

        deactivate_requires_grad(self.target_backbone)
        deactivate_requires_grad(self.target_projection_head)



    def forward(self,
                views:list[torch.Tensor],
                focal: bool = True
                ) -> tuple[torch.Tensor]:

        update_momentum(model= self.backbone,
                        model_ema= self.target_backbone,
                        m = self.ema_p,
                       )
        update_momentum(model= self.projection_head,
                        model_ema= self.target_projection_head,
                        m = self.ema_p,
                       )
        projections = self._forward_all(batch=views,focal=self.focal)
        

        return projections

    def _target_forward(self,
                        view: torch.Tensor
                        ) -> torch.Tensor:

        target_encodings = self.target_backbone(x= view,
                                                branch='target')
        target_projections = self.target_projection_head(x= target_encodings)

        return target_projections


    def _anchor_forward(self,
                        view: torch.Tensor
                        ) -> torch.Tensor:

        batch_size, _, _, width = view.shape
        seq_length = (width // self.backbone.patch_size) ** 2
        idx_keep, idx_mask = random_token_mask(size= (view.shape[0],seq_length),
                                               mask_ratio= self.masking_ratio,
                                               device=torch.device("cuda"))

        anchor_encodings = self.backbone(x= view,
                                         branch= 'anchor',
                                         idx_keep= idx_keep)
        anchor_projections = self.projection_head(x= anchor_encodings)

        return anchor_projections
    
    
    def _forward_all(self,
                     batch: list,
                     focal: bool = True
                     ) -> torch.Tensor:
        
        target_view = batch[0][0].to('cuda')
        anchor_view_sim = batch[0][1].to('cuda')
        focal_views_sim = torch.concat(batch[0][2:],dim=0).to('cuda')
        anchor_view_age = batch[1][1].to('cuda')
        focal_views_age = torch.concat(batch[1][2:],dim=0).to('cuda')
        anchor_view_gender = batch[2][1].to('cuda')
        focal_views_gender = torch.concat(batch[2][2:],dim=0).to('cuda')
        
        
        target_projections = self._target_forward(target_view)
        
        anchor_projections_sim = self._anchor_forward(anchor_view_sim)
        if focal:
            anchor_focal_projections_sim = self._anchor_forward(focal_views_sim)
            similarity_projections = self._arrange_tokens(anchor_projections_sim,
                                                          anchor_focal_projections_sim,
                                                          num_focal = 10)
                                                          
        
        anchor_projections_age = self._anchor_forward(anchor_view_age)
        if focal:
            anchor_focal_projections_age =  self._anchor_forward(focal_views_age)
            age_projections = self._arrange_tokens(anchor_projections_age,
                                                   anchor_focal_projections_age,
                                                   num_focal = 10)
                                                    
        
        
        anchor_projections_gender = self._anchor_forward(anchor_view_gender)
        if focal:
            anchor_focal_projections_gender = self._anchor_forward(focal_views_gender)  
            gender_projections = self._arrange_tokens(anchor_projections_gender,
                                                      anchor_focal_projections_gender,
                                                      num_focal = 10)
                                                          
        if focal:
            anchor_projections = torch.stack((similarity_projections,
                                              age_projections,
                                              gender_projections
                                              ),
                                         dim= 0)
        else:
            anchor_projections = torch.stack((anchor_projections_sim,
                                              anchor_projections_age,
                                              anchor_projections_gender
                                              ),
                                             dim= 1)       

        return (anchor_projections,
                target_projections)
    
    def _arrange_tokens(self,
                        tensor1: torch.Tensor,
                        tensor2:torch.Tensor,
                        num_focal: int = 10
                        ) ->torch.Tensor:

        a = torch.stack(torch.split(tensor1,1),0)
        b = torch.stack(torch.split(tensor2,num_focal),0)
        c = torch.cat((a.expand(-1,num_focal,-1),b),dim=1)[:,num_focal-1:]
        arranged_tokens = torch.cat(c.split(1),1).squeeze(0)
        return arranged_tokens

In [25]:
msn = ChexMSN(backbone=transformer,
             projection_head=projection_head,
             masking_ratio=0.15,
             ema_p=0.996,
             focal=True)

# MSN Loss

In [26]:
import math
import warnings
from typing import Union

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

In [27]:
def prototype_probabilities(queries: torch.Tensor,
                            prototypes: torch.Tensor,
                            temperature: float,
                            ) -> torch.Tensor:
    """Returns probability for each query to belong to each prototype.

    Args:
        queries:
            Tensor with shape (batch_size, dim), projection head output
        prototypes:
            Tensor with shape (num_prototypes, dim)
        temperature:
            Inverse scaling factor for the similarity.

    Returns:
        Probability tensor with shape (batch_size, num_prototypes) which sums to 1 along
        the num_prototypes dimension.

    """                           
    return F.softmax(torch.matmul(queries, prototypes.T) / temperature, dim=1)

In [28]:
def sharpen(probabilities: torch.Tensor, 
            temperature: float
           ) -> torch.Tensor:
    """Sharpens the probabilities with the given temperature.

    Args:
        probabilities:
            Tensor with shape (batch_size, dim)
        temperature:
            Temperature in (0, 1]. Lower temperature results in stronger sharpening (
            output probabilities are less uniform).
    Returns:
        Probabilities tensor with shape (batch_size, num_prototypes).

    """
    probabilities = probabilities ** (1.0 / temperature)
    probabilities /= torch.sum(probabilities, dim=1, keepdim=True)
    return probabilities

In [29]:
@torch.no_grad()
def sinkhorn(probabilities: torch.Tensor,
             iterations: int = 3,
             gather_distributed: bool = False,
            ) -> torch.Tensor:
    """Runs sinkhorn normalization on the probabilities as described in [0].

    Code inspired by [1].

    - [0]: Masked Siamese Networks, 2022, https://arxiv.org/abs/2204.07141
    - [1]: https://github.com/facebookresearch/msn

    Args:
        probabilities:
            Probabilities tensor with shape (batch_size, num_prototypes).
        iterations:
            Number of iterations of the sinkhorn algorithms. Set to 0 to disable.
        gather_distributed:
            If True then features from all gpus are gathered during normalization.
    Returns:
        A normalized probabilities tensor.

    """
    if iterations <= 0:
        return probabilities


    num_targets, num_prototypes = probabilities.shape
    probabilities = probabilities.T
    sum_probabilities = torch.sum(probabilities)

    probabilities = probabilities / sum_probabilities

    for _ in range(iterations):
        # normalize rows
        row_sum = torch.sum(probabilities, dim=1, keepdim=True)

        probabilities /= row_sum
        probabilities /= num_prototypes

        # normalize columns
        probabilities /= torch.sum(probabilities, dim=0, keepdim=True)
        probabilities /= num_targets

    probabilities *= num_targets
    return probabilities.T

In [30]:
def regularization_loss(mean_anchor_probs: torch.Tensor
                       ) -> torch.Tensor:
    """Calculates mean entropy regularization loss."""
    loss = -torch.sum(torch.log(mean_anchor_probs ** (-mean_anchor_probs)))
    loss += math.log(float(len(mean_anchor_probs)))
    return loss

In [31]:
class MSNLoss(nn.Module):
    def __init__(self,
                 temperature: float = 0.1,
                 sinkhorn_iterations: int = 3,
                 similarity_weight: float = 1.0,
                 age_weight: float = 1.0,
                 gender_weight: float = 1.0,
                 regularization_weight: float = 1.0,
                 gather_distributed: bool = False,
               ) -> None:

        super().__init__()
        
        self.temperature = temperature
        self.sinkhorn_iterations = sinkhorn_iterations
        self.similarity_weight = similarity_weight
        self.age_weight = age_weight
        self.gender_weight = gender_weight
        self.regularization_weight = regularization_weight

        
    def forward(self,
                anchors: torch.Tensor,
                targets: torch.Tensor,
                prototypes: torch.Tensor,
                target_sharpen_temperature: float = 0.25,
                focal: bool = True
                      ) -> torch.Tensor:
        
        similarity_loss = self.similarity_weight * self._forward_loss(anchors=anchors[0] if focal else anchors[:,0],
                                                                      targets=targets[:,0],
                                                                      prototypes=prototypes[0].weight.data
                                                                     )

        age_loss = self.age_weight * self._forward_loss(anchors=anchors[1] if focal else anchors[:,1],
                                                           targets=targets[:,1],
                                                           prototypes=prototypes[1].weight.data
                                                          )

        gender_loss = self.gender_weight * self._forward_loss(anchors=anchors[2] if focal else anchors[:,2],
                                                              targets=targets[:,2],
                                                              prototypes=prototypes[2].weight.data
                                                             )

        loss = similarity_loss + age_loss + gender_loss
        return loss
    
    def _forward_loss(self,
                anchors: torch.Tensor,
                targets: torch.Tensor,
                prototypes: torch.Tensor,
                target_sharpen_temperature: float = 0.25,
               ) -> torch.Tensor:

        num_views = anchors.shape[0] // targets.shape[0]
        anchors = F.normalize(anchors, dim=1)
        targets = F.normalize(targets, dim=1)
        prototypes = F.normalize(prototypes, dim=1)

        anchor_probs = prototype_probabilities(anchors, 
                                               prototypes, 
                                               temperature=self.temperature
                                              )

        with torch.no_grad():
            target_probs = prototype_probabilities(targets, 
                                                   prototypes, 
                                                   temperature=self.temperature
                                                   )
            target_probs = sharpen(target_probs, temperature=target_sharpen_temperature)
            if self.sinkhorn_iterations > 0:
                target_probs = sinkhorn(probabilities=target_probs,
                                        iterations=self.sinkhorn_iterations,
                                        )
            target_probs = torch.repeat_interleave(target_probs, repeats=num_views, dim=0)
        
        loss = torch.mean(torch.sum(torch.log(anchor_probs ** (-target_probs)), dim=1))

        # regularization loss
        if self.regularization_weight > 0:
            mean_anchor_probs = torch.mean(anchor_probs, dim=0)
            reg_loss = regularization_loss(mean_anchor_probs=mean_anchor_probs)
            loss += self.regularization_weight * reg_loss
         
        return loss

In [32]:
criterterion = MSNLoss(temperature=0.1,
                       sinkhorn_iterations=0,
                       similarity_weight=1.0,
                       age_weight=1.0,
                       gender_weight=1.0,
                       regularization_weight=0.0)

# MSN transforms

In [33]:
from typing import Dict, List, Optional, Tuple, Union

import torchvision.transforms as T
from PIL.Image import Image
from torch import Tensor


from lightly.transforms.multi_view_transform import MultiViewTransform
IMAGENET_STAT = {"mean":torch.tensor([0.4884, 0.4550, 0.4171]),
                 "std":torch.tensor([0.2596, 0.2530, 0.2556])}

MIMIC_NORMALIZE ={"mean":torch.tensor([0.4723, 0.4723, 0.4723]), 
                  "std":torch.tensor([0.3023, 0.3023, 0.3023])}

In [34]:
class MSNTransform(MultiViewTransform):
    """Implements the transformations for MSN [0].

    Input to this transform:
        PIL Image or Tensor.

    Output of this transform:
        List of Tensor of length 2 * random_views + focal_views. (12 by default)

    Applies the following augmentations by default:
        - Random resized crop
        - Random horizontal flip
        - ImageNet normalization

    Generates a set of random and focal views for each input image. The generated output
    is (views, target, filenames) where views is list with the following entries:
    [random_views_0, random_views_1, ..., focal_views_0, focal_views_1, ...].

    - [0]: Masked Siamese Networks, 2022: https://arxiv.org/abs/2204.07141

    Attributes:
        random_size:
            Size of the random image views in pixels.
        focal_size:
            Size of the focal image views in pixels.
        random_views:
            Number of random views to generate.
        focal_views:
            Number of focal views to generate.
        random_crop_scale:
            Minimum and maximum size of the randomized crops for the relative to random_size.
        focal_crop_scale:
            Minimum and maximum size of the randomized crops relative to focal_size.
        hf_prob:
            Probability that horizontal flip is applied.
        vf_prob:
            Probability that vertical flip is applied.
        normalize:
            Dictionary with 'mean' and 'std' for torchvision.transforms.Normalize.
    """

    def __init__(
        self,
        random_size: int = 224,
        focal_size: int = 96,
        random_views: int = 2,
        focal_views: int = 10,
        affine_dgrees: int = 15,
        affine_scale: Tuple[float,float]= (.9, 1.1),
        affine_shear: int = 0,
        affine_translate: Tuple[float,float] = (0.1, 0.1),
        random_crop_scale: Tuple[float, float] = (0.3, 1.0),
        focal_crop_scale: Tuple[float, float] = (0.05, 0.3),
        hf_prob: float = 0.5,
        vf_prob: float = 0.0,
        normalize: Dict[str, List[float]] = MIMIC_NORMALIZE,
    ):
        random_view_transform = MSNViewTransform(
            affine_dgrees=affine_dgrees,
            affine_scale=affine_scale,
            affine_shear=affine_shear,
            affine_translate=affine_translate,
            crop_size=random_size,
            crop_scale=random_crop_scale,
            hf_prob=hf_prob,
            vf_prob=vf_prob,
            normalize=normalize,
        )
        focal_view_transform = MSNViewTransform(
            affine_dgrees=affine_dgrees,
            affine_scale=affine_scale,
            affine_shear=affine_shear,
            affine_translate=affine_translate,
            crop_size=focal_size,
            crop_scale=focal_crop_scale,
            hf_prob=hf_prob,
            vf_prob=vf_prob,
            normalize=normalize,
        )
        transforms = [random_view_transform] * random_views
        transforms += [focal_view_transform] * focal_views
        super().__init__(transforms=transforms)

In [35]:
class MSNViewTransform:
    def __init__(
        self,
        affine_dgrees: int = 15,
        affine_scale: Tuple[float,float]= (.9, 1.1),
        affine_shear: int = 0,
        affine_translate: Tuple[float,float] = (0.1, 0.1),
        crop_size: int = 224,
        crop_scale: Tuple[float, float] = (0.3, 1.0),
        hf_prob: float = 0.5,
        vf_prob: float = 0.0,
        normalize: Dict[str, List[float]] = MIMIC_NORMALIZE,
    ):

        transform = [
            T.RandomAffine(degrees=affine_dgrees, 
                           scale=affine_scale, 
                           shear=affine_shear, 
                           translate=affine_translate),
            T.RandomResizedCrop(size=crop_size, scale=crop_scale),
            T.RandomHorizontalFlip(p=hf_prob),
            T.RandomVerticalFlip(p=vf_prob),
            T.ToTensor(),
            T.Normalize(mean=normalize["mean"], std=normalize["std"]),
        ]

        self.transform = T.Compose(transform)

    def __call__(self, image: Union[torch.Tensor, Image]) -> torch.Tensor:
        """
        Applies the transforms to the input image.

        Args:
            image:
                The input image to apply the transforms to.

        Returns:
            The transformed image.

        """
        transformed = self.transform(image)
        return transformed

# MSN dataset

In [36]:
import os
import torch
import torch.nn as nn 
import pytorch_lightning as pl

import random
from collections import Counter
from typing import Tuple, Optional
from torch import Tensor
from PIL import Image
from torch.utils.data import Dataset

import os
import torch
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn as nn
import torchvision.transforms as T
from tqdm import tqdm
from torch import Tensor
from copy import deepcopy
from typing import List, Tuple, Dict
from torch.utils.data import DataLoader



In [37]:
class ChexMSNDataset(Dataset):
    def __init__(self, 
                 data_dir: str,
                 transforms: nn.Module,
                 same = True
                 ) -> None:
      
        self.meta = pd.read_csv(data_dir)
        self.all_images = list(self.meta.path)
        self.transform = transforms
        self.same = same
        
    def __len__(self
                ) -> int:
        return len(self.all_images)
    
    def __getitem__(self,
                    index: int
                    ) -> Tuple[torch.Tensor]:

        
        target_path = self.all_images[index]
        image_id = target_path.split('/')[-1][:-4]
        img_age_path, img_gender_path = self._retrieve_anchors(image_id=image_id,
                                                               meta = self.meta,
                                                               same=self.same)

        img_target = Image.open(fp=target_path).convert('RGB')
        img_target = self.transform(img_target)
        
        img_age = Image.open(fp=img_age_path).convert('RGB')
        img_age = self.transform(img_age)

        img_gender = Image.open(fp=img_gender_path).convert('RGB')
        img_gender = self.transform(img_gender)

        return (img_target,img_age,img_gender)
    
    
    def _retrieve_anchors(self,
                          image_id: str,
                          meta: pd.DataFrame,
                          same: bool = False) -> Tuple[str]:
        record = meta[meta.dicom_id == image_id]
    
        subject_id = list(record.subject_id)[0]
        age_groub =list(record.ageR5)[0] 
        gender = list(record.gender)[0]
    
        group = meta[meta.ageR5 == age_groub]
    
        if same:
            candidate_anchors = group[group.gender == gender]
            candidate_anchors = candidate_anchors[candidate_anchors.subject_id != subject_id]
            images= list(candidate_anchors.path)
            sampled_images = random.sample(images,k=2)
            image_age, image_gender = sampled_images[0],sampled_images[1]
            return image_age, image_gender
        else:
            candidate_anchors = group
            candidate_anchors = candidate_anchors[candidate_anchors.subject_id != subject_id]
            images= list(candidate_anchors.path)
            image_age = random.sample(images,k=1)[0]
            candidate_anchors = candidate_anchors[candidate_anchors.gender == gender]
            images= list(candidate_anchors.path)
            image_gender = random.sample(images,k=1)[0]
            return image_age, image_gender

In [38]:
def get_mean_and_std(dataloader):
    channels_sum, channels_squared_sum, num_batches = 0, 0, 0
    for data in tqdm(dataloader):
        # Mean over batch, height and width, but not over the channels
        channels_sum += torch.mean(data, dim=[0,2,3])
        channels_squared_sum += torch.mean(data**2, dim=[0,2,3])
        num_batches += 1
    
    mean = channels_sum / num_batches

    # std = sqrt(E[X^2] - (E[X])^2)
    std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5

    return mean, std

In [39]:
transforms = MSNTransform()
dataset = ChexMSNDataset(data_dir='../data/meta.csv',transforms= transforms,same=True)
dataloader = DataLoader(dataset=dataset,batch_size=10,num_workers=24,pin_memory=True,shuffle=True)

In [50]:
import torch.nn.functional as F
class ChexMSNModel(pl.LightningModule):
    def __init__(self,
                 model: nn.Module,
                 criterion: nn.Module,
                 num_prototypes: int = 1024,
                 learning_rate: float =  1e-3,
                 weight_decay: float = 0.0,
                 max_epochs: int = 50,
                 focal: bool = True

                ) -> None:
        super().__init__()

        self.model = model
        self.criterion = criterion
        self.num_prototypes = num_prototypes
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.max_epochs = max_epochs
        self.focal = focal
        self.in_prototypes = self.model.projection_head.out_features
        self.prototypes = nn.ModuleList([nn.Linear(in_features=self.in_prototypes,out_features=num_prototypes,bias=False),
                                         nn.Linear(in_features=self.in_prototypes,out_features=num_prototypes,bias=False),
                                         nn.Linear(in_features=self.in_prototypes,out_features=num_prototypes,bias=False)])
    
    
    def training_step(self, 
                      batch: List[torch.Tensor], 
                      batch_idx: int
                     ) -> float:
        
        
        anchors, target = self.model(batch) 
        print('\n',target)
        loss = self.criterion(anchors,target,self.prototypes,focal=self.focal)
        self.log("train_loss", loss, on_epoch= True,on_step=True, logger=True,)
        print('\n',loss.detach())
        return loss


    def configure_optimizers(self):

        params = [
            *list(self.model.backbone.parameters()),
            *list(self.model.projection_head.parameters()),
            self.prototypes[0].weight,
            self.prototypes[1].weight,
            self.prototypes[2].weight,
        ]
        
        optimizer = torch.optim.AdamW(params, 
                                      lr=self.learning_rate,
                                      weight_decay=self.weight_decay)
        
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,
                                                               eta_min=0.00001,
                                                               T_max=self.max_epochs)
        return {'optimizer': optimizer,
               'lr_scheduler': scheduler
               }


In [51]:

transforms = MSNTransform()
dataset = ChexMSNDataset(data_dir='../data/meta.csv',transforms= transforms,same=True)
dataloader = DataLoader(dataset=dataset,batch_size=64,num_workers=24,pin_memory=True,shuffle=True)

In [52]:
criterterion = MSNLoss(temperature=0.1,
                       sinkhorn_iterations=0,
                       similarity_weight=1.0,
                       age_weight=0.0,
                       gender_weight=0.0,
                       regularization_weight=5)

In [53]:
model = ChexMSNModel(msn,criterterion,focal=True)



In [54]:
trainer = pl.Trainer(devices="auto",
                     accelerator="auto",
                     strategy="auto",
                     log_every_n_steps=1,
                     max_epochs=50,
                     precision=16,
                     )

/home/sas10092/.conda/envs/chexmsn-env/lib/python3.9/site-packages/lightning_fabric/connector.py:565: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
/home/sas10092/.conda/envs/chexmsn-env/lib/python3.9/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/sas10092/.conda/envs/chexmsn-env/lib/python3.9 ...
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model=model, train_dataloaders=dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type       | Params
------------------------------------------
0 | model      | ChexMSN    | 13.2 M
1 | criterion  | MSNLoss    | 0     
2 | prototypes | ModuleList | 589 K 
------------------------------------------
7.2 M     Trainable params
6.6 M     Non-trainable params
13.8 M    Total params
55.207    Total estimated model params size (MB)


Epoch 0:   0%|          | 0/5879 [00:00<?, ?it/s] 
 tensor([[[-0.9858, -2.0000,  0.5649,  ...,  1.3418, -0.1093, -1.0986],
         [-0.9985, -2.0723,  0.5146,  ...,  1.2842, -0.0777, -1.0557],
         [-0.9292, -1.7441,  0.6270,  ...,  1.4316, -0.2272, -1.1270]],

        [[-0.9849, -1.9844,  0.5679,  ...,  1.3447, -0.1246, -1.1055],
         [-0.9980, -2.0723,  0.5132,  ...,  1.2783, -0.0837, -1.0586],
         [-0.9097, -1.6914,  0.6328,  ...,  1.4375, -0.2483, -1.1143]],

        [[-0.4844,  0.3589,  0.0652,  ...,  0.1214,  2.1953, -1.8428],
         [-0.4153,  0.3186,  0.0768,  ...,  0.1132,  2.0527, -1.7744],
         [-0.4797,  0.5215, -0.0115,  ...,  0.1394,  2.5078, -1.8115]],

        ...,

        [[-1.0059, -2.0977,  0.5195,  ...,  1.2715, -0.0591, -1.0654],
         [-1.0137, -2.1738,  0.4424,  ...,  1.1816, -0.0419, -0.9946],
         [-0.9019, -1.6621,  0.6338,  ...,  1.4434, -0.2625, -1.1143]],

        [[-1.0107, -2.1426,  0.4912,  ...,  1.2275, -0.0394, -1.0352],
   

Epoch 0:   0%|          | 6/5879 [00:16<4:25:09,  0.37it/s, v_num=7]
 tensor([[[-1.0244, -2.2578,  0.4160,  ...,  1.1162,  0.0383, -0.9971],
         [-1.0127, -2.2988,  0.2881,  ...,  0.9722,  0.0102, -0.8833],
         [-0.8657, -1.5977,  0.6245,  ...,  1.4453, -0.2639, -1.1260]],

        [[-0.1627,  0.2820, -0.6641,  ...,  0.1545,  0.1186, -0.8262],
         [-0.1456, -0.0912, -0.8613,  ..., -0.0031, -0.4609, -0.3669],
         [-0.4397,  0.6753, -0.0984,  ...,  0.1534,  2.3359, -1.6709]],

        [[-1.0186, -2.1621,  0.5015,  ...,  1.2373,  0.0156, -1.0830],
         [-1.0264, -2.2285,  0.4214,  ...,  1.1436,  0.0245, -1.0088],
         [-0.9092, -1.7129,  0.6260,  ...,  1.4316, -0.2045, -1.1514]],

        ...,

        [[-0.3811,  0.3171, -0.1218,  ...,  0.1633,  1.4473, -1.5625],
         [-0.2141,  0.2260, -0.2959,  ...,  0.1438,  0.8057, -1.1719],
         [-0.4863,  0.5337,  0.0215,  ...,  0.1779,  2.4648, -1.7812]],

        [[-0.3108, -0.6084, -0.9683,  ..., -0.0640, -0.7

Epoch 0:   0%|          | 12/5879 [00:22<3:03:11,  0.53it/s, v_num=7]
 tensor([[[-3.0151e-01, -6.5674e-01, -1.0176e+00,  ..., -1.2469e-01,
          -7.2900e-01,  3.0762e-02],
         [-3.3179e-01, -9.2432e-01, -9.9512e-01,  ..., -2.5269e-01,
          -6.7480e-01,  8.3008e-02],
         [-3.8818e-01,  8.5205e-01, -1.9519e-01,  ...,  1.9470e-01,
           2.0488e+00, -1.4736e+00]],

        [[-9.1455e-01, -2.2910e+00,  8.9502e-04,  ...,  5.8057e-01,
           7.0429e-04, -5.8887e-01],
         [-8.7451e-01, -2.2402e+00, -1.2976e-01,  ...,  4.3604e-01,
          -6.1676e-02, -4.9683e-01],
         [-7.7002e-01, -1.4023e+00,  5.6982e-01,  ...,  1.4355e+00,
          -3.1494e-01, -1.0967e+00]],

        [[-1.0127e+00, -2.1621e+00,  5.0537e-01,  ...,  1.2520e+00,
           8.0566e-02, -1.1631e+00],
         [-1.0264e+00, -2.2402e+00,  4.4092e-01,  ...,  1.1709e+00,
           1.0828e-01, -1.1045e+00],
         [-9.1797e-01, -1.8008e+00,  5.9912e-01,  ...,  1.3994e+00,
          -9.8511

Epoch 0:   0%|          | 17/5879 [00:25<2:24:49,  0.67it/s, v_num=7]
 tensor([[[-0.3765, -0.9741, -0.9839,  ..., -0.2211, -0.6709,  0.1041],
         [-0.3906, -1.1367, -0.9727,  ..., -0.3179, -0.6128,  0.1066],
         [-0.2595,  1.0830, -0.3743,  ...,  0.3679,  1.0732, -1.1641]],

        [[-0.8862, -2.2793, -0.1307,  ...,  0.3987,  0.0242, -0.5308],
         [-0.7700, -2.0996, -0.3774,  ...,  0.1229, -0.1216, -0.3250],
         [-0.7783, -1.4434,  0.5444,  ...,  1.4170, -0.2455, -1.1367]],

        [[-0.5518,  0.1718,  0.2910,  ...,  0.3020,  1.9863, -1.7686],
         [-0.5127,  0.1536,  0.2979,  ...,  0.2883,  1.9121, -1.7354],
         [-0.5942,  0.3213,  0.2822,  ...,  0.3184,  2.4570, -1.8301]],

        ...,

        [[-0.9463, -2.3594, -0.0192,  ...,  0.5415,  0.0911, -0.6655],
         [-0.8491, -2.2168, -0.2524,  ...,  0.2732, -0.0454, -0.4622],
         [-0.8325, -1.5801,  0.5625,  ...,  1.4180, -0.1744, -1.1758]],

        [[-0.7856, -2.1172, -0.3079,  ...,  0.2023, -0.

Epoch 0:   0%|          | 23/5879 [00:31<2:12:11,  0.74it/s, v_num=7]
 tensor([[[-0.5317, -1.4629, -0.8188,  ..., -0.2969, -0.4365,  0.0756],
         [-0.4343, -1.2695, -0.9634,  ..., -0.4062, -0.5459,  0.1032],
         [-0.3303,  1.0332, -0.2898,  ...,  0.3396,  1.3896, -1.2266]],

        [[-1.0293, -2.4922,  0.0646,  ...,  0.6562,  0.2844, -0.8848],
         [-0.9849, -2.4414, -0.0862,  ...,  0.4856,  0.2052, -0.7568],
         [-0.8623, -1.6973,  0.5093,  ...,  1.3975, -0.0563, -1.2471]],

        [[-1.0488, -2.3164,  0.4023,  ...,  1.1699,  0.2971, -1.2686],
         [-1.0645, -2.3965,  0.3335,  ...,  1.0732,  0.3257, -1.1992],
         [-0.9531, -1.9521,  0.5044,  ...,  1.3525,  0.0905, -1.3076]],

        ...,

        [[-1.0215, -2.2031,  0.4470,  ...,  1.2588,  0.2223, -1.3027],
         [-1.0342, -2.2656,  0.4094,  ...,  1.2139,  0.2581, -1.2734],
         [-0.9717, -2.0059,  0.4937,  ...,  1.3359,  0.1185, -1.3174]],

        [[-0.5977, -1.7109, -0.6943,  ..., -0.2142, -0.

Epoch 0:   0%|          | 29/5879 [00:32<1:49:50,  0.89it/s, v_num=7]
 tensor([[[-1.1035, -2.5684,  0.1296,  ...,  0.7915,  0.4543, -1.1055],
         [-1.0537, -2.5410, -0.0422,  ...,  0.5659,  0.3701, -0.9219],
         [-0.9478, -1.8994,  0.4478,  ...,  1.3711,  0.1147, -1.3496]],

        [[-1.0869, -2.4023,  0.3284,  ...,  1.1299,  0.4150, -1.3301],
         [-1.0967, -2.4688,  0.2700,  ...,  1.0479,  0.4417, -1.2734],
         [-0.9937, -2.0371,  0.4346,  ...,  1.3359,  0.1963, -1.3770]],

        [[-0.9121, -2.3398, -0.2952,  ...,  0.1824,  0.1886, -0.5693],
         [-0.8545, -2.2559, -0.4209,  ...,  0.0430,  0.1054, -0.4634],
         [-0.8652, -1.6494,  0.4409,  ...,  1.4014, -0.0393, -1.2891]],

        ...,

        [[-1.0635, -2.2949,  0.3718,  ...,  1.2227,  0.3430, -1.3682],
         [-1.0801, -2.3750,  0.3254,  ...,  1.1514,  0.3894, -1.3252],
         [-1.0039, -2.0664,  0.4294,  ...,  1.3252,  0.2134, -1.3828]],

        [[-1.0762, -2.3496,  0.3523,  ...,  1.1807,  0.

Epoch 0:   1%|          | 35/5879 [00:41<1:55:52,  0.84it/s, v_num=7]
 tensor([[[-0.3127,  0.0812, -0.0392,  ...,  0.3391,  0.4133, -0.9341],
         [-0.2416,  0.0565, -0.1410,  ...,  0.2874,  0.1523, -0.7808],
         [-0.7188,  0.1396,  0.5420,  ...,  0.5024,  2.2676, -1.6572]],

        [[-1.1240, -2.4473,  0.2742,  ...,  1.1348,  0.5063, -1.4209],
         [-1.1367, -2.5195,  0.2191,  ...,  1.0449,  0.5420, -1.3623],
         [-1.0498, -2.1465,  0.3608,  ...,  1.3096,  0.3181, -1.4570]],

        [[-0.4441, -1.2549, -1.0186,  ..., -0.4993, -0.5186,  0.1127],
         [-0.4434, -1.3428, -1.0264,  ..., -0.5596, -0.4614,  0.0995],
         [-0.7070, -0.9634,  0.2808,  ...,  1.3379, -0.2544, -1.1699]],

        ...,

        [[-1.1455, -2.5430,  0.2186,  ...,  1.0186,  0.5635, -1.3643],
         [-1.1465, -2.5996,  0.1356,  ...,  0.8906,  0.5703, -1.2686],
         [-1.0342, -2.1016,  0.3689,  ...,  1.3232,  0.2925, -1.4521]],

        [[-1.1162, -2.4160,  0.2876,  ...,  1.1670,  0.

Epoch 0:   1%|          | 41/5879 [00:45<1:48:26,  0.90it/s, v_num=7]
 tensor([[[-1.1807, -2.5840,  0.1804,  ...,  1.0430,  0.6548, -1.4736],
         [-1.1875, -2.6484,  0.1169,  ...,  0.9395,  0.6816, -1.4023],
         [-1.0850, -2.1934,  0.3059,  ...,  1.3018,  0.4016, -1.5312]],

        [[-1.1846, -2.5898,  0.1781,  ...,  1.0361,  0.6582, -1.4639],
         [-1.1865, -2.6367,  0.1313,  ...,  0.9663,  0.6777, -1.4141],
         [-1.0947, -2.2168,  0.3018,  ...,  1.2979,  0.4150, -1.5322]],

        [[-1.1934, -2.6777,  0.0948,  ...,  0.8613,  0.6982, -1.3652],
         [-1.1777, -2.6992,  0.0055,  ...,  0.7271,  0.6777, -1.2588],
         [-1.0820, -2.1797,  0.3081,  ...,  1.3105,  0.3923, -1.5273]],

        ...,

        [[-1.0225, -2.5137, -0.3293,  ...,  0.1478,  0.4507, -0.7808],
         [-0.9321, -2.3809, -0.4890,  ..., -0.0570,  0.3181, -0.6030],
         [-1.0352, -2.0293,  0.3225,  ...,  1.3525,  0.2927, -1.4990]],

        [[-1.1699, -2.5195,  0.2140,  ...,  1.1191,  0.

Epoch 0:   1%|          | 46/5879 [00:49<1:45:24,  0.92it/s, v_num=7]
 tensor([[[-1.2314e+00, -2.7461e+00,  3.9032e-02,  ...,  8.2031e-01,
           7.9443e-01, -1.4160e+00],
         [-1.2080e+00, -2.7637e+00, -6.3843e-02,  ...,  6.5088e-01,
           7.6514e-01, -1.2852e+00],
         [-1.1279e+00, -2.2559e+00,  2.5903e-01,  ...,  1.2969e+00,
           4.8462e-01, -1.5889e+00]],

        [[-1.0732e+00, -2.5977e+00, -3.3521e-01,  ...,  1.5259e-01,
           5.7275e-01, -8.8281e-01],
         [-9.9951e-01, -2.4961e+00, -4.6362e-01,  ..., -1.2062e-02,
           4.6411e-01, -7.3926e-01],
         [-1.0918e+00, -2.1523e+00,  2.7173e-01,  ...,  1.3320e+00,
           4.1382e-01, -1.5693e+00]],

        [[-1.1621e+00, -2.7227e+00, -1.7542e-01,  ...,  4.2261e-01,
           7.0654e-01, -1.1143e+00],
         [-1.0947e+00, -2.6406e+00, -3.2446e-01,  ...,  2.1375e-01,
           5.9961e-01, -9.3311e-01],
         [-1.0996e+00, -2.1738e+00,  2.6929e-01,  ...,  1.3223e+00,
           4.3091

Epoch 0:   1%|          | 51/5879 [00:51<1:37:17,  1.00it/s, v_num=7]
 tensor([[[-1.2412, -2.5801,  0.1450,  ...,  1.1680,  0.7358, -1.6494],
         [-1.2549, -2.6602,  0.1028,  ...,  1.0859,  0.7896, -1.6074],
         [-1.2021, -2.4160,  0.1921,  ...,  1.2607,  0.6270, -1.6670]],

        [[-0.1779, -0.3994, -1.0977,  ..., -0.3962, -0.9565,  0.1384],
         [-0.1970, -0.5205, -1.1211,  ..., -0.4875, -0.9009,  0.1541],
         [-0.7661,  0.2595,  0.5239,  ...,  0.6079,  1.9199, -1.3662]],

        [[-0.5464, -1.5098, -1.0146,  ..., -0.7148, -0.2683,  0.0028],
         [-0.4963, -1.4551, -1.0654,  ..., -0.7612, -0.3176,  0.0335],
         [-1.0518, -1.9580,  0.2308,  ...,  1.3701,  0.3225, -1.5566]],

        ...,

        [[-0.2703,  0.0097, -0.0638,  ...,  0.3362, -0.1213, -0.5591],
         [-0.2288, -0.0096, -0.1406,  ...,  0.2795, -0.2769, -0.4670],
         [-0.8413, -0.0823,  0.7935,  ...,  0.6660,  2.0742, -1.5166]],

        [[-0.4001, -1.0898, -1.1260,  ..., -0.6987, -0.

Epoch 0:   1%|          | 57/5879 [00:52<1:29:25,  1.09it/s, v_num=7]
 tensor([[[-0.6948, -0.2703,  0.7588,  ...,  0.6313,  1.2100, -1.1934],
         [-0.6860, -0.2734,  0.7520,  ...,  0.6167,  1.1738, -1.1777],
         [-0.8711, -0.2781,  0.9414,  ...,  0.7329,  1.9062, -1.4795]],

        [[-1.1318, -2.6914, -0.4204,  ...,  0.0612,  0.7603, -0.9912],
         [-1.0547, -2.5859, -0.5386,  ..., -0.1074,  0.6426, -0.8325],
         [-1.1982, -2.3340,  0.1770,  ...,  1.3164,  0.6245, -1.7109]],

        [[-1.3164, -2.8633, -0.0311,  ...,  0.8330,  0.9966, -1.5791],
         [-1.3018, -2.8926, -0.1093,  ...,  0.6895,  0.9878, -1.4707],
         [-1.2324, -2.4395,  0.1593,  ...,  1.2754,  0.6978, -1.7275]],

        ...,

        [[-1.2998, -2.8926, -0.1220,  ...,  0.6328,  0.9941, -1.4482],
         [-1.2539, -2.8691, -0.2378,  ...,  0.4382,  0.9277, -1.2832],
         [-1.2207, -2.3984,  0.1661,  ...,  1.2920,  0.6733, -1.7256]],

        [[-0.5723, -0.1973,  0.5474,  ...,  0.5498,  0.

Epoch 0:   1%|          | 63/5879 [00:54<1:23:23,  1.16it/s, v_num=7]
 tensor([[[-0.6274, -0.2612,  0.6519,  ...,  0.6006,  0.7915, -0.9697],
         [-0.6113, -0.2585,  0.6323,  ...,  0.5786,  0.7314, -0.9438],
         [-0.9121, -0.3315,  0.9980,  ...,  0.7778,  1.8350, -1.4150]],

        [[-0.1946, -0.4724, -1.1182,  ..., -0.5664, -0.9551,  0.1989],
         [-0.1965, -0.5146, -1.1250,  ..., -0.6021, -0.9297,  0.1919],
         [-0.7412,  0.4414,  0.2937,  ...,  0.7637,  1.2617, -1.2070]],

        [[-1.3164, -2.9434, -0.2260,  ...,  0.4719,  1.0684, -1.4131],
         [-1.2607, -2.8945, -0.3369,  ...,  0.2896,  0.9868, -1.2529],
         [-1.2715, -2.4766,  0.1249,  ...,  1.2930,  0.7710, -1.7900]],

        ...,

        [[-0.3589, -0.9336, -1.1631,  ..., -0.8037, -0.6450,  0.1667],
         [-0.3657, -1.0156, -1.1689,  ..., -0.8525, -0.5830,  0.1538],
         [-0.9238, -1.1348,  0.1382,  ...,  1.3535,  0.1241, -1.4609]],

        [[-1.3672, -2.9180, -0.0516,  ...,  0.8843,  1.

Epoch 0:   1%|          | 69/5879 [01:01<1:26:35,  1.12it/s, v_num=7]
 tensor([[[-7.0850e-01, -1.8154e+00, -9.7900e-01,  ..., -8.3594e-01,
           1.1676e-01, -2.2742e-01],
         [-6.8896e-01, -1.8086e+00, -1.0029e+00,  ..., -8.5547e-01,
           1.0461e-01, -2.1631e-01],
         [-1.2754e+00, -2.4023e+00,  1.0968e-01,  ...,  1.3506e+00,
           7.4658e-01, -1.8164e+00]],

        [[-2.4695e-01, -6.0693e-01, -1.1475e+00,  ..., -7.2021e-01,
          -8.7012e-01,  2.1021e-01],
         [-2.5293e-01, -6.8506e-01, -1.1611e+00,  ..., -7.8174e-01,
          -8.1445e-01,  2.0898e-01],
         [-8.5840e-01,  2.5464e-01,  5.4199e-01,  ...,  7.6221e-01,
           1.5127e+00, -1.2080e+00]],

        [[-1.9604e-01,  4.3297e-03, -2.4329e-01,  ...,  1.9299e-01,
          -7.3682e-01, -1.6663e-01],
         [-1.6956e-01, -2.0695e-03, -3.2861e-01,  ...,  1.2469e-01,
          -8.3887e-01, -1.0779e-01],
         [-9.6045e-01, -3.1396e-01,  1.0068e+00,  ...,  8.1104e-01,
           1.8242

Epoch 0:   1%|▏         | 75/5879 [01:03<1:21:42,  1.18it/s, v_num=7]
 tensor([[[-0.6421, -1.6064, -1.0508,  ..., -0.9556, -0.0161, -0.1187],
         [-0.6113, -1.5898, -1.0742,  ..., -0.9810, -0.0418, -0.0966],
         [-1.2861, -2.2988,  0.0955,  ...,  1.3945,  0.7026, -1.8311]],

        [[-0.5166, -1.3213, -1.1201,  ..., -0.9883, -0.2637,  0.0214],
         [-0.5215, -1.3701, -1.1221,  ..., -0.9976, -0.2191, -0.0040],
         [-1.2715, -2.3086,  0.0883,  ...,  1.3906,  0.7041, -1.8252]],

        [[-1.4473, -3.0898, -0.2239,  ...,  0.6001,  1.3057, -1.6484],
         [-1.4287, -3.0859, -0.2742,  ...,  0.5171,  1.2812, -1.5752],
         [-1.3984, -2.6719,  0.0502,  ...,  1.2822,  0.9873, -1.9199]],

        ...,

        [[-0.8296, -2.0254, -0.9160,  ..., -0.7964,  0.3496, -0.3877],
         [-0.7378, -1.8818, -0.9907,  ..., -0.8867,  0.2039, -0.2644],
         [-1.3008, -2.3555,  0.0934,  ...,  1.3896,  0.7432, -1.8457]],

        [[-1.0664, -2.5195, -0.7065,  ..., -0.4292,  0.

Epoch 0:   1%|▏         | 81/5879 [01:04<1:17:20,  1.25it/s, v_num=7]
 tensor([[[-1.5273, -3.1309, -0.1694,  ...,  0.7983,  1.3936, -1.8291],
         [-1.5146, -3.1504, -0.2177,  ...,  0.6992,  1.3936, -1.7568],
         [-1.4609, -2.7559,  0.0248,  ...,  1.2793,  1.0820, -1.9756]],

        [[-0.1190, -0.2578, -1.0137,  ..., -0.5796, -1.1250,  0.2443],
         [-0.1246, -0.3000, -1.0254,  ..., -0.6182, -1.0996,  0.2407],
         [-1.0117, -0.2349,  0.9561,  ...,  0.8628,  1.6328, -1.1953]],

        [[-0.6016, -0.3030,  0.5991,  ...,  0.5752,  0.2415, -0.6475],
         [-0.5488, -0.2646,  0.5146,  ...,  0.5210,  0.0801, -0.5708],
         [-0.9951, -0.5068,  1.1162,  ...,  0.8730,  1.5312, -1.2188]],

        ...,

        [[-0.0775, -0.1339, -0.8916,  ..., -0.4246, -1.1904,  0.2302],
         [-0.0831, -0.1672, -0.9111,  ..., -0.4692, -1.1738,  0.2324],
         [-1.0215, -0.3955,  1.0693,  ...,  0.8745,  1.6484, -1.2285]],

        [[-1.5225, -3.0000, -0.0639,  ...,  1.1045,  1.

Epoch 0:   1%|▏         | 87/5879 [01:06<1:13:31,  1.31it/s, v_num=7]
 tensor([[[-1.2227e+00, -2.7070e+00, -6.5234e-01,  ..., -3.5425e-01,
           1.0439e+00, -1.0273e+00],
         [-1.0635e+00, -2.4434e+00, -7.9297e-01,  ..., -6.1572e-01,
           7.9297e-01, -7.5439e-01],
         [-1.4756e+00, -2.6836e+00,  3.9917e-02,  ...,  1.3545e+00,
           1.0518e+00, -2.0098e+00]],

        [[-1.5381e+00, -3.1836e+00, -2.9517e-01,  ...,  4.8779e-01,
           1.4639e+00, -1.6963e+00],
         [-1.4395e+00, -3.0859e+00, -4.2725e-01,  ...,  1.8823e-01,
           1.3477e+00, -1.4580e+00],
         [-1.5127e+00, -2.8008e+00,  1.5358e-02,  ...,  1.2979e+00,
           1.1455e+00, -2.0293e+00]],

        [[-1.5908e+00, -3.1465e+00, -1.3989e-01,  ...,  9.3506e-01,
           1.4375e+00, -1.9492e+00],
         [-1.5840e+00, -3.1895e+00, -1.9336e-01,  ...,  8.0029e-01,
           1.4609e+00, -1.8701e+00],
         [-1.5322e+00, -2.8516e+00,  1.3056e-03,  ...,  1.2734e+00,
           1.1836

Epoch 0:   2%|▏         | 92/5879 [01:13<1:16:36,  1.26it/s, v_num=7]
 tensor([[[-1.6338e+00, -3.0938e+00, -8.1360e-02,  ...,  1.1348e+00,
           1.4014e+00, -2.0586e+00],
         [-1.6328e+00, -3.1074e+00, -9.3445e-02,  ...,  1.1211e+00,
           1.4092e+00, -2.0449e+00],
         [-1.6055e+00, -2.9492e+00, -2.0996e-02,  ...,  1.2598e+00,
           1.2744e+00, -2.0820e+00]],

        [[-1.5332e+00, -3.1719e+00, -3.7207e-01,  ...,  3.0273e-01,
           1.4844e+00, -1.6113e+00],
         [-1.4219e+00, -3.0332e+00, -4.9731e-01,  ...,  1.9424e-02,
           1.3408e+00, -1.3750e+00],
         [-1.5566e+00, -2.8301e+00,  1.5114e-02,  ...,  1.3193e+00,
           1.1855e+00, -2.0703e+00]],

        [[-3.8428e-01, -8.9307e-01, -1.1211e+00,  ..., -1.0811e+00,
          -5.5322e-01,  1.3977e-01],
         [-3.5742e-01, -8.7695e-01, -1.1279e+00,  ..., -1.0918e+00,
          -5.7715e-01,  1.5015e-01],
         [-1.3545e+00, -2.2402e+00,  7.1960e-02,  ...,  1.4541e+00,
           7.4365

Epoch 0:   2%|▏         | 97/5879 [01:16<1:15:41,  1.27it/s, v_num=7]
 tensor([[[-2.9468e-01, -6.3916e-01, -1.0840e+00,  ..., -1.0205e+00,
          -7.6660e-01,  2.1594e-01],
         [-2.8833e-01, -6.4355e-01, -1.0850e+00,  ..., -1.0303e+00,
          -7.5342e-01,  2.0740e-01],
         [-1.0889e+00, -1.2219e-01,  6.3330e-01,  ...,  1.0098e+00,
           1.1299e+00, -1.1543e+00]],

        [[-1.6816e+00, -3.2793e+00, -2.2302e-01,  ...,  7.2314e-01,
           1.6094e+00, -1.9258e+00],
         [-1.6475e+00, -3.2754e+00, -2.9248e-01,  ...,  5.4639e-01,
           1.5918e+00, -1.8047e+00],
         [-1.6221e+00, -2.9121e+00,  5.8403e-03,  ...,  1.3105e+00,
           1.2666e+00, -2.1172e+00]],

        [[-5.9766e-01, -1.3262e+00, -1.0615e+00,  ..., -1.1650e+00,
          -1.0260e-01, -3.7445e-02],
         [-5.1953e-01, -1.1992e+00, -1.0859e+00,  ..., -1.1836e+00,
          -2.3877e-01,  3.2196e-02],
         [-1.4463e+00, -2.3867e+00,  7.8674e-02,  ...,  1.4697e+00,
           8.5156

Epoch 0:   2%|▏         | 102/5879 [01:22<1:17:37,  1.24it/s, v_num=7]
 tensor([[[-3.1543e-01, -9.5642e-02,  1.7578e-02,  ...,  1.9348e-01,
          -9.5410e-01,  2.2461e-02],
         [-2.4390e-01, -5.4718e-02, -1.2622e-01,  ...,  7.5928e-02,
          -1.0986e+00,  1.1322e-01],
         [-1.0928e+00, -6.9434e-01,  1.1572e+00,  ...,  9.4287e-01,
           1.2100e+00, -1.0000e+00]],

        [[-1.3245e-01, -2.6465e-01, -9.6484e-01,  ..., -8.0322e-01,
          -1.1172e+00,  3.0347e-01],
         [-1.3611e-01, -2.9639e-01, -9.7510e-01,  ..., -8.4277e-01,
          -1.0791e+00,  3.0200e-01],
         [-1.1367e+00, -5.8789e-01,  1.1133e+00,  ...,  9.5996e-01,
           1.3477e+00, -1.0283e+00]],

        [[-1.1611e+00, -2.4473e+00, -7.6367e-01,  ..., -6.9824e-01,
           9.6338e-01, -8.3105e-01],
         [-9.5020e-01, -2.0645e+00, -8.9746e-01,  ..., -9.6582e-01,
           6.1230e-01, -5.0586e-01],
         [-1.6113e+00, -2.7715e+00,  5.0293e-02,  ...,  1.4082e+00,
           1.147

Epoch 0:   2%|▏         | 107/5879 [01:23<1:15:01,  1.28it/s, v_num=7]
 tensor([[[-1.7402, -3.3262, -0.2930,  ...,  0.4866,  1.7070, -1.8584],
         [-1.4951, -3.0195, -0.5244,  ..., -0.1396,  1.4453, -1.3672],
         [-1.7119, -2.9512,  0.0319,  ...,  1.3545,  1.3125, -2.1836]],

        [[-0.9185, -0.6475,  0.9521,  ...,  0.7910,  0.4929, -0.6880],
         [-0.8296, -0.5630,  0.8164,  ...,  0.6895,  0.2113, -0.5532],
         [-1.0791, -0.7451,  1.1240,  ...,  0.9258,  1.0010, -0.9116]],

        [[-1.7881, -3.1719, -0.0540,  ...,  1.2070,  1.5059, -2.1875],
         [-1.7930, -3.2070, -0.0770,  ...,  1.1660,  1.5332, -2.1680],
         [-1.7725, -3.0898, -0.0198,  ...,  1.2744,  1.4316, -2.2031]],

        ...,

        [[-1.6494, -3.2305, -0.3916,  ...,  0.2095,  1.6377, -1.6631],
         [-1.5342, -3.0762, -0.4978,  ..., -0.0549,  1.5029, -1.4424],
         [-1.7041, -2.9355,  0.0355,  ...,  1.3623,  1.2969, -2.1816]],

        [[-0.2886, -0.5854, -1.0273,  ..., -1.0859, -0

Epoch 0:   2%|▏         | 113/5879 [01:25<1:12:43,  1.32it/s, v_num=7]
 tensor([[[-1.1680, -2.3457, -0.7598,  ..., -0.8154,  0.9932, -0.7900],
         [-0.9595, -1.9609, -0.8721,  ..., -1.0566,  0.6479, -0.4792],
         [-1.7080, -2.8145,  0.0838,  ...,  1.4434,  1.1807, -2.1816]],

        [[-1.8633, -3.3574, -0.1438,  ...,  0.9146,  1.7168, -2.1289],
         [-1.8135, -3.3750, -0.2566,  ...,  0.5630,  1.7441, -1.9277],
         [-1.8008, -3.0605,  0.0280,  ...,  1.3242,  1.4102, -2.2363]],

        [[-0.3845, -0.7690, -1.0098,  ..., -1.2070, -0.5195,  0.1595],
         [-0.3257, -0.6841, -1.0117,  ..., -1.1934, -0.6133,  0.2061],
         [-1.2344, -0.6548,  0.4563,  ...,  1.3018,  0.7051, -1.4160]],

        ...,

        [[-1.5361, -3.0059, -0.5142,  ..., -0.1843,  1.5146, -1.3965],
         [-1.1963, -2.4316, -0.7441,  ..., -0.7603,  1.0342, -0.8335],
         [-1.7373, -2.8945,  0.0714,  ...,  1.4141,  1.2549, -2.2031]],

        [[-1.8594, -3.2773, -0.0754,  ...,  1.1172,  1

Epoch 0:   2%|▏         | 119/5879 [01:34<1:15:56,  1.26it/s, v_num=7]
 tensor([[[-0.6558, -1.2783, -0.9409,  ..., -1.2852,  0.0743, -0.0679],
         [-0.5557, -1.1162, -0.9629,  ..., -1.3076, -0.1086,  0.0315],
         [-1.6826, -2.5996,  0.1354,  ...,  1.5098,  1.0078, -2.1426]],

        [[-0.0789, -0.1249, -0.8394,  ..., -0.8340, -1.2168,  0.3794],
         [-0.0859, -0.1614, -0.8511,  ..., -0.8789, -1.1680,  0.3772],
         [-1.2188, -0.7637,  1.1006,  ...,  0.9912,  1.0898, -0.8950]],

        [[-0.0675, -0.0958, -0.8184,  ..., -0.7856, -1.2617,  0.3887],
         [-0.0718, -0.1174, -0.8281,  ..., -0.8228, -1.2246,  0.3865],
         [-1.2139, -0.7998,  1.1152,  ...,  0.9873,  1.0605, -0.8867]],

        ...,

        [[-0.0808, -0.1292, -0.8467,  ..., -0.8354, -1.2188,  0.3767],
         [-0.0756, -0.1205, -0.8442,  ..., -0.8306, -1.2227,  0.3723],
         [-1.2197, -0.7539,  1.0908,  ...,  0.9956,  1.0957, -0.8911]],

        [[-1.9170, -3.2559, -0.0240,  ...,  1.2168,  1

Epoch 0:   2%|▏         | 125/5879 [01:36<1:14:00,  1.30it/s, v_num=7]
 tensor([[[-0.5342, -0.9946, -0.9189,  ..., -1.3164, -0.1613,  0.0726],
         [-0.4690, -0.8970, -0.9277,  ..., -1.3184, -0.2766,  0.1305],
         [-1.6963, -2.4980,  0.1757,  ...,  1.5371,  0.9380, -2.1387]],

        [[-1.9404, -3.4336, -0.2133,  ...,  0.5908,  1.8594, -2.0273],
         [-1.6992, -3.1562, -0.4199,  ..., -0.0502,  1.6836, -1.5547],
         [-1.8994, -3.0742,  0.0889,  ...,  1.3848,  1.4043, -2.3027]],

        [[-1.9814, -3.4375, -0.1376,  ...,  0.8364,  1.8203, -2.1621],
         [-1.8955, -3.4023, -0.2622,  ...,  0.4338,  1.8311, -1.9160],
         [-1.9199, -3.1230,  0.0735,  ...,  1.3525,  1.4512, -2.3086]],

        ...,

        [[-0.0451,  0.0128, -0.6919,  ..., -0.6465, -1.4111,  0.4387],
         [-0.0427,  0.0047, -0.7021,  ..., -0.6748, -1.3916,  0.4353],
         [-1.2354, -0.8564,  1.0928,  ...,  0.9834,  0.9438, -0.8345]],

        [[-1.9590, -3.4414, -0.1876,  ...,  0.6758,  1

Epoch 0:   2%|▏         | 131/5879 [01:38<1:12:03,  1.33it/s, v_num=7]
 tensor([[[-2.0488, -3.4297, -0.0695,  ...,  0.9951,  1.7871, -2.2578],
         [-2.0273, -3.4668, -0.1548,  ...,  0.7300,  1.8594, -2.1270],
         [-1.9883, -3.1719,  0.0881,  ...,  1.3506,  1.4844, -2.3418]],

        [[-0.7393, -0.4802,  0.4685,  ...,  0.4502, -0.6138, -0.1184],
         [-0.5957, -0.3386,  0.2913,  ...,  0.2915, -0.8984,  0.0541],
         [-1.2080, -0.9258,  1.0439,  ...,  0.9370,  0.6738, -0.7349]],

        [[-0.3623, -0.6182, -0.8735,  ..., -1.2598, -0.5186,  0.2378],
         [-0.3254, -0.5674, -0.8721,  ..., -1.2559, -0.5742,  0.2620],
         [-1.3975, -0.6982,  0.6875,  ...,  1.1934,  0.8076, -1.2422]],

        ...,

        [[-2.0488, -3.4336, -0.0741,  ...,  0.9814,  1.7930, -2.2539],
         [-2.0273, -3.4668, -0.1490,  ...,  0.7490,  1.8574, -2.1367],
         [-1.9893, -3.1719,  0.0877,  ...,  1.3486,  1.4834, -2.3398]],

        [[-2.0488, -3.4160, -0.0586,  ...,  1.0312,  1

Epoch 0:   2%|▏         | 137/5879 [01:40<1:10:16,  1.36it/s, v_num=7]
 tensor([[[-2.0977e+00, -3.3086e+00,  5.1941e-02,  ...,  1.2773e+00,
           1.6104e+00, -2.3672e+00],
         [-2.0996e+00, -3.3359e+00,  3.2379e-02,  ...,  1.2451e+00,
           1.6387e+00, -2.3555e+00],
         [-2.0918e+00, -3.2715e+00,  6.8604e-02,  ...,  1.3096e+00,
           1.5713e+00, -2.3789e+00]],

        [[-1.1025e+00, -1.9766e+00, -7.0850e-01,  ..., -1.0791e+00,
           9.2773e-01, -5.6738e-01],
         [-8.3936e-01, -1.5146e+00, -7.8955e-01,  ..., -1.2822e+00,
           4.8730e-01, -2.2595e-01],
         [-1.9688e+00, -3.0137e+00,  1.5881e-01,  ...,  1.4648e+00,
           1.3047e+00, -2.3340e+00]],

        [[-4.7046e-01, -7.8809e-01, -8.2715e-01,  ..., -1.3252e+00,
          -2.7295e-01,  1.7725e-01],
         [-4.3481e-01, -7.3340e-01, -8.2812e-01,  ..., -1.3213e+00,
          -3.2812e-01,  1.9739e-01],
         [-1.5088e+00, -1.0840e+00,  5.3662e-01,  ...,  1.3730e+00,
           6.445

Epoch 0:   2%|▏         | 143/5879 [01:53<1:15:36,  1.26it/s, v_num=7]
 tensor([[[-0.0718,  0.0710, -0.5151,  ..., -0.5845, -1.5537,  0.5366],
         [-0.0525,  0.0545, -0.5601,  ..., -0.6787, -1.4980,  0.5469],
         [-1.2979, -0.9976,  0.9814,  ...,  0.9331,  0.6147, -0.7144]],

        [[-2.1641, -3.4668, -0.0363,  ...,  1.0000,  1.8271, -2.3047],
         [-2.1270, -3.4961, -0.1357,  ...,  0.6821,  1.9189, -2.1406],
         [-2.0938, -3.1953,  0.1365,  ...,  1.3838,  1.4824, -2.3945]],

        [[-0.1278, -0.1675, -0.7383,  ..., -1.0918, -1.0400,  0.4497],
         [-0.1313, -0.1838, -0.7402,  ..., -1.1084, -1.0156,  0.4534],
         [-1.3584, -0.8579,  0.9312,  ...,  0.9951,  0.8135, -0.8257]],

        ...,

        [[-0.8169, -1.3652, -0.7446,  ..., -1.3037,  0.4219, -0.1375],
         [-0.6597, -1.1045, -0.7690,  ..., -1.3574,  0.1356,  0.0249],
         [-1.9111, -2.7109,  0.2329,  ...,  1.5605,  1.0361, -2.2715]],

        [[-2.1660, -3.4395, -0.0086,  ...,  1.0791,  1

Epoch 0:   3%|▎         | 149/5879 [01:54<1:13:24,  1.30it/s, v_num=7]
 tensor([[[-2.0625, -3.3945, -0.2217,  ...,  0.3118,  1.9697, -1.9141],
         [-1.6240, -2.7754, -0.4536,  ..., -0.5112,  1.6152, -1.2051],
         [-2.0625, -3.0020,  0.2139,  ...,  1.5127,  1.2656, -2.3730]],

        [[-0.2947, -0.4043, -0.7222,  ..., -1.2656, -0.6514,  0.3674],
         [-0.2698, -0.3894, -0.7241,  ..., -1.2715, -0.6860,  0.3984],
         [-1.4404, -0.8335,  0.8281,  ...,  1.0479,  0.7617, -0.9565]],

        [[-0.0498,  0.0446, -0.5830,  ..., -0.7783, -1.4766,  0.5796],
         [-0.0572, -0.0056, -0.6250,  ..., -0.9102, -1.3398,  0.5737],
         [-1.3301, -1.0361,  0.9307,  ...,  0.9097,  0.5269, -0.6895]],

        ...,

        [[-0.3994, -0.5894, -0.7285,  ..., -1.3291, -0.4146,  0.2852],
         [-0.3591, -0.5337, -0.7290,  ..., -1.3242, -0.4839,  0.3164],
         [-1.6846, -1.8018,  0.3916,  ...,  1.5557,  0.6440, -1.9258]],

        [[-1.1797, -0.9512,  0.8252,  ...,  0.7778,  0

Epoch 0:   3%|▎         | 155/5879 [01:56<1:11:29,  1.33it/s, v_num=7]
 tensor([[[-1.2861, -2.1172, -0.5562,  ..., -0.9927,  1.2041, -0.6763],
         [-0.9092, -1.4717, -0.6519,  ..., -1.3066,  0.6011, -0.1787],
         [-2.1055, -3.0156,  0.2333,  ...,  1.5264,  1.2559, -2.3848]],

        [[-2.2676, -3.4941, -0.0134,  ...,  0.9731,  1.8672, -2.3105],
         [-2.1289, -3.4297, -0.1840,  ...,  0.3660,  1.9648, -1.9492],
         [-2.1836, -3.1914,  0.1874,  ...,  1.4326,  1.4512, -2.4238]],

        [[-2.2676, -3.4824,  0.0054,  ...,  1.0293,  1.8359, -2.3340],
         [-2.2031, -3.4961, -0.1238,  ...,  0.5894,  1.9639, -2.0996],
         [-2.1914, -3.2109,  0.1812,  ...,  1.4170,  1.4697, -2.4238]],

        ...,

        [[-0.2327, -0.2554, -0.6680,  ..., -1.2266, -0.8271,  0.4431],
         [-0.2168, -0.2554, -0.6680,  ..., -1.2334, -0.8384,  0.4568],
         [-1.7500, -1.8467,  0.4333,  ...,  1.5752,  0.6440, -1.9336]],

        [[-2.2676, -3.5000, -0.0237,  ...,  0.9409,  1

Epoch 0:   3%|▎         | 161/5879 [01:58<1:09:56,  1.36it/s, v_num=7]
 tensor([[[-0.8843, -1.3330, -0.6099,  ..., -1.3262,  0.5342, -0.1201],
         [-0.7134, -1.0537, -0.6270,  ..., -1.3896,  0.2220,  0.0592],
         [-2.1387, -3.0020,  0.2556,  ...,  1.5518,  1.2246, -2.3867]],

        [[-0.6304, -0.8940, -0.6299,  ..., -1.3916,  0.0362,  0.1492],
         [-0.5190, -0.7314, -0.6382,  ..., -1.4004, -0.1776,  0.2546],
         [-2.0840, -2.8672,  0.2749,  ...,  1.5889,  1.0996, -2.3477]],

        [[-1.1221, -0.8901,  0.5894,  ...,  0.5791, -0.3767, -0.2959],
         [-1.0215, -0.7725,  0.4771,  ...,  0.4563, -0.6216, -0.1539],
         [-1.3057, -1.0771,  0.7832,  ...,  0.7871,  0.1840, -0.5698]],

        ...,

        [[-0.0994, -0.0317, -0.6001,  ..., -1.0674, -1.2402,  0.5757],
         [-0.1045, -0.0424, -0.6001,  ..., -1.0898, -1.2031,  0.5737],
         [-1.3857, -1.0801,  0.8271,  ...,  0.8647,  0.4431, -0.6758]],

        [[-0.1113, -0.0524, -0.6055,  ..., -1.0938, -1

Epoch 0:   3%|▎         | 166/5879 [01:59<1:08:47,  1.38it/s, v_num=7]
 tensor([[[-0.8730, -1.2568, -0.5708,  ..., -1.3506,  0.4834, -0.0511],
         [-0.6611, -0.9292, -0.5918,  ..., -1.4189,  0.0812,  0.1742],
         [-2.1172, -2.8652,  0.2966,  ...,  1.6094,  1.0918, -2.3516]],

        [[-2.3535, -3.4941,  0.0225,  ...,  1.0010,  1.8730, -2.3184],
         [-2.1582, -3.3691, -0.1748,  ...,  0.2527,  1.9844, -1.8584],
         [-2.2480, -3.1641,  0.2375,  ...,  1.4912,  1.4053, -2.4316]],

        [[-2.3242, -3.3457,  0.1487,  ...,  1.3379,  1.6221, -2.4238],
         [-2.3320, -3.3652,  0.1327,  ...,  1.3086,  1.6475, -2.4160],
         [-2.2988, -3.2559,  0.1927,  ...,  1.4219,  1.5117, -2.4414]],

        ...,

        [[-0.7764, -1.0889, -0.5796,  ..., -1.3789,  0.2993,  0.0477],
         [-0.6123, -0.8301, -0.5898,  ..., -1.4141, -0.0186,  0.2102],
         [-2.0938, -2.8105,  0.3040,  ...,  1.6201,  1.0439, -2.3359]],

        [[-1.0127, -1.5010, -0.5532,  ..., -1.2793,  0

Epoch 0:   3%|▎         | 172/5879 [02:13<1:13:55,  1.29it/s, v_num=7]
 tensor([[[-0.2318, -0.1564, -0.5464,  ..., -1.2500, -0.9585,  0.5420],
         [-0.2365, -0.1671, -0.5449,  ..., -1.2646, -0.9365,  0.5479],
         [-1.4316, -1.0957,  0.7261,  ...,  0.8286,  0.3962, -0.6973]],

        [[-0.1775, -0.0793, -0.5415,  ..., -1.1934, -1.1113,  0.5845],
         [-0.1936, -0.1096, -0.5415,  ..., -1.2217, -1.0547,  0.5845],
         [-1.4209, -1.1074,  0.7256,  ...,  0.8140,  0.3640, -0.6768]],

        [[-2.3633, -3.5000, -0.0306,  ...,  0.7666,  1.9775, -2.1934],
         [-2.1543, -3.2969, -0.1750,  ...,  0.1704,  1.9971, -1.7861],
         [-2.2734, -3.1348,  0.2634,  ...,  1.5283,  1.3682, -2.4258]],

        ...,

        [[-0.5073, -0.5723, -0.5410,  ..., -1.3984, -0.3086,  0.3389],
         [-0.4731, -0.5347, -0.5449,  ..., -1.4033, -0.3723,  0.3750],
         [-2.0781, -2.6875,  0.3379,  ...,  1.6602,  0.9429, -2.2949]],

        [[-0.1946, -0.1042, -0.5435,  ..., -1.2129, -1

Epoch 0:   3%|▎         | 178/5879 [02:15<1:12:07,  1.32it/s, v_num=7]
 tensor([[[-0.2289, -0.0986, -0.5054,  ..., -1.2461, -1.0576,  0.5850],
         [-0.2389, -0.1207, -0.5034,  ..., -1.2656, -1.0186,  0.5903],
         [-1.4346, -1.1182,  0.6704,  ...,  0.7920,  0.3311, -0.6860]],

        [[-2.4141, -3.4004,  0.1340,  ...,  1.2529,  1.7158, -2.3848],
         [-2.4258, -3.4883,  0.0329,  ...,  0.9512,  1.8887, -2.2695],
         [-2.3320, -3.1797,  0.2610,  ...,  1.5186,  1.4189, -2.4277]],

        [[-1.0566, -1.4072, -0.4663,  ..., -1.2979,  0.7310, -0.1685],
         [-0.8086, -1.0127, -0.4883,  ..., -1.4150,  0.2683,  0.1092],
         [-2.2168, -2.9375,  0.3206,  ...,  1.6416,  1.1455, -2.3730]],

        ...,

        [[-1.2119, -1.0098,  0.4910,  ...,  0.4976, -0.4573, -0.3064],
         [-1.0791, -0.8462,  0.3521,  ...,  0.3301, -0.7910, -0.1076],
         [-1.3438, -1.1523,  0.6260,  ...,  0.6611, -0.0115, -0.5254]],

        [[-0.8428, -1.0537, -0.4836,  ..., -1.3965,  0

Epoch 0:   3%|▎         | 184/5879 [02:16<1:10:38,  1.34it/s, v_num=7]
 tensor([[[-1.0322, -1.2842, -0.4233,  ..., -1.3447,  0.6226, -0.0798],
         [-0.8096, -0.9404, -0.4414,  ..., -1.4395,  0.1949,  0.1644],
         [-2.2480, -2.9453,  0.3359,  ...,  1.6729,  1.1504, -2.3672]],

        [[-0.6670, -0.6709, -0.4368,  ..., -1.4307, -0.1168,  0.2905],
         [-0.6011, -0.5908, -0.4438,  ..., -1.4443, -0.2463,  0.3567],
         [-2.1875, -2.8066,  0.3560,  ...,  1.7070,  1.0186, -2.3242]],

        [[-2.4434, -3.3809,  0.1538,  ...,  1.2881,  1.7031, -2.3770],
         [-2.4414, -3.4785,  0.0069,  ...,  0.8105,  1.9492, -2.1797],
         [-2.3477, -3.1445,  0.2859,  ...,  1.5684,  1.3818, -2.4180]],

        ...,

        [[-2.3496, -3.4043, -0.0600,  ...,  0.5054,  2.0371, -1.9941],
         [-1.8145, -2.6523, -0.2778,  ..., -0.5332,  1.7393, -1.1465],
         [-2.3145, -3.0879,  0.3066,  ...,  1.6064,  1.3086, -2.4043]],

        [[-2.4102, -3.2754,  0.2214,  ...,  1.4561,  1

Epoch 0:   3%|▎         | 190/5879 [02:18<1:09:19,  1.37it/s, v_num=7]
 tensor([[[-0.3318, -0.1207, -0.4153,  ..., -1.3369, -0.9814,  0.5874],
         [-0.3413, -0.1378, -0.4136,  ..., -1.3516, -0.9517,  0.5903],
         [-1.4346, -1.1875,  0.5771,  ...,  0.6982,  0.2220, -0.6763]],

        [[-0.6060, -0.4907, -0.3914,  ..., -1.4365, -0.3457,  0.3823],
         [-0.5728, -0.4526, -0.3945,  ..., -1.4414, -0.4102,  0.4148],
         [-2.1738, -2.7344,  0.3745,  ...,  1.7568,  0.9463, -2.2949]],

        [[-1.1670, -0.9844,  0.3191,  ...,  0.3049, -0.7544, -0.1793],
         [-1.0244, -0.7993,  0.1846,  ...,  0.1231, -1.0781,  0.0341],
         [-1.3467, -1.2109,  0.5156,  ...,  0.5557, -0.1288, -0.4956]],

        ...,

        [[-0.4741, -0.2966, -0.3984,  ..., -1.3994, -0.6465,  0.4888],
         [-0.4629, -0.2903, -0.3999,  ..., -1.4072, -0.6611,  0.5005],
         [-1.9863, -2.2266,  0.4478,  ...,  1.7451,  0.6577, -2.0879]],

        [[-0.1869,  0.0842, -0.3982,  ..., -0.9561, -1

Epoch 0:   3%|▎         | 196/5879 [02:24<1:09:50,  1.36it/s, v_num=7]
 tensor([[[-0.4661, -0.2122, -0.3557,  ..., -1.4072, -0.7773,  0.5327],
         [-0.4580, -0.2028, -0.3560,  ..., -1.4102, -0.7876,  0.5376],
         [-1.6924, -1.3652,  0.5479,  ...,  1.2422,  0.4153, -1.3271]],

        [[-0.3025, -0.0282, -0.3870,  ..., -1.3145, -1.1914,  0.6528],
         [-0.3169, -0.0463, -0.3816,  ..., -1.3350, -1.1455,  0.6504],
         [-1.4014, -1.2432,  0.5117,  ...,  0.6001,  0.0783, -0.5986]],

        [[-0.3953, -0.1337, -0.3711,  ..., -1.3770, -0.9473,  0.5835],
         [-0.3977, -0.1392, -0.3694,  ..., -1.3877, -0.9331,  0.5894],
         [-1.4453, -1.2129,  0.5327,  ...,  0.6875,  0.2074, -0.7046]],

        ...,

        [[-1.0762, -1.1523, -0.3254,  ..., -1.3828,  0.5439, -0.0119],
         [-0.8887, -0.8740, -0.3362,  ..., -1.4668,  0.1714,  0.1949],
         [-2.2637, -2.8496,  0.3652,  ...,  1.7725,  1.0586, -2.3242]],

        [[-0.5918, -0.3857, -0.3403,  ..., -1.4502, -0

Epoch 0:   3%|▎         | 202/5879 [02:26<1:08:42,  1.38it/s, v_num=7]
 tensor([[[-2.5059, -3.2617,  0.2247,  ...,  1.5039,  1.5693, -2.3613],
         [-2.5488, -3.3594,  0.1547,  ...,  1.3125,  1.7178, -2.3203],
         [-2.4258, -3.0918,  0.3096,  ...,  1.6865,  1.3418, -2.3770]],

        [[-0.3044,  0.0240, -0.3564,  ..., -1.3135, -1.3340,  0.6978],
         [-0.3198,  0.0037, -0.3489,  ..., -1.3359, -1.2793,  0.6938],
         [-1.3916, -1.2744,  0.4624,  ...,  0.5522,  0.0384, -0.5894]],

        [[-0.2479,  0.0754, -0.3694,  ..., -1.2217, -1.5596,  0.7578],
         [-0.2808,  0.0394, -0.3574,  ..., -1.2881, -1.4199,  0.7407],
         [-1.3701, -1.2920,  0.4448,  ...,  0.5005, -0.0776, -0.5303]],

        ...,

        [[-0.5771, -0.2837, -0.2981,  ..., -1.4580, -0.6362,  0.4917],
         [-0.5728, -0.2864, -0.2991,  ..., -1.4678, -0.6372,  0.5020],
         [-2.0723, -2.3613,  0.4326,  ...,  1.8447,  0.6919, -2.1328]],

        [[-0.6172, -0.3379, -0.2917,  ..., -1.4668, -0

Epoch 0:   4%|▎         | 208/5879 [02:28<1:07:32,  1.40it/s, v_num=7]
 tensor([[[-0.8159, -0.5269, -0.2279,  ..., -1.5010, -0.2257,  0.3364],
         [-0.7808, -0.4868, -0.2329,  ..., -1.5127, -0.2927,  0.3708],
         [-2.3008, -2.8008,  0.3716,  ...,  1.8740,  1.0146, -2.2852]],

        [[-0.7710, -0.4629, -0.2291,  ..., -1.5010, -0.3301,  0.3875],
         [-0.7588, -0.4639, -0.2345,  ..., -1.5166, -0.3477,  0.4080],
         [-2.2734, -2.7480,  0.3818,  ...,  1.8896,  0.9604, -2.2676]],

        [[-0.5557, -0.1863, -0.2590,  ..., -1.4678, -0.8071,  0.5420],
         [-0.5547, -0.1887, -0.2598,  ..., -1.4746, -0.8037,  0.5483],
         [-1.5527, -1.3174,  0.4944,  ...,  0.9277,  0.2881, -0.9897]],

        ...,

        [[-0.6797, -0.3350, -0.2377,  ..., -1.4912, -0.5273,  0.4502],
         [-0.6758, -0.3467, -0.2408,  ..., -1.5068, -0.5249,  0.4656],
         [-2.2227, -2.6504,  0.3921,  ...,  1.9023,  0.8726, -2.2344]],

        [[-0.6924, -0.3503, -0.2368,  ..., -1.4902, -0

Epoch 0:   4%|▎         | 214/5879 [02:30<1:06:13,  1.43it/s, v_num=7]
 tensor([[[-0.5874, -0.1481, -0.2085,  ..., -1.4854, -0.8608,  0.5576],
         [-0.5938, -0.1605, -0.2067,  ..., -1.4980, -0.8364,  0.5596],
         [-1.5088, -1.3545,  0.4402,  ...,  0.7915,  0.2106, -0.8804]],

        [[-2.6211, -3.3496,  0.0864,  ...,  1.0361,  1.9219, -2.1465],
         [-2.2578, -2.8789, -0.0714,  ..., -0.0940,  1.9521, -1.4414],
         [-2.4180, -2.9453,  0.3484,  ...,  1.8613,  1.1934, -2.3242]],

        [[-0.7905, -0.3936, -0.1718,  ..., -1.5156, -0.4092,  0.4065],
         [-0.7808, -0.3967, -0.1763,  ..., -1.5312, -0.4226,  0.4250],
         [-2.2598, -2.6523,  0.3904,  ...,  1.9580,  0.8760, -2.2246]],

        ...,

        [[-0.5752, -0.1373, -0.2119,  ..., -1.4844, -0.8857,  0.5674],
         [-0.5859, -0.1547, -0.2101,  ..., -1.4971, -0.8564,  0.5684],
         [-1.5352, -1.3662,  0.4497,  ...,  0.8687,  0.2371, -0.9448]],

        [[-1.1377, -0.9199, -0.1633,  ..., -1.4561,  0

Epoch 0:   4%|▎         | 220/5879 [02:35<1:06:43,  1.41it/s, v_num=7]
 tensor([[[-0.7568, -0.2629, -0.1249,  ..., -1.5283, -0.6074,  0.4678],
         [-0.7598, -0.2808, -0.1271,  ..., -1.5439, -0.5942,  0.4768],
         [-2.1914, -2.4551,  0.4099,  ...,  2.0215,  0.7217, -2.1406]],

        [[-0.6362, -0.1321, -0.1554,  ..., -1.5127, -0.8721,  0.5503],
         [-0.6426, -0.1462, -0.1545,  ..., -1.5254, -0.8506,  0.5537],
         [-1.5459, -1.4258,  0.4138,  ...,  0.8989,  0.2147, -0.9722]],

        [[-1.1172, -0.7793, -0.1038,  ..., -1.4961,  0.1761,  0.1530],
         [-1.0527, -0.6890, -0.1069,  ..., -1.5234,  0.0450,  0.2192],
         [-2.3594, -2.7734,  0.3687,  ...,  1.9863,  1.0029, -2.2578]],

        ...,

        [[-0.6167, -0.1105, -0.1603,  ..., -1.5068, -0.9209,  0.5693],
         [-0.6191, -0.1157, -0.1593,  ..., -1.5156, -0.9082,  0.5723],
         [-1.5410, -1.4258,  0.4087,  ...,  0.8804,  0.2080, -0.9531]],

        [[-0.5767, -0.0714, -0.1736,  ..., -1.4961, -1

Epoch 0:   4%|▍         | 225/5879 [02:36<1:05:41,  1.43it/s, v_num=7]
 tensor([[[-0.6694, -0.1082, -0.1124,  ..., -1.5361, -0.9043,  0.5532],
         [-0.6763, -0.1219, -0.1121,  ..., -1.5479, -0.8843,  0.5566],
         [-1.4277, -1.4297,  0.3379,  ...,  0.5640,  0.0571, -0.7212]],

        [[-0.7539, -0.1951, -0.0884,  ..., -1.5479, -0.7207,  0.4939],
         [-0.7500, -0.1941, -0.0899,  ..., -1.5557, -0.7227,  0.5010],
         [-1.9619, -1.9727,  0.4189,  ...,  1.8691,  0.4407, -1.8203]],

        [[-0.9556, -0.4475, -0.0585,  ..., -1.5518, -0.2803,  0.3367],
         [-0.9351, -0.4297, -0.0635,  ..., -1.5635, -0.3184,  0.3589],
         [-2.3652, -2.7305,  0.3643,  ...,  2.0469,  0.9688, -2.2363]],

        ...,

        [[-2.5781, -3.0781,  0.2688,  ...,  1.8076,  1.4082, -2.3105],
         [-2.6328, -3.1699,  0.2169,  ...,  1.6738,  1.5303, -2.3027],
         [-2.5195, -2.9707,  0.3118,  ...,  1.9150,  1.2656, -2.3125]],

        [[-0.4253,  0.0942, -0.2095,  ..., -1.4414, -1

Epoch 0:   4%|▍         | 230/5879 [02:38<1:04:53,  1.45it/s, v_num=7]
 tensor([[[-0.9902, -0.4106, -0.0084,  ..., -1.5664, -0.3127,  0.3350],
         [-0.9790, -0.3977, -0.0098,  ..., -1.5713, -0.3315,  0.3442],
         [-2.3652, -2.6797,  0.3582,  ...,  2.1074,  0.9238, -2.2129]],

        [[-0.9585, -0.3696, -0.0104,  ..., -1.5693, -0.3857,  0.3713],
         [-0.9521, -0.3726, -0.0150,  ..., -1.5850, -0.3958,  0.3850],
         [-2.3379, -2.6250,  0.3665,  ...,  2.1191,  0.8696, -2.1914]],

        [[-0.8628, -0.2487, -0.0237,  ..., -1.5693, -0.5913,  0.4424],
         [-0.8647, -0.2627, -0.0259,  ..., -1.5840, -0.5811,  0.4507],
         [-2.2285, -2.4277,  0.3916,  ...,  2.1230,  0.7021, -2.1152]],

        ...,

        [[-0.9634, -0.3674, -0.0115,  ..., -1.5645, -0.3848,  0.3635],
         [-0.9390, -0.3489, -0.0184,  ..., -1.5801, -0.4270,  0.3872],
         [-2.3613, -2.6641,  0.3611,  ...,  2.1113,  0.9097, -2.2031]],

        [[-0.8662, -0.2539, -0.0273,  ..., -1.5723, -0

Epoch 0:   4%|▍         | 235/5879 [02:39<1:03:56,  1.47it/s, v_num=7]
 tensor([[[-2.7031e+00, -3.1641e+00,  1.7944e-01,  ...,  1.6523e+00,
           1.6094e+00, -2.2500e+00],
         [-2.7480e+00, -3.2422e+00,  9.5703e-02,  ...,  1.2119e+00,
           1.8613e+00, -2.1133e+00],
         [-2.5234e+00, -2.8750e+00,  3.1372e-01,  ...,  2.0508e+00,
           1.1768e+00, -2.2695e+00]],

        [[-1.1914e+00, -1.3877e+00,  3.7750e-02,  ..., -8.8379e-02,
          -9.0332e-01, -1.4124e-01],
         [-1.1660e+00, -1.3389e+00,  5.7259e-03,  ..., -1.4270e-01,
          -1.0117e+00, -7.8613e-02],
         [-1.2480e+00, -1.5029e+00,  1.1554e-01,  ...,  5.4840e-02,
          -5.5615e-01, -3.0396e-01]],

        [[-6.3965e-01,  1.9058e-02, -6.7017e-02,  ..., -1.5674e+00,
          -1.1875e+00,  6.3867e-01],
         [-6.5869e-01, -2.6188e-03, -5.9662e-02,  ..., -1.5801e+00,
          -1.1387e+00,  6.3086e-01],
         [-1.3320e+00, -1.5078e+00,  2.0959e-01,  ...,  2.9810e-01,
          -1.578

Epoch 0:   4%|▍         | 241/5879 [02:46<1:04:55,  1.45it/s, v_num=7]
 tensor([[[-0.8018, -0.0486,  0.0318,  ..., -1.6133, -0.9609,  0.5439],
         [-0.8057, -0.0553,  0.0324,  ..., -1.6201, -0.9492,  0.5444],
         [-1.9639, -1.9414,  0.3564,  ...,  1.9395,  0.3635, -1.7139]],

        [[-0.7534, -0.0137,  0.0109,  ..., -1.6064, -1.0576,  0.5742],
         [-0.7656, -0.0261,  0.0156,  ..., -1.6172, -1.0283,  0.5718],
         [-1.2881, -1.5713,  0.1310,  ...,  0.1599, -0.2856, -0.4321]],

        [[-0.9678, -0.2111,  0.0833,  ..., -1.6094, -0.6094,  0.4158],
         [-0.9639, -0.2102,  0.0817,  ..., -1.6191, -0.6133,  0.4219],
         [-2.3301, -2.5039,  0.3425,  ...,  2.2480,  0.7686, -2.1250]],

        ...,

        [[-0.8672, -0.1078,  0.0538,  ..., -1.6143, -0.8232,  0.4880],
         [-0.8652, -0.1109,  0.0514,  ..., -1.6230, -0.8203,  0.4937],
         [-1.8555, -1.8252,  0.3313,  ...,  1.7344,  0.2888, -1.5410]],

        [[-0.9800, -0.2268,  0.0866,  ..., -1.6152, -0

Epoch 0:   4%|▍         | 246/5879 [02:56<1:07:13,  1.40it/s, v_num=7]
 tensor([[[-1.1582, -1.4658, -0.0516,  ..., -0.2277, -1.0107, -0.0805],
         [-1.1182, -1.3789, -0.1036,  ..., -0.3228, -1.1875,  0.0271],
         [-1.2129, -1.5986,  0.0318,  ..., -0.0643, -0.6294, -0.2573]],

        [[-1.0381, -0.2202,  0.1381,  ..., -1.6348, -0.5664,  0.3823],
         [-1.0342, -0.2209,  0.1349,  ..., -1.6426, -0.5708,  0.3892],
         [-2.3926, -2.5625,  0.3157,  ...,  2.2891,  0.8379, -2.1426]],

        [[-1.0410, -0.2223,  0.1412,  ..., -1.6299, -0.5630,  0.3860],
         [-1.0381, -0.2250,  0.1384,  ..., -1.6396, -0.5640,  0.3921],
         [-2.3574, -2.5000,  0.3240,  ...,  2.3008,  0.7725, -2.1152]],

        ...,

        [[-1.0146, -0.1934,  0.1337,  ..., -1.6338, -0.6211,  0.4023],
         [-1.0117, -0.1937,  0.1316,  ..., -1.6406, -0.6221,  0.4067],
         [-2.3535, -2.5078,  0.3208,  ...,  2.2988,  0.7759, -2.1211]],

        [[-2.6895, -3.0020,  0.2065,  ...,  1.9463,  1

Epoch 0:   4%|▍         | 252/5879 [02:57<1:06:10,  1.42it/s, v_num=7]
 tensor([[[-1.1533, -0.2607,  0.2131,  ..., -1.6475, -0.4575,  0.3228],
         [-1.1475, -0.2615,  0.2087,  ..., -1.6582, -0.4670,  0.3325],
         [-2.4297, -2.5547,  0.2920,  ...,  2.3516,  0.8481, -2.1309]],

        [[-0.9292, -0.0353,  0.1493,  ..., -1.6621, -0.9214,  0.4919],
         [-0.9302, -0.0389,  0.1477,  ..., -1.6689, -0.9150,  0.4946],
         [-1.4844, -1.6953,  0.1846,  ...,  0.7793,  0.0262, -0.8530]],

        [[-0.9727, -0.0710,  0.1661,  ..., -1.6611, -0.8345,  0.4592],
         [-0.9722, -0.0750,  0.1636,  ..., -1.6689, -0.8315,  0.4644],
         [-2.0098, -2.0020,  0.3066,  ...,  2.0938,  0.3591, -1.7471]],

        ...,

        [[-2.6816, -2.9180,  0.2101,  ...,  2.0801,  1.3223, -2.2266],
         [-2.7227, -2.9746,  0.1842,  ...,  1.9932,  1.4033, -2.2246],
         [-2.5859, -2.7754,  0.2615,  ...,  2.2344,  1.1240, -2.2148]],

        [[-1.0605, -0.1583,  0.1959,  ..., -1.6543, -0

Epoch 0:   4%|▍         | 258/5879 [02:59<1:05:17,  1.43it/s, v_num=7]
 tensor([[[-2.6973, -2.8691,  0.1949,  ...,  2.1543,  1.2910, -2.2051],
         [-2.7305, -2.9141,  0.1755,  ...,  2.0879,  1.3545, -2.2051],
         [-2.5996, -2.7285,  0.2423,  ...,  2.3066,  1.0947, -2.1914]],

        [[-1.0654, -0.0803,  0.2402,  ..., -1.6758, -0.7749,  0.4180],
         [-1.0625, -0.0762,  0.2397,  ..., -1.6768, -0.7778,  0.4192],
         [-2.3262, -2.3633,  0.2871,  ...,  2.4414,  0.6470, -2.0410]],

        [[-0.9473,  0.0135,  0.1910,  ..., -1.6855, -1.0078,  0.5078],
         [-0.9468,  0.0140,  0.1906,  ..., -1.6904, -1.0049,  0.5088],
         [-1.2549, -1.7236,  0.0297,  ...,  0.0829, -0.3225, -0.4250]],

        ...,

        [[-1.1328, -1.5908, -0.1208,  ..., -0.3406, -1.0391, -0.0526],
         [-1.1182, -1.5537, -0.1418,  ..., -0.3823, -1.1191, -0.0092],
         [-1.1719, -1.7041, -0.0526,  ..., -0.1936, -0.7075, -0.2052]],

        [[-0.9561,  0.0083,  0.1958,  ..., -1.6855, -0

Epoch 0:   4%|▍         | 264/5879 [03:01<1:04:24,  1.45it/s, v_num=7]
 tensor([[[-1.0996, -0.0367,  0.2913,  ..., -1.7031, -0.8257,  0.4126],
         [-1.0977, -0.0367,  0.2888,  ..., -1.7100, -0.8267,  0.4182],
         [-1.5430, -1.8076,  0.1194,  ...,  0.9746,  0.0126, -0.9277]],

        [[-1.1943, -0.1288,  0.3242,  ..., -1.6924, -0.6318,  0.3447],
         [-1.1934, -0.1302,  0.3220,  ..., -1.6992, -0.6318,  0.3491],
         [-2.4043, -2.4062,  0.2539,  ...,  2.5039,  0.7114, -2.0527]],

        [[-2.7051, -2.8125,  0.1810,  ...,  2.2383,  1.2451, -2.1855],
         [-2.7676, -2.8926,  0.1471,  ...,  2.1152,  1.3613, -2.1855],
         [-2.6094, -2.6777,  0.2231,  ...,  2.3828,  1.0586, -2.1680]],

        ...,

        [[-1.1641, -0.0969,  0.3154,  ..., -1.6953, -0.6963,  0.3669],
         [-1.1631, -0.0992,  0.3127,  ..., -1.7031, -0.6958,  0.3728],
         [-2.3574, -2.3535,  0.2600,  ...,  2.5078,  0.6484, -2.0293]],

        [[-1.1133, -1.6338, -0.1665,  ..., -0.4199, -1

Epoch 0:   5%|▍         | 269/5879 [03:14<1:07:41,  1.38it/s, v_num=7]
 tensor([[[-1.3125, -0.1794,  0.3918,  ..., -1.6934, -0.4941,  0.2686],
         [-1.3096, -0.1812,  0.3879,  ..., -1.7051, -0.4983,  0.2756],
         [-2.4766, -2.4570,  0.2251,  ...,  2.5449,  0.7866, -2.0723]],

        [[-1.8467, -0.9331,  0.3630,  ..., -1.4023,  0.6665, -0.2876],
         [-1.7344, -0.7632,  0.3762,  ..., -1.5117,  0.4204, -0.1527],
         [-2.5410, -2.5449,  0.2225,  ...,  2.5078,  0.9004, -2.1094]],

        [[-2.7246, -2.7832,  0.1620,  ...,  2.2871,  1.2344, -2.1680],
         [-2.7812, -2.8516,  0.1332,  ...,  2.1758,  1.3359, -2.1699],
         [-2.6172, -2.6387,  0.2068,  ...,  2.4453,  1.0293, -2.1484]],

        ...,

        [[-2.8008, -2.8730,  0.1265,  ...,  2.1230,  1.3975, -2.1621],
         [-2.9180, -2.9609,  0.0615,  ...,  1.3750,  1.8223, -1.9824],
         [-2.5859, -2.6016,  0.2156,  ...,  2.4785,  0.9766, -2.1348]],

        [[-1.3184, -0.1886,  0.3909,  ..., -1.6963, -0

Epoch 0:   5%|▍         | 275/5879 [03:16<1:06:40,  1.40it/s, v_num=7]
 tensor([[[-0.9551,  0.1445,  0.2737,  ..., -1.7764, -1.3164,  0.5854],
         [-0.9639,  0.1385,  0.2766,  ..., -1.7842, -1.2979,  0.5825],
         [-1.1338, -1.8838, -0.1221,  ..., -0.2705, -0.6152, -0.2158]],

        [[-2.5703, -2.1660,  0.2172,  ..., -0.1107,  1.8301, -1.2646],
         [-2.2246, -1.4775,  0.3330,  ..., -0.9351,  1.3193, -0.7285],
         [-2.5488, -2.4922,  0.1957,  ...,  2.5859,  0.8608, -2.0840]],

        [[-1.2285, -0.0177,  0.4136,  ..., -1.7422, -0.7930,  0.3606],
         [-1.2314, -0.0264,  0.4106,  ..., -1.7510, -0.7852,  0.3660],
         [-1.9414, -1.9512,  0.1569,  ...,  2.0742,  0.1940, -1.5244]],

        ...,

        [[-2.2852, -1.5957,  0.3152,  ..., -0.8066,  1.4385, -0.8203],
         [-1.8789, -0.8833,  0.4131,  ..., -1.4463,  0.5894, -0.2671],
         [-2.5469, -2.4902,  0.1959,  ...,  2.5859,  0.8560, -2.0820]],

        [[-1.3242, -0.1092,  0.4446,  ..., -1.7256, -0

Epoch 0:   5%|▍         | 281/5879 [03:18<1:05:47,  1.42it/s, v_num=7]
 tensor([[[-1.7988, -0.6035,  0.5161,  ..., -1.5605,  0.2808, -0.1350],
         [-1.7373, -0.5239,  0.5186,  ..., -1.6123,  0.1477, -0.0658],
         [-2.5488, -2.4395,  0.1669,  ...,  2.6602,  0.8145, -2.0547]],

        [[-1.2432,  0.0363,  0.4541,  ..., -1.7744, -0.8848,  0.3767],
         [-1.2432,  0.0340,  0.4519,  ..., -1.7812, -0.8818,  0.3806],
         [-1.7119, -1.9316,  0.0561,  ...,  1.5303,  0.0423, -1.1631]],

        [[-0.9805, -1.5381, -0.3560,  ..., -0.8271, -1.5830,  0.2744],
         [-0.9624, -1.4814, -0.3723,  ..., -0.8730, -1.6533,  0.3198],
         [-1.0898, -1.9209, -0.1892,  ..., -0.4133, -0.7803, -0.1295]],

        ...,

        [[-1.3232, -0.0275,  0.4893,  ..., -1.7598, -0.7285,  0.3147],
         [-1.3232, -0.0302,  0.4866,  ..., -1.7666, -0.7271,  0.3186],
         [-2.2930, -2.1660,  0.1877,  ...,  2.6445,  0.4634, -1.8945]],

        [[-2.7891, -2.7422,  0.1024,  ...,  2.3555,  1

Epoch 0:   5%|▍         | 287/5879 [03:20<1:04:56,  1.43it/s, v_num=7]
 tensor([[[-1.3730e+00,  3.3131e-03,  5.4346e-01,  ..., -1.7812e+00,
          -7.5732e-01,  3.0054e-01],
         [-1.3730e+00, -2.8586e-04,  5.4053e-01,  ..., -1.7891e+00,
          -7.5439e-01,  3.0518e-01],
         [-2.3320e+00, -2.1680e+00,  1.5295e-01,  ...,  2.7324e+00,
           4.7266e-01, -1.8936e+00]],

        [[-1.4072e+00, -2.6886e-02,  5.5615e-01,  ..., -1.7744e+00,
          -6.8945e-01,  2.7490e-01],
         [-1.4102e+00, -3.3325e-02,  5.5322e-01,  ..., -1.7812e+00,
          -6.8213e-01,  2.7759e-01],
         [-2.4824e+00, -2.3105e+00,  1.3904e-01,  ...,  2.7617e+00,
           6.6309e-01, -1.9814e+00]],

        [[-2.7949e+00, -2.6875e+00,  7.8918e-02,  ...,  2.4375e+00,
           1.2021e+00, -2.1055e+00],
         [-2.8750e+00, -2.7656e+00,  4.8340e-02,  ...,  2.2715e+00,
           1.3447e+00, -2.1016e+00],
         [-2.6406e+00, -2.5039e+00,  1.2842e-01,  ...,  2.6621e+00,
           9.252

Epoch 0:   5%|▍         | 292/5879 [03:25<1:05:40,  1.42it/s, v_num=7]
 tensor([[[-1.4121,  0.0302,  0.5884,  ..., -1.8018, -0.7856,  0.2905],
         [-1.4121,  0.0264,  0.5840,  ..., -1.8096, -0.7842,  0.2947],
         [-2.4277, -2.2227,  0.1158,  ...,  2.8145,  0.5576, -1.9258]],

        [[-1.5088, -0.0542,  0.6187,  ..., -1.7773, -0.5938,  0.2156],
         [-1.5098, -0.0632,  0.6138,  ..., -1.7861, -0.5874,  0.2191],
         [-2.4844, -2.2754,  0.1110,  ...,  2.8145,  0.6343, -1.9580]],

        [[-1.5000, -0.0461,  0.6162,  ..., -1.7812, -0.6118,  0.2245],
         [-1.5010, -0.0500,  0.6143,  ..., -1.7881, -0.6084,  0.2267],
         [-2.5020, -2.2988,  0.1104,  ...,  2.8125,  0.6641, -1.9688]],

        ...,

        [[-1.1279,  0.1830,  0.4421,  ..., -1.8584, -1.3125,  0.5312],
         [-1.1406,  0.1755,  0.4482,  ..., -1.8643, -1.2852,  0.5249],
         [-1.1162, -2.0469, -0.1958,  ..., -0.2678, -0.5508, -0.2522]],

        [[-2.7734, -2.6250,  0.0652,  ...,  2.5430,  1

Epoch 0:   5%|▌         | 297/5879 [03:27<1:05:09,  1.43it/s, v_num=7]
 tensor([[[-2.8145, -2.6230,  0.0297,  ...,  2.5371,  1.1621, -2.0664],
         [-2.9277, -2.7246, -0.0041,  ...,  2.2871,  1.3662, -2.0566],
         [-2.6406, -2.4297,  0.0757,  ...,  2.7793,  0.8643, -2.0273]],

        [[-1.2852,  0.1553,  0.5513,  ..., -1.8604, -1.1172,  0.4187],
         [-1.2969,  0.1421,  0.5532,  ..., -1.8691, -1.0928,  0.4167],
         [-1.0596, -2.0957, -0.2335,  ..., -0.4001, -0.6436, -0.1787]],

        [[-1.7930, -0.2998,  0.6865,  ..., -1.7061, -0.1151, -0.0074],
         [-1.7588, -0.2595,  0.6841,  ..., -1.7314, -0.1910,  0.0286],
         [-2.5488, -2.3125,  0.0799,  ...,  2.8477,  0.7041, -1.9688]],

        ...,

        [[-1.3594,  0.1197,  0.5913,  ..., -1.8467, -0.9795,  0.3547],
         [-1.3594,  0.1170,  0.5894,  ..., -1.8525, -0.9771,  0.3579],
         [-1.0928, -2.0977, -0.2162,  ..., -0.2927, -0.5601, -0.2477]],

        [[-1.5049,  0.0152,  0.6548,  ..., -1.8086, -0

Epoch 0:   5%|▌         | 303/5879 [03:33<1:05:27,  1.42it/s, v_num=7]
 tensor([[[-1.7959, -0.1984,  0.7466,  ..., -1.7588, -0.2522,  0.0258],
         [-1.7852, -0.1875,  0.7441,  ..., -1.7695, -0.2756,  0.0380],
         [-2.5547, -2.2832,  0.0414,  ...,  2.9102,  0.6846, -1.9482]],

        [[-1.6523, -0.0405,  0.7329,  ..., -1.8115, -0.5527,  0.1565],
         [-1.6572, -0.0504,  0.7300,  ..., -1.8193, -0.5410,  0.1555],
         [-2.5215, -2.2422,  0.0413,  ...,  2.9238,  0.6299, -1.9258]],

        [[-1.5498,  0.0508,  0.7070,  ..., -1.8389, -0.7515,  0.2347],
         [-1.5498,  0.0476,  0.7031,  ..., -1.8467, -0.7495,  0.2388],
         [-2.4492, -2.1797,  0.0446,  ...,  2.9375,  0.5308, -1.8896]],

        ...,

        [[-1.6748, -0.0625,  0.7373,  ..., -1.8047, -0.5049,  0.1345],
         [-1.6768, -0.0717,  0.7324,  ..., -1.8125, -0.4976,  0.1360],
         [-2.5352, -2.2598,  0.0410,  ...,  2.9180,  0.6519, -1.9355]],

        [[-1.3994,  0.1486,  0.6377,  ..., -1.8789, -1

Epoch 0:   5%|▌         | 309/5879 [03:34<1:04:33,  1.44it/s, v_num=7]
 tensor([[[-2.8926e+00, -2.6074e+00, -5.3314e-02,  ...,  2.5195e+00,
           1.2236e+00, -2.0195e+00],
         [-3.0391e+00, -2.6836e+00, -5.7007e-02,  ...,  2.0273e+00,
           1.5371e+00, -1.9561e+00],
         [-2.6387e+00, -2.3496e+00,  3.3331e-04,  ...,  2.9062e+00,
           7.9395e-01, -1.9746e+00]],

        [[-1.4902e+00,  1.5222e-01,  7.0752e-01,  ..., -1.9004e+00,
          -9.7754e-01,  3.0591e-01],
         [-1.4883e+00,  1.5088e-01,  7.0459e-01,  ..., -1.9062e+00,
          -9.7803e-01,  3.0957e-01],
         [-1.0312e+00, -2.2207e+00, -2.7148e-01,  ..., -3.9941e-01,
          -6.1572e-01, -2.0508e-01]],

        [[-1.6328e+00,  5.4596e-02,  7.7002e-01,  ..., -1.8623e+00,
          -7.1826e-01,  1.9788e-01],
         [-1.6318e+00,  5.4779e-02,  7.6807e-01,  ..., -1.8662e+00,
          -7.1875e-01,  2.0044e-01],
         [-2.4727e+00, -2.1660e+00,  2.0943e-03,  ...,  2.9922e+00,
           5.336

Epoch 0:   5%|▌         | 315/5879 [03:36<1:03:52,  1.45it/s, v_num=7]
 tensor([[[-2.7676, -2.4531, -0.0662,  ...,  2.8066,  0.9683, -1.9902],
         [-2.8203, -2.5059, -0.0787,  ...,  2.7246,  1.0488, -2.0000],
         [-2.6562, -2.3320, -0.0454,  ...,  2.9473,  0.7866, -1.9580]],

        [[-2.7285, -2.4082, -0.0589,  ...,  2.8672,  0.9033, -1.9795],
         [-2.7461, -2.4277, -0.0630,  ...,  2.8438,  0.9287, -1.9854],
         [-2.6641, -2.3379, -0.0478,  ...,  2.9414,  0.7979, -1.9609]],

        [[-2.2949, -0.6992,  0.7744,  ..., -1.4668,  0.6025, -0.4338],
         [-2.2109, -0.5654,  0.8022,  ..., -1.5791,  0.3989, -0.3289],
         [-2.5742, -2.2324, -0.0422,  ...,  3.0156,  0.6519, -1.9102]],

        ...,

        [[-1.6846,  0.0826,  0.8213,  ..., -1.8955, -0.7451,  0.1876],
         [-1.6826,  0.0817,  0.8179,  ..., -1.9014, -0.7471,  0.1914],
         [-2.3965, -2.0859, -0.0357,  ...,  3.0332,  0.4114, -1.8135]],

        [[-1.7285,  0.0432,  0.8330,  ..., -1.8818, -0

Epoch 0:   5%|▌         | 321/5879 [03:42<1:04:15,  1.44it/s, v_num=7]
 tensor([[[-1.8506,  0.0079,  0.9053,  ..., -1.8906, -0.5469,  0.0856],
         [-1.8516,  0.0044,  0.9019,  ..., -1.8965, -0.5444,  0.0875],
         [-2.5117, -2.1348, -0.0878,  ...,  3.0898,  0.5210, -1.8486]],

        [[-1.7002,  0.1333,  0.8584,  ..., -1.9414, -0.8374,  0.2050],
         [-1.6973,  0.1346,  0.8564,  ..., -1.9434, -0.8408,  0.2076],
         [-2.2617, -2.0117, -0.0754,  ...,  2.9824,  0.2556, -1.7119]],

        [[-1.4336,  0.2515,  0.7119,  ..., -2.0195, -1.2910,  0.4243],
         [-1.4326,  0.2522,  0.7114,  ..., -2.0215, -1.2900,  0.4255],
         [-0.9053, -2.3281, -0.3528,  ..., -0.6963, -0.8374, -0.0574]],

        ...,

        [[-2.0566, -0.2245,  0.9087,  ..., -1.7910, -0.1152, -0.1028],
         [-2.0312, -0.1986,  0.9062,  ..., -1.8154, -0.1709, -0.0762],
         [-2.5488, -2.1680, -0.0861,  ...,  3.0801,  0.5767, -1.8701]],

        [[-1.8340,  0.0211,  0.8979,  ..., -1.9004, -0

Epoch 0:   6%|▌         | 327/5879 [03:44<1:03:35,  1.46it/s, v_num=7]
 tensor([[[-1.9541e+00, -1.2909e-02,  9.6436e-01,  ..., -1.9014e+00,
          -4.7314e-01,  3.1921e-02],
         [-1.9512e+00, -1.3206e-02,  9.6143e-01,  ..., -1.9062e+00,
          -4.7656e-01,  3.5004e-02],
         [-2.5527e+00, -2.1328e+00, -1.2964e-01,  ...,  3.1230e+00,
           5.4395e-01, -1.8496e+00]],

        [[-8.4766e-01, -2.1309e+00, -4.9976e-01,  ..., -1.1162e+00,
          -1.4873e+00,  2.6416e-01],
         [-8.4180e-01, -2.0918e+00, -5.1270e-01,  ..., -1.1582e+00,
          -1.5479e+00,  3.0200e-01],
         [-8.5986e-01, -2.3613e+00, -3.9990e-01,  ..., -8.1836e-01,
          -9.6338e-01,  1.0315e-02]],

        [[-1.9375e+00, -6.8545e-06,  9.6045e-01,  ..., -1.9072e+00,
          -5.0391e-01,  4.7485e-02],
         [-1.9365e+00,  2.9411e-03,  9.6094e-01,  ..., -1.9092e+00,
          -5.0537e-01,  4.8218e-02],
         [-2.5586e+00, -2.1426e+00, -1.2939e-01,  ...,  3.1211e+00,
           5.537

Epoch 0:   6%|▌         | 332/5879 [03:45<1:02:55,  1.47it/s, v_num=7]
 tensor([[[-1.8467e+00,  1.3550e-01,  9.6729e-01,  ..., -1.9756e+00,
          -7.8760e-01,  1.5051e-01],
         [-1.8467e+00,  1.3086e-01,  9.6289e-01,  ..., -1.9814e+00,
          -7.8662e-01,  1.5381e-01],
         [-2.4922e+00, -2.0508e+00, -1.6211e-01,  ...,  3.1777e+00,
           4.2383e-01, -1.7979e+00]],

        [[-1.9316e+00,  6.7932e-02,  9.9365e-01,  ..., -1.9473e+00,
          -6.2354e-01,  8.4656e-02],
         [-1.9307e+00,  6.5552e-02,  9.9072e-01,  ..., -1.9521e+00,
          -6.2354e-01,  8.6853e-02],
         [-2.5469e+00, -2.0938e+00, -1.6101e-01,  ...,  3.1699e+00,
           5.0146e-01, -1.8271e+00]],

        [[-1.8926e+00,  1.0126e-01,  9.8438e-01,  ..., -1.9600e+00,
          -6.9873e-01,  1.1646e-01],
         [-1.8906e+00,  1.0187e-01,  9.8193e-01,  ..., -1.9648e+00,
          -7.0215e-01,  1.1969e-01],
         [-2.5039e+00, -2.0566e+00, -1.6138e-01,  ...,  3.1797e+00,
           4.394

Epoch 0:   6%|▌         | 338/5879 [03:47<1:02:09,  1.49it/s, v_num=7]
 tensor([[[-1.9248,  0.1404,  1.0293,  ..., -1.9922, -0.7632,  0.1226],
         [-1.9248,  0.1410,  1.0283,  ..., -1.9941, -0.7632,  0.1237],
         [-2.3750, -1.9668, -0.1906,  ...,  3.1855,  0.2644, -1.7217]],

        [[-1.9336,  0.1334,  1.0312,  ..., -1.9893, -0.7471,  0.1178],
         [-1.9346,  0.1315,  1.0303,  ..., -1.9922, -0.7461,  0.1179],
         [-2.4727, -2.0059, -0.1963,  ...,  3.2227,  0.3652, -1.7676]],

        [[-0.7988, -2.2129, -0.5435,  ..., -1.2246, -1.5342,  0.3110],
         [-0.7988, -2.2188, -0.5396,  ..., -1.2148, -1.5225,  0.3008],
         [-0.8047, -2.4609, -0.4387,  ..., -0.9004, -0.9897,  0.0378]],

        ...,

        [[-2.0293,  0.0561,  1.0586,  ..., -1.9531, -0.5610,  0.0371],
         [-2.0293,  0.0516,  1.0547,  ..., -1.9590, -0.5601,  0.0392],
         [-2.5547, -2.0605, -0.1982,  ...,  3.2109,  0.4722, -1.8076]],

        [[-1.9443,  0.1250,  1.0352,  ..., -1.9863, -0

Epoch 0:   6%|▌         | 344/5879 [04:02<1:04:57,  1.42it/s, v_num=7]
 tensor([[[-2.8047, -2.2832, -0.2432,  ...,  3.0176,  0.8179, -1.8984],
         [-2.8359, -2.3086, -0.2472,  ...,  2.9766,  0.8589, -1.9043],
         [-2.6660, -2.1465, -0.2305,  ...,  3.1836,  0.6099, -1.8555]],

        [[-1.9951,  0.1530,  1.0898,  ..., -2.0059, -0.7549,  0.1009],
         [-1.9951,  0.1527,  1.0879,  ..., -2.0098, -0.7563,  0.1034],
         [-2.4824, -1.9766, -0.2347,  ...,  3.2656,  0.3435, -1.7529]],

        [[-2.0000,  0.1451,  1.0898,  ..., -2.0059, -0.7427,  0.0995],
         [-2.0000,  0.1450,  1.0889,  ..., -2.0078, -0.7432,  0.1003],
         [-2.5059, -1.9902, -0.2360,  ...,  3.2637,  0.3711, -1.7637]],

        ...,

        [[-0.7788, -2.4277, -0.5103,  ..., -1.0957, -1.2549,  0.1724],
         [-0.7783, -2.4238, -0.5122,  ..., -1.1025, -1.2646,  0.1777],
         [-0.7734, -2.5020, -0.4673,  ..., -0.9736, -1.0537,  0.0713]],

        [[-1.8340,  0.2427,  1.0127,  ..., -2.0664, -1

Epoch 0:   6%|▌         | 350/5879 [04:03<1:04:13,  1.43it/s, v_num=7]
 tensor([[[-0.7539, -2.4160, -0.5601,  ..., -1.2217, -1.3926,  0.2532],
         [-0.7534, -2.4102, -0.5620,  ..., -1.2295, -1.4033,  0.2593],
         [-0.7466, -2.5645, -0.4844,  ..., -0.9961, -1.0264,  0.0681]],

        [[-2.0000,  0.2035,  1.1143,  ..., -2.0469, -0.8682,  0.1273],
         [-1.9990,  0.2054,  1.1133,  ..., -2.0469, -0.8701,  0.1287],
         [-2.0762, -1.9502, -0.2854,  ...,  2.8945,  0.0035, -1.4990]],

        [[-2.4004, -0.1880,  1.1543,  ..., -1.8350, -0.0483, -0.2432],
         [-2.3594, -0.1376,  1.1582,  ..., -1.8760, -0.1451, -0.1965],
         [-2.5918, -2.0254, -0.2712,  ...,  3.2812,  0.4587, -1.7920]],

        ...,

        [[-2.0586,  0.1661,  1.1396,  ..., -2.0234, -0.7603,  0.0839],
         [-2.0605,  0.1641,  1.1387,  ..., -2.0254, -0.7578,  0.0837],
         [-2.5332, -1.9746, -0.2744,  ...,  3.3027,  0.3767, -1.7588]],

        [[-2.0410,  0.1820,  1.1348,  ..., -2.0312, -0

Epoch 0:   6%|▌         | 355/5879 [04:05<1:03:39,  1.45it/s, v_num=7]
 tensor([[[-2.1133,  0.1766,  1.1855,  ..., -2.0371, -0.7642,  0.0684],
         [-2.1191,  0.1732,  1.1855,  ..., -2.0391, -0.7588,  0.0662],
         [-2.5469, -1.9551, -0.3074,  ...,  3.3301,  0.3640, -1.7490]],

        [[-2.0508,  0.2244,  1.1621,  ..., -2.0566, -0.8784,  0.1116],
         [-2.0488,  0.2274,  1.1631,  ..., -2.0566, -0.8799,  0.1121],
         [-2.3008, -1.8877, -0.2993,  ...,  3.2344,  0.1238, -1.6367]],

        [[-2.9336, -2.2949, -0.3103,  ...,  2.8945,  0.9380, -1.8740],
         [-3.0059, -2.3301, -0.3047,  ...,  2.7559,  1.0449, -1.8691],
         [-2.6426, -2.0566, -0.2983,  ...,  3.2754,  0.5171, -1.8115]],

        ...,

        [[-2.1406,  0.1600,  1.1953,  ..., -2.0273, -0.7124,  0.0453],
         [-2.1406,  0.1580,  1.1943,  ..., -2.0293, -0.7114,  0.0460],
         [-2.5527, -1.9609, -0.3076,  ...,  3.3281,  0.3735, -1.7539]],

        [[-2.0469,  0.2250,  1.1582,  ..., -2.0625, -0

Epoch 0:   6%|▌         | 361/5879 [04:07<1:03:01,  1.46it/s, v_num=7]
 tensor([[[-2.1035e+00,  2.4658e-01,  1.2129e+00,  ..., -2.0781e+00,
          -9.0039e-01,  9.9426e-02],
         [-2.0996e+00,  2.4780e-01,  1.2080e+00,  ..., -2.0840e+00,
          -9.1016e-01,  1.0565e-01],
         [-1.5732e+00, -2.2402e+00, -4.3384e-01,  ...,  1.7783e+00,
          -2.6196e-01, -9.8389e-01]],

        [[-2.1055e+00,  2.4524e-01,  1.2109e+00,  ..., -2.0801e+00,
          -8.9893e-01,  9.9670e-02],
         [-2.1016e+00,  2.4658e-01,  1.2090e+00,  ..., -2.0820e+00,
          -9.0332e-01,  1.0303e-01],
         [-2.1270e+00, -1.9004e+00, -3.5156e-01,  ...,  3.0430e+00,
          -1.7151e-02, -1.5117e+00]],

        [[-2.1250e+00,  2.3364e-01,  1.2217e+00,  ..., -2.0723e+00,
          -8.6426e-01,  8.3252e-02],
         [-2.1230e+00,  2.3511e-01,  1.2197e+00,  ..., -2.0742e+00,
          -8.6816e-01,  8.5693e-02],
         [-2.2676e+00, -1.8633e+00, -3.4253e-01,  ...,  3.2402e+00,
           7.116

Epoch 0:   6%|▌         | 367/5879 [04:19<1:04:53,  1.42it/s, v_num=7]
 tensor([[[-2.2012e+00,  2.4268e-01,  1.2754e+00,  ..., -2.0879e+00,
          -8.4619e-01,  5.4596e-02],
         [-2.2012e+00,  2.4377e-01,  1.2744e+00,  ..., -2.0898e+00,
          -8.4814e-01,  5.5573e-02],
         [-2.3945e+00, -1.8389e+00, -3.9062e-01,  ...,  3.3789e+00,
           1.3831e-01, -1.6475e+00]],

        [[-2.1562e+00,  2.6611e-01,  1.2520e+00,  ..., -2.1035e+00,
          -9.2188e-01,  9.0149e-02],
         [-2.1543e+00,  2.6685e-01,  1.2480e+00,  ..., -2.1074e+00,
          -9.2920e-01,  9.4849e-02],
         [-2.2891e+00, -1.8389e+00, -3.8696e-01,  ...,  3.2988e+00,
           5.5878e-02, -1.5957e+00]],

        [[-2.2793e+00,  1.8274e-01,  1.3027e+00,  ..., -2.0586e+00,
          -6.9775e-01, -3.3283e-03],
         [-2.2793e+00,  1.8188e-01,  1.3008e+00,  ..., -2.0625e+00,
          -7.0068e-01, -1.8282e-03],
         [-2.5488e+00, -1.8916e+00, -3.9355e-01,  ...,  3.4004e+00,
           3.027

Epoch 0:   6%|▋         | 373/5879 [04:20<1:04:09,  1.43it/s, v_num=7]
 tensor([[[-2.2480,  0.2612,  1.3242,  ..., -2.1211, -0.8750,  0.0489],
         [-2.2500,  0.2605,  1.3223,  ..., -2.1230, -0.8760,  0.0502],
         [-2.4844, -1.8428, -0.4421,  ...,  3.4277,  0.1958, -1.6719]],

        [[-2.3789,  0.1654,  1.3652,  ..., -2.0723, -0.6294, -0.0540],
         [-2.3789,  0.1643,  1.3623,  ..., -2.0762, -0.6328, -0.0518],
         [-2.5684, -1.8955, -0.4358,  ...,  3.4141,  0.3057, -1.7148]],

        [[-0.6616, -2.5859, -0.6411,  ..., -1.4150, -1.4521,  0.3337],
         [-0.6621, -2.5957, -0.6396,  ..., -1.4102, -1.4414,  0.3298],
         [-0.6416, -2.7676, -0.5527,  ..., -1.1377, -1.0322,  0.1148]],

        ...,

        [[-2.8164, -2.1230, -0.4280,  ...,  3.1660,  0.6631, -1.8115],
         [-2.9492, -2.1992, -0.4243,  ...,  2.9727,  0.8350, -1.8232],
         [-2.6328, -1.9697, -0.4275,  ...,  3.3672,  0.4089, -1.7578]],

        [[-2.2949,  0.2283,  1.3438,  ..., -2.1055, -0

Epoch 0:   6%|▋         | 379/5879 [04:22<1:03:26,  1.44it/s, v_num=7]
 tensor([[[-2.2949,  0.2776,  1.3770,  ..., -2.1641, -0.8965,  0.0423],
         [-2.2949,  0.2761,  1.3730,  ..., -2.1680, -0.8965,  0.0444],
         [-2.4980, -1.8506, -0.4980,  ...,  3.4395,  0.1960, -1.6777]],

        [[-2.2969,  0.2773,  1.3789,  ..., -2.1621, -0.8906,  0.0392],
         [-2.3008,  0.2744,  1.3770,  ..., -2.1660, -0.8882,  0.0385],
         [-2.4688, -1.8369, -0.5000,  ...,  3.4395,  0.1637, -1.6641]],

        [[-2.7109, -0.1738,  1.3730,  ..., -1.9014,  0.0227, -0.3430],
         [-2.6230, -0.0504,  1.3965,  ..., -1.9990, -0.2085, -0.2462],
         [-2.5801, -1.9170, -0.4871,  ...,  3.4102,  0.3132, -1.7256]],

        ...,

        [[-2.2852,  0.2825,  1.3730,  ..., -2.1680, -0.9097,  0.0486],
         [-2.2852,  0.2817,  1.3701,  ..., -2.1699, -0.9116,  0.0508],
         [-2.4785, -1.8418, -0.4993,  ...,  3.4395,  0.1748, -1.6689]],

        [[-2.7305, -2.0645, -0.4788,  ...,  3.2695,  0

Epoch 0:   7%|▋         | 384/5879 [04:24<1:02:58,  1.45it/s, v_num=7]
 tensor([[[-0.6294, -2.6465, -0.6753,  ..., -1.4912, -1.4912,  0.3945],
         [-0.6304, -2.6777, -0.6689,  ..., -1.4648, -1.4463,  0.3733],
         [-0.6040, -2.8496, -0.5835,  ..., -1.1836, -1.0371,  0.1560]],

        [[-2.3027,  0.3149,  1.3926,  ..., -2.2070, -0.9692,  0.0651],
         [-2.3008,  0.3137,  1.3877,  ..., -2.2129, -0.9736,  0.0694],
         [-2.3730, -1.8125, -0.5430,  ...,  3.4336,  0.0598, -1.6328]],

        [[-2.5195,  0.1451,  1.4600,  ..., -2.1172, -0.5474, -0.1090],
         [-2.5488,  0.1207,  1.4570,  ..., -2.1113, -0.4983, -0.1394],
         [-2.5645, -1.8945, -0.5308,  ...,  3.4297,  0.2686, -1.7178]],

        ...,

        [[-2.3008,  0.3135,  1.3887,  ..., -2.2109, -0.9731,  0.0683],
         [-2.3008,  0.3125,  1.3857,  ..., -2.2148, -0.9751,  0.0706],
         [-2.3789, -1.8184, -0.5415,  ...,  3.4336,  0.0663, -1.6367]],

        [[-2.8164, -2.1250, -0.5142,  ...,  3.1719,  0

Epoch 0:   7%|▋         | 389/5879 [04:34<1:04:40,  1.41it/s, v_num=7]
 tensor([[[-2.2910e+00,  3.5620e-01,  1.3828e+00,  ..., -2.2598e+00,
          -1.0771e+00,  1.1444e-01],
         [-2.2832e+00,  3.5718e-01,  1.3740e+00,  ..., -2.2676e+00,
          -1.0908e+00,  1.2427e-01],
         [-1.5879e+00, -2.1738e+00, -6.4502e-01,  ...,  2.2090e+00,
          -2.9810e-01, -1.0859e+00]],

        [[-2.4316e+00,  2.8979e-01,  1.4648e+00,  ..., -2.2051e+00,
          -8.4717e-01,  2.1057e-03],
         [-2.4375e+00,  2.8442e-01,  1.4600e+00,  ..., -2.2090e+00,
          -8.4180e-01, -1.3723e-03],
         [-2.5293e+00, -1.8506e+00, -5.7178e-01,  ...,  3.4590e+00,
           1.9409e-01, -1.7021e+00]],

        [[-2.4062e+00,  3.0811e-01,  1.4531e+00,  ..., -2.2148e+00,
          -8.9404e-01,  2.0096e-02],
         [-2.4062e+00,  3.0420e-01,  1.4492e+00,  ..., -2.2168e+00,
          -8.9209e-01,  2.0462e-02],
         [-2.4805e+00, -1.8242e+00, -5.7812e-01,  ...,  3.4688e+00,
           1.353

Epoch 0:   7%|▋         | 394/5879 [04:36<1:04:04,  1.43it/s, v_num=7]
 tensor([[[-0.5962, -2.8320, -0.6802,  ..., -1.3857, -1.3291,  0.3538],
         [-0.5957, -2.8359, -0.6802,  ..., -1.3828, -1.3213,  0.3516],
         [-0.5776, -2.9238, -0.6265,  ..., -1.2266, -1.0967,  0.2327]],

        [[-2.1816,  0.4111,  1.2998,  ..., -2.3555, -1.3223,  0.2537],
         [-2.1641,  0.4114,  1.2861,  ..., -2.3652, -1.3447,  0.2705],
         [-1.3184, -2.3965, -0.6680,  ...,  1.5713, -0.3789, -0.8101]],

        [[-2.3730,  0.3677,  1.4326,  ..., -2.2734, -1.0449,  0.0856],
         [-2.3691,  0.3684,  1.4287,  ..., -2.2793, -1.0527,  0.0918],
         [-1.1641, -2.5273, -0.6792,  ...,  1.0859, -0.4294, -0.6274]],

        ...,

        [[-2.4668,  0.3262,  1.4883,  ..., -2.2383, -0.8940,  0.0087],
         [-2.4727,  0.3220,  1.4873,  ..., -2.2402, -0.8877,  0.0052],
         [-2.4785, -1.8135, -0.6074,  ...,  3.4805,  0.1081, -1.6865]],

        [[-2.6055,  0.2208,  1.5283,  ..., -2.1816, -0

Epoch 0:   7%|▋         | 399/5879 [04:37<1:03:29,  1.44it/s, v_num=7]
 tensor([[[-0.5894, -2.7988, -0.7266,  ..., -1.4824, -1.4678,  0.4539],
         [-0.5845, -2.8750, -0.7085,  ..., -1.4004, -1.3438,  0.3899],
         [-0.5566, -2.9727, -0.6450,  ..., -1.2080, -1.0723,  0.2477]],

        [[-2.5996,  0.2998,  1.5469,  ..., -2.2383, -0.7661, -0.0567],
         [-2.6133,  0.2886,  1.5479,  ..., -2.2363, -0.7456, -0.0658],
         [-2.5469, -1.8516, -0.6221,  ...,  3.4766,  0.1754, -1.7363]],

        [[-2.6133,  0.2908,  1.5527,  ..., -2.2305, -0.7397, -0.0674],
         [-2.6133,  0.2898,  1.5488,  ..., -2.2363, -0.7461, -0.0659],
         [-2.5449, -1.8486, -0.6226,  ...,  3.4785,  0.1713, -1.7344]],

        ...,

        [[-2.5059,  0.3591,  1.5068,  ..., -2.2715, -0.9307,  0.0155],
         [-2.5078,  0.3572,  1.5039,  ..., -2.2734, -0.9326,  0.0170],
         [-2.3535, -1.7832, -0.6440,  ...,  3.4785, -0.0200, -1.6494]],

        [[-2.6523,  0.2551,  1.5586,  ..., -2.2129, -0

Epoch 0:   7%|▋         | 405/5879 [04:39<1:02:55,  1.45it/s, v_num=7]
 tensor([[[-2.6582,  0.3406,  1.5752,  ..., -2.2734, -0.7925, -0.0566],
         [-2.6641,  0.3374,  1.5732,  ..., -2.2754, -0.7905, -0.0584],
         [-2.5332, -1.8359, -0.6504,  ...,  3.5117,  0.1326, -1.7480]],

        [[-2.6465,  0.3513,  1.5703,  ..., -2.2793, -0.8169, -0.0481],
         [-2.6602,  0.3433,  1.5713,  ..., -2.2773, -0.8008, -0.0594],
         [-2.5312, -1.8340, -0.6519,  ...,  3.5117,  0.1304, -1.7461]],

        [[-2.5547,  0.3999,  1.5186,  ..., -2.3125, -0.9683,  0.0228],
         [-2.5391,  0.4031,  1.5029,  ..., -2.3242, -1.0000,  0.0415],
         [-2.3633, -1.7773, -0.6748,  ...,  3.5176, -0.0387, -1.6709]],

        ...,

        [[-2.5957,  0.3838,  1.5420,  ..., -2.2969, -0.9067, -0.0090],
         [-2.5957,  0.3818,  1.5391,  ..., -2.3008, -0.9087, -0.0071],
         [-2.4707, -1.7988, -0.6636,  ...,  3.5273,  0.0566, -1.7158]],

        [[-2.5527,  0.4001,  1.5176,  ..., -2.3145, -0

Epoch 0:   7%|▋         | 411/5879 [04:41<1:02:22,  1.46it/s, v_num=7]
 tensor([[[-2.6348,  0.4248,  1.5615,  ..., -2.3438, -0.9551,  0.0054],
         [-2.6289,  0.4258,  1.5547,  ..., -2.3496, -0.9658,  0.0129],
         [-2.3652, -1.7676, -0.7026,  ...,  3.5586, -0.0673, -1.6914]],

        [[-2.6035,  0.4368,  1.5391,  ..., -2.3594, -1.0039,  0.0346],
         [-2.6094,  0.4331,  1.5381,  ..., -2.3613, -1.0020,  0.0332],
         [-2.3555, -1.7686, -0.7031,  ...,  3.5586, -0.0742, -1.6885]],

        [[-2.7227,  0.3801,  1.6162,  ..., -2.3105, -0.8115, -0.0610],
         [-2.7383,  0.3716,  1.6191,  ..., -2.3066, -0.7930, -0.0762],
         [-2.5293, -1.8232, -0.6777,  ...,  3.5488,  0.0977, -1.7646]],

        ...,

        [[-2.9297, -2.0977, -0.6108,  ...,  3.1055,  0.6479, -1.8945],
         [-2.9766, -2.1152, -0.6001,  ...,  3.0293,  0.7056, -1.8965],
         [-2.6133, -1.9219, -0.6558,  ...,  3.4883,  0.2305, -1.8262]],

        [[-2.7715,  0.3481,  1.6318,  ..., -2.2891, -0

Epoch 0:   7%|▋         | 417/5879 [04:48<1:02:52,  1.45it/s, v_num=7]
 tensor([[[-2.6582,  0.4656,  1.5791,  ..., -2.4082, -1.0303,  0.0423],
         [-2.6309,  0.4678,  1.5527,  ..., -2.4277, -1.0732,  0.0710],
         [-2.4043, -1.7598, -0.7397,  ...,  3.6035, -0.0721, -1.7188]],

        [[-2.7539,  0.4321,  1.6416,  ..., -2.3652, -0.8848, -0.0408],
         [-2.7559,  0.4309,  1.6396,  ..., -2.3691, -0.8862, -0.0393],
         [-2.4824, -1.7852, -0.7266,  ...,  3.5938,  0.0047, -1.7539]],

        [[-2.5176,  0.4900,  1.4727,  ..., -2.4707, -1.2178,  0.1648],
         [-2.4297,  0.4890,  1.3838,  ..., -2.5195, -1.3018,  0.2355],
         [-1.4121, -2.2422, -0.8315,  ...,  2.2930, -0.4004, -1.0557]],

        ...,

        [[-2.6797,  0.4622,  1.5908,  ..., -2.3984, -1.0020,  0.0256],
         [-2.6699,  0.4612,  1.5781,  ..., -2.4102, -1.0195,  0.0388],
         [-2.4160, -1.7637, -0.7373,  ...,  3.6016, -0.0615, -1.7256]],

        [[-0.5103, -3.0801, -0.7520,  ..., -1.3242, -1

Epoch 0:   7%|▋         | 422/5879 [04:49<1:02:27,  1.46it/s, v_num=7]
 tensor([[[-2.9336e+00,  3.7695e-01,  1.7227e+00,  ..., -2.3574e+00,
          -6.7822e-01, -1.4001e-01],
         [-2.9414e+00,  3.7134e-01,  1.7227e+00,  ..., -2.3574e+00,
          -6.6846e-01, -1.4673e-01],
         [-2.5391e+00, -1.8271e+00, -7.4561e-01,  ...,  3.5918e+00,
           5.8533e-02, -1.7979e+00]],

        [[-2.7852e+00,  4.6094e-01,  1.6514e+00,  ..., -2.4199e+00,
          -9.3164e-01, -1.8326e-02],
         [-2.7852e+00,  4.6118e-01,  1.6445e+00,  ..., -2.4258e+00,
          -9.4189e-01, -1.4565e-02],
         [-2.4629e+00, -1.7676e+00, -7.6904e-01,  ...,  3.6211e+00,
          -4.4800e-02, -1.7510e+00]],

        [[-2.8828e+00,  4.1382e-01,  1.7090e+00,  ..., -2.3828e+00,
          -7.7295e-01, -9.3994e-02],
         [-2.9043e+00,  4.0112e-01,  1.7139e+00,  ..., -2.3750e+00,
          -7.4365e-01, -1.1407e-01],
         [-2.5410e+00, -1.8271e+00, -7.4707e-01,  ...,  3.5918e+00,
           6.280

Epoch 0:   7%|▋         | 427/5879 [04:51<1:02:03,  1.46it/s, v_num=7]
 tensor([[[-2.9121e+00,  4.5264e-01,  1.7129e+00,  ..., -2.4375e+00,
          -8.3008e-01, -7.3608e-02],
         [-2.9316e+00,  4.4336e-01,  1.7207e+00,  ..., -2.4316e+00,
          -8.0664e-01, -9.0088e-02],
         [-2.5000e+00, -1.7881e+00, -7.9736e-01,  ...,  3.6289e+00,
          -1.9485e-02, -1.7773e+00]],

        [[-4.6289e-01, -3.2168e+00, -7.7344e-01,  ..., -1.2764e+00,
          -1.2168e+00,  3.9038e-01],
         [-4.5947e-01, -3.2305e+00, -7.6758e-01,  ..., -1.2578e+00,
          -1.1885e+00,  3.7744e-01],
         [-4.5044e-01, -3.2539e+00, -7.4023e-01,  ..., -1.2051e+00,
          -1.1172e+00,  3.3643e-01]],

        [[-2.9395e+00,  4.3628e-01,  1.7266e+00,  ..., -2.4277e+00,
          -7.8564e-01, -9.2712e-02],
         [-2.9648e+00,  4.2383e-01,  1.7363e+00,  ..., -2.4199e+00,
          -7.5488e-01, -1.1432e-01],
         [-2.5137e+00, -1.8018e+00, -7.9248e-01,  ...,  3.6230e+00,
           1.236

Epoch 0:   7%|▋         | 432/5879 [04:53<1:01:41,  1.47it/s, v_num=7]
 tensor([[[-2.9238e+00,  4.8901e-01,  1.7031e+00,  ..., -2.5039e+00,
          -9.0820e-01, -3.1769e-02],
         [-2.9414e+00,  4.8315e-01,  1.7100e+00,  ..., -2.4980e+00,
          -8.9062e-01, -4.8553e-02],
         [-2.4668e+00, -1.7627e+00, -8.4424e-01,  ...,  3.6602e+00,
          -8.0933e-02, -1.7617e+00]],

        [[-2.8809e+00,  5.0000e-01,  1.6709e+00,  ..., -2.5234e+00,
          -9.7021e-01,  5.6458e-03],
         [-2.8848e+00,  4.9634e-01,  1.6689e+00,  ..., -2.5254e+00,
          -9.6826e-01,  2.9449e-03],
         [-2.4219e+00, -1.7402e+00, -8.5547e-01,  ...,  3.6719e+00,
          -1.3379e-01, -1.7402e+00]],

        [[-2.9785e+00,  4.6704e-01,  1.7354e+00,  ..., -2.4785e+00,
          -8.2422e-01, -7.7820e-02],
         [-2.9922e+00,  4.6191e-01,  1.7432e+00,  ..., -2.4766e+00,
          -8.0811e-01, -8.8501e-02],
         [-2.4805e+00, -1.7734e+00, -8.3936e-01,  ...,  3.6543e+00,
          -6.390

Epoch 0:   7%|▋         | 437/5879 [04:58<1:01:56,  1.46it/s, v_num=7]
 tensor([[[-2.9629,  0.5093,  1.7109,  ..., -2.5508, -0.9355, -0.0166],
         [-2.9863,  0.5044,  1.7246,  ..., -2.5449, -0.9092, -0.0411],
         [-2.3887, -1.7256, -0.8975,  ...,  3.6973, -0.1892, -1.7285]],

        [[-3.0176,  0.4946,  1.7539,  ..., -2.5293, -0.8589, -0.0616],
         [-3.0430,  0.4854,  1.7656,  ..., -2.5195, -0.8311, -0.0831],
         [-2.4609, -1.7578, -0.8804,  ...,  3.6816, -0.1084, -1.7627]],

        [[-0.4329, -3.2988, -0.8247,  ..., -1.3135, -1.2773,  0.4392],
         [-0.4197, -3.3418, -0.8042,  ..., -1.2451, -1.1855,  0.3906],
         [-0.4080, -3.3633, -0.7681,  ..., -1.1816, -1.1025,  0.3457]],

        ...,

        [[-3.0605,  0.4756,  1.7773,  ..., -2.5098, -0.7939, -0.0959],
         [-3.0859,  0.4651,  1.7930,  ..., -2.5020, -0.7627, -0.1187],
         [-2.4824, -1.7803, -0.8716,  ...,  3.6699, -0.0760, -1.7793]],

        [[-3.3418,  0.2062,  1.7979,  ..., -2.2910, -0

Epoch 0:   8%|▊         | 443/5879 [05:01<1:01:42,  1.47it/s, v_num=7]
 tensor([[[-2.9102e+00,  5.4590e-01,  1.6572e+00,  ..., -2.6504e+00,
          -1.0898e+00,  8.3435e-02],
         [-2.9922e+00,  5.3516e-01,  1.7051e+00,  ..., -2.6191e+00,
          -9.7900e-01,  3.2330e-03],
         [-2.2637e+00, -1.7021e+00, -9.6045e-01,  ...,  3.7285e+00,
          -3.1665e-01, -1.6836e+00]],

        [[-3.3203e+00,  3.5352e-01,  1.8662e+00,  ..., -2.4395e+00,
          -4.2798e-01, -2.8345e-01],
         [-3.3047e+00,  3.7207e-01,  1.8682e+00,  ..., -2.4590e+00,
          -4.7534e-01, -2.6587e-01],
         [-2.5039e+00, -1.8125e+00, -9.0283e-01,  ...,  3.6738e+00,
          -6.0455e-02, -1.8037e+00]],

        [[-3.0176e+00,  5.3076e-01,  1.7363e+00,  ..., -2.6035e+00,
          -9.5557e-01, -7.1030e-03],
         [-3.0996e+00,  5.1367e-01,  1.7891e+00,  ..., -2.5684e+00,
          -8.5010e-01, -9.0942e-02],
         [-2.3809e+00, -1.7109e+00, -9.4385e-01,  ...,  3.7285e+00,
          -2.280

Epoch 0:   8%|▊         | 448/5879 [05:03<1:01:21,  1.48it/s, v_num=7]
 tensor([[[-3.2129,  0.5063,  1.8711,  ..., -2.5781, -0.7651, -0.1240],
         [-3.2227,  0.5020,  1.8691,  ..., -2.5742, -0.7568, -0.1353],
         [-2.4629, -1.7607, -0.9536,  ...,  3.7188, -0.1461, -1.7764]],

        [[-3.2383,  0.4937,  1.8809,  ..., -2.5645, -0.7217, -0.1482],
         [-3.2500,  0.4873,  1.8838,  ..., -2.5625, -0.7100, -0.1609],
         [-2.4648, -1.7627, -0.9526,  ...,  3.7168, -0.1444, -1.7783]],

        [[-3.1289,  0.5337,  1.8066,  ..., -2.6113, -0.8823, -0.0599],
         [-3.1641,  0.5239,  1.8301,  ..., -2.5996, -0.8423, -0.0911],
         [-2.4121, -1.7129, -0.9736,  ...,  3.7402, -0.2208, -1.7412]],

        ...,

        [[-3.1973,  0.5117,  1.8623,  ..., -2.5859, -0.7900, -0.1106],
         [-3.2129,  0.5049,  1.8652,  ..., -2.5801, -0.7778, -0.1307],
         [-2.4473, -1.7422, -0.9604,  ...,  3.7285, -0.1705, -1.7646]],

        [[-3.2324,  0.4927,  1.8789,  ..., -2.5684, -0

Epoch 0:   8%|▊         | 454/5879 [05:05<1:00:53,  1.48it/s, v_num=7]
 tensor([[[-3.1953,  0.5532,  1.8438,  ..., -2.6543, -0.8901, -0.0625],
         [-3.2344,  0.5430,  1.8740,  ..., -2.6387, -0.8452, -0.0995],
         [-2.4238, -1.7100, -1.0137,  ...,  3.7598, -0.2333, -1.7490]],

        [[-3.1719,  0.5591,  1.8252,  ..., -2.6641, -0.9194, -0.0446],
         [-3.2324,  0.5449,  1.8662,  ..., -2.6387, -0.8481, -0.1041],
         [-2.3652, -1.6768, -1.0312,  ...,  3.7773, -0.3042, -1.7197]],

        [[-2.8828,  0.5747,  1.6006,  ..., -2.7832, -1.2461,  0.1969],
         [-3.0723,  0.5669,  1.7412,  ..., -2.7129, -1.0352,  0.0323],
         [-2.2090, -1.6719, -1.0537,  ...,  3.7793, -0.4165, -1.6641]],

        ...,

        [[-3.2832,  0.5269,  1.9062,  ..., -2.6113, -0.7598, -0.1389],
         [-3.2891,  0.5244,  1.9053,  ..., -2.6113, -0.7544, -0.1445],
         [-2.4492, -1.7393, -1.0020,  ...,  3.7461, -0.1906, -1.7695]],

        [[-3.1445,  0.5625,  1.8066,  ..., -2.6777, -0

Epoch 0:   8%|▊         | 460/5879 [05:07<1:00:19,  1.50it/s, v_num=7]
 tensor([[[-3.2695,  0.5713,  1.8848,  ..., -2.6855, -0.8799, -0.0795],
         [-3.2930,  0.5654,  1.9004,  ..., -2.6777, -0.8535, -0.1049],
         [-2.3789, -1.6611, -1.0732,  ...,  3.7969, -0.3259, -1.7217]],

        [[-3.3828,  0.5332,  1.9619,  ..., -2.6309, -0.7148, -0.1792],
         [-3.3848,  0.5317,  1.9580,  ..., -2.6309, -0.7168, -0.1858],
         [-2.4062, -1.6807, -1.0635,  ...,  3.7871, -0.2903, -1.7383]],

        [[-3.4375,  0.4993,  1.9922,  ..., -2.6074, -0.6284, -0.2163],
         [-3.4160,  0.5181,  1.9863,  ..., -2.6230, -0.6816, -0.1997],
         [-2.4434, -1.7217, -1.0479,  ...,  3.7695, -0.2318, -1.7666]],

        ...,

        [[-2.9414,  0.5879,  1.6338,  ..., -2.8203, -1.2471,  0.1940],
         [-3.1230,  0.5796,  1.7656,  ..., -2.7578, -1.0527,  0.0389],
         [-2.3047, -1.6426, -1.0918,  ...,  3.8125, -0.3958, -1.6904]],

        [[-3.5938,  0.3340,  1.9688,  ..., -2.4473, -0

Epoch 0:   8%|▊         | 466/5879 [05:14<1:00:47,  1.48it/s, v_num=7]
 tensor([[[-3.4785,  0.5396,  2.0117,  ..., -2.6484, -0.6709, -0.2174],
         [-3.4727,  0.5415,  1.9980,  ..., -2.6504, -0.6743, -0.2223],
         [-2.3828, -1.6426, -1.1172,  ...,  3.8203, -0.3557, -1.7188]],

        [[-2.9883,  0.5986,  1.6621,  ..., -2.8652, -1.2734,  0.2037],
         [-1.9336,  0.2673,  0.6494,  ..., -3.0254, -1.8799,  0.8750],
         [-0.6021, -3.0488, -1.0430,  ...,  1.0361, -0.6313, -0.4656]],

        [[-3.3027,  0.5942,  1.9043,  ..., -2.7363, -0.9233, -0.0603],
         [-3.1875,  0.5879,  1.7539,  ..., -2.7812, -0.9829,  0.0152],
         [-2.1270, -1.6377, -1.1611,  ...,  3.8203, -0.5391, -1.6260]],

        ...,

        [[-3.7383,  0.2498,  1.9893,  ..., -2.3555, -0.0447, -0.5410],
         [-3.7734,  0.1876,  1.9551,  ..., -2.2754,  0.0656, -0.6250],
         [-2.4531, -1.7188, -1.0869,  ...,  3.7812, -0.2463, -1.7725]],

        [[-3.3652,  0.5815,  1.9531,  ..., -2.7051, -0

Epoch 0:   8%|▊         | 472/5879 [05:15<1:00:15,  1.50it/s, v_num=7]
 tensor([[[-3.4863,  0.5908,  2.0352,  ..., -2.7129, -0.7852, -0.1675],
         [-3.4941,  0.5884,  2.0293,  ..., -2.7109, -0.7749, -0.1804],
         [-2.3691, -1.6152, -1.1660,  ...,  3.8496, -0.4058, -1.7051]],

        [[-3.2832,  0.6152,  1.8662,  ..., -2.8047, -1.0166, -0.0046],
         [-3.3281,  0.6079,  1.8994,  ..., -2.7871, -0.9697, -0.0439],
         [-2.2949, -1.5908, -1.1885,  ...,  3.8711, -0.4792, -1.6689]],

        [[-0.2881, -3.6719, -0.9048,  ..., -1.3076, -1.2900,  0.4995],
         [-0.2686, -3.7246, -0.8779,  ..., -1.2119, -1.1748,  0.4294],
         [-0.2622, -3.7441, -0.8364,  ..., -1.0381, -1.0361,  0.3169]],

        ...,

        [[-3.5586,  0.5508,  2.0742,  ..., -2.6758, -0.6592, -0.2335],
         [-3.5762,  0.5420,  2.0762,  ..., -2.6660, -0.6387, -0.2605],
         [-2.3965, -1.6406, -1.1543,  ...,  3.8398, -0.3638, -1.7246]],

        [[-3.2617,  0.6152,  1.8477,  ..., -2.8164, -1

Epoch 0:   8%|▊         | 478/5879 [05:17<59:43,  1.51it/s, v_num=7]
 tensor([[[-3.6289,  0.5747,  2.1094,  ..., -2.6973, -0.6558, -0.2561],
         [-3.6191,  0.5410,  2.0352,  ..., -2.6719, -0.5933, -0.2856],
         [-2.3633, -1.5947, -1.2109,  ...,  3.8809, -0.4431, -1.6963]],

        [[-3.7949,  0.4209,  2.1133,  ..., -2.5430, -0.2871, -0.4504],
         [-3.8164,  0.3931,  2.0996,  ..., -2.5137, -0.2323, -0.4934],
         [-2.4082, -1.6445, -1.1914,  ...,  3.8574, -0.3713, -1.7314]],

        [[-3.7559,  0.4697,  2.1250,  ..., -2.5977, -0.3945, -0.3953],
         [-3.7520,  0.4739,  2.1250,  ..., -2.6035, -0.4077, -0.3936],
         [-2.3945, -1.6318, -1.1963,  ...,  3.8652, -0.3911, -1.7217]],

        ...,

        [[-0.2527, -3.7930, -0.8838,  ..., -1.2158, -1.1738,  0.4368],
         [-0.2484, -3.8008, -0.8765,  ..., -1.2012, -1.1533,  0.4260],
         [-0.2476, -3.8105, -0.8560,  ..., -1.1855, -1.1299,  0.4089]],

        [[-3.5156,  0.6162,  2.0410,  ..., -2.7578, -0.8

Epoch 0:   8%|▊         | 484/5879 [05:24<1:00:22,  1.49it/s, v_num=7]
 tensor([[[-3.2305,  0.6416,  1.8105,  ..., -2.9355, -1.1943,  0.1208],
         [-1.1182, -2.2305, -1.4062,  ...,  2.9395, -0.6855, -1.1045],
         [-1.2490, -2.1172, -1.3457,  ...,  3.0742, -0.7490, -1.2119]],

        [[-0.2507, -3.7734, -0.9355,  ..., -1.3486, -1.3213,  0.5435],
         [-0.2174, -3.8555, -0.8940,  ..., -1.2021, -1.1494,  0.4353],
         [-0.2169, -3.8691, -0.8560,  ..., -1.0742, -1.0527,  0.3452]],

        [[-3.5938,  0.6353,  2.1113,  ..., -2.7793, -0.8208, -0.1757],
         [-3.5859,  0.4231,  1.7734,  ..., -2.5664, -0.3796, -0.3396],
         [-2.3320, -1.5537, -1.2676,  ...,  3.9141, -0.5215, -1.6719]],

        ...,

        [[-3.3242,  0.6450,  1.8877,  ..., -2.9043, -1.1074,  0.0522],
         [-0.1670, -3.8457, -1.0049,  ..., -1.0586, -1.0928,  0.3855],
         [-0.9956, -2.4062, -1.3076,  ...,  2.5469, -0.7344, -0.9971]],

        [[-3.6074,  0.6328,  2.1211,  ..., -2.7734, -0

Epoch 0:   8%|▊         | 490/5879 [05:33<1:01:05,  1.47it/s, v_num=7]
 tensor([[[-0.3865, -2.5449, -1.0098,  ..., -2.3242, -2.1348,  1.1973],
         [-0.1584, -3.9004, -0.9893,  ..., -0.9746, -1.0391,  0.3308],
         [-0.9004, -2.4727, -1.3467,  ...,  2.4531, -0.7642, -0.9448]],

        [[-3.6699,  0.6479,  2.1738,  ..., -2.7930, -0.8042, -0.2078],
         [-3.5664,  0.6240,  2.0195,  ..., -2.8242, -0.8325, -0.1488],
         [-2.3379, -1.5332, -1.3135,  ...,  3.9316, -0.5552, -1.6699]],

        [[-3.3594,  0.6572,  1.9160,  ..., -2.9355, -1.1309,  0.0522],
         [-0.1525, -3.9121, -0.9604,  ..., -0.9912, -1.0332,  0.3416],
         [-0.4062, -3.3223, -1.1006,  ...,  0.7988, -0.6855, -0.3450]],

        ...,

        [[-3.7305,  0.6304,  2.2188,  ..., -2.7656, -0.7246, -0.2537],
         [-3.7051,  0.6411,  2.2012,  ..., -2.7793, -0.7715, -0.2345],
         [-2.3770, -1.5713, -1.2959,  ...,  3.9160, -0.4963, -1.6992]],

        [[-3.7207,  0.6328,  2.2070,  ..., -2.7656, -0

Epoch 0:   8%|▊         | 496/5879 [05:34<1:00:34,  1.48it/s, v_num=7]
 tensor([[[-3.3496,  0.6401,  1.8340,  ..., -2.9590, -1.0723,  0.0723],
         [-2.3672, -1.5107, -1.3496,  ...,  3.9375, -0.5728, -1.6670],
         [-2.2207, -1.4971, -1.3936,  ...,  3.9688, -0.6992, -1.6191]],

        [[-0.2186, -3.5176, -1.0664,  ..., -1.7812, -1.6807,  0.8740],
         [-2.1660, -1.5195, -1.4219,  ...,  3.9629, -0.7104, -1.5830],
         [-0.4873, -3.1348, -1.1973,  ...,  1.2510, -0.7046, -0.4941]],

        [[-3.7305,  0.6592,  2.1914,  ..., -2.8086, -0.7842, -0.2339],
         [-2.3906, -1.5303, -1.3379,  ...,  3.9219, -0.5356, -1.6816],
         [-2.3164, -1.4961, -1.3701,  ...,  3.9531, -0.6240, -1.6504]],

        ...,

        [[-0.1390, -3.9707, -0.9614,  ..., -1.0176, -1.0352,  0.3547],
         [-1.0098, -2.3164, -1.4717,  ...,  2.8301, -0.7646, -1.0498],
         [-1.3750, -1.9443, -1.4561,  ...,  3.3789, -0.8599, -1.3096]],

        [[-3.6191,  0.6768,  2.1172,  ..., -2.8691, -0

Epoch 0:   9%|▊         | 502/5879 [05:36<1:00:03,  1.49it/s, v_num=7]
 tensor([[[-0.2004, -3.9902, -0.9238,  ..., -1.3408, -1.2480,  0.5332],
         [-0.1675, -4.0469, -0.8931,  ..., -1.2334, -1.1211,  0.4543],
         [-0.1714, -4.0547, -0.8672,  ..., -1.2197, -1.1055,  0.4358]],

        [[-3.0254,  0.6152,  1.6328,  ..., -3.0996, -1.4893,  0.3767],
         [-2.3477, -1.4707, -1.4023,  ...,  3.9648, -0.6313, -1.6436],
         [-2.1191, -1.4873, -1.4551,  ...,  3.9902, -0.7900, -1.5791]],

        [[-0.1831, -4.0273, -0.9062,  ..., -1.2734, -1.1738,  0.4834],
         [-0.1716, -4.0430, -0.8940,  ..., -1.2441, -1.1338,  0.4604],
         [-0.1722, -4.0547, -0.8687,  ..., -1.2285, -1.1123,  0.4417]],

        ...,

        [[-4.0352,  0.5210,  2.3145,  ..., -2.6406, -0.3452, -0.4983],
         [-2.4316, -1.5596, -1.3633,  ...,  3.9082, -0.5083, -1.7051],
         [-2.3848, -1.5400, -1.3828,  ...,  3.9453, -0.5591, -1.6943]],

        [[-3.7676,  0.6914,  2.2422,  ..., -2.8555, -0

Epoch 0:   9%|▊         | 508/5879 [05:38<59:35,  1.50it/s, v_num=7]
 tensor([[[-3.6914,  0.7148,  2.1660,  ..., -2.9414, -0.9897, -0.1068],
         [-2.0977, -1.4648, -1.5098,  ...,  4.0078, -0.8286, -1.5469],
         [-1.7783, -1.6084, -1.5371,  ...,  3.8652, -0.9282, -1.4717]],

        [[-0.0920, -4.0664, -1.0029,  ..., -0.8447, -0.9360,  0.2847],
         [-1.0527, -2.1816, -1.5645,  ...,  3.0430, -0.8438, -1.1104],
         [-1.5322, -1.7559, -1.5527,  ...,  3.6562, -0.9443, -1.3828]],

        [[-0.1610, -4.0977, -0.8955,  ..., -1.2705, -1.1396,  0.4814],
         [-0.1603, -4.0977, -0.8940,  ..., -1.2676, -1.1348,  0.4783],
         [-0.1620, -4.1094, -0.8691,  ..., -1.2549, -1.1172,  0.4597]],

        ...,

        [[-0.3208, -2.7637, -1.0723,  ..., -2.3516, -2.0664,  1.2344],
         [-2.1426, -1.4609, -1.4971,  ...,  4.0117, -0.8047, -1.5684],
         [-2.1758, -1.4473, -1.4922,  ...,  4.0117, -0.8066, -1.5840]],

        [[-1.2051, -0.7070, -0.0726,  ..., -3.0020, -2.3

Epoch 0:   9%|▊         | 514/5879 [05:51<1:01:10,  1.46it/s, v_num=7]
 tensor([[[-3.8574,  0.7334,  2.3164,  ..., -2.9141, -0.8945, -0.2059],
         [-2.2070, -1.4121, -1.5283,  ...,  4.0312, -0.8140, -1.5674],
         [-0.7661, -2.5762, -1.4746,  ...,  2.4004, -0.8311, -0.8872]],

        [[-0.1976, -3.7324, -1.0498,  ..., -1.7939, -1.6182,  0.8809],
         [-2.0293, -1.4590, -1.5625,  ...,  4.0156, -0.8857, -1.5166],
         [-1.4082, -1.8418, -1.5928,  ...,  3.5488, -0.9644, -1.3271]],

        [[-4.0352,  0.6831,  2.4297,  ..., -2.8125, -0.6494, -0.3633],
         [-4.0234,  0.6904,  2.4277,  ..., -2.8203, -0.6714, -0.3550],
         [-2.3418, -1.4541, -1.4873,  ...,  4.0000, -0.7021, -1.6484]],

        ...,

        [[-0.1575, -4.1367, -0.9038,  ..., -1.3203, -1.1738,  0.5171],
         [-0.1489, -4.1523, -0.8936,  ..., -1.2979, -1.1426,  0.4985],
         [-0.1488, -4.1641, -0.8721,  ..., -1.2871, -1.1240,  0.4836]],

        [[-3.9434,  0.7090,  2.3691,  ..., -2.8613, -0

Epoch 0:   9%|▉         | 520/5879 [05:53<1:00:39,  1.47it/s, v_num=7]
 tensor([[[-0.1333, -4.2109, -0.8984,  ..., -1.3193, -1.1396,  0.5137],
         [-0.1321, -4.2148, -0.8960,  ..., -1.3154, -1.1338,  0.5103],
         [-0.1338, -4.2227, -0.8755,  ..., -1.3086, -1.1201,  0.4971]],

        [[-4.1406,  0.6660,  2.4785,  ..., -2.7930, -0.5557, -0.4299],
         [-4.1211,  0.6807,  2.4785,  ..., -2.8105, -0.5981, -0.4099],
         [-2.3555, -1.4570, -1.5186,  ...,  4.0156, -0.7080, -1.6572]],

        [[-0.1519, -4.1797, -0.9150,  ..., -1.3691, -1.2002,  0.5508],
         [-0.1285, -4.2148, -0.8926,  ..., -1.3086, -1.1240,  0.5044],
         [-0.1322, -4.2227, -0.8730,  ..., -1.3047, -1.1152,  0.4944]],

        ...,

        [[-4.0117,  0.7432,  2.4355,  ..., -2.8828, -0.7900, -0.2986],
         [-4.0000,  0.7446,  2.4277,  ..., -2.8887, -0.8062, -0.2900],
         [-2.2598, -1.3965, -1.5596,  ...,  4.0469, -0.8306, -1.5967]],

        [[-3.8516,  0.7510,  2.3047,  ..., -2.9609, -0

Epoch 0:   9%|▉         | 526/5879 [05:54<1:00:10,  1.48it/s, v_num=7]
 tensor([[[-3.5645,  0.7305,  2.0664,  ..., -3.1152, -1.2607,  0.1199],
         [-3.6270,  0.7153,  1.9951,  ..., -3.0566, -1.0439,  0.0253],
         [-2.1367, -1.3965, -1.6270,  ...,  4.0742, -0.9443, -1.5537]],

        [[-0.1691, -3.9766, -1.0107,  ..., -1.7197, -1.4951,  0.8193],
         [-0.1164, -4.1719, -0.9829,  ..., -1.4678, -1.2842,  0.6357],
         [-1.8691, -1.4941, -1.6641,  ...,  4.0078, -1.0352, -1.4814]],

        [[-4.0664,  0.7607,  2.4766,  ..., -2.9004, -0.7998, -0.3118],
         [-4.0586,  0.7612,  2.4727,  ..., -2.9043, -0.8091, -0.3069],
         [-2.2852, -1.3799, -1.5918,  ...,  4.0625, -0.8447, -1.6025]],

        ...,

        [[-4.0742,  0.7583,  2.4863,  ..., -2.8984, -0.7847, -0.3181],
         [-4.0625,  0.7607,  2.4766,  ..., -2.9043, -0.8018, -0.3083],
         [-2.2637, -1.3750, -1.5996,  ...,  4.0664, -0.8638, -1.5918]],

        [[-3.8809,  0.7642,  2.3105,  ..., -2.9941, -0

Epoch 0:   9%|▉         | 532/5879 [05:56<59:44,  1.49it/s, v_num=7]
 tensor([[[-3.6758,  0.7593,  2.1602,  ..., -3.1152, -1.2227,  0.0489],
         [-0.1348, -4.1562, -0.9795,  ..., -1.6025, -1.3662,  0.7197],
         [-1.8516, -1.4834, -1.7080,  ...,  4.0117, -1.0752, -1.4756]],

        [[-0.1060, -4.3125, -0.9004,  ..., -1.3721, -1.1289,  0.5415],
         [-0.0981, -4.3164, -0.8989,  ..., -1.3525, -1.1113,  0.5283],
         [-0.1025, -4.3281, -0.8672,  ..., -1.3301, -1.0889,  0.4983]],

        [[-4.1055,  0.7803,  2.5234,  ..., -2.9277, -0.8247, -0.3079],
         [-4.0977,  0.7822,  2.5176,  ..., -2.9336, -0.8359, -0.3010],
         [-2.2090, -1.3613, -1.6553,  ...,  4.0859, -0.9438, -1.5713]],

        ...,

        [[-0.1158, -4.3164, -0.8955,  ..., -1.3887, -1.1445,  0.5483],
         [-0.1152, -4.3164, -0.8955,  ..., -1.3877, -1.1436,  0.5479],
         [-0.1122, -4.3320, -0.8662,  ..., -1.3691, -1.1133,  0.5239]],

        [[-4.1719,  0.7563,  2.5625,  ..., -2.8887, -0.7

Epoch 0:   9%|▉         | 538/5879 [06:08<1:00:59,  1.46it/s, v_num=7]
 tensor([[[-1.5254, -1.6533, -1.7959,  ...,  3.8203, -1.0752, -1.3389],
         [-1.7393, -1.5244, -1.7734,  ...,  3.9688, -1.0879, -1.4180],
         [-1.1660, -1.9932, -1.7490,  ...,  3.3574, -1.0791, -1.2324]],

        [[-0.0981, -4.3672, -0.8779,  ..., -1.4072, -1.1172,  0.5503],
         [-0.0975, -4.3672, -0.8765,  ..., -1.4053, -1.1143,  0.5483],
         [-0.1022, -4.3828, -0.8521,  ..., -1.4004, -1.1045,  0.5337]],

        [[-4.0586,  0.8013,  2.4922,  ..., -2.9961, -0.9404, -0.2363],
         [-3.9941,  0.7988,  2.4395,  ..., -3.0273, -1.0039, -0.1791],
         [-1.5713, -1.6631, -1.7598,  ...,  3.8125, -1.1143, -1.3994]],

        ...,

        [[-0.1783, -4.0977, -0.9604,  ..., -1.7715, -1.4639,  0.8237],
         [-0.1494, -4.2305, -0.9375,  ..., -1.6260, -1.3418,  0.7144],
         [-0.2925, -3.2656, -1.3926,  ...,  1.3320, -0.7822, -0.4885]],

        [[-4.1602,  0.7930,  2.5781,  ..., -2.9414, -0

Epoch 0:   9%|▉         | 544/5879 [06:10<1:00:29,  1.47it/s, v_num=7]
 tensor([[[-0.0928, -4.4102, -0.8657,  ..., -1.4590, -1.1191,  0.5713],
         [-0.0934, -4.4102, -0.8662,  ..., -1.4600, -1.1201,  0.5723],
         [-0.0948, -4.4297, -0.8369,  ..., -1.4473, -1.1006,  0.5508]],

        [[-4.1328,  0.8149,  2.5723,  ..., -3.0039, -0.9248, -0.2683],
         [-4.0977,  0.8140,  2.5430,  ..., -3.0215, -0.9604, -0.2386],
         [-1.1152, -2.0352, -1.7695,  ...,  3.2930, -1.0791, -1.2051]],

        [[-4.2734,  0.7793,  2.6699,  ..., -2.9219, -0.7368, -0.3960],
         [-4.2656,  0.7866,  2.6680,  ..., -2.9297, -0.7607, -0.3879],
         [-2.3242, -1.3584, -1.7002,  ...,  4.0742, -0.9126, -1.6211]],

        ...,

        [[-2.3320, -1.3262, -1.6953,  ...,  4.0703, -0.9175, -1.6035],
         [-2.2695, -1.3154, -1.7188,  ...,  4.0859, -0.9736, -1.5732],
         [-2.2090, -1.3271, -1.7402,  ...,  4.0938, -1.0225, -1.5645]],

        [[-4.1875,  0.8110,  2.6191,  ..., -2.9746, -0

Epoch 0:   9%|▉         | 550/5879 [06:11<1:00:00,  1.48it/s, v_num=7]
 tensor([[[-2.3438, -1.3242, -1.7227,  ...,  4.0820, -0.9316, -1.6045],
         [-2.3418, -1.3223, -1.7246,  ...,  4.0820, -0.9365, -1.6025],
         [-2.0000, -1.3848, -1.8125,  ...,  4.0859, -1.1416, -1.5068]],

        [[-4.3477,  0.7817,  2.7012,  ..., -2.9316, -0.6987, -0.4221],
         [-4.3438,  0.7837,  2.7012,  ..., -2.9336, -0.7046, -0.4189],
         [-2.3066, -1.3301, -1.7471,  ...,  4.0938, -0.9727, -1.6006]],

        [[-4.3633,  0.7656,  2.7051,  ..., -2.9141, -0.6621, -0.4375],
         [-4.3867,  0.7480,  2.7051,  ..., -2.8926, -0.6167, -0.4675],
         [-2.3242, -1.3525, -1.7373,  ...,  4.0898, -0.9414, -1.6172]],

        ...,

        [[-4.1602,  0.8311,  2.5840,  ..., -3.0410, -0.9541, -0.2546],
         [-4.1602,  0.8306,  2.5801,  ..., -3.0430, -0.9575, -0.2512],
         [-2.1855, -1.3105, -1.7822,  ...,  4.1094, -1.0752, -1.5488]],

        [[-4.1758,  0.8276,  2.5938,  ..., -3.0352, -0

Epoch 0:   9%|▉         | 556/5879 [06:13<59:31,  1.49it/s, v_num=7]
 tensor([[[-4.3359,  0.8330,  2.7051,  ..., -3.0000, -0.8213, -0.3623],
         [-4.3320,  0.8345,  2.7031,  ..., -3.0039, -0.8276, -0.3599],
         [-2.2559, -1.2979, -1.7998,  ...,  4.1250, -1.0586, -1.5664]],

        [[-0.0910, -4.5352, -0.8584,  ..., -1.5410, -1.1143,  0.6069],
         [-0.0903, -4.5352, -0.8574,  ..., -1.5391, -1.1113,  0.6055],
         [-0.0912, -4.5469, -0.8369,  ..., -1.5342, -1.0977,  0.5938]],

        [[-4.3047,  0.8413,  2.6816,  ..., -3.0156, -0.8589, -0.3364],
         [-4.3047,  0.8413,  2.6816,  ..., -3.0156, -0.8589, -0.3364],
         [-2.1035, -1.3311, -1.8330,  ...,  4.1250, -1.1416, -1.5234]],

        ...,

        [[-1.9873, -0.1740,  0.6743,  ..., -3.2988, -2.1328,  1.2012],
         [-0.2080, -3.9863, -0.9829,  ..., -2.1016, -1.6094,  1.0332],
         [-2.1543, -1.3057, -1.8242,  ...,  4.1289, -1.1260, -1.5322]],

        [[-0.9307, -2.1367, -1.8604,  ...,  3.1777, -1.0

Epoch 0:  10%|▉         | 562/5879 [06:20<59:55,  1.48it/s, v_num=7]
 tensor([[[-4.4023,  0.8442,  2.7305,  ..., -3.0254, -0.8052, -0.3718],
         [-4.3945,  0.8472,  2.7246,  ..., -3.0312, -0.8213, -0.3655],
         [-2.2363, -1.2959, -1.8369,  ...,  4.1562, -1.0996, -1.5547]],

        [[-0.4785, -2.9492, -0.8687,  ..., -2.6602, -1.9941,  1.3877],
         [-0.4617, -3.0078, -0.8813,  ..., -2.6387, -1.9795,  1.3760],
         [-1.1162, -1.9688, -1.8916,  ...,  3.4395, -1.1885, -1.2188]],

        [[-4.3711,  0.8535,  2.7109,  ..., -3.0430, -0.8545, -0.3467],
         [-4.3711,  0.8540,  2.7109,  ..., -3.0449, -0.8564, -0.3462],
         [-2.0098, -1.3584, -1.8809,  ...,  4.1445, -1.2070, -1.4941]],

        ...,

        [[-2.2754, -1.2852, -1.8193,  ...,  4.1484, -1.0674, -1.5576],
         [-2.2812, -1.2900, -1.8174,  ...,  4.1484, -1.0605, -1.5635],
         [-1.9316, -1.3955, -1.8916,  ...,  4.1250, -1.2256, -1.4766]],

        [[-2.3359, -1.3320, -1.7861,  ...,  4.1250, -0.9

Epoch 0:  10%|▉         | 568/5879 [06:21<59:27,  1.49it/s, v_num=7]
 tensor([[[-4.3906,  0.8730,  2.6660,  ..., -3.1113, -0.8906, -0.3013],
         [-4.3945,  0.8730,  2.6680,  ..., -3.1094, -0.8867, -0.3037],
         [-1.5908, -1.5908, -1.9453,  ...,  4.0000, -1.2891, -1.3887]],

        [[-4.4531,  0.8652,  2.7168,  ..., -3.0781, -0.8174, -0.3528],
         [-4.4453,  0.8662,  2.7129,  ..., -3.0820, -0.8267, -0.3494],
         [-2.1348, -1.3145, -1.8916,  ...,  4.2070, -1.1895, -1.5166]],

        [[-0.1925, -4.0625, -1.0039,  ..., -2.1445, -1.5762,  1.0635],
         [-0.1799, -4.1172, -1.0020,  ..., -2.1055, -1.5469,  1.0361],
         [-1.6250, -1.5547, -1.9492,  ...,  4.0273, -1.3027, -1.3955]],

        ...,

        [[-2.2559, -1.2969, -1.8486,  ...,  4.1914, -1.0977, -1.5479],
         [-2.2539, -1.2959, -1.8496,  ...,  4.1914, -1.0996, -1.5469],
         [-2.0977, -1.3164, -1.8984,  ...,  4.2070, -1.2129, -1.5039]],

        [[-1.8301, -0.4053,  0.4390,  ..., -3.3223, -2.1

Epoch 0:  10%|▉         | 574/5879 [06:23<59:00,  1.50it/s, v_num=7]
 tensor([[[-4.4727,  0.8931,  2.6680,  ..., -3.1465, -0.8564, -0.3147],
         [-4.4805,  0.8926,  2.6758,  ..., -3.1406, -0.8462, -0.3220],
         [-1.7598, -1.4775, -1.9658,  ...,  4.1836, -1.3262, -1.4229]],

        [[-0.0233, -4.7227, -0.9072,  ..., -1.5283, -1.0527,  0.6147],
         [-0.0233, -4.7227, -0.9072,  ..., -1.5283, -1.0527,  0.6143],
         [-0.0235, -4.7383, -0.8843,  ..., -1.5215, -1.0342,  0.5986]],

        [[-4.5352,  0.8799,  2.7148,  ..., -3.1113, -0.7764, -0.3674],
         [-4.5352,  0.8794,  2.7148,  ..., -3.1113, -0.7754, -0.3674],
         [-2.1797, -1.3213, -1.9023,  ...,  4.2578, -1.1807, -1.5303]],

        ...,

        [[-1.6475, -0.6230,  0.2043,  ..., -3.3125, -2.1289,  1.3877],
         [-1.7344, -0.5259,  0.2886,  ..., -3.3301, -2.1172,  1.3604],
         [-1.4453, -1.6846, -1.9746,  ...,  3.9707, -1.3242, -1.3457]],

        [[-4.5039,  0.8882,  2.6914,  ..., -3.1270, -0.8

Epoch 0:  10%|▉         | 579/5879 [06:25<58:50,  1.50it/s, v_num=7]
 tensor([[[-4.6094,  0.8887,  2.7266,  ..., -3.1289, -0.7197, -0.3955],
         [-4.6094,  0.8887,  2.7266,  ..., -3.1289, -0.7192, -0.3955],
         [-2.1660, -1.3418, -1.9248,  ...,  4.3008, -1.2041, -1.5342]],

        [[-4.5742,  0.9053,  2.7070,  ..., -3.1543, -0.7803, -0.3596],
         [-4.5742,  0.9053,  2.7070,  ..., -3.1523, -0.7783, -0.3604],
         [-2.1230, -1.3281, -1.9404,  ...,  4.3047, -1.2461, -1.5107]],

        [[-0.0049, -4.7578, -0.9316,  ..., -1.5215, -1.0605,  0.6245],
         [-0.0049, -4.7578, -0.9321,  ..., -1.5225, -1.0625,  0.6260],
         [ 0.0080, -4.7891, -0.8989,  ..., -1.4902, -1.0127,  0.5889]],

        ...,

        [[-4.5312,  0.9155,  2.6758,  ..., -3.1797, -0.8379, -0.3206],
         [-4.5312,  0.9150,  2.6758,  ..., -3.1797, -0.8369, -0.3213],
         [-1.8320, -1.4385, -1.9854,  ...,  4.2695, -1.3555, -1.4395]],

        [[-4.5938,  0.8970,  2.7246,  ..., -3.1426, -0.7

Epoch 0:  10%|▉         | 585/5879 [06:31<59:05,  1.49it/s, v_num=7]
 tensor([[[-4.6094,  0.9351,  2.7090,  ..., -3.2148, -0.8066, -0.3447],
         [-4.6094,  0.9351,  2.7070,  ..., -3.2148, -0.8066, -0.3442],
         [-1.9814, -1.3691, -2.0039,  ...,  4.3516, -1.3604, -1.4775]],

        [[-2.9590,  0.4250,  1.3389,  ..., -3.5176, -1.7656,  0.8491],
         [-3.1934,  0.5454,  1.5361,  ..., -3.5215, -1.6768,  0.7158],
         [-1.8701, -1.4092, -2.0195,  ...,  4.3320, -1.4033, -1.4502]],

        [[-0.0472, -4.5938, -1.0039,  ..., -1.7832, -1.2793,  0.8389],
         [-0.0478, -4.5938, -1.0039,  ..., -1.7861, -1.2812,  0.8413],
         [ 0.0630, -4.8477, -0.9199,  ..., -1.3975, -0.9668,  0.5342]],

        ...,

        [[-2.2051, -1.3574, -1.9287,  ...,  4.3008, -1.1768, -1.5430],
         [-2.1660, -1.3408, -1.9492,  ...,  4.3398, -1.2324, -1.5303],
         [-1.7344, -1.4902, -2.0312,  ...,  4.2852, -1.4199, -1.4248]],

        [[-3.6777,  0.7432,  1.9385,  ..., -3.4941, -1.4

Epoch 0:  10%|█         | 591/5879 [06:33<58:38,  1.50it/s, v_num=7]
 tensor([[[-2.1230, -1.3389, -2.0078,  ...,  4.3789, -1.3105, -1.5195],
         [-2.1211, -1.3379, -2.0098,  ...,  4.3828, -1.3125, -1.5186],
         [-1.9531, -1.3799, -2.0527,  ...,  4.3906, -1.4170, -1.4775]],

        [[-4.6953,  0.9473,  2.7461,  ..., -3.2480, -0.7671, -0.3730],
         [-4.6953,  0.9473,  2.7461,  ..., -3.2480, -0.7671, -0.3733],
         [-2.0293, -1.3545, -2.0391,  ...,  4.3906, -1.3828, -1.4971]],

        [[-2.1230, -1.3340, -2.0039,  ...,  4.3789, -1.3115, -1.5166],
         [-2.1289, -1.3369, -2.0020,  ...,  4.3789, -1.3037, -1.5195],
         [-1.9160, -1.3877, -2.0586,  ...,  4.3867, -1.4385, -1.4658]],

        ...,

        [[ 0.0385, -4.8516, -0.9780,  ..., -1.5273, -1.0967,  0.6519],
         [ 0.0369, -4.8477, -0.9790,  ..., -1.5322, -1.1006,  0.6558],
         [ 0.1898, -4.7773, -1.0557,  ..., -0.6816, -0.7827,  0.1799]],

        [[-4.6992,  0.9448,  2.7480,  ..., -3.2461, -0.7

Epoch 0:  10%|█         | 597/5879 [06:34<58:12,  1.51it/s, v_num=7]
 tensor([[[-2.0996, -1.3389, -2.0586,  ...,  4.4141, -1.3730, -1.5176],
         [-2.0996, -1.3389, -2.0586,  ...,  4.4141, -1.3730, -1.5176],
         [-1.6973, -1.5059, -2.1230,  ...,  4.3555, -1.5273, -1.4297]],

        [[-3.1152,  0.4695,  1.4023,  ..., -3.5918, -1.6934,  0.8081],
         [-2.5977,  0.1522,  0.9580,  ..., -3.5547, -1.8486,  1.0723],
         [-2.0020, -1.3525, -2.0898,  ...,  4.4219, -1.4492, -1.4912]],

        [[-2.1289, -1.3604, -2.0410,  ...,  4.4023, -1.3340, -1.5381],
         [-2.1270, -1.3662, -2.0469,  ...,  4.4062, -1.3359, -1.5400],
         [-2.0879, -1.3574, -2.0664,  ...,  4.4180, -1.3799, -1.5273]],

        ...,

        [[-4.7695,  0.9585,  2.7793,  ..., -3.2910, -0.7461, -0.3877],
         [-4.7656,  0.9590,  2.7773,  ..., -3.2910, -0.7466, -0.3872],
         [-2.1113, -1.3867, -2.0527,  ...,  4.4141, -1.3389, -1.5488]],

        [[-4.7344,  0.9688,  2.7578,  ..., -3.3105, -0.7

Epoch 0:  10%|█         | 603/5879 [06:47<59:26,  1.48it/s, v_num=7]
 tensor([[[-4.8320,  0.9722,  2.8105,  ..., -3.3320, -0.7324, -0.3945],
         [-4.8320,  0.9717,  2.8105,  ..., -3.3320, -0.7310, -0.3950],
         [-2.0723, -1.3506, -2.1152,  ...,  4.4375, -1.4375, -1.5225]],

        [[-4.8203,  0.9780,  2.8027,  ..., -3.3418, -0.7524, -0.3838],
         [-4.8203,  0.9775,  2.8027,  ..., -3.3418, -0.7524, -0.3833],
         [-2.0469, -1.3408, -2.1230,  ...,  4.4414, -1.4629, -1.5088]],

        [[-4.8516,  0.9600,  2.8184,  ..., -3.3184, -0.6963, -0.4158],
         [-4.8516,  0.9600,  2.8164,  ..., -3.3184, -0.6963, -0.4160],
         [-2.0781, -1.3594, -2.1113,  ...,  4.4375, -1.4268, -1.5293]],

        ...,

        [[-4.8242,  0.9756,  2.8086,  ..., -3.3379, -0.7446, -0.3879],
         [-4.8242,  0.9756,  2.8086,  ..., -3.3379, -0.7441, -0.3879],
         [-2.0547, -1.3428, -2.1211,  ...,  4.4414, -1.4570, -1.5117]],

        [[ 0.1991, -4.9961, -1.0244,  ..., -1.0771, -0.8

Epoch 0:  10%|█         | 609/5879 [06:49<59:00,  1.49it/s, v_num=7]
 tensor([[[-2.0703, -1.3242, -2.1523,  ...,  4.4492, -1.4873, -1.5117],
         [-2.0703, -1.3242, -2.1523,  ...,  4.4492, -1.4873, -1.5117],
         [-1.6777, -1.4893, -2.2129,  ...,  4.3945, -1.6367, -1.4229]],

        [[-4.8789,  0.9946,  2.8418,  ..., -3.3828, -0.7480, -0.3882],
         [-4.8789,  0.9951,  2.8418,  ..., -3.3828, -0.7476, -0.3879],
         [-1.9014, -1.3682, -2.1973,  ...,  4.4531, -1.5898, -1.4688]],

        [[-4.8516,  1.0000,  2.8242,  ..., -3.4004, -0.7866, -0.3638],
         [-4.8516,  1.0000,  2.8223,  ..., -3.4004, -0.7900, -0.3623],
         [-1.9834, -1.3330, -2.1816,  ...,  4.4570, -1.5576, -1.4844]],

        ...,

        [[-4.9062,  0.9829,  2.8574,  ..., -3.3633, -0.7104, -0.4167],
         [-4.9023,  0.9849,  2.8574,  ..., -3.3672, -0.7227, -0.4121],
         [-2.0684, -1.3594, -2.1543,  ...,  4.4570, -1.4717, -1.5312]],

        [[-4.2383,  0.8867,  2.2930,  ..., -3.6074, -1.2

Epoch 0:  10%|█         | 615/5879 [06:51<58:38,  1.50it/s, v_num=7]
 tensor([[[-0.5215, -2.8027, -0.9243,  ..., -2.8555, -1.8906,  1.5449],
         [-0.5293, -2.7832, -0.9185,  ..., -2.8613, -1.8936,  1.5469],
         [-1.9521, -1.3213, -2.2285,  ...,  4.4727, -1.6221, -1.4707]],

        [[-2.0332, -1.2979, -2.2051,  ...,  4.4648, -1.5615, -1.4814],
         [-2.0332, -1.2979, -2.2051,  ...,  4.4609, -1.5615, -1.4805],
         [-0.9380, -2.0625, -2.1914,  ...,  3.7871, -1.5850, -1.2256]],

        [[-4.9023,  1.0156,  2.8535,  ..., -3.4414, -0.7920, -0.3630],
         [-4.9023,  1.0156,  2.8535,  ..., -3.4414, -0.7920, -0.3630],
         [-1.8955, -1.3506, -2.2402,  ...,  4.4688, -1.6436, -1.4619]],

        ...,

        [[-4.9414,  1.0088,  2.8828,  ..., -3.4199, -0.7427, -0.3965],
         [-4.9414,  1.0088,  2.8828,  ..., -3.4199, -0.7427, -0.3967],
         [-2.0469, -1.3242, -2.2070,  ...,  4.4727, -1.5488, -1.5107]],

        [[-2.0566, -1.3066, -2.1992,  ...,  4.4648, -1.5

Epoch 0:  11%|█         | 621/5879 [06:52<58:15,  1.50it/s, v_num=7]
 tensor([[[-5.0234,  1.0098,  2.9336,  ..., -3.4395, -0.6934, -0.4275],
         [-5.0195,  1.0107,  2.9336,  ..., -3.4395, -0.6958, -0.4258],
         [-2.0547, -1.3350, -2.2402,  ...,  4.4883, -1.5664, -1.5234]],

        [[-5.0352,  1.0049,  2.9395,  ..., -3.4297, -0.6768, -0.4426],
         [-5.0352,  1.0049,  2.9395,  ..., -3.4297, -0.6768, -0.4426],
         [-2.0469, -1.3320, -2.2441,  ...,  4.4883, -1.5752, -1.5195]],

        [[-4.5664,  0.9683,  2.5430,  ..., -3.6309, -1.0781, -0.0282],
         [-4.5625,  0.9668,  2.5391,  ..., -3.6328, -1.0811, -0.0238],
         [-1.8633, -1.3389, -2.2852,  ...,  4.4805, -1.7070, -1.4482]],

        ...,

        [[-5.0117,  1.0225,  2.9277,  ..., -3.4492, -0.7266, -0.4150],
         [-5.0078,  1.0225,  2.9277,  ..., -3.4492, -0.7261, -0.4150],
         [-2.0430, -1.3174, -2.2480,  ...,  4.4883, -1.5908, -1.5107]],

        [[-4.9922,  1.0273,  2.9199,  ..., -3.4609, -0.7

Epoch 0:  11%|█         | 627/5879 [07:04<59:15,  1.48it/s, v_num=7]
 tensor([[[-0.5806, -2.6914, -0.8887,  ..., -2.9258, -1.8760,  1.5791],
         [-0.4836, -2.9414, -0.9575,  ..., -2.8418, -1.8438,  1.5518],
         [-1.8789, -1.3203, -2.3203,  ...,  4.5000, -1.7324, -1.4463]],

        [[-5.0586,  1.0400,  2.9629,  ..., -3.4902, -0.7383, -0.4146],
         [-5.0586,  1.0400,  2.9629,  ..., -3.4902, -0.7383, -0.4143],
         [-2.0273, -1.2881, -2.2949,  ...,  4.5039, -1.6533, -1.4941]],

        [[-5.0820,  1.0293,  2.9746,  ..., -3.4727, -0.7012, -0.4380],
         [-5.0820,  1.0293,  2.9746,  ..., -3.4727, -0.7012, -0.4382],
         [-2.0508, -1.3232, -2.2793,  ...,  4.5000, -1.6074, -1.5195]],

        ...,

        [[-3.1172,  0.3879,  1.3242,  ..., -3.7285, -1.6621,  0.9214],
         [-3.0723,  0.3618,  1.2871,  ..., -3.7246, -1.6738,  0.9438],
         [-2.0000, -1.2734, -2.2988,  ...,  4.5078, -1.6816, -1.4756]],

        [[-5.0312,  1.0479,  2.9395,  ..., -3.5039, -0.7

Epoch 0:  11%|█         | 633/5879 [07:06<58:52,  1.48it/s, v_num=7]
 tensor([[[-5.1445,  1.0430,  3.0137,  ..., -3.5000, -0.6860, -0.4521],
         [-5.1445,  1.0430,  3.0137,  ..., -3.5000, -0.6855, -0.4524],
         [-2.0312, -1.2939, -2.3281,  ...,  4.5234, -1.6729, -1.5020]],

        [[-5.0977,  1.0625,  2.9824,  ..., -3.5352, -0.7690, -0.4055],
         [-5.0977,  1.0625,  2.9824,  ..., -3.5352, -0.7690, -0.4058],
         [-2.0195, -1.2695, -2.3359,  ...,  4.5234, -1.7012, -1.4844]],

        [[-5.1406,  1.0430,  3.0137,  ..., -3.5020, -0.6895, -0.4514],
         [-5.1445,  1.0420,  3.0137,  ..., -3.5020, -0.6860, -0.4529],
         [-2.0293, -1.2910, -2.3301,  ...,  4.5234, -1.6768, -1.5000]],

        ...,

        [[-2.0078, -1.2461, -2.3320,  ...,  4.5117, -1.7100, -1.4609],
         [-2.0098, -1.2461, -2.3320,  ...,  4.5117, -1.7090, -1.4609],
         [-1.3867, -1.6074, -2.3730,  ...,  4.3047, -1.8281, -1.3447]],

        [[-3.6445,  0.6392,  1.7510,  ..., -3.7812, -1.5

Epoch 0:  11%|█         | 639/5879 [07:07<58:29,  1.49it/s, v_num=7]
 tensor([[[-5.1523,  1.0762,  3.0234,  ..., -3.5645, -0.7681, -0.4163],
         [-5.1523,  1.0762,  3.0234,  ..., -3.5645, -0.7671, -0.4167],
         [-1.9492, -1.2559, -2.3906,  ...,  4.5430, -1.7891, -1.4512]],

        [[-5.2070,  1.0576,  3.0605,  ..., -3.5234, -0.6768, -0.4714],
         [-5.2070,  1.0576,  3.0605,  ..., -3.5234, -0.6763, -0.4714],
         [-2.1348, -1.3271, -2.3145,  ...,  4.4609, -1.5889, -1.5420]],

        [[-2.0273, -1.2383, -2.3652,  ...,  4.5352, -1.7324, -1.4697],
         [-2.0273, -1.2383, -2.3652,  ...,  4.5352, -1.7324, -1.4707],
         [-1.8086, -1.3115, -2.4082,  ...,  4.5273, -1.8486, -1.4189]],

        ...,

        [[-5.1562,  1.0762,  3.0273,  ..., -3.5605, -0.7559, -0.4219],
         [-5.1602,  1.0762,  3.0273,  ..., -3.5625, -0.7544, -0.4221],
         [-1.9209, -1.2666, -2.3965,  ...,  4.5430, -1.8037, -1.4443]],

        [[ 0.1243, -5.1289, -1.0400,  ..., -1.7080, -1.2

Epoch 0:  11%|█         | 645/5879 [07:09<58:08,  1.50it/s, v_num=7]
 tensor([[[-5.1953,  1.0908,  3.0586,  ..., -3.5977, -0.7793, -0.4197],
         [-5.1953,  1.0908,  3.0586,  ..., -3.5977, -0.7793, -0.4202],
         [-1.9619, -1.2363, -2.4277,  ...,  4.5547, -1.8223, -1.4502]],

        [[-5.2500,  1.0752,  3.0996,  ..., -3.5625, -0.6963, -0.4729],
         [-5.2500,  1.0752,  3.0996,  ..., -3.5625, -0.6963, -0.4729],
         [-2.0273, -1.2754, -2.4023,  ...,  4.5547, -1.7441, -1.4961]],

        [[-5.2344,  1.0859,  3.0898,  ..., -3.5781, -0.7354, -0.4517],
         [-5.2344,  1.0859,  3.0898,  ..., -3.5781, -0.7354, -0.4517],
         [-2.0137, -1.2490, -2.4121,  ...,  4.5586, -1.7764, -1.4785]],

        ...,

        [[-5.2539,  1.0732,  3.0996,  ..., -3.5605, -0.6914, -0.4773],
         [-5.2539,  1.0732,  3.0996,  ..., -3.5605, -0.6914, -0.4773],
         [-2.0215, -1.2715, -2.4062,  ...,  4.5547, -1.7539, -1.4932]],

        [[-0.1176, -3.9961, -1.2012,  ..., -2.5137, -1.6

Epoch 0:  11%|█         | 651/5879 [07:20<58:54,  1.48it/s, v_num=7]
 tensor([[[-5.1680,  1.0977,  2.9883,  ..., -3.6758, -0.8472, -0.3457],
         [-5.1797,  1.0986,  2.9980,  ..., -3.6699, -0.8403, -0.3547],
         [-1.8535, -1.2588, -2.4824,  ...,  4.5625, -1.9072, -1.4199]],

        [[-5.3438,  1.0742,  3.1289,  ..., -3.5664, -0.6274, -0.5220],
         [-5.3438,  1.0723,  3.1270,  ..., -3.5625, -0.6221, -0.5249],
         [-2.0488, -1.2773, -2.4277,  ...,  4.5508, -1.7471, -1.5059]],

        [[-5.3203,  1.0898,  3.1211,  ..., -3.5918, -0.6772, -0.4878],
         [-5.3203,  1.0898,  3.1211,  ..., -3.5918, -0.6763, -0.4883],
         [-2.0098, -1.2529, -2.4492,  ...,  4.5742, -1.7998, -1.4844]],

        ...,

        [[ 0.1989, -5.4141, -0.9854,  ..., -1.5127, -1.0645,  0.7002],
         [ 0.2002, -5.4180, -0.9849,  ..., -1.5088, -1.0625,  0.6973],
         [-1.9414, -1.2178, -2.4668,  ...,  4.5703, -1.8750, -1.4375]],

        [[-5.3359,  1.0762,  3.1250,  ..., -3.5742, -0.6

Epoch 0:  11%|█         | 657/5879 [07:22<58:37,  1.48it/s, v_num=7]
 tensor([[[-5.3633,  1.1299,  3.0977,  ..., -3.6484, -0.7285, -0.4614],
         [-5.3633,  1.1299,  3.0977,  ..., -3.6484, -0.7280, -0.4612],
         [-1.9893, -1.2041, -2.4883,  ...,  4.5938, -1.8770, -1.4551]],

        [[-5.3633,  1.1299,  3.0977,  ..., -3.6465, -0.7266, -0.4626],
         [-5.3633,  1.1299,  3.0977,  ..., -3.6465, -0.7266, -0.4624],
         [-1.9336, -1.2080, -2.5020,  ...,  4.5898, -1.9160, -1.4346]],

        [[-5.3633,  1.1309,  3.0977,  ..., -3.6484, -0.7280, -0.4607],
         [-5.3594,  1.1309,  3.0977,  ..., -3.6484, -0.7290, -0.4597],
         [-1.9912, -1.2100, -2.4883,  ...,  4.5938, -1.8730, -1.4600]],

        ...,

        [[-0.9937, -1.8320, -2.4902,  ...,  4.0820, -1.8896, -1.2197],
         [-0.9912, -1.8350, -2.4883,  ...,  4.0781, -1.8896, -1.2197],
         [ 0.3440, -4.4062, -1.4766,  ...,  0.8306, -0.9595, -0.2644]],

        [[-2.0215, -1.2041, -2.4707,  ...,  4.5859, -1.8

Epoch 0:  11%|█▏        | 663/5879 [07:24<58:14,  1.49it/s, v_num=7]
 tensor([[[-5.4375,  1.1641,  3.0996,  ..., -3.6641, -0.7129, -0.4666],
         [-5.4375,  1.1631,  3.0977,  ..., -3.6660, -0.7144, -0.4651],
         [-1.9883, -1.2256, -2.5039,  ...,  4.6328, -1.8867, -1.4736]],

        [[-5.4375,  1.1621,  3.0977,  ..., -3.6680, -0.7173, -0.4656],
         [-5.4375,  1.1621,  3.0977,  ..., -3.6680, -0.7168, -0.4661],
         [-1.9824, -1.2051, -2.5078,  ...,  4.6328, -1.9062, -1.4609]],

        [[-5.4688,  1.1455,  3.1113,  ..., -3.6387, -0.6538, -0.5049],
         [-5.4688,  1.1455,  3.1113,  ..., -3.6387, -0.6538, -0.5049],
         [-1.9922, -1.2285, -2.5000,  ...,  4.6328, -1.8818, -1.4766]],

        ...,

        [[ 0.1987, -5.4180, -1.0498,  ..., -1.6094, -1.1348,  0.8423],
         [ 0.1960, -5.4102, -1.0518,  ..., -1.6191, -1.1406,  0.8506],
         [-1.8867, -1.1973, -2.5273,  ...,  4.6289, -1.9814, -1.4160]],

        [[-5.4766,  1.1426,  3.1172,  ..., -3.6309, -0.6

Epoch 0:  11%|█▏        | 669/5879 [07:25<57:52,  1.50it/s, v_num=7]
 tensor([[[ 0.2876, -5.6406, -0.9966,  ..., -1.3574, -0.9800,  0.6807],
         [ 0.2869, -5.6406, -0.9966,  ..., -1.3574, -0.9805,  0.6812],
         [ 0.2795, -5.6641, -0.9692,  ..., -1.3613, -0.9722,  0.6689]],

        [[-0.1613, -4.0156, -1.1758,  ..., -2.5215, -1.6016,  1.5205],
         [-0.1060, -4.2031, -1.1865,  ..., -2.4375, -1.5654,  1.4736],
         [-1.9375, -1.1689, -2.5352,  ...,  4.6758, -1.9932, -1.4355]],

        [[-4.3359,  0.8877,  2.0762,  ..., -3.9277, -1.2900,  0.4204],
         [-4.4961,  0.9458,  2.2090,  ..., -3.9219, -1.2354,  0.3169],
         [-1.7676, -1.2412, -2.5625,  ...,  4.6602, -2.0566, -1.3955]],

        ...,

        [[-5.4883,  1.2041,  3.0957,  ..., -3.6914, -0.7334, -0.4531],
         [-5.4844,  1.2041,  3.0938,  ..., -3.6934, -0.7354, -0.4512],
         [-1.9639, -1.1855, -2.5273,  ...,  4.6797, -1.9619, -1.4541]],

        [[-5.5078,  1.2002,  3.1152,  ..., -3.6816, -0.7

Epoch 0:  11%|█▏        | 675/5879 [07:36<58:39,  1.48it/s, v_num=7]
 tensor([[[-1.9014, -1.1611, -2.5566,  ...,  4.7109, -2.0312, -1.4121],
         [-1.9150, -1.1572, -2.5547,  ...,  4.7109, -2.0254, -1.4180],
         [ 0.4661, -4.8828, -1.3662,  ...,  0.4102, -0.9404, -0.0829]],

        [[-5.2031,  1.1699,  2.7793,  ..., -3.8496, -0.9634, -0.1523],
         [-5.2188,  1.1738,  2.7930,  ..., -3.8457, -0.9551, -0.1663],
         [-1.8516, -1.1914, -2.5723,  ...,  4.7227, -2.0645, -1.4150]],

        [[-5.5625,  1.2363,  3.1309,  ..., -3.7129, -0.7275, -0.4675],
         [-5.5625,  1.2363,  3.1309,  ..., -3.7129, -0.7275, -0.4675],
         [-1.9385, -1.1826, -2.5547,  ...,  4.7305, -2.0098, -1.4521]],

        ...,

        [[-5.5703,  1.2354,  3.1348,  ..., -3.7090, -0.7173, -0.4744],
         [-5.5703,  1.2354,  3.1348,  ..., -3.7090, -0.7168, -0.4746],
         [-1.9395, -1.1924, -2.5547,  ...,  4.7344, -2.0020, -1.4580]],

        [[-5.5742,  1.2324,  3.1445,  ..., -3.7031, -0.7

Epoch 0:  12%|█▏        | 681/5879 [07:39<58:26,  1.48it/s, v_num=7]
 tensor([[[-5.3242,  1.2109,  2.8613,  ..., -3.8672, -0.9292, -0.2032],
         [-5.4258,  1.2344,  2.9551,  ..., -3.8340, -0.8745, -0.2898],
         [-1.9062, -1.1904, -2.5938,  ...,  4.7695, -2.0566, -1.4521]],

        [[-5.6289,  1.2646,  3.1699,  ..., -3.7422, -0.7197, -0.4731],
         [-5.6289,  1.2646,  3.1699,  ..., -3.7422, -0.7197, -0.4731],
         [-1.9316, -1.2393, -2.5762,  ...,  4.7617, -2.0000, -1.4834]],

        [[-5.6367,  1.2598,  3.1738,  ..., -3.7344, -0.6982, -0.4832],
         [-5.6367,  1.2598,  3.1738,  ..., -3.7344, -0.6987, -0.4827],
         [-1.9150, -1.2266, -2.5840,  ...,  4.7695, -2.0234, -1.4746]],

        ...,

        [[-5.6367,  1.2588,  3.1738,  ..., -3.7344, -0.6982, -0.4797],
         [-5.6367,  1.2588,  3.1738,  ..., -3.7344, -0.6973, -0.4795],
         [-1.9170, -1.2314, -2.5820,  ...,  4.7695, -2.0176, -1.4775]],

        [[-5.6211,  1.2686,  3.1602,  ..., -3.7480, -0.7

Epoch 0:  12%|█▏        | 687/5879 [07:41<58:04,  1.49it/s, v_num=7]
 tensor([[[-5.6719,  1.2930,  3.1875,  ..., -3.7930, -0.7432, -0.4573],
         [-5.6719,  1.2930,  3.1875,  ..., -3.7930, -0.7432, -0.4570],
         [-1.8906, -1.2236, -2.6328,  ...,  4.8047, -2.0723, -1.4736]],

        [[-5.6797,  1.2939,  3.1914,  ..., -3.7891, -0.7329, -0.4631],
         [-5.6797,  1.2939,  3.1914,  ..., -3.7891, -0.7329, -0.4629],
         [-1.8818, -1.2041, -2.6387,  ...,  4.8008, -2.0938, -1.4609]],

        [[-2.4258, -0.2703,  0.4309,  ..., -3.7285, -1.6816,  1.4404],
         [-2.4863, -0.2173,  0.4805,  ..., -3.7441, -1.6748,  1.4180],
         [-1.8389, -1.1973, -2.6445,  ...,  4.7891, -2.1289, -1.4395]],

        ...,

        [[-5.6680,  1.2959,  3.1777,  ..., -3.7949, -0.7534, -0.4514],
         [-5.6680,  1.2959,  3.1777,  ..., -3.7949, -0.7529, -0.4514],
         [-1.8643, -1.1953, -2.6426,  ...,  4.7930, -2.1113, -1.4482]],

        [[-5.5234,  1.2715,  3.0215,  ..., -3.8574, -0.8

Epoch 0:  12%|█▏        | 693/5879 [07:42<57:41,  1.50it/s, v_num=7]
 tensor([[[-5.7266,  1.3203,  3.2129,  ..., -3.8359, -0.7524, -0.4512],
         [-5.7266,  1.3203,  3.2129,  ..., -3.8359, -0.7520, -0.4514],
         [-1.8652, -1.2051, -2.6855,  ...,  4.8203, -2.1328, -1.4639]],

        [[-5.7500,  1.3105,  3.2363,  ..., -3.8164, -0.7109, -0.4768],
         [-5.7500,  1.3105,  3.2344,  ..., -3.8164, -0.7109, -0.4768],
         [-1.8730, -1.2422, -2.6797,  ...,  4.8242, -2.1016, -1.4863]],

        [[-5.7305,  1.3193,  3.2207,  ..., -3.8301, -0.7471, -0.4536],
         [-5.7305,  1.3193,  3.2207,  ..., -3.8301, -0.7471, -0.4539],
         [-1.8711, -1.2412, -2.6797,  ...,  4.8242, -2.1035, -1.4854]],

        ...,

        [[ 0.3342, -5.8281, -1.0518,  ..., -1.3926, -1.0771,  0.7690],
         [ 0.3337, -5.8281, -1.0518,  ..., -1.3926, -1.0771,  0.7690],
         [ 0.4387, -5.8438, -1.0430,  ..., -1.0303, -0.9521,  0.5044]],

        [[-5.7656,  1.2988,  3.2461,  ..., -3.7988, -0.6

Epoch 0:  12%|█▏        | 699/5879 [07:47<57:46,  1.49it/s, v_num=7]
 tensor([[[-5.7891,  1.3418,  3.2559,  ..., -3.8652, -0.7598, -0.4500],
         [-5.7891,  1.3428,  3.2578,  ..., -3.8652, -0.7598, -0.4500],
         [-1.8779, -1.2598, -2.7129,  ...,  4.8359, -2.1133, -1.4980]],

        [[-5.7695,  1.3428,  3.2344,  ..., -3.8750, -0.7700, -0.4360],
         [-5.7695,  1.3428,  3.2344,  ..., -3.8750, -0.7695, -0.4360],
         [-1.8594, -1.2246, -2.7266,  ...,  4.8477, -2.1523, -1.4785]],

        [[-5.7969,  1.3389,  3.2656,  ..., -3.8633, -0.7437, -0.4595],
         [-5.7969,  1.3389,  3.2656,  ..., -3.8633, -0.7427, -0.4600],
         [-1.8564, -1.2266, -2.7266,  ...,  4.8477, -2.1523, -1.4785]],

        ...,

        [[-5.5352,  1.2910,  2.9785,  ..., -3.9668, -0.9082, -0.2333],
         [-5.4766,  1.2754,  2.9199,  ..., -3.9844, -0.9375, -0.1827],
         [-1.8105, -1.2080, -2.7363,  ...,  4.8320, -2.2012, -1.4482]],

        [[-5.7930,  1.3428,  3.2617,  ..., -3.8633, -0.7

Epoch 0:  12%|█▏        | 705/5879 [07:57<58:22,  1.48it/s, v_num=7]
 tensor([[[-5.8672,  1.3545,  3.3184,  ..., -3.8828, -0.7192, -0.4729],
         [-5.8672,  1.3535,  3.3184,  ..., -3.8828, -0.7192, -0.4729],
         [-1.8799, -1.2637, -2.7480,  ...,  4.8438, -2.1328, -1.5039]],

        [[-5.8672,  1.3535,  3.3184,  ..., -3.8809, -0.7139, -0.4756],
         [-5.8672,  1.3535,  3.3164,  ..., -3.8809, -0.7148, -0.4751],
         [-1.8652, -1.2607, -2.7559,  ...,  4.8555, -2.1484, -1.4980]],

        [[-5.8516,  1.3604,  3.3008,  ..., -3.8965, -0.7471, -0.4592],
         [-5.8516,  1.3604,  3.3008,  ..., -3.8965, -0.7471, -0.4590],
         [-1.8613, -1.2549, -2.7598,  ...,  4.8594, -2.1562, -1.4932]],

        ...,

        [[-5.8594,  1.3604,  3.3105,  ..., -3.8926, -0.7407, -0.4626],
         [-5.8594,  1.3604,  3.3105,  ..., -3.8926, -0.7402, -0.4631],
         [-1.8525, -1.2480, -2.7637,  ...,  4.8633, -2.1680, -1.4893]],

        [[-5.8555,  1.3613,  3.3105,  ..., -3.8926, -0.7

Epoch 0:  12%|█▏        | 711/5879 [07:58<58:00,  1.48it/s, v_num=7]
 tensor([[[-5.9219,  1.3721,  3.3613,  ..., -3.9082, -0.7134, -0.4785],
         [-5.9219,  1.3721,  3.3613,  ..., -3.9082, -0.7134, -0.4785],
         [-5.9609,  1.3740,  3.4004,  ..., -3.8867, -0.7305, -0.5239]],

        [[-5.9023,  1.3857,  3.3379,  ..., -3.9297, -0.7676, -0.4541],
         [-5.9023,  1.3857,  3.3359,  ..., -3.9297, -0.7681, -0.4539],
         [-1.8447, -1.2148, -2.8105,  ...,  4.8828, -2.2227, -1.4746]],

        [[-5.9102,  1.3838,  3.3496,  ..., -3.9238, -0.7490, -0.4619],
         [-5.9102,  1.3838,  3.3496,  ..., -3.9238, -0.7490, -0.4622],
         [-1.8496, -1.2461, -2.8027,  ...,  4.8789, -2.1953, -1.4902]],

        ...,

        [[-5.9141,  1.3818,  3.3555,  ..., -3.9199, -0.7437, -0.4663],
         [-5.9141,  1.3818,  3.3555,  ..., -3.9199, -0.7441, -0.4666],
         [-1.8613, -1.2559, -2.7949,  ...,  4.8672, -2.1797, -1.4980]],

        [[-5.9141,  1.3838,  3.3555,  ..., -3.9180, -0.7

Epoch 0:  12%|█▏        | 717/5879 [08:00<57:42,  1.49it/s, v_num=7]
 tensor([[[-5.9531,  1.4082,  3.3828,  ..., -3.9570, -0.7729, -0.4595],
         [-5.9531,  1.4082,  3.3828,  ..., -3.9570, -0.7729, -0.4597],
         [-1.8330, -1.1855, -2.8535,  ...,  4.8945, -2.2773, -1.4600]],

        [[-5.9570,  1.4092,  3.3848,  ..., -3.9531, -0.7715, -0.4604],
         [-5.9570,  1.4092,  3.3848,  ..., -3.9531, -0.7715, -0.4602],
         [-6.0391,  1.3496,  3.4492,  ..., -3.8574, -0.6592, -0.5908]],

        [[-5.9648,  1.4033,  3.3926,  ..., -3.9512, -0.7573, -0.4678],
         [-5.9648,  1.4033,  3.3926,  ..., -3.9512, -0.7568, -0.4678],
         [-1.8418, -1.2285, -2.8457,  ...,  4.8984, -2.2402, -1.4844]],

        ...,

        [[-5.9336,  1.4053,  3.3555,  ..., -3.9609, -0.7959, -0.4409],
         [-5.9336,  1.4053,  3.3555,  ..., -3.9609, -0.7954, -0.4409],
         [-1.8408, -1.2031, -2.8496,  ...,  4.8984, -2.2598, -1.4727]],

        [[ 0.3704, -6.0352, -1.0449,  ..., -1.4473, -1.1

Epoch 0:  12%|█▏        | 723/5879 [08:02<57:23,  1.50it/s, v_num=7]
 tensor([[[-6.0234,  1.4180,  3.4473,  ..., -3.9629, -0.7432, -0.4846],
         [-6.0234,  1.4180,  3.4473,  ..., -3.9629, -0.7427, -0.4849],
         [-1.8457, -1.2295, -2.8789,  ...,  4.9102, -2.2598, -1.4912]],

        [[-6.0156,  1.4277,  3.4395,  ..., -3.9727, -0.7710, -0.4722],
         [-6.0156,  1.4277,  3.4395,  ..., -3.9727, -0.7710, -0.4722],
         [-1.8457, -1.2236, -2.8809,  ...,  4.9141, -2.2656, -1.4873]],

        [[-6.0156,  1.4248,  3.4395,  ..., -3.9707, -0.7622, -0.4775],
         [-6.0156,  1.4248,  3.4395,  ..., -3.9707, -0.7627, -0.4771],
         [-1.8467, -1.2266, -2.8789,  ...,  4.9141, -2.2617, -1.4893]],

        ...,

        [[-6.0195,  1.4238,  3.4473,  ..., -3.9688, -0.7617, -0.4756],
         [-6.0195,  1.4238,  3.4473,  ..., -3.9668, -0.7617, -0.4756],
         [-6.0586,  1.4219,  3.4902,  ..., -3.9492, -0.7705, -0.5122]],

        [[-6.0039,  1.4297,  3.4238,  ..., -3.9785, -0.7

Epoch 0:  12%|█▏        | 729/5879 [08:14<58:16,  1.47it/s, v_num=7]
 tensor([[[-6.0742,  1.2656,  3.4062,  ..., -3.6699, -0.3638, -0.7827],
         [-6.0742,  1.2646,  3.4062,  ..., -3.6699, -0.3633, -0.7832],
         [-3.7891, -0.0151,  0.0817,  ...,  0.8926, -0.1708, -1.4414]],

        [[-6.0508,  1.4512,  3.4707,  ..., -3.9941, -0.7993, -0.4697],
         [-6.0508,  1.4512,  3.4727,  ..., -3.9941, -0.7993, -0.4697],
         [-1.8535, -1.2178, -2.9141,  ...,  4.9258, -2.2871, -1.4893]],

        [[-6.0703,  1.4434,  3.4941,  ..., -3.9863, -0.7666, -0.4868],
         [-6.0703,  1.4434,  3.4922,  ..., -3.9863, -0.7661, -0.4863],
         [-1.8447, -1.2158, -2.9180,  ...,  4.9297, -2.2930, -1.4873]],

        ...,

        [[-6.0625,  1.4502,  3.4805,  ..., -3.9961, -0.7935, -0.4773],
         [-6.0625,  1.4492,  3.4805,  ..., -3.9941, -0.7935, -0.4771],
         [-1.8379, -1.1963, -2.9238,  ...,  4.9375, -2.3125, -1.4775]],

        [[-6.0547,  1.4512,  3.4746,  ..., -3.9980, -0.7

Epoch 0:  13%|█▎        | 735/5879 [08:16<57:53,  1.48it/s, v_num=7]
 tensor([[[-6.1211,  1.4639,  3.5371,  ..., -4.0039, -0.7827, -0.4932],
         [-6.1211,  1.4648,  3.5371,  ..., -4.0078, -0.7832, -0.4929],
         [-1.8555, -1.2090, -2.9473,  ...,  4.9336, -2.3105, -1.4883]],

        [[-6.1094,  1.4707,  3.5254,  ..., -4.0117, -0.8037, -0.4858],
         [-6.1094,  1.4707,  3.5254,  ..., -4.0117, -0.8037, -0.4861],
         [-1.8496, -1.2012, -2.9551,  ...,  4.9453, -2.3242, -1.4814]],

        [[-6.0273,  1.4541,  3.4160,  ..., -4.0430, -0.8574, -0.4207],
         [-6.0273,  1.4541,  3.4180,  ..., -4.0430, -0.8569, -0.4207],
         [-1.8232, -1.1465, -2.9688,  ...,  4.9453, -2.3828, -1.4492]],

        ...,

        [[-5.4844,  1.3086,  2.8633,  ..., -4.1758, -1.0732,  0.0292],
         [-5.4648,  1.3027,  2.8457,  ..., -4.1797, -1.0801,  0.0436],
         [-1.8027, -1.1494, -2.9707,  ...,  4.9375, -2.3965, -1.4424]],

        [[-6.1133,  1.4688,  3.5293,  ..., -4.0117, -0.7

Epoch 0:  13%|█▎        | 741/5879 [08:17<57:31,  1.49it/s, v_num=7]
 tensor([[[-6.1680,  1.4893,  3.5723,  ..., -4.0273, -0.7979, -0.5000],
         [-6.1680,  1.4893,  3.5723,  ..., -4.0273, -0.7969, -0.4998],
         [-1.8350, -1.1846, -2.9961,  ...,  4.9648, -2.3633, -1.4727]],

        [[-6.1562,  1.4941,  3.5527,  ..., -4.0352, -0.8218, -0.4905],
         [-6.1562,  1.4941,  3.5527,  ..., -4.0352, -0.8218, -0.4902],
         [-1.8252, -1.1318, -3.0078,  ...,  4.9648, -2.4082, -1.4443]],

        [[-6.1562,  1.4941,  3.5547,  ..., -4.0352, -0.8228, -0.4907],
         [-6.1562,  1.4941,  3.5547,  ..., -4.0352, -0.8228, -0.4907],
         [-1.8369, -1.1719, -2.9980,  ...,  4.9727, -2.3711, -1.4678]],

        ...,

        [[-6.1680,  1.4922,  3.5664,  ..., -4.0312, -0.8081, -0.4998],
         [-6.1680,  1.4922,  3.5664,  ..., -4.0312, -0.8076, -0.4995],
         [-6.2148,  1.4854,  3.6211,  ..., -3.9902, -0.8037, -0.5552]],

        [[-6.1602,  1.4932,  3.5605,  ..., -4.0312, -0.8

Epoch 0:  13%|█▎        | 747/5879 [08:19<57:09,  1.50it/s, v_num=7]
 tensor([[[-6.2070,  1.5195,  3.5898,  ..., -4.0430, -0.8364, -0.5029],
         [-6.2070,  1.5195,  3.5898,  ..., -4.0430, -0.8364, -0.5029],
         [-1.8525, -1.1631, -3.0215,  ...,  4.9727, -2.3809, -1.4727]],

        [[-6.2148,  1.5166,  3.6016,  ..., -4.0391, -0.8252, -0.5063],
         [-6.2148,  1.5166,  3.6016,  ..., -4.0391, -0.8252, -0.5063],
         [-1.8965, -1.1807, -2.9941,  ...,  4.9336, -2.3320, -1.4893]],

        [[ 0.2776, -5.8164, -1.1699,  ..., -1.9307, -1.2646,  1.2812],
         [ 0.2991, -5.9023, -1.1523,  ..., -1.8691, -1.2402,  1.2295],
         [-1.8105, -1.1084, -3.0430,  ...,  4.9727, -2.4531, -1.4326]],

        ...,

        [[-6.2148,  1.5127,  3.6016,  ..., -4.0391, -0.8120, -0.5127],
         [-6.2148,  1.5127,  3.6016,  ..., -4.0391, -0.8115, -0.5127],
         [-6.2109,  1.5234,  3.6016,  ..., -4.0312, -0.8765, -0.5005]],

        [[-6.2227,  1.5098,  3.6113,  ..., -4.0352, -0.8

Epoch 0:  13%|█▎        | 753/5879 [08:32<58:08,  1.47it/s, v_num=7]
 tensor([[[-6.2578,  1.5430,  3.6367,  ..., -4.0508, -0.8506, -0.5156],
         [-6.2578,  1.5430,  3.6367,  ..., -4.0508, -0.8506, -0.5151],
         [-1.8486, -1.1211, -3.0566,  ...,  5.0000, -2.4297, -1.4619]],

        [[-6.2852,  1.4316,  3.6250,  ..., -3.8926, -0.5864, -0.6929],
         [-6.2852,  1.4316,  3.6250,  ..., -3.8926, -0.5859, -0.6934],
         [-6.1758,  1.2412,  3.4434,  ..., -3.4629, -0.2827, -1.0049]],

        [[-6.2734,  1.5254,  3.6543,  ..., -4.0312, -0.7939, -0.5420],
         [-6.2734,  1.5244,  3.6543,  ..., -4.0312, -0.7935, -0.5425],
         [-4.6016,  0.4700,  1.4297,  ..., -0.7441,  0.0990, -1.4385]],

        ...,

        [[-0.5176, -3.2754, -1.1191,  ..., -3.0117, -1.5889,  1.9453],
         [-0.5576, -3.1797, -1.0986,  ..., -3.0391, -1.5938,  1.9521],
         [-1.8359, -1.0742, -3.0664,  ...,  4.9922, -2.4707, -1.4316]],

        [[-6.2578,  1.5449,  3.6348,  ..., -4.0508, -0.8

Epoch 0:  13%|█▎        | 759/5879 [08:33<57:46,  1.48it/s, v_num=7]
 tensor([[[-6.3242,  1.5508,  3.7031,  ..., -4.0469, -0.8257, -0.5503],
         [-6.3242,  1.5508,  3.7012,  ..., -4.0469, -0.8252, -0.5508],
         [-6.3594,  1.5645,  3.7402,  ..., -4.0273, -0.8623, -0.5830]],

        [[-6.3125,  1.5703,  3.6875,  ..., -4.0586, -0.8638, -0.5352],
         [-6.3125,  1.5703,  3.6875,  ..., -4.0586, -0.8633, -0.5352],
         [-1.8613, -1.1045, -3.0801,  ...,  5.0117, -2.4453, -1.4668]],

        [[-6.3164,  1.5645,  3.6973,  ..., -4.0547, -0.8462, -0.5381],
         [-6.3164,  1.5635,  3.6973,  ..., -4.0547, -0.8462, -0.5381],
         [-6.3438,  1.5771,  3.7285,  ..., -4.0352, -0.8813, -0.5659]],

        ...,

        [[-6.3164,  1.5605,  3.6973,  ..., -4.0508, -0.8340, -0.5459],
         [-6.3164,  1.5605,  3.6973,  ..., -4.0508, -0.8340, -0.5454],
         [-6.3516,  1.5703,  3.7324,  ..., -4.0312, -0.8740, -0.5747]],

        [[-5.7695,  1.4326,  3.0859,  ..., -4.2070, -1.0

Epoch 0:  13%|█▎        | 765/5879 [08:35<57:26,  1.48it/s, v_num=7]
 tensor([[[-6.3555,  1.5938,  3.7324,  ..., -4.0664, -0.8843, -0.5483],
         [-6.3555,  1.5938,  3.7324,  ..., -4.0664, -0.8838, -0.5483],
         [-1.8613, -1.0439, -3.1250,  ...,  5.0391, -2.5059, -1.4492]],

        [[-6.3633,  1.5938,  3.7441,  ..., -4.0664, -0.8818, -0.5547],
         [-6.3633,  1.5938,  3.7441,  ..., -4.0625, -0.8818, -0.5547],
         [-6.3984,  1.5918,  3.7793,  ..., -4.0391, -0.8857, -0.5942]],

        [[-6.3555,  1.5908,  3.7422,  ..., -4.0625, -0.8726, -0.5503],
         [-6.3555,  1.5908,  3.7422,  ..., -4.0625, -0.8726, -0.5503],
         [-1.8652, -1.0791, -3.1133,  ...,  5.0312, -2.4746, -1.4658]],

        ...,

        [[ 0.4597, -6.4805, -0.9824,  ..., -1.4014, -0.9863,  0.8359],
         [ 0.4600, -6.4805, -0.9824,  ..., -1.4014, -0.9863,  0.8354],
         [ 0.4375, -6.5312, -0.9434,  ..., -1.4463, -0.9839,  0.8594]],

        [[-1.0146, -2.2617, -0.8364,  ..., -3.3262, -1.6

Epoch 0:  13%|█▎        | 771/5879 [08:37<57:05,  1.49it/s, v_num=7]
 tensor([[[-6.4062,  1.6162,  3.7930,  ..., -4.0742, -0.8921, -0.5693],
         [-6.4062,  1.6162,  3.7930,  ..., -4.0742, -0.8926, -0.5688],
         [-1.8672, -1.0449, -3.1484,  ...,  5.0586, -2.5137, -1.4600]],

        [[-3.1738,  0.1985,  0.8540,  ..., -4.0078, -1.5264,  1.4082],
         [-3.1680,  0.1940,  0.8491,  ..., -4.0039, -1.5264,  1.4111],
         [-0.8252, -1.8379, -2.9941,  ...,  4.2578, -2.4414, -1.1855]],

        [[-6.4062,  1.6074,  3.7969,  ..., -4.0703, -0.8760, -0.5728],
         [-6.4062,  1.6074,  3.7969,  ..., -4.0703, -0.8765, -0.5723],
         [-1.8838, -1.0596, -3.1367,  ...,  5.0430, -2.4863, -1.4717]],

        ...,

        [[-6.4023,  1.6152,  3.7852,  ..., -4.0781, -0.9023, -0.5664],
         [-6.4023,  1.6152,  3.7852,  ..., -4.0781, -0.9023, -0.5659],
         [-1.8711, -1.0508, -3.1465,  ...,  5.0547, -2.5039, -1.4639]],

        [[-6.4023,  1.6182,  3.7832,  ..., -4.0742, -0.9

Epoch 0:  13%|█▎        | 777/5879 [08:48<57:52,  1.47it/s, v_num=7]
 tensor([[[-6.4141,  1.6367,  3.7891,  ..., -4.0938, -0.9346, -0.5640],
         [-6.4141,  1.6367,  3.7891,  ..., -4.0938, -0.9346, -0.5640],
         [-1.8574, -0.9810, -3.1875,  ...,  5.0586, -2.5801, -1.4326]],

        [[-6.3906,  1.6299,  3.7559,  ..., -4.1016, -0.9531, -0.5415],
         [-6.3906,  1.6299,  3.7559,  ..., -4.1016, -0.9526, -0.5415],
         [-1.8652, -0.9780, -3.1855,  ...,  5.0625, -2.5781, -1.4365]],

        [[-6.4531,  1.6338,  3.8477,  ..., -4.0781, -0.8979, -0.5918],
         [-6.4531,  1.6338,  3.8477,  ..., -4.0781, -0.8979, -0.5923],
         [-1.9180, -1.0430, -3.1504,  ...,  5.0352, -2.4863, -1.4814]],

        ...,

        [[-6.4414,  1.6406,  3.8262,  ..., -4.0859, -0.9292, -0.5806],
         [-6.4414,  1.6406,  3.8262,  ..., -4.0859, -0.9292, -0.5811],
         [-1.8906, -1.0303, -3.1699,  ...,  5.0625, -2.5195, -1.4678]],

        [[-6.4414,  1.6406,  3.8262,  ..., -4.0859, -0.9

Epoch 0:  13%|█▎        | 783/5879 [08:50<57:31,  1.48it/s, v_num=7]
 tensor([[[-6.5039,  1.6240,  3.9004,  ..., -4.0547, -0.8535, -0.6470],
         [-6.5039,  1.6240,  3.9004,  ..., -4.0547, -0.8525, -0.6470],
         [-6.1641,  1.2412,  3.4453,  ..., -3.1543, -0.2246, -1.1973]],

        [[-6.4922,  1.6582,  3.8984,  ..., -4.0859, -0.9224, -0.6050],
         [-6.4922,  1.6582,  3.8984,  ..., -4.0859, -0.9224, -0.6050],
         [-6.5312,  1.6494,  3.9316,  ..., -4.0547, -0.9370, -0.6528]],

        [[-6.4922,  1.6611,  3.8906,  ..., -4.0898, -0.9404, -0.6045],
         [-6.4922,  1.6611,  3.8906,  ..., -4.0898, -0.9399, -0.6050],
         [-1.8896, -0.9692, -3.2148,  ...,  5.0859, -2.5781, -1.4502]],

        ...,

        [[-6.4727,  1.6602,  3.8594,  ..., -4.0977, -0.9526, -0.5938],
         [-6.4727,  1.6602,  3.8594,  ..., -4.0977, -0.9521, -0.5933],
         [-1.8740, -0.9502, -3.2168,  ...,  5.0703, -2.6035, -1.4336]],

        [[-6.4922,  1.6611,  3.8906,  ..., -4.0898, -0.9

Epoch 0:  13%|█▎        | 789/5879 [08:51<57:11,  1.48it/s, v_num=7]
 tensor([[[-6.1875,  1.5967,  3.5273,  ..., -4.2031, -1.0918, -0.3206],
         [-6.1953,  1.5996,  3.5371,  ..., -4.1992, -1.0889, -0.3279],
         [-1.4648, -1.1865, -3.2422,  ...,  4.8711, -2.6719, -1.3340]],

        [[ 0.4402, -6.6680, -0.9302,  ..., -1.5137, -0.9424,  0.8955],
         [ 0.4402, -6.6680, -0.9307,  ..., -1.5137, -0.9429,  0.8955],
         [ 0.4192, -6.7227, -0.8940,  ..., -1.5518, -0.9380,  0.9155]],

        [[-6.5352,  1.6816,  3.9395,  ..., -4.0977, -0.9580, -0.6177],
         [-6.5352,  1.6807,  3.9395,  ..., -4.0977, -0.9580, -0.6177],
         [-6.5703,  1.6846,  3.9844,  ..., -4.0742, -0.9668, -0.6509]],

        ...,

        [[ 0.4282, -6.6758, -0.9272,  ..., -1.5449, -0.9487,  0.9219],
         [ 0.4250, -6.6797, -0.9243,  ..., -1.5498, -0.9492,  0.9268],
         [ 0.4185, -6.7227, -0.8931,  ..., -1.5547, -0.9390,  0.9185]],

        [[-6.5391,  1.6787,  3.9434,  ..., -4.0977, -0.9

Epoch 0:  14%|█▎        | 795/5879 [08:53<56:50,  1.49it/s, v_num=7]
 tensor([[[-1.9297, -0.8892, -3.2637,  ...,  5.1016, -2.6172, -1.4170],
         [-1.9287, -0.8892, -3.2637,  ...,  5.1016, -2.6172, -1.4170],
         [-0.5498, -2.1113, -2.9551,  ...,  3.8613, -2.3418, -1.0605]],

        [[-6.5820,  1.6963,  3.9844,  ..., -4.1055, -0.9644, -0.6338],
         [-6.5781,  1.6973,  3.9844,  ..., -4.1055, -0.9644, -0.6333],
         [-6.6133,  1.6650,  4.0117,  ..., -4.0586, -0.9136, -0.6919]],

        [[-6.5742,  1.6963,  3.9805,  ..., -4.1055, -0.9658, -0.6304],
         [-6.5742,  1.6963,  3.9805,  ..., -4.1055, -0.9663, -0.6299],
         [-1.9141, -0.9502, -3.2715,  ...,  5.1133, -2.5938, -1.4521]],

        ...,

        [[-6.5156,  1.6855,  3.8848,  ..., -4.1250, -1.0088, -0.5864],
         [-6.5156,  1.6855,  3.8848,  ..., -4.1250, -1.0078, -0.5864],
         [-1.8730, -0.9097, -3.2852,  ...,  5.0938, -2.6523, -1.4111]],

        [[-1.9346, -0.8965, -3.2637,  ...,  5.1055, -2.6

Epoch 0:  14%|█▎        | 801/5879 [09:06<57:44,  1.47it/s, v_num=7]
 tensor([[[-6.6250,  1.7109,  4.0195,  ..., -4.1133, -0.9756, -0.6489],
         [-6.6250,  1.7119,  4.0195,  ..., -4.1133, -0.9761, -0.6494],
         [-6.6602,  1.7090,  4.0508,  ..., -4.0781, -0.9990, -0.6924]],

        [[-6.6211,  1.7168,  4.0156,  ..., -4.1172, -0.9888, -0.6426],
         [-6.6211,  1.7168,  4.0156,  ..., -4.1172, -0.9883, -0.6426],
         [-1.9121, -0.9014, -3.3164,  ...,  5.1445, -2.6426, -1.4287]],

        [[-6.6250,  1.7148,  4.0195,  ..., -4.1133, -0.9785, -0.6470],
         [-6.6250,  1.7148,  4.0195,  ..., -4.1133, -0.9785, -0.6470],
         [-6.6562,  1.7188,  4.0508,  ..., -4.0859, -1.0146, -0.6816]],

        ...,

        [[-6.6250,  1.7139,  4.0156,  ..., -4.1133, -0.9805, -0.6479],
         [-6.6250,  1.7139,  4.0156,  ..., -4.1133, -0.9805, -0.6479],
         [-1.9375, -0.9365, -3.2930,  ...,  5.1211, -2.5938, -1.4512]],

        [[-1.4492, -1.1396, -3.3145,  ...,  4.9141, -2.6

Epoch 0:  14%|█▎        | 807/5879 [09:07<57:23,  1.47it/s, v_num=7]
 tensor([[[-6.6484,  1.7383,  4.0234,  ..., -4.1289, -1.0244, -0.6514],
         [-6.6484,  1.7393,  4.0234,  ..., -4.1289, -1.0244, -0.6514],
         [-1.9180, -0.8706, -3.3496,  ...,  5.1641, -2.6738, -1.4180]],

        [[-6.6055,  1.7256,  3.9629,  ..., -4.1406, -1.0420, -0.6157],
         [-6.6055,  1.7256,  3.9629,  ..., -4.1406, -1.0420, -0.6157],
         [-1.9170, -0.8662, -3.3516,  ...,  5.1641, -2.6758, -1.4150]],

        [[-6.6680,  1.7324,  4.0586,  ..., -4.1211, -0.9922, -0.6626],
         [-6.6680,  1.7324,  4.0547,  ..., -4.1211, -0.9922, -0.6626],
         [-6.6992,  1.6924,  4.0781,  ..., -4.0664, -0.9316, -0.7363]],

        ...,

        [[ 0.4138, -6.8125, -0.9214,  ..., -1.6123, -0.9194,  0.9697],
         [ 0.4087, -6.8125, -0.9189,  ..., -1.6162, -0.9229,  0.9736],
         [-1.7461, -0.9380, -3.3613,  ...,  5.0859, -2.7324, -1.3643]],

        [[ 0.5229, -6.6250, -1.0000,  ..., -1.2842, -0.8

Epoch 0:  14%|█▍        | 813/5879 [09:09<57:04,  1.48it/s, v_num=7]
 tensor([[[-6.6641,  1.7539,  4.0352,  ..., -4.1406, -1.0498, -0.6479],
         [-6.6641,  1.7539,  4.0352,  ..., -4.1406, -1.0488, -0.6484],
         [-1.9209, -0.8457, -3.3867,  ...,  5.1836, -2.7012, -1.4092]],

        [[ 0.3857, -6.8086, -0.9272,  ..., -1.6836, -0.9336,  1.0322],
         [ 0.3833, -6.8008, -0.9297,  ..., -1.6904, -0.9370,  1.0381],
         [-0.8140, -1.7441, -3.1914,  ...,  4.2734, -2.5430, -1.1387]],

        [[-6.7031,  1.7520,  4.0938,  ..., -4.1289, -1.0137, -0.6724],
         [-6.7031,  1.7510,  4.0938,  ..., -4.1289, -1.0127, -0.6724],
         [-6.7383,  1.7588,  4.1250,  ..., -4.1016, -1.0508, -0.7085]],

        ...,

        [[-6.6914,  1.7607,  4.0742,  ..., -4.1328, -1.0400, -0.6670],
         [-6.6953,  1.7607,  4.0742,  ..., -4.1328, -1.0410, -0.6675],
         [-1.9121, -0.8423, -3.3867,  ...,  5.1758, -2.7109, -1.4033]],

        [[-6.7031,  1.7598,  4.0898,  ..., -4.1367, -1.0

Epoch 0:  14%|█▍        | 819/5879 [09:11<56:47,  1.48it/s, v_num=7]
 tensor([[[-6.7305,  1.7793,  4.1133,  ..., -4.1445, -1.0615, -0.6797],
         [-6.7305,  1.7793,  4.1133,  ..., -4.1445, -1.0615, -0.6797],
         [-1.9180, -0.8228, -3.4199,  ...,  5.1992, -2.7344, -1.3994]],

        [[-6.7383,  1.7773,  4.1211,  ..., -4.1445, -1.0586, -0.6836],
         [-6.7383,  1.7783,  4.1211,  ..., -4.1445, -1.0576, -0.6836],
         [-1.9287, -0.8389, -3.4180,  ...,  5.2109, -2.7129, -1.4111]],

        [[ 0.4417, -6.8828, -0.8999,  ..., -1.5859, -0.8633,  0.9263],
         [ 0.4419, -6.8828, -0.8999,  ..., -1.5859, -0.8633,  0.9263],
         [ 0.4177, -6.9375, -0.8657,  ..., -1.6348, -0.8604,  0.9600]],

        ...,

        [[-6.7344,  1.7793,  4.1211,  ..., -4.1445, -1.0576, -0.6821],
         [-6.7344,  1.7783,  4.1211,  ..., -4.1445, -1.0576, -0.6821],
         [-6.7773,  1.7871,  4.1680,  ..., -4.1172, -1.0732, -0.7153]],

        [[-6.7266,  1.7832,  4.1094,  ..., -4.1445, -1.0

Epoch 0:  14%|█▍        | 825/5879 [09:17<56:57,  1.48it/s, v_num=7]
 tensor([[[-6.4883,  1.7256,  3.8105,  ..., -4.2344, -1.1748, -0.4529],
         [-6.4883,  1.7266,  3.8125,  ..., -4.2344, -1.1748, -0.4534],
         [-1.5664, -1.0127, -3.4609,  ...,  5.0312, -2.8066, -1.3086]],

        [[-6.7617,  1.7988,  4.1445,  ..., -4.1523, -1.0840, -0.6914],
         [-6.7617,  1.7998,  4.1445,  ..., -4.1523, -1.0840, -0.6914],
         [-1.9248, -0.8022, -3.4531,  ...,  5.2188, -2.7559, -1.3945]],

        [[-1.9590, -0.7969, -3.4336,  ...,  5.2148, -2.7227, -1.3906],
         [-1.9590, -0.7969, -3.4316,  ...,  5.2148, -2.7227, -1.3916],
         [-1.7910, -0.8608, -3.4629,  ...,  5.1562, -2.7949, -1.3535]],

        ...,

        [[-1.3965, -1.1143, -3.4434,  ...,  4.9375, -2.7676, -1.2363],
         [-1.4111, -1.1025, -3.4453,  ...,  4.9453, -2.7695, -1.2393],
         [ 0.4521, -6.9258, -0.8823,  ..., -1.5459, -0.8403,  0.8804]],

        [[-6.7773,  1.7979,  4.1680,  ..., -4.1484, -1.0

Epoch 0:  14%|█▍        | 831/5879 [09:19<56:37,  1.49it/s, v_num=7]
 tensor([[[-6.8086,  1.8154,  4.2070,  ..., -4.1562, -1.0928, -0.7090],
         [-6.8086,  1.8154,  4.2070,  ..., -4.1562, -1.0928, -0.7090],
         [-1.9375, -0.7896, -3.4883,  ...,  5.2422, -2.7695, -1.3955]],

        [[-6.8086,  1.8154,  4.2031,  ..., -4.1562, -1.0996, -0.7100],
         [-6.8086,  1.8154,  4.2031,  ..., -4.1562, -1.0996, -0.7100],
         [-6.8398,  1.8154,  4.2305,  ..., -4.1406, -1.1250, -0.7339]],

        [[-6.8086,  1.8145,  4.2070,  ..., -4.1562, -1.0977, -0.7075],
         [-6.8086,  1.8145,  4.2031,  ..., -4.1562, -1.0967, -0.7075],
         [-1.9365, -0.7935, -3.4883,  ...,  5.2461, -2.7676, -1.3975]],

        ...,

        [[ 0.3887, -6.9609, -0.8906,  ..., -1.7012, -0.8706,  1.0273],
         [ 0.3879, -6.9570, -0.8916,  ..., -1.7041, -0.8721,  1.0293],
         [ 0.6216, -6.2383, -1.1592,  ..., -0.5562, -0.9092,  0.3462]],

        [[-6.7930,  1.8164,  4.1797,  ..., -4.1602, -1.1

Epoch 0:  14%|█▍        | 837/5879 [09:20<56:18,  1.49it/s, v_num=7]
 tensor([[[-6.8438,  1.8301,  4.2500,  ..., -4.1602, -1.1172, -0.7207],
         [-6.8438,  1.8301,  4.2500,  ..., -4.1602, -1.1172, -0.7212],
         [-6.8789,  1.8398,  4.2852,  ..., -4.1367, -1.1436, -0.7539]],

        [[-6.8398,  1.7998,  4.2500,  ..., -4.1367, -1.0498, -0.7520],
         [-6.8398,  1.7998,  4.2500,  ..., -4.1367, -1.0498, -0.7520],
         [-1.9629, -0.8076, -3.5039,  ...,  5.2461, -2.7461, -1.4131]],

        [[-6.8477,  1.8242,  4.2539,  ..., -4.1602, -1.1064, -0.7227],
         [-6.8477,  1.8242,  4.2539,  ..., -4.1602, -1.1064, -0.7231],
         [-6.8828,  1.8271,  4.2969,  ..., -4.1367, -1.1113, -0.7617]],

        ...,

        [[-6.8281,  1.8301,  4.2227,  ..., -4.1680, -1.1328, -0.7163],
         [-6.8320,  1.8301,  4.2227,  ..., -4.1680, -1.1328, -0.7163],
         [-1.9316, -0.7598, -3.5234,  ...,  5.2500, -2.8086, -1.3799]],

        [[-1.9619, -0.7534, -3.5039,  ...,  5.2461, -2.7

Epoch 0:  14%|█▍        | 843/5879 [09:22<55:59,  1.50it/s, v_num=7]
 tensor([[[-6.7734,  1.8193,  4.1484,  ..., -4.1914, -1.1787, -0.6572],
         [-6.7734,  1.8193,  4.1484,  ..., -4.1914, -1.1787, -0.6572],
         [-1.9502, -0.7412, -3.5547,  ...,  5.2695, -2.8203, -1.3789]],

        [[-6.8516,  1.8486,  4.2578,  ..., -4.1719, -1.1553, -0.7246],
         [-6.8555,  1.8506,  4.2617,  ..., -4.1719, -1.1523, -0.7280],
         [-1.9385, -0.7378, -3.5547,  ...,  5.2617, -2.8301, -1.3730]],

        [[ 0.3435, -3.4004, -2.4395,  ...,  2.2617, -1.7402, -0.5396],
         [ 0.3440, -3.4004, -2.4375,  ...,  2.2598, -1.7402, -0.5386],
         [ 0.4397, -7.0469, -0.8608,  ..., -1.6133, -0.7979,  0.9160]],

        ...,

        [[-1.9727, -0.7383, -3.5371,  ...,  5.2617, -2.7949, -1.3750],
         [-1.9727, -0.7393, -3.5371,  ...,  5.2617, -2.7930, -1.3750],
         [-1.4951, -1.0205, -3.5488,  ...,  5.0078, -2.8594, -1.2695]],

        [[-6.8750,  1.8486,  4.2930,  ..., -4.1680, -1.1

Epoch 0:  14%|█▍        | 849/5879 [09:29<56:14,  1.49it/s, v_num=7]
 tensor([[[-1.9268, -0.7163, -3.5723,  ...,  5.2383, -2.8438, -1.3379],
         [-1.9297, -0.7148, -3.5723,  ...,  5.2422, -2.8438, -1.3389],
         [ 0.5269, -6.8555, -0.9365,  ..., -1.2764, -0.7920,  0.6860]],

        [[-1.9707, -0.7075, -3.5703,  ...,  5.2656, -2.8281, -1.3604],
         [-1.9707, -0.7070, -3.5703,  ...,  5.2656, -2.8281, -1.3604],
         [ 0.6118, -5.1758, -1.6133,  ...,  0.5371, -1.1621, -0.0451]],

        [[-6.9023,  1.8633,  4.3398,  ..., -4.1719, -1.1592, -0.7480],
         [-6.9023,  1.8633,  4.3398,  ..., -4.1719, -1.1592, -0.7485],
         [-1.9590, -0.7251, -3.5879,  ...,  5.2852, -2.8379, -1.3789]],

        ...,

        [[-6.8945,  1.8682,  4.3242,  ..., -4.1719, -1.1709, -0.7476],
         [-6.8945,  1.8682,  4.3242,  ..., -4.1719, -1.1719, -0.7471],
         [-1.9541, -0.7178, -3.5859,  ...,  5.2812, -2.8457, -1.3750]],

        [[-6.7148,  1.8125,  4.0898,  ..., -4.2188, -1.2

Epoch 0:  15%|█▍        | 855/5879 [09:31<55:55,  1.50it/s, v_num=7]
 tensor([[[-6.9297,  1.8740,  4.3906,  ..., -4.1719, -1.1787, -0.7632],
         [-6.9297,  1.8740,  4.3906,  ..., -4.1719, -1.1787, -0.7632],
         [-1.9717, -0.7007, -3.6152,  ...,  5.2969, -2.8594, -1.3779]],

        [[-6.9258,  1.8809,  4.3828,  ..., -4.1719, -1.1914, -0.7622],
         [-6.9258,  1.8809,  4.3828,  ..., -4.1719, -1.1904, -0.7617],
         [-1.9707, -0.6978, -3.6152,  ...,  5.2930, -2.8633, -1.3760]],

        [[-6.9258,  1.8779,  4.3789,  ..., -4.1719, -1.1934, -0.7622],
         [-6.9258,  1.8779,  4.3789,  ..., -4.1719, -1.1934, -0.7617],
         [-1.9707, -0.7002, -3.6172,  ...,  5.2930, -2.8613, -1.3770]],

        ...,

        [[-6.7891,  1.8447,  4.1992,  ..., -4.2031, -1.2324, -0.6523],
         [-6.7891,  1.8438,  4.1953,  ..., -4.2031, -1.2324, -0.6509],
         [-1.9053, -0.7090, -3.6211,  ...,  5.2500, -2.8965, -1.3477]],

        [[-6.9219,  1.8789,  4.3711,  ..., -4.1719, -1.1

Epoch 0:  15%|█▍        | 861/5879 [09:32<55:37,  1.50it/s, v_num=7]
 tensor([[[-0.4238, -4.1602, -1.1348,  ..., -3.0684, -1.3076,  2.1426],
         [-0.4341, -4.1289, -1.1318,  ..., -3.0762, -1.3105,  2.1465],
         [-1.9590, -0.6646, -3.6387,  ...,  5.2734, -2.9062, -1.3555]],

        [[-6.9258,  1.8906,  4.3984,  ..., -4.1680, -1.2266, -0.7666],
         [-6.9258,  1.8906,  4.3984,  ..., -4.1680, -1.2266, -0.7671],
         [-1.9893, -0.6636, -3.6367,  ...,  5.2969, -2.8867, -1.3711]],

        [[ 0.5874, -6.6055, -1.0430,  ..., -1.0156, -0.8052,  0.5620],
         [ 0.5898, -6.5938, -1.0479,  ..., -1.0029, -0.8066,  0.5557],
         [ 0.3901, -7.1758, -0.8032,  ..., -1.7607, -0.7573,  1.0137]],

        ...,

        [[-2.0078, -0.6699, -3.6230,  ...,  5.2930, -2.8613, -1.3711],
         [-2.0078, -0.6699, -3.6230,  ...,  5.2930, -2.8613, -1.3711],
         [-1.6934, -0.8203, -3.6465,  ...,  5.1367, -2.9375, -1.2930]],

        [[-6.9531,  1.8945,  4.4336,  ..., -4.1680, -1.2

Epoch 0:  15%|█▍        | 867/5879 [09:34<55:18,  1.51it/s, v_num=7]
 tensor([[[-6.9688,  1.9062,  4.4883,  ..., -4.1641, -1.2363, -0.7939],
         [-6.9688,  1.9062,  4.4883,  ..., -4.1641, -1.2363, -0.7939],
         [-2.0215, -0.6782, -3.6445,  ...,  5.3008, -2.8613, -1.3916]],

        [[-6.5742,  1.7900,  3.9941,  ..., -4.2656, -1.3486, -0.4487],
         [-6.5703,  1.7891,  3.9883,  ..., -4.2656, -1.3496, -0.4438],
         [-1.9531, -0.6470, -3.6621,  ...,  5.2773, -2.9336, -1.3447]],

        [[-6.9688,  1.9023,  4.4883,  ..., -4.1641, -1.2305, -0.7939],
         [-6.9688,  1.9023,  4.4883,  ..., -4.1641, -1.2314, -0.7930],
         [-7.0156,  1.9082,  4.5312,  ..., -4.1367, -1.2451, -0.8340]],

        ...,

        [[-2.0156, -0.6299, -3.6445,  ...,  5.2930, -2.8945, -1.3545],
         [-2.0156, -0.6299, -3.6445,  ...,  5.2930, -2.8945, -1.3545],
         [ 0.3013, -3.5996, -2.3535,  ...,  1.9990, -1.7012, -0.4912]],

        [[-6.7188,  1.8389,  4.1602,  ..., -4.2344, -1.3

Epoch 0:  15%|█▍        | 873/5879 [09:41<55:32,  1.50it/s, v_num=7]
 tensor([[[ 0.3442, -7.2266, -0.8047,  ..., -1.8682, -0.7554,  1.1055],
         [ 0.3433, -7.2266, -0.8047,  ..., -1.8691, -0.7563,  1.1074],
         [-1.3652, -1.0625, -3.6426,  ...,  4.8828, -2.9141, -1.2061]],

        [[-5.1758,  1.2783,  2.6328,  ..., -4.3125, -1.5391,  0.5474],
         [-5.1953,  1.2861,  2.6504,  ..., -4.3125, -1.5381,  0.5356],
         [-1.7715, -0.7432, -3.6934,  ...,  5.1914, -2.9766, -1.2910]],

        [[-6.9531,  1.9121,  4.4766,  ..., -4.1680, -1.2852, -0.7871],
         [-6.9531,  1.9121,  4.4766,  ..., -4.1680, -1.2852, -0.7871],
         [-2.0078, -0.6167, -3.6855,  ...,  5.3203, -2.9355, -1.3535]],

        ...,

        [[-6.9766,  1.9229,  4.5156,  ..., -4.1641, -1.2744, -0.8022],
         [-6.9766,  1.9229,  4.5156,  ..., -4.1641, -1.2734, -0.8022],
         [-2.0156, -0.6196, -3.6836,  ...,  5.3242, -2.9297, -1.3584]],

        [[-6.8750,  1.8906,  4.3711,  ..., -4.1875, -1.3

Epoch 0:  15%|█▍        | 879/5879 [09:48<55:45,  1.49it/s, v_num=7]
 tensor([[[-1.3066, -1.0762, -3.6680,  ...,  4.8711, -2.9043, -1.1543],
         [-1.3086, -1.0742, -3.6680,  ...,  4.8711, -2.9043, -1.1543],
         [ 0.3713, -7.2461, -0.7856,  ..., -1.8203, -0.7124,  1.0225]],

        [[-6.9766,  1.9268,  4.5273,  ..., -4.1719, -1.3135, -0.8052],
         [-6.9766,  1.9268,  4.5273,  ..., -4.1719, -1.3135, -0.8052],
         [-2.0137, -0.6094, -3.7129,  ...,  5.3398, -2.9512, -1.3477]],

        [[-7.0078,  1.9307,  4.5781,  ..., -4.1680, -1.2939, -0.8198],
         [-7.0078,  1.9307,  4.5781,  ..., -4.1680, -1.2939, -0.8198],
         [-2.0195, -0.6123, -3.7148,  ...,  5.3477, -2.9473, -1.3516]],

        ...,

        [[ 0.2605, -7.0000, -0.8950,  ..., -2.1094, -0.8408,  1.3379],
         [ 0.2625, -7.0078, -0.8926,  ..., -2.1035, -0.8379,  1.3320],
         [-1.9893, -0.6128, -3.7148,  ...,  5.3203, -2.9629, -1.3340]],

        [[-6.9219,  1.9053,  4.4570,  ..., -4.1836, -1.3

Epoch 0:  15%|█▌        | 885/5879 [09:50<55:34,  1.50it/s, v_num=7]
 tensor([[[-7.0234,  1.9346,  4.6250,  ..., -4.1680, -1.3086, -0.8359],
         [-7.0234,  1.9346,  4.6250,  ..., -4.1680, -1.3096, -0.8359],
         [-2.0254, -0.6118, -3.7480,  ...,  5.3672, -2.9570, -1.3477]],

        [[-7.0039,  1.9414,  4.5938,  ..., -4.1680, -1.3291, -0.8281],
         [-7.0039,  1.9414,  4.5977,  ..., -4.1680, -1.3301, -0.8281],
         [-2.0156, -0.6021, -3.7441,  ...,  5.3516, -2.9707, -1.3379]],

        [[-6.9531,  1.9209,  4.5156,  ..., -4.1836, -1.3477, -0.7837],
         [-6.9492,  1.9209,  4.5156,  ..., -4.1836, -1.3486, -0.7827],
         [-1.9971, -0.6055, -3.7461,  ...,  5.3359, -2.9805, -1.3291]],

        ...,

        [[-2.0430, -0.6011, -3.7285,  ...,  5.3555, -2.9414, -1.3369],
         [-2.0430, -0.6011, -3.7285,  ...,  5.3555, -2.9414, -1.3369],
         [-1.0752, -1.3379, -3.6035,  ...,  4.5820, -2.8281, -1.1230]],

        [[-7.0234,  1.9365,  4.6250,  ..., -4.1719, -1.3

Epoch 0:  15%|█▌        | 891/5879 [09:52<55:18,  1.50it/s, v_num=7]
 tensor([[[-1.9268, -0.6372, -3.7734,  ...,  5.3008, -2.9922, -1.2842],
         [-1.9277, -0.6367, -3.7734,  ...,  5.3008, -2.9902, -1.2842],
         [ 0.5195, -6.8047, -0.9697,  ..., -1.2139, -0.7207,  0.6113]],

        [[-2.0469, -0.6094, -3.7637,  ...,  5.3750, -2.9473, -1.3359],
         [-2.0469, -0.6094, -3.7637,  ...,  5.3750, -2.9473, -1.3359],
         [-1.8203, -0.7036, -3.7891,  ...,  5.2578, -3.0176, -1.2764]],

        [[-4.0664,  0.6768,  1.6738,  ..., -4.1875, -1.5840,  1.1758],
         [-4.2734,  0.8032,  1.8594,  ..., -4.2188, -1.5879,  1.0684],
         [-1.9180, -0.6465, -3.7891,  ...,  5.3125, -3.0098, -1.3008]],

        ...,

        [[-6.7227,  1.8545,  4.2539,  ..., -4.2539, -1.4248, -0.5679],
         [-6.7227,  1.8545,  4.2539,  ..., -4.2539, -1.4248, -0.5679],
         [-1.9668, -0.6187, -3.7852,  ...,  5.3359, -3.0020, -1.3125]],

        [[-7.0391,  1.9463,  4.6641,  ..., -4.1758, -1.3

Epoch 0:  15%|█▌        | 897/5879 [09:54<55:00,  1.51it/s, v_num=7]
 tensor([[[-1.2246, -2.4219, -0.7085,  ..., -3.5332, -1.3809,  2.2383],
         [-1.2334, -2.4043, -0.7021,  ..., -3.5371, -1.3828,  2.2363],
         [-1.0029, -1.4121, -3.6445,  ...,  4.5117, -2.8125, -1.0918]],

        [[-2.0391, -0.5986, -3.8027,  ...,  5.3867, -2.9688, -1.3184],
         [-2.0391, -0.5981, -3.8027,  ...,  5.3867, -2.9688, -1.3184],
         [ 0.3784, -3.8867, -2.2715,  ...,  1.6377, -1.5635, -0.3862]],

        [[-7.0625,  1.9502,  4.6953,  ..., -4.1875, -1.3613, -0.8569],
         [-7.0586,  1.9502,  4.6953,  ..., -4.1875, -1.3613, -0.8569],
         [-2.0195, -0.6079, -3.8262,  ...,  5.4023, -2.9883, -1.3301]],

        ...,

        [[-7.0586,  1.9541,  4.6953,  ..., -4.1836, -1.3633, -0.8584],
         [-7.0586,  1.9551,  4.6914,  ..., -4.1836, -1.3643, -0.8574],
         [-2.0215, -0.6094, -3.8262,  ...,  5.4023, -2.9863, -1.3311]],

        [[-7.0273,  1.9443,  4.6367,  ..., -4.1875, -1.3

Epoch 0:  15%|█▌        | 903/5879 [09:59<55:04,  1.51it/s, v_num=7]
 tensor([[[-1.5459, -0.8892, -3.8496,  ...,  5.1289, -3.0098, -1.1768],
         [-1.5625, -0.8770, -3.8516,  ...,  5.1406, -3.0117, -1.1797],
         [ 0.3997, -7.2773, -0.8286,  ..., -1.7734, -0.6255,  0.9243]],

        [[-7.0312,  1.9395,  4.6055,  ..., -4.2188, -1.3994, -0.8179],
         [-7.0312,  1.9385,  4.6055,  ..., -4.2188, -1.3994, -0.8179],
         [-1.9932, -0.6128, -3.8672,  ...,  5.4219, -3.0117, -1.3076]],

        [[-7.0938,  1.9590,  4.6992,  ..., -4.2070, -1.3799, -0.8618],
         [-7.0938,  1.9590,  4.6992,  ..., -4.2070, -1.3799, -0.8623],
         [-1.9980, -0.6172, -3.8691,  ...,  5.4258, -3.0059, -1.3115]],

        ...,

        [[-7.0977,  1.9541,  4.7109,  ..., -4.2070, -1.3662, -0.8643],
         [-7.0977,  1.9541,  4.7109,  ..., -4.2070, -1.3662, -0.8643],
         [-2.0000, -0.6338, -3.8691,  ...,  5.4336, -2.9902, -1.3213]],

        [[-7.0859,  1.9551,  4.6836,  ..., -4.2070, -1.3

Epoch 0:  15%|█▌        | 909/5879 [10:10<55:35,  1.49it/s, v_num=7]
 tensor([[[ 0.4465, -7.1719, -0.9336,  ..., -1.7207, -0.6387,  0.9248],
         [ 0.3921, -7.2891, -0.8823,  ..., -1.8740, -0.6294,  1.0322],
         [-1.8721, -0.6724, -3.9082,  ...,  5.3945, -3.0488, -1.2578]],

        [[ 0.6084, -6.5586, -1.1484,  ..., -0.9146, -0.7192,  0.4683],
         [ 0.6089, -6.5547, -1.1504,  ..., -0.9106, -0.7197,  0.4663],
         [ 0.3301, -7.4648, -0.7837,  ..., -1.9785, -0.5947,  1.0996]],

        [[-7.1406,  1.9551,  4.7188,  ..., -4.2266, -1.3398, -0.8804],
         [-7.1406,  1.9551,  4.7148,  ..., -4.2266, -1.3428, -0.8784],
         [-7.1836,  1.9531,  4.7578,  ..., -4.1992, -1.3457, -0.9224]],

        ...,

        [[-7.1484,  1.9678,  4.7148,  ..., -4.2305, -1.3721, -0.8774],
         [-7.1484,  1.9678,  4.7148,  ..., -4.2305, -1.3721, -0.8774],
         [-1.9658, -0.6387, -3.9082,  ...,  5.4609, -3.0195, -1.2930]],

        [[-7.1445,  1.9668,  4.7148,  ..., -4.2305, -1.3

Epoch 0:  16%|█▌        | 915/5879 [10:11<55:18,  1.50it/s, v_num=7]
 tensor([[[-7.1758,  1.9824,  4.6992,  ..., -4.2539, -1.3643, -0.8853],
         [-7.1758,  1.9824,  4.6992,  ..., -4.2539, -1.3643, -0.8853],
         [-1.9268, -0.6553, -3.9375,  ...,  5.4805, -3.0410, -1.2744]],

        [[-7.1836,  1.9756,  4.7148,  ..., -4.2539, -1.3486, -0.8853],
         [-7.1836,  1.9756,  4.7148,  ..., -4.2539, -1.3486, -0.8853],
         [-1.9355, -0.6724, -3.9395,  ...,  5.4961, -3.0234, -1.2881]],

        [[-1.9443, -0.6538, -3.9238,  ...,  5.4805, -3.0215, -1.2705],
         [-1.9443, -0.6538, -3.9238,  ...,  5.4805, -3.0215, -1.2705],
         [-1.2373, -1.1719, -3.8633,  ...,  4.9414, -2.9746, -1.1260]],

        ...,

        [[-7.1836,  1.9775,  4.7148,  ..., -4.2539, -1.3525, -0.8848],
         [-7.1836,  1.9775,  4.7148,  ..., -4.2578, -1.3525, -0.8848],
         [-1.9346, -0.6646, -3.9395,  ...,  5.4961, -3.0293, -1.2842]],

        [[-1.9492, -0.6597, -3.9219,  ...,  5.4844, -3.0

Epoch 0:  16%|█▌        | 921/5879 [10:13<55:01,  1.50it/s, v_num=7]
 tensor([[[-5.2188,  1.2666,  2.6133,  ..., -4.3906, -1.5400,  0.5845],
         [-5.2422,  1.2773,  2.6367,  ..., -4.3945, -1.5400,  0.5698],
         [-1.8193, -0.7080, -3.9688,  ...,  5.4531, -3.0820, -1.2393]],

        [[-7.2188,  1.9971,  4.7148,  ..., -4.2852, -1.3545, -0.8921],
         [-7.2188,  1.9971,  4.7148,  ..., -4.2852, -1.3555, -0.8921],
         [-1.9033, -0.6772, -3.9707,  ...,  5.5195, -3.0566, -1.2715]],

        [[-7.2188,  1.9980,  4.7148,  ..., -4.2852, -1.3574, -0.8921],
         [-7.2188,  1.9980,  4.7148,  ..., -4.2852, -1.3584, -0.8921],
         [-1.9014, -0.6743, -3.9707,  ...,  5.5156, -3.0605, -1.2695]],

        ...,

        [[-7.0430,  1.9443,  4.4766,  ..., -4.3203, -1.3975, -0.7480],
         [-7.0273,  1.9404,  4.4609,  ..., -4.3242, -1.4004, -0.7349],
         [-1.8438, -0.6938, -3.9688,  ...,  5.4688, -3.0801, -1.2461]],

        [[-7.1016,  1.9600,  4.5508,  ..., -4.3047, -1.3

Epoch 0:  16%|█▌        | 927/5879 [10:14<54:44,  1.51it/s, v_num=7]
 tensor([[[-7.2500,  2.0156,  4.7305,  ..., -4.3164, -1.3555, -0.8975],
         [-7.2500,  2.0156,  4.7305,  ..., -4.3164, -1.3555, -0.8975],
         [-1.8789, -0.6914, -4.0078,  ...,  5.5430, -3.0840, -1.2666]],

        [[-3.7344,  0.4263,  1.2705,  ..., -4.1875, -1.4844,  1.3730],
         [-3.7891,  0.4634,  1.3174,  ..., -4.1953, -1.4873,  1.3467],
         [-1.1475, -1.2559, -3.9102,  ...,  4.9453, -3.0000, -1.1162]],

        [[-7.2578,  2.0137,  4.7344,  ..., -4.3164, -1.3594, -0.8989],
         [-7.2578,  2.0137,  4.7344,  ..., -4.3164, -1.3594, -0.8989],
         [-1.8779, -0.6929, -4.0078,  ...,  5.5430, -3.0840, -1.2676]],

        ...,

        [[-7.2578,  2.0117,  4.7383,  ..., -4.3164, -1.3506, -0.8960],
         [-7.2539,  2.0098,  4.7383,  ..., -4.3125, -1.3447, -0.8965],
         [-1.8877, -0.7207, -3.9980,  ...,  5.5430, -3.0527, -1.2852]],

        [[-7.2578,  2.0137,  4.7344,  ..., -4.3164, -1.3

Epoch 0:  16%|█▌        | 933/5879 [10:26<55:23,  1.49it/s, v_num=7]
 tensor([[[-7.2969,  2.0312,  4.7422,  ..., -4.3438, -1.3662, -0.9033],
         [-7.2969,  2.0293,  4.7422,  ..., -4.3438, -1.3652, -0.9038],
         [-1.8613, -0.7158, -4.0469,  ...,  5.5742, -3.1016, -1.2725]],

        [[-1.8682, -0.6968, -4.0312,  ...,  5.5547, -3.0996, -1.2539],
         [-1.8682, -0.6968, -4.0312,  ...,  5.5547, -3.0996, -1.2539],
         [ 0.4265, -3.8457, -2.4648,  ...,  1.9014, -1.6416, -0.4785]],

        [[-7.2461,  2.0156,  4.6758,  ..., -4.3438, -1.3701, -0.8789],
         [-7.2461,  2.0156,  4.6758,  ..., -4.3438, -1.3701, -0.8789],
         [-1.8438, -0.7002, -4.0469,  ...,  5.5508, -3.1250, -1.2549]],

        ...,

        [[ 0.3672, -7.5586, -0.8604,  ..., -1.9170, -0.5601,  1.1006],
         [-7.1562,  1.9863,  4.5586,  ..., -4.3672, -1.3877, -0.7983],
         [-1.8359, -0.7036, -4.0469,  ...,  5.5430, -3.1250, -1.2520]],

        [[-7.2969,  2.0312,  4.7578,  ..., -4.3438, -1.3

Epoch 0:  16%|█▌        | 939/5879 [10:28<55:07,  1.49it/s, v_num=7]
 tensor([[[-7.3047,  2.0449,  4.6992,  ..., -4.3594, -1.3711, -0.9038],
         [-7.3047,  2.0449,  4.6992,  ..., -4.3594, -1.3711, -0.9038],
         [-1.8262, -0.7085, -4.0859,  ...,  5.5703, -3.1543, -1.2529]],

        [[-7.3398,  2.0469,  4.7539,  ..., -4.3594, -1.3496, -0.9146],
         [-7.3398,  2.0469,  4.7539,  ..., -4.3594, -1.3496, -0.9146],
         [-1.8408, -0.7217, -4.0898,  ...,  5.5938, -3.1348, -1.2695]],

        [[-7.3242,  2.0547,  4.7305,  ..., -4.3555, -1.3682, -0.9194],
         [-7.3242,  2.0547,  4.7266,  ..., -4.3555, -1.3691, -0.9170],
         [-1.8213, -0.7100, -4.0859,  ...,  5.5664, -3.1562, -1.2510]],

        ...,

        [[-6.9297,  1.9199,  4.2383,  ..., -4.4414, -1.4395, -0.5688],
         [-7.0234,  1.9521,  4.3477,  ..., -4.4258, -1.4248, -0.6519],
         [-1.7422, -0.7559, -4.0859,  ...,  5.5117, -3.1699, -1.2295]],

        [[-7.3203,  2.0449,  4.7148,  ..., -4.3555, -1.3

Epoch 0:  16%|█▌        | 945/5879 [10:30<54:51,  1.50it/s, v_num=7]
 tensor([[[-1.8340, -0.7080, -4.1055,  ...,  5.5938, -3.1621, -1.2480],
         [-1.8340, -0.7080, -4.1055,  ...,  5.5938, -3.1641, -1.2480],
         [ 0.6709, -5.2930, -1.8125,  ...,  0.6377, -1.1182, -0.1075]],

        [[-7.3984,  2.0879,  4.7578,  ..., -4.3672, -1.3652, -0.9390],
         [-7.3984,  2.0879,  4.7578,  ..., -4.3672, -1.3652, -0.9395],
         [-1.8213, -0.7144, -4.1211,  ...,  5.6016, -3.1797, -1.2578]],

        [[-7.3281,  2.0664,  4.6602,  ..., -4.3672, -1.3799, -0.9014],
         [-7.3242,  2.0645,  4.6523,  ..., -4.3711, -1.3809, -0.8979],
         [-1.7891, -0.7231, -4.1211,  ...,  5.5742, -3.1914, -1.2441]],

        ...,

        [[-7.3984,  2.0820,  4.7539,  ..., -4.3672, -1.3652, -0.9360],
         [-7.3984,  2.0820,  4.7539,  ..., -4.3672, -1.3652, -0.9355],
         [-1.8252, -0.7236, -4.1250,  ...,  5.6133, -3.1699, -1.2656]],

        [[-7.3945,  2.0820,  4.7539,  ..., -4.3672, -1.3

Epoch 0:  16%|█▌        | 951/5879 [10:32<54:37,  1.50it/s, v_num=7]
 tensor([[[-1.8105, -0.7114, -4.1406,  ...,  5.6133, -3.1973, -1.2432],
         [-1.8115, -0.7119, -4.1406,  ...,  5.6133, -3.1973, -1.2432],
         [ 0.6650, -5.0898, -1.9365,  ...,  0.8833, -1.2129, -0.1613]],

        [[-1.8223, -0.7197, -4.1445,  ...,  5.6289, -3.1895, -1.2539],
         [-1.8213, -0.7197, -4.1406,  ...,  5.6289, -3.1875, -1.2539],
         [-1.7617, -0.7314, -4.1523,  ...,  5.5977, -3.2227, -1.2373]],

        [[-7.4375,  2.1113,  4.7500,  ..., -4.3633, -1.3584, -0.9463],
         [-7.4375,  2.1113,  4.7500,  ..., -4.3633, -1.3594, -0.9463],
         [-1.8086, -0.7285, -4.1602,  ...,  5.6406, -3.1992, -1.2656]],

        ...,

        [[-7.2969,  2.0742,  4.5547,  ..., -4.3867, -1.4004, -0.8457],
         [-7.2969,  2.0742,  4.5547,  ..., -4.3867, -1.4004, -0.8462],
         [-1.7305, -0.7534, -4.1562,  ...,  5.5742, -3.2285, -1.2305]],

        [[-7.2852,  2.0703,  4.5391,  ..., -4.3906, -1.4

Epoch 0:  16%|█▋        | 957/5879 [10:43<55:08,  1.49it/s, v_num=7]
 tensor([[[-2.0645, -0.5918, -3.9629,  ...,  5.3945, -3.0293, -1.3252],
         [-2.0664, -0.5903, -3.9609,  ...,  5.3906, -3.0273, -1.3252],
         [-1.9834, -0.6455, -4.0391,  ...,  5.4922, -3.0918, -1.3174]],

        [[-7.4922,  2.1602,  4.7578,  ..., -4.3633, -1.3896, -0.9644],
         [-7.4922,  2.1602,  4.7578,  ..., -4.3633, -1.3896, -0.9644],
         [-1.7939, -0.7236, -4.1875,  ...,  5.6680, -3.2324, -1.2617]],

        [[-7.4805,  2.1582,  4.7500,  ..., -4.3594, -1.3887, -0.9678],
         [-7.4844,  2.1602,  4.7539,  ..., -4.3633, -1.3887, -0.9678],
         [-1.7803, -0.7188, -4.1875,  ...,  5.6484, -3.2461, -1.2520]],

        ...,

        [[-7.4219,  2.1387,  4.6641,  ..., -4.3633, -1.3975, -0.9282],
         [-7.4219,  2.1387,  4.6641,  ..., -4.3633, -1.3975, -0.9297],
         [-1.7344, -0.7422, -4.1875,  ...,  5.6172, -3.2559, -1.2363]],

        [[-7.4844,  2.1523,  4.7617,  ..., -4.3633, -1.3

Epoch 0:  16%|█▋        | 963/5879 [10:44<54:52,  1.49it/s, v_num=7]
 tensor([[[-7.3828,  2.1484,  4.5742,  ..., -4.3750, -1.4248, -0.8672],
         [-7.3828,  2.1484,  4.5742,  ..., -4.3750, -1.4238, -0.8677],
         [-1.7578, -0.7212, -4.2148,  ...,  5.6758, -3.2754, -1.2500]],

        [[-6.0508,  1.7051,  3.1719,  ..., -4.4570, -1.5293,  0.1456],
         [-6.0430,  1.7021,  3.1641,  ..., -4.4570, -1.5293,  0.1506],
         [-1.5576, -0.8589, -4.2148,  ...,  5.5469, -3.2891, -1.2051]],

        [[-7.5156,  2.1836,  4.7656,  ..., -4.3477, -1.3633, -0.9761],
         [-7.5156,  2.1836,  4.7656,  ..., -4.3477, -1.3643, -0.9761],
         [-1.7734, -0.7251, -4.2227,  ...,  5.6992, -3.2637, -1.2637]],

        ...,

        [[-7.5117,  2.1836,  4.7656,  ..., -4.3477, -1.3623, -0.9766],
         [-7.5117,  2.1836,  4.7656,  ..., -4.3477, -1.3623, -0.9766],
         [-1.7754, -0.7310, -4.2227,  ...,  5.7031, -3.2578, -1.2676]],

        [[-7.5273,  2.1992,  4.7695,  ..., -4.3594, -1.4

Epoch 0:  16%|█▋        | 969/5879 [10:46<54:36,  1.50it/s, v_num=7]
 tensor([[[-7.5664,  2.2344,  4.7930,  ..., -4.3594, -1.4023, -0.9775],
         [-7.5664,  2.2344,  4.7930,  ..., -4.3594, -1.4023, -0.9775],
         [-1.7588, -0.7227, -4.2500,  ...,  5.7266, -3.2930, -1.2637]],

        [[-7.5703,  2.2266,  4.7969,  ..., -4.3477, -1.3857, -0.9854],
         [-7.5703,  2.2266,  4.7969,  ..., -4.3477, -1.3857, -0.9854],
         [-1.7773, -0.7490, -4.2344,  ...,  5.7148, -3.2539, -1.2852]],

        [[-7.5703,  2.2402,  4.7969,  ..., -4.3594, -1.4111, -0.9785],
         [-7.5703,  2.2402,  4.7969,  ..., -4.3594, -1.4111, -0.9785],
         [-1.7588, -0.7202, -4.2500,  ...,  5.7227, -3.2949, -1.2607]],

        ...,

        [[-7.5742,  2.2402,  4.8008,  ..., -4.3594, -1.4072, -0.9810],
         [-7.5703,  2.2402,  4.8008,  ..., -4.3594, -1.4082, -0.9810],
         [-1.7578, -0.7212, -4.2500,  ...,  5.7227, -3.2949, -1.2607]],

        [[-2.7812, -0.1410,  0.2983,  ..., -3.8613, -1.3

Epoch 0:  17%|█▋        | 975/5879 [10:48<54:23,  1.50it/s, v_num=7]
 tensor([[[-7.5820,  2.2695,  4.7773,  ..., -4.3555, -1.4297, -0.9775],
         [-7.5820,  2.2695,  4.7773,  ..., -4.3555, -1.4307, -0.9775],
         [-1.7285, -0.7295, -4.2773,  ...,  5.7305, -3.3262, -1.2559]],

        [[-7.6094,  2.2812,  4.8242,  ..., -4.3594, -1.4238, -0.9854],
         [-7.6094,  2.2793,  4.8242,  ..., -4.3594, -1.4238, -0.9858],
         [-1.7363, -0.7256, -4.2773,  ...,  5.7383, -3.3242, -1.2588]],

        [[-7.5664,  2.2637,  4.7578,  ..., -4.3555, -1.4316, -0.9688],
         [-7.5625,  2.2637,  4.7539,  ..., -4.3555, -1.4316, -0.9678],
         [-1.7275, -0.7305, -4.2773,  ...,  5.7305, -3.3262, -1.2549]],

        ...,

        [[-7.6055,  2.2656,  4.8242,  ..., -4.3477, -1.4014, -0.9878],
         [-7.6055,  2.2656,  4.8242,  ..., -4.3477, -1.4023, -0.9878],
         [-1.7500, -0.7344, -4.2812,  ...,  5.7617, -3.3086, -1.2725]],

        [[-7.5742,  2.2461,  4.7969,  ..., -4.3281, -1.3

Epoch 0:  17%|█▋        | 981/5879 [10:54<54:29,  1.50it/s, v_num=7]
 tensor([[[ 0.4927, -7.6953, -1.0771,  ..., -1.4297, -0.6519,  1.1221],
         [ 0.4927, -7.6914, -1.0771,  ..., -1.4297, -0.6519,  1.1221],
         [-0.6289, -1.7861, -3.9922,  ...,  4.6992, -3.0215, -1.0566]],

        [[-7.5078,  2.2617,  4.6641,  ..., -4.3750, -1.4600, -0.8901],
         [-7.5078,  2.2617,  4.6641,  ..., -4.3750, -1.4600, -0.8911],
         [-1.6807, -0.7598, -4.3086,  ...,  5.7305, -3.3535, -1.2480]],

        [[-5.5117,  1.5469,  2.6270,  ..., -4.3789, -1.5479,  0.5249],
         [-5.5273,  1.5527,  2.6406,  ..., -4.3789, -1.5479,  0.5176],
         [-1.5723, -0.8398, -4.3086,  ...,  5.6641, -3.3594, -1.2275]],

        ...,

        [[-7.4492,  2.2461,  4.5898,  ..., -4.3867, -1.4688, -0.8442],
         [-7.4531,  2.2461,  4.5938,  ..., -4.3867, -1.4678, -0.8457],
         [-1.7012, -0.7456, -4.3086,  ...,  5.7461, -3.3516, -1.2539]],

        [[-7.6406,  2.3145,  4.8555,  ..., -4.3633, -1.4

Epoch 0:  17%|█▋        | 987/5879 [10:56<54:12,  1.50it/s, v_num=7]
 tensor([[[-7.5938,  2.3105,  4.7617,  ..., -4.3711, -1.4658, -0.9375],
         [-7.6094,  2.3184,  4.7891,  ..., -4.3711, -1.4629, -0.9541],
         [-1.6738, -0.7646, -4.3398,  ...,  5.7578, -3.3750, -1.2529]],

        [[-7.6797,  2.3418,  4.8828,  ..., -4.3750, -1.4531, -0.9849],
         [-7.6758,  2.3418,  4.8828,  ..., -4.3750, -1.4531, -0.9844],
         [-1.7217, -0.7349, -4.3398,  ...,  5.7930, -3.3672, -1.2666]],

        [[-7.6680,  2.3301,  4.8750,  ..., -4.3711, -1.4326, -0.9829],
         [-7.6680,  2.3301,  4.8750,  ..., -4.3711, -1.4326, -0.9824],
         [-1.7314, -0.7339, -4.3398,  ...,  5.8008, -3.3613, -1.2715]],

        ...,

        [[-7.6680,  2.3340,  4.8789,  ..., -4.3711, -1.4395, -0.9805],
         [-7.6680,  2.3340,  4.8750,  ..., -4.3711, -1.4395, -0.9800],
         [-1.7266, -0.7339, -4.3398,  ...,  5.7969, -3.3652, -1.2686]],

        [[ 0.3179, -6.7617, -1.3350,  ..., -1.9561, -0.8

Epoch 0:  17%|█▋        | 993/5879 [10:57<53:56,  1.51it/s, v_num=7]
 tensor([[[-7.6523,  2.3457,  4.8281,  ..., -4.3828, -1.4824, -0.9619],
         [-7.6523,  2.3457,  4.8281,  ..., -4.3828, -1.4824, -0.9614],
         [-1.7002, -0.7505, -4.3672,  ...,  5.7969, -3.3906, -1.2646]],

        [[ 0.5210, -7.8125, -1.0713,  ..., -1.3232, -0.6548,  1.0674],
         [ 0.5205, -7.8125, -1.0713,  ..., -1.3232, -0.6548,  1.0674],
         [ 0.7671, -6.4922, -1.5322,  ...,  0.1018, -0.9609,  0.1420]],

        [[-7.7031,  2.3672,  4.9102,  ..., -4.3867, -1.4697, -0.9888],
         [-7.7031,  2.3672,  4.9102,  ..., -4.3867, -1.4697, -0.9888],
         [-1.7070, -0.7456, -4.3711,  ...,  5.8047, -3.3887, -1.2676]],

        ...,

        [[-7.7070,  2.3633,  4.9141,  ..., -4.3867, -1.4609, -0.9824],
         [-7.7070,  2.3633,  4.9141,  ..., -4.3867, -1.4609, -0.9824],
         [-1.7246, -0.7383, -4.3711,  ...,  5.8242, -3.3828, -1.2754]],

        [[-7.6914,  2.3535,  4.9023,  ..., -4.3789, -1.4

Epoch 0:  17%|█▋        | 999/5879 [10:59<53:40,  1.52it/s, v_num=7]
 tensor([[[-7.7148,  2.3633,  4.9336,  ..., -4.3750, -1.4316, -1.0049],
         [-7.7109,  2.3633,  4.9336,  ..., -4.3711, -1.4316, -1.0049],
         [-1.7236, -0.7617, -4.4023,  ...,  5.8516, -3.3848, -1.2910]],

        [[-7.7188,  2.3652,  4.9375,  ..., -4.3750, -1.4375, -1.0088],
         [-7.7188,  2.3672,  4.9375,  ..., -4.3750, -1.4375, -1.0088],
         [-1.7246, -0.7490, -4.4023,  ...,  5.8516, -3.3945, -1.2842]],

        [[-7.7383,  2.3887,  4.9531,  ..., -4.3984, -1.4766, -0.9902],
         [-7.7383,  2.3887,  4.9531,  ..., -4.4023, -1.4775, -0.9902],
         [-1.7207, -0.7461, -4.3984,  ...,  5.8359, -3.4023, -1.2773]],

        ...,

        [[-7.7148,  2.3633,  4.9336,  ..., -4.3750, -1.4326, -1.0059],
         [-7.7148,  2.3633,  4.9336,  ..., -4.3750, -1.4316, -1.0049],
         [-1.7246, -0.7529, -4.4023,  ...,  5.8516, -3.3945, -1.2861]],

        [[-7.7344,  2.3906,  4.9531,  ..., -4.3984, -1.4

Epoch 0:  17%|█▋        | 1005/5879 [11:06<53:51,  1.51it/s, v_num=7]
 tensor([[[-7.7617,  2.4141,  4.9922,  ..., -4.4102, -1.4951, -1.0000],
         [-7.7617,  2.4141,  4.9922,  ..., -4.4102, -1.4951, -1.0000],
         [-1.7119, -0.7505, -4.4297,  ...,  5.8516, -3.4277, -1.2773]],

        [[-5.5312,  1.6152,  2.6660,  ..., -4.3867, -1.5840,  0.5664],
         [-5.5352,  1.6172,  2.6680,  ..., -4.3867, -1.5840,  0.5645],
         [-1.3105, -1.0928, -4.3867,  ...,  5.5508, -3.4004, -1.2070]],

        [[-7.7461,  2.4023,  4.9805,  ..., -4.4062, -1.4736, -0.9927],
         [-7.7461,  2.4023,  4.9805,  ..., -4.4062, -1.4736, -0.9927],
         [-1.7041, -0.7534, -4.4297,  ...,  5.8438, -3.4297, -1.2734]],

        ...,

        [[-7.7656,  2.4082,  4.9922,  ..., -4.4141, -1.4902, -0.9980],
         [-7.7656,  2.4082,  4.9922,  ..., -4.4141, -1.4902, -0.9976],
         [-1.7207, -0.7480, -4.4297,  ...,  5.8633, -3.4219, -1.2812]],

        [[-7.7578,  2.4023,  4.9883,  ..., -4.4062, -1.

Epoch 0:  17%|█▋        | 1011/5879 [11:07<53:36,  1.51it/s, v_num=7]
 tensor([[[-7.7812,  2.4238,  5.0312,  ..., -4.4180, -1.4922, -1.0068],
         [-7.7812,  2.4238,  5.0312,  ..., -4.4180, -1.4922, -1.0068],
         [-1.7207, -0.7554, -4.4648,  ...,  5.8789, -3.4395, -1.2852]],

        [[-7.7852,  2.4258,  5.0352,  ..., -4.4180, -1.4941, -1.0078],
         [-7.7812,  2.4258,  5.0352,  ..., -4.4180, -1.4941, -1.0078],
         [-1.7207, -0.7563, -4.4648,  ...,  5.8789, -3.4375, -1.2861]],

        [[-7.7500,  2.3906,  5.0000,  ..., -4.3789, -1.4355, -1.0342],
         [-7.7500,  2.3906,  5.0000,  ..., -4.3789, -1.4365, -1.0342],
         [-1.7188, -0.7544, -4.4609,  ...,  5.8789, -3.4414, -1.2842]],

        ...,

        [[-7.7852,  2.4258,  5.0312,  ..., -4.4219, -1.4990, -1.0049],
         [-7.7852,  2.4258,  5.0312,  ..., -4.4219, -1.5000, -1.0049],
         [-1.7168, -0.7549, -4.4609,  ...,  5.8711, -3.4434, -1.2822]],

        [[-7.7773,  2.4141,  5.0234,  ..., -4.4141, -1.

Epoch 0:  17%|█▋        | 1017/5879 [11:15<53:49,  1.51it/s, v_num=7]
 tensor([[[-7.8086,  2.4395,  5.0742,  ..., -4.4336, -1.5059, -1.0156],
         [-7.8086,  2.4395,  5.0742,  ..., -4.4336, -1.5059, -1.0156],
         [-1.7197, -0.7666, -4.4961,  ...,  5.8945, -3.4512, -1.2891]],

        [[-7.8008,  2.4453,  5.0703,  ..., -4.4375, -1.5137, -1.0078],
         [-7.8008,  2.4453,  5.0703,  ..., -4.4375, -1.5137, -1.0088],
         [-1.6934, -0.7666, -4.4922,  ...,  5.8594, -3.4707, -1.2725]],

        [[-7.8125,  2.4492,  5.0742,  ..., -4.4375, -1.5244, -1.0166],
         [-7.8125,  2.4492,  5.0742,  ..., -4.4375, -1.5234, -1.0156],
         [-1.7158, -0.7563, -4.4922,  ...,  5.8828, -3.4629, -1.2812]],

        ...,

        [[-7.7930,  2.4336,  5.0625,  ..., -4.4258, -1.4932, -1.0127],
         [-7.7930,  2.4355,  5.0625,  ..., -4.4258, -1.4941, -1.0127],
         [-1.7129, -0.7563, -4.4922,  ...,  5.8789, -3.4648, -1.2783]],

        [[-7.7070,  2.4141,  4.9414,  ..., -4.4336, -1.

Epoch 0:  17%|█▋        | 1023/5879 [11:16<53:33,  1.51it/s, v_num=7]
 tensor([[[-7.7852,  2.4512,  5.0586,  ..., -4.4414, -1.5371, -1.0088],
         [-7.7852,  2.4512,  5.0586,  ..., -4.4414, -1.5371, -1.0088],
         [-1.6816, -0.7798, -4.5195,  ...,  5.8594, -3.4922, -1.2676]],

        [[-7.8359,  2.4629,  5.1172,  ..., -4.4492, -1.5322, -1.0244],
         [-7.8359,  2.4629,  5.1172,  ..., -4.4492, -1.5322, -1.0244],
         [-1.7139, -0.7603, -4.5273,  ...,  5.8906, -3.4844, -1.2783]],

        [[-7.8242,  2.4629,  5.1133,  ..., -4.4453, -1.5273, -1.0195],
         [-7.8242,  2.4629,  5.1172,  ..., -4.4453, -1.5273, -1.0195],
         [-1.7061, -0.7622, -4.5234,  ...,  5.8828, -3.4863, -1.2744]],

        ...,

        [[-7.7734,  2.4434,  5.0391,  ..., -4.4414, -1.5381, -1.0020],
         [-7.7734,  2.4434,  5.0391,  ..., -4.4414, -1.5381, -1.0020],
         [-1.6914, -0.7710, -4.5234,  ...,  5.8672, -3.4922, -1.2695]],

        [[-6.5039,  1.8740,  3.7363,  ..., -2.8125, -0.

Epoch 0:  18%|█▊        | 1029/5879 [11:18<53:17,  1.52it/s, v_num=7]
 tensor([[[-7.8438,  2.4766,  5.1523,  ..., -4.4609, -1.5439, -1.0264],
         [-7.8477,  2.4766,  5.1523,  ..., -4.4609, -1.5439, -1.0264],
         [-1.6875, -0.7759, -4.5547,  ...,  5.8750, -3.5098, -1.2666]],

        [[-7.8398,  2.4648,  5.1523,  ..., -4.4453, -1.5215, -1.0352],
         [-7.8359,  2.4648,  5.1484,  ..., -4.4453, -1.5215, -1.0352],
         [-1.7148, -0.7637, -4.5586,  ...,  5.9023, -3.5020, -1.2783]],

        [[-7.8359,  2.4766,  5.1406,  ..., -4.4531, -1.5498, -1.0312],
         [-7.8359,  2.4766,  5.1406,  ..., -4.4531, -1.5488, -1.0312],
         [-1.6904, -0.7749, -4.5547,  ...,  5.8750, -3.5098, -1.2676]],

        ...,

        [[-7.8398,  2.4746,  5.1562,  ..., -4.4570, -1.5361, -1.0283],
         [-7.8398,  2.4746,  5.1562,  ..., -4.4570, -1.5361, -1.0283],
         [-1.7090, -0.7642, -4.5547,  ...,  5.8945, -3.5059, -1.2744]],

        [[-7.7852,  2.4570,  5.0703,  ..., -4.4531, -1.

Epoch 0:  18%|█▊        | 1035/5879 [11:26<53:32,  1.51it/s, v_num=7]
 tensor([[[-7.8750,  2.4902,  5.2031,  ..., -4.4648, -1.5566, -1.0439],
         [-7.8750,  2.4902,  5.2031,  ..., -4.4648, -1.5566, -1.0439],
         [-1.7148, -0.7725, -4.5898,  ...,  5.9141, -3.5156, -1.2803]],

        [[-7.8555,  2.4668,  5.1875,  ..., -4.4414, -1.5156, -1.0576],
         [-7.8555,  2.4668,  5.1875,  ..., -4.4414, -1.5156, -1.0576],
         [-1.7119, -0.7783, -4.5898,  ...,  5.9141, -3.5117, -1.2822]],

        [[-7.8750,  2.4902,  5.2031,  ..., -4.4688, -1.5547, -1.0420],
         [-7.8750,  2.4902,  5.2031,  ..., -4.4688, -1.5547, -1.0420],
         [-1.7090, -0.7686, -4.5898,  ...,  5.9023, -3.5234, -1.2734]],

        ...,

        [[-7.3633,  2.3301,  4.5859,  ..., -4.5117, -1.6133, -0.6426],
         [-7.3633,  2.3281,  4.5820,  ..., -4.5117, -1.6143, -0.6401],
         [-1.5752, -0.8726, -4.5742,  ...,  5.8008, -3.5312, -1.2451]],

        [[-7.8438,  2.4648,  5.1797,  ..., -4.4375, -1.

Epoch 0:  18%|█▊        | 1041/5879 [11:35<53:50,  1.50it/s, v_num=7]
 tensor([[[-7.8828,  2.5078,  5.2383,  ..., -4.4766, -1.5771, -1.0508],
         [-7.8828,  2.5078,  5.2383,  ..., -4.4766, -1.5771, -1.0508],
         [-1.6885, -0.7812, -4.6172,  ...,  5.8867, -3.5488, -1.2646]],

        [[-7.7578,  2.4668,  5.0781,  ..., -4.4805, -1.5908, -0.9678],
         [-7.7539,  2.4668,  5.0781,  ..., -4.4766, -1.5908, -0.9678],
         [-1.6836, -0.7871, -4.6133,  ...,  5.8789, -3.5488, -1.2627]],

        [[-6.6172,  1.9297,  3.9219,  ..., -2.9629, -0.8403, -1.5693],
         [-6.6172,  1.9297,  3.9199,  ..., -2.9629, -0.8398, -1.5693],
         [-7.8281,  2.3809,  5.1719,  ..., -4.3047, -1.3984, -1.2100]],

        ...,

        [[-7.8750,  2.4902,  5.2344,  ..., -4.4648, -1.5449, -1.0557],
         [-7.8750,  2.4902,  5.2344,  ..., -4.4648, -1.5459, -1.0557],
         [-1.7070, -0.7725, -4.6211,  ...,  5.9062, -3.5430, -1.2715]],

        [[-7.8633,  2.4746,  5.2227,  ..., -4.4414, -1.

Epoch 0:  18%|█▊        | 1047/5879 [11:36<53:36,  1.50it/s, v_num=7]
 tensor([[[-7.6523,  2.3164,  5.0312,  ..., -4.1406, -1.2637, -1.2949],
         [-7.6484,  2.3164,  5.0312,  ..., -4.1406, -1.2627, -1.2949],
         [-1.7109, -0.7988, -4.6406,  ...,  5.9062, -3.5332, -1.2861]],

        [[-7.8984,  2.5059,  5.2812,  ..., -4.4766, -1.5635, -1.0635],
         [-7.8984,  2.5059,  5.2812,  ..., -4.4766, -1.5635, -1.0635],
         [-1.6982, -0.7783, -4.6484,  ...,  5.8984, -3.5664, -1.2666]],

        [[-7.8789,  2.4902,  5.2617,  ..., -4.4609, -1.5381, -1.0732],
         [-7.8789,  2.4902,  5.2617,  ..., -4.4609, -1.5391, -1.0732],
         [-1.6963, -0.7783, -4.6523,  ...,  5.8984, -3.5664, -1.2666]],

        ...,

        [[-7.9062,  2.5156,  5.2891,  ..., -4.4805, -1.5781, -1.0625],
         [-7.9062,  2.5156,  5.2891,  ..., -4.4805, -1.5781, -1.0625],
         [-1.7041, -0.7798, -4.6523,  ...,  5.9062, -3.5605, -1.2715]],

        [[-7.8828,  2.4941,  5.2656,  ..., -4.4648, -1.

Epoch 0:  18%|█▊        | 1053/5879 [11:38<53:23,  1.51it/s, v_num=7]
 tensor([[[-7.8906,  2.4941,  5.3047,  ..., -4.4492, -1.5371, -1.1006],
         [-7.8906,  2.4922,  5.3047,  ..., -4.4492, -1.5371, -1.1006],
         [-1.6963, -0.7832, -4.6836,  ...,  5.9102, -3.5820, -1.2676]],

        [[-7.9258,  2.5254,  5.3359,  ..., -4.4883, -1.5908, -1.0723],
         [-7.9258,  2.5254,  5.3359,  ..., -4.4883, -1.5908, -1.0723],
         [-1.7002, -0.7832, -4.6836,  ...,  5.9141, -3.5801, -1.2695]],

        [[-7.8867,  2.5137,  5.2734,  ..., -4.4883, -1.6064, -1.0557],
         [-7.8867,  2.5137,  5.2734,  ..., -4.4883, -1.6064, -1.0557],
         [-1.6729, -0.7944, -4.6797,  ...,  5.8828, -3.5898, -1.2588]],

        ...,

        [[-5.4102,  1.5957,  2.6855,  ..., -4.3945, -1.6377,  0.6807],
         [-5.4062,  1.5928,  2.6797,  ..., -4.3945, -1.6377,  0.6836],
         [-1.5205, -0.9292, -4.6602,  ...,  5.7656, -3.5840, -1.2334]],

        [[-7.8438,  2.4512,  5.2578,  ..., -4.3867, -1.

Epoch 0:  18%|█▊        | 1059/5879 [11:40<53:08,  1.51it/s, v_num=7]
 tensor([[[-7.9375,  2.5312,  5.3789,  ..., -4.4844, -1.5947, -1.0898],
         [-7.9414,  2.5312,  5.3789,  ..., -4.4844, -1.5947, -1.0898],
         [-1.6934, -0.7896, -4.7148,  ...,  5.9102, -3.6016, -1.2656]],

        [[-7.9492,  2.5391,  5.3906,  ..., -4.4922, -1.6094, -1.0869],
         [-7.9492,  2.5391,  5.3906,  ..., -4.4922, -1.6094, -1.0869],
         [-1.6895, -0.7891, -4.7109,  ...,  5.9023, -3.6055, -1.2617]],

        [[-7.9453,  2.5410,  5.3867,  ..., -4.4961, -1.6113, -1.0820],
         [-7.9453,  2.5410,  5.3867,  ..., -4.4961, -1.6113, -1.0830],
         [-1.6934, -0.7886, -4.7148,  ...,  5.9102, -3.6016, -1.2656]],

        ...,

        [[-7.9414,  2.5430,  5.3750,  ..., -4.4961, -1.6191, -1.0801],
         [-7.9414,  2.5430,  5.3750,  ..., -4.4961, -1.6191, -1.0801],
         [-1.6777, -0.7974, -4.7070,  ...,  5.8867, -3.6074, -1.2578]],

        [[-7.8789,  2.4746,  5.3203,  ..., -4.4062, -1.

Epoch 0:  18%|█▊        | 1065/5879 [11:48<53:21,  1.50it/s, v_num=7]
 tensor([[[-6.0430,  1.8711,  3.3398,  ..., -4.4805, -1.6865,  0.3115],
         [-6.0508,  1.8760,  3.3496,  ..., -4.4805, -1.6865,  0.3049],
         [-1.4619, -0.9902, -4.7031,  ...,  5.7227, -3.6172, -1.2217]],

        [[-7.9648,  2.5547,  5.4297,  ..., -4.5000, -1.6289, -1.0928],
         [-7.9648,  2.5547,  5.4297,  ..., -4.5000, -1.6299, -1.0928],
         [-1.6816, -0.7949, -4.7422,  ...,  5.8984, -3.6250, -1.2598]],

        [[-7.7578,  2.4922,  5.1719,  ..., -4.5156, -1.6572, -0.9463],
         [-7.7539,  2.4902,  5.1680,  ..., -4.5156, -1.6572, -0.9448],
         [-1.6484, -0.8164, -4.7383,  ...,  5.8711, -3.6309, -1.2500]],

        ...,

        [[-7.6758,  2.4609,  5.0703,  ..., -4.5195, -1.6611, -0.8760],
         [-7.6797,  2.4609,  5.0742,  ..., -4.5195, -1.6611, -0.8774],
         [-1.5439, -0.9136, -4.7188,  ...,  5.7852, -3.6270, -1.2334]],

        [[-7.9609,  2.5547,  5.4258,  ..., -4.5000, -1.

Epoch 0:  18%|█▊        | 1071/5879 [11:49<53:06,  1.51it/s, v_num=7]
 tensor([[[-7.9180,  2.5039,  5.4219,  ..., -4.4180, -1.5391, -1.1689],
         [-7.9180,  2.5039,  5.4219,  ..., -4.4180, -1.5391, -1.1689],
         [-1.6787, -0.8022, -4.7695,  ...,  5.9062, -3.6426, -1.2607]],

        [[-7.9883,  2.5645,  5.4883,  ..., -4.4961, -1.6387, -1.1152],
         [-7.9883,  2.5645,  5.4883,  ..., -4.4961, -1.6387, -1.1152],
         [-1.6797, -0.8047, -4.7695,  ...,  5.9062, -3.6387, -1.2617]],

        [[-7.9766,  2.5547,  5.4805,  ..., -4.4844, -1.6211, -1.1182],
         [-7.9766,  2.5547,  5.4805,  ..., -4.4844, -1.6211, -1.1182],
         [-1.6777, -0.8018, -4.7695,  ...,  5.9023, -3.6426, -1.2588]],

        ...,

        [[-0.0928, -4.7109, -1.5615,  ..., -2.4980, -0.9819,  2.3750],
         [-0.1031, -4.6680, -1.5605,  ..., -2.5078, -0.9849,  2.3809],
         [-1.5752, -0.8823, -4.7539,  ...,  5.8164, -3.6484, -1.2373]],

        [[-7.9609,  2.5371,  5.4688,  ..., -4.4609, -1.

Epoch 0:  18%|█▊        | 1077/5879 [11:51<52:51,  1.51it/s, v_num=7]
 tensor([[[-7.9531,  2.5332,  5.4883,  ..., -4.4453, -1.5850, -1.1582],
         [-7.9531,  2.5352,  5.4922,  ..., -4.4453, -1.5850, -1.1572],
         [-1.6572, -0.8174, -4.7930,  ...,  5.8828, -3.6680, -1.2520]],

        [[-6.9961,  2.2422,  4.3672,  ..., -4.5469, -1.7236, -0.3315],
         [-6.9883,  2.2402,  4.3594,  ..., -4.5469, -1.7236, -0.3254],
         [-1.4854, -0.9736, -4.7617,  ...,  5.7461, -3.6582, -1.2246]],

        [[-7.7969,  2.4180,  5.3359,  ..., -4.2383, -1.4033, -1.3066],
         [-7.7969,  2.4180,  5.3359,  ..., -4.2383, -1.4033, -1.3066],
         [-1.6729, -0.8262, -4.7930,  ...,  5.9023, -3.6426, -1.2676]],

        ...,

        [[-7.9961,  2.5879,  5.5273,  ..., -4.5000, -1.6768, -1.1182],
         [-7.9961,  2.5879,  5.5273,  ..., -4.5000, -1.6768, -1.1182],
         [-1.6729, -0.8110, -4.7969,  ...,  5.9062, -3.6582, -1.2598]],

        [[-6.8242,  2.1738,  4.1758,  ..., -4.5391, -1.

Epoch 0:  18%|█▊        | 1083/5879 [11:52<52:36,  1.52it/s, v_num=7]
 tensor([[[-8.0078,  2.5938,  5.5781,  ..., -4.5000, -1.6855, -1.1318],
         [-8.0078,  2.5938,  5.5781,  ..., -4.5000, -1.6855, -1.1318],
         [-1.6592, -0.8184, -4.8242,  ...,  5.8906, -3.6836, -1.2529]],

        [[-7.9844,  2.5820,  5.5352,  ..., -4.5000, -1.6865, -1.1162],
         [-7.9844,  2.5820,  5.5352,  ..., -4.5000, -1.6865, -1.1172],
         [-1.6289, -0.8433, -4.8164,  ...,  5.8594, -3.6875, -1.2441]],

        [[-7.9297,  2.5625,  5.4648,  ..., -4.5000, -1.6943, -1.0791],
         [-7.9297,  2.5625,  5.4648,  ..., -4.5000, -1.6943, -1.0791],
         [-1.6465, -0.8276, -4.8203,  ...,  5.8789, -3.6875, -1.2480]],

        ...,

        [[-8.0078,  2.5820,  5.5781,  ..., -4.4883, -1.6611, -1.1387],
         [-8.0078,  2.5820,  5.5781,  ..., -4.4883, -1.6611, -1.1387],
         [-1.6641, -0.8198, -4.8242,  ...,  5.9023, -3.6797, -1.2568]],

        [[ 0.5459, -8.3281, -0.9980,  ..., -1.0820, -0.

Epoch 0:  19%|█▊        | 1089/5879 [12:00<52:50,  1.51it/s, v_num=7]
 tensor([[[-8.0156,  2.5879,  5.6211,  ..., -4.4805, -1.6680, -1.1543],
         [-8.0156,  2.5879,  5.6211,  ..., -4.4805, -1.6680, -1.1543],
         [-1.6514, -0.8291, -4.8516,  ...,  5.8867, -3.7031, -1.2490]],

        [[-8.0078,  2.5742,  5.6133,  ..., -4.4609, -1.6475, -1.1729],
         [-8.0078,  2.5762,  5.6172,  ..., -4.4609, -1.6475, -1.1729],
         [-1.6572, -0.8306, -4.8516,  ...,  5.8945, -3.6973, -1.2529]],

        [[-8.0156,  2.6035,  5.6133,  ..., -4.5000, -1.7061, -1.1387],
         [-8.0156,  2.6035,  5.6133,  ..., -4.5000, -1.7061, -1.1387],
         [-1.6387, -0.8374, -4.8477,  ...,  5.8750, -3.7051, -1.2451]],

        ...,

        [[-7.8906,  2.4844,  5.5039,  ..., -4.3164, -1.5029, -1.2861],
         [-7.8906,  2.4844,  5.5039,  ..., -4.3164, -1.5029, -1.2861],
         [-1.6562, -0.8442, -4.8477,  ...,  5.8945, -3.6797, -1.2607]],

        [[-7.9570,  2.5898,  5.5430,  ..., -4.5039, -1.

Epoch 0:  19%|█▊        | 1095/5879 [12:02<52:35,  1.52it/s, v_num=7]
 tensor([[[-7.9062,  2.5762,  5.5156,  ..., -4.5000, -1.7354, -1.0654],
         [-7.9062,  2.5762,  5.5156,  ..., -4.5000, -1.7354, -1.0654],
         [-1.6084, -0.8667, -4.8711,  ...,  5.8477, -3.7227, -1.2373]],

        [[-8.0312,  2.6133,  5.6641,  ..., -4.4961, -1.7197, -1.1553],
         [-8.0312,  2.6113,  5.6680,  ..., -4.4961, -1.7197, -1.1562],
         [-1.6484, -0.8398, -4.8789,  ...,  5.8945, -3.7129, -1.2500]],

        [[-8.0391,  2.6113,  5.6836,  ..., -4.4922, -1.7041, -1.1641],
         [-8.0391,  2.6113,  5.6836,  ..., -4.4922, -1.7041, -1.1641],
         [-1.6406, -0.8418, -4.8750,  ...,  5.8789, -3.7207, -1.2451]],

        ...,

        [[-8.0391,  2.6172,  5.6797,  ..., -4.4961, -1.7197, -1.1582],
         [-8.0391,  2.6172,  5.6797,  ..., -4.4961, -1.7197, -1.1582],
         [-1.6328, -0.8467, -4.8750,  ...,  5.8711, -3.7227, -1.2422]],

        [[-8.0000,  2.5742,  5.6484,  ..., -4.4453, -1.

Epoch 0:  19%|█▊        | 1101/5879 [12:08<52:43,  1.51it/s, v_num=7]
 tensor([[[-8.0469,  2.6250,  5.7227,  ..., -4.4961, -1.7373, -1.1699],
         [-8.0469,  2.6250,  5.7227,  ..., -4.4961, -1.7363, -1.1699],
         [-1.6377, -0.8545, -4.9062,  ...,  5.8867, -3.7285, -1.2471]],

        [[-7.8203,  2.4492,  5.5039,  ..., -4.2109, -1.4570, -1.3730],
         [-7.8203,  2.4492,  5.5039,  ..., -4.2109, -1.4570, -1.3730],
         [-1.6377, -0.8647, -4.9023,  ...,  5.8867, -3.7168, -1.2520]],

        [[-8.0234,  2.5918,  5.7070,  ..., -4.4531, -1.6748, -1.2031],
         [-8.0234,  2.5918,  5.7070,  ..., -4.4531, -1.6748, -1.2021],
         [-1.6377, -0.8506, -4.9062,  ...,  5.8828, -3.7344, -1.2441]],

        ...,

        [[-7.9922,  2.5664,  5.6758,  ..., -4.4180, -1.6348, -1.2324],
         [-7.9961,  2.5664,  5.6758,  ..., -4.4180, -1.6348, -1.2324],
         [-1.6367, -0.8530, -4.9062,  ...,  5.8867, -3.7305, -1.2451]],

        [[-8.0391,  2.6074,  5.7188,  ..., -4.4727, -1.

Epoch 0:  19%|█▉        | 1107/5879 [12:10<52:28,  1.52it/s, v_num=7]
 tensor([[[-4.7070,  1.3721,  1.7314,  ..., -0.7104, -0.8623, -1.6963],
         [-4.7109,  1.3730,  1.7393,  ..., -0.7168, -0.8618, -1.6973],
         [-1.6289, -0.8823, -4.9258,  ...,  5.8711, -3.7246, -1.2510]],

        [[-8.0156,  2.5859,  5.7383,  ..., -4.4297, -1.6641, -1.2363],
         [-8.0156,  2.5859,  5.7344,  ..., -4.4297, -1.6641, -1.2354],
         [-1.6289, -0.8662, -4.9336,  ...,  5.8789, -3.7461, -1.2412]],

        [[-8.0625,  2.6348,  5.7773,  ..., -4.4922, -1.7451, -1.1816],
         [-8.0625,  2.6348,  5.7773,  ..., -4.4922, -1.7441, -1.1816],
         [-1.6133, -0.8721, -4.9297,  ...,  5.8555, -3.7559, -1.2344]],

        ...,

        [[-8.0469,  2.6348,  5.7617,  ..., -4.4961, -1.7578, -1.1709],
         [-8.0469,  2.6348,  5.7578,  ..., -4.4961, -1.7578, -1.1699],
         [-1.5908, -0.8906, -4.9258,  ...,  5.8359, -3.7578, -1.2295]],

        [[-8.0469,  2.6387,  5.7656,  ..., -4.5000, -1.

Epoch 0:  19%|█▉        | 1113/5879 [12:13<52:19,  1.52it/s, v_num=7]
 tensor([[[-8.0703,  2.6445,  5.8125,  ..., -4.4961, -1.7705, -1.1924],
         [-8.0703,  2.6445,  5.8125,  ..., -4.4961, -1.7705, -1.1924],
         [-1.6152, -0.8779, -4.9648,  ...,  5.8711, -3.7676, -1.2354]],

        [[-7.4609,  2.4414,  5.0820,  ..., -4.5234, -1.8096, -0.6968],
         [-7.4688,  2.4434,  5.0898,  ..., -4.5234, -1.8096, -0.7031],
         [-1.5195, -0.9619, -4.9414,  ...,  5.7773, -3.7695, -1.2168]],

        [[-8.0625,  2.6484,  5.8164,  ..., -4.4922, -1.7725, -1.1895],
         [-8.0625,  2.6484,  5.8164,  ..., -4.4922, -1.7715, -1.1895],
         [-1.6084, -0.8818, -4.9570,  ...,  5.8555, -3.7715, -1.2314]],

        ...,

        [[-8.0234,  2.5957,  5.7812,  ..., -4.4297, -1.6826, -1.2461],
         [-8.0234,  2.5957,  5.7812,  ..., -4.4297, -1.6826, -1.2461],
         [-1.6191, -0.8838, -4.9609,  ...,  5.8750, -3.7578, -1.2393]],

        [[-6.7188,  2.1641,  4.2539,  ..., -4.4961, -1.

Epoch 0:  19%|█▉        | 1119/5879 [12:14<52:04,  1.52it/s, v_num=7]
 tensor([[[-8.0156,  2.5957,  5.8164,  ..., -4.4062, -1.6816, -1.2783],
         [-8.0156,  2.5957,  5.8164,  ..., -4.4062, -1.6816, -1.2783],
         [-1.6094, -0.9038, -4.9844,  ...,  5.8711, -3.7676, -1.2393]],

        [[-8.0703,  2.6445,  5.8633,  ..., -4.4727, -1.7598, -1.2197],
         [-8.0703,  2.6445,  5.8594,  ..., -4.4727, -1.7598, -1.2197],
         [-1.6084, -0.8916, -4.9883,  ...,  5.8672, -3.7832, -1.2324]],

        [[-7.8203,  2.5742,  5.5430,  ..., -4.5039, -1.8135, -1.0020],
         [-7.8203,  2.5742,  5.5430,  ..., -4.5039, -1.8125, -1.0010],
         [-1.5557, -0.9336, -4.9727,  ...,  5.8086, -3.7891, -1.2197]],

        ...,

        [[-8.0625,  2.6309,  5.8555,  ..., -4.4531, -1.7363, -1.2363],
         [-8.0625,  2.6309,  5.8555,  ..., -4.4531, -1.7363, -1.2363],
         [-1.6094, -0.8931, -4.9883,  ...,  5.8672, -3.7832, -1.2334]],

        [[-6.5742,  2.1074,  4.1289,  ..., -4.4727, -1.

Epoch 0:  19%|█▉        | 1125/5879 [12:21<52:13,  1.52it/s, v_num=7]
 tensor([[[-7.5352,  2.4824,  5.2422,  ..., -4.5078, -1.8428, -0.7637],
         [-7.5391,  2.4844,  5.2422,  ..., -4.5078, -1.8428, -0.7651],
         [-1.5010, -0.9917, -4.9883,  ...,  5.7656, -3.8008, -1.2100]],

        [[-8.0469,  2.6621,  5.8711,  ..., -4.4883, -1.8154, -1.1914],
         [-8.0469,  2.6621,  5.8711,  ..., -4.4883, -1.8154, -1.1914],
         [-1.5566, -0.9385, -5.0000,  ...,  5.8164, -3.8066, -1.2188]],

        [[-8.0859,  2.6680,  5.9102,  ..., -4.4883, -1.8057, -1.2148],
         [-8.0859,  2.6660,  5.9141,  ..., -4.4844, -1.7998, -1.2178],
         [-1.5986, -0.9082, -5.0117,  ...,  5.8633, -3.7969, -1.2305]],

        ...,

        [[-8.0859,  2.6641,  5.9180,  ..., -4.4805, -1.7930, -1.2256],
         [-8.0859,  2.6641,  5.9180,  ..., -4.4805, -1.7930, -1.2256],
         [-1.5977, -0.9082, -5.0117,  ...,  5.8633, -3.7969, -1.2295]],

        [[-8.0859,  2.6641,  5.9141,  ..., -4.4805, -1.

Epoch 0:  19%|█▉        | 1131/5879 [12:23<51:59,  1.52it/s, v_num=7]
 tensor([[[-8.0938,  2.6738,  5.9609,  ..., -4.4727, -1.8066, -1.2402],
         [-8.0859,  2.6719,  5.9609,  ..., -4.4727, -1.8057, -1.2402],
         [-1.5820, -0.9253, -5.0352,  ...,  5.8477, -3.8203, -1.2227]],

        [[-3.6914,  0.6533,  1.3115,  ..., -3.9004, -1.5098,  1.6318],
         [-3.6836,  0.6484,  1.3057,  ..., -3.8984, -1.5078,  1.6348],
         [-1.4893, -1.0117, -5.0117,  ...,  5.7578, -3.8164, -1.2080]],

        [[-8.0859,  2.6816,  5.9570,  ..., -4.4805, -1.8262, -1.2236],
         [-8.0859,  2.6816,  5.9570,  ..., -4.4805, -1.8262, -1.2236],
         [-1.5605, -0.9429, -5.0273,  ...,  5.8203, -3.8223, -1.2178]],

        ...,

        [[-8.0859,  2.6816,  5.9570,  ..., -4.4805, -1.8232, -1.2266],
         [-8.0859,  2.6816,  5.9570,  ..., -4.4805, -1.8242, -1.2266],
         [-1.5830, -0.9243, -5.0352,  ...,  5.8477, -3.8203, -1.2227]],

        [[-8.0391,  2.6191,  5.9141,  ..., -4.3984, -1.

Epoch 0:  19%|█▉        | 1137/5879 [12:25<51:49,  1.53it/s, v_num=7]
 tensor([[[-7.8867,  2.6211,  5.7383,  ..., -4.4844, -1.8613, -1.0752],
         [-7.8867,  2.6211,  5.7383,  ..., -4.4844, -1.8613, -1.0762],
         [-1.5146, -0.9941, -5.0430,  ...,  5.7812, -3.8359, -1.2090]],

        [[-7.7695,  2.5840,  5.5977,  ..., -4.4883, -1.8691, -0.9751],
         [-7.7695,  2.5820,  5.5977,  ..., -4.4883, -1.8691, -0.9736],
         [-1.4385, -1.0703, -5.0234,  ...,  5.7148, -3.8242, -1.1992]],

        [[-8.0859,  2.6758,  6.0039,  ..., -4.4570, -1.8115, -1.2617],
         [-8.0859,  2.6758,  6.0039,  ..., -4.4570, -1.8125, -1.2617],
         [-1.5781, -0.9404, -5.0586,  ...,  5.8516, -3.8320, -1.2217]],

        ...,

        [[-8.0781,  2.6641,  5.9922,  ..., -4.4414, -1.7900, -1.2783],
         [-8.0781,  2.6621,  5.9922,  ..., -4.4375, -1.7891, -1.2793],
         [-1.5791, -0.9419, -5.0586,  ...,  5.8516, -3.8301, -1.2236]],

        [[-8.0859,  2.6895,  5.9922,  ..., -4.4766, -1.

Epoch 0:  19%|█▉        | 1143/5879 [12:26<51:35,  1.53it/s, v_num=7]
 tensor([[[-8.0859,  2.6797,  6.0391,  ..., -4.4414, -1.8164, -1.2822],
         [-8.0781,  2.6797,  6.0391,  ..., -4.4414, -1.8164, -1.2812],
         [-1.5674, -0.9604, -5.0820,  ...,  5.8438, -3.8477, -1.2188]],

        [[-8.0547,  2.6562,  6.0156,  ..., -4.4141, -1.7793, -1.3076],
         [-8.0547,  2.6562,  6.0156,  ..., -4.4141, -1.7793, -1.3076],
         [-1.5684, -0.9609, -5.0820,  ...,  5.8438, -3.8457, -1.2188]],

        [[ 0.0983, -6.0469, -1.5117,  ..., -2.0918, -0.5938,  2.1738],
         [ 0.0978, -6.0430, -1.5117,  ..., -2.0918, -0.5938,  2.1738],
         [-1.3154, -1.2109, -5.0078,  ...,  5.5938, -3.8047, -1.1836]],

        ...,

        [[-7.8984,  2.6426,  5.8047,  ..., -4.4805, -1.8828, -1.0898],
         [-7.8984,  2.6406,  5.8008,  ..., -4.4805, -1.8828, -1.0869],
         [-1.4746, -1.0420, -5.0586,  ...,  5.7500, -3.8477, -1.2021]],

        [[-8.0469,  2.6445,  6.0078,  ..., -4.3945, -1.

Epoch 0:  20%|█▉        | 1149/5879 [12:36<51:56,  1.52it/s, v_num=7]
 tensor([[[-8.0859,  2.6875,  6.0820,  ..., -4.4336, -1.8301, -1.3018],
         [-8.0859,  2.6875,  6.0820,  ..., -4.4336, -1.8301, -1.3008],
         [-1.5576, -0.9800, -5.1055,  ...,  5.8359, -3.8633, -1.2148]],

        [[-0.2161, -4.5547, -1.5430,  ..., -2.4414, -0.7319,  2.4922],
         [-0.2725, -4.3281, -1.5303,  ..., -2.4922, -0.7559,  2.5234],
         [-0.6621, -2.0605, -4.5977,  ...,  4.7227, -3.4160, -1.1094]],

        [[-8.0703,  2.6758,  6.0703,  ..., -4.4180, -1.8086, -1.3154],
         [-8.0703,  2.6758,  6.0703,  ..., -4.4180, -1.8096, -1.3154],
         [-1.5586, -0.9805, -5.1094,  ...,  5.8359, -3.8613, -1.2158]],

        ...,

        [[-8.0859,  2.7012,  6.0898,  ..., -4.4492, -1.8496, -1.2822],
         [-8.0859,  2.6992,  6.0859,  ..., -4.4492, -1.8496, -1.2822],
         [-1.5547, -0.9810, -5.1055,  ...,  5.8320, -3.8652, -1.2139]],

        [[-4.7773,  1.2920,  2.4199,  ..., -4.1055, -1.

Epoch 0:  20%|█▉        | 1155/5879 [12:38<51:42,  1.52it/s, v_num=7]
 tensor([[[-7.8164,  2.6328,  5.7852,  ..., -4.4727, -1.9189, -1.0381],
         [-7.8164,  2.6328,  5.7852,  ..., -4.4727, -1.9189, -1.0391],
         [-1.3457, -1.1992, -5.0703,  ...,  5.6250, -3.8477, -1.1855]],

        [[-8.0938,  2.7129,  6.1367,  ..., -4.4453, -1.8672, -1.3008],
         [-8.0938,  2.7129,  6.1367,  ..., -4.4453, -1.8672, -1.2998],
         [-1.5449, -1.0020, -5.1328,  ...,  5.8281, -3.8789, -1.2119]],

        [[-6.3125,  2.0352,  4.0547,  ..., -4.3672, -1.8662,  0.1266],
         [-6.3047,  2.0332,  4.0469,  ..., -4.3672, -1.8652,  0.1301],
         [-1.3916, -1.1504, -5.0859,  ...,  5.6680, -3.8613, -1.1904]],

        ...,

        [[-8.0859,  2.7207,  6.1211,  ..., -4.4609, -1.8945, -1.2676],
         [-8.0859,  2.7207,  6.1211,  ..., -4.4609, -1.8945, -1.2676],
         [-1.5361, -1.0068, -5.1289,  ...,  5.8125, -3.8828, -1.2090]],

        [[-8.0938,  2.7090,  6.1328,  ..., -4.4375, -1.

Epoch 0:  20%|█▉        | 1161/5879 [12:39<51:27,  1.53it/s, v_num=7]
 tensor([[[-8.0781,  2.7051,  6.1602,  ..., -4.4219, -1.8555, -1.3252],
         [-8.0781,  2.7051,  6.1602,  ..., -4.4219, -1.8545, -1.3252],
         [-1.5371, -1.0264, -5.1562,  ...,  5.8281, -3.8848, -1.2119]],

        [[-1.8203, -1.0029, -0.3955,  ..., -3.2480, -1.1445,  2.4375],
         [-1.8428, -0.9761, -0.3750,  ..., -3.2559, -1.1494,  2.4297],
         [-1.3750, -1.1758, -5.1133,  ...,  5.6602, -3.8750, -1.1865]],

        [[-8.1016,  2.7305,  6.1836,  ..., -4.4492, -1.9004, -1.3008],
         [-8.1016,  2.7305,  6.1836,  ..., -4.4492, -1.9004, -1.3008],
         [-1.5352, -1.0254, -5.1562,  ...,  5.8203, -3.8926, -1.2100]],

        ...,

        [[-8.0859,  2.7266,  6.1602,  ..., -4.4570, -1.9033, -1.2852],
         [-8.0859,  2.7266,  6.1602,  ..., -4.4570, -1.9033, -1.2852],
         [-1.5371, -1.0273, -5.1562,  ...,  5.8281, -3.8848, -1.2119]],

        [[-8.0781,  2.7109,  6.1680,  ..., -4.4219, -1.

Epoch 0:  20%|█▉        | 1167/5879 [12:43<51:21,  1.53it/s, v_num=7]
 tensor([[[-8.0469,  2.7324,  6.1523,  ..., -4.4531, -1.9385, -1.2568],
         [-8.0469,  2.7324,  6.1562,  ..., -4.4531, -1.9385, -1.2578],
         [-1.5078, -1.0586, -5.1719,  ...,  5.7930, -3.9121, -1.2021]],

        [[-7.8320,  2.6504,  5.8789,  ..., -4.4531, -1.9463, -1.0811],
         [-7.8359,  2.6523,  5.8828,  ..., -4.4531, -1.9463, -1.0840],
         [-1.4854, -1.0781, -5.1680,  ...,  5.7695, -3.9102, -1.1982]],

        [[-7.3516,  2.4688,  5.2969,  ..., -4.4414, -1.9453, -0.6733],
         [-7.3516,  2.4688,  5.2969,  ..., -4.4414, -1.9453, -0.6738],
         [-1.4756, -1.0879, -5.1641,  ...,  5.7578, -3.9102, -1.1973]],

        ...,

        [[-8.0938,  2.7402,  6.2148,  ..., -4.4453, -1.9189, -1.3057],
         [-8.0938,  2.7402,  6.2148,  ..., -4.4453, -1.9189, -1.3057],
         [-1.5166, -1.0518, -5.1758,  ...,  5.8047, -3.9102, -1.2041]],

        [[-8.0859,  2.7266,  6.2188,  ..., -4.4219, -1.

Epoch 0:  20%|█▉        | 1173/5879 [12:48<51:22,  1.53it/s, v_num=7]
 tensor([[[-8.0938,  2.7480,  6.2578,  ..., -4.4297, -1.9238, -1.3301],
         [-8.0859,  2.7480,  6.2578,  ..., -4.4297, -1.9248, -1.3301],
         [-1.5059, -1.0762, -5.1992,  ...,  5.8008, -3.9258, -1.2012]],

        [[-1.7129, -1.1445, -0.5054,  ..., -3.1699, -1.0996,  2.4844],
         [-1.7158, -1.1416, -0.5029,  ..., -3.1719, -1.1006,  2.4824],
         [-1.1846, -1.4209, -5.0742,  ...,  5.4531, -3.8320, -1.1650]],

        [[-8.0391,  2.6934,  6.2148,  ..., -4.3594, -1.8359, -1.4043],
         [-8.0391,  2.6934,  6.2148,  ..., -4.3594, -1.8359, -1.4043],
         [-1.5117, -1.0791, -5.2031,  ...,  5.8125, -3.9121, -1.2061]],

        ...,

        [[-8.0781,  2.7520,  6.2461,  ..., -4.4414, -1.9414, -1.3086],
         [-8.0781,  2.7520,  6.2461,  ..., -4.4414, -1.9414, -1.3076],
         [-1.5010, -1.0801, -5.1992,  ...,  5.7930, -3.9258, -1.2002]],

        [[-8.0859,  2.7500,  6.2461,  ..., -4.4414, -1.

Epoch 0:  20%|██        | 1179/5879 [12:51<51:15,  1.53it/s, v_num=7]
 tensor([[[-7.9180,  2.7090,  6.0781,  ..., -4.4375, -1.9775, -1.1807],
         [-7.9258,  2.7109,  6.0898,  ..., -4.4375, -1.9775, -1.1865],
         [-1.4658, -1.1260, -5.2148,  ...,  5.7578, -3.9395, -1.1934]],

        [[-7.8242,  2.6699,  5.9531,  ..., -4.4336, -1.9795, -1.0977],
         [-7.8203,  2.6699,  5.9531,  ..., -4.4336, -1.9785, -1.0977],
         [-1.4541, -1.1367, -5.2070,  ...,  5.7461, -3.9375, -1.1914]],

        [[-7.5352,  2.5605,  5.5977,  ..., -4.4258, -1.9785, -0.8506],
         [-7.5312,  2.5586,  5.5938,  ..., -4.4258, -1.9785, -0.8477],
         [-1.3867, -1.2070, -5.1875,  ...,  5.6758, -3.9238, -1.1836]],

        ...,

        [[-8.0859,  2.7637,  6.2969,  ..., -4.4297, -1.9541, -1.3311],
         [-8.0859,  2.7637,  6.2969,  ..., -4.4297, -1.9541, -1.3311],
         [-1.4883, -1.1055, -5.2227,  ...,  5.7852, -3.9395, -1.1982]],

        [[-7.9922,  2.7344,  6.1680,  ..., -4.4375, -1.

Epoch 0:  20%|██        | 1185/5879 [12:54<51:06,  1.53it/s, v_num=7]
 tensor([[[-8.0781,  2.7598,  6.3320,  ..., -4.4023, -1.9385, -1.3750],
         [-8.0781,  2.7578,  6.3320,  ..., -4.4023, -1.9375, -1.3750],
         [-1.4814, -1.1289, -5.2422,  ...,  5.7812, -3.9531, -1.1943]],

        [[-7.5430,  2.5742,  5.6484,  ..., -4.4141, -1.9941, -0.8701],
         [-7.5430,  2.5742,  5.6484,  ..., -4.4141, -1.9941, -0.8706],
         [-1.3799, -1.2285, -5.2070,  ...,  5.6719, -3.9395, -1.1816]],

        [[-8.0547,  2.7695,  6.3008,  ..., -4.4258, -1.9785, -1.3184],
         [-8.0547,  2.7695,  6.3008,  ..., -4.4258, -1.9785, -1.3193],
         [-1.4619, -1.1455, -5.2383,  ...,  5.7578, -3.9531, -1.1914]],

        ...,

        [[-7.9062,  2.7188,  6.1055,  ..., -4.4297, -1.9951, -1.1855],
         [-7.9062,  2.7188,  6.1016,  ..., -4.4297, -1.9951, -1.1836],
         [-1.4912, -1.1250, -5.2461,  ...,  5.8008, -3.9414, -1.1982]],

        [[-8.0625,  2.7695,  6.3086,  ..., -4.4258, -1.

Epoch 0:  20%|██        | 1191/5879 [13:04<51:26,  1.52it/s, v_num=7]
 tensor([[[-8.0625,  2.7812,  6.3555,  ..., -4.4141, -1.9873, -1.3506],
         [-8.0625,  2.7812,  6.3555,  ..., -4.4141, -1.9873, -1.3506],
         [-1.4707, -1.1543, -5.2656,  ...,  5.7734, -3.9668, -1.1914]],

        [[-8.0391,  2.7773,  6.3281,  ..., -4.4180, -1.9961, -1.3262],
         [-8.0391,  2.7773,  6.3281,  ..., -4.4180, -1.9961, -1.3262],
         [-1.4639, -1.1602, -5.2617,  ...,  5.7656, -3.9688, -1.1904]],

        [[-7.5742,  2.5918,  5.7227,  ..., -4.4023, -2.0020, -0.9194],
         [-7.5742,  2.5918,  5.7227,  ..., -4.4023, -2.0020, -0.9194],
         [-1.4258, -1.1953, -5.2500,  ...,  5.7227, -3.9648, -1.1846]],

        ...,

        [[-2.6387,  0.0745, -2.5801,  ...,  3.0195, -2.3438, -1.6436],
         [-2.6348,  0.0704, -2.5938,  ...,  3.0312, -2.3496, -1.6426],
         [-1.4814, -1.1650, -5.2539,  ...,  5.7812, -3.9336, -1.2021]],

        [[-7.2266,  2.3047,  5.5234,  ..., -3.5684, -1.

Epoch 0:  20%|██        | 1197/5879 [13:05<51:13,  1.52it/s, v_num=7]
 tensor([[[-7.8398,  2.7148,  6.1094,  ..., -4.4062, -2.0215, -1.1680],
         [-7.8398,  2.7148,  6.1094,  ..., -4.4062, -2.0215, -1.1680],
         [-1.4072, -1.2305, -5.2656,  ...,  5.7070, -3.9785, -1.1807]],

        [[-8.0312,  2.7852,  6.3594,  ..., -4.4102, -2.0098, -1.3408],
         [-8.0312,  2.7852,  6.3594,  ..., -4.4102, -2.0098, -1.3408],
         [-1.4609, -1.1797, -5.2852,  ...,  5.7695, -3.9805, -1.1865]],

        [[-8.0625,  2.7852,  6.4062,  ..., -4.3945, -1.9834, -1.3926],
         [-8.0625,  2.7871,  6.4062,  ..., -4.3945, -1.9824, -1.3936],
         [-1.4629, -1.1797, -5.2852,  ...,  5.7734, -3.9805, -1.1875]],

        ...,

        [[-8.0625,  2.7871,  6.4023,  ..., -4.3984, -1.9893, -1.3857],
         [-8.0625,  2.7891,  6.4023,  ..., -4.3984, -1.9893, -1.3857],
         [-1.4668, -1.1758, -5.2852,  ...,  5.7812, -3.9785, -1.1885]],

        [[-8.0625,  2.7891,  6.4023,  ..., -4.3984, -1.

Epoch 0:  20%|██        | 1203/5879 [13:07<51:01,  1.53it/s, v_num=7]
 tensor([[[-8.0391,  2.7617,  6.4336,  ..., -4.3398, -1.9346, -1.4814],
         [-8.0391,  2.7617,  6.4336,  ..., -4.3398, -1.9336, -1.4814],
         [-1.4570, -1.2100, -5.3047,  ...,  5.7773, -3.9805, -1.1875]],

        [[-2.2422, -0.5439,  0.0125,  ..., -3.2109, -1.1777,  2.3262],
         [-2.2480, -0.5371,  0.0186,  ..., -3.2129, -1.1787,  2.3223],
         [-1.3828, -1.2695, -5.2852,  ...,  5.6836, -3.9922, -1.1748]],

        [[-3.9219,  1.0635,  1.0762,  ..., -0.0910, -1.0420, -1.9326],
         [-3.9219,  1.0635,  1.0742,  ..., -0.0899, -1.0420, -1.9316],
         [-1.4590, -1.2168, -5.2930,  ...,  5.7695, -3.9648, -1.1914]],

        ...,

        [[-7.6992,  2.6699,  5.9727,  ..., -4.3828, -2.0371, -1.0645],
         [-7.7031,  2.6719,  5.9766,  ..., -4.3828, -2.0371, -1.0674],
         [-1.3301, -1.3271, -5.2617,  ...,  5.6250, -3.9746, -1.1689]],

        [[-8.0312,  2.7969,  6.4062,  ..., -4.3984, -2.

Epoch 0:  21%|██        | 1209/5879 [13:11<50:56,  1.53it/s, v_num=7]
 tensor([[[-7.3672,  2.5391,  5.5977,  ..., -4.3438, -2.0332, -0.8013],
         [-7.3750,  2.5430,  5.6055,  ..., -4.3438, -2.0352, -0.8071],
         [-1.3682, -1.3027, -5.3008,  ...,  5.6719, -4.0039, -1.1689]],

        [[-8.0547,  2.8008,  6.4883,  ..., -4.3633, -1.9961, -1.4561],
         [-8.0547,  2.8008,  6.4883,  ..., -4.3633, -1.9961, -1.4561],
         [-1.4453, -1.2314, -5.3281,  ...,  5.7695, -4.0078, -1.1787]],

        [[-8.0391,  2.8086,  6.4688,  ..., -4.3789, -2.0234, -1.4082],
         [-8.0391,  2.8086,  6.4688,  ..., -4.3789, -2.0234, -1.4082],
         [-1.4258, -1.2461, -5.3203,  ...,  5.7422, -4.0117, -1.1758]],

        ...,

        [[-7.7812,  2.7109,  6.1211,  ..., -4.3789, -2.0469, -1.1650],
         [-7.7812,  2.7109,  6.1172,  ..., -4.3750, -2.0469, -1.1641],
         [-1.4414, -1.2324, -5.3281,  ...,  5.7617, -4.0117, -1.1777]],

        [[-8.0234,  2.8066,  6.4453,  ..., -4.3867, -2.

Epoch 0:  21%|██        | 1215/5879 [13:17<51:01,  1.52it/s, v_num=7]
 tensor([[[-7.7695,  2.6035,  6.2734,  ..., -4.0664, -1.7207, -1.7402],
         [-7.7695,  2.6035,  6.2734,  ..., -4.0664, -1.7197, -1.7402],
         [-1.4365, -1.2617, -5.3438,  ...,  5.7656, -4.0156, -1.1758]],

        [[-5.5781,  1.7402,  3.5547,  ..., -4.0195, -1.8525,  0.5239],
         [-5.5898,  1.7451,  3.5664,  ..., -4.0234, -1.8545,  0.5166],
         [-1.3008, -1.3906, -5.2930,  ...,  5.6016, -3.9980, -1.1582]],

        [[-8.0391,  2.8125,  6.5195,  ..., -4.3555, -2.0176, -1.4658],
         [-8.0391,  2.8125,  6.5195,  ..., -4.3555, -2.0176, -1.4658],
         [-1.4307, -1.2607, -5.3438,  ...,  5.7578, -4.0234, -1.1738]],

        ...,

        [[-8.0391,  2.8125,  6.5156,  ..., -4.3594, -2.0215, -1.4600],
         [-8.0391,  2.8125,  6.5156,  ..., -4.3594, -2.0215, -1.4600],
         [-1.4277, -1.2627, -5.3438,  ...,  5.7539, -4.0273, -1.1729]],

        [[-1.5488, -1.4131, -0.6626,  ..., -2.9121, -0.

Epoch 0:  21%|██        | 1221/5879 [13:19<50:50,  1.53it/s, v_num=7]
 tensor([[[-7.9531,  2.8086,  6.4453,  ..., -4.3633, -2.0703, -1.3730],
         [-7.9531,  2.8086,  6.4414,  ..., -4.3633, -2.0703, -1.3730],
         [-1.4004, -1.3057, -5.3555,  ...,  5.7266, -4.0391, -1.1660]],

        [[-8.0234,  2.8066,  6.5547,  ..., -4.3242, -2.0020, -1.5225],
         [-8.0234,  2.8066,  6.5547,  ..., -4.3242, -2.0020, -1.5225],
         [-1.4209, -1.2891, -5.3633,  ...,  5.7539, -4.0391, -1.1689]],

        [[-8.0156,  2.8262,  6.5312,  ..., -4.3594, -2.0508, -1.4492],
         [-8.0156,  2.8262,  6.5352,  ..., -4.3594, -2.0508, -1.4492],
         [-1.4102, -1.2969, -5.3594,  ...,  5.7383, -4.0430, -1.1670]],

        ...,

        [[-7.9219,  2.7969,  6.4023,  ..., -4.3633, -2.0742, -1.3398],
         [-7.9180,  2.7969,  6.3984,  ..., -4.3633, -2.0742, -1.3369],
         [-1.3770, -1.3271, -5.3477,  ...,  5.6992, -4.0391, -1.1631]],

        [[-8.0234,  2.8281,  6.5430,  ..., -4.3594, -2.

Epoch 0:  21%|██        | 1227/5879 [13:21<50:40,  1.53it/s, v_num=7]
 tensor([[[-2.2520, -0.4199, -3.6367,  ...,  3.9316, -2.9375, -1.6016],
         [-2.2500, -0.4204, -3.6367,  ...,  3.9316, -2.9375, -1.6006],
         [-1.4141, -1.3330, -5.3672,  ...,  5.7461, -4.0195, -1.1738]],

        [[-7.9141,  2.8066,  6.4453,  ..., -4.3516, -2.0879, -1.3662],
         [-7.9141,  2.8066,  6.4453,  ..., -4.3516, -2.0879, -1.3662],
         [-1.4023, -1.3223, -5.3789,  ...,  5.7422, -4.0547, -1.1631]],

        [[-6.1406,  2.0293,  4.2656,  ..., -4.0859, -1.9521,  0.1019],
         [-6.1680,  2.0410,  4.2969,  ..., -4.0898, -1.9551,  0.0838],
         [-1.1914, -1.5498, -5.2812,  ...,  5.4805, -3.9883, -1.1445]],

        ...,

        [[-7.9180,  2.7285,  6.5117,  ..., -4.2031, -1.8838, -1.6807],
         [-7.9102,  2.7227,  6.5039,  ..., -4.1914, -1.8760, -1.6875],
         [-1.4121, -1.3203, -5.3828,  ...,  5.7578, -4.0469, -1.1670]],

        [[-8.0156,  2.8340,  6.5898,  ..., -4.3398, -2.

Epoch 0:  21%|██        | 1233/5879 [13:23<50:28,  1.53it/s, v_num=7]
 tensor([[[-8.0078,  2.8457,  6.6250,  ..., -4.3320, -2.0723, -1.5166],
         [-8.0078,  2.8438,  6.6211,  ..., -4.3320, -2.0723, -1.5156],
         [-1.3936, -1.3506, -5.4023,  ...,  5.7500, -4.0664, -1.1592]],

        [[-7.7930,  2.7715,  6.3320,  ..., -4.3242, -2.1055, -1.2812],
         [-7.7891,  2.7715,  6.3320,  ..., -4.3242, -2.1035, -1.2803],
         [-1.3389, -1.4014, -5.3789,  ...,  5.6758, -4.0625, -1.1533]],

        [[-8.0078,  2.8359,  6.6289,  ..., -4.3164, -2.0508, -1.5420],
         [-8.0078,  2.8359,  6.6289,  ..., -4.3164, -2.0508, -1.5430],
         [-1.3916, -1.3525, -5.3984,  ...,  5.7422, -4.0703, -1.1592]],

        ...,

        [[-7.9688,  2.8379,  6.5664,  ..., -4.3398, -2.0938, -1.4541],
         [-7.9648,  2.8379,  6.5625,  ..., -4.3398, -2.0938, -1.4531],
         [-1.3975, -1.3496, -5.4023,  ...,  5.7539, -4.0625, -1.1602]],

        [[-8.0078,  2.8438,  6.6289,  ..., -4.3281, -2.

Epoch 0:  21%|██        | 1239/5879 [13:29<50:30,  1.53it/s, v_num=7]
 tensor([[[-7.9023,  2.8262,  6.5273,  ..., -4.3242, -2.1094, -1.4238],
         [-7.9023,  2.8262,  6.5273,  ..., -4.3242, -2.1094, -1.4229],
         [-1.3770, -1.3848, -5.4180,  ...,  5.7383, -4.0820, -1.1533]],

        [[-7.9453,  2.8418,  6.5820,  ..., -4.3281, -2.1094, -1.4648],
         [-7.9453,  2.8438,  6.5820,  ..., -4.3281, -2.1094, -1.4668],
         [-1.3691, -1.3906, -5.4141,  ...,  5.7266, -4.0820, -1.1523]],

        [[-7.5664,  2.6875,  6.0820,  ..., -4.2852, -2.1094, -1.1055],
         [-7.5664,  2.6875,  6.0859,  ..., -4.2852, -2.1094, -1.1064],
         [-1.3193, -1.4404, -5.3945,  ...,  5.6602, -4.0742, -1.1475]],

        ...,

        [[-7.9414,  2.8418,  6.5742,  ..., -4.3281, -2.1055, -1.4600],
         [-7.9375,  2.8398,  6.5742,  ..., -4.3281, -2.1055, -1.4580],
         [-1.3076, -1.4531, -5.3867,  ...,  5.6445, -4.0703, -1.1475]],

        [[ 0.1209, -7.9023, -0.8945,  ..., -1.5400,  0.

Epoch 0:  21%|██        | 1245/5879 [13:34<50:29,  1.53it/s, v_num=7]
 tensor([[[-7.9766,  2.8574,  6.6875,  ..., -4.3047, -2.0898, -1.5703],
         [-7.9766,  2.8574,  6.6875,  ..., -4.3047, -2.0898, -1.5713],
         [-1.3682, -1.4131, -5.4336,  ...,  5.7383, -4.0977, -1.1484]],

        [[-7.0977,  2.4922,  5.5312,  ..., -4.1953, -2.0898, -0.7202],
         [-7.0977,  2.4922,  5.5312,  ..., -4.1953, -2.0898, -0.7202],
         [-1.3184, -1.4600, -5.4141,  ...,  5.6719, -4.0898, -1.1416]],

        [[-7.9453,  2.8086,  6.6758,  ..., -4.2383, -1.9990, -1.6836],
         [-7.9414,  2.8086,  6.6758,  ..., -4.2383, -1.9990, -1.6846],
         [-1.3711, -1.4111, -5.4375,  ...,  5.7461, -4.0938, -1.1494]],

        ...,

        [[-7.8789,  2.8301,  6.5430,  ..., -4.3125, -2.1250, -1.4346],
         [-7.8789,  2.8320,  6.5430,  ..., -4.3125, -2.1250, -1.4346],
         [-1.3418, -1.4355, -5.4258,  ...,  5.7031, -4.0977, -1.1455]],

        [[-7.9766,  2.8574,  6.6875,  ..., -4.3047, -2.

Epoch 0:  21%|██▏       | 1251/5879 [13:35<50:18,  1.53it/s, v_num=7]
 tensor([[[-7.9648e+00,  2.8613e+00,  6.7266e+00,  ..., -4.2852e+00,
          -2.0859e+00, -1.6230e+00],
         [-7.9648e+00,  2.8613e+00,  6.7266e+00,  ..., -4.2852e+00,
          -2.0879e+00, -1.6221e+00],
         [-1.3574e+00, -1.4443e+00, -5.4531e+00,  ...,  5.7422e+00,
          -4.1094e+00, -1.1426e+00]],

        [[-7.6680e+00,  2.7578e+00,  6.3125e+00,  ..., -4.2695e+00,
          -2.1387e+00, -1.2627e+00],
         [-7.6211e+00,  2.7363e+00,  6.2500e+00,  ..., -4.2656e+00,
          -2.1367e+00, -1.2207e+00],
         [-1.3184e+00, -1.4785e+00, -5.4375e+00,  ...,  5.6875e+00,
          -4.1094e+00, -1.1387e+00]],

        [[ 1.0986e-02, -7.2266e+00, -1.0801e+00,  ..., -1.7783e+00,
           1.6016e-01,  1.5566e+00],
         [ 5.5237e-03, -7.1836e+00, -1.0928e+00,  ..., -1.7891e+00,
           1.5344e-01,  1.5850e+00],
         [ 1.0352e-01, -3.8516e+00, -3.4961e+00,  ...,  2.5840e+00,
          -2.3457

Epoch 0:  21%|██▏       | 1257/5879 [13:37<50:05,  1.54it/s, v_num=7]
 tensor([[[-7.6211,  2.7520,  6.3008,  ..., -4.2500, -2.1465, -1.2549],
         [-7.6211,  2.7520,  6.3008,  ..., -4.2500, -2.1465, -1.2549],
         [-1.2939, -1.5215, -5.4453,  ...,  5.6680, -4.1250, -1.1318]],

        [[-5.7930,  1.8984,  4.0352,  ..., -3.8652, -1.9326,  0.2710],
         [-5.8477,  1.9268,  4.1016,  ..., -3.8789, -1.9424,  0.2286],
         [-1.2256, -1.5967, -5.4102,  ...,  5.5742, -4.0977, -1.1260]],

        [[-7.8750,  2.8594,  6.6484,  ..., -4.2852, -2.1426, -1.5098],
         [-7.8750,  2.8594,  6.6484,  ..., -4.2852, -2.1426, -1.5098],
         [-1.3193, -1.4961, -5.4570,  ...,  5.7031, -4.1289, -1.1338]],

        ...,

        [[-7.9258,  2.8398,  6.7461,  ..., -4.2383, -2.0469, -1.7061],
         [-7.9258,  2.8398,  6.7461,  ..., -4.2383, -2.0469, -1.7061],
         [-1.3467, -1.4746, -5.4648,  ...,  5.7461, -4.1250, -1.1348]],

        [[-7.9375,  2.8750,  6.7344,  ..., -4.2852, -2.

Epoch 0:  21%|██▏       | 1263/5879 [13:48<50:26,  1.53it/s, v_num=7]
 tensor([[[-7.8320,  2.8555,  6.6367,  ..., -4.2695, -2.1582, -1.5029],
         [-7.8320,  2.8574,  6.6406,  ..., -4.2695, -2.1562, -1.5049],
         [-1.3154, -1.5186, -5.4766,  ...,  5.7188, -4.1445, -1.1279]],

        [[-7.3125,  2.6250,  5.9375,  ..., -4.1836, -2.1367, -0.9976],
         [-7.3125,  2.6250,  5.9375,  ..., -4.1836, -2.1367, -0.9980],
         [-1.2578, -1.5781, -5.4492,  ...,  5.6367, -4.1328, -1.1240]],

        [[-7.9180,  2.8730,  6.7656,  ..., -4.2617, -2.1133, -1.6514],
         [-7.9180,  2.8730,  6.7656,  ..., -4.2617, -2.1133, -1.6504],
         [-1.3320, -1.5068, -5.4805,  ...,  5.7422, -4.1406, -1.1299]],

        ...,

        [[-7.8203,  2.8516,  6.6172,  ..., -4.2695, -2.1582, -1.4902],
         [-7.8203,  2.8516,  6.6172,  ..., -4.2695, -2.1582, -1.4912],
         [-1.3057, -1.5283, -5.4727,  ...,  5.7031, -4.1445, -1.1279]],

        [[-7.4531,  2.6914,  6.1289,  ..., -4.2070, -2.

Epoch 0:  22%|██▏       | 1269/5879 [13:53<50:28,  1.52it/s, v_num=7]
 tensor([[[-5.4766,  1.7520,  3.7246,  ..., -3.7207, -1.8887,  0.4685],
         [-5.4570,  1.7422,  3.7051,  ..., -3.7168, -1.8857,  0.4814],
         [-1.2480, -1.6055, -5.4648,  ...,  5.6406, -4.1484, -1.1191]],

        [[-7.6289,  2.7832,  6.4062,  ..., -4.2266, -2.1680, -1.3350],
         [-7.6289,  2.7852,  6.4062,  ..., -4.2266, -2.1680, -1.3369],
         [-1.2764, -1.5752, -5.4766,  ...,  5.6797, -4.1562, -1.1211]],

        [[-2.2578, -0.5454,  0.0681,  ..., -2.8164, -1.0938,  2.3750],
         [-2.2871, -0.5161,  0.0983,  ..., -2.8242, -1.1025,  2.3633],
         [-1.2178, -1.6377, -5.4492,  ...,  5.5977, -4.1367, -1.1172]],

        ...,

        [[-7.9023,  2.8848,  6.7969,  ..., -4.2539, -2.1289, -1.6758],
         [-7.9023,  2.8848,  6.7969,  ..., -4.2539, -2.1289, -1.6758],
         [-1.3154, -1.5391, -5.4961,  ...,  5.7383, -4.1602, -1.1240]],

        [[-0.0169, -7.4375, -0.9565,  ..., -1.7695,  0.

Epoch 0:  22%|██▏       | 1275/5879 [13:55<50:16,  1.53it/s, v_num=7]
 tensor([[[-7.5664,  2.7695,  6.3711,  ..., -4.1992, -2.1777, -1.3115],
         [-7.5664,  2.7695,  6.3711,  ..., -4.1992, -2.1777, -1.3125],
         [-1.2725, -1.5977, -5.4922,  ...,  5.6914, -4.1758, -1.1143]],

        [[-5.8906,  1.9775,  4.2656,  ..., -3.7949, -1.9736,  0.1432],
         [-5.9219,  1.9922,  4.3008,  ..., -3.8027, -1.9785,  0.1218],
         [-1.1270, -1.7598, -5.4102,  ...,  5.4766, -4.1133, -1.1074]],

        [[-0.7773, -3.1777, -1.3691,  ..., -2.2676, -0.4509,  2.7520],
         [-0.7905, -3.1387, -1.3613,  ..., -2.2734, -0.4592,  2.7559],
         [-0.8760, -2.0801, -5.2031,  ...,  5.0625, -3.9219, -1.0967]],

        ...,

        [[-7.8867,  2.8945,  6.8164,  ..., -4.2461, -2.1465, -1.6924],
         [-7.8867,  2.8945,  6.8164,  ..., -4.2461, -2.1465, -1.6924],
         [-1.3018, -1.5703, -5.5039,  ...,  5.7383, -4.1758, -1.1162]],

        [[-1.7695, -1.1562, -0.4473,  ..., -2.6113, -0.

Epoch 0:  22%|██▏       | 1281/5879 [13:57<50:05,  1.53it/s, v_num=7]
 tensor([[[-1.5195, -1.5420, -0.7227,  ..., -2.4941, -0.8042,  2.6875],
         [-1.5244, -1.5342, -0.7178,  ..., -2.4961, -0.8062,  2.6855],
         [-1.0449, -1.8730, -5.3672,  ...,  5.3711, -4.0820, -1.0947]],

        [[-7.7461,  2.8652,  6.6680,  ..., -4.2266, -2.1895, -1.5439],
         [-7.7461,  2.8652,  6.6680,  ..., -4.2266, -2.1895, -1.5439],
         [-1.2852, -1.6035, -5.5156,  ...,  5.7383, -4.1953, -1.1074]],

        [[-6.7734,  2.4180,  5.3906,  ..., -4.0000, -2.1152, -0.6118],
         [-6.7656,  2.4141,  5.3789,  ..., -3.9980, -2.1152, -0.6045],
         [-1.2578, -1.6289, -5.5039,  ...,  5.6953, -4.1914, -1.1055]],

        ...,

        [[-6.2969,  2.0723,  5.2031,  ..., -2.8867, -1.2197, -2.4863],
         [-6.3047,  2.0762,  5.2148,  ..., -2.8945, -1.2227, -2.4863],
         [-1.2979, -1.5996, -5.5156,  ...,  5.7578, -4.1797, -1.1094]],

        [[-7.5859,  2.7949,  6.4492,  ..., -4.1914, -2.

Epoch 0:  22%|██▏       | 1287/5879 [14:04<50:14,  1.52it/s, v_num=7]
 tensor([[[-3.9980,  0.8999,  2.0801,  ..., -3.1895, -1.5732,  1.4482],
         [-4.1016,  0.9688,  2.2012,  ..., -3.2188, -1.5996,  1.3818],
         [-1.1396, -1.7754, -5.4531,  ...,  5.5391, -4.1641, -1.0928]],

        [[-7.4648,  2.7559,  6.3359,  ..., -4.1523, -2.1934, -1.2930],
         [-7.4648,  2.7559,  6.3359,  ..., -4.1523, -2.1934, -1.2930],
         [-1.2578, -1.6475, -5.5195,  ...,  5.7188, -4.2109, -1.0967]],

        [[-7.5234,  2.7812,  6.4141,  ..., -4.1641, -2.1992, -1.3467],
         [-7.5195,  2.7793,  6.4102,  ..., -4.1641, -2.1973, -1.3457],
         [-1.2549, -1.6494, -5.5195,  ...,  5.7148, -4.2109, -1.0967]],

        ...,

        [[-7.8125,  2.9004,  6.8125,  ..., -4.2266, -2.1855, -1.6797],
         [-7.8125,  2.9004,  6.8125,  ..., -4.2266, -2.1855, -1.6797],
         [-1.2734, -1.6328, -5.5273,  ...,  5.7461, -4.2109, -1.0977]],

        [[-7.8438,  2.9023,  6.8750,  ..., -4.2148, -2.

Epoch 0:  22%|██▏       | 1293/5879 [14:09<50:14,  1.52it/s, v_num=7]
 tensor([[[-7.6172,  2.8398,  6.5938,  ..., -4.1797, -2.2070, -1.4932],
         [-7.6211,  2.8398,  6.5938,  ..., -4.1797, -2.2070, -1.4932],
         [-1.2471, -1.6758, -5.5352,  ...,  5.7266, -4.2266, -1.0859]],

        [[-6.8203,  2.4609,  5.5352,  ..., -3.9727, -2.1348, -0.7144],
         [-6.8203,  2.4609,  5.5312,  ..., -3.9727, -2.1348, -0.7134],
         [-1.2100, -1.7129, -5.5156,  ...,  5.6680, -4.2188, -1.0859]],

        [[-6.6094,  2.3594,  5.2617,  ..., -3.9141, -2.1074, -0.5234],
         [-6.5938,  2.3516,  5.2461,  ..., -3.9102, -2.1035, -0.5107],
         [-1.0771, -1.8662, -5.4297,  ...,  5.4570, -4.1523, -1.0830]],

        ...,

        [[-7.8086,  2.8867,  6.8984,  ..., -4.1797, -2.1133, -1.8701],
         [-7.8086,  2.8867,  6.8984,  ..., -4.1797, -2.1133, -1.8691],
         [-1.2695, -1.6611, -5.5430,  ...,  5.7656, -4.2188, -1.0889]],

        [[-7.7969,  2.9082,  6.8516,  ..., -4.2109, -2.

Epoch 0:  22%|██▏       | 1299/5879 [14:11<50:01,  1.53it/s, v_num=7]
 tensor([[[-7.6719,  2.8730,  6.7070,  ..., -4.1836, -2.2109, -1.5986],
         [-7.6680,  2.8711,  6.7070,  ..., -4.1836, -2.2109, -1.5967],
         [-1.2412, -1.7012, -5.5508,  ...,  5.7383, -4.2461, -1.0771]],

        [[-7.4336,  2.7637,  6.3789,  ..., -4.1211, -2.2031, -1.3438],
         [-7.4297,  2.7637,  6.3750,  ..., -4.1211, -2.2031, -1.3398],
         [-1.2256, -1.7168, -5.5430,  ...,  5.7148, -4.2422, -1.0771]],

        [[-7.7656,  2.9102,  6.8477,  ..., -4.1992, -2.1973, -1.7314],
         [-7.7617,  2.9082,  6.8398,  ..., -4.1992, -2.1973, -1.7256],
         [-1.2490, -1.6943, -5.5547,  ...,  5.7539, -4.2422, -1.0771]],

        ...,

        [[-7.3672,  2.7324,  6.2891,  ..., -4.1055, -2.1992, -1.2744],
         [-7.3672,  2.7324,  6.2891,  ..., -4.1055, -2.1992, -1.2744],
         [-1.2256, -1.7158, -5.5430,  ...,  5.7148, -4.2422, -1.0762]],

        [[-7.0195,  2.5703,  5.8359,  ..., -4.0078, -2.

Epoch 0:  22%|██▏       | 1305/5879 [14:12<49:49,  1.53it/s, v_num=7]
 tensor([[[-7.5156,  2.8184,  6.5391,  ..., -4.1328, -2.2148, -1.4746],
         [-7.5156,  2.8184,  6.5391,  ..., -4.1328, -2.2148, -1.4746],
         [-1.2402, -1.7217, -5.5664,  ...,  5.7617, -4.2617, -1.0654]],

        [[-6.5469,  2.3457,  5.2656,  ..., -3.8477, -2.1074, -0.5273],
         [-6.5547,  2.3496,  5.2734,  ..., -3.8516, -2.1074, -0.5347],
         [-1.1855, -1.7725, -5.5430,  ...,  5.6719, -4.2539, -1.0645]],

        [[-7.7148,  2.8555,  6.8984,  ..., -4.1172, -2.0605, -2.0039],
         [-7.7148,  2.8555,  6.8984,  ..., -4.1172, -2.0605, -2.0039],
         [-1.2402, -1.7217, -5.5703,  ...,  5.7617, -4.2617, -1.0654]],

        ...,

        [[-7.5820,  2.8477,  6.6328,  ..., -4.1523, -2.2168, -1.5508],
         [-7.5859,  2.8496,  6.6328,  ..., -4.1523, -2.2168, -1.5518],
         [-1.2344, -1.7266, -5.5664,  ...,  5.7500, -4.2617, -1.0654]],

        [[-0.1009, -7.5430, -0.7896,  ..., -1.8955,  0.

Epoch 0:  22%|██▏       | 1311/5879 [14:20<49:57,  1.52it/s, v_num=7]
 tensor([[[-7.7305,  2.9199,  6.8984,  ..., -4.1797, -2.1934, -1.8252],
         [-7.7305,  2.9199,  6.8984,  ..., -4.1797, -2.1934, -1.8252],
         [-1.2363, -1.7510, -5.5742,  ...,  5.7695, -4.2617, -1.0557]],

        [[-0.6641, -3.7832, -1.4160,  ..., -2.0449, -0.1780,  2.7227],
         [-0.6543, -3.8223, -1.4189,  ..., -2.0430, -0.1705,  2.7168],
         [-0.2421, -3.2285, -4.2852,  ...,  3.4824, -3.0879, -1.0381]],

        [[-4.8555,  1.4385,  3.1934,  ..., -3.2812, -1.7744,  0.8379],
         [-4.8477,  1.4336,  3.1816,  ..., -3.2773, -1.7725,  0.8447],
         [-1.0615, -1.9326, -5.4766,  ...,  5.4766, -4.2148, -1.0547]],

        ...,

        [[-7.7070,  2.9121,  6.8633,  ..., -4.1719, -2.2031, -1.7812],
         [-7.7070,  2.9121,  6.8633,  ..., -4.1719, -2.2012, -1.7842],
         [-1.2070, -1.7695, -5.5703,  ...,  5.7227, -4.2812, -1.0537]],

        [[-6.9492,  2.5566,  5.8203,  ..., -3.9531, -2.

Epoch 0:  22%|██▏       | 1317/5879 [14:25<49:59,  1.52it/s, v_num=7]
 tensor([[[-7.4141e+00,  2.7988e+00,  6.4883e+00,  ..., -4.0820e+00,
          -2.2266e+00, -1.4531e+00],
         [-7.4141e+00,  2.7988e+00,  6.4883e+00,  ..., -4.0820e+00,
          -2.2266e+00, -1.4531e+00],
         [-1.1807e+00, -1.8135e+00, -5.5742e+00,  ...,  5.6992e+00,
          -4.3008e+00, -1.0430e+00]],

        [[-7.6797e+00,  2.9160e+00,  6.8672e+00,  ..., -4.1602e+00,
          -2.2148e+00, -1.8008e+00],
         [-7.6797e+00,  2.9160e+00,  6.8711e+00,  ..., -4.1602e+00,
          -2.2129e+00, -1.8018e+00],
         [-1.2168e+00, -1.7793e+00, -5.5938e+00,  ...,  5.7617e+00,
          -4.3047e+00, -1.0430e+00]],

        [[-2.0312e+00, -8.9600e-01, -1.5454e-01,  ..., -2.3652e+00,
          -9.0918e-01,  2.5684e+00],
         [-2.0391e+00, -8.8525e-01, -1.4465e-01,  ..., -2.3672e+00,
          -9.1260e-01,  2.5664e+00],
         [-9.4238e-01, -2.0957e+00, -5.3906e+00,  ...,  5.2695e+00,
          -4.1523

Epoch 0:  23%|██▎       | 1323/5879 [14:27<49:47,  1.53it/s, v_num=7]
 tensor([[[-7.6836,  2.9141,  6.9531,  ..., -4.1367, -2.1641, -1.9805],
         [-7.6836,  2.9141,  6.9531,  ..., -4.1367, -2.1641, -1.9805],
         [-1.2070, -1.8057, -5.6016,  ...,  5.7734, -4.3242, -1.0312]],

        [[-7.5859,  2.8320,  6.8828,  ..., -4.0547, -2.0352, -2.1387],
         [-7.5859,  2.8301,  6.8828,  ..., -4.0508, -2.0332, -2.1387],
         [-1.2090, -1.8047, -5.6016,  ...,  5.7695, -4.3203, -1.0312]],

        [[-7.6641,  2.9199,  6.9023,  ..., -4.1484, -2.2070, -1.8672],
         [-7.6641,  2.9199,  6.9023,  ..., -4.1484, -2.2070, -1.8672],
         [-1.1973, -1.8125, -5.6016,  ...,  5.7539, -4.3281, -1.0312]],

        ...,

        [[-7.6719,  2.9199,  6.9180,  ..., -4.1484, -2.1992, -1.8926],
         [-7.6719,  2.9199,  6.9180,  ..., -4.1484, -2.1992, -1.8926],
         [-1.2119, -1.8057, -5.5977,  ...,  5.7773, -4.3086, -1.0322]],

        [[-0.3467, -5.5781, -1.2861,  ..., -2.0020,  0.

Epoch 0:  23%|██▎       | 1329/5879 [14:29<49:36,  1.53it/s, v_num=7]
 tensor([[[-7.5664,  2.8984,  6.7969,  ..., -4.1211, -2.2402, -1.7520],
         [-7.5664,  2.8984,  6.7969,  ..., -4.1211, -2.2402, -1.7510],
         [-1.1807, -1.8447, -5.6055,  ...,  5.7500, -4.3477, -1.0205]],

        [[-7.1055,  2.5586,  6.4609,  ..., -3.6953, -1.6992, -2.5234],
         [-7.1055,  2.5586,  6.4609,  ..., -3.6953, -1.6992, -2.5234],
         [-1.2012, -1.8379, -5.5977,  ...,  5.7773, -4.3164, -1.0254]],

        [[-7.1172,  2.6758,  6.1641,  ..., -3.9668, -2.2168, -1.2217],
         [-7.1133,  2.6758,  6.1602,  ..., -3.9668, -2.2168, -1.2188],
         [-1.1758, -1.8496, -5.5977,  ...,  5.7422, -4.3477, -1.0195]],

        ...,

        [[-7.4648,  2.8516,  6.6562,  ..., -4.0859, -2.2441, -1.6143],
         [-7.4648,  2.8535,  6.6562,  ..., -4.0859, -2.2441, -1.6152],
         [-1.1846, -1.8408, -5.6055,  ...,  5.7617, -4.3477, -1.0205]],

        [[-7.6523,  2.9277,  6.9297,  ..., -4.1445, -2.

Epoch 0:  23%|██▎       | 1335/5879 [14:37<49:48,  1.52it/s, v_num=7]
 tensor([[[-7.5117,  2.8887,  6.7695,  ..., -4.1016, -2.2500, -1.7363],
         [-7.5156,  2.8906,  6.7734,  ..., -4.1016, -2.2480, -1.7402],
         [-1.1826, -1.8604, -5.6172,  ...,  5.7852, -4.3672, -1.0107]],

        [[-7.3594,  2.8164,  6.5547,  ..., -4.0469, -2.2480, -1.5410],
         [-7.3633,  2.8164,  6.5586,  ..., -4.0469, -2.2480, -1.5430],
         [-1.1768, -1.8652, -5.6172,  ...,  5.7773, -4.3711, -1.0107]],

        [[-7.6328,  2.9258,  6.9688,  ..., -4.1289, -2.1934, -2.0117],
         [-7.6328,  2.9258,  6.9688,  ..., -4.1289, -2.1934, -2.0117],
         [-1.1865, -1.8594, -5.6172,  ...,  5.7930, -4.3594, -1.0107]],

        ...,

        [[-7.1758,  2.7227,  6.3008,  ..., -3.9766, -2.2344, -1.3242],
         [-7.1758,  2.7246,  6.3008,  ..., -3.9766, -2.2344, -1.3252],
         [-1.1494, -1.8906, -5.6016,  ...,  5.7266, -4.3672, -1.0117]],

        [[-6.2656,  2.2480,  5.0977,  ..., -3.6328, -2.

Epoch 0:  23%|██▎       | 1341/5879 [14:41<49:43,  1.52it/s, v_num=7]
 tensor([[[-7.5703,  2.9238,  6.9102,  ..., -4.1133, -2.2363, -1.9072],
         [-7.5703,  2.9238,  6.9102,  ..., -4.1133, -2.2363, -1.9082],
         [-1.1729, -1.8877, -5.6250,  ...,  5.8047, -4.3828, -1.0010]],

        [[-5.2227,  1.6680,  3.7930,  ..., -3.2188, -1.8711,  0.4951],
         [-5.3242,  1.7285,  3.9238,  ..., -3.2578, -1.8955,  0.4084],
         [-1.1191, -1.9375, -5.5977,  ...,  5.6953, -4.3789, -1.0039]],

        [[-7.4102,  2.8535,  6.6719,  ..., -4.0547, -2.2539, -1.6533],
         [-7.4023,  2.8496,  6.6641,  ..., -4.0547, -2.2539, -1.6455],
         [-1.1729, -1.8867, -5.6250,  ...,  5.8008, -4.3828, -1.0000]],

        ...,

        [[-0.2037, -7.2500, -0.7637,  ..., -2.1230,  0.6055,  1.0518],
         [-0.2032, -7.2500, -0.7632,  ..., -2.1230,  0.6060,  1.0508],
         [ 0.1101, -6.9023, -1.0605,  ..., -1.2188, -0.0522, -0.4456]],

        [[-7.4766,  2.8848,  6.7656,  ..., -4.0820, -2.

Epoch 0:  23%|██▎       | 1347/5879 [14:43<49:32,  1.52it/s, v_num=7]
 tensor([[[-7.5703,  2.9238,  6.9766,  ..., -4.1016, -2.2012, -2.0625],
         [-7.5703,  2.9238,  6.9805,  ..., -4.1016, -2.2012, -2.0625],
         [-1.1582, -1.9160, -5.6289,  ...,  5.8047, -4.4102, -0.9888]],

        [[-6.9844,  2.6484,  6.1289,  ..., -3.8789, -2.2246, -1.2031],
         [-6.9844,  2.6484,  6.1289,  ..., -3.8809, -2.2246, -1.2041],
         [-1.1211, -1.9512, -5.6094,  ...,  5.7305, -4.4062, -0.9922]],

        [[-6.1172,  2.1797,  4.9766,  ..., -3.5312, -2.0742, -0.3203],
         [-6.1172,  2.1797,  4.9766,  ..., -3.5312, -2.0742, -0.3215],
         [-1.1289, -1.9414, -5.6133,  ...,  5.7461, -4.4062, -0.9907]],

        ...,

        [[-7.5312,  2.8867,  6.9648,  ..., -4.0703, -2.1309, -2.1738],
         [-7.5312,  2.8867,  6.9648,  ..., -4.0703, -2.1289, -2.1738],
         [-1.1650, -1.9131, -5.6250,  ...,  5.8125, -4.3984, -0.9893]],

        [[-7.1445,  2.7324,  6.3477,  ..., -3.9453, -2.

Epoch 0:  23%|██▎       | 1353/5879 [14:45<49:20,  1.53it/s, v_num=7]
 tensor([[[-6.9102,  2.6172,  6.0625,  ..., -3.8379, -2.2227, -1.1562],
         [-6.9023,  2.6152,  6.0547,  ..., -3.8359, -2.2227, -1.1504],
         [-1.1182, -1.9688, -5.6211,  ...,  5.7539, -4.4297, -0.9775]],

        [[-7.5391,  2.9258,  6.9805,  ..., -4.0938, -2.2129, -2.0801],
         [-7.5391,  2.9277,  6.9805,  ..., -4.0977, -2.2148, -2.0723],
         [-1.1523, -1.9404, -5.6328,  ...,  5.8242, -4.4219, -0.9751]],

        [[-6.4180,  2.3535,  5.4102,  ..., -3.6309, -2.1406, -0.6338],
         [-6.4258,  2.3574,  5.4180,  ..., -3.6328, -2.1426, -0.6406],
         [-1.0361, -2.0605, -5.5586,  ...,  5.5859, -4.3867, -0.9873]],

        ...,

        [[-6.9688,  2.6504,  6.1445,  ..., -3.8633, -2.2305, -1.2207],
         [-6.9766,  2.6562,  6.1562,  ..., -3.8672, -2.2324, -1.2324],
         [-1.1152, -1.9717, -5.6172,  ...,  5.7500, -4.4297, -0.9780]],

        [[-7.3555,  2.8516,  6.6875,  ..., -4.0273, -2.

Epoch 0:  23%|██▎       | 1359/5879 [14:55<49:38,  1.52it/s, v_num=7]
 tensor([[[-6.0938,  2.1816,  5.0195,  ..., -3.4727, -2.0820, -0.3418],
         [-6.0977,  2.1816,  5.0234,  ..., -3.4727, -2.0820, -0.3433],
         [-1.1094, -1.9932, -5.6289,  ...,  5.7695, -4.4570, -0.9629]],

        [[-7.2930,  2.8340,  6.6445,  ..., -3.9961, -2.2676, -1.6699],
         [-7.2930,  2.8340,  6.6484,  ..., -3.9961, -2.2676, -1.6709],
         [-1.1279, -1.9746, -5.6367,  ...,  5.8086, -4.4570, -0.9604]],

        [[-6.4102,  2.3555,  5.4297,  ..., -3.6113, -2.1426, -0.6597],
         [-6.4102,  2.3555,  5.4297,  ..., -3.6113, -2.1426, -0.6602],
         [-1.1123, -1.9902, -5.6289,  ...,  5.7773, -4.4570, -0.9629]],

        ...,

        [[-0.2201, -7.3477, -0.6626,  ..., -2.2656,  0.6812,  0.7993],
         [-0.2196, -7.3477, -0.6626,  ..., -2.2676,  0.6812,  0.7979],
         [-0.1250, -7.6523, -0.5542,  ..., -2.2871,  0.6035,  0.1086]],

        [[-7.5000,  2.9102,  7.0000,  ..., -4.0742, -2.

Epoch 0:  23%|██▎       | 1365/5879 [14:57<49:28,  1.52it/s, v_num=7]
 tensor([[[-7.4336,  2.9141,  6.8984,  ..., -4.0586, -2.2637, -1.9658],
         [-7.4336,  2.9141,  6.8984,  ..., -4.0586, -2.2637, -1.9658],
         [-1.1152, -2.0039, -5.6367,  ...,  5.8125, -4.4766, -0.9453]],

        [[-7.4688,  2.9219,  6.9727,  ..., -4.0703, -2.2227, -2.1152],
         [-7.4688,  2.9219,  6.9727,  ..., -4.0703, -2.2246, -2.1152],
         [-1.1318, -1.9912, -5.6367,  ...,  5.8438, -4.4648, -0.9438]],

        [[-6.0391,  2.1543,  4.9805,  ..., -3.4238, -2.0703, -0.3140],
         [-6.0391,  2.1543,  4.9805,  ..., -3.4258, -2.0703, -0.3159],
         [-1.0615, -2.0605, -5.5977,  ...,  5.6953, -4.4570, -0.9546]],

        ...,

        [[-7.4531,  2.9199,  6.9414,  ..., -4.0664, -2.2461, -2.0391],
         [-7.4531,  2.9199,  6.9414,  ..., -4.0664, -2.2461, -2.0391],
         [-1.1318, -1.9922, -5.6406,  ...,  5.8438, -4.4648, -0.9443]],

        [[-6.3086,  2.2598,  5.8086,  ..., -3.2422, -1.

Epoch 0:  23%|██▎       | 1371/5879 [14:59<49:17,  1.52it/s, v_num=7]
 tensor([[[-5.3984,  1.7842,  4.1758,  ..., -3.1094, -1.9209,  0.2888],
         [-5.3945,  1.7822,  4.1719,  ..., -3.1074, -1.9199,  0.2917],
         [-0.9756, -2.1738, -5.5352,  ...,  5.5312, -4.4297, -0.9531]],

        [[-4.9531,  1.5088,  3.5996,  ..., -2.9238, -1.8027,  0.6821],
         [-4.9570,  1.5088,  3.5996,  ..., -2.9238, -1.8037,  0.6821],
         [-0.9282, -2.2324, -5.4844,  ...,  5.4258, -4.3828, -0.9600]],

        [[-7.4062,  2.9141,  6.9062,  ..., -4.0430, -2.2598, -2.0117],
         [-7.4062,  2.9141,  6.9062,  ..., -4.0430, -2.2598, -2.0137],
         [-1.1094, -2.0254, -5.6406,  ...,  5.8281, -4.5000, -0.9272]],

        ...,

        [[-7.3008,  2.8652,  6.7383,  ..., -3.9961, -2.2793, -1.8018],
         [-7.3047,  2.8652,  6.7422,  ..., -3.9980, -2.2793, -1.8057],
         [-1.1172, -2.0195, -5.6445,  ...,  5.8477, -4.4961, -0.9263]],

        [[-5.0898,  1.5947,  3.7812,  ..., -2.9707, -1.

Epoch 0:  23%|██▎       | 1377/5879 [15:01<49:06,  1.53it/s, v_num=7]
 tensor([[[-7.3789,  2.9062,  6.9414,  ..., -4.0312, -2.2246, -2.1504],
         [-7.3828,  2.9062,  6.9453,  ..., -4.0312, -2.2227, -2.1582],
         [-1.1123, -2.0391, -5.6445,  ...,  5.8594, -4.5156, -0.9087]],

        [[-7.2109,  2.8301,  6.6523,  ..., -3.9473, -2.2852, -1.7285],
         [-7.2148,  2.8320,  6.6523,  ..., -3.9492, -2.2852, -1.7314],
         [-1.1113, -2.0391, -5.6445,  ...,  5.8555, -4.5156, -0.9082]],

        [[-1.4307, -1.5312, -4.8672,  ...,  4.8555, -3.8359, -1.4043],
         [-1.4307, -1.5322, -4.8672,  ...,  4.8555, -3.8359, -1.4033],
         [-1.1182, -2.0430, -5.6172,  ...,  5.8320, -4.4688, -0.9229]],

        ...,

        [[-7.2461,  2.8496,  6.7031,  ..., -3.9668, -2.2832, -1.7852],
         [-7.2461,  2.8496,  6.7031,  ..., -3.9668, -2.2832, -1.7861],
         [-1.0928, -2.0547, -5.6367,  ...,  5.8125, -4.5234, -0.9121]],

        [[-7.2852,  2.8691,  6.7578,  ..., -3.9863, -2.

Epoch 0:  24%|██▎       | 1383/5879 [15:07<49:10,  1.52it/s, v_num=7]
 tensor([[[-2.8438, -0.1920,  0.9224,  ..., -1.9893, -1.0771,  2.3574],
         [-2.8770, -0.1571,  0.9673,  ..., -2.0000, -1.0908,  2.3340],
         [-0.9263, -2.2578, -5.5000,  ...,  5.4570, -4.4414, -0.9302]],

        [[-7.3516,  2.8906,  6.9805,  ..., -4.0234, -2.1582, -2.3984],
         [-7.3516,  2.8906,  6.9805,  ..., -4.0234, -2.1582, -2.3984],
         [-1.1016, -2.0625, -5.6445,  ...,  5.8633, -4.5391, -0.8896]],

        [[-7.1602,  2.8184,  6.6211,  ..., -3.9219, -2.2852, -1.7197],
         [-7.1602,  2.8184,  6.6211,  ..., -3.9219, -2.2852, -1.7207],
         [-1.0830, -2.0762, -5.6367,  ...,  5.8242, -4.5469, -0.8931]],

        ...,

        [[-0.3110, -6.6719, -0.8198,  ..., -2.3086,  0.7173,  1.2773],
         [-0.3110, -6.6719, -0.8198,  ..., -2.3086,  0.7173,  1.2773],
         [ 0.1537, -5.9141, -1.6689,  ..., -0.6040, -0.7153, -0.9385]],

        [[-7.3594,  2.9160,  6.9336,  ..., -4.0312, -2.

Epoch 0:  24%|██▎       | 1389/5879 [15:09<48:59,  1.53it/s, v_num=7]
 tensor([[[-5.6602,  1.9521,  4.6172,  ..., -3.1309, -1.9961, -0.0086],
         [-5.6875,  1.9697,  4.6523,  ..., -3.1445, -2.0039, -0.0376],
         [-0.9927, -2.1875, -5.5625,  ...,  5.6367, -4.5234, -0.8950]],

        [[-0.2825, -6.9922, -0.6890,  ..., -2.4375,  0.7656,  0.9121],
         [-0.2827, -6.9922, -0.6890,  ..., -2.4375,  0.7661,  0.9131],
         [-0.0430, -7.1797, -0.7256,  ..., -2.1582,  0.3311, -0.4761]],

        [[-1.5635, -1.9414, -0.6499,  ..., -1.6309, -0.3894,  2.9746],
         [-1.5820, -1.9082, -0.6294,  ..., -1.6328, -0.4009,  2.9707],
         [-0.8032, -2.4219, -5.3789,  ...,  5.1875, -4.3594, -0.9453]],

        ...,

        [[-5.9492,  2.1270,  4.9961,  ..., -3.2812, -2.0684, -0.3191],
         [-5.9531,  2.1270,  5.0000,  ..., -3.2812, -2.0684, -0.3198],
         [-1.0732, -2.0996, -5.6328,  ...,  5.8281, -4.5703, -0.8755]],

        [[-7.1953,  2.8047,  6.8789,  ..., -3.9414, -2.

Epoch 0:  24%|██▎       | 1395/5879 [15:10<48:47,  1.53it/s, v_num=7]
 tensor([[[-5.3906,  1.7900,  4.2930,  ..., -2.9688, -1.9277,  0.2411],
         [-5.3945,  1.7910,  4.2969,  ..., -2.9707, -1.9277,  0.2389],
         [-1.0449, -2.1387, -5.6211,  ...,  5.7930, -4.5898, -0.8633]],

        [[-7.1992,  2.8613,  6.7656,  ..., -3.9492, -2.2930, -1.9355],
         [-7.1992,  2.8613,  6.7656,  ..., -3.9492, -2.2930, -1.9355],
         [-1.0869, -2.0996, -5.6367,  ...,  5.8906, -4.5859, -0.8521]],

        [[-2.3750, -0.7627,  0.3291,  ..., -1.7490, -0.8296,  2.6738],
         [-2.3750, -0.7617,  0.3303,  ..., -1.7490, -0.8301,  2.6738],
         [-0.7197, -2.5449, -5.2617,  ...,  4.9688, -4.2695, -0.9565]],

        ...,

        [[-4.2617,  1.0215,  2.8008,  ..., -2.4453, -1.5859,  1.2910],
         [-4.2656,  1.0244,  2.8066,  ..., -2.4473, -1.5869,  1.2871],
         [-0.9512, -2.2461, -5.5273,  ...,  5.5625, -4.5234, -0.8906]],

        [[-7.2148,  2.8711,  6.7852,  ..., -3.9609, -2.

Epoch 0:  24%|██▍       | 1401/5879 [15:12<48:38,  1.53it/s, v_num=7]
 tensor([[[-4.5078,  1.1953,  3.1562,  ..., -2.4980, -1.6611,  1.0859],
         [-4.5078,  1.1963,  3.1582,  ..., -2.4980, -1.6611,  1.0850],
         [-0.9619, -2.2422, -5.5469,  ...,  5.6211, -4.5664, -0.8691]],

        [[-6.4922,  2.4668,  5.8047,  ..., -3.5332, -2.2051, -0.9902],
         [-6.4492,  2.4414,  5.7461,  ..., -3.5098, -2.1953, -0.9390],
         [-1.0605, -2.1348, -5.6250,  ...,  5.8594, -4.6172, -0.8374]],

        [[-7.2695,  2.9062,  6.9609,  ..., -4.0078, -2.2383, -2.3379],
         [-7.2695,  2.9062,  6.9609,  ..., -4.0078, -2.2363, -2.3379],
         [-1.0781, -2.1211, -5.6367,  ...,  5.9023, -4.6133, -0.8330]],

        ...,

        [[-7.2109,  2.8809,  6.8398,  ..., -3.9629, -2.2871, -2.0703],
         [-7.2109,  2.8828,  6.8398,  ..., -3.9629, -2.2871, -2.0703],
         [-1.0723, -2.1230, -5.6328,  ...,  5.8906, -4.6172, -0.8330]],

        [[-6.9844,  2.7539,  6.4883,  ..., -3.8203, -2.

Epoch 0:  24%|██▍       | 1407/5879 [15:19<48:42,  1.53it/s, v_num=7]
 tensor([[[-7.1484,  2.8398,  6.9023,  ..., -3.9570, -2.1270, -2.5820],
         [-7.1523,  2.8438,  6.9062,  ..., -3.9609, -2.1328, -2.5723],
         [-1.0674, -2.1406, -5.6289,  ...,  5.9102, -4.6406, -0.8120]],

        [[-7.1680,  2.8730,  6.8203,  ..., -3.9453, -2.2949, -2.0742],
         [-7.1680,  2.8730,  6.8203,  ..., -3.9453, -2.2949, -2.0742],
         [-1.0742, -2.1348, -5.6289,  ...,  5.9180, -4.6289, -0.8115]],

        [[-4.8320,  1.4150,  3.6035,  ..., -2.6191, -1.7588,  0.7734],
         [-4.8398,  1.4219,  3.6172,  ..., -2.6230, -1.7617,  0.7646],
         [-1.0439, -2.1621, -5.6133,  ...,  5.8477, -4.6367, -0.8203]],

        ...,

        [[-6.8594,  2.6953,  6.3516,  ..., -3.7461, -2.2832, -1.5215],
         [-6.8594,  2.6953,  6.3516,  ..., -3.7461, -2.2832, -1.5225],
         [-1.0654, -2.1406, -5.6289,  ...,  5.8984, -4.6406, -0.8135]],

        [[-6.7812,  2.6484,  6.2422,  ..., -3.6992, -2.

Epoch 0:  24%|██▍       | 1413/5879 [15:21<48:32,  1.53it/s, v_num=7]
 tensor([[[-6.8789,  2.7188,  6.4219,  ..., -3.7598, -2.2891, -1.6201],
         [-6.8828,  2.7188,  6.4219,  ..., -3.7598, -2.2891, -1.6221],
         [-1.0166, -2.2012, -5.5898,  ...,  5.8008, -4.6523, -0.8076]],

        [[-6.2617,  2.3418,  5.5586,  ..., -3.3730, -2.1660, -0.7891],
         [-6.2578,  2.3379,  5.5547,  ..., -3.3691, -2.1660, -0.7827],
         [-1.0352, -2.1816, -5.6094,  ...,  5.8516, -4.6641, -0.8018]],

        [[-7.0625,  2.8242,  6.6875,  ..., -3.8789, -2.3105, -1.9160],
         [-7.0625,  2.8242,  6.6875,  ..., -3.8789, -2.3105, -1.9160],
         [-1.0596, -2.1582, -5.6250,  ...,  5.9141, -4.6680, -0.7925]],

        ...,

        [[-7.0625,  2.8281,  6.7070,  ..., -3.8809, -2.2969, -1.9590],
         [-7.0664,  2.8281,  6.7070,  ..., -3.8828, -2.2949, -1.9629],
         [-1.0225, -2.1934, -5.6016,  ...,  5.8203, -4.6641, -0.8062]],

        [[-6.5781,  2.5332,  5.9883,  ..., -3.5684, -2.

Epoch 0:  24%|██▍       | 1419/5879 [15:23<48:21,  1.54it/s, v_num=7]
 tensor([[[-3.0840, -0.0569,  1.2939,  ..., -1.7188, -1.0928,  2.3125],
         [-3.0840, -0.0555,  1.2959,  ..., -1.7197, -1.0928,  2.3105],
         [-1.0107, -2.2168, -5.5859,  ...,  5.8086, -4.6758, -0.7876]],

        [[-7.1406,  2.8828,  6.8828,  ..., -3.9531, -2.2891, -2.2637],
         [-7.1406,  2.8828,  6.8828,  ..., -3.9531, -2.2891, -2.2637],
         [-1.0557, -2.1719, -5.6211,  ...,  5.9297, -4.6875, -0.7690]],

        [[-5.6328,  1.9492,  4.7305,  ..., -2.9863, -2.0039, -0.0848],
         [-5.6094,  1.9336,  4.6992,  ..., -2.9727, -1.9971, -0.0585],
         [-0.9868, -2.2422, -5.5703,  ...,  5.7500, -4.6680, -0.7974]],

        ...,

        [[-1.1455, -2.9590, -1.0137,  ..., -1.4092,  0.0740,  3.0273],
         [-1.1455, -2.9590, -1.0137,  ..., -1.4092,  0.0737,  3.0273],
         [-0.8682, -2.3809, -5.4336,  ...,  5.4336, -4.5508, -0.8413]],

        [[-6.9609,  2.7773,  6.5781,  ..., -3.8145, -2.

Epoch 0:  24%|██▍       | 1425/5879 [15:24<48:09,  1.54it/s, v_num=7]
 tensor([[[-6.2148,  2.3262,  5.5664,  ..., -3.3184, -2.1680, -0.8027],
         [-6.2188,  2.3281,  5.5703,  ..., -3.3203, -2.1699, -0.8076],
         [-0.9438, -2.2988, -5.5273,  ...,  5.6602, -4.6641, -0.7935]],

        [[-7.1055,  2.8672,  6.9297,  ..., -3.9688, -2.2168, -2.5859],
         [-7.1055,  2.8672,  6.9297,  ..., -3.9688, -2.2168, -2.5859],
         [-1.0586, -2.1816, -5.6094,  ...,  5.9531, -4.6992, -0.7451]],

        [[-5.6680,  1.9717,  4.8086,  ..., -2.9824, -2.0176, -0.1472],
         [-5.6719,  1.9736,  4.8125,  ..., -2.9844, -2.0176, -0.1500],
         [-1.0195, -2.2168, -5.5898,  ...,  5.8555, -4.7070, -0.7603]],

        ...,

        [[-2.7227,  0.5435,  0.7046,  ..., -0.4050, -1.0254, -3.0312],
         [-2.6504,  0.4939,  0.5508,  ..., -0.3044, -1.0508, -3.0000],
         [-1.0605, -2.1836, -5.5938,  ...,  5.9375, -4.6680, -0.7539]],

        [[-7.0820,  2.8516,  6.9180,  ..., -3.9629, -2.

Epoch 0:  24%|██▍       | 1431/5879 [15:37<48:35,  1.53it/s, v_num=7]
 tensor([[[-2.9258, -0.2644,  1.1006,  ..., -1.5557, -0.9727,  2.4648],
         [-2.8438, -0.3567,  0.9863,  ..., -1.5264, -0.9316,  2.5234],
         [-0.6050, -2.7441, -5.0898,  ...,  4.6953, -4.2773, -0.9282]],

        [[-6.0195,  2.1992,  5.3164,  ..., -3.1855, -2.1191, -0.5913],
         [-6.0195,  2.2012,  5.3164,  ..., -3.1855, -2.1211, -0.5928],
         [-1.0195, -2.2266, -5.5859,  ...,  5.8750, -4.7344, -0.7354]],

        [[-6.5703,  2.5566,  6.0977,  ..., -3.5488, -2.2637, -1.3262],
         [-6.5703,  2.5566,  6.0977,  ..., -3.5488, -2.2637, -1.3262],
         [-1.0098, -2.2363, -5.5820,  ...,  5.8555, -4.7344, -0.7397]],

        ...,

        [[-5.0430,  1.5508,  3.9883,  ..., -2.5938, -1.8203,  0.5186],
         [-5.0430,  1.5518,  3.9902,  ..., -2.5938, -1.8203,  0.5171],
         [-0.9844, -2.2637, -5.5508,  ...,  5.7812, -4.7109, -0.7515]],

        [[-5.3047,  1.7324,  4.3555,  ..., -2.7324, -1.

Epoch 0:  24%|██▍       | 1437/5879 [15:39<48:23,  1.53it/s, v_num=7]
 tensor([[[-5.5781,  1.9189,  4.7539,  ..., -2.8828, -1.9941, -0.0895],
         [-5.5859,  1.9238,  4.7656,  ..., -2.8867, -1.9971, -0.0968],
         [-0.9902, -2.2656, -5.5547,  ...,  5.8125, -4.7422, -0.7261]],

        [[-2.5137, -0.7891,  0.5312,  ..., -1.3516, -0.7295,  2.7715],
         [-2.5176, -0.7861,  0.5342,  ..., -1.3525, -0.7310,  2.7695],
         [-0.8135, -2.4707, -5.3750,  ...,  5.3281, -4.5938, -0.8145]],

        [[-6.3125,  2.4004,  5.7656,  ..., -3.3691, -2.2090, -1.0156],
         [-6.3203,  2.4023,  5.7734,  ..., -3.3730, -2.2090, -1.0225],
         [-0.9980, -2.2578, -5.5664,  ...,  5.8398, -4.7539, -0.7217]],

        ...,

        [[-6.8711,  2.7578,  6.5781,  ..., -3.7773, -2.3203, -1.8984],
         [-6.8711,  2.7559,  6.5742,  ..., -3.7754, -2.3203, -1.8926],
         [-1.0312, -2.2246, -5.5898,  ...,  5.9297, -4.7578, -0.7046]],

        [[-6.6797,  2.6348,  6.2891,  ..., -3.6289, -2.

Epoch 0:  25%|██▍       | 1443/5879 [15:41<48:13,  1.53it/s, v_num=7]
 tensor([[[-6.9727,  2.8340,  6.7930,  ..., -3.8809, -2.3262, -2.2598],
         [-6.9727,  2.8340,  6.7930,  ..., -3.8809, -2.3262, -2.2617],
         [-1.0400, -2.2266, -5.5859,  ...,  5.9727, -4.7773, -0.6743]],

        [[-6.9414,  2.8105,  6.8516,  ..., -3.9277, -2.1855, -2.7637],
         [-6.9414,  2.8105,  6.8516,  ..., -3.9258, -2.1875, -2.7637],
         [-1.0361, -2.2305, -5.5859,  ...,  5.9609, -4.7812, -0.6763]],

        [[-6.8203,  2.7402,  6.7500,  ..., -3.8867, -2.0762, -2.9844],
         [-6.8203,  2.7402,  6.7500,  ..., -3.8867, -2.0781, -2.9844],
         [-1.0410, -2.2266, -5.5859,  ...,  5.9766, -4.7734, -0.6733]],

        ...,

        [[-6.0586,  2.2383,  5.4453,  ..., -3.1836, -2.1445, -0.7153],
         [-6.0547,  2.2363,  5.4414,  ..., -3.1816, -2.1426, -0.7129],
         [-1.0176, -2.2461, -5.5703,  ...,  5.9102, -4.7773, -0.6860]],

        [[-7.0039,  2.8516,  6.8438,  ..., -3.9102, -2.

Epoch 0:  25%|██▍       | 1449/5879 [15:43<48:03,  1.54it/s, v_num=7]
 tensor([[[-0.5737, -4.7891, -1.0811,  ..., -1.8027,  0.7173,  2.2930],
         [-0.5737, -4.7852, -1.0811,  ..., -1.8027,  0.7173,  2.2930],
         [-0.8882, -2.3945, -5.4414,  ...,  5.5625, -4.7227, -0.7349]],

        [[-6.8867,  2.7891,  6.6875,  ..., -3.8145, -2.3418, -2.1094],
         [-6.8867,  2.7891,  6.6875,  ..., -3.8145, -2.3418, -2.1094],
         [-1.0332, -2.2422, -5.5742,  ...,  5.9766, -4.8008, -0.6509]],

        [[-6.9805,  2.8496,  6.8906,  ..., -3.9434, -2.2793, -2.6250],
         [-6.9805,  2.8496,  6.8906,  ..., -3.9434, -2.2793, -2.6270],
         [-1.0381, -2.2383, -5.5742,  ...,  5.9844, -4.8008, -0.6494]],

        ...,

        [[-6.1406,  2.2949,  5.5859,  ..., -3.2344, -2.1758, -0.8701],
         [-6.1719,  2.3145,  5.6289,  ..., -3.2559, -2.1836, -0.9106],
         [-1.0264, -2.2480, -5.5703,  ...,  5.9570, -4.8047, -0.6543]],

        [[-6.5859,  2.5898,  6.2266,  ..., -3.5664, -2.

Epoch 0:  25%|██▍       | 1455/5879 [15:55<48:25,  1.52it/s, v_num=7]
 tensor([[[-6.2734,  2.3887,  5.8047,  ..., -3.3242, -2.2227, -1.0938],
         [-6.2734,  2.3887,  5.8047,  ..., -3.3262, -2.2246, -1.0967],
         [-1.0342, -2.2500, -5.5586,  ...,  5.9844, -4.8203, -0.6250]],

        [[-6.3008,  2.4062,  5.8438,  ..., -3.3438, -2.2305, -1.1299],
         [-6.3008,  2.4062,  5.8477,  ..., -3.3438, -2.2305, -1.1318],
         [-0.9985, -2.2832, -5.5391,  ...,  5.8906, -4.8242, -0.6455]],

        [[-5.2734,  1.7021,  4.4062,  ..., -2.6074, -1.8936,  0.2200],
         [-5.2734,  1.7021,  4.4062,  ..., -2.6074, -1.8936,  0.2200],
         [-0.9473, -2.3379, -5.4922,  ...,  5.7461, -4.7969, -0.6772]],

        ...,

        [[-6.8633,  2.7832,  6.6875,  ..., -3.8086, -2.3516, -2.1426],
         [-6.8633,  2.7832,  6.6875,  ..., -3.8086, -2.3516, -2.1426],
         [-1.0361, -2.2480, -5.5625,  ...,  5.9961, -4.8203, -0.6235]],

        [[-4.4648,  1.1055,  3.2949,  ..., -2.0898, -1.

Epoch 0:  25%|██▍       | 1461/5879 [15:56<48:13,  1.53it/s, v_num=7]
 tensor([[[-2.2539, -1.2646,  0.1733,  ..., -1.0439, -0.4331,  3.0234],
         [-2.2539, -1.2637,  0.1746,  ..., -1.0439, -0.4338,  3.0215],
         [-0.8755, -2.4238, -5.3984,  ...,  5.5352, -4.7539, -0.6997]],

        [[-6.7461,  2.7109,  6.5391,  ..., -3.7109, -2.3496, -1.9561],
         [-6.7500,  2.7168,  6.5469,  ..., -3.7168, -2.3496, -1.9717],
         [-1.0225, -2.2676, -5.5430,  ...,  5.9727, -4.8477, -0.6035]],

        [[-6.7969,  2.7480,  6.6211,  ..., -3.7598, -2.3594, -2.0645],
         [-6.7969,  2.7461,  6.6172,  ..., -3.7598, -2.3594, -2.0605],
         [-1.0361, -2.2578, -5.5469,  ...,  6.0039, -4.8359, -0.5977]],

        ...,

        [[-5.7656,  2.0430,  5.1133,  ..., -2.9355, -2.0703, -0.4182],
         [-5.7734,  2.0488,  5.1250,  ..., -2.9395, -2.0723, -0.4285],
         [-0.9160, -2.3789, -5.4492,  ...,  5.6641, -4.8008, -0.6729]],

        [[-6.8281,  2.7695,  6.6719,  ..., -3.7949, -2.

Epoch 0:  25%|██▍       | 1467/5879 [15:58<48:03,  1.53it/s, v_num=7]
 tensor([[[-6.7578,  2.7266,  6.5938,  ..., -3.7344, -2.3633, -2.0625],
         [-6.7578,  2.7305,  6.6016,  ..., -3.7383, -2.3633, -2.0723],
         [-1.0283, -2.2715, -5.5312,  ...,  6.0000, -4.8672, -0.5732]],

        [[-2.8711, -0.4683,  1.0381,  ..., -1.1650, -0.7886,  2.6582],
         [-2.8711, -0.4683,  1.0381,  ..., -1.1650, -0.7886,  2.6582],
         [-0.9565, -2.3418, -5.4766,  ...,  5.7930, -4.8516, -0.6221]],

        [[-5.3047,  1.7236,  4.5039,  ..., -2.5801, -1.9092,  0.1488],
         [-5.3281,  1.7402,  4.5352,  ..., -2.5957, -1.9180,  0.1202],
         [-0.9951, -2.3008, -5.5117,  ...,  5.9062, -4.8711, -0.5952]],

        ...,

        [[-6.0117,  2.3457,  5.9492,  ..., -3.5312, -1.7256, -3.5723],
         [-6.0078,  2.3438,  5.9453,  ..., -3.5293, -1.7256, -3.5723],
         [-1.0361, -2.2656, -5.5273,  ...,  6.0117, -4.8477, -0.5723]],

        [[-4.0859,  0.7729,  2.7910,  ..., -1.7725, -1.

Epoch 0:  25%|██▌       | 1473/5879 [16:00<47:53,  1.53it/s, v_num=7]
 tensor([[[-4.5859,  1.1758,  3.5137,  ..., -2.0449, -1.6152,  1.0088],
         [-4.5586,  1.1533,  3.4727,  ..., -2.0273, -1.6035,  1.0400],
         [-0.9790, -2.3242, -5.4805,  ...,  5.8672, -4.8867, -0.5801]],

        [[-3.6641,  1.1699,  2.6895,  ..., -1.8418, -1.0303, -3.6777],
         [-3.6328,  1.1543,  2.6387,  ..., -1.8145, -1.0283, -3.6699],
         [-1.0361, -2.2715, -5.5039,  ...,  6.0078, -4.8516, -0.5508]],

        [[-6.8438,  2.8027,  6.8164,  ..., -3.8828, -2.3496, -2.5762],
         [-6.8438,  2.8027,  6.8164,  ..., -3.8848, -2.3477, -2.5840],
         [-1.0283, -2.2773, -5.5156,  ...,  6.0117, -4.8867, -0.5454]],

        ...,

        [[-6.5156,  2.5684,  6.2617,  ..., -3.5215, -2.3242, -1.6455],
         [-6.5156,  2.5684,  6.2617,  ..., -3.5215, -2.3242, -1.6465],
         [-0.9790, -2.3242, -5.4805,  ...,  5.8672, -4.8867, -0.5801]],

        [[-2.0938,  0.1248, -0.4751,  ...,  0.1052, -1.

Epoch 0:  25%|██▌       | 1479/5879 [16:12<48:12,  1.52it/s, v_num=7]
 tensor([[[-4.8164,  1.3486,  3.8574,  ..., -2.1699, -1.7080,  0.7339],
         [-4.7539,  1.3018,  3.7715,  ..., -2.1270, -1.6816,  0.8071],
         [-0.9668, -2.3398, -5.4531,  ...,  5.8398, -4.9023, -0.5635]],

        [[-5.6016,  1.9297,  4.9688,  ..., -2.7578, -2.0273, -0.2791],
         [-5.6055,  1.9336,  4.9766,  ..., -2.7617, -2.0273, -0.2859],
         [-1.0078, -2.3008, -5.4883,  ...,  5.9609, -4.9180, -0.5322]],

        [[-4.7109,  1.2656,  3.7012,  ..., -2.0996, -1.6631,  0.8555],
         [-4.7109,  1.2686,  3.7051,  ..., -2.1016, -1.6641,  0.8521],
         [-0.9399, -2.3672, -5.4258,  ...,  5.7578, -4.8867, -0.5840]],

        ...,

        [[-4.7539,  1.2998,  3.7676,  ..., -2.1250, -1.6807,  0.8076],
         [-4.7578,  1.3037,  3.7754,  ..., -2.1289, -1.6836,  0.8018],
         [-0.9746, -2.3320, -5.4609,  ...,  5.8633, -4.9062, -0.5576]],

        [[-2.7754, -0.6514,  0.8901,  ..., -0.9946, -0.

Epoch 0:  25%|██▌       | 1485/5879 [16:13<48:01,  1.53it/s, v_num=7]
 tensor([[[-4.5703,  1.1455,  3.5176,  ..., -1.9668, -1.5938,  1.0205],
         [-4.5625,  1.1396,  3.5059,  ..., -1.9619, -1.5908,  1.0293],
         [-0.9624, -2.3496, -5.4336,  ...,  5.8320, -4.9219, -0.5415]],

        [[-3.1680, -0.1967,  1.4648,  ..., -1.0996, -0.8750,  2.5137],
         [-3.0586, -0.3267,  1.2988,  ..., -1.0488, -0.8086,  2.6074],
         [-0.7383, -2.6035, -5.1719,  ...,  5.1367, -4.7031, -0.7124]],

        [[-6.4531,  2.5332,  6.2266,  ..., -3.4824, -2.3359, -1.6553],
         [-6.4570,  2.5352,  6.2344,  ..., -3.4863, -2.3359, -1.6621],
         [-0.9941, -2.3184, -5.4609,  ...,  5.9297, -4.9336, -0.5166]],

        ...,

        [[-6.7188,  2.7285,  6.6641,  ..., -3.7637, -2.3887, -2.3047],
         [-6.7188,  2.7285,  6.6641,  ..., -3.7637, -2.3887, -2.3047],
         [-1.0234, -2.2930, -5.4805,  ...,  6.0156, -4.9336, -0.4941]],

        [[-6.7383,  2.7656,  6.8047,  ..., -3.9180, -2.

Epoch 0:  25%|██▌       | 1491/5879 [16:15<47:49,  1.53it/s, v_num=7]
 tensor([[[-4.2461,  0.8618,  3.0684,  ..., -1.6924, -1.4287,  1.4170],
         [-4.2578,  0.8711,  3.0840,  ..., -1.6992, -1.4336,  1.4053],
         [-0.9268, -2.3906, -5.3789,  ...,  5.7305, -4.9219, -0.5454]],

        [[-6.7344,  2.7637,  6.7930,  ..., -3.8848, -2.3574, -2.7793],
         [-6.7344,  2.7637,  6.7930,  ..., -3.8848, -2.3594, -2.7812],
         [-1.0254, -2.2969, -5.4648,  ...,  6.0273, -4.9531, -0.4653]],

        [[-5.5195,  1.8623,  4.8945,  ..., -2.6641, -2.0000, -0.2107],
         [-5.6055,  1.9238,  5.0156,  ..., -2.7324, -2.0352, -0.3296],
         [-0.9800, -2.3359, -5.4336,  ...,  5.8945, -4.9531, -0.5010]],

        ...,

        [[-4.9102,  1.4062,  4.0273,  ..., -2.1777, -1.7393,  0.6079],
         [-4.8555,  1.3652,  3.9512,  ..., -2.1387, -1.7158,  0.6758],
         [-0.9385, -2.3770, -5.3906,  ...,  5.7656, -4.9297, -0.5352]],

        [[-6.5625,  2.6777,  6.6680,  ..., -3.8945, -2.

Epoch 0:  25%|██▌       | 1497/5879 [16:17<47:40,  1.53it/s, v_num=7]
 tensor([[[-2.1816, -1.5596,  0.0429,  ..., -0.6914, -0.1299,  3.2109],
         [-2.1855, -1.5557,  0.0464,  ..., -0.6919, -0.1316,  3.2090],
         [-0.8013, -2.5312, -5.2148,  ...,  5.3438, -4.8203, -0.6235]],

        [[-6.5273,  2.6016,  6.4258,  ..., -3.5879, -2.3828, -1.9746],
         [-6.5273,  2.6016,  6.4258,  ..., -3.5879, -2.3828, -1.9775],
         [-1.0166, -2.3066, -5.4414,  ...,  6.0117, -4.9727, -0.4446]],

        [[-5.0156,  1.4834,  4.1992,  ..., -2.2383, -1.7842,  0.4492],
         [-5.0195,  1.4844,  4.1992,  ..., -2.2402, -1.7842,  0.4473],
         [-0.9873, -2.3340, -5.4180,  ...,  5.9219, -4.9766, -0.4695]],

        ...,

        [[-5.4922,  1.8447,  4.8906,  ..., -2.6270, -1.9941, -0.2035],
         [-5.5664,  1.9004,  5.0000,  ..., -2.6895, -2.0254, -0.3118],
         [-0.9707, -2.3496, -5.4023,  ...,  5.8711, -4.9688, -0.4844]],

        [[-3.9043,  1.3115,  3.1680,  ..., -2.2637, -1.

Epoch 0:  26%|██▌       | 1503/5879 [16:23<47:43,  1.53it/s, v_num=7]
 tensor([[[-6.6172,  2.7188,  6.7578,  ..., -3.9199, -2.3164, -3.1367],
         [-6.6172,  2.7188,  6.7578,  ..., -3.9199, -2.3164, -3.1367],
         [-1.0312, -2.2988, -5.4219,  ...,  6.0469, -4.9766, -0.4094]],

        [[-6.6523,  2.7168,  6.7148,  ..., -3.8105, -2.4043, -2.6055],
         [-6.6523,  2.7148,  6.7070,  ..., -3.8027, -2.4062, -2.5801],
         [-1.0156, -2.3125, -5.4219,  ...,  6.0117, -4.9961, -0.4192]],

        [[-6.2188,  2.3789,  5.9883,  ..., -3.2773, -2.2949, -1.4150],
         [-6.2109,  2.3730,  5.9766,  ..., -3.2715, -2.2930, -1.4023],
         [-1.0156, -2.3125, -5.4219,  ...,  6.0117, -4.9961, -0.4194]],

        ...,

        [[-4.6484,  1.1787,  3.6836,  ..., -1.9160, -1.6045,  0.9307],
         [-4.6484,  1.1797,  3.6855,  ..., -1.9160, -1.6045,  0.9292],
         [-0.9561, -2.3672, -5.3672,  ...,  5.8242, -4.9766, -0.4729]],

        [[-6.2539,  2.4043,  6.0391,  ..., -3.3145, -2.

Epoch 0:  26%|██▌       | 1509/5879 [16:24<47:32,  1.53it/s, v_num=7]
 tensor([[[-6.5547e+00,  2.6426e+00,  6.5703e+00,  ..., -3.6816e+00,
          -2.4141e+00, -2.3047e+00],
         [-6.5547e+00,  2.6426e+00,  6.5703e+00,  ..., -3.6836e+00,
          -2.4141e+00, -2.3105e+00],
         [-9.9805e-01, -2.3301e+00, -5.3906e+00,  ...,  5.9648e+00,
          -5.0156e+00, -4.0845e-01]],

        [[-4.0508e+00,  6.4844e-01,  2.8066e+00,  ..., -1.4316e+00,
          -1.2861e+00,  1.6670e+00],
         [-4.0508e+00,  6.4990e-01,  2.8105e+00,  ..., -1.4326e+00,
          -1.2871e+00,  1.6650e+00],
         [-9.6240e-01, -2.3633e+00, -5.3594e+00,  ...,  5.8555e+00,
          -5.0039e+00, -4.4141e-01]],

        [[-8.7793e-01, -3.8457e+00, -9.7607e-01,  ..., -1.1963e+00,
           8.7793e-01,  2.7285e+00],
         [-8.8086e-01, -3.8379e+00, -9.7559e-01,  ..., -1.1914e+00,
           8.7598e-01,  2.7324e+00],
         [ 4.3488e-03, -3.9551e+00, -3.3320e+00,  ...,  1.7607e+00,
          -2.9805

Epoch 0:  26%|██▌       | 1515/5879 [16:26<47:22,  1.54it/s, v_num=7]
 tensor([[[-5.1172,  1.5420,  4.4102,  ..., -2.2461, -1.8291,  0.2712],
         [-5.0039,  1.4531,  4.2461,  ..., -2.1504, -1.7744,  0.4304],
         [-1.0137, -2.3184, -5.3750,  ...,  6.0078, -5.0273, -0.3701]],

        [[-5.9453,  2.1758,  5.6367,  ..., -3.0137, -2.2129, -1.0264],
         [-5.9492,  2.1777,  5.6406,  ..., -3.0156, -2.2148, -1.0303],
         [-1.0186, -2.3164, -5.3828,  ...,  6.0312, -5.0312, -0.3640]],

        [[-5.9180,  2.3613,  6.0352,  ..., -3.7012, -1.9521, -3.7988],
         [-5.9180,  2.3613,  6.0352,  ..., -3.7012, -1.9521, -3.8008],
         [-1.0312, -2.3066, -5.3867,  ...,  6.0586, -5.0156, -0.3564]],

        ...,

        [[-1.7852, -2.2578, -0.4387,  ..., -0.5542,  0.3142,  3.3262],
         [-1.7998, -2.2344, -0.4241,  ..., -0.5508,  0.3030,  3.3262],
         [-0.7036, -2.6543, -5.0234,  ...,  5.0234, -4.7656, -0.6548]],

        [[-1.7100, -2.3750, -0.5083,  ..., -0.5679,  0.

In [46]:
!nvidia-smi

Tue Nov 21 13:54:58 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.106.00   Driver Version: 460.106.00   CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  A100-SXM-80GB       Off  | 00000000:01:00.0 Off |                    0 |
| N/A   28C    P0    62W / 500W |      0MiB / 81251MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  A100-SXM-80GB       Off  | 00000000:41:00.0 Off |                    0 |
| N/A   26C    P0    59W / 500W |      0MiB / 81251MiB |      0%      Default |
|       

In [48]:
dist.is_available()

True