In [1]:
import time
from pathlib import Path
from datetime import datetime
import gymnasium as gym
import json
import numpy as np
import torch
from torch import nn, optim
from torch.distributions import Normal
from torch.nn.functional import mse_loss
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm import tqdm

In [2]:
# from pyvirtualdisplay import Display
# display = Display(visible=0, size=(800, 600))
# display.start()

In [2]:
class Args:
    pass

args = Args()
args.env_id = "HalfCheetah-v4"
args.total_timesteps = 10_000_000
args.num_envs = 16
args.num_steps = 5
args.learning_rate = 5e-4
args.actor_layers = [64, 64]
args.critic_layers  = [64, 64]
args.gamma = 0.99
args.gae = 1.0
args.value_coef = 0.5
args.entropy_coef = 0.01
args.clip_grad_norm = 0.5
args.seed = 0

args.batch_size = int(args.num_envs * args.num_steps)
args.num_updates = int(args.total_timesteps // args.batch_size)

In [3]:
def make_env(env_id, capture_video=False, run_dir="."):
    def thunk():
        if capture_video:
            env = gym.make(env_id, render_mode="rgb_array")
            env = gym.wrappers.RecordVideo(
                env=env,
                video_folder=f"{run_dir}/videos",
                episode_trigger=lambda x: x,
                disable_logger=True,
            )
        else:
            env = gym.make(env_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env = gym.wrappers.FlattenObservation(env)
        env = gym.wrappers.ClipAction(env)
        env = gym.wrappers.NormalizeObservation(env)
        env = gym.wrappers.TransformObservation(env, lambda state: np.clip(state, -10, 10))
        env = gym.wrappers.NormalizeReward(env)
        env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10))

        return env

    return thunk

In [4]:
def compute_advantages(rewards, flags, values, last_value, args):
    advantages = torch.zeros((args.num_steps, args.num_envs))
    adv = torch.zeros(args.num_envs)

    for i in reversed(range(args.num_steps)):
        returns = rewards[i] + args.gamma * flags[i] * last_value
        delta = returns - values[i]

        adv = delta + args.gamma * args.gae * flags[i] * adv
        advantages[i] = adv

        last_value = values[i]

    return advantages

In [5]:
class RolloutBuffer:
    def __init__(self, num_steps, num_envs, observation_shape, action_shape):
        self.states = np.zeros((num_steps, num_envs, *observation_shape), dtype=np.float32)
        self.actions = np.zeros((num_steps, num_envs, *action_shape), dtype=np.float32)
        self.rewards = np.zeros((num_steps, num_envs), dtype=np.float32)
        self.flags = np.zeros((num_steps, num_envs), dtype=np.float32)
        self.values = np.zeros((num_steps, num_envs), dtype=np.float32)

        self.step = 0
        self.num_steps = num_steps

    def push(self, state, action, reward, flag, value):
        self.states[self.step] = state
        self.actions[self.step] = action
        self.rewards[self.step] = reward
        self.flags[self.step] = flag
        self.values[self.step] = value

        self.step = (self.step + 1) % self.num_steps

    def get(self):
        return (
            torch.from_numpy(self.states),
            torch.from_numpy(self.actions),
            torch.from_numpy(self.rewards),
            torch.from_numpy(self.flags),
            torch.from_numpy(self.values),
        )

In [6]:
class ActorCriticNet(nn.Module):
    def __init__(self, observation_shape, action_dim, actor_layers, critic_layers):
        super().__init__()

        self.actor_net = self._build_net(observation_shape, actor_layers)
        self.critic_net = self._build_net(observation_shape, critic_layers)

        self.actor_net.append(self._build_linear(actor_layers[-1], action_dim, std=0.01))
        self.actor_logstd = nn.Parameter(torch.zeros(1, action_dim))

        self.critic_net.append(self._build_linear(critic_layers[-1], 1, std=1.0))

    def _build_linear(self, in_size, out_size, apply_init=True, std=np.sqrt(2), bias_const=0.0):
        layer = nn.Linear(in_size, out_size)

        if apply_init:
            torch.nn.init.orthogonal_(layer.weight, std)
            torch.nn.init.constant_(layer.bias, bias_const)

        return layer

    def _build_net(self, observation_shape, hidden_layers):
        layers = nn.Sequential()
        in_size = np.prod(observation_shape)

        for out_size in hidden_layers:
            layers.append(self._build_linear(in_size, out_size))
            layers.append(nn.Tanh())
            in_size = out_size

        return layers

    def forward(self, state):
        mean = self.actor_net(state)
        std = self.actor_logstd
        std = std.expand_as(mean).exp()
        distribution = Normal(mean, std)

        action = distribution.sample()

        value = self.critic_net(state).squeeze(-1)

        return action, value

    def evaluate(self, states, actions):
        mean = self.actor_net(states)
        std = self.actor_logstd.expand_as(mean).exp()
        distribution = Normal(mean, std)

        log_probs = distribution.log_prob(actions).sum(-1)
        entropy = distribution.entropy().sum(-1)

        values = self.critic_net(states).squeeze(-1)

        return log_probs, values, entropy

    def critic(self, state):
        return self.critic_net(state).squeeze(-1)

In [7]:
def train(args, run_name, run_dir):
    # Create tensorboard writer and save hyperparameters
    writer = SummaryWriter(run_dir)
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )

    # Create vectorized environment(s)
    envs = gym.vector.AsyncVectorEnv([make_env(args.env_id) for _ in range(args.num_envs)])

    # Metadata about the environment
    observation_shape = envs.single_observation_space.shape
    action_shape = envs.single_action_space.shape
    action_dim = np.prod(action_shape)

    # Set seed for reproducibility
    if args.seed:
        torch.manual_seed(args.seed)
        state, _ = envs.reset(seed=args.seed)
    else:
        state, _ = envs.reset()

    # Create policy network and optimizer
    policy = ActorCriticNet(observation_shape, action_dim, args.actor_layers, args.critic_layers)
    optimizer = optim.RMSprop(policy.parameters(), lr=args.learning_rate, alpha=0.99, eps=1e-5)

    # Create buffers
    rollout_buffer = RolloutBuffer(args.num_steps, args.num_envs, observation_shape, action_shape)

    # Remove unnecessary variables
    del action_dim

    global_step = 0
    log_episodic_returns, log_episodic_lengths = [], []
    start_time = time.process_time()

    # Main loop
    for iter in tqdm(range(args.num_updates)):
        for _ in range(args.num_steps):
            # Update global step
            global_step += 1 * args.num_envs

            with torch.no_grad():
                # Get action
                action, value = policy(torch.from_numpy(state).float())

            # Perform action
            action = action.cpu().numpy()
            next_state, reward, terminated, truncated, infos = envs.step(action)

            # Store transition
            flag = 1.0 - np.logical_or(terminated, truncated)
            value = value.cpu().numpy()
            rollout_buffer.push(state, action, reward, flag, value)

            state = next_state

            if "final_info" not in infos:
                continue

            # Log episodic return and length
            for info in infos["final_info"]:
                if info is None:
                    continue

                log_episodic_returns.append(info["episode"]["r"])
                log_episodic_lengths.append(info["episode"]["l"])
                writer.add_scalar("rollout/episodic_return", np.mean(log_episodic_returns[-5:]), global_step)
                writer.add_scalar("rollout/episodic_length", np.mean(log_episodic_lengths[-5:]), global_step)

        # Get transition batch
        states, actions, rewards, flags, values = rollout_buffer.get()

        with torch.no_grad():
            last_value = policy.critic(torch.from_numpy(next_state).float())

        # Calculate advantages and TD target
        advantages = compute_advantages(rewards, flags, values, last_value, args)
        td_target = advantages + values

        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # Flatten batch
        states = states.reshape(-1, *observation_shape)
        actions = actions.reshape(-1, *action_shape)
        td_target = td_target.reshape(-1)
        advantages = advantages.reshape(-1)

        # Compute losses
        log_probs, td_predict, entropy = policy.evaluate(states, actions)

        actor_loss = (-log_probs * advantages).mean()
        critic_loss = mse_loss(td_target, td_predict)
        entropy_loss = entropy.mean()

        loss = actor_loss + critic_loss * args.value_coef - entropy_loss * args.entropy_coef

        # Update policy network
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm_(policy.parameters(), args.clip_grad_norm)
        optimizer.step()

        # Log training metrics
        writer.add_scalar("rollout/SPS", int(global_step / (time.process_time() - start_time)), global_step)
        writer.add_scalar("train/loss", loss, global_step)
        writer.add_scalar("train/actor_loss", actor_loss, global_step)
        writer.add_scalar("train/critic_loss", critic_loss, global_step)

        if iter % 1_000 == 0:
            torch.save(policy.state_dict(), f"{run_dir}/policy.pt")

    # Save final policy
    torch.save(policy.state_dict(), f"{run_dir}/policy.pt")
    print(f"Saved policy to {run_dir}/policy.pt")

    # Close the environment
    envs.close()
    writer.close()

    # Average of episodic returns (for the last 5% of the training)
    indexes = int(len(log_episodic_returns) * 0.05)
    mean_train_return = np.mean(log_episodic_returns[-indexes:])
    writer.add_scalar("rollout/mean_train_return", mean_train_return, global_step)

    return mean_train_return

In [8]:
def eval_and_render(args, run_dir):
    # Create environment
    env = gym.vector.SyncVectorEnv([make_env(args.env_id, capture_video=True, run_dir=run_dir)])

    # Metadata about the environment
    observation_shape = env.single_observation_space.shape
    action_shape = env.single_action_space.shape
    action_dim = np.prod(action_shape)

    # Load policy
    policy = ActorCriticNet(observation_shape, action_dim, args.actor_layers, args.critic_layers)
    filename = f"{run_dir}/policy.pt"
    print(f"reading {filename}...")
    policy.load_state_dict(torch.load(filename))
    policy.eval()

    count_episodes = 0
    list_rewards = []

    state, _ = env.reset()

    # Run episodes
    while count_episodes < 30:
        with torch.no_grad():
            action, _ = policy(torch.from_numpy(state).float())

        action = action.cpu().numpy()
        state, _, _, _, infos = env.step(action)

        if "final_info" in infos:
            info = infos["final_info"][0]
            returns = info["episode"]["r"][0]
            count_episodes += 1
            list_rewards.append(returns)
            print(f"-> Episode {count_episodes}: {returns} returns")

    env.close()

    return np.mean(list_rewards)

In [9]:
# Create run directory
run_time = str(datetime.now()).replace(" ", "_" ).replace(":", "-").split(".")[0].replace("-", "_")
print(f"run time: {run_time}")
run_name = "A2C_PyTorch"

run_dir = Path(f"runs/{run_name}/{run_time}")
run_dir.mkdir(parents=True, exist_ok=True)

with open(run_dir / "args.json", "w") as fp:
    json.dump(args.__dict__, fp)

run time: 2024_01_25_14_30_58


In [10]:
print(f"Commencing training of {run_name} on {args.env_id} for {args.total_timesteps} timesteps.")
print(f"Results will be saved to: {run_dir}")
mean_train_return = train(args=args, run_name=run_name, run_dir=run_dir)
print(f"Training - Mean returns achieved: {mean_train_return}.")

