In [15]:
# #@markdown ### **Installing pip packages**
# #@markdown - Diffusion Model: [PyTorch](https://pytorch.org) & [HuggingFace diffusers](https://huggingface.co/docs/diffusers/index)
# #@markdown - Dataset Loading: [Zarr](https://zarr.readthedocs.io/en/stable/) & numcodecs
# #@markdown - Push-T Env: gym, pygame, pymunk & shapely
# !python --version
# !pip3 install zarr scikit-video torch==1.13.1 torchvision==0.14.1 diffusers==0.18.2 \
# scikit-image==0.19.3 scikit-video==1.1.11 zarr==2.12.0 numcodecs==0.10.2 \
# pygame==2.1.2 pymunk==6.2.1 gym==0.26.2 shapely==1.8.4 \
# &> /dev/null # mute output
# # Reinstall pymunk explicitly to ensure it's correctly installed
# !pip install pymunk==6.2.1

In [16]:
#@markdown ### **Imports**
# diffusion policy import
from typing import Tuple, Sequence, Dict, Union, Optional, Callable
import numpy as np
import math
import torch
import torch.nn as nn
import torchvision
import collections
import zarr
# from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
# from diffusers.training_utils import EMAModel
# from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm

# env import
import gym
from gym import spaces
import pygame
import pymunk
import pymunk.pygame_util
from pymunk.space_debug_draw_options import SpaceDebugColor
from pymunk.vec2d import Vec2d
import shapely.geometry as sg
import cv2
import skimage.transform as st
from skvideo.io import vwrite
from IPython.display import Video
# import gdown
import os

In [17]:
#@markdown ### **Environment**
#@markdown Defines a PyMunk-based Push-T environment `PushTEnv`.
#@markdown And it's subclass `PushTImageEnv`.
#@markdown
#@markdown **Goal**: push the gray T-block into the green area.
#@markdown
#@markdown Adapted from [Implicit Behavior Cloning](https://implicitbc.github.io/)


positive_y_is_up: bool = False
"""Make increasing values of y point upwards.

When True::

    y
    ^
    |      . (3, 3)
    |
    |   . (2, 2)
    |
    +------ > x

When False::

    +------ > x
    |
    |   . (2, 2)
    |
    |      . (3, 3)
    v
    y

"""

def to_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
    """Convenience method to convert pymunk coordinates to pygame surface
    local coordinates.

    Note that in case positive_y_is_up is False, this function wont actually do
    anything except converting the point to integers.
    """
    if positive_y_is_up:
        return round(p[0]), surface.get_height() - round(p[1])
    else:
        return round(p[0]), round(p[1])


def light_color(color: SpaceDebugColor):
    color = np.minimum(1.2 * np.float32([color.r, color.g, color.b, color.a]), np.float32([255]))
    color = SpaceDebugColor(r=color[0], g=color[1], b=color[2], a=color[3])
    return color

class DrawOptions(pymunk.SpaceDebugDrawOptions):
    def __init__(self, surface: pygame.Surface) -> None:
        """Draw a pymunk.Space on a pygame.Surface object.

        Typical usage::

        >>> import pymunk
        >>> surface = pygame.Surface((10,10))
        >>> space = pymunk.Space()
        >>> options = pymunk.pygame_util.DrawOptions(surface)
        >>> space.debug_draw(options)

        You can control the color of a shape by setting shape.color to the color
        you want it drawn in::

        >>> c = pymunk.Circle(None, 10)
        >>> c.color = pygame.Color("pink")

        See pygame_util.demo.py for a full example

        Since pygame uses a coordiante system where y points down (in contrast
        to many other cases), you either have to make the physics simulation
        with Pymunk also behave in that way, or flip everything when you draw.

        The easiest is probably to just make the simulation behave the same
        way as Pygame does. In that way all coordinates used are in the same
        orientation and easy to reason about::

        >>> space = pymunk.Space()
        >>> space.gravity = (0, -1000)
        >>> body = pymunk.Body()
        >>> body.position = (0, 0) # will be positioned in the top left corner
        >>> space.debug_draw(options)

        To flip the drawing its possible to set the module property
        :py:data:`positive_y_is_up` to True. Then the pygame drawing will flip
        the simulation upside down before drawing::

        >>> positive_y_is_up = True
        >>> body = pymunk.Body()
        >>> body.position = (0, 0)
        >>> # Body will be position in bottom left corner

        :Parameters:
                surface : pygame.Surface
                    Surface that the objects will be drawn on
        """
        self.surface = surface
        super(DrawOptions, self).__init__()

    def draw_circle(
        self,
        pos: Vec2d,
        angle: float,
        radius: float,
        outline_color: SpaceDebugColor,
        fill_color: SpaceDebugColor,
    ) -> None:
        p = to_pygame(pos, self.surface)

        pygame.draw.circle(self.surface, fill_color.as_int(), p, round(radius), 0)
        pygame.draw.circle(self.surface, light_color(fill_color).as_int(), p, round(radius-4), 0)

        circle_edge = pos + Vec2d(radius, 0).rotated(angle)
        p2 = to_pygame(circle_edge, self.surface)
        line_r = 2 if radius > 20 else 1
        # pygame.draw.lines(self.surface, outline_color.as_int(), False, [p, p2], line_r)

    def draw_segment(self, a: Vec2d, b: Vec2d, color: SpaceDebugColor) -> None:
        p1 = to_pygame(a, self.surface)
        p2 = to_pygame(b, self.surface)

        pygame.draw.aalines(self.surface, color.as_int(), False, [p1, p2])

    def draw_fat_segment(
        self,
        a: Tuple[float, float],
        b: Tuple[float, float],
        radius: float,
        outline_color: SpaceDebugColor,
        fill_color: SpaceDebugColor,
    ) -> None:
        p1 = to_pygame(a, self.surface)
        p2 = to_pygame(b, self.surface)

        r = round(max(1, radius * 2))
        pygame.draw.lines(self.surface, fill_color.as_int(), False, [p1, p2], r)
        if r > 2:
            orthog = [abs(p2[1] - p1[1]), abs(p2[0] - p1[0])]
            if orthog[0] == 0 and orthog[1] == 0:
                return
            scale = radius / (orthog[0] * orthog[0] + orthog[1] * orthog[1]) ** 0.5
            orthog[0] = round(orthog[0] * scale)
            orthog[1] = round(orthog[1] * scale)
            points = [
                (p1[0] - orthog[0], p1[1] - orthog[1]),
                (p1[0] + orthog[0], p1[1] + orthog[1]),
                (p2[0] + orthog[0], p2[1] + orthog[1]),
                (p2[0] - orthog[0], p2[1] - orthog[1]),
            ]
            pygame.draw.polygon(self.surface, fill_color.as_int(), points)
            pygame.draw.circle(
                self.surface,
                fill_color.as_int(),
                (round(p1[0]), round(p1[1])),
                round(radius),
            )
            pygame.draw.circle(
                self.surface,
                fill_color.as_int(),
                (round(p2[0]), round(p2[1])),
                round(radius),
            )

    def draw_polygon(
        self,
        verts: Sequence[Tuple[float, float]],
        radius: float,
        outline_color: SpaceDebugColor,
        fill_color: SpaceDebugColor,
    ) -> None:
        ps = [to_pygame(v, self.surface) for v in verts]
        ps += [ps[0]]

        radius = 2
        pygame.draw.polygon(self.surface, light_color(fill_color).as_int(), ps)

        if radius > 0:
            for i in range(len(verts)):
                a = verts[i]
                b = verts[(i + 1) % len(verts)]
                self.draw_fat_segment(a, b, radius, fill_color, fill_color)

    def draw_dot(
        self, size: float, pos: Tuple[float, float], color: SpaceDebugColor
    ) -> None:
        p = to_pygame(pos, self.surface)
        pygame.draw.circle(self.surface, color.as_int(), p, round(size), 0)


def pymunk_to_shapely(body, shapes):
    geoms = list()
    for shape in shapes:
        if isinstance(shape, pymunk.shapes.Poly):
            verts = [body.local_to_world(v) for v in shape.get_vertices()]
            verts += [verts[0]]
            geoms.append(sg.Polygon(verts))
        else:
            raise RuntimeError(f'Unsupported shape type {type(shape)}')
    geom = sg.MultiPolygon(geoms)
    return geom

