## config.py

In [None]:
from dataclasses import dataclass

@dataclass
class Config:
    model_name = "Version_one"
    load_model = False
    loaded_model = "v2"
    loaded_checkpoint = "abc"
    

    sample_every_x_batches = 48 #avoid being divisible by latent_persistence_turns
    inference_samples = 300
    inference_step_size = 1
    img_dir = 'images'

    model_resolution = (384, 512)
    features = [64, 128, 256, 512, 1024] 
    #64 -> 128 -> 256 -> 512 -> 1024/16 ->| 2048/16 |  -> 1024/32 append 2048/32 -> 1024/32 -> 512/64 -> 256/128 -> 128/256 -> 64/512
    #256 -> 128 -> 64 -> 32 -> 16 |  192 -> 96 -> 48 -> 24 -> 12
    
    latent_persistence_turns = 5
    predictions_per_image = 1

    time_embedding_dim = 512 #Must be even
    movement_embedding_dim = 512
    latent_dimension = 1024
    max_batches = 1000
    rotation_probability = 0.6
    initial_pages = [ 
        r"https://www.google.com/maps/@-38.5922817,176.8199381,3a,75y,271.3h,104.26t/data=!3m7!1e1!3m5!1s1Pt-bx9x0-vdMyd-R05xZw!2e0!6shttps:%2F%2Fstreetviewpixels-pa.googleapis.com%2Fv1%2Fthumbnail%3Fcb_client%3Dmaps_sv.tactile%26w%3D900%26h%3D600%26pitch%3D-14.260361780245958%26panoid%3D1Pt-bx9x0-vdMyd-R05xZw%26yaw%3D271.30228279891134!7i13312!8i6656?entry=ttu&g_ep=EgoyMDI1MDkwNy4wIKXMDSoASAFQAw%3D%3D",
        #r"https://www.google.com/maps/place/Pukaki+Canal+Road,+Canterbury+Region+7999/@-44.3213197,170.0447441,3a,75y,195.92h,96.45t/data=!3m7!1e1!3m5!1sI9B-vhHlUJpLWYoqiLarOw!2e0!6shttps:%2F%2Fstreetviewpixels-pa.googleapis.com%2Fv1%2Fthumbnail%3Fcb_client%3Dmaps_sv.tactile%26w%3D900%26h%3D600%26pitch%3D-6.447370168190574%26panoid%3DI9B-vhHlUJpLWYoqiLarOw%26yaw%3D195.9163922970516!7i16384!8i8192!4m6!3m5!1s0x6d2ae1d22564538b:0xac6cf5a0c84835c2!8m2!3d-44.2098275!4d170.0758422!16s%2Fg%2F11h6yd4mrn?entry=ttu&g_ep=EgoyMDI1MDkwNy4wIKXMDSoASAFQAw%3D%3D",
       # r"https://www.google.com/maps/@-45.1338186,168.7592437,3a,75y,211.88h,90t/data=!3m7!1e1!3m5!1slN5kBn2x9rltM8rmRvc6iw!2e0!6shttps:%2F%2Fstreetviewpixels-pa.googleapis.com%2Fv1%2Fthumbnail%3Fcb_client%3Dmaps_sv.tactile%26w%3D900%26h%3D600%26pitch%3D0%26panoid%3DlN5kBn2x9rltM8rmRvc6iw%26yaw%3D211.87823!7i13312!8i6656?entry=ttu&g_ep=EgoyMDI1MDkwNy4wIKXMDSoASAFQAw%3D%3D"
    ] # make one of these huia

    learning_rate = 4e-5
    weight_decay = 0.01

    graph_update_freq = 1
    recent_losses_shown = 1500
    loss_bucket_size = 10
    save_freq = 100





## grapher.py

In [None]:
import matplotlib.pyplot as plt

