In [None]:
import torch
import numpy as np
import gymnasium as gym

import torch
import torch.nn as nn
from torch.optim import Adam
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
from PIL import Image
from pyvirtualdisplay import Display

# os.environ["DISPLAY"] = "99"
rand_seed=123
device = torch.device(f"cuda:{2}" if torch.cuda.is_available() else "cpu")

In [2]:
disp = Display(visible=0, size=(480, 480))
disp.start()

<pyvirtualdisplay.display.Display at 0x7f47c6b95550>

In [3]:
env_name = "Reacher-v5"
env = gym.make(env_name, render_mode="rgb_array")

num_prev = 3
data_length = 1000000
obs, _ = env.reset(seed=rand_seed)

dataset = []  # List of (obs, action, image)
next_imgs = []
obs_arr = []
acts_arr = []

fail_counter = 0
for obs_counter in tqdm(range(data_length)):
    # if obs_counter + 3 >= data_length:
    #     break
    if len(next_imgs) == 0:
        rendered_img = env.render()
        # rendered_img = rendered_img[:, 50:-50, :] # Acrobot
        rendered_img = rendered_img[100:-50, 100:-100, :]
        img_resize = Image.fromarray(rendered_img).resize((64, 64)) #.resize((cropped.shape[1] // 8 * 8, cropped.shape[0] // 8 * 8))
        next_imgs.append(np.asarray(img_resize))
        obs_arr.append(obs.copy())

    act = env.action_space.sample()
    next_obs, rew, done, _, _ = env.step(act)
    acts_arr.append(act)
    
    if done:
        fail_counter += 1
        obs, _ = env.reset()
        next_imgs = []
        obs_arr = []
        acts_arr = []
    else:
        if len(next_imgs) == num_prev + 1:
            obs_arr.pop(0)
            next_imgs.pop(0)
            acts_arr.pop(0)
        next_img = env.render()
        # next_img = next_img[:, 50:-50, :] # [y dir, x dir, color]
        next_img = next_img[100:-50, 100:-100, :] # [y dir, x dir, color]
        next_img = Image.fromarray(next_img).resize((64, 64)) #.resize((cropped.shape[1] // 8 * 8, cropped.shape[0] // 8 * 8)).resize((cropped.shape[1] // 8 * 8, cropped.shape[0] // 8 * 8))
        next_imgs.append(np.asarray(next_img))
        obs_arr.append(next_obs.copy())
        
        if len(next_imgs) == num_prev + 1:
            """
            fig, ax = plt.subplots(1, 3, figsize=(8, 10))
            ax[0].set_title("Current Image")
            ax[1].set_title("Current Image 2")
            ax[2].set_title("Future Image")
            for i, img in enumerate(next_imgs):
                ax[i].imshow(img)
                ax[i].set_xlabel(np.array2string(obs_arr[i], formatter={'float_kind': lambda x: "%.2f" % x}))
            plt.show()
            """
            # Store only valid transitions
            dataset.append({
                'obs': np.asarray(obs_arr).copy(),
                'action': np.asarray(acts_arr).copy(),
                'image': np.asarray(next_imgs[:-1]).copy(),
                'next_image': next_imgs[-1].copy()
            })
        obs = next_obs
    
    if (obs_counter+1) % 25000 == 0 or (obs_counter+1) == data_length:
        print("Saving at", obs_counter+1)
        all_obs = np.array([d['obs'] for d in dataset])
        all_obs = all_obs.reshape(all_obs.shape[0], -1)
        all_acts = np.array([d['action'] for d in dataset])
        all_acts = all_acts.reshape(-1, all_acts.shape[-1])
        if not os.path.exists("../data/" + env_name.lower() + "/actions.txt"):
            np.savetxt("../data/" + env_name.lower() + "/obs.txt", all_obs)
            np.savetxt("../data/" + env_name.lower() + "/actions.txt", all_acts)
        else:
            prev_obs = np.loadtxt("../data/" + env_name.lower() + "/obs.txt")
            prev_acts = np.loadtxt("../data/" + env_name.lower() + "/actions.txt")
            all_obs = np.concatenate((prev_obs, all_obs), axis=0)
            all_acts = np.concatenate((prev_acts, all_acts), axis=0)
            np.savetxt("../data/" + env_name.lower() + "/obs.txt", all_obs)
            np.savetxt("../data/" + env_name.lower() + "/actions.txt", all_acts)

        # make sure to delete old images from directory before saving new ones
        for i, d in enumerate(dataset):
            index = i + (obs_counter+1) - 25000
            for j, img in enumerate(d['image']):
                Image.fromarray(img).save(f"../data/{env_name.lower()}/{index:05d}_{j:05d}_curr.png")
            Image.fromarray(d['next_image']).save(f"../data/{env_name.lower()}/{index:05d}_next.png")
        
        next_imgs = []
        obs_arr = []
        acts_arr = []
        dataset = []