# env
class PushTEnv(gym.Env):
    metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 10}
    reward_range = (0., 1.)

    def __init__(self,
            legacy=False,
            block_cog=None, damping=None,
            render_action=True,
            render_size=96,
            reset_to_state=None
        ):
        self._seed = None
        self.seed()
        self.window_size = ws = 512  # The size of the PyGame window
        self.render_size = render_size
        self.sim_hz = 100
        # Local controller params.
        self.k_p, self.k_v = 100, 20    # PD control.z
        self.control_hz = self.metadata['video.frames_per_second']
        # legcay set_state for data compatiblity
        self.legacy = legacy

        # agent_pos, block_pos, block_angle
        self.observation_space = spaces.Box(
            low=np.array([0,0,0,0,0], dtype=np.float64),
            high=np.array([ws,ws,ws,ws,np.pi*2], dtype=np.float64),
            shape=(5,),
            dtype=np.float64
        )

        # positional goal for agent
        self.action_space = spaces.Box(
            low=np.array([0,0], dtype=np.float64),
            high=np.array([ws,ws], dtype=np.float64),
            shape=(2,),
            dtype=np.float64
        )

        self.block_cog = block_cog
        self.damping = damping
        self.render_action = render_action

        """
        If human-rendering is used, `self.window` will be a reference
        to the window that we draw to. `self.clock` will be a clock that is used
        to ensure that the environment is rendered at the correct framerate in
        human-mode. They will remain `None` until human-mode is used for the
        first time.
        """
        self.window = None
        self.clock = None
        self.screen = None

        self.space = None
        self.teleop = None
        self.render_buffer = None
        self.latest_action = None
        self.reset_to_state = reset_to_state

    def reset(self):
        seed = self._seed
        self._setup()
        if self.block_cog is not None:
            self.block.center_of_gravity = self.block_cog
        if self.damping is not None:
            self.space.damping = self.damping

        # use legacy RandomState for compatiblity
        state = self.reset_to_state
        if state is None:
            rs = np.random.RandomState(seed=seed)
            state = np.array([
                rs.randint(50, 450), rs.randint(50, 450),
                rs.randint(100, 400), rs.randint(100, 400),
                rs.randn() * 2 * np.pi - np.pi
                ])
        self._set_state(state)

        obs = self._get_obs()
        info = self._get_info()
        return obs, info

    def step(self, action):
        dt = 1.0 / self.sim_hz
        self.n_contact_points = 0
        n_steps = self.sim_hz // self.control_hz
        if action is not None:
            self.latest_action = action
            for i in range(n_steps):
                # Step PD control.
                # self.agent.velocity = self.k_p * (act - self.agent.position)    # P control works too.
                acceleration = self.k_p * (action - self.agent.position) + self.k_v * (Vec2d(0, 0) - self.agent.velocity)
                self.agent.velocity += acceleration * dt

                # Step physics.
                self.space.step(dt)

        # compute reward
        goal_body = self._get_goal_pose_body(self.goal_pose)
        goal_geom = pymunk_to_shapely(goal_body, self.block.shapes)
        block_geom = pymunk_to_shapely(self.block, self.block.shapes)

        intersection_area = goal_geom.intersection(block_geom).area
        goal_area = goal_geom.area
        coverage = intersection_area / goal_area
        reward = np.clip(coverage / self.success_threshold, 0, 1)
        done = coverage > self.success_threshold
        terminated = done
        truncated = done

        observation = self._get_obs()
        info = self._get_info()

        return observation, reward, terminated, truncated, info

    def render(self, mode):
        return self._render_frame(mode)

    def teleop_agent(self):
        TeleopAgent = collections.namedtuple('TeleopAgent', ['act'])
        def act(obs):
            act = None
            mouse_position = pymunk.pygame_util.from_pygame(Vec2d(*pygame.mouse.get_pos()), self.screen)
            if self.teleop or (mouse_position - self.agent.position).length < 30:
                self.teleop = True
                act = mouse_position
            return act
        return TeleopAgent(act)

    def _get_obs(self):
        obs = np.array(
            tuple(self.agent.position) \
            + tuple(self.block.position) \
            + (self.block.angle % (2 * np.pi),))
        return obs

    def _get_goal_pose_body(self, pose):
        mass = 1
        inertia = pymunk.moment_for_box(mass, (50, 100))
        body = pymunk.Body(mass, inertia)
        # preserving the legacy assignment order for compatibility
        # the order here dosn't matter somehow, maybe because CoM is aligned with body origin
        body.position = pose[:2].tolist()
        body.angle = pose[2]
        return body

    def _get_info(self):
        n_steps = self.sim_hz // self.control_hz
        n_contact_points_per_step = int(np.ceil(self.n_contact_points / n_steps))
        info = {
            'pos_agent': np.array(self.agent.position),
            'vel_agent': np.array(self.agent.velocity),
            'block_pose': np.array(list(self.block.position) + [self.block.angle]),
            'goal_pose': self.goal_pose,
            'n_contacts': n_contact_points_per_step}
        return info

    def _render_frame(self, mode):

        if self.window is None and mode == "human":
            pygame.init()
            pygame.display.init()
            self.window = pygame.display.set_mode((self.window_size, self.window_size))
        if self.clock is None and mode == "human":
            self.clock = pygame.time.Clock()

        canvas = pygame.Surface((self.window_size, self.window_size))
        canvas.fill((255, 255, 255))
        self.screen = canvas

        draw_options = DrawOptions(canvas)

        # Draw goal pose.
        goal_body = self._get_goal_pose_body(self.goal_pose)
        for shape in self.block.shapes:
            goal_points = [pymunk.pygame_util.to_pygame(goal_body.local_to_world(v), draw_options.surface) for v in shape.get_vertices()]
            goal_points += [goal_points[0]]
            pygame.draw.polygon(canvas, self.goal_color, goal_points)

        # Draw agent and block.
        self.space.debug_draw(draw_options)

        if mode == "human":
            # The following line copies our drawings from `canvas` to the visible window
            self.window.blit(canvas, canvas.get_rect())
            pygame.event.pump()
            pygame.display.update()

            # the clock is aleady ticked during in step for "human"


        img = np.transpose(
                np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)
            )
        img = cv2.resize(img, (self.render_size, self.render_size))
        if self.render_action:
            if self.render_action and (self.latest_action is not None):
                action = np.array(self.latest_action)
                coord = (action / 512 * 96).astype(np.int32)
                marker_size = int(8/96*self.render_size)
                thickness = int(1/96*self.render_size)
                cv2.drawMarker(img, coord,
                    color=(255,0,0), markerType=cv2.MARKER_CROSS,
                    markerSize=marker_size, thickness=thickness)
        return img


    def close(self):
        if self.window is not None:
            pygame.display.quit()
            pygame.quit()

    def seed(self, seed=None):
        if seed is None:
            seed = np.random.randint(0,25536)
        self._seed = seed
        self.np_random = np.random.default_rng(seed)

    def _handle_collision(self, arbiter, space, data):
        self.n_contact_points += len(arbiter.contact_point_set.points)

    def _set_state(self, state):
        if isinstance(state, np.ndarray):
            state = state.tolist()
        pos_agent = state[:2]
        pos_block = state[2:4]
        rot_block = state[4]
        self.agent.position = pos_agent
        # setting angle rotates with respect to center of mass
        # therefore will modify the geometric position
        # if not the same as CoM
        # therefore should be modified first.
        if self.legacy:
            # for compatiblity with legacy data
            self.block.position = pos_block
            self.block.angle = rot_block
        else:
            self.block.angle = rot_block
            self.block.position = pos_block

        # Run physics to take effect
        self.space.step(1.0 / self.sim_hz)

    def _set_state_local(self, state_local):
        agent_pos_local = state_local[:2]
        block_pose_local = state_local[2:]
        tf_img_obj = st.AffineTransform(
            translation=self.goal_pose[:2],
            rotation=self.goal_pose[2])
        tf_obj_new = st.AffineTransform(
            translation=block_pose_local[:2],
            rotation=block_pose_local[2]
        )
        tf_img_new = st.AffineTransform(
            matrix=tf_img_obj.params @ tf_obj_new.params
        )
        agent_pos_new = tf_img_new(agent_pos_local)
        new_state = np.array(
            list(agent_pos_new[0]) + list(tf_img_new.translation) \
                + [tf_img_new.rotation])
        self._set_state(new_state)
        return new_state

    def _setup(self):
        self.space = pymunk.Space()
        self.space.gravity = 0, 0
        self.space.damping = 0
        self.teleop = False
        self.render_buffer = list()

        # Add walls.
        walls = [
            self._add_segment((5, 506), (5, 5), 2),
            self._add_segment((5, 5), (506, 5), 2),
            self._add_segment((506, 5), (506, 506), 2),
            self._add_segment((5, 506), (506, 506), 2)
        ]
        self.space.add(*walls)

        # Add agent, block, and goal zone.
        self.agent = self.add_circle((256, 400), 15)
        self.block = self.add_tee((256, 300), 0)
        self.goal_color = pygame.Color('LightGreen')
        self.goal_pose = np.array([256,256,np.pi/4])  # x, y, theta (in radians)

        # Add collision handeling
        self.collision_handeler = self.space.add_collision_handler(0, 0)
        self.collision_handeler.post_solve = self._handle_collision
        self.n_contact_points = 0

        self.max_score = 50 * 100
        self.success_threshold = 0.95    # 95% coverage.

    def _add_segment(self, a, b, radius):
        shape = pymunk.Segment(self.space.static_body, a, b, radius)
        shape.color = pygame.Color('LightGray')    # https://htmlcolorcodes.com/color-names
        return shape

    def add_circle(self, position, radius):
        body = pymunk.Body(body_type=pymunk.Body.KINEMATIC)
        body.position = position
        body.friction = 1
        shape = pymunk.Circle(body, radius)
        shape.color = pygame.Color('RoyalBlue')
        self.space.add(body, shape)
        return body

    def add_box(self, position, height, width):
        mass = 1
        inertia = pymunk.moment_for_box(mass, (height, width))
        body = pymunk.Body(mass, inertia)
        body.position = position
        shape = pymunk.Poly.create_box(body, (height, width))
        shape.color = pygame.Color('LightSlateGray')
        self.space.add(body, shape)
        return body

    def add_tee(self, position, angle, scale=30, color='LightSlateGray', mask=pymunk.ShapeFilter.ALL_MASKS()):
        mass = 1
        length = 4
        vertices1 = [(-length*scale/2, scale),
                                 ( length*scale/2, scale),
                                 ( length*scale/2, 0),
                                 (-length*scale/2, 0)]
        inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
        vertices2 = [(-scale/2, scale),
                                 (-scale/2, length*scale),
                                 ( scale/2, length*scale),
                                 ( scale/2, scale)]
        inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
        body = pymunk.Body(mass, inertia1 + inertia2)
        shape1 = pymunk.Poly(body, vertices1)
        shape2 = pymunk.Poly(body, vertices2)
        shape1.color = pygame.Color(color)
        shape2.color = pygame.Color(color)
        shape1.filter = pymunk.ShapeFilter(mask=mask)
        shape2.filter = pymunk.ShapeFilter(mask=mask)
        body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
        body.position = position
        body.angle = angle
        body.friction = 1
        self.space.add(body, shape1, shape2)
        return body


class PushTImageEnv(PushTEnv):
    metadata = {"render.modes": ["rgb_array"], "video.frames_per_second": 10}

    def __init__(self,
            legacy=False,
            block_cog=None,
            damping=None,
            render_size=96):
        super().__init__(
            legacy=legacy,
            block_cog=block_cog,
            damping=damping,
            render_size=render_size,
            render_action=False)
        ws = self.window_size
        self.observation_space = spaces.Dict({
            'image': spaces.Box(
                low=0,
                high=1,
                shape=(3,render_size,render_size),
                dtype=np.float32
            ),
            'agent_pos': spaces.Box(
                low=0,
                high=ws,
                shape=(2,),
                dtype=np.float32
            )
        })
        self.render_cache = None

    def _get_obs(self):
        img = super()._render_frame(mode='rgb_array')

        agent_pos = np.array(self.agent.position)
        img_obs = np.moveaxis(img.astype(np.float32) / 255, -1, 0)
        obs = {
            'image': img_obs,
            'agent_pos': agent_pos
        }

        # draw action
        if self.latest_action is not None:
            action = np.array(self.latest_action)
            coord = (action / 512 * 96).astype(np.int32)
            marker_size = int(8/96*self.render_size)
            thickness = int(1/96*self.render_size)
            cv2.drawMarker(img, coord,
                color=(255,0,0), markerType=cv2.MARKER_CROSS,
                markerSize=marker_size, thickness=thickness)
        self.render_cache = img

        return obs

    def render(self, mode):
        assert mode == 'rgb_array'

        if self.render_cache is None:
            self._get_obs()

        return self.render_cache



In [18]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import os # For creating directories if needed

def plot_and_save_grayscale(output_2d_tensor, filename="grayscale_output.png", title="Autoencoder Output (Grayscale)"):
    plt.ioff()
    """
    Plots a 1xHxW tensor as a grayscale image and saves it.
    Scales values from -1.0 to 1.0 to 0.0 to 1.0 for plotting.

    Args:
        output_2d_tensor (torch.Tensor): The output tensor of shape (1, H, W) or (H, W).
                                         Values are expected to be between -1.0 and 1.0.
        filename (str): The name of the file to save the plot.
        title (str): The title of the plot.
    """
    if output_2d_tensor.dim() == 3 and output_2d_tensor.shape[0] == 1:
        # Remove the channel dimension if it's 1
        plot_data = output_2d_tensor.squeeze(0).cpu().numpy()
    elif output_2d_tensor.dim() == 2:
        plot_data = output_2d_tensor.cpu().numpy()
    else:
        raise ValueError("output_2d_tensor must be of shape (1, H, W) or (H, W)")

    # Scale values from [-1, 1] to [0, 1] for typical image display
    plot_data_scaled = (plot_data + 1.0) / 2.0

    plt.figure(figsize=(6, 6))
    plt.imshow(plot_data_scaled, cmap='gray', vmin=0, vmax=1)
    plt.title(title)
    plt.colorbar(label="Scaled Value (0.0 to 1.0)")
    plt.axis('off') # Hide axes for cleaner image
    plt.tight_layout()
    plt.savefig(filename)
    plt.close() # Close the plot to free memory

    # print(f"Grayscale plot saved to {filename}")

def plot_and_save_heatmap(output_2d_tensor, filename="heatmap_output.png", title="Autoencoder Output (Heatmap)"):
    plt.ioff()
    """
    Plots a 1xHxW tensor as a heatmap and saves it.
    Values are expected to be between -1.0 and 1.0.

    Args:
        output_2d_tensor (torch.Tensor): The output tensor of shape (1, H, W) or (H, W).
                                         Values are expected to be between -1.0 and 1.0.
        filename (str): The name of the file to save the plot.
        title (str): The title of the plot.
    """
    if output_2d_tensor.dim() == 3 and output_2d_tensor.shape[0] == 1:
        # Remove the channel dimension if it's 1
        plot_data = output_2d_tensor.squeeze(0).cpu().numpy()
    elif output_2d_tensor.dim() == 2:
        plot_data = output_2d_tensor.cpu().numpy()
    else:
        raise ValueError("output_2d_tensor must be of shape (1, H, W) or (H, W)")

    plt.figure(figsize=(7, 6)) # Slightly wider for colorbar
    plt.imshow(plot_data, cmap='viridis', vmin=-1.0, vmax=1.0) # 'viridis' or 'plasma' are good perceptual colormaps
    plt.title(title)
    plt.colorbar(label="Value (-1.0 to 1.0)")
    plt.axis('off') # Hide axes for cleaner image
    plt.tight_layout()
    plt.savefig(filename)
    plt.close() # Close the plot to free memory

    # print(f"Heatmap plot saved to {filename}")
def viz(dummy_output_2d):
    output_dir = "model_outputs"
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Plot and save as grayscale
    plot_and_save_grayscale(dummy_output_2d, filename=os.path.join(output_dir, "output_grayscale.png"))

    # Plot and save as heatmap
    plot_and_save_heatmap(dummy_output_2d, filename=os.path.join(output_dir, "output_heatmap.png"))


In [19]:
#@markdown ### **Env Demo**
#@markdown Standard Gym Env (0.21.0 API)

# 0. create env object
env = PushTImageEnv()

# 1. seed env for initial state.
# Seed 0-200 are used for the demonstration dataset.
env.seed(1000)

# 2. must reset before use
obs, info = env.reset()

# 3. 2D positional action space [0,512]
action = env.action_space.sample()

# 4. Standard gym step method
obs, reward, terminated, truncated, info = env.step(action)

# prints and explains each dimension of the observation and action vectors
with np.printoptions(precision=4, suppress=True, threshold=5):
    print("obs['image'].shape:", obs['image'].shape, "float32, [0,1]")
    print("obs['agent_pos'].shape:", obs['agent_pos'].shape, "float32, [0,512]")
    print("action.shape: ", action.shape, "float32, [0,512]")

obs['image'].shape: (3, 96, 96) float32, [0,1]
obs['agent_pos'].shape: (2,) float32, [0,512]
action.shape:  (2,) float32, [0,512]


