In [3]:
#@markdown ### **Installing pip packages**
#@markdown - Diffusion Model: [PyTorch](https://pytorch.org) & [HuggingFace diffusers](https://huggingface.co/docs/diffusers/index)
#@markdown - Dataset Loading: [Zarr](https://zarr.readthedocs.io/en/stable/) & numcodecs
#@markdown - Push-T Env: gym, pygame, pymunk & shapely
!python --version
!pip3 install torch==1.13.1 torchvision==0.14.1 diffusers==0.18.2 \
scikit-image==0.19.3 scikit-video==1.1.11 zarr==2.12.0 numcodecs==0.10.2 \
pygame==2.1.2 pymunk==6.2.1 gym==0.26.2 shapely==1.8.4 \
&> /dev/null # mute output

Python 3.9.15


In [4]:
#@markdown ### **Imports**
# diffusion policy import
from typing import Tuple, Sequence, Dict, Union, Optional
import numpy as np
import math
import torch
import torch.nn as nn
import collections
import zarr
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm

# env import
import gym
from gym import spaces
import pygame
import pymunk
import pymunk.pygame_util
from pymunk.space_debug_draw_options import SpaceDebugColor
from pymunk.vec2d import Vec2d
import shapely.geometry as sg
import cv2
import skimage.transform as st
from skvideo.io import vwrite
from IPython.display import Video
import gdown
import os

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
#@markdown ### **Environment**
#@markdown Defines a PyMunk-based Push-T environment `PushTEnv`.
#@markdown
#@markdown **Goal**: push the gray T-block into the green area.
#@markdown
#@markdown Adapted from [Implicit Behavior Cloning](https://implicitbc.github.io/)


positive_y_is_up: bool = False
"""Make increasing values of y point upwards.

When True::

    y
    ^
    |      . (3, 3)
    |
    |   . (2, 2)
    |
    +------ > x

When False::

    +------ > x
    |
    |   . (2, 2)
    |
    |      . (3, 3)
    v
    y

"""

def to_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
    """Convenience method to convert pymunk coordinates to pygame surface
    local coordinates.

    Note that in case positive_y_is_up is False, this function wont actually do
    anything except converting the point to integers.
    """
    if positive_y_is_up:
        return round(p[0]), surface.get_height() - round(p[1])
    else:
        return round(p[0]), round(p[1])


def light_color(color: SpaceDebugColor):
    color = np.minimum(1.2 * np.float32([color.r, color.g, color.b, color.a]), np.float32([255]))
    color = SpaceDebugColor(r=color[0], g=color[1], b=color[2], a=color[3])
    return color

class DrawOptions(pymunk.SpaceDebugDrawOptions):
    def __init__(self, surface: pygame.Surface) -> None:
        """Draw a pymunk.Space on a pygame.Surface object.

        Typical usage::

        >>> import pymunk
        >>> surface = pygame.Surface((10,10))
        >>> space = pymunk.Space()
        >>> options = pymunk.pygame_util.DrawOptions(surface)
        >>> space.debug_draw(options)

        You can control the color of a shape by setting shape.color to the color
        you want it drawn in::

        >>> c = pymunk.Circle(None, 10)
        >>> c.color = pygame.Color("pink")

        See pygame_util.demo.py for a full example

        Since pygame uses a coordiante system where y points down (in contrast
        to many other cases), you either have to make the physics simulation
        with Pymunk also behave in that way, or flip everything when you draw.

        The easiest is probably to just make the simulation behave the same
        way as Pygame does. In that way all coordinates used are in the same
        orientation and easy to reason about::

        >>> space = pymunk.Space()
        >>> space.gravity = (0, -1000)
        >>> body = pymunk.Body()
        >>> body.position = (0, 0) # will be positioned in the top left corner
        >>> space.debug_draw(options)

        To flip the drawing its possible to set the module property
        :py:data:`positive_y_is_up` to True. Then the pygame drawing will flip
        the simulation upside down before drawing::

        >>> positive_y_is_up = True
        >>> body = pymunk.Body()
        >>> body.position = (0, 0)
        >>> # Body will be position in bottom left corner

        :Parameters:
                surface : pygame.Surface
                    Surface that the objects will be drawn on
        """
        self.surface = surface
        super(DrawOptions, self).__init__()

    def draw_circle(
        self,
        pos: Vec2d,
        angle: float,
        radius: float,
        outline_color: SpaceDebugColor,
        fill_color: SpaceDebugColor,
    ) -> None:
        p = to_pygame(pos, self.surface)

        pygame.draw.circle(self.surface, fill_color.as_int(), p, round(radius), 0)
        pygame.draw.circle(self.surface, light_color(fill_color).as_int(), p, round(radius-4), 0)

        circle_edge = pos + Vec2d(radius, 0).rotated(angle)
        p2 = to_pygame(circle_edge, self.surface)
        line_r = 2 if radius > 20 else 1
        # pygame.draw.lines(self.surface, outline_color.as_int(), False, [p, p2], line_r)

    def draw_segment(self, a: Vec2d, b: Vec2d, color: SpaceDebugColor) -> None:
        p1 = to_pygame(a, self.surface)
        p2 = to_pygame(b, self.surface)

        pygame.draw.aalines(self.surface, color.as_int(), False, [p1, p2])

    def draw_fat_segment(
        self,
        a: Tuple[float, float],
        b: Tuple[float, float],
        radius: float,
        outline_color: SpaceDebugColor,
        fill_color: SpaceDebugColor,
    ) -> None:
        p1 = to_pygame(a, self.surface)
        p2 = to_pygame(b, self.surface)

        r = round(max(1, radius * 2))
        pygame.draw.lines(self.surface, fill_color.as_int(), False, [p1, p2], r)
        if r > 2:
            orthog = [abs(p2[1] - p1[1]), abs(p2[0] - p1[0])]
            if orthog[0] == 0 and orthog[1] == 0:
                return
            scale = radius / (orthog[0] * orthog[0] + orthog[1] * orthog[1]) ** 0.5
            orthog[0] = round(orthog[0] * scale)
            orthog[1] = round(orthog[1] * scale)
            points = [
                (p1[0] - orthog[0], p1[1] - orthog[1]),
                (p1[0] + orthog[0], p1[1] + orthog[1]),
                (p2[0] + orthog[0], p2[1] + orthog[1]),
                (p2[0] - orthog[0], p2[1] - orthog[1]),
            ]
            pygame.draw.polygon(self.surface, fill_color.as_int(), points)
            pygame.draw.circle(
                self.surface,
                fill_color.as_int(),
                (round(p1[0]), round(p1[1])),
                round(radius),
            )
            pygame.draw.circle(
                self.surface,
                fill_color.as_int(),
                (round(p2[0]), round(p2[1])),
                round(radius),
            )

    def draw_polygon(
        self,
        verts: Sequence[Tuple[float, float]],
        radius: float,
        outline_color: SpaceDebugColor,
        fill_color: SpaceDebugColor,
    ) -> None:
        ps = [to_pygame(v, self.surface) for v in verts]
        ps += [ps[0]]

        radius = 2
        pygame.draw.polygon(self.surface, light_color(fill_color).as_int(), ps)

        if radius > 0:
            for i in range(len(verts)):
                a = verts[i]
                b = verts[(i + 1) % len(verts)]
                self.draw_fat_segment(a, b, radius, fill_color, fill_color)

    def draw_dot(
        self, size: float, pos: Tuple[float, float], color: SpaceDebugColor
    ) -> None:
        p = to_pygame(pos, self.surface)
        pygame.draw.circle(self.surface, color.as_int(), p, round(size), 0)


def pymunk_to_shapely(body, shapes):
    geoms = list()
    for shape in shapes:
        if isinstance(shape, pymunk.shapes.Poly):
            verts = [body.local_to_world(v) for v in shape.get_vertices()]
            verts += [verts[0]]
            geoms.append(sg.Polygon(verts))
        else:
            raise RuntimeError(f'Unsupported shape type {type(shape)}')
    geom = sg.MultiPolygon(geoms)
    return geom

# env
class PushTEnv(gym.Env):
    metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 10}
    reward_range = (0., 1.)

    def __init__(self,
            legacy=False,
            block_cog=None, damping=None,
            render_action=True,
            render_size=96,
            reset_to_state=None
        ):
        self._seed = None
        self.seed()
        self.window_size = ws = 512  # The size of the PyGame window
        self.render_size = render_size
        self.sim_hz = 100
        # Local controller params.
        self.k_p, self.k_v = 100, 20    # PD control.z
        self.control_hz = self.metadata['video.frames_per_second']
        # legcay set_state for data compatiblity
        self.legacy = legacy

        # agent_pos, block_pos, block_angle
        self.observation_space = spaces.Box(
            low=np.array([0,0,0,0,0], dtype=np.float64),
            high=np.array([ws,ws,ws,ws,np.pi*2], dtype=np.float64),
            shape=(5,),
            dtype=np.float64
        )

        # positional goal for agent
        self.action_space = spaces.Box(
            low=np.array([0,0], dtype=np.float64),
            high=np.array([ws,ws], dtype=np.float64),
            shape=(2,),
            dtype=np.float64
        )

        self.block_cog = block_cog
        self.damping = damping
        self.render_action = render_action

        """
        If human-rendering is used, `self.window` will be a reference
        to the window that we draw to. `self.clock` will be a clock that is used
        to ensure that the environment is rendered at the correct framerate in
        human-mode. They will remain `None` until human-mode is used for the
        first time.
        """
        self.window = None
        self.clock = None
        self.screen = None

        self.space = None
        self.teleop = None
        self.render_buffer = None
        self.latest_action = None
        self.reset_to_state = reset_to_state

    def reset(self):
        seed = self._seed
        self._setup()
        if self.block_cog is not None:
            self.block.center_of_gravity = self.block_cog
        if self.damping is not None:
            self.space.damping = self.damping

        # use legacy RandomState for compatiblity
        state = self.reset_to_state
        if state is None:
            rs = np.random.RandomState(seed=seed)
            state = np.array([
                rs.randint(50, 450), rs.randint(50, 450),
                rs.randint(100, 400), rs.randint(100, 400),
                rs.randn() * 2 * np.pi - np.pi
                ])
        self._set_state(state)

        obs = self._get_obs()
        info = self._get_info()
        return obs, info

    def step(self, action):
        dt = 1.0 / self.sim_hz
        self.n_contact_points = 0
        n_steps = self.sim_hz // self.control_hz
        if action is not None:
            self.latest_action = action
            for i in range(n_steps):
                # Step PD control.
                # self.agent.velocity = self.k_p * (act - self.agent.position)    # P control works too.
                acceleration = self.k_p * (action - self.agent.position) + self.k_v * (Vec2d(0, 0) - self.agent.velocity)
                self.agent.velocity += acceleration * dt

                # Step physics.
                self.space.step(dt)

        # compute reward
        goal_body = self._get_goal_pose_body(self.goal_pose)
        goal_geom = pymunk_to_shapely(goal_body, self.block.shapes)
        block_geom = pymunk_to_shapely(self.block, self.block.shapes)

        intersection_area = goal_geom.intersection(block_geom).area
        goal_area = goal_geom.area
        coverage = intersection_area / goal_area
        reward = np.clip(coverage / self.success_threshold, 0, 1)
        done = coverage > self.success_threshold
        terminated = done
        truncated = done

        observation = self._get_obs()
        info = self._get_info()

        return observation, reward, terminated, truncated, info

    def render(self, mode):
        return self._render_frame(mode)

    def teleop_agent(self):
        TeleopAgent = collections.namedtuple('TeleopAgent', ['act'])
        def act(obs):
            act = None
            mouse_position = pymunk.pygame_util.from_pygame(Vec2d(*pygame.mouse.get_pos()), self.screen)
            if self.teleop or (mouse_position - self.agent.position).length < 30:
                self.teleop = True
                act = mouse_position
            return act
        return TeleopAgent(act)

    def _get_obs(self):
        obs = np.array(
            tuple(self.agent.position) \
            + tuple(self.block.position) \
            + (self.block.angle % (2 * np.pi),))
        return obs

    def _get_goal_pose_body(self, pose):
        mass = 1
        inertia = pymunk.moment_for_box(mass, (50, 100))
        body = pymunk.Body(mass, inertia)
        # preserving the legacy assignment order for compatibility
        # the order here dosn't matter somehow, maybe because CoM is aligned with body origin
        body.position = pose[:2].tolist()
        body.angle = pose[2]
        return body

    def _get_info(self):
        n_steps = self.sim_hz // self.control_hz
        n_contact_points_per_step = int(np.ceil(self.n_contact_points / n_steps))
        info = {
            'pos_agent': np.array(self.agent.position),
            'vel_agent': np.array(self.agent.velocity),
            'block_pose': np.array(list(self.block.position) + [self.block.angle]),
            'goal_pose': self.goal_pose,
            'n_contacts': n_contact_points_per_step}
        return info

    def _render_frame(self, mode):

        if self.window is None and mode == "human":
            pygame.init()
            pygame.display.init()
            self.window = pygame.display.set_mode((self.window_size, self.window_size))
        if self.clock is None and mode == "human":
            self.clock = pygame.time.Clock()

        canvas = pygame.Surface((self.window_size, self.window_size))
        canvas.fill((255, 255, 255))
        self.screen = canvas

        draw_options = DrawOptions(canvas)

        # Draw goal pose.
        goal_body = self._get_goal_pose_body(self.goal_pose)
        for shape in self.block.shapes:
            goal_points = [pymunk.pygame_util.to_pygame(goal_body.local_to_world(v), draw_options.surface) for v in shape.get_vertices()]
            goal_points += [goal_points[0]]
            pygame.draw.polygon(canvas, self.goal_color, goal_points)

        # Draw agent and block.
        self.space.debug_draw(draw_options)

        if mode == "human":
            # The following line copies our drawings from `canvas` to the visible window
            self.window.blit(canvas, canvas.get_rect())
            pygame.event.pump()
            pygame.display.update()

            # the clock is aleady ticked during in step for "human"


        img = np.transpose(
                np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)
            )
        img = cv2.resize(img, (self.render_size, self.render_size))
        if self.render_action:
            if self.render_action and (self.latest_action is not None):
                action = np.array(self.latest_action)
                coord = (action / 512 * 96).astype(np.int32)
                marker_size = int(8/96*self.render_size)
                thickness = int(1/96*self.render_size)
                cv2.drawMarker(img, coord,
                    color=(255,0,0), markerType=cv2.MARKER_CROSS,
                    markerSize=marker_size, thickness=thickness)
        return img


    def close(self):
        if self.window is not None:
            pygame.display.quit()
            pygame.quit()

    def seed(self, seed=None):
        if seed is None:
            seed = np.random.randint(0,25536)
        self._seed = seed
        self.np_random = np.random.default_rng(seed)

    def _handle_collision(self, arbiter, space, data):
        self.n_contact_points += len(arbiter.contact_point_set.points)

    def _set_state(self, state):
        if isinstance(state, np.ndarray):
            state = state.tolist()
        pos_agent = state[:2]
        pos_block = state[2:4]
        rot_block = state[4]
        self.agent.position = pos_agent
        # setting angle rotates with respect to center of mass
        # therefore will modify the geometric position
        # if not the same as CoM
        # therefore should be modified first.
        if self.legacy:
            # for compatiblity with legacy data
            self.block.position = pos_block
            self.block.angle = rot_block
        else:
            self.block.angle = rot_block
            self.block.position = pos_block

        # Run physics to take effect
        self.space.step(1.0 / self.sim_hz)

    def _set_state_local(self, state_local):
        agent_pos_local = state_local[:2]
        block_pose_local = state_local[2:]
        tf_img_obj = st.AffineTransform(
            translation=self.goal_pose[:2],
            rotation=self.goal_pose[2])
        tf_obj_new = st.AffineTransform(
            translation=block_pose_local[:2],
            rotation=block_pose_local[2]
        )
        tf_img_new = st.AffineTransform(
            matrix=tf_img_obj.params @ tf_obj_new.params
        )
        agent_pos_new = tf_img_new(agent_pos_local)
        new_state = np.array(
            list(agent_pos_new[0]) + list(tf_img_new.translation) \
                + [tf_img_new.rotation])
        self._set_state(new_state)
        return new_state

    def _setup(self):
        self.space = pymunk.Space()
        self.space.gravity = 0, 0
        self.space.damping = 0
        self.teleop = False
        self.render_buffer = list()

        # Add walls.
        walls = [
            self._add_segment((5, 506), (5, 5), 2),
            self._add_segment((5, 5), (506, 5), 2),
            self._add_segment((506, 5), (506, 506), 2),
            self._add_segment((5, 506), (506, 506), 2)
        ]
        self.space.add(*walls)

        # Add agent, block, and goal zone.
        self.agent = self.add_circle((256, 400), 15)
        self.block = self.add_tee((256, 300), 0)
        self.goal_color = pygame.Color('LightGreen')
        self.goal_pose = np.array([256,256,np.pi/4])  # x, y, theta (in radians)

        # Add collision handeling
        self.collision_handeler = self.space.add_collision_handler(0, 0)
        self.collision_handeler.post_solve = self._handle_collision
        self.n_contact_points = 0

        self.max_score = 50 * 100
        self.success_threshold = 0.95    # 95% coverage.

    def _add_segment(self, a, b, radius):
        shape = pymunk.Segment(self.space.static_body, a, b, radius)
        shape.color = pygame.Color('LightGray')    # https://htmlcolorcodes.com/color-names
        return shape

    def add_circle(self, position, radius):
        body = pymunk.Body(body_type=pymunk.Body.KINEMATIC)
        body.position = position
        body.friction = 1
        shape = pymunk.Circle(body, radius)
        shape.color = pygame.Color('RoyalBlue')
        self.space.add(body, shape)
        return body

    def add_box(self, position, height, width):
        mass = 1
        inertia = pymunk.moment_for_box(mass, (height, width))
        body = pymunk.Body(mass, inertia)
        body.position = position
        shape = pymunk.Poly.create_box(body, (height, width))
        shape.color = pygame.Color('LightSlateGray')
        self.space.add(body, shape)
        return body

    def add_tee(self, position, angle, scale=30, color='LightSlateGray', mask=pymunk.ShapeFilter.ALL_MASKS()):
        mass = 1
        length = 4
        vertices1 = [(-length*scale/2, scale),
                                 ( length*scale/2, scale),
                                 ( length*scale/2, 0),
                                 (-length*scale/2, 0)]
        inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
        vertices2 = [(-scale/2, scale),
                                 (-scale/2, length*scale),
                                 ( scale/2, length*scale),
                                 ( scale/2, scale)]
        inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
        body = pymunk.Body(mass, inertia1 + inertia2)
        shape1 = pymunk.Poly(body, vertices1)
        shape2 = pymunk.Poly(body, vertices2)
        shape1.color = pygame.Color(color)
        shape2.color = pygame.Color(color)
        shape1.filter = pymunk.ShapeFilter(mask=mask)
        shape2.filter = pymunk.ShapeFilter(mask=mask)
        body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
        body.position = position
        body.angle = angle
        body.friction = 1
        self.space.add(body, shape1, shape2)
        return body