print(f"Failures (resets): {fail_counter}")

  2%|▏         | 24992/1000000 [01:58<1:17:43, 209.07it/s]

Saving at 25000


  5%|▍         | 49980/1000000 [04:16<1:13:13, 216.23it/s]

Saving at 50000


  7%|▋         | 74998/1000000 [06:34<1:11:57, 214.22it/s]

Saving at 75000


 10%|▉         | 99986/1000000 [08:52<1:09:06, 217.06it/s]

Saving at 100000


 12%|█▏        | 124979/1000000 [11:11<1:10:54, 205.67it/s]

Saving at 125000


 15%|█▍        | 149982/1000000 [13:30<1:07:43, 209.17it/s]

Saving at 150000


 17%|█▋        | 174983/1000000 [15:52<1:07:43, 203.05it/s]

Saving at 175000


 20%|█▉        | 199979/1000000 [18:15<1:03:31, 209.92it/s]

Saving at 200000


 22%|██▏       | 224998/1000000 [20:39<59:48, 215.95it/s]  

Saving at 225000


 25%|██▍       | 249991/1000000 [23:05<1:06:53, 186.89it/s] 

Saving at 250000


 27%|██▋       | 274992/1000000 [25:32<1:02:00, 194.89it/s] 

Saving at 275000


 30%|██▉       | 299992/1000000 [27:58<55:07, 211.65it/s]  

Saving at 300000


 32%|███▏      | 324983/1000000 [30:24<53:27, 210.46it/s]  

Saving at 325000


 35%|███▍      | 349998/1000000 [32:53<51:52, 208.84it/s]  

Saving at 350000


 37%|███▋      | 374991/1000000 [35:23<51:43, 201.37it/s]  

Saving at 375000


 40%|███▉      | 399988/1000000 [37:51<46:25, 215.37it/s]  

Saving at 400000


 42%|████▏     | 424988/1000000 [40:22<47:45, 200.69it/s]  

Saving at 425000


 45%|████▍     | 449998/1000000 [43:18<51:03, 179.52it/s]  

Saving at 450000


 47%|████▋     | 474989/1000000 [46:19<49:46, 175.80it/s]   

Saving at 475000


 50%|████▉     | 499991/1000000 [49:19<47:46, 174.46it/s]   

Saving at 500000


 52%|█████▏    | 524997/1000000 [52:21<46:57, 168.57it/s]   

Saving at 525000


 55%|█████▍    | 549988/1000000 [55:23<40:46, 183.96it/s]   

Saving at 550000


 57%|█████▋    | 574992/1000000 [58:27<39:22, 179.91it/s]  

Saving at 575000


 60%|█████▉    | 599981/1000000 [1:01:31<37:23, 178.34it/s]

Saving at 600000


 62%|██████▏   | 624986/1000000 [1:04:08<29:36, 211.06it/s]  

Saving at 625000


 65%|██████▍   | 649984/1000000 [1:06:42<27:11, 214.47it/s]  

Saving at 650000


 67%|██████▋   | 674990/1000000 [1:09:19<25:23, 213.35it/s]  

Saving at 675000


 70%|██████▉   | 699987/1000000 [1:11:55<23:24, 213.56it/s]  

Saving at 700000


 72%|███████▏  | 724981/1000000 [1:14:32<21:07, 217.01it/s]  

Saving at 725000


 75%|███████▍  | 749993/1000000 [1:17:11<19:38, 212.13it/s]  

Saving at 750000


 77%|███████▋  | 774984/1000000 [1:19:52<17:31, 213.92it/s]  

Saving at 775000


 80%|███████▉  | 799983/1000000 [1:22:33<15:55, 209.27it/s]  

Saving at 800000


 82%|████████▏ | 824980/1000000 [1:25:13<13:53, 209.92it/s]  

Saving at 825000


 85%|████████▍ | 849991/1000000 [1:28:19<13:33, 184.36it/s]  

Saving at 850000


 87%|████████▋ | 874990/1000000 [1:31:31<11:20, 183.73it/s]  

Saving at 875000


 90%|████████▉ | 899998/1000000 [1:34:43<09:32, 174.80it/s]  

Saving at 900000


 92%|█████████▏| 924983/1000000 [1:37:57<07:09, 174.55it/s]  

Saving at 925000


 95%|█████████▍| 949984/1000000 [1:41:12<04:31, 183.91it/s]  

Saving at 950000


 97%|█████████▋| 974989/1000000 [1:44:28<02:19, 178.82it/s]  

Saving at 975000