In [None]:
def plot_agent_positions(naction,path, batch_idx=0, max_trajectories=8, figsize=(15, 10)):
    plt.ioff()
    if isinstance(naction, torch.Tensor):
        naction_np = naction.detach().cpu().numpy()
    else:
        naction_np = naction
    if isinstance(path, torch.Tensor):
        path = path.detach().cpu().numpy()
    else:
        path = path
    

    B, T, coord_dim = naction_np.shape
    # print(f"Agent position tensor shape: {naction_np.shape}")
    # print(f"Batch size: {B}, Time steps: {T}, Coordinates: {coord_dim}")
    
    # Create subplots
    fig = plt.figure(figsize=figsize)
    
    plt.subplot(2, 3, 1)
    traj = naction_np[batch_idx]  # Shape: (T, 2)
    x_coords = traj[:, 0]
    y_coords = traj[:, 1]
    traj_gt = path[batch_idx]
    traj_gt_x_coords = traj_gt[:, 0]
    traj_gt_y_coords = traj_gt[:, 1]
    
    plt.plot(x_coords, y_coords, 'b-', linewidth=2, alpha=0.7, label='Trajectory')
    plt.plot(traj_gt_x_coords, traj_gt_y_coords, 'r-', linewidth=2, alpha=0.7, label='Trajectory')
    plt.tight_layout()
    # plt.show()
    plt.savefig("model_outputs/plot.jpg")
    

In [21]:
#@markdown ### **Dataset**
#@markdown
#@markdown Defines `PushTImageDataset` and helper functions
#@markdown
#@markdown The dataset class
#@markdown - Load data ((image, agent_pos), action) from a zarr storage
#@markdown - Normalizes each dimension of agent_pos and action to [-1,1]
#@markdown - Returns
#@markdown  - All possible segments with length `pred_horizon`
#@markdown  - Pads the beginning and the end of each episode with repetition
#@markdown  - key `image`: shape (obs_hoirzon, 3, 96, 96)
#@markdown  - key `agent_pos`: shape (obs_hoirzon, 2)
#@markdown  - key `action`: shape (pred_horizon, 2)

def create_sample_indices(
        episode_ends:np.ndarray, sequence_length:int,
        pad_before: int=0, pad_after: int=0):
    indices = list()
    for i in range(len(episode_ends)):
        start_idx = 0
        if i > 0:
            start_idx = episode_ends[i-1]
        end_idx = episode_ends[i]
        episode_length = end_idx - start_idx

        min_start = -pad_before
        max_start = episode_length - sequence_length + pad_after

        # range stops one idx before end
        for idx in range(min_start, max_start+1):
            buffer_start_idx = max(idx, 0) + start_idx
            buffer_end_idx = min(idx+sequence_length, episode_length) + start_idx
            start_offset = buffer_start_idx - (idx+start_idx)
            end_offset = (idx+sequence_length+start_idx) - buffer_end_idx
            sample_start_idx = 0 + start_offset
            sample_end_idx = sequence_length - end_offset
            indices.append([
                buffer_start_idx, buffer_end_idx,
                sample_start_idx, sample_end_idx])
    indices = np.array(indices)
    return indices


def sample_sequence(train_data, sequence_length,
                    buffer_start_idx, buffer_end_idx,
                    sample_start_idx, sample_end_idx):
    result = dict()
    for key, input_arr in train_data.items():
        sample = input_arr[buffer_start_idx:buffer_end_idx]
        data = sample
        if (sample_start_idx > 0) or (sample_end_idx < sequence_length):
            data = np.zeros(
                shape=(sequence_length,) + input_arr.shape[1:],
                dtype=input_arr.dtype)
            if sample_start_idx > 0:
                data[:sample_start_idx] = sample[0]
            if sample_end_idx < sequence_length:
                data[sample_end_idx:] = sample[-1]
            data[sample_start_idx:sample_end_idx] = sample
        result[key] = data
    return result

# normalize data
def get_data_stats(data):
    data = data.reshape(-1,data.shape[-1])
    stats = {
        'min': np.min(data, axis=0),
        'max': np.max(data, axis=0)
    }
    return stats

def normalize_data(data, stats):
    # nomalize to [0,1]
    ndata = (data - stats['min']) / (stats['max'] - stats['min'])
    # normalize to [-1, 1]
    ndata = ndata * 2 - 1
    return ndata

def normalize_data_torch(data, stats):
    # nomalize to [0,1]
    ndata = (data - torch.Tensor(stats['min']).to("cuda")) / torch.Tensor(stats['max'] - stats['min']).to("cuda")
    # normalize to [-1, 1]
    ndata = ndata * 2 - 1
    return ndata

def scale_agent_position_in_pixel(data):
    ndata = (data/512)*96 
    return ndata
# def scale_agent_position_in_pixel(data):
#     ndata = (data/512)*96 
#     return ndata
def unnormalize_data(ndata, stats):
    ndata = (ndata + 1) / 2
    data = ndata * (stats['max'] - stats['min']) + stats['min']
    return data

# dataset
class PushTImageDataset(torch.utils.data.Dataset):
    def __init__(self,
                 dataset_path: str,
                 pred_horizon: int,
                 obs_horizon: int,
                 action_horizon: int):

        # read from zarr dataset
        dataset_root = zarr.open(dataset_path, 'r')

        # float32, [0,1], (N,96,96,3)
        train_image_data = dataset_root['data']['img'][:]
        train_image_data = np.moveaxis(train_image_data, -1,1)
        # (N,3,96,96)

        # (N, D)
        train_data = {
            # first two dims of state vector are agent (i.e. gripper) locations
            'agent_pos': scale_agent_position_in_pixel(dataset_root['data']['state'][:,:2]),
            'action': dataset_root['data']['action'][:]
        }
        episode_ends = dataset_root['meta']['episode_ends'][:]

        # compute start and end of each state-action sequence
        # also handles padding
        indices = create_sample_indices(
            episode_ends=episode_ends,
            sequence_length=pred_horizon,
            pad_before=obs_horizon-1,
            pad_after=action_horizon-1)

        # compute statistics and normalized data to [-1,1]
        stats = dict()
        normalized_train_data = dict()
        for key, data in train_data.items():
            stats[key] = get_data_stats(data)
            # if key !="agent_pos":
            normalized_train_data[key] = normalize_data(data, stats[key])
            # else:


        # images are already normalized
        normalized_train_data['image'] = train_image_data

        self.indices = indices
        self.stats = stats
        self.normalized_train_data = normalized_train_data
        self.pred_horizon = pred_horizon
        self.action_horizon = action_horizon
        self.obs_horizon = obs_horizon

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        # get the start/end indices for this datapoint
        buffer_start_idx, buffer_end_idx, \
            sample_start_idx, sample_end_idx = self.indices[idx]

        # get nomralized data using these indices
        nsample = sample_sequence(
            train_data=self.normalized_train_data,
            sequence_length=self.pred_horizon,
            buffer_start_idx=buffer_start_idx,
            buffer_end_idx=buffer_end_idx,
            sample_start_idx=sample_start_idx,
            sample_end_idx=sample_end_idx
        )

        # discard unused observations
        nsample['image'] = nsample['image'][:self.obs_horizon,:]
        nsample['agent_pos'] = nsample['agent_pos']
        # nsample["action"]=
        return nsample


In [22]:
#@markdown ### **Dataset Demo**

# download demonstration data from Google Drive
dataset_path = "/home/kojogyaase/Projects/Research/neural_potential_fields/pusht_cchi_v7_replay.zarr.zip"
# if not os.path.isfile(dataset_path):
#     id = "1KY1InLurpMvJDRb14L9NlXT_fEsCvVUq&confirm=t"
#     gdown.download(id=id, output=dataset_path, quiet=False)

# parameters
pred_horizon = 8
obs_horizon = 1
action_horizon = 8
device="cuda"
#|o|                             observations: 1
#| |a|a|a|a|a|a|a|a|               actions executed: 8
#|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p| actions predicted: 16

# create dataset from file
dataset = PushTImageDataset(
    dataset_path=dataset_path,
    pred_horizon=pred_horizon,
    obs_horizon=obs_horizon,
    action_horizon=action_horizon
)
# save training data statistics (min, max) for each dim
stats = dataset.stats

# create dataloader
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=32,
    num_workers=4,
    shuffle=True,
    # accelerate cpu-gpu transfer
    pin_memory=True,
    # don't kill worker process afte each epoch
    persistent_workers=True
)



In [23]:
# visualize data in batch
batch = next(iter(dataloader))
print("batch['image'].shape:", batch['image'].shape)
print("batch['agent_pos'].shape:", batch['agent_pos'].shape)
# print("batch['action'].shape", batch['action'].shape)
# plot_agent_positions(batch["agent_pos"])

batch['image'].shape: torch.Size([32, 1, 3, 96, 96])
batch['agent_pos'].shape: torch.Size([32, 8, 2])


In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class RGBTo2DAutoencoder(nn.Module):
    """
    Autoencoder that takes an RGB image (3x96x96) and encodes it to a latent vector,
    then decodes it into a 2D potential field (1x96x96), which is then squeezed to (96x96).
    This version uses GroupNorm for better batch size flexibility.
    """
    def __init__(self, input_channels=3, latent_dim=512):
        super(RGBTo2DAutoencoder, self).__init__()

        # Encoder: RGB Image -> Latent Vector (using GroupNorm instead of BatchNorm)
        # Input: 3 x 96 x 96
        self.encoder = nn.Sequential(
            # Conv1: 96x96 -> 48x48
            nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1),  # (96 - 4 + 2*1)/2 + 1 = 48
            nn.GroupNorm(8, 64),
            nn.ReLU(True),

            # Conv2: 48x48 -> 24x24
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # (48 - 4 + 2*1)/2 + 1 = 24
            nn.GroupNorm(16, 128),
            nn.ReLU(True),

            # Conv3: 24x24 -> 12x12
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # (24 - 4 + 2*1)/2 + 1 = 12
            nn.GroupNorm(32, 256),
            nn.ReLU(True),

            # Conv4: 12x12 -> 6x6
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), # (12 - 4 + 2*1)/2 + 1 = 6
            nn.GroupNorm(32, 512), # Added GroupNorm here as well
            nn.ReLU(True),

            # Conv5: 6x6 -> 1x1 (final feature map before flatten)
            nn.Conv2d(512, 512, kernel_size=6, stride=1, padding=0), # (6 - 6 + 2*0)/1 + 1 = 1
            nn.ReLU(True),

            nn.Flatten(),  # 512 * 1 * 1 = 512
            nn.Linear(512, latent_dim),
            nn.ReLU(True)
        )

        # Decoder: Latent Vector -> 96x96 2D Field
        self.decoder = nn.Sequential(
            # Project latent vector to a feature map that can be upsampled
            # Start with a 1x1 latent space from the encoder, which expands to 512 channels
            # We need to project it back to a feature map that can be upsampled to 96x96
            # The last encoder conv outputs 512x1x1, so we start the decoder from that point.
            # We need to reverse the operations.
            # Decoder starts from latent_dim -> 512 (same as the last encoder conv layer output channels)
            nn.Linear(latent_dim, 512 * 1 * 1), # Project latent_dim to the starting feature map size
            nn.ReLU(True),
            nn.Unflatten(1, (512, 1, 1)),  # Reshape to (Batch, 512, 1, 1)

            # Up1: 1x1 -> 6x6 (reverse of Conv5: 6x6 -> 1x1)
            nn.ConvTranspose2d(512, 256, kernel_size=6, stride=1, padding=0), # (1-1)*1 - 2*0 + 6 + 0 + 1 = 6
            nn.GroupNorm(32, 256),
            nn.ReLU(True),

            # Up2: 6x6 -> 12x12 (reverse of Conv4: 12x12 -> 6x6)
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # (6-1)*2 - 2*1 + 4 + 0 + 1 = 12
            nn.GroupNorm(16, 128),
            nn.ReLU(True),

            # Up3: 12x12 -> 24x24 (reverse of Conv3: 24x24 -> 12x12)
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # (12-1)*2 - 2*1 + 4 + 0 + 1 = 24
            nn.GroupNorm(8, 64),
            nn.ReLU(True),

            # Up4: 24x24 -> 48x48 (reverse of Conv2: 48x48 -> 24x24)
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # (24-1)*2 - 2*1 + 4 + 0 + 1 = 48
            nn.GroupNorm(4, 32), # Adjusted group norm if necessary
            nn.ReLU(True),

            # Up5: 48x48 -> 96x96 (reverse of Conv1: 96x96 -> 48x48)
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1), # (48-1)*2 - 2*1 + 4 + 0 + 1 = 96
            nn.Tanh() # Using Tanh to keep the output values typically in a range like -1 to 1
        )

    def forward(self, x):
        latent = self.encoder(x)
        output_2d = self.decoder(latent)
        # Squeeze the channel dimension if it's 1, resulting in (B, 96, 96)
        output_2d = output_2d.squeeze(1)
        return output_2d

In [None]:
# import torch
# import torch.nn.functional as F
# import matplotlib.pyplot as plt
# import numpy as np
# import os

# def soft_clamp(x, min_val, max_val, temperature=1.0):
#     """
#     Differentiable soft clamp function using sigmoid.
#     """
#     # Normalize to [0, 1]
#     x_norm = (x - min_val) / (max_val - min_val)
#     # Apply sigmoid for soft clamping
#     x_clamped = torch.sigmoid(temperature * (x_norm - 0.5)) * 0.98 + 0.01
#     # Scale back to original range
#     return x_clamped * (max_val - min_val) + min_val

