In [1]:
import torch
from models.gen.edm import EDM
from models.gen.blocks import UNet
import tqdm
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
from data.data import SequencesDataset
import torchvision.transforms as transforms
import os
import numpy as np
import matplotlib.pyplot as plt
import random
import tqdm

In [2]:
input_channels = 3
context_length = 4
actions_count = 5
batch_size = 1
num_workers = 2
device = "cuda" if torch.cuda.is_available() else "cpu"
FPS = 1

# For Mac OS
if torch.backends.mps.is_available():
    device = "mps"
ROOT_PATH = "../"
def local_path(path):
    return os.path.join(ROOT_PATH, path)
MODEL_PATH = local_path("models/model.pth")

In [3]:
edm = EDM(
    p_mean=-1.2,
    p_std=1.2,
    sigma_data=0.5,
    model=UNet((input_channels) * (context_length + 1), 3, None, actions_count, context_length),
    context_length=context_length,
    device=device
)
edm.load_state_dict(torch.load(MODEL_PATH, map_location=device)["model"])

  edm.load_state_dict(torch.load(MODEL_PATH, map_location=device)["model"])


<All keys matched successfully>

In [4]:
transform_to_tensor = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((.5,.5,.5), (.5,.5,.5))
])

dataset = SequencesDataset(
    images_dir=local_path("training_data/snapshots"),
    actions_path=local_path("training_data/actions"),
    seq_length=context_length,
    transform=transform_to_tensor
)

In [10]:
from IPython.display import display, clear_output, Image as iImage
import ipywidgets as widgets
from PIL import Image
import time
import torch
import numpy as np
import io
import random

class State:
    def __init__(self):
        self.action = 0
        self.is_running = False
        self.frame_number = 0
        self.gen_imgs = None
        self.actions = None
        
    def reset(self):
        self.frame_number = 0
        self.is_running = False
        self.gen_imgs = None
        self.actions = None

state = State()

def on_button_click(input_action):
    state.action = input_action

# Create buttons
left_button = widgets.Button(description='Left')
right_button = widgets.Button(description='Right')
up_button = widgets.Button(description='Up')
down_button = widgets.Button(description='Down')
start_button = widgets.Button(description='Start')
stop_button = widgets.Button(description='Stop')

directions = {
    0: "Right",
    1: "Left",
    2: "Up",
    3: "Down"
}

# Set up button callbacks
right_button.on_click(lambda b: on_button_click(0))
left_button.on_click(lambda b: on_button_click(1))
up_button.on_click(lambda b: on_button_click(2))
down_button.on_click(lambda b: on_button_click(3))

# Display buttons horizontally
buttons = widgets.HBox([left_button, widgets.VBox([up_button, down_button]), right_button, start_button, stop_button])

button_output = widgets.Output()
image_output = widgets.Output()

with button_output:
    display(buttons)

def get_np_img(tensor: torch.Tensor) -> np.ndarray:
    return (tensor * 127.5 + 127.5).long().clip(0, 255).permute(1, 2, 0).detach().cpu().numpy().astype(np.uint8)

def render_frame():
    if not state.is_running:
        return
        
    if state.frame_number >= 80:
        stop_rendering()
        return
        
    start_time = time.time()
    
    # Initialize on first frame
    if state.frame_number == 0:
        index = random.randint(0, len(dataset) - 1)
        img, last_imgs, actions = dataset[index]
        state.gen_imgs = last_imgs.clone().to(device)
        state.actions = actions.to(device)
    
    # Automatically change direction for each frame
    state.action = random.choice(list(directions.keys()))  # Randomly select a direction
    state.actions = torch.concat((state.actions, torch.tensor([state.action], device=device)))
    
    with torch.no_grad():
        gen_img = edm.sample(
            10,
            state.gen_imgs[0].shape,
            state.gen_imgs[-context_length:].unsqueeze(0),
            state.actions[-context_length:].unsqueeze(0)
        )[0]
    
    state.gen_imgs = torch.concat([state.gen_imgs, gen_img[None, :, :, :]], dim=0)
    
    # Display frame
    display_img = get_np_img(gen_img)
    buffer = io.BytesIO()
    Image.fromarray(display_img).resize((360, 360), Image.Resampling.LANCZOS).save(buffer, format='PNG')
    
    with image_output:
        clear_output(wait=True)
        print(f'Direction: {directions[state.action]}')
        print(f'Frame: {state.frame_number + 1}/80')
        display(iImage(data=buffer.getvalue()))
    
    state.frame_number += 1
    
    # Maintain frame rate
    elapsed_time = time.time() - start_time
    delay = max(0, frame_time - elapsed_time)
    
    # Schedule next frame
    if state.is_running:
        timer = time.time() + delay
        while time.time() < timer:
            pass
        render_frame()

def start_rendering(b):
    if state.is_running:
        return
    state.reset()
    state.is_running = True
    render_frame()

def stop_rendering(b=None):
    state.reset()
    with image_output:
        clear_output(wait=True)
        print('Stopped rendering')

start_button.on_click(start_rendering)
stop_button.on_click(stop_rendering)

# Initialize constants
FPS = 0.2  # Adjust this to match your desired frame rate
frame_time = 1 / FPS

display(button_output)
display(image_output)

Output()

Output()