100%|█████████▉| 999998/1000000 [1:47:45<00:00, 177.98it/s] 

Saving at 1000000


100%|██████████| 1000000/1000000 [1:48:43<00:00, 153.29it/s]

Failures (resets): 0





In [3]:
all_obs = np.array([d['obs'] for d in dataset])
all_obs = all_obs.reshape(all_obs.shape[0], -1)
# all_obs.reshape(all_obs.shape[0], num_prev+1, env.observation_space.shape[0]).shape
np.savetxt("../data/" + env_name.lower() + "/obs.txt", all_obs)
all_acts = np.array([d['action'] for d in dataset])
all_acts = all_acts.reshape(-1, all_acts.shape[-1])
np.savetxt("../data/" + env_name.lower() + "/actions.txt", all_acts)

# make sure to delete old images from directory before saving new ones
for i, d in tqdm(enumerate(dataset), total=len(dataset)):
    for j, img in enumerate(d['image']):
        Image.fromarray(img).save(f"../data/{env_name.lower()}/{i:05d}_{j:05d}_curr.png")
    Image.fromarray(d['next_image']).save(f"../data/{env_name.lower()}/{i:05d}_next.png")

disp.stop()

100%|██████████| 99998/99998 [01:16<00:00, 1301.14it/s]


<pyvirtualdisplay.display.Display at 0x7f9321502490>

In [None]:
env_name = "HalfCheetah-v5"
env = gym.make(env_name, render_mode="rgb_array")

data_length = 20000
obs, _ = env.reset(seed=rand_seed)

dataset = []  # List of (obs, action, image)
next_imgs = []
obs_arr = []

num_prev = 4

fail_counter = 0
for obs_counter in tqdm(range(data_length)):
    # if obs_counter + 3 >= data_length:
    #     break
    if len(next_imgs) == 0:
        rendered_img = env.render()
        img_resize = Image.fromarray(rendered_img).resize((64, 64)) #.resize((cropped.shape[1] // 8 * 8, cropped.shape[0] // 8 * 8))
        next_imgs.append(np.asarray(img_resize))
        obs_arr.append(obs.copy())

    act = env.action_space.sample()
    next_obs, rew, done, _, _ = env.step(act)
    
    if done or obs_counter % 25 == 0:
        fail_counter += 1
        obs, _ = env.reset()
        next_imgs = []
        obs_arr = []
    else:
        if len(next_imgs) == num_prev + 1:
            obs_arr.pop(0)
            next_imgs.pop(0)
        next_img = env.render()
        next_img = Image.fromarray(next_img).resize((64, 64)) #.resize((cropped.shape[1] // 8 * 8, cropped.shape[0] // 8 * 8)).resize((cropped.shape[1] // 8 * 8, cropped.shape[0] // 8 * 8))
        next_imgs.append(np.asarray(next_img))
        obs_arr.append(next_obs.copy())
        
        if len(next_imgs) == num_prev + 1:
            """fig, ax = plt.subplots(1, 3, figsize=(8, 10))
            ax[0].set_title("Current Image")
            ax[1].set_title("Current Image 2")
            ax[2].set_title("Future Image")
            for i, img in enumerate(next_imgs):
                ax[i].imshow(img)
                ax[i].set_xlabel(np.array2string(obs_arr[i], formatter={'float_kind': lambda x: "%.2f" % x}))
            plt.show()"""
            # Store only valid transitions
            dataset.append({
                'obs': np.asarray(obs_arr).copy(),
                'action': act,
                'image': np.asarray(next_imgs[:-1]).copy(),
                'next_image': next_imgs[-1].copy()
            })
        obs = next_obs

print(f"Failures (resets): {fail_counter}")

 12%|█▏        | 2393/20000 [00:15<01:52, 156.70it/s]

In [None]:
all_obs = np.array([d['obs'] for d in dataset])
print(all_obs.shape)
all_obs = all_obs.reshape(all_obs.shape[0]-1, -1)
# all_obs.reshape(data_length-1, 3, env.observation_space.shape[0]).shape
np.savetxt("../data/" + env_name.lower() + "/obs.txt", all_obs)
all_acts = np.array([d['action'] for d in dataset])
np.savetxt("../data/" + env_name.lower() + "/actions.txt", all_acts)

# make sure to delete old images from directory before saving new ones
for i, d in tqdm(enumerate(dataset), total=len(dataset)):
    for j, img in enumerate(d['image']):
        Image.fromarray(img).save(f"../data/{env_name.lower()}/{i:05d}_{j:05d}_curr.png")
    Image.fromarray(d['next_image']).save(f"../data/{env_name.lower()}/{i:05d}_next.png")

(19997, 5, 17)


ValueError: cannot reshape array of size 1699745 into shape (19999,newaxis)