# def differentiable_grid_sample(field, coords, x_min=-5, x_max=5, y_min=-5, y_max=5):
#     """
#     Differentiable bilinear interpolation using grid_sample.
    
#     Args:
#         field: tensor of shape (B, 1, H, W) - batch of potential fields
#         coords: tensor of shape (B, N, 2) - batch of coordinates in world space
        
#     Returns:
#         values: tensor of shape (B, N) - interpolated values
#     """
#     B, _, H, W = field.shape
#     N = coords.shape[1]
    
#     # Normalize coordinates to [-1, 1] for grid_sample
#     x_norm = 2 * (coords[:, :, 0] - x_min) / (x_max - x_min) - 1
#     y_norm = 2 * (coords[:, :, 1] - y_min) / (y_max - y_min) - 1
    
#     # Stack and reshape for grid_sample: (B, N, 1, 2)
#     grid = torch.stack([x_norm, y_norm], dim=-1).unsqueeze(2)
    
#     # Sample from field: (B, 1, N, 1)
#     sampled = F.grid_sample(field, grid, mode='bilinear', padding_mode='border', align_corners=True)
    
#     # Reshape to (B, N)
#     return sampled.squeeze(1).squeeze(-1)

# def compute_gradient_differentiable(fields, coords, x_min=-5, x_max=5, y_min=-5, y_max=5, h=0.01):
#     """
#     Compute gradients differentiably using finite differences with grid_sample.
    
#     Args:
#         fields: tensor of shape (B, 1, H, W) - batch of potential fields
#         coords: tensor of shape (B, 2) - batch of (x, y) coordinates
        
#     Returns:
#         gradients: tensor of shape (B, 2) - batch of gradients
#     """
#     B = fields.shape[0]
#     device = fields.device
    
#     # Create offset coordinates for finite differences
#     coords_expanded = coords.unsqueeze(1)  # (B, 1, 2)
    
#     # Create perturbation vectors
#     dx_vec = torch.tensor([h, 0.0], device=device).expand(B, 1, 2)
#     dy_vec = torch.tensor([0.0, h], device=device).expand(B, 1, 2)
    
#     coords_x_plus = coords_expanded + dx_vec
#     coords_x_minus = coords_expanded - dx_vec
#     coords_y_plus = coords_expanded + dy_vec
#     coords_y_minus = coords_expanded - dy_vec
    
#     # Soft clamp coordinates to stay within bounds
#     coords_x_plus[:, :, 0] = soft_clamp(coords_x_plus[:, :, 0], x_min, x_max)
#     coords_x_minus[:, :, 0] = soft_clamp(coords_x_minus[:, :, 0], x_min, x_max)
#     coords_y_plus[:, :, 1] = soft_clamp(coords_y_plus[:, :, 1], y_min, y_max)
#     coords_y_minus[:, :, 1] = soft_clamp(coords_y_minus[:, :, 1], y_min, y_max)
    
#     # Sample values at perturbed coordinates
#     val_x_plus = differentiable_grid_sample(fields, coords_x_plus, x_min, x_max, y_min, y_max)
#     val_x_minus = differentiable_grid_sample(fields, coords_x_minus, x_min, x_max, y_min, y_max)
#     val_y_plus = differentiable_grid_sample(fields, coords_y_plus, x_min, x_max, y_min, y_max)
#     val_y_minus = differentiable_grid_sample(fields, coords_y_minus, x_min, x_max, y_min, y_max)
    
#     # Compute gradients
#     grad_x = (val_x_plus - val_x_minus) / (2 * h)
#     grad_y = (val_y_plus - val_y_minus) / (2 * h)
    
#     return torch.stack([grad_x.squeeze(), grad_y.squeeze()], dim=-1)

# def find_extrema_differentiable(fields, x_min=-5, x_max=5, y_min=-5, y_max=5, temperature=40):
#     """
#     Find extrema using differentiable operations with high-temperature softmax.
#     This provides a much better approximation to true argmax/argmin.
    
#     Args:
#         fields: tensor of shape (B, H, W) or (B, 1, H, W)
#         temperature: Higher values make the softmax sharper (closer to true argmax)
        
#     Returns:
#         max_coords: tensor of shape (B, 2) - approximate world coordinates of maxima
#         min_coords: tensor of shape (B, 2) - approximate world coordinates of minima
#     """
#     if len(fields.shape) == 3:
#         fields = fields.unsqueeze(1)  # Add channel dimension
    
#     B, _, H, W = fields.shape
#     device = fields.device
    
#     # Create coordinate grids
#     x_coords = torch.linspace(x_min, x_max, W, device=device)
#     y_coords = torch.linspace(y_min, y_max, H, device=device)
#     Y_grid, X_grid = torch.meshgrid(y_coords, x_coords, indexing='ij')
    
#     # Flatten coordinates
#     X_flat = X_grid.flatten()  # (H*W,)
#     Y_flat = Y_grid.flatten()  # (H*W,)
    
#     max_coords = torch.zeros(B, 2, device=device)
#     min_coords = torch.zeros(B, 2, device=device)
    
#     for i in range(B):
#         field = fields[i, 0]  # (H, W)
#         field_flat = field.flatten()  # (H*W,)
        
#         # Normalize field values to prevent overflow with high temperature
#         field_normalized = (field_flat - field_flat.mean()) / (field_flat.std() + 1e-8)
        
#         # For maximum: use high-temperature softmax
#         max_weights = F.softmax(field_normalized * temperature, dim=0)
#         max_x = torch.sum(max_weights * X_flat)
#         max_y = torch.sum(max_weights * Y_flat)
#         max_coords[i] = torch.stack([max_x, max_y])
        
#         # For minimum: use high-temperature softmax on negative normalized values
#         min_weights = F.softmax(-field_normalized * temperature, dim=0)
#         min_x = torch.sum(min_weights * X_flat)
#         min_y = torch.sum(min_weights * Y_flat)
#         min_coords[i] = torch.stack([min_x, min_y])
    
#     return max_coords, min_coords

# def differentiable_gradient_descent_batch(fields, learning_rate=0.01, num_steps=100, 
#                                         num_path_points=10, x_min=-5, x_max=5, 
#                                         y_min=-5, y_max=5, target_end_points=None):
#     """
#     Fully differentiable gradient descent path finding.
    
#     Args:
#         fields: tensor of shape (B, H, W) or (B, 1, H, W) - batch of potential fields
#         target_end_points: Optional tensor of shape (B, 2) - if provided, use these as targets
        
#     Returns:
#         paths: tensor of shape (B, num_path_points, 2) - batch of paths
#     """
#     # Handle input shape
#     if len(fields.shape) == 3:  # (B, H, W)
#         fields = fields.unsqueeze(1)  # Convert to (B, 1, H, W)
#     elif len(fields.shape) == 4 and fields.shape[-1] == 1:  # (B, H, W, 1)
#         fields = fields.permute(0, 3, 1, 2)  # Convert to (B, 1, H, W)
    
#     B, _, H, W = fields.shape
#     device = fields.device
    
#     # Find start and end points - these will be differentiable
#     start_coords, end_coords = find_extrema_differentiable(fields, x_min, x_max, y_min, y_max)
    
#     if target_end_points is not None:
#         end_coords = target_end_points
    
#     # Initialize path storage - we'll store every step
#     all_positions = torch.zeros(B, num_steps + 1, 2, device=device)
#     all_positions[:, 0] = start_coords
    
#     current_positions = start_coords.clone()
    
#     # Gradient descent loop - fully differentiable
#     for step in range(num_steps):
#         # Compute gradients at current positions
#         gradients = compute_gradient_differentiable(fields, current_positions, 
#                                                   x_min, x_max, y_min, y_max)
        
#         # Update positions
#         current_positions = current_positions - learning_rate * gradients
        
#         # Soft clamp to bounds
#         current_positions[:, 0] = soft_clamp(current_positions[:, 0], x_min, x_max)
#         current_positions[:, 1] = soft_clamp(current_positions[:, 1], y_min, y_max)
        
#         all_positions[:, step + 1] = current_positions
    
#     # Differentiable path resampling using interpolation
#     # Create indices for resampling
#     step_indices = torch.linspace(0, num_steps, num_path_points, device=device)
    
#     # Use linear interpolation to get path points at desired indices
#     paths = torch.zeros(B, num_path_points, 2, device=device)
    
#     for i in range(num_path_points):
#         idx = step_indices[i]
        
#         # Get integer parts and fractional part for interpolation
#         idx_floor = torch.floor(idx).long()
#         idx_ceil = torch.ceil(idx).long()
#         alpha = idx - idx_floor.float()
        
#         # Clamp indices
#         idx_floor = torch.clamp(idx_floor, 0, num_steps)
#         idx_ceil = torch.clamp(idx_ceil, 0, num_steps)
        
#         # Linear interpolation between adjacent points
#         if idx_floor == idx_ceil:
#             paths[:, i] = all_positions[:, idx_floor]
#         else:
#             paths[:, i] = (1 - alpha) * all_positions[:, idx_floor] + alpha * all_positions[:, idx_ceil]
    
#     return paths

# def save_batch_plots_differentiable(fields, paths, output_dir="model/outputs", batch_idx=0, 
#                                   x_min=-1, x_max=1, y_min=-1, y_max=1, max_plots=4):
#     """
#     Save plots for a subset of the batch (non-differentiable, for visualization only).
#     """
#     os.makedirs(output_dir, exist_ok=True)
    
#     # Handle different input shapes
#     if len(fields.shape) == 4:
#         if fields.shape[1] == 1:  # (B, 1, H, W)
#             fields_plot = fields.squeeze(1)
#         else:  # (B, H, W, 1)
#             fields_plot = fields.squeeze(-1)
#     else:  # (B, H, W)
#         fields_plot = fields
    
#     B = min(fields_plot.shape[0], max_plots)
    
#     for i in range(B):
#         field = fields_plot[i].detach().cpu().numpy()
#         path = paths[i].detach().cpu().numpy()
        
#         plt.figure(figsize=(10, 8))
#         plt.imshow(field, origin='lower', cmap='viridis', 
#                    extent=[x_min, x_max, y_min, y_max])
#         plt.colorbar(label='Potential Value')
#         plt.title(f'Batch {batch_idx}, Sample {i}: Differentiable Path')
#         plt.xlabel('X')
#         plt.ylabel('Y')

#         # Plot path
#         plt.plot(path[:, 0], path[:, 1], 'r--', linewidth=2, label='Differentiable Path')
#         plt.plot(path[:, 0], path[:, 1], 'go', markersize=4, label='Path Points')
#         plt.plot(path[0, 0], path[0, 1], 'ro', markersize=8, label='Start')
#         plt.plot(path[-1, 0], path[-1, 1], 'bx', markersize=8, label='End')
        
#         plt.legend()
#         plt.grid(True)
#         plt.tight_layout()
#         plt.savefig(os.path.join(output_dir, f"differentiable_batch_{batch_idx}_sample_{i}.png"), 
#                    dpi=300, bbox_inches='tight')
#         plt.close()

# def process_batch_fields_differentiable(field_batch, learning_rate=0.05, num_steps=100, 
#                                       num_path_points=10, output_dir="model/outputs", 
#                                       batch_idx=0, save_plots_flag=False, 
#                                       x_min=-1, x_max=1, y_min=-1, y_max=1):
#     """
#     Process a batch of potential fields to find differentiable gradient descent paths.
    
#     Args:
#         field_batch: tensor of shape (B, 96, 96, 1) or (B, 96, 96) or (B, 1, 96, 96)
        
#     Returns:
#         paths: tensor of shape (B, num_path_points, 2) - batch of paths (differentiable)
#     """
#     # Find paths for the entire batch
#     paths = differentiable_gradient_descent_batch(
#         field_batch, 
#         learning_rate=learning_rate, 
#         num_steps=num_steps,
#         num_path_points=num_path_points,
#         x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max
#     )
    
#     # Save plots if requested (this part is not differentiable)
#     if save_plots_flag:
#         save_batch_plots_differentiable(field_batch, paths, output_dir, batch_idx, 
#                                       x_min, x_max, y_min, y_max)
    
#     return paths



In [None]:
model=RGBTo2DAutoencoder()
# field = model(torch.randn((1,3,96,96)))
# shape=field.shape
# paths=[]
# for _ in  range(shape[0]):
# path = process_batch_fields_differentiable(
#         field,
#         learning_rate=0.05,
#         num_steps=50,  # Reduced for efficiency
#         num_path_points=10,
#         output_dir="model/outputs",
#         batch_idx=0,
#         save_plots_flag=True
# )

model=model.to(device)


In [None]:
# import torch

# def soft_clamp(x, min_val, max_val, temperature=1.0):
#     x_norm = (x - min_val) / (max_val - min_val)
#     return (torch.sigmoid(temperature * (x_norm - 0.5)) * 0.98 + 0.01) * (max_val - min_val) + min_val

# def linear_interp1d(x, xp, yp):
#     indices = torch.searchsorted(xp, x).clamp(1, len(xp) - 1)
#     x1, x2 = xp[indices-1], xp[indices]
#     y1, y2 = yp[indices-1], yp[indices]
#     return y1 + (x - x1) * (y2 - y1) / (x2 - x1 + 1e-8)