In [6]:
from huggingface_hub.utils import IGNORE_GIT_FOLDER_PATTERNS
#@markdown ### **Env Demo**
#@markdown Standard Gym Env (0.21.0 API)

# 0. create env object
env = PushTEnv()

# 1. seed env for initial state.
# Seed 0-200 are used for the demonstration dataset.
env.seed(1000)

# 2. must reset before use
obs, IGNORE_GIT_FOLDER_PATTERNS = env.reset()

# 3. 2D positional action space [0,512]
action = env.action_space.sample()

# 4. Standard gym step method
obs, reward, terminated, truncated, info = env.step(action)

# prints and explains each dimension of the observation and action vectors
with np.printoptions(precision=4, suppress=True, threshold=5):
    print("Obs: ", repr(obs))
    print("Obs:        [agent_x,  agent_y,  block_x,  block_y,    block_angle]")
    print("Action: ", repr(action))
    print("Action:   [target_agent_x, target_agent_y]")

Obs:  array([160.0377, 135.4359, 292.    , 351.    ,   2.9196])
Obs:        [agent_x,  agent_y,  block_x,  block_y,    block_angle]
Action:  array([214.9165, 169.8241])
Action:   [target_agent_x, target_agent_y]


In [7]:
#@markdown ### **Dataset**
#@markdown
#@markdown Defines `PushTStateDataset` and helper functions
#@markdown
#@markdown The dataset class
#@markdown - Load data (obs, action) from a zarr storage
#@markdown - Normalizes each dimension of obs and action to [-1,1]
#@markdown - Returns
#@markdown  - All possible segments with length `pred_horizon`
#@markdown  - Pads the beginning and the end of each episode with repetition
#@markdown  - key `obs`: shape (obs_horizon, obs_dim)
#@markdown  - key `action`: shape (pred_horizon, action_dim)

def create_sample_indices(
        episode_ends:np.ndarray, sequence_length:int,
        pad_before: int=0, pad_after: int=0):
    indices = list()
    for i in range(len(episode_ends)):
        start_idx = 0
        if i > 0:
            start_idx = episode_ends[i-1]
        end_idx = episode_ends[i]
        episode_length = end_idx - start_idx

        min_start = -pad_before
        max_start = episode_length - sequence_length + pad_after

        # range stops one idx before end
        for idx in range(min_start, max_start+1):
            buffer_start_idx = max(idx, 0) + start_idx
            buffer_end_idx = min(idx+sequence_length, episode_length) + start_idx
            start_offset = buffer_start_idx - (idx+start_idx)
            end_offset = (idx+sequence_length+start_idx) - buffer_end_idx
            sample_start_idx = 0 + start_offset
            sample_end_idx = sequence_length - end_offset
            indices.append([
                buffer_start_idx, buffer_end_idx,
                sample_start_idx, sample_end_idx])
    indices = np.array(indices)
    return indices


def sample_sequence(train_data, sequence_length,
                    buffer_start_idx, buffer_end_idx,
                    sample_start_idx, sample_end_idx):
    result = dict()
    for key, input_arr in train_data.items():
        sample = input_arr[buffer_start_idx:buffer_end_idx]
        data = sample
        if (sample_start_idx > 0) or (sample_end_idx < sequence_length):
            data = np.zeros(
                shape=(sequence_length,) + input_arr.shape[1:],
                dtype=input_arr.dtype)
            if sample_start_idx > 0:
                data[:sample_start_idx] = sample[0]
            if sample_end_idx < sequence_length:
                data[sample_end_idx:] = sample[-1]
            data[sample_start_idx:sample_end_idx] = sample
        result[key] = data
    return result

# normalize data
def get_data_stats(data):
    data = data.reshape(-1,data.shape[-1])
    stats = {
        'min': np.min(data, axis=0),
        'max': np.max(data, axis=0)
    }
    return stats

def normalize_data(data, stats):
    # nomalize to [0,1]
    ndata = (data - stats['min']) / (stats['max'] - stats['min'])
    # normalize to [-1, 1]
    ndata = ndata * 2 - 1
    return ndata

def unnormalize_data(ndata, stats):
    ndata = (ndata + 1) / 2
    data = ndata * (stats['max'] - stats['min']) + stats['min']
    return data

# dataset
class PushTStateDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_path,
                 pred_horizon, obs_horizon, action_horizon):

        # read from zarr dataset
        dataset_root = zarr.open(dataset_path, 'r')
        # All demonstration episodes are concatinated in the first dimension N
        train_data = {
            # (N, action_dim)
            'action': dataset_root['data']['action'][:],
            # (N, obs_dim)
            'obs': dataset_root['data']['state'][:]
        }
        # Marks one-past the last index for each episode
        episode_ends = dataset_root['meta']['episode_ends'][:]

        # compute start and end of each state-action sequence
        # also handles padding
        indices = create_sample_indices(
            episode_ends=episode_ends,
            sequence_length=pred_horizon,
            # add padding such that each timestep in the dataset are seen
            pad_before=obs_horizon-1,
            pad_after=action_horizon-1)

        # compute statistics and normalized data to [-1,1]
        stats = dict()
        normalized_train_data = dict()
        for key, data in train_data.items():
            stats[key] = get_data_stats(data)
            normalized_train_data[key] = normalize_data(data, stats[key])

        self.indices = indices
        self.stats = stats
        self.normalized_train_data = normalized_train_data
        self.pred_horizon = pred_horizon
        self.action_horizon = action_horizon
        self.obs_horizon = obs_horizon

    def __len__(self):
        # all possible segments of the dataset
        return len(self.indices)

    def __getitem__(self, idx):
        # get the start/end indices for this datapoint
        buffer_start_idx, buffer_end_idx, \
            sample_start_idx, sample_end_idx = self.indices[idx]

        # get nomralized data using these indices
        nsample = sample_sequence(
            train_data=self.normalized_train_data,
            sequence_length=self.pred_horizon,
            buffer_start_idx=buffer_start_idx,
            buffer_end_idx=buffer_end_idx,
            sample_start_idx=sample_start_idx,
            sample_end_idx=sample_end_idx
        )

        # discard unused observations
        nsample['obs'] = nsample['obs'][:self.obs_horizon,:]
        return nsample


In [8]:
#@markdown ### **Dataset Demo**

# download demonstration data from Google Drive
dataset_path = "pusht_cchi_v7_replay.zarr.zip"
if not os.path.isfile(dataset_path):
    id = "1KY1InLurpMvJDRb14L9NlXT_fEsCvVUq&confirm=t"
    gdown.download(id=id, output=dataset_path, quiet=False)

# parameters
pred_horizon = 16
obs_horizon = 2
action_horizon = 8
#|o|o|                             observations: 2
#| |a|a|a|a|a|a|a|a|               actions executed: 8
#|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p| actions predicted: 16

# create dataset from file
dataset = PushTStateDataset(
    dataset_path=dataset_path,
    pred_horizon=pred_horizon,
    obs_horizon=obs_horizon,
    action_horizon=action_horizon
)
# save training data statistics (min, max) for each dim
stats = dataset.stats

# create dataloader
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    num_workers=1,
    shuffle=True,
    # accelerate cpu-gpu transfer
    pin_memory=True,
    # don't kill worker process afte each epoch
    persistent_workers=True
)

# visualize data in batch
batch = next(iter(dataloader))
print("batch['obs'].shape:", batch['obs'].shape)
print("batch['action'].shape", batch['action'].shape)

batch['obs'].shape: torch.Size([256, 2, 5])
batch['action'].shape torch.Size([256, 16, 2])


In [9]:
#@markdown ### **Network**
#@markdown
#@markdown Defines a 1D UNet architecture `ConditionalUnet1D`
#@markdown as the noies prediction network
#@markdown
#@markdown Components
#@markdown - `SinusoidalPosEmb` Positional encoding for the diffusion iteration k
#@markdown - `Downsample1d` Strided convolution to reduce temporal resolution
#@markdown - `Upsample1d` Transposed convolution to increase temporal resolution
#@markdown - `Conv1dBlock` Conv1d --> GroupNorm --> Mish
#@markdown - `ConditionalResidualBlock1D` Takes two inputs `x` and `cond`. \
#@markdown `x` is passed through 2 `Conv1dBlock` stacked together with residual connection.
#@markdown `cond` is applied to `x` with [FiLM](https://arxiv.org/abs/1709.07871) conditioning.

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class Downsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, 3, 2, 1)

    def forward(self, x):
        return self.conv(x)

class Upsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)

    def forward(self, x):
        return self.conv(x)


class Conv1dBlock(nn.Module):
    '''
        Conv1d --> GroupNorm --> Mish
    '''

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
            nn.GroupNorm(n_groups, out_channels),
            nn.Mish(),
        )

    def forward(self, x):
        return self.block(x)


class ConditionalResidualBlock1D(nn.Module):
    def __init__(self,
            in_channels,
            out_channels,
            cond_dim,
            kernel_size=3,
            n_groups=8):
        super().__init__()

        self.blocks = nn.ModuleList([
            Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
            Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
        ])

        # FiLM modulation https://arxiv.org/abs/1709.07871
        # predicts per-channel scale and bias
        cond_channels = out_channels * 2
        self.out_channels = out_channels
        self.cond_encoder = nn.Sequential(
            nn.Mish(),
            nn.Linear(cond_dim, cond_channels),
            nn.Unflatten(-1, (-1, 1))
        )

        # make sure dimensions compatible
        self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
            if in_channels != out_channels else nn.Identity()

    def forward(self, x, cond):
        '''
            x : [ batch_size x in_channels x horizon ]
            cond : [ batch_size x cond_dim]

            returns:
            out : [ batch_size x out_channels x horizon ]
        '''
        out = self.blocks[0](x)
        embed = self.cond_encoder(cond)

        embed = embed.reshape(
            embed.shape[0], 2, self.out_channels, 1)
        scale = embed[:,0,...]
        bias = embed[:,1,...]
        out = scale * out + bias

        out = self.blocks[1](out)
        out = out + self.residual_conv(x)
        return out