Commencing training of A2C_PyTorch on HalfCheetah-v4 for 10000000 timesteps.
Results will be saved to: runs\A2C_PyTorch\2024_01_25_14_30_58


  0%|          | 0/125000 [00:00<?, ?it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -9.5079,  -6.9118,  -8.7038,  -7.7109,  -7.1455,  -7.3211, -11.3247,
         -9.4816, -10.7889,  -7.1531,  -9.8028, -11.1268,  -7.5115, -12.2696,
        -10.4919,  -7.2353, -11.5606,  -9.6581, -10.4265,  -8.8145,  -6.2998,
         -9.3900,  -8.9101,  -7.9214,  -9.1861,  -9.5402,  -7.4585,  -9.2766,
        -13.0857,  -8.0035,  -7.2780,  -7.4284, -10.6952,  -8.6997,  -8.4272,
         -7.3089,  -8.4338,  -5.9591,  -8.2958,  -6.7976,  -7.7187,  -7.7464,
         -8.6061,  -8.5287,  -7.0908,  -6.9176, -11.4913, -10.0649,  -7.4291,
         -7.8363,  -8.4858,  -8.2495,  -7.2911,  -6.3886,  -9.3578,  -9.3288,
        -13.1002,  -6.2564,  -7.4353,  -6.7484,  -8.6085,  -7.5471,  -7.85

  0%|          | 4/125000 [00:00<4:05:15,  8.49it/s] 

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -7.7544,  -8.5907,  -7.6455, -12.2562,  -7.0277,  -7.2115,  -8.3066,
         -9.3771,  -6.9276, -13.7222,  -7.6208,  -7.9477,  -6.0989,  -8.6889,
         -8.8624, -11.2606,  -5.8080,  -8.1976,  -7.9613,  -8.0889, -11.4663,
         -6.5733,  -9.2791,  -7.2414,  -7.4477, -11.8833, -11.7853,  -7.8465,
         -8.6765,  -6.7869,  -7.5355,  -7.2574,  -6.3438,  -8.3131,  -9.9237,
         -8.0688,  -7.6758,  -8.7705, -10.4571,  -9.6726,  -9.9571,  -7.3799,
         -8.5467, -11.4938,  -6.1687,  -8.4237, -11.8567, -10.2552,  -6.8779,
         -7.7484,  -9.6371,  -8.2765,  -9.0310,  -8.6967,  -8.8948,  -7.0017,
         -8.5431,  -8.0583,  -8.6952, -10.0346, -11.8954,  -8.1456,  -8.26

  0%|          | 10/125000 [00:00<2:08:12, 16.25it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -9.9686,  -8.4454,  -7.2826,  -9.3204,  -8.0143,  -7.4761,  -8.3649,
        -11.7406,  -6.8677, -10.0598,  -7.9500,  -9.4349,  -7.9759,  -7.6545,
         -7.6467,  -8.3358,  -8.2924,  -7.6363,  -7.3144,  -7.7396,  -8.0694,
         -7.7859,  -6.0713,  -7.4696,  -9.2641,  -6.9547,  -9.0060, -10.2414,
         -7.0705,  -7.3466, -11.3184,  -6.2076, -12.4472,  -8.2315,  -6.8057,
         -8.2238,  -7.9054,  -8.6038,  -8.3651,  -8.5245, -10.7489,  -7.1244,
        -10.6587,  -7.9327,  -9.7106,  -8.7508,  -9.2503,  -7.7541,  -7.4530,
         -9.3428,  -7.7995,  -7.5737,  -7.7500,  -9.7192,  -7.8430, -10.1909,
         -6.6997,  -7.9301,  -8.7471,  -7.7137,  -7.7727,  -8.0886,  -6.2829,
         -9.5271, -11.9399, -10.681

  0%|          | 13/125000 [00:00<1:52:43, 18.48it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -8.0521,  -7.5382,  -7.1556, -10.2611, -12.4066,  -7.9320,  -6.8213,
         -8.2269, -10.1735,  -7.2419,  -8.4942,  -6.8100,  -7.6587, -11.1885,
         -9.5208,  -8.4341,  -6.7002,  -7.6178,  -8.7700,  -8.1197, -11.9384,
         -9.2847,  -7.3881,  -9.7939,  -7.9069,  -7.5665,  -9.8419, -10.2193,
        -12.3477,  -6.3874,  -7.9395,  -7.8904,  -7.5380,  -6.6702, -10.4791,
         -7.5585,  -9.3748,  -7.1458, -16.7662,  -7.2460,  -7.4174,  -7.2622,
         -8.6130, -10.2148, -10.3157,  -8.6125,  -7.2317,  -6.7608, -11.9364,
         -6.2567,  -9.9149,  -8.8679, -15.0988, -12.2689,  -7.0615,  -7.7375,
        -17.9930,  -8.7926, -10.0942,  -8.7016, -11.9192,  -8.6866,  -9.2455,
        -10.0376,  -6.0308,  -7.0711,  -7.5533,  -8.9160,  -9.0678,  -8.93

  0%|          | 19/125000 [00:01<1:45:51, 19.68it/s]

tensor([ -6.1539,  -7.7827,  -7.4278,  -7.3412,  -7.1959,  -7.4819,  -8.8024,
         -8.1373,  -8.1696,  -7.1476,  -7.6891,  -8.8229,  -8.0104,  -6.4610,
         -7.2151,  -9.1929,  -7.0861,  -6.9434,  -6.6069,  -6.2133, -10.0576,
         -7.1366,  -5.8865, -11.4990,  -8.1901, -16.4174,  -7.8737,  -8.4883,
         -8.1594,  -9.7374,  -7.7250,  -8.7545, -11.4820,  -7.3002,  -8.2571,
         -9.0536,  -7.6510,  -7.6714,  -9.2588,  -9.9628, -10.3161,  -9.6105,
         -9.9440,  -9.7443,  -8.0638,  -8.9857,  -6.3858,  -7.9967,  -8.3351,
         -6.5226,  -6.6530,  -9.9337,  -8.3317, -10.9070,  -6.6746,  -7.1588,
         -8.3928,  -8.4035, -11.3042, -10.3509,  -6.4870,  -7.6972, -10.6909,
         -8.8249,  -7.2327,  -8.2979,  -8.2019, -11.1885,  -6.8578,  -9.1126,
         -8.6912,  -7.3479, -10.1592,  -7.5416,  -9.6424,  -9.2774, -11.4874,
        -10.6271, -10.7033,  -8.9980], grad_fn=<SumBackward1>) tensor([-0.1377, -0.5134, -0.3126, -0.1780, -0.0105, -0.3035, -0.9931, -0.4142,

  0%|          | 22/125000 [00:01<1:44:53, 19.86it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([-10.3292,  -8.2734,  -6.5345,  -7.3713,  -9.4597,  -8.2855,  -7.8687,
         -8.0359,  -9.8064,  -9.3320,  -7.2729,  -8.4262,  -5.6279,  -8.6445,
        -10.9525,  -7.1275,  -7.4067,  -9.6669, -12.3602, -11.1457,  -8.5259,
         -6.9113,  -8.9990, -11.9985,  -7.3539,  -9.4063,  -7.7042,  -6.5517,
         -7.3298,  -6.1552,  -6.2057,  -9.3695,  -8.7015,  -7.5756,  -7.1291,
        -12.5104,  -6.0371,  -8.2972,  -9.3507,  -8.0428,  -8.1161,  -8.0459,
         -7.2824, -10.0115,  -7.0693,  -9.3230,  -9.8803,  -9.5077,  -7.5905,
         -7.5125,  -7.5321,  -7.8212,  -8.5018, -12.6406,  -8.1713,  -7.4849,
         -8.6708,  -8.1510,  -8.7636,  -6.0992,  -6.1839, -10.3588,  -5.84

  0%|          | 28/125000 [00:01<1:40:00, 20.83it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -7.5336,  -6.8681,  -7.5018,  -6.7765,  -9.3960, -11.4270,  -7.2139,
        -12.7618,  -8.0416,  -7.2085,  -6.7830,  -8.4048,  -8.5835,  -6.8169,
        -12.4818,  -7.1998,  -7.4777,  -8.2812, -11.4162,  -9.8877, -10.1795,
         -8.3742,  -7.7061,  -9.5000,  -6.4895,  -9.0578,  -9.8917,  -9.8344,
         -7.5028,  -6.8508,  -7.8546,  -8.5662,  -8.4024,  -6.9886,  -8.8705,
         -7.5441, -11.1041,  -8.0618,  -7.2700,  -8.8943, -12.9402,  -6.7422,
         -8.5218,  -6.9141,  -9.5635,  -6.6321, -10.3887,  -7.4395,  -6.7543,
         -8.2818,  -6.3598,  -7.7758,  -6.7989, -10.7763, -12.1949, -10.7458,
         -7.8638,  -6.3645,  -6.8305,  -9.5371,  -8.0975, -10.0267,  -8.7580,
        -12.0549,  -9.9893,  -7.7841,  -8.6491, -10.2652,  -6.9614,  -7.79

  0%|          | 31/125000 [00:01<1:50:43, 18.81it/s]

torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -8.1102,  -8.6585,  -8.1446,  -7.5463,  -6.8123,  -9.7756,  -6.6562,
         -7.9179,  -6.5403,  -8.2802,  -9.6104,  -8.3991,  -8.3224,  -9.8705,
         -7.6187,  -7.6913,  -6.7033,  -7.9339,  -6.4521,  -8.8099, -10.4894,
        -11.2822,  -6.3414,  -6.3154, -10.8046,  -7.9261,  -9.9788, -11.1608,
         -7.6609,  -9.3359,  -7.0796,  -8.4092,  -7.5760,  -6.8672, -12.4917,
         -9.8301,  -6.9225,  -6.2605, -11.9182,  -6.9851, -10.7745, -10.1099,
        -11.3171,  -8.9209,  -6.7218,  -7.2579,  -9.1062,  -6.2447,  -6.4471,
         -9.8827,  -6.2529,  -8.5313,  -7.1289,  -8.2210, -13.1478,  -7.4205,
         -9.6481,  -6.5991, -10.3090,  -8.0549,  -7.3479,  -6.9543,  -6.3450,
         -9.9300,  -7.1834, -11.0258,  -8.6018, -12.3592,  -6.6789, -10.0519,
         -6.6507,  -5.8939,  -8.4852, -11.0050,  -7.5077,  -6.5243,  -7.3106,
         -6.9356,  -7.5499,  -6.969

  0%|          | 35/125000 [00:02<1:59:48, 17.38it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -6.2889,  -8.7638,  -6.3860,  -6.1138,  -8.8874,  -7.1035,  -8.2880,
         -9.7679,  -6.1771,  -6.3512, -14.5047, -10.8957,  -7.3911,  -7.9737,
         -8.4961,  -9.3268,  -6.8842, -10.3084,  -6.7515,  -7.4074,  -8.5045,
         -8.8520,  -7.8706,  -6.0702,  -7.2977,  -7.6860,  -6.2828,  -7.1183,
         -8.8708,  -8.0267,  -8.6824,  -9.6758,  -7.9806,  -7.3329,  -9.4154,
         -8.3904,  -8.0181,  -5.8958,  -6.0647,  -8.8485,  -6.2582,  -6.6148,
         -9.4110,  -7.4956, -13.2448,  -7.4117,  -6.7124,  -9.6425,  -9.1713,
         -7.6548,  -6.5082,  -9.2650, -10.3499,  -9.7600,  -8.2577,  -9.8254,
         -6.1232, -11.1878, -12.6506,  -9.5182,  -7.5724, -11.3513,  -6.89

  0%|          | 37/125000 [00:02<2:01:33, 17.13it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -8.2351, -12.5687,  -7.4748,  -8.3043,  -8.8380, -13.9473,  -8.5746,
         -8.0504,  -8.7995,  -8.4131, -12.6137,  -7.5486,  -9.7345,  -6.7627,
        -11.0026, -11.6636,  -8.4826,  -7.7553,  -8.0701,  -7.6207, -11.2473,
        -10.0109,  -8.2197,  -9.1808, -10.0711,  -6.4866,  -6.7153,  -9.3301,
         -7.1359,  -7.3258, -11.2654,  -8.8481,  -8.1841,  -8.4955, -10.9290,
         -8.1838, -10.4854,  -8.0590,  -9.3842, -11.1866, -12.9772,  -7.1161,
         -8.9366,  -8.0164,  -7.8525,  -6.0038, -12.1225,  -5.9785, -13.6126,
         -8.5454,  -9.4143,  -7.2489,  -8.6502,  -8.9796,  -6.6984,  -8.6272,
         -7.0861,  -8.2991,  -7.2104,  -6.5065,  -7.0621,  -9.1271, -11.52

  0%|          | 41/125000 [00:02<2:06:30, 16.46it/s]

tensor([ -6.3625, -11.6996,  -7.6331,  -8.6726,  -7.3099,  -8.6481, -10.2879,
         -8.3851,  -7.5091,  -8.2931,  -9.5125,  -7.5847,  -7.0286,  -9.2457,
         -7.0502,  -7.4506,  -9.2573,  -7.5717,  -8.7989,  -6.4830,  -8.1863,
         -7.2218,  -9.9826,  -8.3156,  -7.3555,  -9.8321,  -7.7349,  -9.1806,
         -6.6943,  -8.1496, -11.0944,  -6.3252, -10.7265,  -5.6895,  -8.6192,
         -7.6042,  -7.3785,  -7.0670,  -9.6313, -12.2738,  -9.4104,  -8.7267,
         -6.7345, -10.0858,  -9.6951, -13.2292,  -9.2283,  -8.8538,  -7.7230,
         -8.3529,  -7.2193,  -9.0619,  -7.1020, -12.4385,  -9.5972, -10.1079,
         -6.9862,  -6.8032,  -7.6930,  -7.4857,  -8.2375, -11.2682,  -7.5893,
         -6.1374,  -9.9993,  -7.0632,  -7.7713,  -9.1542,  -7.9177,  -9.7112,
         -7.5133,  -9.7784,  -6.4315, -10.8162,  -7.5779,  -8.8264,  -7.4727,
         -9.4788,  -8.5030,  -6.3404], grad_fn=<SumBackward1>) tensor([-0.0576,  0.1495, -0.4903, -0.3299,  0.5741, -0.4944, -0.0651, -0.7622,

  0%|          | 43/125000 [00:02<2:13:58, 15.54it/s]

tensor([ -7.2528,  -6.5855,  -8.0845,  -7.9772,  -9.4031,  -7.2685,  -9.0065,
         -7.2696, -12.8957,  -7.6890,  -8.2727,  -7.8203,  -8.7701, -10.2791,
        -10.1975,  -6.0501,  -9.2231,  -8.5945,  -7.2394,  -9.7392,  -6.2066,
         -9.7982,  -6.7522,  -9.3878,  -7.5393,  -9.4039,  -7.6650, -15.0861,
        -10.1704,  -9.2669, -13.5745,  -6.5964,  -8.2689,  -8.3737, -10.1637,
         -6.5197,  -6.1946, -10.7099,  -6.5380,  -9.3536,  -8.8492,  -6.4201,
         -7.9530,  -8.1611,  -7.5031,  -6.5463,  -6.8337,  -9.9135,  -6.2096,
         -6.4172,  -9.9970,  -9.4591,  -7.6128,  -7.2592,  -7.9799,  -7.4037,
         -9.9132,  -9.5886,  -7.5557,  -6.9310,  -7.1750, -15.3629,  -9.0975,
        -10.6019,  -7.0096,  -9.0942,  -7.7103,  -7.4874,  -9.0510,  -9.8599,
         -7.0019,  -6.7807,  -9.6473,  -6.1630,  -6.7871,  -7.1266,  -8.0150,
         -7.2915,  -9.3109,  -7.3527], grad_fn=<SumBackward1>) tensor([ 0.1633,  0.0482, -0.3283, -0.1580, -0.5822, -0.8122,  0.0485, -0.4605,

  0%|          | 47/125000 [00:02<2:19:47, 14.90it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -6.9163,  -8.8826, -10.1764, -13.4830,  -8.3668,  -7.6176,  -7.7690,
         -6.0381,  -8.8418,  -6.3225,  -9.5409,  -7.7867, -10.0066,  -7.6968,
         -7.6791,  -9.3088,  -6.7079,  -8.4315,  -7.1689,  -7.6573, -10.5059,
         -6.7056,  -7.4369,  -8.2739,  -8.9474,  -6.3708,  -6.6294, -10.9533,
         -5.8574,  -6.7045,  -7.8333, -10.4618,  -7.6966, -10.8723,  -6.3849,
         -8.8281,  -9.9589, -10.6062, -12.6099,  -6.5020,  -8.6089,  -9.1809,
         -9.1523,  -6.5581,  -6.8620,  -8.3259, -10.0696,  -9.0698,  -9.3578,
         -6.8943, -10.7437,  -9.3640,  -8.6481, -11.6619,  -6.2614,  -6.9259,
         -6.5758, -14.7594, -10.2021,  -8.0419,  -7.7590, -11.1298,  -9.2519,
         -7.4323,  -7.2477,  -6.2156,  -8.1862,  -9.7475, -10.3983,  -8.9180,
        -11.5726,  -8.2859,  -7.269

  0%|          | 49/125000 [00:03<2:10:42, 15.93it/s]

torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([-10.6008,  -9.8204,  -8.8129,  -7.4697,  -8.4970,  -9.3097,  -7.7046,
         -8.4601, -13.5749, -10.6668, -10.2045,  -9.5440,  -7.5788,  -6.7873,
         -8.0646,  -6.7964,  -8.7681,  -6.1040,  -8.1913,  -7.7198,  -8.9100,
         -8.9984,  -8.1954, -11.7981,  -6.9448,  -7.5277,  -7.1491,  -7.3767,
        -11.6525,  -8.7494, -10.2079,  -8.5449,  -5.7083,  -6.9710,  -8.3488,
         -6.6963, -11.1957,  -7.3426,  -6.2867,  -8.5253,  -7.8981, -10.3927,
         -9.4434,  -6.8803, -12.9015,  -8.6139, -13.9513,  -7.7448,  -6.3440,
         -6.4559,  -7.6665,  -8.8595,  -7.7829,  -8.8975,  -8.7834,  -8.3462,
         -7.2968,  -6.8202, -10.7802,  -7.0045,  -9.6822,  -6.6109,  -6.7056,
         -9.2114,  -8.3681,  -8.5869,  -6.4773, -10.2925,  -7.5047, -11.5399,
         -8.7492, -11.7503, -13.8120,  -7.5938,  -8.2223, -10.6396,  -9.5398,
         -8.2783,  -8.3873,  -8.697

  0%|          | 53/125000 [00:03<2:11:36, 15.82it/s]

torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -7.7634, -10.1738, -15.2493,  -6.4083, -11.5542, -14.5889,  -6.5680,
         -7.8826,  -6.4101,  -7.0732,  -8.4428,  -7.1851,  -7.7626,  -9.4417,
         -7.4889,  -8.1142,  -7.6153,  -7.1729, -11.5724,  -8.1855,  -7.7160,
         -9.7665,  -8.5172,  -8.4535,  -6.9871,  -7.8987,  -7.4854,  -6.6814,
         -6.2332,  -5.9975,  -7.1407,  -8.5943,  -8.7412,  -7.0311,  -7.9005,
        -11.3515,  -7.8537,  -8.5888,  -6.0913,  -7.1032,  -8.6530,  -8.1380,
         -6.9494,  -8.6778,  -7.4609,  -6.5111,  -8.4411,  -7.9705,  -6.5158,
         -9.4615, -11.0550,  -8.4063, -10.8345,  -9.1810,  -8.0749,  -9.6031,
         -6.1107,  -6.8140,  -6.8775,  -9.4907,  -9.2540,  -8.2619, -11.9805,
         -9.2302,  -7.5394,  -8.2319,  -8.0462, -13.8047,  -7.8286,  -7.0371,
         -7.8059,  -9.7068, -12.0116,  -8.2143,  -7.8350,  -7.5031,  -8.8408,
         -8.5865,  -7.7989,  -8.244

  0%|          | 57/125000 [00:03<2:09:47, 16.04it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -9.1712,  -7.4342,  -7.6553,  -8.8203,  -5.8769, -10.1750,  -8.1138,
         -7.4471,  -5.5959,  -7.9609, -11.8900, -10.1731,  -8.1759,  -7.5584,
         -9.1065,  -8.2547,  -8.9752,  -7.0853,  -8.3643,  -8.2165,  -7.8508,
         -9.0121,  -9.0524,  -9.2539,  -8.5590,  -7.5499, -11.1725, -12.1051,
         -7.0396,  -7.1941,  -6.5121,  -8.4074,  -7.4362, -11.7677,  -9.8463,
         -7.3468,  -8.2180,  -8.7114,  -8.4663,  -7.7746,  -7.4309, -13.3921,
         -9.3240,  -5.6611,  -7.8694, -11.1961, -10.1652,  -8.2131,  -6.7169,
         -7.1311,  -6.6365,  -7.1872, -11.0371, -10.3252,  -7.0503,  -7.3837,
         -8.0016, -11.3630,  -7.4002,  -8.5919, -10.5596,  -7.6590,  -8.2355,
         -6.9221,  -7.4574,  -7.4791,  -6.6466,  -8.9227,  -7.1895, -10.2785,
         -9.8813,  -8.6297,  -9.0531,  -8.0319,  -9.1750,  -6.5921,  -7.69

  0%|          | 59/125000 [00:03<2:07:33, 16.32it/s]

torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -6.3705,  -6.9743,  -7.1972,  -7.3972,  -7.1419,  -6.2889,  -8.7233,
        -11.5436,  -7.6405,  -7.2487,  -7.7671,  -9.6579,  -7.4359, -10.0630,
        -12.6017,  -7.4111, -10.5361,  -7.1545,  -5.9446,  -7.3018,  -6.1133,
         -8.3244, -13.0720,  -7.5956,  -8.8076,  -8.8457, -10.4610,  -8.2288,
        -11.0973,  -8.9550, -10.3419,  -6.5633,  -7.6975, -11.8571, -11.8605,
         -7.5390,  -8.1320,  -9.7337, -10.6877,  -6.9811,  -9.3679,  -7.3502,
         -7.3872,  -7.4658,  -6.2038,  -8.0708,  -9.3932,  -7.5295,  -7.5920,
         -8.3922,  -7.8707,  -7.5825,  -9.1624,  -7.5215, -11.9092, -11.6958,
         -7.5234,  -7.8970,  -6.7803,  -7.7532,  -7.4755,  -6.5186,  -7.5062,
         -7.9173,  -7.4793,  -7.8612,  -9.1870, -11.5351,  -6.1942,  -7.9717,
         -9.8550,  -7.9322,  -8.5885,  -9.8283,  -8.0361, -10.7224,  -8.4010,
         -8.6001,  -8.4635,  -7.701

  0%|          | 63/125000 [00:03<2:11:54, 15.79it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -8.2048,  -8.4905, -10.4981,  -8.5280,  -7.4260,  -7.1832,  -6.6437,
         -8.3527,  -7.7173,  -6.3149,  -6.4241,  -9.9291,  -7.1512,  -8.9003,
         -9.3011,  -9.6524,  -8.0130,  -8.5162,  -9.0630, -10.4871, -10.4881,
         -9.6986,  -8.9455,  -9.6987, -12.0549,  -7.3492, -10.1405,  -7.0728,
         -7.5030, -11.6368,  -7.1189,  -9.0730,  -7.5920,  -7.2203,  -8.7242,
         -7.0253,  -8.3797, -10.0123,  -9.5820,  -7.4789,  -7.9528, -10.6809,
         -9.4630,  -8.5860,  -9.0807,  -6.3398,  -9.4156,  -8.8958,  -8.1133,
         -6.4208,  -9.7246,  -7.3570,  -9.9708,  -9.8910, -10.0968, -11.0597,
        -13.8525,  -5.9336,  -7.8517,  -7.8000,  -9.5064,  -6.0979,  -7.96

  0%|          | 65/125000 [00:04<2:17:56, 15.10it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -7.9796,  -6.2406, -13.0201,  -7.1132,  -8.8782,  -9.2174,  -7.9910,
         -9.0454, -10.9360,  -8.6553,  -7.1622,  -6.6502,  -6.6680,  -8.3716,
         -6.4046, -10.1915,  -9.3899,  -8.9167,  -7.9542,  -9.9319,  -7.2289,
        -11.5883, -10.3908,  -7.9701,  -8.6730,  -9.6158,  -5.9460,  -8.6479,
         -9.4207,  -9.4169,  -9.3838,  -7.9114,  -7.0919,  -8.7293, -11.2765,
        -10.0769,  -9.2682,  -6.7415, -10.4302,  -6.0542,  -8.0561,  -8.0391,
         -8.1172,  -8.9643,  -8.6760, -10.0135,  -5.9132,  -6.1456,  -9.3531,
         -7.3282,  -7.2131,  -6.4589,  -8.9751,  -8.2441, -10.6513,  -9.0806,
         -8.6368,  -8.3184,  -6.1187, -12.4890,  -7.2242,  -6.8217,  -9.1681,
         -7.8839, -11.1836,  -6.027

  0%|          | 70/125000 [00:04<2:17:40, 15.12it/s]

tensor([ -6.4768,  -7.1651,  -7.7893,  -6.4894,  -5.8557, -10.5807, -15.7658,
         -8.9050,  -8.9157,  -7.8276, -12.3853,  -8.5138,  -7.0669,  -7.0900,
        -10.0686,  -7.8364,  -6.8345,  -7.6266,  -6.4830,  -7.3304, -13.9097,
         -8.1336,  -8.8530,  -9.3915,  -9.7090,  -8.1197,  -8.3534,  -7.2377,
         -8.2616, -12.5221,  -7.2260, -12.5400, -11.3084,  -6.2436, -10.2839,
         -8.1419,  -8.3509,  -7.3431,  -9.5294,  -7.2170,  -7.8752, -10.9925,
         -7.9929, -10.5427, -10.4820, -10.0704,  -9.3226,  -7.5546,  -9.6545,
         -8.6285, -10.6479,  -7.5378, -13.1523,  -8.4471,  -7.9833,  -9.1822,
         -9.7521, -10.2904,  -6.2823,  -6.4935,  -7.2792,  -9.8024, -14.1042,
         -7.0456,  -8.1758,  -7.8342,  -8.4657,  -7.1477, -12.3191,  -6.3684,
         -9.7437,  -6.7689, -12.7935,  -7.0666,  -6.2088,  -8.3628, -10.0419,
         -9.2153,  -7.8316,  -9.5846], grad_fn=<SumBackward1>) tensor([ 0.1847, -0.7377, -0.4219, -1.0562, -0.2811, -0.4735, -0.9578, -0.4009,

  0%|          | 72/125000 [00:04<2:26:56, 14.17it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -6.7136,  -9.6029, -10.1777,  -9.7583,  -7.2514,  -8.4099,  -8.3858,
         -6.5205,  -6.9004,  -8.8319,  -6.6871,  -7.3420,  -6.3769, -11.5753,
         -6.3051,  -8.0155,  -7.2315,  -9.2689,  -8.8192,  -8.9999,  -7.0610,
         -7.5460, -11.2212,  -6.7247,  -7.2282,  -8.4185,  -9.5894,  -8.8145,
         -8.7618,  -7.1156,  -9.6990,  -6.7318,  -9.8971,  -8.2126,  -6.4838,
         -9.5864,  -9.5840,  -8.7000, -11.4193,  -8.4118, -11.2084,  -9.8028,
         -9.2596,  -6.6002,  -7.1985,  -7.3851,  -9.5632,  -7.9460,  -9.3992,
         -6.5893,  -9.9433,  -7.8096, -11.1992,  -9.1893,  -6.7899, -14.5779,
        -12.3107,  -8.9023,  -8.3019,  -8.0261,  -6.5605,  -6.6276,  -8.2218,
         -7.1716,  -6.0084,  -9.7731,  -8.9164,  -7.3550,  -6.4099,  -9.4603,
         -9.3951,  -7.5377,  -6.6927,  -7.4927,  -8.3332,  -8.0827,  -7.76

  0%|          | 74/125000 [00:04<2:39:05, 13.09it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -8.6086,  -7.0074,  -6.9842,  -9.2106,  -9.2294,  -8.5077,  -8.3060,
         -7.7723,  -7.7441,  -9.3215, -11.6203,  -8.9970, -10.2307,  -6.8578,
         -9.9975,  -9.6776,  -8.5660,  -8.6504,  -8.7903,  -9.0841,  -7.7770,
         -7.6051,  -8.3011, -10.2705,  -7.0394,  -6.3919, -10.2118, -10.3387,
        -13.6828,  -7.0127,  -7.1271,  -8.2134,  -9.0073, -12.1896, -13.4666,
         -8.9903,  -6.5549,  -5.6190,  -7.8314,  -9.5059,  -9.3626,  -7.8058,
        -11.7629,  -9.8594,  -9.9378,  -9.7692,  -8.4762, -12.3992,  -8.5444,
         -7.4574,  -8.1226,  -7.5178,  -9.0641,  -8.9611,  -9.8946,  -7.4955,
         -6.2698,  -6.5085,  -8.0137,  -9.1925,  -6.8120,  -7.2647,  -6.84

  0%|          | 78/125000 [00:04<2:32:24, 13.66it/s]

torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -9.0414, -10.5382,  -8.1176,  -6.5771,  -6.9735, -10.4757,  -9.2354,
         -6.4845,  -7.0865,  -7.2232,  -9.7561,  -7.9516,  -9.0443,  -7.2873,
         -7.7080,  -7.2622,  -7.4969, -10.1889,  -6.7175,  -7.6886,  -8.9520,
        -10.9699,  -8.7368,  -8.8251, -10.2518,  -6.5672,  -6.5386,  -8.5038,
        -14.5451,  -9.4743,  -7.2362, -13.8207,  -7.7267,  -8.4790,  -8.4763,
         -9.0139, -10.8406,  -9.6793, -10.9854,  -9.3159,  -8.6196,  -9.4752,
        -11.8457,  -7.7650, -10.9364,  -6.4579,  -7.8428,  -8.7881,  -6.4082,
         -8.7516,  -6.6368, -11.1911,  -6.9256,  -7.3224,  -9.5761, -10.3171,
         -7.1251,  -6.7976,  -8.2539,  -6.8760, -12.4483,  -8.8968,  -6.1754,
         -9.2007,  -7.2264,  -6.7367,  -6.9185,  -9.7820,  -7.1627, -12.5667,
         -6.2609,  -6.8413,  -9.8958, -13.1455,  -8.7396,  -7.2809,  -6.5982,
         -6.2876,  -6.4338, -10.492

  0%|          | 80/125000 [00:05<2:34:41, 13.46it/s]

tensor([8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765,
        8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765,
        8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765,
        8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765,
        8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765,
        8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765,
        8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765,
        8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765,
        8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765, 8.4765],
       grad_fn=<SumBackward1>)
tensor(0.0485, grad_fn=<SubBackward0>)
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]

  0%|          | 84/125000 [00:05<2:25:58, 14.26it/s]

tensor([-1.0335, -1.2169, -0.6830, -0.9273, -0.0071, -0.6825, -0.9644, -0.6065,
        -0.5916, -0.5486, -0.3460, -0.9091, -0.5525, -0.1534, -1.0376, -0.5065,
        -1.3337, -1.0017, -0.6870, -0.8125, -0.1777, -0.4495, -1.3275, -0.4910,
        -0.5955, -0.7503, -0.3135, -0.9803, -0.8541, -0.3123, -0.9818, -0.2593,
        -1.0248, -1.2381, -0.4005, -0.8681, -0.2866, -0.6319, -1.1598, -0.6939,
        -0.7429, -0.6358, -0.5242, -1.0316, -0.9587, -0.1698, -1.0332, -0.3966,
        -1.1460, -1.2515, -0.2435, -1.0746, -0.3118, -0.6511, -1.0689, -0.6261,
        -0.6828, -0.6675, -0.6722, -0.8883, -1.0757, -0.1369, -0.8745, -0.5881,
        -1.3491, -1.3045, -0.3487, -0.9945, -0.2597, -0.3860, -1.0553, -0.6845,
        -0.7868, -0.5699, -0.7350, -0.8876, -0.8395, -0.0614, -1.0516, -0.3665],
       grad_fn=<SqueezeBackward1>) tensor([8.4768, 8.4768, 8.4768, 8.4768, 8.4768, 8.4768, 8.4768, 8.4768, 8.4768,
        8.4768, 8.4768, 8.4768, 8.4768, 8.4768, 8.4768, 8.4768, 8.4768, 8.4768,
    

  0%|          | 88/125000 [00:05<2:13:48, 15.56it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -7.6845,  -9.1103,  -7.1801,  -9.3723, -10.9706,  -9.2949,  -9.4032,
         -8.2703,  -8.8359,  -5.8661,  -6.2724,  -9.0772, -11.2083,  -9.6798,
        -12.1201,  -8.1325,  -5.9436,  -7.0992,  -7.5711,  -9.3987,  -6.5649,
         -6.9431,  -8.2020,  -9.9486,  -7.5289,  -7.3233,  -9.7049,  -9.0546,
         -8.0244,  -8.4565,  -9.8601,  -6.3865,  -8.5394, -11.7828,  -8.4996,
         -8.3374,  -7.3472,  -9.5712, -14.0310,  -8.0854,  -6.5613,  -9.1921,
         -9.9373,  -8.7922, -13.2001,  -6.0878,  -8.5531,  -9.5692,  -6.3766,
         -8.0654,  -9.9128,  -8.1021, -12.1011,  -7.4256,  -8.1332,  -6.3137,
        -12.9011,  -7.0166, -11.3784, -10.8722,  -7.4448,  -9.3081,  -7.0903,
         -9.7705,  -6.7543,  -5.8998,  -6.8838,  -6.6873, -11.9960,  -7.8517,
         -8.7407,  -6.1921,  -6.948

  0%|          | 92/125000 [00:05<2:08:35, 16.19it/s]

tensor([-1.2501, -1.2279, -0.0109, -1.0572, -0.4036, -0.9675, -1.0359, -0.6278,
        -0.8268, -0.9324, -0.3042, -0.7542, -0.7832, -0.4639, -0.5777, -0.6584,
        -1.4634, -0.9910, -0.3587, -0.7009, -0.3682, -1.0920, -1.1074, -0.7489,
        -0.7191, -0.9933, -0.3331, -0.6410, -1.0999, -0.6721, -1.0194, -0.6740,
        -1.1497, -0.8866, -0.2770, -0.8322, -0.1245, -0.8500, -0.7724, -0.6575,
        -0.5083, -0.7046, -0.4803, -0.7984, -1.0698, -0.2820, -0.9744, -0.6261,
        -1.1425, -1.1503, -0.7668, -0.9633, -0.6643, -0.2183, -1.2442, -0.8165,
        -0.3248, -0.5230, -0.4502, -1.0780, -0.8293, -0.3855, -0.9802, -0.7328,
        -0.9213, -0.8902, -0.6849, -0.5989, -0.5251, -0.5484, -0.7449, -0.8480,
        -0.6054, -0.7709, -0.1336, -1.0623, -0.9576, -0.4589, -1.0748, -0.6935],
       grad_fn=<SqueezeBackward1>) tensor([8.4762, 8.4762, 8.4762, 8.4762, 8.4762, 8.4762, 8.4762, 8.4762, 8.4762,
        8.4762, 8.4762, 8.4762, 8.4762, 8.4762, 8.4762, 8.4762, 8.4762, 8.4762,
    

  0%|          | 94/125000 [00:05<2:02:45, 16.96it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([-11.2609, -10.1433,  -7.8261,  -7.1592,  -7.9618,  -8.6262,  -7.4257,
         -7.4188,  -6.7197,  -9.4614,  -7.5781,  -6.7318,  -6.6493,  -7.6034,
         -8.1719,  -9.3206,  -8.3964,  -7.1358,  -7.1833,  -6.4729,  -7.6143,
         -7.2978,  -7.9060,  -7.1998,  -8.8018,  -8.3702,  -9.1116,  -8.0325,
         -9.9662, -12.7886,  -9.3398, -10.0258,  -7.9745,  -6.8886,  -8.1216,
         -6.9121,  -6.2175,  -9.2978,  -8.5209,  -5.7573,  -7.7144,  -9.1871,
         -8.4746, -11.3915, -10.2583,  -7.5900, -11.0707,  -9.0951,  -9.3364,
        -16.0135,  -7.8151,  -8.9138, -10.1455,  -7.8860,  -9.5865,  -6.9385,
         -8.8158,  -8.6628,  -8.0514,  -6.9770,  -8.6542,  -9.8232,  -9.91

  0%|          | 98/125000 [00:06<2:06:40, 16.43it/s]

tensor([ -9.4725,  -8.5645,  -7.0399,  -7.1342,  -7.7167,  -7.3242,  -8.1139,
         -6.6905, -10.7147,  -7.2130,  -8.1856,  -9.1307,  -9.3362, -18.1223,
         -9.4779,  -6.4923,  -7.6330, -12.7872,  -7.2460,  -7.3093,  -9.7933,
         -8.0607,  -7.8972,  -6.8670,  -6.6376,  -7.0639,  -8.6931, -11.1397,
         -7.3894,  -7.7917,  -7.1764,  -9.0024,  -8.2059, -11.0776,  -8.8398,
        -10.0634,  -6.7167,  -8.2522,  -6.3797,  -7.9618,  -8.9907,  -8.7291,
         -9.8479, -12.4156,  -7.9713,  -9.3927,  -9.8010,  -6.9507, -11.3172,
         -7.5975, -10.6736,  -7.7440,  -9.3054,  -7.1185,  -8.3071, -10.1673,
        -11.2276,  -8.2746,  -8.5030, -10.4907,  -8.4808, -10.8009,  -7.1563,
         -9.0813,  -7.9651,  -9.1959,  -7.7141,  -8.5929,  -8.3777,  -7.7939,
         -8.8186,  -8.4146,  -8.5373,  -6.9496,  -9.0251, -12.8634,  -7.4281,
         -6.6215,  -8.2903,  -8.1194], grad_fn=<SumBackward1>) tensor([-1.2573, -1.4097, -0.7464, -0.6697, -0.3384, -0.6163, -1.3972, -0.5847,

  0%|          | 102/125000 [00:06<2:05:32, 16.58it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([-11.9275,  -7.9284,  -9.7516,  -6.1249,  -7.6316,  -8.1903,  -8.1055,
         -6.1425,  -7.2911, -10.3707, -10.0825,  -7.4297,  -7.9606,  -7.7017,
         -7.7097,  -7.0017,  -8.4939,  -6.3361,  -7.5051, -10.2124,  -7.1299,
         -9.2008,  -8.7922,  -7.0325,  -9.8869,  -8.0036,  -7.4278,  -6.6951,
         -7.7343, -14.0521, -11.3454,  -6.0872,  -7.0553,  -8.1564,  -6.3285,
        -10.3148,  -9.2121, -10.1620,  -6.9151, -11.5434, -10.6939,  -7.7122,
         -8.8765,  -5.6415,  -7.2511,  -7.8581,  -9.4647, -12.5710,  -6.1750,
         -8.0085, -10.1863, -10.8889,  -7.7777,  -9.1807,  -9.8068,  -6.6376,
         -9.8119,  -9.9289,  -7.6643,  -8.3208,  -6.9160, -10.9975,  -6.27

  0%|          | 104/125000 [00:06<2:03:32, 16.85it/s]

tensor([ -8.6511,  -8.8504,  -7.8054,  -8.6706,  -9.2766,  -9.6406,  -7.0713,
         -6.9323, -10.1827,  -7.1628,  -6.5918, -11.4820,  -9.7461,  -8.6998,
         -8.3446,  -9.1764, -13.5098,  -9.6866,  -9.4853,  -8.2007,  -8.0592,
         -8.1072,  -9.9560,  -8.1700,  -6.6503,  -7.2961,  -6.5477, -13.7702,
         -6.9235,  -6.2429,  -7.2294,  -8.9694,  -8.7118,  -6.8951, -10.9664,
         -9.5428, -12.5020,  -9.3268,  -7.1283,  -8.0419,  -8.7573, -12.7456,
         -8.6687, -11.4487,  -6.6647,  -6.0063, -10.9673, -10.2322, -10.3500,
         -7.3528,  -7.8193,  -6.1737,  -9.5565,  -8.8335,  -7.6281,  -7.1736,
         -9.0761,  -8.2362,  -6.8190,  -6.1744,  -8.6785,  -6.6110,  -8.1563,
         -7.7750, -11.7386,  -6.9105,  -6.5414,  -6.3397,  -8.2097,  -6.0413,
         -9.1906, -12.1944,  -7.7978,  -8.0157,  -7.4577,  -5.8703, -10.9369,
        -10.1761, -10.1161, -10.1988], grad_fn=<SumBackward1>) tensor([-1.1978, -1.0477, -0.4306, -0.3682, -0.1178,  0.7629, -1.0557, -0.0672,

  0%|          | 108/125000 [00:06<2:03:07, 16.91it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -7.2150, -12.1077, -10.1569,  -8.2422,  -7.9840,  -7.1945, -10.1442,
        -10.0703,  -8.6160,  -6.8435,  -7.9456,  -7.1743,  -6.4563,  -7.5957,
         -9.8801,  -8.1145,  -7.2871,  -6.8465, -16.2221,  -8.8724, -11.3471,
         -7.3581,  -9.5183,  -8.4341, -10.1326,  -6.7671,  -8.4662,  -9.6987,
        -10.5689,  -6.0350,  -7.7705,  -7.5341,  -8.0045,  -9.1815,  -6.9839,
         -8.0655,  -6.8071,  -8.2605,  -6.0308,  -7.4328,  -8.9207,  -8.8571,
         -8.4458,  -8.6205,  -7.9034,  -8.1423,  -7.1774, -12.6804,  -8.6585,
         -7.3025,  -7.8490,  -6.5333, -12.3444, -10.4519,  -6.1022,  -8.3409,
         -7.0258, -13.2362,  -6.0444,  -9.0627, -10.2832,  -6.0495,  -7.78

  0%|          | 112/125000 [00:07<2:05:07, 16.64it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -6.2195,  -7.0034,  -6.5057,  -8.8551,  -8.0736,  -8.8831,  -9.4947,
         -6.9537,  -6.9111,  -9.3468,  -6.2309,  -6.5364,  -7.2459,  -6.6806,
         -7.8866,  -7.5829,  -8.2109, -10.0844, -10.3590, -10.9777, -12.7310,
         -8.4838, -11.1443,  -7.6151,  -8.6536,  -7.3489,  -7.2451,  -7.5061,
         -7.4399,  -8.6908,  -7.9619, -11.9044,  -7.0463,  -6.8030,  -6.6139,
         -8.7992,  -6.6695,  -6.1393,  -6.0879,  -9.3628, -11.7369, -10.9273,
         -7.6204,  -8.0407,  -7.2284, -10.5186,  -7.4200,  -8.1671,  -5.8392,
         -7.8170,  -9.2409,  -7.0787, -14.1097,  -5.8327, -11.6823,  -6.4983,
         -6.8692,  -6.8002,  -7.3561,  -9.9262,  -7.3761,  -9.2768,  -7.3993,
        -10.0029,  -8.0760,  -9.7378,  -9.3334, -13.6646,  -6.7306, -11.83

  0%|          | 116/125000 [00:07<2:00:43, 17.24it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -8.2617,  -9.8819,  -6.4954,  -8.2970,  -6.8618,  -7.0891,  -7.8278,
         -7.2732,  -8.1314,  -7.4408, -11.6665,  -9.0036,  -8.0930,  -6.0327,
        -11.4350,  -6.2074,  -8.1117, -11.7449,  -8.3232,  -6.9386,  -8.5883,
         -9.6945,  -7.8288,  -8.7487,  -7.4415,  -7.8499,  -7.5803,  -7.1935,
         -7.6763,  -7.0100,  -8.4896,  -6.7073,  -7.4599,  -9.4598,  -7.0589,
         -7.6137,  -8.0436,  -9.1302,  -7.9192,  -7.5984,  -9.8698,  -9.1203,
         -8.4166, -11.0801, -16.0813,  -6.7429,  -9.6185,  -9.1567,  -9.2767,
        -11.3599,  -7.6885,  -6.4932,  -7.1359,  -9.8490,  -6.9201,  -5.9278,
         -7.8877,  -6.7946,  -7.4608,  -6.5611,  -6.8736,  -6.9099,  -6.93

  0%|          | 118/125000 [00:07<2:01:29, 17.13it/s]

torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -5.8941,  -8.9978,  -8.3463,  -8.8380,  -6.3969,  -8.2102,  -8.6531,
         -8.8183,  -7.9604,  -7.2207,  -7.3503,  -9.0519,  -7.4550,  -6.2223,
        -11.3961,  -7.9800,  -8.2893,  -6.4909,  -6.5247,  -5.8742,  -9.2556,
         -6.7941,  -7.6356,  -6.6979,  -6.5550,  -9.1862, -14.4210, -10.1088,
         -6.6276,  -7.6809,  -7.3642,  -9.3578,  -8.6650,  -9.8422,  -7.7766,
         -8.1178, -11.7330, -12.4902,  -8.5462,  -7.2704, -12.5676,  -9.4126,
        -10.6335,  -7.8149, -13.1731,  -7.5326,  -7.4272,  -6.5394,  -9.3401,
         -7.2918,  -8.9639, -10.8898,  -8.1017,  -9.1412,  -7.7904,  -6.7165,
         -7.6717,  -7.9919,  -8.0545,  -7.7660,  -8.3528,  -7.7197, -12.5017,
         -7.2448,  -8.8736,  -6.3396,  -7.5558,  -8.0347, -10.6365,  -8.6056,
         -7.5084,  -8.7168,  -7.8512,  -8.8978,  -6.4352,  -6.6541, -10.1053,
         -8.7894, -11.2904, -10.062

  0%|          | 122/125000 [00:07<2:04:07, 16.77it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -8.8916,  -6.9096,  -7.2967,  -8.3388,  -9.3602,  -7.2334,  -8.2015,
        -12.6697,  -9.0189,  -9.6599,  -7.1067, -12.2886,  -7.8284,  -7.5137,
        -11.2915,  -7.3769,  -7.5104,  -9.2327,  -9.7414, -13.3483,  -6.3744,
         -9.9195, -13.0955,  -8.5137,  -7.6442,  -8.4846,  -9.2080,  -6.6464,
         -8.3867,  -8.3535, -11.6116,  -8.0940,  -9.3202,  -8.3199,  -8.2508,
         -9.9741,  -7.1538,  -6.6409, -12.5256, -10.8069,  -9.1094, -11.3024,
         -7.3488,  -8.5067,  -9.0619,  -7.8119,  -7.1859,  -6.3043, -11.7056,
        -12.9290,  -9.9616,  -7.4239,  -9.8733,  -6.5688,  -9.2414,  -7.5310,
         -7.8396,  -7.0163,  -9.0483,  -8.2648,  -7.8240,  -8.7546,  -7.95

  0%|          | 126/125000 [00:07<2:14:27, 15.48it/s]

tensor([8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966,
        8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966,
        8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966,
        8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966,
        8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966,
        8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966,
        8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966,
        8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966,
        8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966, 8.4966],
       grad_fn=<SumBackward1>)
tensor(-0.2752, grad_fn=<SubBackward0>)
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17

  0%|          | 128/125000 [00:08<2:12:15, 15.74it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -7.9908, -11.2218,  -8.7124,  -7.6571,  -9.0286,  -8.3198,  -7.3259,
         -7.0473, -12.4836,  -6.7312,  -8.2041,  -7.0484,  -8.0820,  -6.7749,
         -8.4902,  -5.9359,  -7.2629,  -8.4072,  -8.1954,  -9.1135,  -8.8793,
         -7.0328,  -6.5737,  -7.1768,  -7.3687,  -9.7621, -15.1419,  -7.9254,
        -11.1509,  -6.9196,  -6.7785,  -8.5082,  -7.9837,  -9.7022,  -7.2700,
         -7.5408,  -8.2826,  -8.2825,  -7.6454,  -7.2727, -10.1651,  -9.3097,
         -8.5982,  -9.9922, -10.8313,  -6.3259,  -8.0857,  -7.4360,  -7.5375,
         -7.4632, -11.6434,  -8.1920,  -6.1999,  -8.7510, -12.2550,  -7.0357,
         -8.6289, -12.7160,  -8.4008,  -6.4470,  -9.3418,  -8.0425,  -9.3524,
        -15.8616,  -8.9754,  -7.1713,  -7.8810,  -9.4806,  -7.5925,  -7.5846,
         -8.4846,  -5.7142,  -8.1641,  -9.8501,  -7.2847, -11.6181, -12.53

  0%|          | 132/125000 [00:08<2:11:22, 15.84it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([-16.5476,  -8.1737,  -9.9390,  -9.8324, -10.0946,  -6.6520, -11.5217,
         -7.1118,  -7.7843,  -8.7628, -10.7705,  -8.9553,  -9.4611,  -6.9319,
         -7.6599,  -7.7302,  -6.3220,  -9.1996,  -8.1422, -10.2793,  -9.9091,
         -8.2501,  -8.9760,  -8.8476,  -9.4264,  -7.9075,  -7.0631,  -7.2434,
         -9.4068,  -7.0251,  -8.4784,  -9.7783,  -9.0394,  -7.2983,  -8.7887,
         -8.4302,  -8.8361,  -6.1829,  -6.7355, -10.2393,  -7.7687, -11.6332,
         -8.1375,  -9.0721,  -6.8001,  -8.4621,  -8.5624, -10.4929,  -8.9150,
         -8.4650,  -8.8707, -14.5845,  -8.7275,  -6.5509, -11.1021,  -9.4378,
         -9.7110,  -6.5059,  -8.0630, -10.0339,  -8.9334,  -7.1193, -10.26

  0%|          | 136/125000 [00:08<2:11:23, 15.84it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([-10.3298,  -8.6487,  -7.3019,  -8.2431,  -7.4720,  -9.2848,  -6.2618,
         -9.1694,  -7.1945,  -6.7842,  -9.0263,  -7.4012, -11.4955,  -8.8297,
         -8.6552, -10.3845,  -7.0043,  -8.9066,  -5.7720,  -9.8625,  -8.4090,
         -6.6841,  -9.2136,  -7.5138,  -9.3928, -11.2639, -11.0987, -10.9094,
         -9.1514, -10.0198,  -8.0130,  -7.5592,  -7.8000,  -6.7621,  -7.8789,
         -9.3414, -11.1355, -11.2618,  -8.0305,  -7.9040,  -7.9841,  -9.5450,
         -9.0966,  -8.2086,  -7.3632,  -7.4697,  -6.7403,  -7.1688,  -8.5442,
         -6.3398,  -8.2337,  -7.1295,  -9.3853, -10.5387,  -7.6639,  -9.6435,
         -7.4632,  -7.3375,  -6.6673,  -7.2824,  -8.1936,  -7.5577,  -8.5190,
         -9.9653,  -6.9215,  -6.3989,  -9.3027,  -8.3201,  -9.5128,  -6.99

  0%|          | 138/125000 [00:08<2:10:11, 15.98it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -7.1325, -12.6919,  -7.1119,  -9.3629, -12.5098, -11.8277,  -7.4087,
        -10.8369,  -7.3087,  -9.5430,  -6.8134,  -7.8300,  -7.1155, -10.3555,
         -9.6875,  -9.3279,  -9.8653, -10.3621,  -7.6231,  -7.8644, -12.7778,
         -8.9214, -10.2003,  -6.9700,  -8.3050, -10.5518,  -8.2547,  -8.2193,
         -7.3488,  -8.9308,  -6.4950, -11.7605,  -7.5648,  -8.3636,  -6.9581,
         -9.0834, -10.8433, -10.6763,  -8.5433,  -7.0266,  -7.9833,  -7.9835,
         -6.7542,  -8.0299,  -6.5011,  -7.8596,  -6.6945,  -7.6696,  -7.0286,
         -7.2756,  -9.4777,  -8.7817,  -6.0767,  -7.1767,  -7.0318, -11.5411,
         -7.7395,  -6.6669, -10.3718,  -7.4730,  -9.4740,  -8.0098,  -7.70

  0%|          | 140/125000 [00:08<2:14:19, 15.49it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -7.7908,  -7.8491,  -7.7688, -11.1502,  -7.4620,  -7.6947,  -8.2660,
         -7.6056,  -9.1300, -10.7943,  -7.3346,  -7.2071,  -8.0780,  -9.2838,
         -8.4010, -11.9721,  -6.8516,  -7.3741,  -6.7298, -10.4162,  -9.2621,
         -7.7499, -10.4396,  -8.4345,  -6.9978,  -8.9373, -11.5135,  -7.7413,
         -6.8064,  -7.4218,  -9.4840, -10.8918,  -7.8230,  -8.6021,  -8.1585,
         -9.3694,  -7.5939,  -6.6944,  -7.2379,  -6.8265,  -6.5056,  -7.9913,
         -7.2437,  -6.1996,  -8.3658,  -7.0118,  -6.3639,  -7.2326,  -9.7353,
         -5.7353,  -7.9546,  -6.5837,  -7.9732,  -7.3362,  -7.2012,  -7.7617,
         -7.6850,  -9.8183,  -8.5142,  -8.8585,  -8.7019, -11.2626,  -8.7737,
         -8.2015,  -8.6594,  -7.466

  0%|          | 144/125000 [00:09<2:11:15, 15.85it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([-10.1898,  -6.3785,  -7.1700, -10.4986, -10.0765, -13.6169,  -7.0619,
        -10.8618,  -9.0138,  -7.5892, -11.9309, -16.7090,  -8.7755,  -9.6999,
         -6.4396,  -8.9039, -10.2561,  -9.0069,  -6.4238,  -7.6452,  -6.7518,
         -8.7238,  -7.4698, -10.4232,  -6.4426,  -7.4638,  -9.5155, -11.9645,
        -10.8104,  -7.5350,  -6.6878, -12.2558,  -7.8131, -11.8402,  -8.5011,
        -11.2110,  -8.3109,  -9.6749,  -7.1038,  -7.4451,  -7.3993,  -9.0192,
         -6.1799,  -8.4389,  -6.8525,  -6.9459,  -7.9928, -11.3026,  -8.2164,
        -10.1478,  -6.5351,  -7.4981,  -7.0799, -11.6070,  -9.6968,  -9.6633,
        -13.7524,  -6.4765,  -8.8517,  -7.3372,  -7.3791,  -7.2266,  -9.03

  0%|          | 148/125000 [00:09<2:05:58, 16.52it/s]

tensor([ -9.4478,  -8.9015,  -7.6210,  -8.7674,  -7.2096,  -6.9533, -11.4753,
         -8.5620,  -6.6191,  -8.4988,  -6.5657,  -9.4262,  -7.2385, -10.2432,
         -8.2730, -10.1847,  -9.2630,  -6.2864,  -8.6381,  -9.1410,  -8.1466,
         -9.7242,  -7.7858,  -6.5276,  -6.1886,  -8.1924,  -7.8476,  -7.8946,
         -7.5598,  -9.7335,  -7.3030,  -8.2273,  -6.7415, -11.1898, -10.8392,
        -10.7243,  -6.9139,  -8.2624,  -7.7827,  -7.4036, -10.8027,  -7.3316,
         -7.7481,  -8.9033,  -7.3196,  -7.1990,  -7.4253,  -7.7142,  -8.4229,
         -6.9442,  -7.7235,  -9.6336,  -9.0531,  -6.9844,  -6.9282,  -9.7367,
         -8.8566,  -9.5654,  -6.3321,  -9.4683,  -7.3869,  -8.3551,  -7.5142,
         -6.7006,  -9.2202,  -8.8407,  -7.2348, -10.1193, -10.8482, -11.1156,
        -10.9600,  -9.0741,  -7.6947, -11.3595,  -7.6498, -10.8250,  -7.6307,
         -7.7764, -13.3585,  -7.2215], grad_fn=<SumBackward1>) tensor([-1.5656, -1.2702, -0.8024, -0.6781, -1.9334, -0.3861, -0.9043, -1.1268,

  0%|          | 152/125000 [00:09<2:20:41, 14.79it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -9.8194,  -8.9885, -10.3326,  -8.3889,  -8.0892,  -9.5233,  -8.5395,
         -7.5365,  -7.1225,  -6.7990,  -6.8226,  -9.4866,  -7.4633,  -8.7502,
        -10.3640,  -8.8073,  -5.6471,  -7.8873,  -8.7222,  -8.7853,  -7.2769,
         -7.4565,  -8.7423,  -8.6956,  -8.5261,  -8.3005,  -8.0832,  -6.4467,
        -10.8266,  -7.7171,  -8.0483,  -9.1509,  -8.7774,  -6.6582,  -7.2269,
         -7.6332,  -6.2551,  -8.2985,  -7.0237,  -8.3105,  -8.4508, -10.7028,
         -7.5241,  -6.1040,  -9.3389,  -7.7036, -12.6195,  -8.8943,  -9.4813,
         -7.1153,  -8.0771,  -9.6473,  -7.0115,  -6.8367,  -7.9718,  -6.3785,
         -8.0078,  -8.8774,  -9.6749,  -8.7922,  -6.7216,  -6.9459,  -9.36

  0%|          | 154/125000 [00:09<2:14:07, 15.51it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -8.9930, -10.3449,  -9.9453,  -8.1330,  -7.3075,  -8.5979,  -8.0796,
         -6.6131,  -6.8076, -11.8380,  -7.1088,  -8.3910,  -6.5066,  -7.3708,
         -9.3510,  -7.0031,  -9.7141, -11.2519,  -9.2815,  -8.3741,  -8.6460,
         -6.8114,  -9.9392,  -7.0868, -12.6165,  -9.3041,  -9.3582,  -7.9510,
        -13.6609,  -9.6276,  -7.7586,  -9.0997,  -7.0732,  -6.6788, -11.0017,
         -8.6424,  -7.9441,  -6.7920,  -8.2773, -10.9310,  -9.6853,  -7.2465,
         -9.3226,  -9.7407,  -8.4603,  -8.3243,  -8.1732,  -6.8377, -10.8352,
         -9.6719,  -6.0263,  -8.1077,  -6.9574,  -7.2454, -10.6710,  -7.2369,
         -9.6503,  -7.4636,  -8.3978,  -6.7262,  -7.1660, -10.4002,  -8.13

  0%|          | 158/125000 [00:09<2:22:38, 14.59it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -9.3264,  -7.0144, -10.0152,  -9.2226, -12.5249,  -8.9010,  -7.6960,
         -8.7966,  -7.9868,  -6.3986,  -7.0469,  -7.5956,  -8.6019,  -9.0347,
         -9.0101,  -7.3870,  -7.5139,  -8.7920,  -8.2721,  -7.2575,  -9.6933,
         -8.4963,  -9.4078,  -6.7042,  -8.7670,  -7.6183,  -7.3150, -10.0482,
         -7.6999,  -8.0501,  -7.2291,  -7.2172,  -9.2784,  -6.9660,  -8.0787,
         -6.2308,  -6.5230,  -9.1027,  -6.7730,  -7.0482,  -8.4376,  -6.2170,
        -11.6711,  -6.7282,  -6.7689,  -7.5342, -12.7045, -10.4984,  -7.7225,
         -8.0422,  -7.0398,  -6.6668,  -6.7093, -10.8020, -10.1508, -12.2249,
         -6.9319,  -8.5201,  -9.0074,  -8.8246,  -8.9216, -12.1567,  -6.1717,
         -9.6605,  -6.9253,  -8.0262,  -7.6507, -13.7364,  -7.7984,  -6.7429,
         -6.2283,  -9.8483,  -6.254

  0%|          | 160/125000 [00:10<2:25:40, 14.28it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -7.7441,  -6.6958,  -7.6799,  -9.8305, -10.6339,  -7.6095,  -7.8159,
         -9.7839,  -8.2090,  -8.1569,  -8.1655,  -8.0519,  -7.8080,  -8.7714,
         -7.4958, -10.6701,  -8.1039,  -9.3517, -12.9866,  -8.8461,  -7.2677,
         -7.1030, -10.2339,  -6.6482,  -6.5578,  -9.6058,  -6.8963,  -9.9161,
         -9.5300,  -7.5128,  -8.0363,  -7.8191,  -7.8434,  -6.6363,  -8.4555,
        -10.0814,  -8.0771,  -8.4097,  -6.1841,  -9.4465,  -6.7938, -10.1182,
         -9.1043,  -6.6546,  -6.9007,  -7.3351,  -6.7090,  -9.9492,  -9.3439,
         -7.5878,  -9.9821,  -8.0726,  -6.5906,  -6.2470,  -6.6250, -10.3424,
         -7.7967,  -8.0312,  -8.5346,  -7.4150,  -7.5824,  -9.9736,  -7.73

  0%|          | 162/125000 [00:10<2:38:06, 13.16it/s]

tensor([-11.5130,  -8.7172,  -7.5788, -10.5444, -11.0655, -10.3876,  -7.8736,
         -8.5586,  -8.0441,  -7.3301,  -7.1875, -10.0914,  -7.5892,  -6.3726,
         -6.2139, -10.6839, -12.4305,  -7.2057,  -8.2226,  -9.0431, -10.6961,
        -10.4938,  -7.3255,  -7.6930,  -6.6312,  -7.8300,  -7.9644,  -6.9717,
         -7.9775, -11.0523, -10.6037,  -7.1347,  -8.5997,  -8.4025,  -8.9391,
         -6.0152,  -9.3105, -12.9177,  -7.0432,  -8.7717,  -8.0102,  -8.2844,
         -6.7765,  -8.1003,  -9.6289,  -7.0864,  -7.3155,  -7.7772,  -6.0567,
         -6.9560,  -7.8714,  -9.0521,  -9.9221,  -8.2930,  -8.1362,  -7.6858,
         -6.2271,  -9.1584,  -9.2806, -10.7457,  -8.3637,  -7.9361,  -9.4449,
         -9.0394,  -7.8323,  -6.5886,  -6.8589,  -8.0669, -10.9453,  -6.1461,
         -8.3152,  -8.9276,  -7.3162,  -8.7060,  -8.8416,  -9.1747,  -7.2780,
        -11.0051,  -7.1084,  -8.1533], grad_fn=<SumBackward1>) tensor([-1.3007, -0.9618, -0.5932, -1.4629, -1.5570, -0.4202, -1.2020, -0.7548,

  0%|          | 164/125000 [00:10<2:44:01, 12.68it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -6.8177,  -6.0193,  -8.4927, -14.2666, -11.4573,  -6.8319,  -6.9228,
         -6.5454, -10.6554,  -9.5765,  -7.0431,  -6.7389,  -8.1891,  -7.9922,
        -11.1711,  -8.4072,  -7.5948,  -7.1630,  -6.6451,  -8.9589,  -9.9226,
         -8.4466,  -7.0907,  -8.6431,  -7.5009,  -7.8274,  -7.6351,  -7.8988,
         -9.5956,  -6.7040, -10.0323, -11.1490,  -7.3936,  -6.3441, -11.0037,
         -5.7373, -11.3194, -10.3691,  -8.4782,  -6.5332, -12.5634,  -8.0879,
         -8.5605,  -8.7369, -10.0885,  -8.1097,  -6.0298,  -6.7460,  -9.7049,
        -10.9101,  -8.2698,  -9.4973,  -6.5970,  -9.5965,  -7.0432,  -8.9524,
        -10.5819, -10.1534,  -8.1927,  -7.0681,  -7.5467, -12.4316,  -7.1487,
         -6.2043,  -8.2833,  -6.4545,  -8.6628, -11.4874,  -9.1180,  -7.0809,
         -6.7877,  -7.1558, -10.145

  0%|          | 168/125000 [00:10<2:46:15, 12.51it/s]

torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -7.4460,  -9.8353,  -8.6253,  -9.3530,  -9.2004,  -7.0980,  -8.4558,
        -11.4677,  -6.7079,  -9.3855,  -7.0180,  -8.6594,  -8.1497,  -7.9748,
         -7.2013,  -6.9447,  -8.2489,  -7.7732,  -8.6067, -15.8396,  -9.8950,
        -13.0023,  -9.3331, -12.1415, -10.5964,  -9.9533,  -7.9010,  -7.1922,
         -6.3541,  -7.1124,  -8.3898, -11.7008,  -7.3778,  -5.6750,  -7.0260,
         -7.7535,  -7.6466,  -9.5435,  -6.9344,  -9.9752,  -6.7432,  -7.6014,
         -9.9493,  -7.1740,  -7.2623,  -8.5289,  -8.0097,  -7.5812,  -9.9376,
        -12.2742,  -9.5596,  -7.8926,  -6.8616,  -8.8648,  -7.5764,  -8.4977,
         -8.6117,  -9.2570,  -6.6114,  -9.1546,  -7.1858,  -7.8310,  -7.9865,
        -10.9276,  -8.2450, -11.3308,  -9.2929,  -8.2506, -11.2538,  -6.6454,
         -9.3150,  -7.7732,  -7.1851, -10.1300,  -9.3397,  -7.9795,  -9.4835,
         -8.1005, -10.5425,  -8.864

  0%|          | 170/125000 [00:11<2:52:15, 12.08it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([-13.3479,  -9.2438, -10.2112,  -6.7729,  -9.3618, -13.9349,  -9.1762,
         -8.4814, -11.1641, -12.1682,  -6.6111,  -7.9644,  -7.4239, -10.2599,
         -6.0347,  -6.8124,  -7.1722,  -8.4000,  -8.4050,  -6.3065,  -6.8174,
         -9.7605,  -7.4025,  -7.1293,  -9.2413,  -6.1177,  -6.6373,  -7.7836,
        -10.1110, -12.7716,  -8.2617,  -8.2947,  -7.4993,  -8.8559,  -7.2479,
         -8.9448,  -8.6533,  -8.0903,  -8.4207, -10.3075, -13.4809, -11.8558,
        -10.8472,  -7.9820,  -9.0528,  -8.7989,  -7.6855,  -8.4441,  -8.4245,
         -6.2553,  -7.6816, -10.2903,  -8.4024,  -7.5461,  -7.3952,  -8.2803,
         -6.6430,  -6.8741,  -7.3334, -10.1690, -10.6847, -10.5657, -12.17

  0%|          | 172/125000 [00:11<2:56:28, 11.79it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -7.2576,  -7.6247,  -7.5969,  -7.8655,  -6.5055,  -6.8100,  -8.2108,
         -8.4601, -10.5454,  -7.5242,  -5.9401,  -9.7686,  -9.3506,  -7.3111,
         -8.0803,  -7.0349,  -8.6788,  -9.0409,  -8.1073, -10.4186,  -7.1000,
         -8.7616,  -7.9121,  -7.3758,  -9.7616, -10.0688,  -7.7318,  -7.9293,
         -9.2800,  -7.1832,  -9.1997,  -8.7293,  -8.3871, -12.3806,  -7.0964,
         -9.6290,  -7.0397,  -7.1648,  -6.4714,  -7.0126,  -6.7216,  -9.5622,
         -7.9422, -10.5407,  -7.1038,  -6.8769,  -7.8485, -11.0743,  -7.9411,
         -8.0547,  -9.7120,  -7.4164,  -7.6332,  -6.0919,  -8.7604,  -9.5873,
         -7.0704,  -8.6548,  -8.1936, -10.9661,  -8.0205,  -8.1159,  -7.5441,
         -9.9497,  -9.5085,  -8.3009,  -8.5989,  -6.3902, -11.0406,  -6.5245,
         -9.0941, -12.0142,  -6.2683,  -8.5809, -10.3325,  -7.3264, -12.89

  0%|          | 174/125000 [00:11<2:59:55, 11.56it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -6.7838,  -9.6469,  -8.7817,  -7.7375,  -8.3954,  -7.8911,  -6.9071,
        -11.9249, -11.0107,  -6.4847,  -7.8138,  -6.9732, -10.7489,  -7.2547,
         -9.3233,  -7.1500,  -6.9595,  -9.1959,  -8.7886,  -7.5516,  -7.2351,
         -9.8167, -12.8757,  -6.4311,  -7.3096,  -6.6835,  -7.5370,  -8.5471,
         -7.2987,  -8.5066,  -9.9424,  -6.5292,  -7.5244,  -7.6679,  -7.5013,
         -9.7877,  -7.2286, -10.5352,  -7.6452, -10.3087,  -9.3802,  -7.4247,
         -7.0969, -13.1898,  -7.7583,  -6.7681,  -8.0918,  -8.4526,  -7.5985,
         -8.1468, -10.8215, -13.5750,  -7.3793,  -7.0313,  -9.8325,  -8.4629,
         -9.1416,  -8.6027,  -9.8097,  -7.7242,  -8.1980, -11.7703, -11.28

  0%|          | 176/125000 [00:11<2:56:27, 11.79it/s]

torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -7.8226, -12.8143,  -9.8851,  -7.1396,  -6.4308,  -6.7229,  -8.9582,
        -11.4051, -11.2502,  -8.7900,  -6.2236,  -6.5304,  -8.9488,  -9.4554,
         -6.8354,  -5.9107,  -6.7903,  -7.6531,  -7.4430,  -9.2693,  -9.2704,
         -6.2262,  -7.6127,  -7.2729, -10.8329, -10.6517,  -8.5304,  -5.8832,
        -10.5955,  -9.0661,  -7.5394,  -9.9838,  -7.4417,  -8.9478,  -7.0471,
         -9.9993,  -8.3285,  -9.5863,  -6.2023,  -9.5947, -10.7834,  -7.1306,
         -6.6911,  -8.4223,  -8.3478,  -8.0929,  -9.0676,  -9.4751,  -9.8022,
         -6.9465,  -7.6486,  -8.5187,  -6.7697,  -9.0536,  -7.7922,  -6.6513,
         -8.4999,  -8.2687,  -7.2940,  -9.5957,  -7.7717,  -7.7341,  -8.2217,
         -7.3852,  -7.5414,  -6.8559,  -8.4198,  -8.1967,  -6.8548,  -7.0529,
         -6.9059,  -8.1095,  -9.6342,  -7.7517,  -8.5276,  -8.5203,  -7.0726,
         -8.9462,  -9.7766,  -6.324

  0%|          | 180/125000 [00:11<2:36:14, 13.31it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -7.3465,  -7.6306,  -9.6337,  -6.4017,  -8.2658,  -6.8390, -10.1060,
        -10.7326,  -8.0016,  -8.6557,  -8.9012,  -7.7790,  -8.3501,  -7.7707,
         -6.3270,  -8.2311,  -7.0153,  -8.2680,  -6.5035, -10.5826,  -7.0174,
        -11.3765,  -8.9411,  -6.7084,  -7.7038,  -7.4130,  -7.1562,  -7.9634,
         -7.0445, -10.6483,  -6.8285,  -6.1228,  -6.7931, -11.9253,  -7.0698,
         -7.8258,  -8.1359,  -8.3988,  -8.0397, -11.6802,  -7.3009,  -7.8802,
         -7.6512, -11.5576,  -7.2486,  -6.3947,  -7.0786,  -6.6497,  -8.2437,
         -8.9695,  -7.2077,  -8.8975,  -7.3427,  -7.4168,  -6.9178,  -7.1698,
         -9.0623,  -9.2050,  -6.3378,  -7.5095,  -6.5694,  -7.5298,  -6.58

  0%|          | 184/125000 [00:12<2:25:33, 14.29it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -8.7144,  -8.7410, -11.0538,  -8.2862,  -9.8969, -10.1473,  -9.3978,
        -10.0033, -11.0892,  -6.9666,  -7.1169,  -8.2770,  -6.6307,  -8.8944,
        -11.4389,  -7.1859,  -7.1864, -12.6967,  -7.3770,  -8.4054,  -8.6765,
        -11.0956,  -6.3869,  -8.5558, -10.0483,  -9.3458,  -7.6619,  -8.0590,
        -11.1437,  -6.6576, -14.0187,  -7.0777,  -6.6662,  -7.0503,  -7.0796,
        -11.5799, -13.7066,  -6.3547,  -7.0208,  -8.4611,  -7.3288,  -7.7582,
         -7.7818,  -7.2500,  -6.9063,  -9.3628,  -8.4861,  -6.9221, -10.5242,
         -9.6726,  -6.8473,  -6.8782,  -7.7970,  -8.2699,  -6.3854,  -9.5721,
         -8.0227, -10.5665,  -6.0656,  -7.7901,  -7.1006,  -8.7187,  -8.0888,
         -8.1279, -10.1552,  -6.2620,  -9.0241,  -7.8800,  -6.1282,  -6.5356,
         -7.3907,  -8.7588,  -6.9008,  -7.4755,  -7.0033,  -7.5254,  -8.66

  0%|          | 186/125000 [00:12<2:22:16, 14.62it/s]

tensor([ -9.1652,  -8.7846,  -7.9288,  -7.9823, -12.5874,  -6.8450,  -7.4073,
         -6.3887,  -8.0206,  -9.7464,  -7.3042,  -6.9932,  -6.1952,  -9.3167,
         -9.1008,  -6.9913,  -7.2457, -10.6439,  -6.7977, -12.5549,  -7.5285,
         -9.7717,  -7.3375,  -7.5484,  -9.8481,  -8.8150,  -8.5508, -12.7650,
         -9.5545, -10.4283,  -7.6388,  -7.5749, -11.0325,  -7.3878,  -9.2730,
         -8.6708,  -7.0742, -11.3370, -10.9922,  -6.4662, -13.7597,  -7.3116,
         -6.4372, -10.7707,  -7.6244,  -8.1703, -14.3250,  -9.7299,  -9.1431,
         -7.5138,  -6.2926,  -9.5009,  -7.3187,  -6.3428,  -8.5146,  -7.6113,
         -9.2366,  -8.1623, -10.3409, -10.3509,  -7.3212,  -7.6650,  -7.7323,
         -7.8186,  -6.9166,  -9.8235,  -6.6702,  -7.1790,  -6.1304,  -9.2847,
         -7.5003, -11.1301,  -9.6005,  -6.6263,  -7.0564,  -9.6601,  -7.4151,
         -7.7922,  -8.1427,  -6.6324], grad_fn=<SumBackward1>) tensor([-1.2736, -1.2430, -1.0014, -0.1715, -1.5330, -0.4601, -1.3980, -0.6522,

  0%|          | 190/125000 [00:12<2:14:24, 15.48it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -9.6505,  -9.5438,  -6.4804,  -8.4256, -15.1102,  -8.1008,  -6.3378,
         -7.0999, -13.3815,  -6.6534,  -7.9000,  -8.3786,  -9.0645,  -7.1202,
         -6.5482, -12.5900, -10.9525,  -8.2575,  -6.4164, -10.7181,  -8.0906,
         -9.4667,  -7.3838,  -7.3930, -10.9654,  -6.6631,  -7.6634,  -9.7963,
         -6.6243,  -6.1904, -11.1316,  -8.4538, -10.1275,  -8.7715,  -8.2145,
        -10.9379,  -7.8590,  -6.5213,  -7.5937,  -9.3981,  -9.5955,  -8.8661,
         -9.0421,  -8.0105,  -6.6139,  -8.3201,  -7.0896,  -9.3923,  -9.7479,
         -9.3541,  -8.5548,  -7.2516,  -7.0119,  -6.6929, -12.7348,  -8.5093,
         -8.5878,  -7.9476,  -8.5653,  -8.7550,  -9.3344, -11.6685,  -8.72

  0%|          | 194/125000 [00:12<2:05:51, 16.53it/s]

tensor([ -8.9811,  -6.7604, -11.3255,  -7.5355, -10.8598,  -7.0327, -13.2597,
         -8.3815,  -7.4144,  -9.3069, -10.6958,  -7.6887,  -9.4971,  -6.5977,
         -7.2665,  -9.8745,  -8.1422,  -8.7126,  -9.9817,  -9.2952, -11.4870,
         -9.6809,  -8.8357,  -7.1953,  -9.9957,  -6.5730,  -8.9409,  -9.2234,
         -7.2944,  -7.4907, -10.5555,  -6.1273,  -7.4770,  -7.9283,  -7.0336,
         -8.0472, -12.1815,  -7.6147,  -7.3664, -10.3121, -14.8973,  -6.7699,
         -6.3851,  -7.9259,  -8.7571,  -6.1660,  -6.4240, -10.1067,  -9.6245,
         -6.1766,  -6.8027, -10.7069, -10.2328, -10.0331,  -8.0967,  -9.8149,
        -13.0020, -10.7751,  -6.6858,  -8.7074,  -7.0455, -14.8764, -10.4210,
         -8.7005, -12.7555, -10.9299,  -7.3423, -10.7890, -13.2089,  -7.6792,
         -9.7171,  -7.9504,  -7.8149,  -8.2127,  -7.6494,  -7.8389,  -6.8962,
         -8.1560,  -8.7923,  -7.5001], grad_fn=<SumBackward1>) tensor([-1.4316, -0.9798, -1.1970, -0.6396, -1.8639, -0.7809, -1.5042, -1.2883,

  0%|          | 198/125000 [00:12<2:07:06, 16.37it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -8.1355,  -6.4730,  -9.9609,  -9.8406, -10.5119,  -7.8147,  -7.5498,
         -6.9438,  -7.7248,  -7.8363,  -8.6676,  -8.3202,  -7.6803,  -7.1623,
        -11.9646,  -9.5007,  -8.1693,  -8.7720,  -8.1616,  -7.3148,  -8.9783,
        -11.5988, -10.5766,  -6.8435,  -7.0555,  -6.9732,  -6.4925,  -8.3231,
         -8.3574,  -9.4217,  -7.6109,  -9.0542, -10.1796,  -6.9843,  -8.4987,
         -8.1846,  -6.7169,  -6.8810,  -7.5461,  -8.6358,  -9.0778,  -7.2709,
         -6.4264,  -6.9939,  -6.9025,  -8.1594,  -7.8508, -10.1182, -10.2316,
         -6.6131,  -8.2660,  -7.9517,  -8.0369,  -8.5938,  -7.9631,  -9.3753,
        -11.2961,  -8.4391,  -7.9436,  -8.8978, -10.9855, -11.2927,  -9.9256,
         -6.6668,  -9.9183,  -7.602

  0%|          | 200/125000 [00:13<2:18:32, 15.01it/s]

torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -8.6336, -12.0949,  -6.1650,  -9.1085,  -6.6822,  -7.6301,  -7.1829,
         -7.9360,  -6.3730,  -7.3375,  -6.4285,  -7.5316,  -8.5917, -10.4233,
         -7.7467,  -7.6158, -10.7483,  -9.1334, -11.6833,  -8.3276,  -7.7624,
         -9.7829,  -7.7504,  -9.8759, -10.4687,  -8.3048,  -6.1634,  -7.4640,
        -12.0488,  -6.9358,  -9.8225,  -9.9768,  -8.4069,  -7.1877,  -7.7831,
        -10.0579,  -7.1669,  -8.0042,  -8.7971,  -8.5306,  -8.0902, -13.4417,
         -8.3124,  -7.7272,  -7.6430,  -8.5147, -10.9564,  -8.6335,  -9.3009,
         -9.3439, -10.1234,  -8.0412,  -6.1148,  -7.8834, -10.3918,  -8.2262,
         -9.3374,  -8.2326,  -9.1151,  -8.5416,  -6.6018,  -6.6479,  -9.67

  0%|          | 202/125000 [00:13<2:16:15, 15.26it/s]


torch.Size([16, 6])
torch.Size([1, 6])
torch.Size([5, 16, 17]) torch.Size([5, 16, 6]) torch.Size([5, 16]) torch.Size([5, 16]) torch.Size([5, 16])
tensor([ -8.8133,  -8.8690, -10.7052,  -7.8553,  -7.9867, -12.6453, -11.0884,
         -8.0597, -14.0783,  -8.9305, -12.2697,  -8.3338,  -7.9580,  -7.1213,
         -6.1177, -10.2597,  -7.8068,  -8.4635,  -7.0204, -12.8072,  -6.6489,
        -11.0278,  -7.7099, -11.5533,  -7.8390,  -6.7872,  -7.6721,  -9.0330,
         -8.2619,  -7.7314, -13.0279,  -6.5950, -12.5153,  -9.0618,  -7.0773,
         -8.0243,  -9.5197,  -8.2567, -10.2596,  -8.4986,  -9.4049,  -6.8511,
         -9.6313, -10.5266,  -6.3229,  -7.6154, -10.4019,  -9.2263,  -7.9004,
         -9.0895,  -7.9749,  -7.8254,  -8.1590,  -8.1087,  -8.7948, -11.4508,
        -10.0966,  -6.7778,  -5.7149, -10.6557,  -7.8548,  -8.3381,  -8.5026,
         -8.1725,  -9.3994,  -6.0307,  -7.4261,  -7.8728, -10.2089,  -6.5280,
         -8.6227, -11.3244,  -8.7865,  -8.6625,  -7.6596,  -9.1468,  -9.19

KeyboardInterrupt: 

In [17]:
rundir_to_eval = Path(f"runs/{run_name}/2024_01_14_20_44_28")

In [19]:
print(f"Evaluating and capturing videos on {args.env_id}.")
mean_eval_return = eval_and_render(args=args, run_dir=rundir_to_eval)
print(f"Evaluation - Mean returns achieved: {mean_eval_return}.")

Evaluating and capturing videos on HalfCheetah-v4.
reading runs\A2C_PyTorch\2024_01_14_20_44_28/policy.pt...


  logger.warn(


-> Episode 1: 715.8059692382812 returns
-> Episode 2: 1270.4517822265625 returns


KeyboardInterrupt: 

In [15]:
from IPython.display import Video

Video("half-cheetah-video.mp4")

ValueError: To embed videos, you must pass embed=True (this may make your notebook files huge)
Consider passing Video(url='...')