# def compute_differentiable_path(potential_field, step_size=0.1, num_simulation_steps=100, num_path_points=10, convergence_threshold=1e-4):
#     B, H, W = potential_field.shape
#     device = potential_field.device
    
#     flat_potentials = potential_field.view(B, -1)
#     start_indices = torch.argmax(flat_potentials, dim=1)
#     end_indices = torch.argmin(flat_potentials, dim=1)
    
#     start_pos = torch.stack([(start_indices % W).float(), (start_indices // W).float()], dim=1)
#     end_pos = torch.stack([(end_indices % W).float(), (end_indices // W).float()], dim=1)
    
#     current_pos = start_pos.clone()
#     converged = torch.zeros(B, dtype=torch.bool, device=device)
#     path_lengths = torch.full((B,), num_simulation_steps, device=device)
#     all_positions = torch.zeros(num_simulation_steps+1, B, 2, device=device)
#     all_positions[0] = start_pos

#     for step in range(1, num_simulation_steps+1):
#         gradients = compute_fast_gradient(potential_field, current_pos)
#         step_direction = -gradients / (torch.norm(gradients, dim=1, keepdim=True) + 1e-8)
        
#         new_pos = current_pos + step_size * step_direction
#         new_pos = torch.stack([
#             torch.clamp(new_pos[:, 0], 0.0, W-1.0),
#             torch.clamp(new_pos[:, 1], 0.0, H-1.0)
#         ], dim=1)
        
#         distances = torch.norm(new_pos - end_pos, dim=1)
#         newly_converged = (distances < convergence_threshold) & ~converged
#         path_lengths = torch.where(newly_converged, step, path_lengths)
#         converged = converged | newly_converged
        
#         current_pos = torch.where(converged.unsqueeze(1), current_pos, new_pos)
#         all_positions[step] = current_pos
        
#         if converged.all():
#             all_positions = all_positions[:step+1]
#             break

#     final_paths = torch.zeros(B, num_path_points, 2, device=device)
#     interp_indices = torch.linspace(0, 1, num_path_points, device=device)
    
#     for b in range(B):
#         n_steps = min(path_lengths[b].item()+1, all_positions.shape[0])
#         path = all_positions[:n_steps, b]
#         t = torch.linspace(0, 1, n_steps, device=device)
#         for dim in range(2):
#             final_paths[b, :, dim] = linear_interp1d(interp_indices, t, path[:, dim])
    
#     return final_paths

# def compute_fast_gradient(potential_field, positions):
#     B, H, W = potential_field.shape
#     x, y = positions[:, 0], positions[:, 1]
#     eps = 0.5
    
#     x_plus = torch.clamp(x + eps, 0.0, W-1.0)
#     x_minus = torch.clamp(x - eps, 0.0, W-1.0)
#     y_plus = torch.clamp(y + eps, 0.0, H-1.0)
#     y_minus = torch.clamp(y - eps, 0.0, H-1.0)
    
#     val_x_plus = bilinear_sample_batch(potential_field, x_plus, y)
#     val_x_minus = bilinear_sample_batch(potential_field, x_minus, y)
#     val_y_plus = bilinear_sample_batch(potential_field, x, y_plus)
#     val_y_minus = bilinear_sample_batch(potential_field, x, y_minus)
    
#     return torch.stack([
#         (val_x_plus - val_x_minus) / (x_plus - x_minus + 1e-8),
#         (val_y_plus - val_y_minus) / (y_plus - y_minus + 1e-8)
#     ], dim=1)

# def bilinear_sample_batch(field, x, y):
#     B, H, W = field.shape
#     x0 = torch.floor(x).long().clamp(0, W-2)
#     y0 = torch.floor(y).long().clamp(0, H-2)
#     x1, y1 = x0 + 1, y0 + 1
    
#     wx = x - x0.float()
#     wy = y - y0.float()
    
#     return ((1-wx)*(1-wy)*field[torch.arange(B), y0, x0] + 
#             wx*(1-wy)*field[torch.arange(B), y0, x1] + 
#             (1-wx)*wy*field[torch.arange(B), y1, x0] + 
#             wx*wy*field[torch.arange(B), y1, x1])



In [None]:


# Training loop
import random
import numpy as np
import torch.nn as nn
from tqdm import tqdm

def train_model(model, dataloader, device, obs_horizon, action_horizon, num_epochs=10):
    """Train the neural network model."""
    optimizer = torch.optim.Adam(
        params=model.parameters(),
        lr=1e-2, 
        # weight_decay=1e-6
    )
    
    epoch_losses = []
    batch_count = 0
    
    with tqdm(range(num_epochs), desc='Epoch') as tglobal:
        for epoch_idx in tglobal:
            epoch_loss = []
            
            with tqdm(dataloader, desc='Batch', leave=False) as tepoch:
                for batch_idx, nbatch in enumerate(tepoch):
                    batch_count += 1
                    
                    # Data transfer to device
                    nimage = nbatch['image'][:, :obs_horizon].to(device).squeeze(1)
                    nagent_pos = nbatch['agent_pos'][:, :obs_horizon].to(device)
                    naction = nbatch['agent_pos'].to(device)
                    
                    # Forward pass
                    field = model(nimage)
                    B = nagent_pos.shape[0]
                    
                    # Process batch fields differentiably
                    path = compute_differentiable_path(
                        field,
                        num_path_points=action_horizon
                    ).to(device)
                    
                    # Normalize path if needed
                    normalized_path = normalize_data_torch(path,dataset.stats["agent_pos"])
                    
                    # Visualization (occasional)
                    if batch_idx % 50 == 0:
                        try:
                            plot_agent_positions(naction, normalized_path)
                            viz(field.clone().detach().cpu()[random.randint(0, min(63, B-1))])
                        except:
                            pass  # Skip visualization if functions not available
                    
                    # Compute loss
                    loss = nn.functional.mse_loss(normalized_path, naction)
                    
                    # Backward pass and optimization
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                    # Logging
                    loss_cpu = loss.item()
                    epoch_loss.append(loss_cpu)
                    
                    batch_avg_loss = np.mean(epoch_loss)
                    tepoch.set_postfix({
                        'batch_loss': f'{loss_cpu:.6f}',
                        'batch_avg': f'{batch_avg_loss:.6f}',
                        'batch': f'{batch_idx+1}/{len(dataloader)}'
                    })
                    
                    # # Print detailed batch stats every 20 batches
                    if (batch_idx + 1) % 20 == 0:
                        print(f"\nBatch {batch_idx+1}/{len(dataloader)} | "
                              f"Loss: {loss_cpu:.6f} | "
                              f"Avg Loss: {batch_avg_loss:.6f} | "
                              f"Field Shape: {field.shape} | "
                              f"Path Shape: {path.shape}")
            
            # Epoch statistics
            epoch_avg_loss = np.mean(epoch_loss)
            epoch_losses.append(epoch_avg_loss)
            
            tglobal.set_postfix({
                'epoch_loss': f'{epoch_avg_loss:.6f}',
                'best_loss': f'{min(epoch_losses):.6f}',
                'batches': len(dataloader)
            })
    
    return epoch_losses

# Usage example:
epoch_losses = train_model(model, dataloader, device, obs_horizon, action_horizon, num_epochs=30)

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]


Batch 20/802 | Loss: 0.587764 | Avg Loss: 0.450931 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 40/802 | Loss: 0.650118 | Avg Loss: 0.530918 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 60/802 | Loss: 0.699989 | Avg Loss: 0.482667 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 80/802 | Loss: 0.716899 | Avg Loss: 0.522848 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 100/802 | Loss: 0.682860 | Avg Loss: 0.546233 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 120/802 | Loss: 0.646738 | Avg Loss: 0.567073 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 140/802 | Loss: 0.654084 | Avg Loss: 0.580565 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 160/802 | Loss: 0.492462 | Avg Loss: 0.587585 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 180/802 | Loss: 0.608223 | Avg Loss: 0.577737 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 200/802 | Loss: 0.546355 | Avg Loss: 0.569716 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 220/802 | Loss: 0.564522 | Avg Loss: 0.554888 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 240/802 | Loss: 0.620572 | Avg Loss: 0.544963 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 260/802 | Loss: 0.452841 | Avg Loss: 0.538729 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 280/802 | Loss: 0.340405 | Avg Loss: 0.535594 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 300/802 | Loss: 0.465101 | Avg Loss: 0.532429 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 320/802 | Loss: 0.458906 | Avg Loss: 0.529020 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 340/802 | Loss: 0.412039 | Avg Loss: 0.525332 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 360/802 | Loss: 0.397871 | Avg Loss: 0.523877 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 380/802 | Loss: 0.511898 | Avg Loss: 0.521894 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 400/802 | Loss: 0.432478 | Avg Loss: 0.521470 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])


  plt.figure(figsize=(6, 6))
  plt.figure(figsize=(7, 6)) # Slightly wider for colorbar



Batch 420/802 | Loss: 0.512776 | Avg Loss: 0.519833 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 440/802 | Loss: 0.665649 | Avg Loss: 0.519677 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])


  fig = plt.figure(figsize=figsize)




Batch 460/802 | Loss: 0.761151 | Avg Loss: 0.521704 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])


Batch:  57%|█████▋    | 461/802 [03:30<01:15,  4.49it/s, batch_loss=0.727817, batch_avg=0.522430, batch=462/802][A



Batch 480/802 | Loss: 0.805210 | Avg Loss: 0.529228 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])


Batch:  60%|█████▉    | 481/802 [03:32<00:32,  9.88it/s, batch_loss=0.710195, batch_avg=0.529862, batch=482/802][A


Batch 500/802 | Loss: 0.606881 | Avg Loss: 0.536901 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 520/802 | Loss: 0.810126 | Avg Loss: 0.543739 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])



Batch:  67%|██████▋   | 540/802 [03:38<00:25, 10.20it/s, batch_loss=0.735693, batch_avg=0.550899, batch=542/802]


Batch 540/802 | Loss: 0.782193 | Avg Loss: 0.550382 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])


[A
Batch:  70%|███████   | 562/802 [03:40<00:24,  9.65it/s, batch_loss=0.761081, batch_avg=0.558003, batch=562/802]


Batch 560/802 | Loss: 0.633847 | Avg Loss: 0.557424 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])


[A



Batch 580/802 | Loss: 0.674008 | Avg Loss: 0.563394 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])


Batch:  72%|███████▏  | 581/802 [03:42<00:22,  9.74it/s, batch_loss=0.746661, batch_avg=0.564044, batch=582/802][A


Batch 600/802 | Loss: 0.688997 | Avg Loss: 0.569468 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 620/802 | Loss: 0.763105 | Avg Loss: 0.574263 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])



Batch:  80%|███████▉  | 640/802 [03:48<00:16, 10.04it/s, batch_loss=0.718150, batch_avg=0.579076, batch=642/802]


Batch 640/802 | Loss: 0.680744 | Avg Loss: 0.578682 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])


[A


Batch 660/802 | Loss: 0.858972 | Avg Loss: 0.583342 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 680/802 | Loss: 0.825478 | Avg Loss: 0.587625 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 700/802 | Loss: 0.512884 | Avg Loss: 0.591398 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 720/802 | Loss: 0.713524 | Avg Loss: 0.594950 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])






Batch 740/802 | Loss: 0.691054 | Avg Loss: 0.598170 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])


Batch:  92%|█████████▏| 740/802 [03:59<00:06, 10.05it/s, batch_loss=0.791067, batch_avg=0.598680, batch=742/802][A


Batch 760/802 | Loss: 0.746876 | Avg Loss: 0.601859 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 780/802 | Loss: 0.604236 | Avg Loss: 0.604471 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 800/802 | Loss: 0.671545 | Avg Loss: 0.608041 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])


Epoch:   3%|▎         | 1/30 [04:05<1:58:41, 245.58s/it, epoch_loss=0.608634, best_loss=0.608634, batches=802]


Batch 20/802 | Loss: 0.688348 | Avg Loss: 0.719648 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 40/802 | Loss: 0.712180 | Avg Loss: 0.738729 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])



Batch:   7%|▋         | 60/802 [00:06<01:21,  9.09it/s, batch_loss=0.774438, batch_avg=0.732756, batch=62/802]


Batch 60/802 | Loss: 0.728512 | Avg Loss: 0.732834 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])


[A


Batch 80/802 | Loss: 0.802682 | Avg Loss: 0.725565 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 100/802 | Loss: 0.613004 | Avg Loss: 0.725399 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 120/802 | Loss: 0.632914 | Avg Loss: 0.723152 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 140/802 | Loss: 0.733487 | Avg Loss: 0.721150 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 160/802 | Loss: 0.663925 | Avg Loss: 0.723375 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 180/802 | Loss: 0.593841 | Avg Loss: 0.724300 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 200/802 | Loss: 0.696843 | Avg Loss: 0.726740 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 220/802 | Loss: 0.789181 | Avg Loss: 0.727431 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])






Batch 240/802 | Loss: 0.703347 | Avg Loss: 0.726351 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])