class Grapher:
    def __init__(self):
        plt.ion()
        self.fig, (self.ax1, self.ax2) = plt.subplots(1, 2, figsize=(12, 5))
       
        # Data storage for graph 1
        self.train_x_1 = []
        self.train_y_1 = []
        self.val_x_1 = []
        self.val_y_1 = []
        self.recon_val_x_1 = []
        self.recon_val_y_1 = []
        self.div_val_x_1 = []
        self.div_val_y_1 = []
        
        # Data storage for graph 2
        self.train_x_2 = []
        self.train_y_2 = []
        self.val_x_2 = []
        self.val_y_2 = []
        self.recon_val_x_2 = []
        self.recon_val_y_2 = []
        self.div_val_x_2 = []
        self.div_val_y_2 = []
       
        # Line objects for graph 1
        self.train_line_1, = self.ax1.plot([], [], 'r-', label='Training', linewidth=1)
        self.val_line_1, = self.ax1.plot([], [], 'b-', label='Validation')
        self.recon_val_line_1, = self.ax1.plot([], [], 'g-', label='Reconstruction Val', linewidth=1)
        self.div_val_line_1, = self.ax1.plot([], [], 'm-', label='Divergence Val', linewidth=1)
        
        # Line objects for graph 2
        self.train_line_2, = self.ax2.plot([], [], 'r-', label='Training', linewidth=1)
        self.val_line_2, = self.ax2.plot([], [], 'b-', label='Validation')
        self.recon_val_line_2, = self.ax2.plot([], [], 'g-', label='Reconstruction Val', linewidth=1)
        self.div_val_line_2, = self.ax2.plot([], [], 'm-', label='Divergence Val', linewidth=1)
        
        self.ax1.set_yscale("log")
        self.ax2.set_yscale("log")
     
        self.ax1.legend()
        self.ax2.legend()
       
    def update_line(self, graph_index, line_type, x_data, y_data):
        if graph_index == 0:
            if line_type == "Training":
                self.train_x_1 = x_data
                self.train_y_1 = y_data
                self.train_line_1.set_data(x_data, y_data)
            elif line_type == "Validation":
                self.val_x_1 = x_data
                self.val_y_1 = y_data
                self.val_line_1.set_data(x_data, y_data)
            elif line_type == "Reconstruction loss":
                self.recon_val_x_1 = x_data
                self.recon_val_y_1 = y_data
                self.recon_val_line_1.set_data(x_data, y_data)
            elif line_type == "Divergence loss":
                self.div_val_x_1 = x_data
                self.div_val_y_1 = y_data
                self.div_val_line_1.set_data(x_data, y_data)
            self.ax1.relim()
            self.ax1.autoscale_view()
        else:
            if line_type == "Training":
                self.train_x_2 = x_data
                self.train_y_2 = y_data
                self.train_line_2.set_data(x_data, y_data)
            elif line_type == "Validation":
                self.val_x_2 = x_data
                self.val_y_2 = y_data
                self.val_line_2.set_data(x_data, y_data)
            elif line_type == "Reconstruction Val":
                self.recon_val_x_2 = x_data
                self.recon_val_y_2 = y_data
                self.recon_val_line_2.set_data(x_data, y_data)
            elif line_type == "Divergence Val":
                self.div_val_x_2 = x_data
                self.div_val_y_2 = y_data
                self.div_val_line_2.set_data(x_data, y_data)
            self.ax2.relim()
            self.ax2.autoscale_view()
       
        plt.draw()
        plt.pause(0.001)

## inference.py

In [None]:
from config import Config
import torch
import matplotlib.pyplot as plt
import numpy as np
import os
import tqdm

fig, ax, im = None, None, None