class ConditionalUnet1D(nn.Module):
    def __init__(self,
        input_dim,
        global_cond_dim,
        diffusion_step_embed_dim=256,
        down_dims=[256,512,1024],
        kernel_size=5,
        n_groups=8
        ):
        """
        input_dim: Dim of actions.
        global_cond_dim: Dim of global conditioning applied with FiLM
          in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
        diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
        down_dims: Channel size for each UNet level.
          The length of this array determines numebr of levels.
        kernel_size: Conv kernel size
        n_groups: Number of groups for GroupNorm
        """

        super().__init__()
        all_dims = [input_dim] + list(down_dims)
        start_dim = down_dims[0]

        dsed = diffusion_step_embed_dim
        diffusion_step_encoder = nn.Sequential(
            SinusoidalPosEmb(dsed),
            nn.Linear(dsed, dsed * 4),
            nn.Mish(),
            nn.Linear(dsed * 4, dsed),
        )
        cond_dim = dsed + global_cond_dim

        in_out = list(zip(all_dims[:-1], all_dims[1:]))
        mid_dim = all_dims[-1]
        self.mid_modules = nn.ModuleList([
            ConditionalResidualBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ),
            ConditionalResidualBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ),
        ])

        down_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (len(in_out) - 1)
            down_modules.append(nn.ModuleList([
                ConditionalResidualBlock1D(
                    dim_in, dim_out, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                ConditionalResidualBlock1D(
                    dim_out, dim_out, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                Downsample1d(dim_out) if not is_last else nn.Identity()
            ]))

        up_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (len(in_out) - 1)
            up_modules.append(nn.ModuleList([
                ConditionalResidualBlock1D(
                    dim_out*2, dim_in, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                ConditionalResidualBlock1D(
                    dim_in, dim_in, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                Upsample1d(dim_in) if not is_last else nn.Identity()
            ]))

        final_conv = nn.Sequential(
            Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
            nn.Conv1d(start_dim, input_dim, 1),
        )

        self.diffusion_step_encoder = diffusion_step_encoder
        self.up_modules = up_modules
        self.down_modules = down_modules
        self.final_conv = final_conv

        print("number of parameters: {:e}".format(
            sum(p.numel() for p in self.parameters()))
        )

    def forward(self,
            sample: torch.Tensor,
            timestep: Union[torch.Tensor, float, int],
            global_cond=None):
        """
        x: (B,T,input_dim)
        timestep: (B,) or int, diffusion step
        global_cond: (B,global_cond_dim)
        output: (B,T,input_dim)
        """
        # (B,T,C)
        sample = sample.moveaxis(-1,-2)
        # (B,C,T)

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps.expand(sample.shape[0])

        global_feature = self.diffusion_step_encoder(timesteps)

        if global_cond is not None:
            global_feature = torch.cat([
                global_feature, global_cond
            ], axis=-1)

        x = sample
        h = []
        for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            h.append(x)
            x = downsample(x)

        for mid_module in self.mid_modules:
            x = mid_module(x, global_feature)

        for idx, h_i in enumerate(h):
            print(f"h_i[{idx}]", h_i.shape)

        for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
            x = torch.cat((x, h.pop()), dim=1)
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            x = upsample(x)

        x = self.final_conv(x)

        # (B,C,T)
        x = x.moveaxis(-1,-2)
        # (B,T,C)
        return x


In [8]:
#@markdown ### **Network Demo**

# observation and action dimensions corrsponding to
# the output of PushTEnv
obs_dim = 5
action_dim = 2

# create network object
noise_pred_net = ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim*obs_horizon
)

# example inputs
noised_action = torch.randn((1, pred_horizon, action_dim))
obs = torch.zeros((1, obs_horizon, obs_dim))
diffusion_iter = torch.zeros((1,))

# the noise prediction network
# takes noisy action, diffusion iteration and observation as input
# predicts the noise added to action
noise = noise_pred_net(
    sample=noised_action,
    timestep=diffusion_iter,
    global_cond=obs.flatten(start_dim=1))

# illustration of removing noise
# the actual noise removal is performed by NoiseScheduler
# and is dependent on the diffusion noise schedule
denoised_action = noised_action - noise

# for this demo, we use DDPMScheduler with 100 diffusion iterations
num_diffusion_iters = 100
noise_scheduler = DDPMScheduler(
    num_train_timesteps=num_diffusion_iters,
    # the choise of beta schedule has big impact on performance
    # we found squared cosine works the best
    beta_schedule='squaredcos_cap_v2',
    # clip output to [-1,1] to improve stability
    clip_sample=True,
    # our network predicts noise (instead of denoised action)
    prediction_type='epsilon'
)

# device transfer
device = torch.device('cuda')
_ = noise_pred_net.to(device)

number of parameters: 6.535322e+07
x torch.Size([1, 1024, 4])
h_pop torch.Size([1, 1024, 4])
x torch.Size([1, 512, 8])
h_pop torch.Size([1, 512, 8])


In [9]:
#@markdown ### **Training**
#@markdown
#@markdown Takes about an hour. If you don't want to wait, skip to the next cell
#@markdown to load pre-trained weights

num_epochs = 100

# Exponential Moving Average
# accelerates training and improves stability
# holds a copy of the model weights
ema = EMAModel(
    parameters=noise_pred_net.parameters(),
    power=0.75)

# Standard ADAM optimizer
# Note that EMA parametesr are not optimized
optimizer = torch.optim.AdamW(
    params=noise_pred_net.parameters(),
    lr=1e-4, weight_decay=1e-6)

# Cosine LR schedule with linear warmup
lr_scheduler = get_scheduler(
    name='cosine',
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=len(dataloader) * num_epochs
)

with tqdm(range(num_epochs), desc='Epoch') as tglobal:
    # epoch loop
    for epoch_idx in tglobal:
        epoch_loss = list()
        # batch loop
        with tqdm(dataloader, desc='Batch', leave=False) as tepoch:
            for nbatch in tepoch:
                # data normalized in dataset
                # device transfer
                nobs = nbatch['obs'].to(device)
                naction = nbatch['action'].to(device)

                print("nobs.shape:", nobs.shape)
                print("naction.shape:", naction.shape)
                exit()


                B = nobs.shape[0]

                # observation as FiLM conditioning
                # (B, obs_horizon, obs_dim)
                obs_cond = nobs[:,:obs_horizon,:]
                # (B, obs_horizon * obs_dim)
                obs_cond = obs_cond.flatten(start_dim=1)

                # sample noise to add to actions
                noise = torch.randn(naction.shape, device=device)

                # sample a diffusion iteration for each data point
                timesteps = torch.randint(
                    0, noise_scheduler.config.num_train_timesteps,
                    (B,), device=device
                ).long()

                # add noise to the clean images according to the noise magnitude at each diffusion iteration
                # (this is the forward diffusion process)
                noisy_actions = noise_scheduler.add_noise(
                    naction, noise, timesteps)

                # predict the noise residual
                noise_pred = noise_pred_net(
                    noisy_actions, timesteps, global_cond=obs_cond)

                # L2 loss
                loss = nn.functional.mse_loss(noise_pred, noise)

                # optimize
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                # step lr scheduler every batch
                # this is different from standard pytorch behavior
                lr_scheduler.step()

                # update Exponential Moving Average of the model weights
                ema.step(noise_pred_net.parameters())

                # logging
                loss_cpu = loss.item()
                epoch_loss.append(loss_cpu)
                tepoch.set_postfix(loss=loss_cpu)
        tglobal.set_postfix(loss=np.mean(epoch_loss))

# Weights of the EMA model
# is used for inference
ema_noise_pred_net = noise_pred_net
ema.copy_to(ema_noise_pred_net.parameters())

Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:   1%|          | 1/100 [00:21<34:39, 21.01s/it, loss=0.713]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:   2%|▏         | 2/100 [00:42<34:20, 21.03s/it, loss=0.115]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:   3%|▎         | 3/100 [01:02<33:28, 20.71s/it, loss=0.0742]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  38%|███▊      | 36/95 [00:07<00:12,  4.90it/s, loss=0.0587][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  49%|████▉     | 47/95 [00:10<00:09,  4.89it/s, loss=0.0578][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:   4%|▍         | 4/100 [01:21<32:26, 20.27s/it, loss=0.0639]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:   5%|▌         | 5/100 [01:42<32:05, 20.27s/it, loss=0.062] 

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:   6%|▌         | 6/100 [02:02<31:44, 20.26s/it, loss=0.0563]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:   7%|▋         | 7/100 [02:22<31:25, 20.28s/it, loss=0.0526]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:   8%|▊         | 8/100 [02:42<31:01, 20.23s/it, loss=0.0502]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:   9%|▉         | 9/100 [03:03<30:36, 20.18s/it, loss=0.0489]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  10%|█         | 10/100 [03:23<30:12, 20.14s/it, loss=0.047]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  11%|█         | 11/100 [03:42<29:46, 20.07s/it, loss=0.0447]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  43%|████▎     | 41/95 [00:08<00:11,  4.87it/s, loss=0.0379][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  12%|█▏        | 12/100 [04:02<29:21, 20.02s/it, loss=0.0427]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  86%|████████▋ | 82/95 [00:17<00:02,  4.90it/s, loss=0.0527][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  13%|█▎        | 13/100 [04:22<28:54, 19.94s/it, loss=0.0426]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  36%|███▌      | 34/95 [00:07<00:12,  4.79it/s, loss=0.039] [A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  14%|█▍        | 14/100 [04:42<28:38, 19.98s/it, loss=0.0403]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  15%|█▌        | 15/100 [05:02<28:24, 20.05s/it, loss=0.0403]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  16%|█▌        | 16/100 [05:23<28:17, 20.21s/it, loss=0.0387]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  17%|█▋        | 17/100 [05:43<28:04, 20.30s/it, loss=0.0389]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  18%|█▊        | 18/100 [06:04<28:00, 20.50s/it, loss=0.0375]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  19%|█▉        | 19/100 [06:25<27:42, 20.53s/it, loss=0.0373]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  20%|██        | 20/100 [06:47<27:48, 20.86s/it, loss=0.0361]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  21%|██        | 21/100 [07:08<27:47, 21.11s/it, loss=0.0366]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  22%|██▏       | 22/100 [07:30<27:42, 21.31s/it, loss=0.0361]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  23%|██▎       | 23/100 [07:52<27:30, 21.44s/it, loss=0.035] 

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  24%|██▍       | 24/100 [08:14<27:15, 21.52s/it, loss=0.0346]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  25%|██▌       | 25/100 [08:35<26:55, 21.55s/it, loss=0.0341]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  26%|██▌       | 26/100 [08:57<26:40, 21.63s/it, loss=0.0341]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  27%|██▋       | 27/100 [09:19<26:21, 21.67s/it, loss=0.032] 

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  28%|██▊       | 28/100 [09:40<25:56, 21.61s/it, loss=0.0324]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  29%|██▉       | 29/100 [10:02<25:32, 21.58s/it, loss=0.0326]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  30%|███       | 30/100 [10:23<25:08, 21.54s/it, loss=0.0322]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  31%|███       | 31/100 [10:45<24:45, 21.53s/it, loss=0.031] 

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  32%|███▏      | 32/100 [11:06<24:23, 21.52s/it, loss=0.0316]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  33%|███▎      | 33/100 [11:28<23:59, 21.49s/it, loss=0.0309]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  34%|███▍      | 34/100 [11:49<23:35, 21.45s/it, loss=0.0308]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  35%|███▌      | 35/100 [12:11<23:14, 21.46s/it, loss=0.0303]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  36%|███▌      | 36/100 [12:32<22:55, 21.49s/it, loss=0.0299]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  37%|███▋      | 37/100 [12:54<22:38, 21.57s/it, loss=0.0297]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  38%|███▊      | 38/100 [13:16<22:20, 21.62s/it, loss=0.0299]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  39%|███▉      | 39/100 [13:37<22:02, 21.67s/it, loss=0.0299]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  40%|████      | 40/100 [13:59<21:40, 21.68s/it, loss=0.0286]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  41%|████      | 41/100 [14:21<21:23, 21.76s/it, loss=0.0288]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  42%|████▏     | 42/100 [14:43<21:04, 21.80s/it, loss=0.028] 

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  43%|████▎     | 43/100 [15:05<20:40, 21.76s/it, loss=0.0277]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  44%|████▍     | 44/100 [15:26<20:13, 21.67s/it, loss=0.0275]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  45%|████▌     | 45/100 [15:48<19:49, 21.62s/it, loss=0.0271]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  46%|████▌     | 46/100 [16:09<19:25, 21.58s/it, loss=0.0272]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  47%|████▋     | 47/100 [16:31<19:02, 21.56s/it, loss=0.0272]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  48%|████▊     | 48/100 [16:52<18:43, 21.60s/it, loss=0.0263]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  49%|████▉     | 49/100 [17:14<18:22, 21.63s/it, loss=0.0271]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  50%|█████     | 50/100 [17:36<18:03, 21.67s/it, loss=0.0264]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  51%|█████     | 51/100 [17:58<17:45, 21.74s/it, loss=0.0251]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  52%|█████▏    | 52/100 [18:19<17:25, 21.78s/it, loss=0.0257]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  53%|█████▎    | 53/100 [18:41<17:04, 21.79s/it, loss=0.0252]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  54%|█████▍    | 54/100 [19:03<16:40, 21.76s/it, loss=0.0251]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  55%|█████▌    | 55/100 [19:25<16:18, 21.75s/it, loss=0.0245]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  56%|█████▌    | 56/100 [19:47<15:58, 21.77s/it, loss=0.0247]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  57%|█████▋    | 57/100 [20:08<15:33, 21.71s/it, loss=0.0241]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  58%|█████▊    | 58/100 [20:30<15:09, 21.66s/it, loss=0.0245]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  59%|█████▉    | 59/100 [20:51<14:46, 21.63s/it, loss=0.0233]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  60%|██████    | 60/100 [21:13<14:23, 21.60s/it, loss=0.0234]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  61%|██████    | 61/100 [21:34<14:01, 21.57s/it, loss=0.0227]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  62%|██████▏   | 62/100 [21:56<13:39, 21.57s/it, loss=0.023] 

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  63%|██████▎   | 63/100 [22:18<13:21, 21.67s/it, loss=0.0224]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  24%|██▍       | 23/95 [00:05<00:15,  4.54it/s, loss=0.0194][A
Batch:  24%|██▍       | 23/95 [00:05<00:15,  4.54it/s, loss=0.0216]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  34%|███▎      | 32/95 [00:07<00:13,  4.85it/s, loss=0.023] 

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  42%|████▏     | 40/95 [00:09<00:11,  5.00it/s, loss=0.0236][A
Batch:  43%|████▎     | 41/95 [00:09<00:10,  4.98it/s, loss=0.019] 

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  44%|████▍     | 42/95 [00:09<00:10,  4.97it/s, loss=0.0217]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  45%|████▌     | 43/95 [00:09<00:10,  4.96it/s, loss=0.025] [A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  47%|████▋     | 45/95 [00:10<00:10,  4.94it/s, loss=0.0213]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  48%|████▊     | 46/95 [00:10<00:09,  4.95it/s, loss=0.024] 

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  49%|████▉     | 47/95 [00:10<00:09,  4.94it/s, loss=0.0168]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  51%|█████     | 48/95 [00:10<00:09,  4.95it/s, loss=0.0287]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  53%|█████▎    | 50/95 [00:11<00:09,  4.94it/s, loss=0.0198]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  55%|█████▍    | 52/95 [00:11<00:08,  4.95it/s, loss=0.0193][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  56%|█████▌    | 53/95 [00:11<00:08,  4.94it/s, loss=0.0206][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  58%|█████▊    | 55/95 [00:12<00:08,  4.94it/s, loss=0.0191][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  62%|██████▏   | 59/95 [00:12<00:07,  4.94it/s, loss=0.0162]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  63%|██████▎   | 60/95 [00:13<00:07,  4.94it/s, loss=0.0203][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  65%|██████▌   | 62/95 [00:13<00:06,  4.95it/s, loss=0.0202][A
Batch:  65%|██████▌   | 62/95 [00:13<00:06,  4.95it/s, loss=0.0201]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  67%|██████▋   | 64/95 [00:13<00:06,  4.95it/s, loss=0.0286][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  67%|██████▋   | 64/95 [00:13<00:06,  4.95it/s, loss=0.0201][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])


Epoch:  64%|██████▍   | 64/100 [22:38<12:41, 21.15s/it, loss=0.0218]

x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:   1%|          | 1/95 [00:00<00:19,  4.71it/s, loss=0.0225]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:   2%|▏         | 2/95 [00:00<00:19,  4.84it/s, loss=0.0179]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   3%|▎         | 3/95 [00:00<00:18,  4.89it/s, loss=0.0163][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   5%|▌         | 5/95 [00:01<00:18,  4.91it/s, loss=0.0269][A
Batch:   6%|▋         | 6/95 [00:01<00:18,  4.92it/s, loss=0.0177]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   8%|▊         | 8/95 [00:01<00:17,  4.94it/s, loss=0.0189][A
Batch:   8%|▊         | 8/95 [00:01<00:17,  4.94it/s, loss=0.0161]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   9%|▉         | 9/95 [00:02<00:17,  4.94it/s, loss=0.0249][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  12%|█▏        | 11/95 [00:02<00:16,  4.95it/s, loss=0.0193][A
Batch:  12%|█▏        | 11/95 [00:02<00:16,  4.95it/s, loss=0.0195]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  14%|█▎        | 13/95 [00:02<00:16,  4.95it/s, loss=0.0221]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  18%|█▊        | 17/95 [00:03<00:15,  4.95it/s, loss=0.0183]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  18%|█▊        | 17/95 [00:03<00:15,  4.95it/s, loss=0.0182][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  20%|██        | 19/95 [00:03<00:15,  4.95it/s, loss=0.0206][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  21%|██        | 20/95 [00:04<00:15,  4.94it/s, loss=0.0264][A
Batch:  22%|██▏       | 21/95 [00:04<00:14,  4.94it/s, loss=0.0267]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  23%|██▎       | 22/95 [00:04<00:14,  4.94it/s, loss=0.0188]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  24%|██▍       | 23/95 [00:04<00:14,  4.95it/s, loss=0.0238][A
Batch:  25%|██▌       | 24/95 [00:05<00:14,  4.95it/s, loss=0.0244]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  26%|██▋       | 25/95 [00:05<00:14,  4.95it/s, loss=0.0216][A
Batch:  27%|██▋       | 26/95 [00:05<00:13,  4.94it/s, loss=0.0231]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  28%|██▊       | 27/95 [00:05<00:13,  4.95it/s, loss=0.0181][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  31%|███       | 29/95 [00:05<00:13,  4.95it/s, loss=0.0246][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  33%|███▎      | 31/95 [00:06<00:12,  4.94it/s, loss=0.0205]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  33%|███▎      | 31/95 [00:06<00:12,  4.94it/s, loss=0.0243][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  35%|███▍      | 33/95 [00:06<00:12,  4.95it/s, loss=0.0234][A
Batch:  36%|███▌      | 34/95 [00:07<00:12,  4.94it/s, loss=0.0239]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  38%|███▊      | 36/95 [00:07<00:11,  4.94it/s, loss=0.0235][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  39%|███▉      | 37/95 [00:07<00:11,  4.94it/s, loss=0.0258][A
Batch:  41%|████      | 39/95 [00:07<00:11,  4.94it/s, loss=0.0231]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  41%|████      | 39/95 [00:08<00:11,  4.94it/s, loss=0.0222]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  43%|████▎     | 41/95 [00:08<00:10,  4.95it/s, loss=0.0237]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  43%|████▎     | 41/95 [00:08<00:10,  4.95it/s, loss=0.0224][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  45%|████▌     | 43/95 [00:08<00:10,  4.95it/s, loss=0.0249][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  46%|████▋     | 44/95 [00:08<00:10,  4.96it/s, loss=0.0192][A
Batch:  46%|████▋     | 44/95 [00:09<00:10,  4.96it/s, loss=0.0222]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  48%|████▊     | 46/95 [00:09<00:09,  4.96it/s, loss=0.0241][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])


Epoch:  65%|██████▌   | 65/100 [22:57<11:59, 20.55s/it, loss=0.0224]

x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])





x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   1%|          | 1/95 [00:00<00:19,  4.79it/s, loss=0.0198][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])



Batch:   3%|▎         | 3/95 [00:00<00:18,  4.91it/s, loss=0.02]

x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])





x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   3%|▎         | 3/95 [00:00<00:18,  4.91it/s, loss=0.0194][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  95%|█████████▍| 90/95 [00:18<00:01,  4.93it/s, loss=0.0233][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  97%|█████████▋| 92/95 [00:18<00:00,  4.94it/s, loss=0.0191][A
Batch:  98%|█████████▊| 93/95 [00:18<00:00,  4.93it/s, loss=0.0215]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  66%|██████▌   | 66/100 [23:16<11:24, 20.12s/it, loss=0.0224]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:   1%|          | 1/95 [00:00<00:20,  4.69it/s, loss=0.022]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   3%|▎         | 3/95 [00:00<00:18,  4.89it/s, loss=0.0173][A
Batch:   4%|▍         | 4/95 [00:01<00:18,  4.91it/s, loss=0.0195]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:   5%|▌         | 5/95 [00:01<00:18,  4.92it/s, loss=0.0262]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   6%|▋         | 6/95 [00:01<00:18,  4.93it/s, loss=0.022] [A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   9%|▉         | 9/95 [00:02<00:17,  4.93it/s, loss=0.0217][A
Batch:  11%|█         | 10/95 [00:02<00:17,  4.94it/s, loss=0.0202]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  13%|█▎        | 12/95 [00:02<00:16,  4.95it/s, loss=0.027][A
Batch:  14%|█▎        | 13/95 [00:02<00:16,  4.95it/s, loss=0.022]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  14%|█▎        | 13/95 [00:02<00:16,  4.95it/s, loss=0.0227][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  15%|█▍        | 14/95 [00:03<00:16,  4.95it/s, loss=0.0218][A
Batch:  16%|█▌        | 15/95 [00:03<00:16,  4.95it/s, loss=0.0259]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  18%|█▊        | 17/95 [00:03<00:15,  4.95it/s, loss=0.0227]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  18%|█▊        | 17/95 [00:03<00:15,  4.95it/s, loss=0.0249][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  19%|█▉        | 18/95 [00:03<00:15,  4.95it/s, loss=0.0193][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  22%|██▏       | 21/95 [00:04<00:14,  4.96it/s, loss=0.0219]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  22%|██▏       | 21/95 [00:04<00:14,  4.96it/s, loss=0.0179][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  25%|██▌       | 24/95 [00:05<00:14,  4.94it/s, loss=0.024]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  26%|██▋       | 25/95 [00:05<00:14,  4.94it/s, loss=0.0167][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  45%|████▌     | 43/95 [00:08<00:10,  5.01it/s, loss=0.0211]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  64%|██████▍   | 61/95 [00:12<00:06,  5.00it/s, loss=0.0254]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  65%|██████▌   | 62/95 [00:12<00:06,  4.99it/s, loss=0.0179]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  66%|██████▋   | 63/95 [00:12<00:06,  4.96it/s, loss=0.0203]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  67%|██████▋   | 64/95 [00:13<00:06,  4.96it/s, loss=0.0257][A
Batch:  68%|██████▊   | 65/95 [00:13<00:06,  4.95it/s, loss=0.0219]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  69%|██████▉   | 66/95 [00:13<00:05,  4.95it/s, loss=0.0267][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  72%|███████▏  | 68/95 [00:13<00:05,  4.93it/s, loss=0.0273]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  73%|███████▎  | 69/95 [00:14<00:05,  4.93it/s, loss=0.0256][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  75%|███████▍  | 71/95 [00:14<00:04,  4.94it/s, loss=0.0178][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  75%|███████▍  | 71/95 [00:14<00:04,  4.94it/s, loss=0.0238][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  77%|███████▋  | 73/95 [00:14<00:04,  4.93it/s, loss=0.0248][A
Batch:  78%|███████▊  | 74/95 [00:15<00:04,  4.93it/s, loss=0.0245]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  79%|███████▉  | 75/95 [00:15<00:04,  4.93it/s, loss=0.0235][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  80%|████████  | 76/95 [00:15<00:03,  4.93it/s, loss=0.0174][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  81%|████████  | 77/95 [00:15<00:03,  4.94it/s, loss=0.0257][A
Batch:  82%|████████▏ | 78/95 [00:15<00:03,  4.94it/s, loss=0.0171]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  83%|████████▎ | 79/95 [00:16<00:03,  4.94it/s, loss=0.0255][A
Batch:  84%|████████▍ | 80/95 [00:16<00:03,  4.94it/s, loss=0.0236]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  85%|████████▌ | 81/95 [00:16<00:02,  4.94it/s, loss=0.0248]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  86%|████████▋ | 82/95 [00:16<00:02,  4.94it/s, loss=0.0195]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  87%|████████▋ | 83/95 [00:16<00:02,  4.94it/s, loss=0.0187][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  89%|████████▉ | 85/95 [00:17<00:02,  4.94it/s, loss=0.0161][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  89%|████████▉ | 85/95 [00:17<00:02,  4.94it/s, loss=0.0206][A
Batch:  91%|█████████ | 86/95 [00:17<00:01,  4.94it/s, loss=0.0223]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  93%|█████████▎| 88/95 [00:17<00:01,  4.93it/s, loss=0.0184]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  94%|█████████▎| 89/95 [00:18<00:01,  4.94it/s, loss=0.0231]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  95%|█████████▍| 90/95 [00:18<00:01,  4.94it/s, loss=0.0233]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  96%|█████████▌| 91/95 [00:18<00:00,  4.94it/s, loss=0.02]  [A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  97%|█████████▋| 92/95 [00:18<00:00,  4.94it/s, loss=0.0286][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  67%|██████▋   | 67/100 [23:35<10:54, 19.83s/it, loss=0.0223][A

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])





x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   1%|          | 1/95 [00:00<00:19,  4.70it/s, loss=0.0178][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])



Batch:   7%|▋         | 7/95 [00:01<00:17,  4.99it/s, loss=0.0217]

x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   8%|▊         | 8/95 [00:01<00:17,  4.99it/s, loss=0.0183][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  24%|██▍       | 23/95 [00:04<00:14,  5.00it/s, loss=0.0226][A
Batch:  25%|██▌       | 24/95 [00:04<00:14,  4.99it/s, loss=0.0153]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  25%|██▌       | 24/95 [00:05<00:14,  4.99it/s, loss=0.0182][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  26%|██▋       | 25/95 [00:05<00:14,  4.97it/s, loss=0.0221][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  27%|██▋       | 26/95 [00:05<00:13,  4.96it/s, loss=0.0233][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  29%|██▉       | 28/95 [00:05<00:13,  4.94it/s, loss=0.0183]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  31%|███       | 29/95 [00:06<00:13,  4.95it/s, loss=0.0191]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  32%|███▏      | 30/95 [00:06<00:13,  4.95it/s, loss=0.0282]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  34%|███▎      | 32/95 [00:06<00:12,  4.95it/s, loss=0.0231][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  35%|███▍      | 33/95 [00:06<00:12,  4.94it/s, loss=0.0184][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  37%|███▋      | 35/95 [00:07<00:12,  4.95it/s, loss=0.0216][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  38%|███▊      | 36/95 [00:07<00:11,  4.94it/s, loss=0.0191][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  39%|███▉      | 37/95 [00:07<00:11,  4.94it/s, loss=0.0219][A
Batch:  40%|████      | 38/95 [00:07<00:11,  4.94it/s, loss=0.0185]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  42%|████▏     | 40/95 [00:08<00:11,  4.94it/s, loss=0.022] [A
Batch:  43%|████▎     | 41/95 [00:08<00:10,  4.94it/s, loss=0.025]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  46%|████▋     | 44/95 [00:09<00:10,  4.94it/s, loss=0.022] 

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  49%|████▉     | 47/95 [00:09<00:09,  4.93it/s, loss=0.0215]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  51%|█████     | 48/95 [00:09<00:09,  4.94it/s, loss=0.0236][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  52%|█████▏    | 49/95 [00:10<00:09,  4.94it/s, loss=0.0216][A
Batch:  53%|█████▎    | 50/95 [00:10<00:09,  4.94it/s, loss=0.0226]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  56%|█████▌    | 53/95 [00:10<00:08,  4.94it/s, loss=0.0236][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  64%|██████▍   | 61/95 [00:12<00:06,  4.95it/s, loss=0.0167][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  68%|██████▊   | 65/95 [00:13<00:06,  4.95it/s, loss=0.0202][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  71%|███████   | 67/95 [00:13<00:05,  4.95it/s, loss=0.0178][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  80%|████████  | 76/95 [00:15<00:03,  4.98it/s, loss=0.0261]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  97%|█████████▋| 92/95 [00:18<00:00,  4.96it/s, loss=0.0229][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  68%|██████▊   | 68/100 [23:54<10:27, 19.61s/it, loss=0.0217][A

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   1%|          | 1/95 [00:00<00:19,  4.76it/s, loss=0.0196][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   2%|▏         | 2/95 [00:00<00:19,  4.86it/s, loss=0.0272][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   4%|▍         | 4/95 [00:01<00:18,  4.92it/s, loss=0.0154][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   8%|▊         | 8/95 [00:01<00:17,  4.93it/s, loss=0.0197][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  17%|█▋        | 16/95 [00:03<00:15,  4.95it/s, loss=0.0157]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  34%|███▎      | 32/95 [00:06<00:12,  4.99it/s, loss=0.0172][A
Batch:  35%|███▍      | 33/95 [00:06<00:12,  4.98it/s, loss=0.0204]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  35%|███▍      | 33/95 [00:06<00:12,  4.98it/s, loss=0.021] 

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])


Epoch:  69%|██████▉   | 69/100 [24:13<10:03, 19.47s/it, loss=0.0207]

x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  64%|██████▍   | 61/95 [00:12<00:06,  4.98it/s, loss=0.0236]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  64%|██████▍   | 61/95 [00:12<00:06,  4.98it/s, loss=0.02]  

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  77%|███████▋  | 73/95 [00:14<00:04,  5.00it/s, loss=0.0272][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  78%|███████▊  | 74/95 [00:15<00:04,  4.98it/s, loss=0.028] [A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  80%|████████  | 76/95 [00:15<00:03,  4.95it/s, loss=0.0247]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  81%|████████  | 77/95 [00:15<00:03,  4.95it/s, loss=0.019] [A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  82%|████████▏ | 78/95 [00:15<00:03,  4.95it/s, loss=0.0197][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  83%|████████▎ | 79/95 [00:16<00:03,  4.94it/s, loss=0.0267][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  84%|████████▍ | 80/95 [00:16<00:03,  4.94it/s, loss=0.0184][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  85%|████████▌ | 81/95 [00:16<00:02,  4.94it/s, loss=0.0146][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  86%|████████▋ | 82/95 [00:16<00:02,  4.93it/s, loss=0.0215][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  87%|████████▋ | 83/95 [00:16<00:02,  4.93it/s, loss=0.0156][A
Batch:  88%|████████▊ | 84/95 [00:17<00:02,  4.94it/s, loss=0.017] 

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  92%|█████████▏| 87/95 [00:17<00:01,  4.94it/s, loss=0.0248][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  93%|█████████▎| 88/95 [00:17<00:01,  4.94it/s, loss=0.018] [A
Batch:  94%|█████████▎| 89/95 [00:18<00:01,  4.94it/s, loss=0.0172]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  97%|█████████▋| 92/95 [00:18<00:00,  4.94it/s, loss=0.0206]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  98%|█████████▊| 93/95 [00:18<00:00,  4.94it/s, loss=0.03]  

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  70%|███████   | 70/100 [24:32<09:40, 19.36s/it, loss=0.0209]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  15%|█▍        | 14/95 [00:03<00:16,  4.97it/s, loss=0.0182]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  16%|█▌        | 15/95 [00:03<00:16,  4.96it/s, loss=0.0233][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  21%|██        | 20/95 [00:04<00:15,  4.94it/s, loss=0.0215][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  71%|███████   | 67/95 [00:13<00:05,  4.96it/s, loss=0.0267][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  73%|███████▎  | 69/95 [00:14<00:05,  4.96it/s, loss=0.0218][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  75%|███████▍  | 71/95 [00:14<00:04,  4.96it/s, loss=0.0211][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])


Epoch:  71%|███████   | 71/100 [24:51<09:19, 19.28s/it, loss=0.0206]

x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   1%|          | 1/95 [00:00<00:20,  4.67it/s, loss=0.016] [A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   2%|▏         | 2/95 [00:00<00:19,  4.82it/s, loss=0.02] [A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   3%|▎         | 3/95 [00:00<00:18,  4.87it/s, loss=0.0206][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:   6%|▋         | 6/95 [00:01<00:18,  4.90it/s, loss=0.0228]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   7%|▋         | 7/95 [00:01<00:17,  4.91it/s, loss=0.0218][A
Batch:   8%|▊         | 8/95 [00:01<00:17,  4.92it/s, loss=0.0188]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  11%|█         | 10/95 [00:02<00:17,  4.93it/s, loss=0.0216][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  15%|█▍        | 14/95 [00:03<00:16,  4.92it/s, loss=0.0265]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  22%|██▏       | 21/95 [00:04<00:14,  4.94it/s, loss=0.0201]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  22%|██▏       | 21/95 [00:04<00:14,  4.94it/s, loss=0.0201][A
Batch:  23%|██▎       | 22/95 [00:04<00:14,  4.94it/s, loss=0.0234]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  51%|█████     | 48/95 [00:09<00:09,  4.98it/s, loss=0.0191][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  53%|█████▎    | 50/95 [00:10<00:09,  4.95it/s, loss=0.0252][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  54%|█████▎    | 51/95 [00:10<00:08,  4.94it/s, loss=0.0236][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  55%|█████▍    | 52/95 [00:10<00:08,  4.94it/s, loss=0.0163][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  60%|██████    | 57/95 [00:11<00:07,  4.94it/s, loss=0.0259][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  61%|██████    | 58/95 [00:11<00:07,  4.93it/s, loss=0.0249][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  63%|██████▎   | 60/95 [00:12<00:07,  4.93it/s, loss=0.0172]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  64%|██████▍   | 61/95 [00:12<00:06,  4.93it/s, loss=0.019] 

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  65%|██████▌   | 62/95 [00:12<00:06,  4.94it/s, loss=0.0197][A
Batch:  66%|██████▋   | 63/95 [00:12<00:06,  4.94it/s, loss=0.0271]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  67%|██████▋   | 64/95 [00:13<00:06,  4.94it/s, loss=0.0151][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  69%|██████▉   | 66/95 [00:13<00:05,  4.93it/s, loss=0.016] 

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  71%|███████   | 67/95 [00:13<00:05,  4.94it/s, loss=0.0189]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  73%|███████▎  | 69/95 [00:13<00:05,  4.95it/s, loss=0.0221]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  75%|███████▍  | 71/95 [00:14<00:04,  4.94it/s, loss=0.0173][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  76%|███████▌  | 72/95 [00:14<00:04,  4.94it/s, loss=0.0173][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])


Epoch:  72%|███████▏  | 72/100 [25:11<08:58, 19.24s/it, loss=0.0202]

x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   2%|▏         | 2/95 [00:00<00:19,  4.84it/s, loss=0.0235][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:   5%|▌         | 5/95 [00:01<00:18,  4.94it/s, loss=0.0172]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:   7%|▋         | 7/95 [00:01<00:17,  4.95it/s, loss=0.0184]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  13%|█▎        | 12/95 [00:02<00:16,  4.97it/s, loss=0.0183]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  13%|█▎        | 12/95 [00:02<00:16,  4.97it/s, loss=0.0202][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  14%|█▎        | 13/95 [00:02<00:16,  4.95it/s, loss=0.0161][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  19%|█▉        | 18/95 [00:03<00:15,  4.95it/s, loss=0.0168]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  20%|██        | 19/95 [00:04<00:15,  4.95it/s, loss=0.0176]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  22%|██▏       | 21/95 [00:04<00:14,  4.95it/s, loss=0.0176][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  24%|██▍       | 23/95 [00:04<00:14,  4.95it/s, loss=0.0224][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  26%|██▋       | 25/95 [00:05<00:14,  4.94it/s, loss=0.02]  [A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  63%|██████▎   | 60/95 [00:12<00:07,  4.99it/s, loss=0.0249]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])


Epoch:  73%|███████▎  | 73/100 [25:30<08:38, 19.20s/it, loss=0.0196]

x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])





x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   1%|          | 1/95 [00:00<00:19,  4.81it/s, loss=0.0249][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  14%|█▎        | 13/95 [00:02<00:16,  4.96it/s, loss=0.0237]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  15%|█▍        | 14/95 [00:03<00:16,  4.96it/s, loss=0.0229][A
Batch:  16%|█▌        | 15/95 [00:03<00:16,  4.95it/s, loss=0.0192]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  18%|█▊        | 17/95 [00:03<00:15,  4.96it/s, loss=0.0227]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  18%|█▊        | 17/95 [00:03<00:15,  4.96it/s, loss=0.0254][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  89%|████████▉ | 85/95 [00:17<00:02,  4.99it/s, loss=0.0183][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  91%|█████████ | 86/95 [00:17<00:01,  4.97it/s, loss=0.0198][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  92%|█████████▏| 87/95 [00:17<00:01,  4.96it/s, loss=0.0171][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  94%|█████████▎| 89/95 [00:18<00:01,  4.94it/s, loss=0.0186]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  96%|█████████▌| 91/95 [00:18<00:00,  4.95it/s, loss=0.0216][A
Batch:  96%|█████████▌| 91/95 [00:18<00:00,  4.95it/s, loss=0.0127]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  97%|█████████▋| 92/95 [00:18<00:00,  4.95it/s, loss=0.0166][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  74%|███████▍  | 74/100 [25:49<08:18, 19.17s/it, loss=0.0196][A

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:   1%|          | 1/95 [00:00<00:20,  4.69it/s, loss=0.0172]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   2%|▏         | 2/95 [00:00<00:19,  4.84it/s, loss=0.019] [A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   3%|▎         | 3/95 [00:00<00:18,  4.88it/s, loss=0.0164][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   4%|▍         | 4/95 [00:01<00:18,  4.90it/s, loss=0.0219][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  11%|█         | 10/95 [00:02<00:17,  4.87it/s, loss=0.0162][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  43%|████▎     | 41/95 [00:08<00:10,  4.97it/s, loss=0.019] [A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  44%|████▍     | 42/95 [00:08<00:10,  4.96it/s, loss=0.0179][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  46%|████▋     | 44/95 [00:09<00:10,  4.94it/s, loss=0.0164]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  47%|████▋     | 45/95 [00:09<00:10,  4.94it/s, loss=0.0179]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  48%|████▊     | 46/95 [00:09<00:09,  4.94it/s, loss=0.0167]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  49%|████▉     | 47/95 [00:09<00:09,  4.95it/s, loss=0.0217]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  53%|█████▎    | 50/95 [00:10<00:09,  4.94it/s, loss=0.0183]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  54%|█████▎    | 51/95 [00:10<00:08,  4.94it/s, loss=0.0175][A
Batch:  55%|█████▍    | 52/95 [00:10<00:08,  4.94it/s, loss=0.0165]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  57%|█████▋    | 54/95 [00:11<00:08,  4.94it/s, loss=0.0146]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  58%|█████▊    | 55/95 [00:11<00:08,  4.94it/s, loss=0.0173][A
Batch:  59%|█████▉    | 56/95 [00:11<00:07,  4.94it/s, loss=0.0213]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  60%|██████    | 57/95 [00:11<00:07,  4.94it/s, loss=0.014] [A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  62%|██████▏   | 59/95 [00:12<00:07,  4.93it/s, loss=0.0232]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  63%|██████▎   | 60/95 [00:12<00:07,  4.94it/s, loss=0.0254][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  65%|██████▌   | 62/95 [00:12<00:06,  4.95it/s, loss=0.0196][A
Batch:  65%|██████▌   | 62/95 [00:12<00:06,  4.95it/s, loss=0.0118]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  67%|██████▋   | 64/95 [00:13<00:06,  4.93it/s, loss=0.0245]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  71%|███████   | 67/95 [00:13<00:05,  4.94it/s, loss=0.0207][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  74%|███████▎  | 70/95 [00:14<00:05,  4.94it/s, loss=0.0159]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  75%|███████▌  | 75/100 [26:08<07:58, 19.15s/it, loss=0.0191]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   1%|          | 1/95 [00:00<00:19,  4.75it/s, loss=0.0188][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:   3%|▎         | 3/95 [00:00<00:18,  4.89it/s, loss=0.0261]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:   4%|▍         | 4/95 [00:01<00:18,  4.91it/s, loss=0.0185]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:   6%|▋         | 6/95 [00:01<00:18,  4.94it/s, loss=0.022]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   7%|▋         | 7/95 [00:01<00:17,  4.95it/s, loss=0.0217][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   7%|▋         | 7/95 [00:01<00:17,  4.95it/s, loss=0.0186][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  11%|█         | 10/95 [00:02<00:17,  4.96it/s, loss=0.0211][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  13%|█▎        | 12/95 [00:02<00:16,  4.96it/s, loss=0.0185][A
Batch:  14%|█▎        | 13/95 [00:02<00:16,  4.95it/s, loss=0.0251]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  16%|█▌        | 15/95 [00:03<00:16,  4.96it/s, loss=0.018][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  17%|█▋        | 16/95 [00:03<00:15,  4.96it/s, loss=0.0214][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  18%|█▊        | 17/95 [00:03<00:15,  4.96it/s, loss=0.017][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  18%|█▊        | 17/95 [00:03<00:15,  4.96it/s, loss=0.0188][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  20%|██        | 19/95 [00:03<00:15,  4.96it/s, loss=0.0203][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  21%|██        | 20/95 [00:04<00:15,  4.96it/s, loss=0.0162][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  22%|██▏       | 21/95 [00:04<00:14,  4.96it/s, loss=0.0209][A
Batch:  22%|██▏       | 21/95 [00:04<00:14,  4.96it/s, loss=0.0173]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  27%|██▋       | 26/95 [00:05<00:13,  4.96it/s, loss=0.0185]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  27%|██▋       | 26/95 [00:05<00:13,  4.96it/s, loss=0.0179][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  29%|██▉       | 28/95 [00:05<00:13,  4.96it/s, loss=0.0202][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  29%|██▉       | 28/95 [00:05<00:13,  4.96it/s, loss=0.0193][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  32%|███▏      | 30/95 [00:06<00:13,  4.96it/s, loss=0.02]  [A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])


Epoch:  76%|███████▌  | 76/100 [26:27<07:39, 19.14s/it, loss=0.0194]

x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])



Batch:   5%|▌         | 5/95 [00:01<00:18,  4.94it/s, loss=0.0162]

x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])



Batch:   7%|▋         | 7/95 [00:01<00:17,  4.95it/s, loss=0.0218]

x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:   9%|▉         | 9/95 [00:02<00:17,  4.95it/s, loss=0.0235]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  12%|█▏        | 11/95 [00:02<00:16,  4.95it/s, loss=0.0191][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  19%|█▉        | 18/95 [00:03<00:15,  4.98it/s, loss=0.0175][A
Batch:  20%|██        | 19/95 [00:04<00:15,  4.97it/s, loss=0.0144]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  21%|██        | 20/95 [00:04<00:15,  4.96it/s, loss=0.019] [A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  23%|██▎       | 22/95 [00:04<00:14,  4.95it/s, loss=0.0186][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  25%|██▌       | 24/95 [00:04<00:14,  4.95it/s, loss=0.0185][A
Batch:  25%|██▌       | 24/95 [00:05<00:14,  4.95it/s, loss=0.0197]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  28%|██▊       | 27/95 [00:05<00:13,  4.95it/s, loss=0.0197][A
Batch:  29%|██▉       | 28/95 [00:05<00:13,  4.95it/s, loss=0.0219]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  31%|███       | 29/95 [00:06<00:13,  4.94it/s, loss=0.0278]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  32%|███▏      | 30/95 [00:06<00:13,  4.94it/s, loss=0.0222]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  35%|███▍      | 33/95 [00:06<00:12,  4.94it/s, loss=0.0197][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  37%|███▋      | 35/95 [00:07<00:12,  4.94it/s, loss=0.0178][A
Batch:  37%|███▋      | 35/95 [00:07<00:12,  4.94it/s, loss=0.0141]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  38%|███▊      | 36/95 [00:07<00:11,  4.94it/s, loss=0.0203]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  39%|███▉      | 37/95 [00:07<00:11,  4.95it/s, loss=0.0184][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  40%|████      | 38/95 [00:07<00:11,  4.94it/s, loss=0.02]  [A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  42%|████▏     | 40/95 [00:08<00:11,  4.95it/s, loss=0.0171][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  43%|████▎     | 41/95 [00:08<00:10,  4.95it/s, loss=0.0189][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  46%|████▋     | 44/95 [00:09<00:10,  4.95it/s, loss=0.0163][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])


Epoch:  77%|███████▋  | 77/100 [26:46<07:19, 19.12s/it, loss=0.0198]

x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  28%|██▊       | 27/95 [00:05<00:13,  5.00it/s, loss=0.0206][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  29%|██▉       | 28/95 [00:05<00:13,  4.98it/s, loss=0.0159][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  31%|███       | 29/95 [00:06<00:13,  4.97it/s, loss=0.0159][A
Batch:  32%|███▏      | 30/95 [00:06<00:13,  4.96it/s, loss=0.0139]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  34%|███▎      | 32/95 [00:06<00:12,  4.94it/s, loss=0.0181]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  37%|███▋      | 35/95 [00:07<00:12,  4.95it/s, loss=0.0167][A
Batch:  37%|███▋      | 35/95 [00:07<00:12,  4.95it/s, loss=0.0174]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  39%|███▉      | 37/95 [00:07<00:11,  4.94it/s, loss=0.0183][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  40%|████      | 38/95 [00:07<00:11,  4.94it/s, loss=0.0179][A
Batch:  41%|████      | 39/95 [00:08<00:11,  4.94it/s, loss=0.0186]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  42%|████▏     | 40/95 [00:08<00:11,  4.94it/s, loss=0.0214][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  92%|█████████▏| 87/95 [00:17<00:01,  4.98it/s, loss=0.0176][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  93%|█████████▎| 88/95 [00:17<00:01,  4.97it/s, loss=0.022] [A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  94%|█████████▎| 89/95 [00:18<00:01,  4.96it/s, loss=0.0111][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  95%|█████████▍| 90/95 [00:18<00:01,  4.95it/s, loss=0.0184][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  97%|█████████▋| 92/95 [00:18<00:00,  4.93it/s, loss=0.02]  [A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  78%|███████▊  | 78/100 [27:05<07:00, 19.11s/it, loss=0.019] 

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])





x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   2%|▏         | 2/95 [00:00<00:19,  4.87it/s, loss=0.0148][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])





x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   4%|▍         | 4/95 [00:01<00:18,  4.92it/s, loss=0.0207][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])



Batch:   6%|▋         | 6/95 [00:01<00:18,  4.94it/s, loss=0.0164]

x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])





x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   7%|▋         | 7/95 [00:01<00:17,  4.93it/s, loss=0.0226][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])





x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   9%|▉         | 9/95 [00:02<00:17,  4.95it/s, loss=0.0195][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  24%|██▍       | 23/95 [00:04<00:14,  4.98it/s, loss=0.0198][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])


Epoch:  79%|███████▉  | 79/100 [27:24<06:41, 19.10s/it, loss=0.0188]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  41%|████      | 39/95 [00:08<00:11,  5.01it/s, loss=0.0214]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  43%|████▎     | 41/95 [00:08<00:10,  4.99it/s, loss=0.0133][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  57%|█████▋    | 54/95 [00:11<00:08,  4.95it/s, loss=0.0195][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  59%|█████▉    | 56/95 [00:11<00:07,  4.93it/s, loss=0.0192]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  62%|██████▏   | 59/95 [00:12<00:07,  4.82it/s, loss=0.0219]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  64%|██████▍   | 61/95 [00:12<00:06,  4.88it/s, loss=0.0157]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  66%|██████▋   | 63/95 [00:12<00:06,  4.91it/s, loss=0.022] [A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  68%|██████▊   | 65/95 [00:13<00:06,  4.93it/s, loss=0.0225][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  71%|███████   | 67/95 [00:13<00:05,  4.94it/s, loss=0.0196]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  73%|███████▎  | 69/95 [00:13<00:05,  4.95it/s, loss=0.0135][A
Batch:  73%|███████▎  | 69/95 [00:14<00:05,  4.95it/s, loss=0.0168]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  76%|███████▌  | 72/95 [00:14<00:04,  4.95it/s, loss=0.0189][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  76%|███████▌  | 72/95 [00:14<00:04,  4.95it/s, loss=0.0197][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  78%|███████▊  | 74/95 [00:14<00:04,  4.95it/s, loss=0.0163][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  78%|███████▊  | 74/95 [00:15<00:04,  4.95it/s, loss=0.0209][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  79%|███████▉  | 75/95 [00:15<00:04,  4.94it/s, loss=0.0188][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  80%|████████  | 76/95 [00:15<00:03,  4.94it/s, loss=0.0174][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  82%|████████▏ | 78/95 [00:15<00:03,  4.95it/s, loss=0.0184][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  85%|████████▌ | 81/95 [00:16<00:02,  4.94it/s, loss=0.0178][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  87%|████████▋ | 83/95 [00:16<00:02,  4.95it/s, loss=0.0149][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  91%|█████████ | 86/95 [00:17<00:01,  4.94it/s, loss=0.0179]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  92%|█████████▏| 87/95 [00:17<00:01,  4.94it/s, loss=0.0171][A
Batch:  94%|█████████▎| 89/95 [00:17<00:01,  4.95it/s, loss=0.017]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  95%|█████████▍| 90/95 [00:18<00:01,  4.94it/s, loss=0.0192]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  98%|█████████▊| 93/95 [00:18<00:00,  4.94it/s, loss=0.014] 

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  80%|████████  | 80/100 [27:43<06:22, 19.12s/it, loss=0.0187]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   4%|▍         | 4/95 [00:01<00:18,  4.98it/s, loss=0.0224][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   5%|▌         | 5/95 [00:01<00:18,  4.96it/s, loss=0.0195][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:   8%|▊         | 8/95 [00:01<00:17,  4.95it/s, loss=0.0231]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  11%|█         | 10/95 [00:02<00:17,  4.96it/s, loss=0.0158]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  19%|█▉        | 18/95 [00:03<00:15,  4.99it/s, loss=0.0165][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  21%|██        | 20/95 [00:04<00:15,  4.97it/s, loss=0.0155][A
Batch:  22%|██▏       | 21/95 [00:04<00:14,  4.96it/s, loss=0.0199]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  25%|██▌       | 24/95 [00:04<00:14,  4.96it/s, loss=0.0192][A
Batch:  26%|██▋       | 25/95 [00:05<00:14,  4.95it/s, loss=0.0168]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  26%|██▋       | 25/95 [00:05<00:14,  4.95it/s, loss=0.0164][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  29%|██▉       | 28/95 [00:05<00:13,  4.95it/s, loss=0.0201][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  29%|██▉       | 28/95 [00:05<00:13,  4.95it/s, loss=0.0136][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  31%|███       | 29/95 [00:06<00:13,  4.95it/s, loss=0.0204][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  32%|███▏      | 30/95 [00:06<00:13,  4.95it/s, loss=0.0211][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  34%|███▎      | 32/95 [00:06<00:12,  4.95it/s, loss=0.014] 

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  36%|███▌      | 34/95 [00:06<00:12,  4.95it/s, loss=0.02]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  37%|███▋      | 35/95 [00:07<00:12,  4.96it/s, loss=0.0183][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  39%|███▉      | 37/95 [00:07<00:11,  4.95it/s, loss=0.0192]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  40%|████      | 38/95 [00:07<00:11,  4.95it/s, loss=0.0208][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  40%|████      | 38/95 [00:07<00:11,  4.95it/s, loss=0.0255][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  42%|████▏     | 40/95 [00:08<00:11,  4.95it/s, loss=0.0242][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  44%|████▍     | 42/95 [00:08<00:10,  4.95it/s, loss=0.0232][A
Batch:  44%|████▍     | 42/95 [00:08<00:10,  4.95it/s, loss=0.0193]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  46%|████▋     | 44/95 [00:08<00:10,  4.96it/s, loss=0.0145]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  48%|████▊     | 46/95 [00:09<00:09,  4.95it/s, loss=0.0195][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])


Epoch:  81%|████████  | 81/100 [28:03<06:03, 19.11s/it, loss=0.0182]

x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])



Batch:   1%|          | 1/95 [00:00<00:19,  4.80it/s, loss=0.0182]

x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  75%|███████▍  | 71/95 [00:14<00:04,  4.98it/s, loss=0.0233]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  89%|████████▉ | 85/95 [00:17<00:02,  4.97it/s, loss=0.0158][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  91%|█████████ | 86/95 [00:17<00:01,  4.96it/s, loss=0.0167][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  93%|█████████▎| 88/95 [00:17<00:01,  4.94it/s, loss=0.0253][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  96%|█████████▌| 91/95 [00:18<00:00,  4.94it/s, loss=0.0144]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  96%|█████████▌| 91/95 [00:18<00:00,  4.94it/s, loss=0.0111][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  82%|████████▏ | 82/100 [28:22<05:43, 19.09s/it, loss=0.0177][A

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   1%|          | 1/95 [00:00<00:19,  4.75it/s, loss=0.0177][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   3%|▎         | 3/95 [00:00<00:18,  4.89it/s, loss=0.0217][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   3%|▎         | 3/95 [00:00<00:18,  4.89it/s, loss=0.0195][A
Batch:   4%|▍         | 4/95 [00:01<00:18,  4.91it/s, loss=0.0212]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   6%|▋         | 6/95 [00:01<00:18,  4.94it/s, loss=0.0149][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   6%|▋         | 6/95 [00:01<00:18,  4.94it/s, loss=0.0147][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  11%|█         | 10/95 [00:02<00:17,  4.95it/s, loss=0.0145]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  14%|█▎        | 13/95 [00:02<00:16,  4.95it/s, loss=0.0136]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  77%|███████▋  | 73/95 [00:14<00:04,  4.96it/s, loss=0.0158]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  87%|████████▋ | 83/95 [00:16<00:02,  4.98it/s, loss=0.0165]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  87%|████████▋ | 83/95 [00:16<00:02,  4.98it/s, loss=0.0145]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  91%|█████████ | 86/95 [00:17<00:01,  4.97it/s, loss=0.023]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  98%|█████████▊| 93/95 [00:18<00:00,  5.00it/s, loss=0.0164]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  83%|████████▎ | 83/100 [28:41<05:24, 19.09s/it, loss=0.0173][A

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   1%|          | 1/95 [00:00<00:19,  4.73it/s, loss=0.0142][A
Batch:   2%|▏         | 2/95 [00:00<00:19,  4.85it/s, loss=0.0169]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   3%|▎         | 3/95 [00:00<00:18,  4.89it/s, loss=0.0118][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   5%|▌         | 5/95 [00:01<00:18,  4.91it/s, loss=0.0216][A
Batch:   6%|▋         | 6/95 [00:01<00:18,  4.91it/s, loss=0.0179]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:   8%|▊         | 8/95 [00:01<00:17,  4.94it/s, loss=0.0158]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:   9%|▉         | 9/95 [00:02<00:17,  4.92it/s, loss=0.017] 

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  11%|█         | 10/95 [00:02<00:17,  4.93it/s, loss=0.0168][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  15%|█▍        | 14/95 [00:02<00:16,  4.95it/s, loss=0.0153]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  17%|█▋        | 16/95 [00:03<00:16,  4.94it/s, loss=0.0177][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  19%|█▉        | 18/95 [00:03<00:15,  4.94it/s, loss=0.016][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  19%|█▉        | 18/95 [00:03<00:15,  4.94it/s, loss=0.017][A
Batch:  20%|██        | 19/95 [00:04<00:15,  4.94it/s, loss=0.0195]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  22%|██▏       | 21/95 [00:04<00:14,  4.95it/s, loss=0.0155][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  24%|██▍       | 23/95 [00:04<00:14,  4.95it/s, loss=0.017][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  24%|██▍       | 23/95 [00:04<00:14,  4.95it/s, loss=0.0197][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  26%|██▋       | 25/95 [00:05<00:14,  4.95it/s, loss=0.0178][A
Batch:  27%|██▋       | 26/95 [00:05<00:13,  4.95it/s, loss=0.0184]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  35%|███▍      | 33/95 [00:06<00:12,  5.00it/s, loss=0.0159]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])


Epoch:  84%|████████▍ | 84/100 [29:00<05:05, 19.10s/it, loss=0.0178]

x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   1%|          | 1/95 [00:00<00:19,  4.73it/s, loss=0.0189][A
Batch:   2%|▏         | 2/95 [00:00<00:19,  4.85it/s, loss=0.0216]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   3%|▎         | 3/95 [00:00<00:18,  4.90it/s, loss=0.019] [A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:   7%|▋         | 7/95 [00:01<00:17,  4.94it/s, loss=0.0214]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  12%|█▏        | 11/95 [00:02<00:16,  4.94it/s, loss=0.018] 

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  14%|█▎        | 13/95 [00:02<00:16,  4.95it/s, loss=0.0155][A
Batch:  16%|█▌        | 15/95 [00:03<00:16,  4.95it/s, loss=0.014]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  18%|█▊        | 17/95 [00:03<00:15,  4.95it/s, loss=0.0144]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  19%|█▉        | 18/95 [00:03<00:15,  4.95it/s, loss=0.0151][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  21%|██        | 20/95 [00:04<00:15,  4.95it/s, loss=0.0211][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  22%|██▏       | 21/95 [00:04<00:14,  4.95it/s, loss=0.0145][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  24%|██▍       | 23/95 [00:04<00:14,  4.95it/s, loss=0.016] 

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  25%|██▌       | 24/95 [00:05<00:14,  4.95it/s, loss=0.0151][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  27%|██▋       | 26/95 [00:05<00:13,  4.95it/s, loss=0.0198][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  28%|██▊       | 27/95 [00:05<00:13,  4.95it/s, loss=0.0202]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  31%|███       | 29/95 [00:06<00:13,  4.94it/s, loss=0.0143]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  35%|███▍      | 33/95 [00:06<00:12,  4.95it/s, loss=0.0189][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  35%|███▍      | 33/95 [00:06<00:12,  4.95it/s, loss=0.0173][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  37%|███▋      | 35/95 [00:07<00:12,  4.95it/s, loss=0.0201][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  38%|███▊      | 36/95 [00:07<00:11,  4.95it/s, loss=0.021]  [A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  40%|████      | 38/95 [00:07<00:11,  4.95it/s, loss=0.0148][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  42%|████▏     | 40/95 [00:08<00:11,  4.95it/s, loss=0.018][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  43%|████▎     | 41/95 [00:08<00:10,  4.95it/s, loss=0.0224][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  44%|████▍     | 42/95 [00:08<00:10,  4.96it/s, loss=0.0144][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  45%|████▌     | 43/95 [00:08<00:10,  4.95it/s, loss=0.0128][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  45%|████▌     | 43/95 [00:08<00:10,  4.95it/s, loss=0.0118][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  47%|████▋     | 45/95 [00:09<00:10,  4.96it/s, loss=0.0171][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  60%|██████    | 57/95 [00:11<00:07,  4.99it/s, loss=0.0174][A
Batch:  62%|██████▏   | 59/95 [00:11<00:07,  4.97it/s, loss=0.0156]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  69%|██████▉   | 66/95 [00:13<00:05,  5.01it/s, loss=0.0206][A
Batch:  71%|███████   | 67/95 [00:13<00:05,  4.99it/s, loss=0.0222]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  73%|███████▎  | 69/95 [00:14<00:05,  4.96it/s, loss=0.015] 

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  75%|███████▍  | 71/95 [00:14<00:04,  4.96it/s, loss=0.0163][A
Batch:  75%|███████▍  | 71/95 [00:14<00:04,  4.96it/s, loss=0.0145]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  77%|███████▋  | 73/95 [00:14<00:04,  4.95it/s, loss=0.0204][A
Batch:  78%|███████▊  | 74/95 [00:15<00:04,  4.95it/s, loss=0.0226]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  79%|███████▉  | 75/95 [00:15<00:04,  4.95it/s, loss=0.0234]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  81%|████████  | 77/95 [00:15<00:03,  4.95it/s, loss=0.0192][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  84%|████████▍ | 80/95 [00:16<00:03,  4.94it/s, loss=0.0195]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  84%|████████▍ | 80/95 [00:16<00:03,  4.94it/s, loss=0.023] 

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  85%|████████▌ | 81/95 [00:16<00:02,  4.94it/s, loss=0.0184]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  87%|████████▋ | 83/95 [00:16<00:02,  4.96it/s, loss=0.017] [A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  88%|████████▊ | 84/95 [00:17<00:02,  4.95it/s, loss=0.0175][A
Batch:  91%|█████████ | 86/95 [00:17<00:01,  4.95it/s, loss=0.0151]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  92%|█████████▏| 87/95 [00:17<00:01,  4.96it/s, loss=0.0187]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  94%|█████████▎| 89/95 [00:18<00:01,  4.95it/s, loss=0.0127][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  96%|█████████▌| 91/95 [00:18<00:00,  4.95it/s, loss=0.0167][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])


Epoch:  85%|████████▌ | 85/100 [29:19<04:46, 19.10s/it, loss=0.0177]

x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  22%|██▏       | 21/95 [00:04<00:14,  4.99it/s, loss=0.0174]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  23%|██▎       | 22/95 [00:04<00:14,  4.98it/s, loss=0.0159]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  24%|██▍       | 23/95 [00:04<00:14,  4.97it/s, loss=0.0173][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  26%|██▋       | 25/95 [00:05<00:14,  4.96it/s, loss=0.0188]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  28%|██▊       | 27/95 [00:05<00:13,  4.96it/s, loss=0.0134][A
Batch:  28%|██▊       | 27/95 [00:05<00:13,  4.96it/s, loss=0.0182]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  29%|██▉       | 28/95 [00:05<00:13,  4.95it/s, loss=0.018] [A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  32%|███▏      | 30/95 [00:06<00:13,  4.95it/s, loss=0.0223][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  32%|███▏      | 30/95 [00:06<00:13,  4.95it/s, loss=0.0203][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  34%|███▎      | 32/95 [00:06<00:12,  4.95it/s, loss=0.0133][A
Batch:  35%|███▍      | 33/95 [00:06<00:12,  4.94it/s, loss=0.0175]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  36%|███▌      | 34/95 [00:07<00:12,  4.95it/s, loss=0.0162][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  51%|█████     | 48/95 [00:09<00:09,  4.95it/s, loss=0.0145][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  59%|█████▉    | 56/95 [00:11<00:07,  4.97it/s, loss=0.0179]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  60%|██████    | 57/95 [00:11<00:07,  4.96it/s, loss=0.0155][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  81%|████████  | 77/95 [00:15<00:03,  5.00it/s, loss=0.0257][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  82%|████████▏ | 78/95 [00:15<00:03,  4.98it/s, loss=0.0216][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  84%|████████▍ | 80/95 [00:16<00:03,  4.96it/s, loss=0.0163][A
Batch:  85%|████████▌ | 81/95 [00:16<00:02,  4.95it/s, loss=0.021] 

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  88%|████████▊ | 84/95 [00:16<00:02,  4.95it/s, loss=0.0172][A
Batch:  89%|████████▉ | 85/95 [00:17<00:02,  4.96it/s, loss=0.0119]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  89%|████████▉ | 85/95 [00:17<00:02,  4.96it/s, loss=0.0165]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  91%|█████████ | 86/95 [00:17<00:01,  4.95it/s, loss=0.0147]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  94%|█████████▎| 89/95 [00:17<00:01,  4.95it/s, loss=0.0163][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  96%|█████████▌| 91/95 [00:18<00:00,  4.96it/s, loss=0.0129][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  96%|█████████▌| 91/95 [00:18<00:00,  4.96it/s, loss=0.014] [A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  97%|█████████▋| 92/95 [00:18<00:00,  4.95it/s, loss=0.0181][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  86%|████████▌ | 86/100 [29:38<04:27, 19.09s/it, loss=0.0169][A

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   2%|▏         | 2/95 [00:00<00:19,  4.87it/s, loss=0.016] [A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   4%|▍         | 4/95 [00:00<00:18,  4.91it/s, loss=0.0163][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   5%|▌         | 5/95 [00:01<00:18,  4.93it/s, loss=0.019][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   6%|▋         | 6/95 [00:01<00:18,  4.93it/s, loss=0.0215][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  67%|██████▋   | 64/95 [00:12<00:06,  4.88it/s, loss=0.0225][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  69%|██████▉   | 66/95 [00:13<00:05,  4.92it/s, loss=0.0164]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  71%|███████   | 67/95 [00:13<00:05,  4.93it/s, loss=0.0171][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  73%|███████▎  | 69/95 [00:13<00:05,  4.94it/s, loss=0.0216][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  94%|█████████▎| 89/95 [00:17<00:01,  5.03it/s, loss=0.0196][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  94%|█████████▎| 89/95 [00:18<00:01,  5.03it/s, loss=0.0112][A
Batch:  95%|█████████▍| 90/95 [00:18<00:00,  5.00it/s, loss=0.0227]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  96%|█████████▌| 91/95 [00:18<00:00,  4.99it/s, loss=0.0176][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  87%|████████▋ | 87/100 [29:57<04:08, 19.09s/it, loss=0.0173][A

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   1%|          | 1/95 [00:00<00:19,  4.74it/s, loss=0.0151][A
Batch:   2%|▏         | 2/95 [00:00<00:19,  4.85it/s, loss=0.0139]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   3%|▎         | 3/95 [00:00<00:18,  4.89it/s, loss=0.0153][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   5%|▌         | 5/95 [00:01<00:18,  4.92it/s, loss=0.0175][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  35%|███▍      | 33/95 [00:06<00:12,  4.95it/s, loss=0.0161][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  37%|███▋      | 35/95 [00:07<00:12,  4.95it/s, loss=0.0186][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  38%|███▊      | 36/95 [00:07<00:11,  4.94it/s, loss=0.0138][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  39%|███▉      | 37/95 [00:07<00:11,  4.94it/s, loss=0.0249][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  47%|████▋     | 45/95 [00:09<00:10,  4.99it/s, loss=0.0139][A
Batch:  48%|████▊     | 46/95 [00:09<00:09,  4.97it/s, loss=0.019] 

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  51%|█████     | 48/95 [00:09<00:09,  4.95it/s, loss=0.0183]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  52%|█████▏    | 49/95 [00:10<00:09,  4.95it/s, loss=0.0155]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  54%|█████▎    | 51/95 [00:10<00:08,  4.95it/s, loss=0.0122][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  55%|█████▍    | 52/95 [00:10<00:08,  4.94it/s, loss=0.0167][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  56%|█████▌    | 53/95 [00:10<00:08,  4.93it/s, loss=0.0217][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  58%|█████▊    | 55/95 [00:11<00:08,  4.94it/s, loss=0.0212][A
Batch:  58%|█████▊    | 55/95 [00:11<00:08,  4.94it/s, loss=0.0195]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  60%|██████    | 57/95 [00:11<00:07,  4.93it/s, loss=0.0149]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  61%|██████    | 58/95 [00:11<00:07,  4.93it/s, loss=0.0178][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  88%|████████▊ | 88/100 [30:16<03:49, 19.10s/it, loss=0.0172]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:   2%|▏         | 2/95 [00:00<00:19,  4.84it/s, loss=0.0224]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:   3%|▎         | 3/95 [00:00<00:18,  4.89it/s, loss=0.0199]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   4%|▍         | 4/95 [00:01<00:18,  4.91it/s, loss=0.0185][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   6%|▋         | 6/95 [00:01<00:18,  4.93it/s, loss=0.012][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:   8%|▊         | 8/95 [00:01<00:17,  4.94it/s, loss=0.0151]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  11%|█         | 10/95 [00:02<00:17,  4.94it/s, loss=0.0169]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  12%|█▏        | 11/95 [00:02<00:17,  4.94it/s, loss=0.0198][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  15%|█▍        | 14/95 [00:02<00:16,  4.95it/s, loss=0.0176][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  19%|█▉        | 18/95 [00:03<00:15,  4.95it/s, loss=0.0163][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  21%|██        | 20/95 [00:04<00:15,  4.96it/s, loss=0.0168][A
Batch:  21%|██        | 20/95 [00:04<00:15,  4.96it/s, loss=0.0125]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  22%|██▏       | 21/95 [00:04<00:14,  4.96it/s, loss=0.0179][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  24%|██▍       | 23/95 [00:04<00:14,  4.95it/s, loss=0.0166]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  26%|██▋       | 25/95 [00:05<00:14,  4.95it/s, loss=0.0162][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  26%|██▋       | 25/95 [00:05<00:14,  4.95it/s, loss=0.0145][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  27%|██▋       | 26/95 [00:05<00:13,  4.95it/s, loss=0.0193][A
Batch:  28%|██▊       | 27/95 [00:05<00:13,  4.94it/s, loss=0.0144]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  31%|███       | 29/95 [00:05<00:13,  4.95it/s, loss=0.0176][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  32%|███▏      | 30/95 [00:06<00:13,  4.96it/s, loss=0.0187][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  35%|███▍      | 33/95 [00:06<00:12,  4.96it/s, loss=0.0204][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  39%|███▉      | 37/95 [00:07<00:11,  4.96it/s, loss=0.0159]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  39%|███▉      | 37/95 [00:07<00:11,  4.96it/s, loss=0.0207][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  41%|████      | 39/95 [00:08<00:11,  4.96it/s, loss=0.0178]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  66%|██████▋   | 63/95 [00:12<00:06,  5.00it/s, loss=0.0153][A
Batch:  66%|██████▋   | 63/95 [00:12<00:06,  5.00it/s, loss=0.0132]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  67%|██████▋   | 64/95 [00:13<00:06,  4.98it/s, loss=0.0168]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  69%|██████▉   | 66/95 [00:13<00:05,  4.97it/s, loss=0.0193][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  72%|███████▏  | 68/95 [00:13<00:05,  4.96it/s, loss=0.0151][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  72%|███████▏  | 68/95 [00:13<00:05,  4.96it/s, loss=0.0135][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  74%|███████▎  | 70/95 [00:14<00:05,  4.95it/s, loss=0.0247]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  75%|███████▍  | 71/95 [00:14<00:04,  4.95it/s, loss=0.0157][A
Batch:  77%|███████▋  | 73/95 [00:14<00:04,  4.95it/s, loss=0.0157]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  77%|███████▋  | 73/95 [00:14<00:04,  4.95it/s, loss=0.0189][A
Batch:  78%|███████▊  | 74/95 [00:15<00:04,  4.94it/s, loss=0.0138]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  80%|████████  | 76/95 [00:15<00:03,  4.95it/s, loss=0.0218]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  81%|████████  | 77/95 [00:15<00:03,  4.94it/s, loss=0.0189]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  95%|█████████▍| 90/95 [00:18<00:01,  4.96it/s, loss=0.0206][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  95%|█████████▍| 90/95 [00:18<00:01,  4.96it/s, loss=0.0164][A
Batch:  96%|█████████▌| 91/95 [00:18<00:00,  4.95it/s, loss=0.0135]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  97%|█████████▋| 92/95 [00:18<00:00,  4.95it/s, loss=0.0124][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])


Epoch:  89%|████████▉ | 89/100 [30:35<03:30, 19.10s/it, loss=0.0172]

x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])



Batch:   5%|▌         | 5/95 [00:01<00:18,  4.97it/s, loss=0.0143]

x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])



Batch:  12%|█▏        | 11/95 [00:02<00:16,  4.96it/s, loss=0.0156]

x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])





x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  19%|█▉        | 18/95 [00:03<00:15,  5.01it/s, loss=0.0163][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  21%|██        | 20/95 [00:04<00:15,  4.98it/s, loss=0.0139]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  22%|██▏       | 21/95 [00:04<00:14,  4.97it/s, loss=0.0146][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  25%|██▌       | 24/95 [00:05<00:14,  4.97it/s, loss=0.0152]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  26%|██▋       | 25/95 [00:05<00:14,  4.96it/s, loss=0.0181]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  27%|██▋       | 26/95 [00:05<00:13,  4.96it/s, loss=0.0124]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  28%|██▊       | 27/95 [00:05<00:13,  4.95it/s, loss=0.016] 

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  31%|███       | 29/95 [00:06<00:13,  4.96it/s, loss=0.0169]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  33%|███▎      | 31/95 [00:06<00:12,  4.96it/s, loss=0.0188][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  73%|███████▎  | 69/95 [00:13<00:05,  4.96it/s, loss=0.0167][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  81%|████████  | 77/95 [00:15<00:03,  4.96it/s, loss=0.0142][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  81%|████████  | 77/95 [00:15<00:03,  4.96it/s, loss=0.019] [A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  84%|████████▍ | 80/95 [00:16<00:03,  4.96it/s, loss=0.0209]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  86%|████████▋ | 82/95 [00:16<00:02,  4.96it/s, loss=0.0164]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  87%|████████▋ | 83/95 [00:16<00:02,  4.96it/s, loss=0.0163][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  88%|████████▊ | 84/95 [00:16<00:02,  4.96it/s, loss=0.014][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  89%|████████▉ | 85/95 [00:17<00:02,  4.96it/s, loss=0.0161][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  92%|█████████▏| 87/95 [00:17<00:01,  4.95it/s, loss=0.0148][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  93%|█████████▎| 88/95 [00:17<00:01,  4.96it/s, loss=0.0159][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  94%|█████████▎| 89/95 [00:18<00:01,  4.96it/s, loss=0.0125][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  95%|█████████▍| 90/95 [00:18<00:01,  4.95it/s, loss=0.0223][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  97%|█████████▋| 92/95 [00:18<00:00,  4.94it/s, loss=0.0144][A
Batch:  99%|█████████▉| 94/95 [00:18<00:00,  4.95it/s, loss=0.0159]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  90%|█████████ | 90/100 [30:54<03:10, 19.10s/it, loss=0.0164]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])





x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   1%|          | 1/95 [00:00<00:19,  4.78it/s, loss=0.0167][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  55%|█████▍    | 52/95 [00:10<00:08,  4.96it/s, loss=0.0182][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  57%|█████▋    | 54/95 [00:11<00:08,  4.96it/s, loss=0.015] [A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  59%|█████▉    | 56/95 [00:11<00:07,  4.95it/s, loss=0.0188][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  85%|████████▌ | 81/95 [00:16<00:02,  4.99it/s, loss=0.0164][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  86%|████████▋ | 82/95 [00:16<00:02,  4.97it/s, loss=0.0184][A
Batch:  87%|████████▋ | 83/95 [00:16<00:02,  4.96it/s, loss=0.0144]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  88%|████████▊ | 84/95 [00:17<00:02,  4.96it/s, loss=0.0126][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  91%|█████████ | 86/95 [00:17<00:01,  4.94it/s, loss=0.0206][A
Batch:  92%|█████████▏| 87/95 [00:17<00:01,  4.94it/s, loss=0.0165]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  93%|█████████▎| 88/95 [00:17<00:01,  4.94it/s, loss=0.0181][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  95%|█████████▍| 90/95 [00:18<00:01,  4.94it/s, loss=0.0193][A
Batch:  96%|█████████▌| 91/95 [00:18<00:00,  4.94it/s, loss=0.0225]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  97%|█████████▋| 92/95 [00:18<00:00,  4.94it/s, loss=0.0131]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  91%|█████████ | 91/100 [31:13<02:51, 19.08s/it, loss=0.0161][A

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:   3%|▎         | 3/95 [00:00<00:18,  4.87it/s, loss=0.019]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   3%|▎         | 3/95 [00:00<00:18,  4.87it/s, loss=0.0154][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   5%|▌         | 5/95 [00:01<00:18,  4.92it/s, loss=0.0188][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   5%|▌         | 5/95 [00:01<00:18,  4.92it/s, loss=0.0159][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:   8%|▊         | 8/95 [00:01<00:17,  4.94it/s, loss=0.0191]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  11%|█         | 10/95 [00:02<00:17,  4.94it/s, loss=0.0131]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  12%|█▏        | 11/95 [00:02<00:16,  4.94it/s, loss=0.0157][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  25%|██▌       | 24/95 [00:05<00:14,  4.97it/s, loss=0.0157][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  27%|██▋       | 26/95 [00:05<00:13,  4.96it/s, loss=0.0164][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  36%|███▌      | 34/95 [00:07<00:12,  5.03it/s, loss=0.0166][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  38%|███▊      | 36/95 [00:07<00:11,  4.96it/s, loss=0.0135][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  39%|███▉      | 37/95 [00:07<00:11,  4.95it/s, loss=0.0132][A
Batch:  40%|████      | 38/95 [00:07<00:11,  4.94it/s, loss=0.0155]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  41%|████      | 39/95 [00:08<00:11,  4.94it/s, loss=0.0175][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  42%|████▏     | 40/95 [00:08<00:11,  4.94it/s, loss=0.0153][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  43%|████▎     | 41/95 [00:08<00:10,  4.94it/s, loss=0.0119][A
Batch:  44%|████▍     | 42/95 [00:08<00:10,  4.93it/s, loss=0.0174]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  45%|████▌     | 43/95 [00:08<00:10,  4.93it/s, loss=0.016] 

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  46%|████▋     | 44/95 [00:09<00:10,  4.93it/s, loss=0.0145]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  48%|████▊     | 46/95 [00:09<00:09,  4.93it/s, loss=0.0164][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  49%|████▉     | 47/95 [00:09<00:09,  4.93it/s, loss=0.0167][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  51%|█████     | 48/95 [00:09<00:09,  4.92it/s, loss=0.0138][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  54%|█████▎    | 51/95 [00:10<00:08,  4.92it/s, loss=0.0195]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  55%|█████▍    | 52/95 [00:10<00:08,  4.93it/s, loss=0.018] [A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  57%|█████▋    | 54/95 [00:10<00:08,  4.94it/s, loss=0.0158][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  57%|█████▋    | 54/95 [00:11<00:08,  4.94it/s, loss=0.0133][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  59%|█████▉    | 56/95 [00:11<00:07,  4.93it/s, loss=0.0136][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  63%|██████▎   | 60/95 [00:12<00:07,  4.94it/s, loss=0.0203][A
Batch:  63%|██████▎   | 60/95 [00:12<00:07,  4.94it/s, loss=0.0167]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  64%|██████▍   | 61/95 [00:12<00:06,  4.94it/s, loss=0.0162]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  65%|██████▌   | 62/95 [00:12<00:06,  4.94it/s, loss=0.0202]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  67%|██████▋   | 64/95 [00:12<00:06,  4.95it/s, loss=0.0138][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  69%|██████▉   | 66/95 [00:13<00:05,  4.94it/s, loss=0.0147]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  80%|████████  | 76/95 [00:15<00:03,  4.95it/s, loss=0.0141][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  95%|█████████▍| 90/95 [00:18<00:01,  4.98it/s, loss=0.0151]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])


Epoch:  92%|█████████▏| 92/100 [31:33<02:32, 19.09s/it, loss=0.0161]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  14%|█▎        | 13/95 [00:02<00:16,  4.96it/s, loss=0.0166][A
Batch:  15%|█▍        | 14/95 [00:03<00:16,  4.96it/s, loss=0.0129]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  16%|█▌        | 15/95 [00:03<00:16,  4.95it/s, loss=0.0151][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  18%|█▊        | 17/95 [00:03<00:15,  4.96it/s, loss=0.0249]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  20%|██        | 19/95 [00:04<00:15,  4.95it/s, loss=0.0135]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  22%|██▏       | 21/95 [00:04<00:14,  4.95it/s, loss=0.0143][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  22%|██▏       | 21/95 [00:04<00:14,  4.95it/s, loss=0.0174][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  41%|████      | 39/95 [00:07<00:11,  5.01it/s, loss=0.0155]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  42%|████▏     | 40/95 [00:08<00:11,  5.00it/s, loss=0.0168]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  44%|████▍     | 42/95 [00:08<00:10,  4.98it/s, loss=0.0179]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  92%|█████████▏| 87/95 [00:17<00:01,  4.97it/s, loss=0.0151]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  93%|█████████▎| 93/100 [31:52<02:13, 19.09s/it, loss=0.0163]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   8%|▊         | 8/95 [00:01<00:17,  5.04it/s, loss=0.0118][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  12%|█▏        | 11/95 [00:02<00:16,  4.97it/s, loss=0.0143][A
Batch:  13%|█▎        | 12/95 [00:02<00:16,  4.96it/s, loss=0.0123]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  14%|█▎        | 13/95 [00:02<00:16,  4.95it/s, loss=0.0189]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  16%|█▌        | 15/95 [00:03<00:16,  4.94it/s, loss=0.0133]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  18%|█▊        | 17/95 [00:03<00:15,  4.95it/s, loss=0.0134][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  18%|█▊        | 17/95 [00:03<00:15,  4.95it/s, loss=0.0181][A
Batch:  19%|█▉        | 18/95 [00:03<00:15,  4.94it/s, loss=0.0156]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  20%|██        | 19/95 [00:04<00:15,  4.94it/s, loss=0.0144][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  22%|██▏       | 21/95 [00:04<00:14,  4.94it/s, loss=0.0141]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  24%|██▍       | 23/95 [00:04<00:14,  4.95it/s, loss=0.0183]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  25%|██▌       | 24/95 [00:05<00:14,  4.94it/s, loss=0.0153][A
Batch:  26%|██▋       | 25/95 [00:05<00:14,  4.94it/s, loss=0.0115]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  28%|██▊       | 27/95 [00:05<00:13,  4.94it/s, loss=0.0202][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  28%|██▊       | 27/95 [00:05<00:13,  4.94it/s, loss=0.0152][A
Batch:  29%|██▉       | 28/95 [00:05<00:13,  4.94it/s, loss=0.0142]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  31%|███       | 29/95 [00:06<00:13,  4.94it/s, loss=0.0171][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  35%|███▍      | 33/95 [00:06<00:12,  4.96it/s, loss=0.0102][A
Batch:  36%|███▌      | 34/95 [00:07<00:12,  4.95it/s, loss=0.0186]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  38%|███▊      | 36/95 [00:07<00:11,  4.95it/s, loss=0.018] [A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  40%|████      | 38/95 [00:07<00:11,  4.95it/s, loss=0.0142][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  42%|████▏     | 40/95 [00:08<00:11,  4.95it/s, loss=0.0145] [A
Batch:  44%|████▍     | 42/95 [00:08<00:10,  4.96it/s, loss=0.0138]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  44%|████▍     | 42/95 [00:08<00:10,  4.96it/s, loss=0.0154][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  69%|██████▉   | 66/95 [00:13<00:05,  5.08it/s, loss=0.0203][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  72%|███████▏  | 68/95 [00:13<00:05,  4.99it/s, loss=0.0176][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  74%|███████▎  | 70/95 [00:14<00:05,  4.95it/s, loss=0.0167][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  76%|███████▌  | 72/95 [00:14<00:04,  4.95it/s, loss=0.0224][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  78%|███████▊  | 74/95 [00:15<00:04,  4.94it/s, loss=0.0136][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  79%|███████▉  | 75/95 [00:15<00:04,  4.94it/s, loss=0.0172][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  80%|████████  | 76/95 [00:15<00:03,  4.94it/s, loss=0.022] [A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  81%|████████  | 77/95 [00:15<00:03,  4.94it/s, loss=0.0168][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  82%|████████▏ | 78/95 [00:15<00:03,  4.94it/s, loss=0.0122][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  86%|████████▋ | 82/95 [00:16<00:02,  4.94it/s, loss=0.0188][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  88%|████████▊ | 84/95 [00:17<00:02,  4.93it/s, loss=0.0134][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  92%|█████████▏| 87/95 [00:17<00:01,  4.94it/s, loss=0.0198]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  93%|█████████▎| 88/95 [00:17<00:01,  4.95it/s, loss=0.0189]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  93%|█████████▎| 88/95 [00:17<00:01,  4.95it/s, loss=0.0168]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  95%|█████████▍| 90/95 [00:18<00:01,  4.95it/s, loss=0.0142][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  95%|█████████▍| 90/95 [00:18<00:01,  4.95it/s, loss=0.0151][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  97%|█████████▋| 92/95 [00:18<00:00,  4.95it/s, loss=0.0143][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  98%|█████████▊| 93/95 [00:18<00:00,  4.95it/s, loss=0.0151][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  94%|█████████▍| 94/100 [32:11<01:54, 19.10s/it, loss=0.0161][A

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  22%|██▏       | 21/95 [00:04<00:14,  4.99it/s, loss=0.00948]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  22%|██▏       | 21/95 [00:04<00:14,  4.99it/s, loss=0.0127] 

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  23%|██▎       | 22/95 [00:04<00:14,  4.98it/s, loss=0.0124][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  25%|██▌       | 24/95 [00:05<00:14,  4.95it/s, loss=0.0147]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  26%|██▋       | 25/95 [00:05<00:14,  4.95it/s, loss=0.0162]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  28%|██▊       | 27/95 [00:05<00:13,  4.95it/s, loss=0.0178]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  29%|██▉       | 28/95 [00:05<00:13,  4.94it/s, loss=0.0162][A
Batch:  31%|███       | 29/95 [00:06<00:13,  4.94it/s, loss=0.0145]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  32%|███▏      | 30/95 [00:06<00:13,  4.94it/s, loss=0.0157]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  33%|███▎      | 31/95 [00:06<00:12,  4.94it/s, loss=0.0157][A
Batch:  34%|███▎      | 32/95 [00:06<00:12,  4.94it/s, loss=0.0145]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  35%|███▍      | 33/95 [00:06<00:12,  4.94it/s, loss=0.0162][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  38%|███▊      | 36/95 [00:07<00:11,  4.94it/s, loss=0.0118]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  40%|████      | 38/95 [00:07<00:11,  4.95it/s, loss=0.0141][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  42%|████▏     | 40/95 [00:08<00:11,  4.94it/s, loss=0.0155][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  44%|████▍     | 42/95 [00:08<00:10,  4.96it/s, loss=0.0168]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  46%|████▋     | 44/95 [00:09<00:10,  4.95it/s, loss=0.016] [A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  48%|████▊     | 46/95 [00:09<00:09,  4.95it/s, loss=0.0159][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  49%|████▉     | 47/95 [00:09<00:09,  4.95it/s, loss=0.0182][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  57%|█████▋    | 54/95 [00:11<00:08,  4.95it/s, loss=0.0164][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  82%|████████▏ | 78/95 [00:15<00:03,  5.08it/s, loss=0.0132][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  84%|████████▍ | 80/95 [00:16<00:03,  4.99it/s, loss=0.0161][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  86%|████████▋ | 82/95 [00:16<00:02,  4.97it/s, loss=0.0169][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  87%|████████▋ | 83/95 [00:16<00:02,  4.95it/s, loss=0.0139][A
Batch:  88%|████████▊ | 84/95 [00:17<00:02,  4.94it/s, loss=0.012] 

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  89%|████████▉ | 85/95 [00:17<00:02,  4.94it/s, loss=0.0199]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  91%|█████████ | 86/95 [00:17<00:01,  4.94it/s, loss=0.0171][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  94%|█████████▎| 89/95 [00:18<00:01,  4.93it/s, loss=0.0128][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  98%|█████████▊| 93/95 [00:18<00:00,  4.94it/s, loss=0.0197][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  95%|█████████▌| 95/100 [32:30<01:35, 19.09s/it, loss=0.0162]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   2%|▏         | 2/95 [00:00<00:19,  4.81it/s, loss=0.0176][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   3%|▎         | 3/95 [00:00<00:18,  4.86it/s, loss=0.0169][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   4%|▍         | 4/95 [00:01<00:18,  4.89it/s, loss=0.0157][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  19%|█▉        | 18/95 [00:03<00:15,  5.02it/s, loss=0.0228][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  29%|██▉       | 28/95 [00:05<00:13,  4.90it/s, loss=0.0139]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  39%|███▉      | 37/95 [00:07<00:11,  5.01it/s, loss=0.0154][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  42%|████▏     | 40/95 [00:08<00:11,  4.96it/s, loss=0.017] [A
Batch:  43%|████▎     | 41/95 [00:08<00:10,  4.95it/s, loss=0.016]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  51%|█████     | 48/95 [00:09<00:09,  4.95it/s, loss=0.0103][A
Batch:  52%|█████▏    | 49/95 [00:10<00:09,  4.94it/s, loss=0.0147]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  54%|█████▎    | 51/95 [00:10<00:08,  4.95it/s, loss=0.0166]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  54%|█████▎    | 51/95 [00:10<00:08,  4.95it/s, loss=0.0196][A
Batch:  55%|█████▍    | 52/95 [00:10<00:08,  4.95it/s, loss=0.0163]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  57%|█████▋    | 54/95 [00:11<00:08,  4.94it/s, loss=0.0144]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  58%|█████▊    | 55/95 [00:11<00:08,  4.94it/s, loss=0.0171][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  60%|██████    | 57/95 [00:11<00:07,  4.95it/s, loss=0.0134][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  60%|██████    | 57/95 [00:11<00:07,  4.95it/s, loss=0.015] [A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  63%|██████▎   | 60/95 [00:12<00:07,  4.94it/s, loss=0.0184][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  66%|██████▋   | 63/95 [00:12<00:06,  4.95it/s, loss=0.0105][A
Batch:  66%|██████▋   | 63/95 [00:12<00:06,  4.95it/s, loss=0.0152]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  69%|██████▉   | 66/95 [00:13<00:05,  4.95it/s, loss=0.0209]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  69%|██████▉   | 66/95 [00:13<00:05,  4.95it/s, loss=0.0132][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  71%|███████   | 67/95 [00:13<00:05,  4.94it/s, loss=0.0153][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  74%|███████▎  | 70/95 [00:14<00:05,  4.94it/s, loss=0.017][A
Batch:  74%|███████▎  | 70/95 [00:14<00:05,  4.94it/s, loss=0.0158]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  80%|████████  | 76/95 [00:15<00:03,  4.96it/s, loss=0.016] [A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  96%|█████████▌| 96/100 [32:49<01:16, 19.09s/it, loss=0.0162]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   2%|▏         | 2/95 [00:00<00:19,  4.84it/s, loss=0.0126][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  25%|██▌       | 24/95 [00:05<00:14,  4.95it/s, loss=0.0225][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  56%|█████▌    | 53/95 [00:10<00:08,  4.95it/s, loss=0.015] 

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  57%|█████▋    | 54/95 [00:11<00:08,  4.95it/s, loss=0.0165][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  58%|█████▊    | 55/95 [00:11<00:08,  4.94it/s, loss=0.0133][A
Batch:  59%|█████▉    | 56/95 [00:11<00:07,  4.94it/s, loss=0.0177]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  62%|██████▏   | 59/95 [00:12<00:07,  4.94it/s, loss=0.0161][A
Batch:  63%|██████▎   | 60/95 [00:12<00:07,  4.94it/s, loss=0.0169]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  64%|██████▍   | 61/95 [00:12<00:06,  4.94it/s, loss=0.0149]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  67%|██████▋   | 64/95 [00:12<00:06,  4.94it/s, loss=0.013][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  67%|██████▋   | 64/95 [00:13<00:06,  4.94it/s, loss=0.0192][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  69%|██████▉   | 66/95 [00:13<00:05,  4.95it/s, loss=0.0139][A
Batch:  69%|██████▉   | 66/95 [00:13<00:05,  4.95it/s, loss=0.0176]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  71%|███████   | 67/95 [00:13<00:05,  4.95it/s, loss=0.0135][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  74%|███████▎  | 70/95 [00:14<00:05,  4.95it/s, loss=0.016] [A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  78%|███████▊  | 74/95 [00:14<00:04,  4.95it/s, loss=0.0131]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  78%|███████▊  | 74/95 [00:15<00:04,  4.95it/s, loss=0.0142]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch:  97%|█████████▋| 97/100 [33:08<00:57, 19.10s/it, loss=0.0159]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  46%|████▋     | 44/95 [00:09<00:10,  4.96it/s, loss=0.0102]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  55%|█████▍    | 52/95 [00:10<00:08,  5.02it/s, loss=0.0131][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  67%|██████▋   | 64/95 [00:13<00:06,  5.02it/s, loss=0.0174][A
Batch:  68%|██████▊   | 65/95 [00:13<00:06,  5.00it/s, loss=0.0196]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  69%|██████▉   | 66/95 [00:13<00:05,  4.98it/s, loss=0.0155][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  72%|███████▏  | 68/95 [00:13<00:05,  4.95it/s, loss=0.0123]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  73%|███████▎  | 69/95 [00:14<00:05,  4.95it/s, loss=0.0163][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  74%|███████▎  | 70/95 [00:14<00:05,  4.94it/s, loss=0.0187][A
Batch:  76%|███████▌  | 72/95 [00:14<00:04,  4.95it/s, loss=0.0157]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  78%|███████▊  | 74/95 [00:15<00:04,  4.94it/s, loss=0.0128]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  79%|███████▉  | 75/95 [00:15<00:04,  4.94it/s, loss=0.0138][A
Batch:  80%|████████  | 76/95 [00:15<00:03,  4.94it/s, loss=0.0194]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  81%|████████  | 77/95 [00:15<00:03,  4.94it/s, loss=0.00909]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  82%|████████▏ | 78/95 [00:15<00:03,  4.94it/s, loss=0.0204] [A
Batch:  83%|████████▎ | 79/95 [00:16<00:03,  4.94it/s, loss=0.0194]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  84%|████████▍ | 80/95 [00:16<00:03,  4.94it/s, loss=0.0153]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  86%|████████▋ | 82/95 [00:16<00:02,  4.95it/s, loss=0.0156]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  88%|████████▊ | 84/95 [00:16<00:02,  4.94it/s, loss=0.0159][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  88%|████████▊ | 84/95 [00:17<00:02,  4.94it/s, loss=0.0132][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  94%|█████████▎| 89/95 [00:18<00:01,  4.95it/s, loss=0.0162][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])


Epoch:  98%|█████████▊| 98/100 [33:27<00:38, 19.10s/it, loss=0.0162]

x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])





x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   6%|▋         | 6/95 [00:01<00:17,  4.97it/s, loss=0.0155][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])





x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:   8%|▊         | 8/95 [00:01<00:17,  4.97it/s, loss=0.0199][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  20%|██        | 19/95 [00:03<00:15,  4.99it/s, loss=0.017][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  23%|██▎       | 22/95 [00:04<00:14,  4.97it/s, loss=0.0148][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  26%|██▋       | 25/95 [00:05<00:14,  4.95it/s, loss=0.0131]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  28%|██▊       | 27/95 [00:05<00:13,  4.94it/s, loss=0.0136]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  31%|███       | 29/95 [00:06<00:13,  4.95it/s, loss=0.0212]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  33%|███▎      | 31/95 [00:06<00:12,  4.95it/s, loss=0.0128][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  35%|███▍      | 33/95 [00:06<00:12,  4.95it/s, loss=0.0151][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  37%|███▋      | 35/95 [00:07<00:12,  4.95it/s, loss=0.0153][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  37%|███▋      | 35/95 [00:07<00:12,  4.95it/s, loss=0.018] [A
Batch:  38%|███▊      | 36/95 [00:07<00:11,  4.94it/s, loss=0.0144]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  39%|███▉      | 37/95 [00:07<00:11,  4.94it/s, loss=0.0116]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  41%|████      | 39/95 [00:07<00:11,  4.95it/s, loss=0.0169]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  43%|████▎     | 41/95 [00:08<00:10,  4.95it/s, loss=0.0162][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  45%|████▌     | 43/95 [00:08<00:10,  4.95it/s, loss=0.0245]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  47%|████▋     | 45/95 [00:09<00:10,  4.95it/s, loss=0.0177][A
Batch:  48%|████▊     | 46/95 [00:09<00:09,  4.95it/s, loss=0.0132]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  48%|████▊     | 46/95 [00:09<00:09,  4.95it/s, loss=0.0167][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  49%|████▉     | 47/95 [00:09<00:09,  4.94it/s, loss=0.0169][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])


Epoch:  99%|█████████▉| 99/100 [33:46<00:19, 19.09s/it, loss=0.0162]

x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  26%|██▋       | 25/95 [00:05<00:14,  4.97it/s, loss=0.0193][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  28%|██▊       | 27/95 [00:05<00:13,  4.94it/s, loss=0.0163]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  29%|██▉       | 28/95 [00:05<00:13,  4.94it/s, loss=0.0153][A
Batch:  31%|███       | 29/95 [00:06<00:13,  4.93it/s, loss=0.0132]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  32%|███▏      | 30/95 [00:06<00:13,  4.94it/s, loss=0.0189]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])



Batch:  34%|███▎      | 32/95 [00:06<00:12,  4.92it/s, loss=0.0159]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  35%|███▍      | 33/95 [00:06<00:12,  4.93it/s, loss=0.0215][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  36%|███▌      | 34/95 [00:07<00:12,  4.92it/s, loss=0.0171][A
Batch:  37%|███▋      | 35/95 [00:07<00:12,  4.92it/s, loss=0.0102]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  38%|███▊      | 36/95 [00:07<00:11,  4.93it/s, loss=0.0163][A
Batch:  39%|███▉      | 37/95 [00:07<00:11,  4.92it/s, loss=0.0213]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  40%|████      | 38/95 [00:07<00:11,  4.92it/s, loss=0.0165][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  41%|████      | 39/95 [00:08<00:11,  4.92it/s, loss=0.0143][A
Batch:  42%|████▏     | 40/95 [00:08<00:11,  4.92it/s, loss=0.0128]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  44%|████▍     | 42/95 [00:08<00:10,  4.92it/s, loss=0.0202][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  45%|████▌     | 43/95 [00:08<00:10,  4.92it/s, loss=0.0179][A
Batch:  46%|████▋     | 44/95 [00:09<00:10,  4.92it/s, loss=0.0141]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  47%|████▋     | 45/95 [00:09<00:10,  4.92it/s, loss=0.0159]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  48%|████▊     | 46/95 [00:09<00:09,  4.93it/s, loss=0.0145][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  49%|████▉     | 47/95 [00:09<00:09,  4.92it/s, loss=0.0135][A
Batch:  51%|█████     | 48/95 [00:09<00:09,  4.92it/s, loss=0.0167]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A
Batch:  52%|█████▏    | 49/95 [00:10<00:09,  4.92it/s, loss=0.0171]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  54%|█████▎    | 51/95 [00:10<00:08,  4.93it/s, loss=0.0113][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  55%|█████▍    | 52/95 [00:10<00:08,  4.93it/s, loss=0.0146][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  56%|█████▌    | 53/95 [00:10<00:08,  4.92it/s, loss=0.0237][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  58%|█████▊    | 55/95 [00:11<00:08,  4.94it/s, loss=0.0136][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  59%|█████▉    | 56/95 [00:11<00:07,  4.94it/s, loss=0.0126][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  61%|██████    | 58/95 [00:11<00:07,  4.94it/s, loss=0.0121][A
Batch:  61%|██████    | 58/95 [00:11<00:07,  4.94it/s, loss=0.0174]

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


[A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  68%|██████▊   | 65/95 [00:13<00:06,  4.96it/s, loss=0.0152][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  68%|██████▊   | 65/95 [00:13<00:06,  4.96it/s, loss=0.0202][A


nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  71%|███████   | 67/95 [00:13<00:05,  4.96it/s, loss=0.0204][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])





nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Batch:  86%|████████▋ | 82/95 [00:16<00:02,  5.01it/s, loss=0.0126][A

nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])




x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])




nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])




x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])
nobs.shape: torch.Size([256, 2, 5])
naction.shape: torch.Size([256, 16, 2])
x torch.Size([256, 1024, 4])
h_pop torch.Size([256, 1024, 4])
x torch.Size([256, 512, 8])
h_pop torch.Size([256, 512, 8])


Epoch: 100%|██████████| 100/100 [34:05<00:00, 20.46s/it, loss=0.0159]

nobs.shape: torch.Size([144, 2, 5])
naction.shape: torch.Size([144, 16, 2])
x torch.Size([144, 1024, 4])
h_pop torch.Size([144, 1024, 4])
x torch.Size([144, 512, 8])
h_pop torch.Size([144, 512, 8])





In [None]:
#@markdown ### **Loading Pretrained Checkpoint**
#@markdown Set `load_pretrained = True` to load pretrained weights.

load_pretrained = False
if load_pretrained:
  ckpt_path = "pusht_state_100ep.ckpt"
  if not os.path.isfile(ckpt_path):
      id = "1mHDr_DEZSdiGo9yecL50BBQYzR8Fjhl_&confirm=t"
      gdown.download(id=id, output=ckpt_path, quiet=False)

  state_dict = torch.load(ckpt_path, map_location='cuda')
  ema_noise_pred_net = noise_pred_net
  ema_noise_pred_net.load_state_dict(state_dict)
  print('Pretrained weights loaded.')
else:
  print("Skipped pretrained weight loading.")

: 

In [None]:
#@markdown ### **Inference**

# limit enviornment interaction to 200 steps before termination
max_steps = 200
env = PushTEnv()
# use a seed >200 to avoid initial states seen in the training dataset
env.seed(100000)

# get first observation
obs, info = env.reset()

# keep a queue of last 2 steps of observations
obs_deque = collections.deque(
    [obs] * obs_horizon, maxlen=obs_horizon)
# save visualization and rewards
imgs = [env.render(mode='rgb_array')]
rewards = list()
done = False
step_idx = 0

with tqdm(total=max_steps, desc="Eval PushTStateEnv") as pbar:
    while not done:
        B = 1
        # stack the last obs_horizon (2) number of observations
        obs_seq = np.stack(obs_deque)
        # normalize observation
        nobs = normalize_data(obs_seq, stats=stats['obs'])
        # device transfer
        nobs = torch.from_numpy(nobs).to(device, dtype=torch.float32)

        # infer action
        with torch.no_grad():
            # reshape observation to (B,obs_horizon*obs_dim)
            obs_cond = nobs.unsqueeze(0).flatten(start_dim=1)

            # initialize action from Guassian noise
            noisy_action = torch.randn(
                (B, pred_horizon, action_dim), device=device)
            naction = noisy_action

            # init scheduler
            noise_scheduler.set_timesteps(num_diffusion_iters)

            for k in noise_scheduler.timesteps:
                # predict noise
                noise_pred = ema_noise_pred_net(
                    sample=naction,
                    timestep=k,
                    global_cond=obs_cond
                )

                # inverse diffusion step (remove noise)
                naction = noise_scheduler.step(
                    model_output=noise_pred,
                    timestep=k,
                    sample=naction
                ).prev_sample

        # unnormalize action
        naction = naction.detach().to('cpu').numpy()
        # (B, pred_horizon, action_dim)
        naction = naction[0]
        action_pred = unnormalize_data(naction, stats=stats['action'])

        # only take action_horizon number of actions
        start = obs_horizon - 1
        end = start + action_horizon
        action = action_pred[start:end,:]
        # (action_horizon, action_dim)

        # execute action_horizon number of steps
        # without replanning
        for i in range(len(action)):
            # stepping env
            obs, reward, done, _, info = env.step(action[i])
            # save observations
            obs_deque.append(obs)
            # and reward/vis
            rewards.append(reward)
            imgs.append(env.render(mode='rgb_array'))

            # update progress bar
            step_idx += 1
            pbar.update(1)
            pbar.set_postfix(reward=reward)
            if step_idx > max_steps:
                done = True
            if done:
                break

# print out the maximum target coverage
print('Score: ', max(rewards))

# visualize
from IPython.display import Video
vwrite('vis.mp4', imgs)
Video('vis.mp4', embed=True, width=256, height=256)

Eval PushTStateEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Score:  0.9421109132517091