Batch:  30%|██▉       | 240/802 [00:25<00:53, 10.51it/s, batch_loss=0.705414, batch_avg=0.726052, batch=242/802][A


Batch 260/802 | Loss: 0.713876 | Avg Loss: 0.724064 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 280/802 | Loss: 0.724121 | Avg Loss: 0.724667 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 300/802 | Loss: 0.771433 | Avg Loss: 0.727118 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 320/802 | Loss: 0.620512 | Avg Loss: 0.728597 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 340/802 | Loss: 0.662825 | Avg Loss: 0.728648 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 360/802 | Loss: 0.760567 | Avg Loss: 0.728495 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 380/802 | Loss: 0.797051 | Avg Loss: 0.728345 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 400/802 | Loss: 0.691104 | Avg Loss: 0.728286 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 420/802 | Loss: 0.811678 | Avg Loss: 0.727854 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 440/802 | Loss: 0.707385 | Avg Loss: 0.728324 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 460/802 | Loss: 0.699053 | Avg Loss: 0.728193 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 480/802 | Loss: 0.731422 | Avg Loss: 0.728499 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 500/802 | Loss: 0.988216 | Avg Loss: 0.729905 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 520/802 | Loss: 1.257149 | Avg Loss: 0.745157 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 540/802 | Loss: 1.511546 | Avg Loss: 0.762625 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 560/802 | Loss: 1.406076 | Avg Loss: 0.783126 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 580/802 | Loss: 1.297586 | Avg Loss: 0.802197 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 600/802 | Loss: 1.433803 | Avg Loss: 0.820890 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 620/802 | Loss: 1.361646 | Avg Loss: 0.837612 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 640/802 | Loss: 1.341511 | Avg Loss: 0.853513 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 660/802 | Loss: 1.496309 | Avg Loss: 0.869742 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 680/802 | Loss: 1.365250 | Avg Loss: 0.884807 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 700/802 | Loss: 1.300582 | Avg Loss: 0.897791 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 720/802 | Loss: 1.448894 | Avg Loss: 0.911052 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 740/802 | Loss: 1.303125 | Avg Loss: 0.923148 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 760/802 | Loss: 1.350074 | Avg Loss: 0.934736 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 780/802 | Loss: 1.377925 | Avg Loss: 0.946252 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 800/802 | Loss: 1.251534 | Avg Loss: 0.957217 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])


Epoch:   7%|▋         | 2/30 [07:10<1:37:52, 209.72s/it, epoch_loss=0.958583, best_loss=0.608634, batches=802]


Batch 20/802 | Loss: 1.351059 | Avg Loss: 1.394061 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 40/802 | Loss: 1.376521 | Avg Loss: 1.379434 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 60/802 | Loss: 1.463892 | Avg Loss: 1.377431 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 80/802 | Loss: 1.364043 | Avg Loss: 1.375811 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 100/802 | Loss: 1.245652 | Avg Loss: 1.372285 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 120/802 | Loss: 1.227510 | Avg Loss: 1.366291 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 140/802 | Loss: 1.355831 | Avg Loss: 1.364961 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 160/802 | Loss: 1.379351 | Avg Loss: 1.362852 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 180/802 | Loss: 1.319665 | Avg Loss: 1.361503 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 200/802 | Loss: 1.469905 | Avg Loss: 1.361298 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 220/802 | Loss: 1.316605 | Avg Loss: 1.365249 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 240/802 | Loss: 1.524111 | Avg Loss: 1.364693 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 260/802 | Loss: 1.342606 | Avg Loss: 1.369682 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 280/802 | Loss: 1.217159 | Avg Loss: 1.367316 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 300/802 | Loss: 1.395518 | Avg Loss: 1.368702 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 320/802 | Loss: 1.173432 | Avg Loss: 1.368491 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 340/802 | Loss: 1.296190 | Avg Loss: 1.368382 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 360/802 | Loss: 1.384969 | Avg Loss: 1.367739 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 380/802 | Loss: 1.187173 | Avg Loss: 1.369129 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 400/802 | Loss: 1.324022 | Avg Loss: 1.369837 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 420/802 | Loss: 1.454073 | Avg Loss: 1.369538 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 440/802 | Loss: 1.442573 | Avg Loss: 1.371136 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 460/802 | Loss: 1.321055 | Avg Loss: 1.371027 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 480/802 | Loss: 1.365732 | Avg Loss: 1.370772 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 500/802 | Loss: 1.413503 | Avg Loss: 1.370005 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 520/802 | Loss: 1.400601 | Avg Loss: 1.370492 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 540/802 | Loss: 1.268209 | Avg Loss: 1.369855 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 560/802 | Loss: 1.412834 | Avg Loss: 1.368457 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 580/802 | Loss: 1.339345 | Avg Loss: 1.368462 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 600/802 | Loss: 1.191214 | Avg Loss: 1.366740 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 620/802 | Loss: 1.324149 | Avg Loss: 1.366176 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 640/802 | Loss: 1.335883 | Avg Loss: 1.366253 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 660/802 | Loss: 1.265672 | Avg Loss: 1.366618 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 680/802 | Loss: 1.361205 | Avg Loss: 1.366925 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 700/802 | Loss: 1.147187 | Avg Loss: 1.365872 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 720/802 | Loss: 1.468407 | Avg Loss: 1.367316 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 740/802 | Loss: 1.241564 | Avg Loss: 1.365751 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 760/802 | Loss: 1.338039 | Avg Loss: 1.365431 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 780/802 | Loss: 1.248630 | Avg Loss: 1.366031 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 800/802 | Loss: 1.500770 | Avg Loss: 1.365913 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])


Epoch:  10%|█         | 3/30 [12:44<2:00:03, 266.81s/it, epoch_loss=1.365933, best_loss=0.608634, batches=802]


Batch 20/802 | Loss: 1.332771 | Avg Loss: 1.358795 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 40/802 | Loss: 1.343263 | Avg Loss: 1.369401 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 60/802 | Loss: 1.342473 | Avg Loss: 1.364364 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 80/802 | Loss: 1.448154 | Avg Loss: 1.362390 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 100/802 | Loss: 1.301805 | Avg Loss: 1.368782 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 120/802 | Loss: 1.421135 | Avg Loss: 1.371070 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 140/802 | Loss: 1.394486 | Avg Loss: 1.372495 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 160/802 | Loss: 1.444103 | Avg Loss: 1.377917 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 180/802 | Loss: 1.383004 | Avg Loss: 1.373077 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 200/802 | Loss: 1.588835 | Avg Loss: 1.373686 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 220/802 | Loss: 1.291030 | Avg Loss: 1.372518 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 240/802 | Loss: 1.238286 | Avg Loss: 1.370572 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 260/802 | Loss: 1.460681 | Avg Loss: 1.370244 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 280/802 | Loss: 1.164383 | Avg Loss: 1.371220 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 300/802 | Loss: 1.485396 | Avg Loss: 1.369952 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 320/802 | Loss: 1.331254 | Avg Loss: 1.368029 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 340/802 | Loss: 1.523229 | Avg Loss: 1.367679 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 360/802 | Loss: 1.379559 | Avg Loss: 1.368663 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 380/802 | Loss: 1.312060 | Avg Loss: 1.369503 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 400/802 | Loss: 1.254816 | Avg Loss: 1.370184 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 420/802 | Loss: 1.252317 | Avg Loss: 1.369108 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 440/802 | Loss: 1.307106 | Avg Loss: 1.367972 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 460/802 | Loss: 1.348188 | Avg Loss: 1.367065 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 480/802 | Loss: 1.302052 | Avg Loss: 1.366346 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 500/802 | Loss: 1.396358 | Avg Loss: 1.366646 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 520/802 | Loss: 1.292418 | Avg Loss: 1.365839 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 540/802 | Loss: 1.242507 | Avg Loss: 1.366782 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 560/802 | Loss: 1.586451 | Avg Loss: 1.366296 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 580/802 | Loss: 1.235495 | Avg Loss: 1.368007 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 600/802 | Loss: 1.256133 | Avg Loss: 1.367599 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 620/802 | Loss: 1.330848 | Avg Loss: 1.368743 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 640/802 | Loss: 1.244790 | Avg Loss: 1.368824 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 660/802 | Loss: 1.532374 | Avg Loss: 1.369232 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 680/802 | Loss: 1.433589 | Avg Loss: 1.369796 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 700/802 | Loss: 1.419050 | Avg Loss: 1.369276 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 720/802 | Loss: 1.055248 | Avg Loss: 1.368284 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 740/802 | Loss: 1.270555 | Avg Loss: 1.367691 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 760/802 | Loss: 1.168746 | Avg Loss: 1.367168 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 780/802 | Loss: 1.260293 | Avg Loss: 1.366359 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 800/802 | Loss: 1.320768 | Avg Loss: 1.366333 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])


Epoch:  13%|█▎        | 4/30 [17:45<2:01:28, 280.31s/it, epoch_loss=1.365850, best_loss=0.608634, batches=802]


Batch 20/802 | Loss: 1.337064 | Avg Loss: 1.364879 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 40/802 | Loss: 1.366432 | Avg Loss: 1.358406 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 60/802 | Loss: 1.346766 | Avg Loss: 1.359180 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 80/802 | Loss: 1.379515 | Avg Loss: 1.359079 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 100/802 | Loss: 1.295108 | Avg Loss: 1.356078 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 120/802 | Loss: 1.268922 | Avg Loss: 1.360242 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 140/802 | Loss: 1.329926 | Avg Loss: 1.357042 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 160/802 | Loss: 1.324764 | Avg Loss: 1.353711 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 180/802 | Loss: 1.365802 | Avg Loss: 1.353600 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 200/802 | Loss: 1.317194 | Avg Loss: 1.358330 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 220/802 | Loss: 1.515398 | Avg Loss: 1.360650 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 240/802 | Loss: 1.339432 | Avg Loss: 1.361474 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 260/802 | Loss: 1.404346 | Avg Loss: 1.359981 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 280/802 | Loss: 1.251914 | Avg Loss: 1.360536 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 300/802 | Loss: 1.353393 | Avg Loss: 1.361157 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 320/802 | Loss: 1.603944 | Avg Loss: 1.358810 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 340/802 | Loss: 1.160027 | Avg Loss: 1.361773 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 360/802 | Loss: 1.354922 | Avg Loss: 1.360685 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 380/802 | Loss: 1.246618 | Avg Loss: 1.360631 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 400/802 | Loss: 1.339454 | Avg Loss: 1.361009 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 420/802 | Loss: 1.319709 | Avg Loss: 1.360815 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 440/802 | Loss: 1.511380 | Avg Loss: 1.362770 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 460/802 | Loss: 1.356385 | Avg Loss: 1.363451 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 480/802 | Loss: 1.363045 | Avg Loss: 1.363117 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 500/802 | Loss: 1.441435 | Avg Loss: 1.363786 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 520/802 | Loss: 1.399062 | Avg Loss: 1.364982 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 540/802 | Loss: 1.394416 | Avg Loss: 1.365569 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 560/802 | Loss: 1.379884 | Avg Loss: 1.366159 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 580/802 | Loss: 1.343356 | Avg Loss: 1.366535 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 600/802 | Loss: 1.381045 | Avg Loss: 1.368259 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 620/802 | Loss: 1.382764 | Avg Loss: 1.368455 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 640/802 | Loss: 1.470136 | Avg Loss: 1.367767 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 660/802 | Loss: 1.367772 | Avg Loss: 1.367263 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 680/802 | Loss: 1.453496 | Avg Loss: 1.366538 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 700/802 | Loss: 1.525326 | Avg Loss: 1.366204 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 720/802 | Loss: 1.369173 | Avg Loss: 1.365668 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 740/802 | Loss: 1.408435 | Avg Loss: 1.365660 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 760/802 | Loss: 1.324453 | Avg Loss: 1.365505 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 780/802 | Loss: 1.226526 | Avg Loss: 1.366213 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 800/802 | Loss: 1.613083 | Avg Loss: 1.365742 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])


Epoch:  17%|█▋        | 5/30 [22:47<1:59:59, 287.98s/it, epoch_loss=1.365935, best_loss=0.608634, batches=802]