@torch.no_grad()
def sample_next_img(model, device, sample_name, prev_img, movement, latent, next_img=None):
    model.eval()
    base_img = prev_img.clone()
    global fig, ax, im
    if fig is None:
        plt.ion()
        fig, ax = plt.subplots()
        im = ax.imshow(np.zeros((Config.model_resolution[0], Config.model_resolution[1], 3), dtype=np.float32))
        ax.axis("off")
        plt.show()

    out_img = None
    pbar = tqdm.tqdm(range(Config.inference_samples), desc=f'Sampling')
    for time_step in pbar:
        time_step = torch.tensor([time_step]).to(device)
        dx = Config.inference_step_size / Config.inference_samples
        delta = model.predict_delta(prev_img, time_step, movement, latent)
        prev_img += delta * dx
        prev_img = torch.clamp(prev_img, 0, 1)

        out_img = prev_img.squeeze(0).detach().cpu().numpy()
        out_img = np.transpose(out_img, (1, 2, 0))
        im.set_data(out_img)
        fig.canvas.draw()
        fig.canvas.flush_events()

    model_dir = os.path.join(Config.img_dir, Config.model_name)
    if not os.path.exists(model_dir):
            os.makedirs(model_dir, exist_ok=True)
    save_path = os.path.join(model_dir, f"sample_{sample_name}.png")
    plt.imsave(save_path, np.clip(out_img, 0, 1))

    target = np.transpose(next_img.squeeze(0).detach().cpu().numpy(), (1, 2, 0))
    save_path = os.path.join(model_dir, f"sample_{sample_name}_target.png")
    plt.imsave(save_path, np.clip(target, 0, 1))
    base = np.transpose(base_img.squeeze(0).detach().cpu().numpy(), (1, 2, 0))
    save_path = os.path.join(model_dir, f"sample_{sample_name}_base.png")
    plt.imsave(save_path, np.clip(base, 0, 1))



## model.py

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from config import Config


class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, is_conditioned=True):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(inplace=True)

        self.is_conditioned = is_conditioned
        if is_conditioned:
            self.condition = nn.Linear(Config.movement_embedding_dim + Config.latent_dimension + Config.time_embedding_dim, out_channels*2)
        #self.memory_conditioning = nn.Linear(Config.movement_embedding_dim + Config.latent_dimension, out_channels)
        #self.time_conditioning = nn.Linear(Config.time_embedding_dim, out_channels)
        self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
    
    def forward(self, x, conditioning=None):
        fx = self.relu(self.bn1(self.conv1(x))) # test no batchnorm
        fx = self.bn2(self.conv2(fx))

        if self.is_conditioned:
            gamma_beta = self.condition(conditioning)
            gamma, beta = gamma_beta.chunk(2, dim=1)
            gamma_scaled = gamma.unsqueeze(-1).unsqueeze(-1)
            beta_scaled = beta.unsqueeze(-1).unsqueeze(-1)
            #Test doing a layer/groupnorm here before we modulate to be more like adaGN.
            fx = fx * gamma_scaled + beta_scaled

        hx = self.relu(fx + self.shortcut(x))
        return hx

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(inplace=True)

        self.condition = nn.Linear(Config.movement_embedding_dim + Config.latent_dimension + Config.time_embedding_dim, out_channels*2)
        #self.memory_conditioning = nn.Linear(Config.movement_embedding_dim + Config.latent_dimension, out_channels)
        #self.time_conditioning = nn.Linear(Config.time_embedding_dim, out_channels)
    
    def forward(self, x, conditioning):
        fx = self.relu(self.bn1(self.conv1(x))) # test no batchnorm
        fx = self.bn2(self.conv2(fx))

        gamma_beta = self.condition(conditioning)
        gamma, beta = gamma_beta.chunk(2, dim=1)
        gamma_scaled = gamma.unsqueeze(-1).unsqueeze(-1)
        beta_scaled = beta.unsqueeze(-1).unsqueeze(-1)
        #Test doing a layer/groupnorm here before we modulate to be more like adaGN.
        fx = fx * gamma_scaled + beta_scaled

        hx = self.relu(fx)
        return hx