Batch 20/802 | Loss: 1.251314 | Avg Loss: 1.372674 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 40/802 | Loss: 1.464332 | Avg Loss: 1.351812 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 60/802 | Loss: 1.341715 | Avg Loss: 1.353498 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 80/802 | Loss: 1.438183 | Avg Loss: 1.357893 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 100/802 | Loss: 1.365314 | Avg Loss: 1.358867 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 120/802 | Loss: 1.438013 | Avg Loss: 1.355461 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 140/802 | Loss: 1.460002 | Avg Loss: 1.356364 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 160/802 | Loss: 1.357270 | Avg Loss: 1.359092 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 180/802 | Loss: 1.371486 | Avg Loss: 1.356788 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 200/802 | Loss: 1.421995 | Avg Loss: 1.359717 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 220/802 | Loss: 1.628123 | Avg Loss: 1.359451 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 240/802 | Loss: 1.556743 | Avg Loss: 1.357989 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 260/802 | Loss: 1.461094 | Avg Loss: 1.360643 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 280/802 | Loss: 1.545888 | Avg Loss: 1.359459 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 300/802 | Loss: 1.475539 | Avg Loss: 1.358704 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 320/802 | Loss: 1.315338 | Avg Loss: 1.362272 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 340/802 | Loss: 1.303137 | Avg Loss: 1.364083 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 360/802 | Loss: 1.453633 | Avg Loss: 1.363417 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 380/802 | Loss: 1.306260 | Avg Loss: 1.361226 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 400/802 | Loss: 1.448723 | Avg Loss: 1.360783 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 420/802 | Loss: 1.380110 | Avg Loss: 1.361210 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 440/802 | Loss: 1.406619 | Avg Loss: 1.360626 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 460/802 | Loss: 1.480291 | Avg Loss: 1.360955 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 480/802 | Loss: 1.361089 | Avg Loss: 1.361029 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 500/802 | Loss: 1.469974 | Avg Loss: 1.361084 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 520/802 | Loss: 1.410595 | Avg Loss: 1.360897 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 540/802 | Loss: 1.655366 | Avg Loss: 1.361107 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 560/802 | Loss: 1.271186 | Avg Loss: 1.361766 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 580/802 | Loss: 1.327528 | Avg Loss: 1.361967 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 600/802 | Loss: 1.411964 | Avg Loss: 1.362554 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 620/802 | Loss: 1.070272 | Avg Loss: 1.362425 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 640/802 | Loss: 1.398132 | Avg Loss: 1.363422 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 660/802 | Loss: 1.555165 | Avg Loss: 1.363488 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 680/802 | Loss: 1.288216 | Avg Loss: 1.362097 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 700/802 | Loss: 1.346919 | Avg Loss: 1.362653 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 720/802 | Loss: 1.487878 | Avg Loss: 1.363387 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 740/802 | Loss: 1.420847 | Avg Loss: 1.364009 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 760/802 | Loss: 1.431632 | Avg Loss: 1.364653 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 780/802 | Loss: 1.232164 | Avg Loss: 1.364729 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 800/802 | Loss: 1.406024 | Avg Loss: 1.365686 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])


Epoch:  20%|██        | 6/30 [27:48<1:56:56, 292.37s/it, epoch_loss=1.365878, best_loss=0.608634, batches=802]


Batch 20/802 | Loss: 1.340834 | Avg Loss: 1.351384 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 40/802 | Loss: 1.570557 | Avg Loss: 1.341494 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 60/802 | Loss: 1.179552 | Avg Loss: 1.337352 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 80/802 | Loss: 1.256272 | Avg Loss: 1.360015 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 100/802 | Loss: 1.348654 | Avg Loss: 1.360328 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 120/802 | Loss: 1.438115 | Avg Loss: 1.357363 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 140/802 | Loss: 1.516220 | Avg Loss: 1.356406 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 160/802 | Loss: 1.458093 | Avg Loss: 1.359805 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 180/802 | Loss: 1.249001 | Avg Loss: 1.358716 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 200/802 | Loss: 1.449229 | Avg Loss: 1.362563 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 220/802 | Loss: 1.322163 | Avg Loss: 1.361728 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 240/802 | Loss: 1.418818 | Avg Loss: 1.360097 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 260/802 | Loss: 1.317649 | Avg Loss: 1.360869 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 280/802 | Loss: 1.216597 | Avg Loss: 1.359929 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 300/802 | Loss: 1.438201 | Avg Loss: 1.359748 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 320/802 | Loss: 1.550083 | Avg Loss: 1.359106 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 340/802 | Loss: 1.623917 | Avg Loss: 1.359415 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 360/802 | Loss: 1.318180 | Avg Loss: 1.358646 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 380/802 | Loss: 1.123769 | Avg Loss: 1.358981 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 400/802 | Loss: 1.227463 | Avg Loss: 1.358997 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 420/802 | Loss: 1.341895 | Avg Loss: 1.358571 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 440/802 | Loss: 1.552564 | Avg Loss: 1.358056 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 460/802 | Loss: 1.437574 | Avg Loss: 1.359045 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 480/802 | Loss: 1.208274 | Avg Loss: 1.358447 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 500/802 | Loss: 1.418380 | Avg Loss: 1.359035 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 520/802 | Loss: 1.534825 | Avg Loss: 1.360003 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 540/802 | Loss: 1.291035 | Avg Loss: 1.360430 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 560/802 | Loss: 1.538388 | Avg Loss: 1.361716 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 580/802 | Loss: 1.453310 | Avg Loss: 1.363156 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 600/802 | Loss: 1.359503 | Avg Loss: 1.362260 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 620/802 | Loss: 1.296522 | Avg Loss: 1.362950 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 640/802 | Loss: 1.531050 | Avg Loss: 1.363882 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 660/802 | Loss: 1.327564 | Avg Loss: 1.364362 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 680/802 | Loss: 1.474087 | Avg Loss: 1.364025 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 700/802 | Loss: 1.216528 | Avg Loss: 1.363890 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 720/802 | Loss: 1.407002 | Avg Loss: 1.365478 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 740/802 | Loss: 1.164376 | Avg Loss: 1.364697 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 760/802 | Loss: 1.149812 | Avg Loss: 1.365592 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 780/802 | Loss: 1.397279 | Avg Loss: 1.365321 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 800/802 | Loss: 1.454371 | Avg Loss: 1.365911 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])


Epoch:  23%|██▎       | 7/30 [32:49<1:53:10, 295.23s/it, epoch_loss=1.366007, best_loss=0.608634, batches=802]


Batch 20/802 | Loss: 1.333288 | Avg Loss: 1.382883 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 40/802 | Loss: 1.543653 | Avg Loss: 1.387049 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 60/802 | Loss: 1.385045 | Avg Loss: 1.382343 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 80/802 | Loss: 1.104728 | Avg Loss: 1.375985 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 100/802 | Loss: 1.250648 | Avg Loss: 1.374854 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 120/802 | Loss: 1.447311 | Avg Loss: 1.387750 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 140/802 | Loss: 1.395215 | Avg Loss: 1.385056 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 160/802 | Loss: 1.326295 | Avg Loss: 1.381723 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 180/802 | Loss: 1.323610 | Avg Loss: 1.381731 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 200/802 | Loss: 1.448988 | Avg Loss: 1.382830 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 220/802 | Loss: 1.472986 | Avg Loss: 1.379369 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 240/802 | Loss: 1.282513 | Avg Loss: 1.373881 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 260/802 | Loss: 1.282000 | Avg Loss: 1.370869 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 280/802 | Loss: 1.468919 | Avg Loss: 1.367017 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 300/802 | Loss: 1.268322 | Avg Loss: 1.367967 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 320/802 | Loss: 1.436458 | Avg Loss: 1.368777 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 340/802 | Loss: 1.169754 | Avg Loss: 1.367935 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 360/802 | Loss: 1.415570 | Avg Loss: 1.367530 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 380/802 | Loss: 1.429157 | Avg Loss: 1.367964 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 400/802 | Loss: 1.304953 | Avg Loss: 1.367462 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 420/802 | Loss: 1.477560 | Avg Loss: 1.368819 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 440/802 | Loss: 1.269282 | Avg Loss: 1.369565 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 460/802 | Loss: 1.401880 | Avg Loss: 1.369673 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 480/802 | Loss: 1.340683 | Avg Loss: 1.370104 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 500/802 | Loss: 1.279822 | Avg Loss: 1.368231 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 520/802 | Loss: 1.390968 | Avg Loss: 1.368929 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 540/802 | Loss: 1.211915 | Avg Loss: 1.367200 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 560/802 | Loss: 1.310326 | Avg Loss: 1.366923 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 580/802 | Loss: 1.356747 | Avg Loss: 1.366832 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 600/802 | Loss: 1.343652 | Avg Loss: 1.366599 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 620/802 | Loss: 1.269109 | Avg Loss: 1.366140 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 640/802 | Loss: 1.294410 | Avg Loss: 1.365436 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 660/802 | Loss: 1.315322 | Avg Loss: 1.365553 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 680/802 | Loss: 1.262615 | Avg Loss: 1.366662 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 700/802 | Loss: 1.345387 | Avg Loss: 1.366120 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 720/802 | Loss: 1.300455 | Avg Loss: 1.366516 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 740/802 | Loss: 1.476547 | Avg Loss: 1.366641 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 760/802 | Loss: 1.299059 | Avg Loss: 1.366218 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 780/802 | Loss: 1.516197 | Avg Loss: 1.366400 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 800/802 | Loss: 1.144650 | Avg Loss: 1.365995 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])


Epoch:  27%|██▋       | 8/30 [37:50<1:48:57, 297.15s/it, epoch_loss=1.365838, best_loss=0.608634, batches=802]


Batch 20/802 | Loss: 1.435043 | Avg Loss: 1.419699 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 40/802 | Loss: 1.448235 | Avg Loss: 1.389818 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 60/802 | Loss: 1.316180 | Avg Loss: 1.369126 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 80/802 | Loss: 1.240264 | Avg Loss: 1.369407 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 100/802 | Loss: 1.307702 | Avg Loss: 1.372993 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 120/802 | Loss: 1.304625 | Avg Loss: 1.363374 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 140/802 | Loss: 1.153451 | Avg Loss: 1.360258 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 160/802 | Loss: 1.492824 | Avg Loss: 1.360207 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 180/802 | Loss: 1.288094 | Avg Loss: 1.359599 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 200/802 | Loss: 1.600796 | Avg Loss: 1.361831 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 220/802 | Loss: 1.306265 | Avg Loss: 1.364566 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 240/802 | Loss: 1.518619 | Avg Loss: 1.365183 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 260/802 | Loss: 1.398551 | Avg Loss: 1.363756 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 280/802 | Loss: 1.180806 | Avg Loss: 1.363949 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 300/802 | Loss: 1.228848 | Avg Loss: 1.360933 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 320/802 | Loss: 1.248221 | Avg Loss: 1.361942 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 340/802 | Loss: 1.264889 | Avg Loss: 1.360398 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 360/802 | Loss: 1.257039 | Avg Loss: 1.361175 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 380/802 | Loss: 1.333596 | Avg Loss: 1.361618 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 400/802 | Loss: 1.339541 | Avg Loss: 1.362532 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 420/802 | Loss: 1.244629 | Avg Loss: 1.363016 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 440/802 | Loss: 1.459195 | Avg Loss: 1.362843 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 460/802 | Loss: 1.404738 | Avg Loss: 1.363533 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 480/802 | Loss: 1.296839 | Avg Loss: 1.363221 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 500/802 | Loss: 1.318635 | Avg Loss: 1.361971 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 520/802 | Loss: 1.495289 | Avg Loss: 1.363083 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 540/802 | Loss: 1.350607 | Avg Loss: 1.362833 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 560/802 | Loss: 1.232352 | Avg Loss: 1.363057 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 580/802 | Loss: 1.367901 | Avg Loss: 1.363202 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 600/802 | Loss: 1.348752 | Avg Loss: 1.363262 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 620/802 | Loss: 1.366966 | Avg Loss: 1.363296 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 640/802 | Loss: 1.239318 | Avg Loss: 1.363327 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 660/802 | Loss: 1.357534 | Avg Loss: 1.363644 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 680/802 | Loss: 1.484482 | Avg Loss: 1.364064 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 700/802 | Loss: 1.331629 | Avg Loss: 1.364113 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 720/802 | Loss: 1.446372 | Avg Loss: 1.364846 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 740/802 | Loss: 1.330610 | Avg Loss: 1.365186 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 760/802 | Loss: 1.352648 | Avg Loss: 1.365566 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 780/802 | Loss: 1.337925 | Avg Loss: 1.365702 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 800/802 | Loss: 1.161305 | Avg Loss: 1.365680 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])


Epoch:  30%|███       | 9/30 [42:52<1:44:32, 298.69s/it, epoch_loss=1.366057, best_loss=0.608634, batches=802]