class Dynamics(nn.Module):
    def __init__(self, in_channels=3):
        #Give the dynamics the current timestep too - integer, not the float 0-1.
        super().__init__()
        features = Config.features
        layers = []
        for feature in features:
            layers.append(ResBlock(in_channels, feature, is_conditioned=False))
            layers.append(nn.Conv2d(feature, feature, kernel_size=3, stride=2, padding=1))
            in_channels = feature
        layers.append(nn.AdaptiveAvgPool2d((1, 1)))
        self.down = nn.Sequential(*layers)

        self.project_img = nn.Linear(features[-1], Config.latent_dimension)
        self.predict = nn.Linear(Config.latent_dimension*2, Config.latent_dimension)

    def forward(self, img, prev_state):
        latent_img = self.down(img).squeeze(-1).squeeze(-1)
        latent_img = self.project_img(latent_img)
        state = torch.cat((latent_img, prev_state), dim=1)
        new_state = self.predict(state)
        return new_state


class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()
        features = Config.features
        self.downs = nn.ModuleList()
        self.strides = nn.ModuleList()
        for feature in features:
            self.downs.append(ResBlock(in_channels, feature))
            self.strides.append(nn.Conv2d(feature, feature, kernel_size=3, stride=2, padding=1))
            in_channels = feature
        
        self.bottleneck = ResBlock(features[-1], features[-1]*2)
        
        self.ups = nn.ModuleList()
        self.up_convs = nn.ModuleList()
        rev_features = features[::-1]
        for feature in rev_features:
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
            self.up_convs.append(ConvBlock(feature*2, feature))

        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

        self.embed_movement = nn.Linear(6, Config.movement_embedding_dim)

        
    def forward(self, x, t, m, l_emb):
        skip_connections = []

        #Embedding time step
        t = t.unsqueeze(-1) #[num_dim, 1]
        half_time_dim = Config.time_embedding_dim // 2
        freq_exponents = torch.arange(half_time_dim, dtype=torch.float32, device=t.device)
        freq_exponents = -math.log(10000) * freq_exponents / half_time_dim
        freqs = torch.exp(freq_exponents) 
        freqs = freqs.unsqueeze(0)
        angles = t * freqs  
        t_emb = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)

        m_emb = self.embed_movement(m)
        conditioning_emb = torch.cat([m_emb, l_emb, t_emb], dim=1)

        for down, stride in zip(self.downs, self.strides):
            x = down(x, conditioning_emb)
            skip_connections.append(x)
            x = stride(x)

            #Maybe do conditioning here
        
        x = self.bottleneck(x, conditioning_emb) # add transformer layers here.
        
        skip_connections = skip_connections[::-1]
        for idx in range(len(self.ups)):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx]
            
            if x.shape != skip_connection.shape:
                x = F.interpolate(x, size=skip_connection.shape[2:])
            
            x = torch.cat([skip_connection, x], dim=1)
            x = self.up_convs[idx](x, conditioning_emb)
        
        return self.final_conv(x) #delta predictions in each channel.

class WorldModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = UNet()
        self.dynamics = Dynamics()
    
    def predict_delta(self, x, t, m, l_emb):
        return self.backbone(x, t, m, l_emb)

    def predict_dynamics(self, img, prev_state):
        return self.dynamics(img, prev_state)


def get_model():
    return WorldModel()

## simulator.py

In [None]:

import asyncio
from playwright.async_api import async_playwright
from PIL import Image
import io
from torchvision import transforms
from config import Config
import torch
import random
import math
import numpy as np

def convert_to_vector(x, y, w, h):
    nx = (2.0 * x - w) / w
    ny = (h - 2.0 * y) / h
    # project to unit sphere
    length2 = nx*nx + ny*ny
    if length2 > 1.0:
        norm = 1.0 / np.sqrt(length2)
        return np.array([nx*norm, ny*norm, 0.0])
    else:
        z = np.sqrt(1.0 - length2)
        return np.array([nx, ny, z])