Batch 20/802 | Loss: 1.316678 | Avg Loss: 1.352308 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 40/802 | Loss: 1.309725 | Avg Loss: 1.344780 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 60/802 | Loss: 1.338534 | Avg Loss: 1.352589 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 80/802 | Loss: 1.314954 | Avg Loss: 1.362756 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 100/802 | Loss: 1.537582 | Avg Loss: 1.375960 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 120/802 | Loss: 1.226813 | Avg Loss: 1.379160 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 140/802 | Loss: 1.451720 | Avg Loss: 1.383156 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 160/802 | Loss: 1.426197 | Avg Loss: 1.380827 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 180/802 | Loss: 1.293812 | Avg Loss: 1.380209 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 200/802 | Loss: 1.252684 | Avg Loss: 1.378074 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 220/802 | Loss: 1.489627 | Avg Loss: 1.378803 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 240/802 | Loss: 1.163507 | Avg Loss: 1.376292 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 260/802 | Loss: 1.363022 | Avg Loss: 1.374435 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 280/802 | Loss: 1.256696 | Avg Loss: 1.374287 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 300/802 | Loss: 1.340786 | Avg Loss: 1.375079 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 320/802 | Loss: 1.369541 | Avg Loss: 1.375085 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 340/802 | Loss: 1.591034 | Avg Loss: 1.374270 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 360/802 | Loss: 1.337727 | Avg Loss: 1.372244 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 380/802 | Loss: 1.453859 | Avg Loss: 1.372322 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 400/802 | Loss: 1.378533 | Avg Loss: 1.371146 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 420/802 | Loss: 1.225804 | Avg Loss: 1.370100 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 440/802 | Loss: 1.385014 | Avg Loss: 1.370997 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 460/802 | Loss: 1.537313 | Avg Loss: 1.371484 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 480/802 | Loss: 1.283535 | Avg Loss: 1.369442 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 500/802 | Loss: 1.517047 | Avg Loss: 1.370093 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 520/802 | Loss: 1.278048 | Avg Loss: 1.369166 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 540/802 | Loss: 1.477786 | Avg Loss: 1.367976 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 560/802 | Loss: 1.447562 | Avg Loss: 1.368004 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 580/802 | Loss: 1.656132 | Avg Loss: 1.366709 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 600/802 | Loss: 1.088738 | Avg Loss: 1.366998 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 620/802 | Loss: 1.288676 | Avg Loss: 1.366548 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 640/802 | Loss: 1.175946 | Avg Loss: 1.365641 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 660/802 | Loss: 1.163179 | Avg Loss: 1.365096 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 680/802 | Loss: 1.351546 | Avg Loss: 1.365582 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 700/802 | Loss: 1.494431 | Avg Loss: 1.365687 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 720/802 | Loss: 1.387066 | Avg Loss: 1.365475 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 740/802 | Loss: 1.459120 | Avg Loss: 1.365354 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 760/802 | Loss: 1.218796 | Avg Loss: 1.365848 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 780/802 | Loss: 1.229078 | Avg Loss: 1.365801 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 800/802 | Loss: 1.527328 | Avg Loss: 1.365822 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])


Epoch:  33%|███▎      | 10/30 [47:55<1:39:56, 299.82s/it, epoch_loss=1.365881, best_loss=0.608634, batches=802]


Batch 20/802 | Loss: 1.236020 | Avg Loss: 1.345294 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 40/802 | Loss: 1.311183 | Avg Loss: 1.339927 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 60/802 | Loss: 1.305727 | Avg Loss: 1.372951 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 80/802 | Loss: 1.485468 | Avg Loss: 1.372429 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 100/802 | Loss: 1.431764 | Avg Loss: 1.373048 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 120/802 | Loss: 1.301341 | Avg Loss: 1.373225 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 140/802 | Loss: 1.287422 | Avg Loss: 1.371505 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 160/802 | Loss: 1.274733 | Avg Loss: 1.366694 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 180/802 | Loss: 1.239717 | Avg Loss: 1.361522 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 200/802 | Loss: 1.314736 | Avg Loss: 1.365376 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 220/802 | Loss: 1.413767 | Avg Loss: 1.368499 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 240/802 | Loss: 1.269496 | Avg Loss: 1.366568 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 260/802 | Loss: 1.334257 | Avg Loss: 1.367650 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 280/802 | Loss: 1.282381 | Avg Loss: 1.368566 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 300/802 | Loss: 1.449608 | Avg Loss: 1.371028 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 320/802 | Loss: 1.291061 | Avg Loss: 1.370536 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 340/802 | Loss: 1.362169 | Avg Loss: 1.371098 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 360/802 | Loss: 1.356918 | Avg Loss: 1.372522 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 380/802 | Loss: 1.395250 | Avg Loss: 1.371995 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 400/802 | Loss: 1.301446 | Avg Loss: 1.370527 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 420/802 | Loss: 1.462229 | Avg Loss: 1.370032 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 440/802 | Loss: 1.422024 | Avg Loss: 1.370473 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 460/802 | Loss: 1.370413 | Avg Loss: 1.369477 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 480/802 | Loss: 1.407892 | Avg Loss: 1.371362 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 500/802 | Loss: 1.463485 | Avg Loss: 1.370977 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 520/802 | Loss: 1.321935 | Avg Loss: 1.370372 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 540/802 | Loss: 1.495728 | Avg Loss: 1.371961 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 560/802 | Loss: 1.418446 | Avg Loss: 1.371539 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 580/802 | Loss: 1.271354 | Avg Loss: 1.370225 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 600/802 | Loss: 1.345423 | Avg Loss: 1.371496 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 620/802 | Loss: 1.334440 | Avg Loss: 1.371258 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 640/802 | Loss: 1.002134 | Avg Loss: 1.369233 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 660/802 | Loss: 1.500548 | Avg Loss: 1.369629 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 680/802 | Loss: 1.425904 | Avg Loss: 1.368627 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 700/802 | Loss: 1.354166 | Avg Loss: 1.367348 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 720/802 | Loss: 1.556460 | Avg Loss: 1.367187 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 740/802 | Loss: 1.270162 | Avg Loss: 1.367690 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 760/802 | Loss: 1.351671 | Avg Loss: 1.366303 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 780/802 | Loss: 1.281934 | Avg Loss: 1.367460 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 800/802 | Loss: 1.547743 | Avg Loss: 1.366129 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])


Epoch:  37%|███▋      | 11/30 [53:02<1:35:37, 301.99s/it, epoch_loss=1.365833, best_loss=0.608634, batches=802]


Batch 20/802 | Loss: 1.349160 | Avg Loss: 1.360904 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 40/802 | Loss: 1.434670 | Avg Loss: 1.370623 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 60/802 | Loss: 1.374538 | Avg Loss: 1.363948 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 80/802 | Loss: 1.395782 | Avg Loss: 1.372347 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 100/802 | Loss: 1.469575 | Avg Loss: 1.364154 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 120/802 | Loss: 1.359512 | Avg Loss: 1.366611 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 140/802 | Loss: 1.137129 | Avg Loss: 1.365535 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 160/802 | Loss: 1.532499 | Avg Loss: 1.366720 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 180/802 | Loss: 1.578404 | Avg Loss: 1.368243 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 200/802 | Loss: 1.503469 | Avg Loss: 1.369162 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 220/802 | Loss: 1.268501 | Avg Loss: 1.368603 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 240/802 | Loss: 1.368101 | Avg Loss: 1.369192 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 260/802 | Loss: 1.408097 | Avg Loss: 1.368407 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 280/802 | Loss: 1.399958 | Avg Loss: 1.368181 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 300/802 | Loss: 1.322743 | Avg Loss: 1.366160 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 320/802 | Loss: 1.117352 | Avg Loss: 1.366570 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 340/802 | Loss: 1.409963 | Avg Loss: 1.365734 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 360/802 | Loss: 1.334734 | Avg Loss: 1.366951 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 380/802 | Loss: 1.150941 | Avg Loss: 1.367479 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 400/802 | Loss: 1.302334 | Avg Loss: 1.367872 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 420/802 | Loss: 1.595875 | Avg Loss: 1.369138 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 440/802 | Loss: 1.144783 | Avg Loss: 1.368681 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 460/802 | Loss: 1.364649 | Avg Loss: 1.368823 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 480/802 | Loss: 1.460590 | Avg Loss: 1.368452 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 500/802 | Loss: 1.439599 | Avg Loss: 1.368444 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 520/802 | Loss: 1.264805 | Avg Loss: 1.367102 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 540/802 | Loss: 1.458588 | Avg Loss: 1.368589 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 560/802 | Loss: 1.272224 | Avg Loss: 1.368561 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 580/802 | Loss: 1.332599 | Avg Loss: 1.366720 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 600/802 | Loss: 1.445250 | Avg Loss: 1.366486 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 620/802 | Loss: 1.503728 | Avg Loss: 1.365498 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 640/802 | Loss: 1.484413 | Avg Loss: 1.365405 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 660/802 | Loss: 1.465970 | Avg Loss: 1.366069 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 680/802 | Loss: 1.445010 | Avg Loss: 1.366880 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 700/802 | Loss: 1.365322 | Avg Loss: 1.366903 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 720/802 | Loss: 1.454026 | Avg Loss: 1.366716 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 740/802 | Loss: 1.355599 | Avg Loss: 1.366905 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 760/802 | Loss: 1.419766 | Avg Loss: 1.366371 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 780/802 | Loss: 1.206501 | Avg Loss: 1.365879 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 800/802 | Loss: 1.416720 | Avg Loss: 1.366193 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])


Epoch:  40%|████      | 12/30 [58:05<1:30:44, 302.47s/it, epoch_loss=1.365920, best_loss=0.608634, batches=802]


Batch 20/802 | Loss: 1.346591 | Avg Loss: 1.361709 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 40/802 | Loss: 1.412809 | Avg Loss: 1.362800 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 60/802 | Loss: 1.371987 | Avg Loss: 1.362607 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 80/802 | Loss: 1.452440 | Avg Loss: 1.361778 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 100/802 | Loss: 1.450931 | Avg Loss: 1.360499 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 120/802 | Loss: 1.405366 | Avg Loss: 1.359037 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 140/802 | Loss: 1.237107 | Avg Loss: 1.364655 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 160/802 | Loss: 1.398776 | Avg Loss: 1.368288 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 180/802 | Loss: 1.386072 | Avg Loss: 1.368524 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])





Batch 200/802 | Loss: 1.460121 | Avg Loss: 1.370697 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])




In [None]:
# #@markdown ### **Inference**

# # limit enviornment interaction to 200 steps before termination
# max_steps = 200
# env = PushTImageEnv()
# # use a seed >200 to avoid initial states seen in the training dataset
# Batch 20/802 | Loss: 0.639597 | Avg Loss: 0.438793 | Field Shape: torch.Size([32, 96, 96]) | Path Shape: torch.Size([32, 8, 2])
env.seed(100000)

# # get first observation
# obs, info = env.reset()

# # keep a queue of last 2 steps of observations
# obs_deque = collections.deque(
#     [obs] * obs_horizon, maxlen=obs_horizon)
# # save visualization and rewards
# imgs = [env.render(mode='rgb_array')]
# rewards = list()
# done = False
# step_idx = 0

# with tqdm(total=max_steps, desc="Eval PushTImageEnv") as pbar:
#     while not done:
#         B = 1
#         # stack the last obs_horizon number of observations
#         images = np.stack([x['image'] for x in obs_deque])
#         agent_poses = np.stack([x['agent_pos'] for x in obs_deque])

#         # normalize observation
#         nagent_poses = normalize_data(agent_poses, stats=stats['agent_pos'])
#         # images are already normalized to [0,1]
#         nimages = images

#         # device transfer
#         nimages = torch.from_numpy(nimages).to(device, dtype=torch.float32)
#         # (2,3,96,96)
#         nagent_poses = torch.from_numpy(nagent_poses).to(device, dtype=torch.float32)
#         # (2,2)

#         # infer action
#         with torch.no_grad():
#             # get image features
#             image_features = ema_nets['vision_encoder'](nimages)
#             # (2,512)

#             # concat with low-dim observations
#             obs_features = torch.cat([image_features, nagent_poses], dim=-1)

#             # reshape observation to (B,obs_horizon*obs_dim)
#             obs_cond = obs_features.unsqueeze(0).flatten(start_dim=1)

#             # initialize action from Guassian noise
#             noisy_action = torch.randn(
#                 (B, pred_horizon, action_dim), device=device)
#             naction = noisy_action

#             # init scheduler
#             noise_scheduler.set_timesteps(num_diffusion_iters)

#             for k in noise_scheduler.timesteps:
#                 # predict noise
#                 noise_pred = ema_nets['noise_pred_net'](
#                     sample=naction,
#                     timestep=k,
#                     global_cond=obs_cond
#                 )

#                 # inverse diffusion step (remove noise)
#                 naction = noise_scheduler.step(
#                     model_output=noise_pred,
#                     timestep=k,
#                     sample=naction
#                 ).prev_sample

#         # unnormalize action
#         naction = naction.detach().to('cpu').numpy()
#         # (B, pred_horizon, action_dim)
#         naction = naction[0]
#         action_pred = unnormalize_data(naction, stats=stats['action'])

#         # only take action_horizon number of actions
#         start = obs_horizon - 1
#         end = start + action_horizon
#         action = action_pred[start:end,:]
#         # (action_horizon, action_dim)

#         # execute action_horizon number of steps
#         # without replanning
#         for i in range(len(action)):
#             # stepping env
#             obs, reward, done, _, info = env.step(action[i])
#             # save observations
#             obs_deque.append(obs)
#             # and reward/vis
#             rewards.append(reward)
#             imgs.append(env.render(mode='rgb_array'))

#             # update progress bar
#             step_idx += 1
#             pbar.update(1)
#             pbar.set_postfix(reward=reward)
#             if step_idx > max_steps:
#                 done = True
#             if done:
#                 break

# # print out the maximum target coverage
# print('Score: ', max(rewards))

# # visualize
# from IPython.display import Video
# vwrite('vis.mp4', imgs)
# Video('vis.mp4', embed=True, width=256, height=256)