class StreetViewTab:
    def __init__(self, id, page):
        self.id = id
        self.page = page

    async def take_screenshot(self):
        screenshot_bytes = await self.page.screenshot(timeout=180_000)
        image = Image.open(io.BytesIO(screenshot_bytes)).convert('RGB')
        transform = transforms.Compose([
            transforms.Resize(Config.model_resolution),    
            transforms.ToTensor(),             # Convert to tensor
            # transforms.Normalize(              # Normalize tensor
            #     mean=[0.485, 0.456, 0.406],
            #     std=[0.229, 0.224, 0.225]
            # )
        ])

        tensor = transform(image)
        return tensor

    async def move(self):
        await self.page.wait_for_selector('canvas.aFsglc', timeout=180_000)   
        element = self.page.locator('canvas.aFsglc').first
        box = await element.bounding_box()
        if not box:
            return

        async def sample_point():
            """Keep sampling until point is on canvas.aFsglc."""
            while True:
                rx = random.random()
                ry = random.random()
                x = box['x'] + rx * box['width']
                y = box['y'] + ry * box['height']
                hit = await self.page.evaluate(
                    """([x, y]) => document.elementFromPoint(x, y)?.className || null""",
                    [x, y]
                )
                if hit and "aFsglc" in hit:
                    return x, y, rx, ry

        is_rotation = random.random() < Config.rotation_probability
        if is_rotation:
            # Pick two valid points
            x1, y1, _, _ = await sample_point()
            x2, y2, _, _ = await sample_point()

            await self.page.mouse.move(x1, y1)
            await self.page.mouse.down(button="left")
            await self.page.mouse.move(x2, y2, steps=15)
            await self.page.mouse.up(button="left")
            await asyncio.sleep(4)

            # Convert to vectors relative to box
            v1 = convert_to_vector(x1 - box['x'], y1 - box['y'], box['width'], box['height'])
            v2 = convert_to_vector(x2 - box['x'], y2 - box['y'], box['width'], box['height'])
            axis = np.cross(v1, v2)
            axis /= np.linalg.norm(axis)
            dot = np.clip(np.dot(v1, v2), -1.0, 1.0)
            theta = np.arccos(dot)
            wq = np.cos(theta / 2.0)
            s = np.sin(theta / 2.0)
            xq, yq, zq = axis * s
            return [wq, xq, yq, zq, 0, 0]

        else:
            x, y, rx, ry = await sample_point()
            await self.page.mouse.move(x, y)
            await self.page.mouse.click(x, y)
            await asyncio.sleep(4)
            return [1, 0, 0, 0, rx - 0.5, ry - 0.5]

class Simulator():
    def __init__(self):
        self.initial_pages = Config.initial_pages
        self.browser = None
        self.context = None
        self.playwright = None

    async def setup(self):
        self.playwright = await async_playwright().start()
        self.browser = await self.playwright.chromium.launch(headless=False)
        self.context = await self.browser.new_context()

        pages = await asyncio.gather(*(self.context.new_page() for _ in self.initial_pages))
        self.tabs = [StreetViewTab(i, page) for i, page in enumerate(pages)]
        await asyncio.gather(*(tab.page.goto(self.initial_pages[i]) for i, tab in enumerate(self.tabs)))
        print("All tabs loaded")

    async def close(self):
        await self.context.close()
        await self.browser.close()
        await self.playwright.stop()

    async def get_images(self):
        images_list = await asyncio.gather(*(tab.take_screenshot() for tab in self.tabs))
        images_tensor = torch.stack(images_list) #[page_num, 3, h, w]
        return images_tensor
    
    async def move(self):
        move_list = await asyncio.gather(*(tab.move() for tab in self.tabs))
        movement_tensor = torch.tensor(move_list)
        return movement_tensor #[page_num, 6]





## main.py

In [None]:
from config import Config
from model import WorldModel
from train import Trainer
from simulator import Simulator
import asyncio

async def main():
    try:
        model = WorldModel()
        simulator = Simulator()
        await simulator.setup()
        trainer = Trainer(model, simulator)
        if Config.load_model:
            trainer.load_checkpoint()
        await trainer.train()
    finally:
        await simulator.close()

if __name__ == '__main__':
    asyncio.run(main())