In [None]:
import random
import gym
from PIL import Image
import os
import shutil
from multiprocessing import Process, Queue

# Create Dataset for VAE and MDNRNN

## Definition of the function to be runned concurrently by multiple processes

In [None]:
PATH_DIR_DATASET = os.path.join(".","DATASET_ROLLOUTS")
CSV_SEP = ";"
COLUMNS_NAME = f"Step number{CSV_SEP} "\
               f"Path image observation{CSV_SEP} "\
               f"Action{CSV_SEP} "\
               f"Path image next observation{CSV_SEP} "\
               f"Reward{CSV_SEP} "\
               f"Terminated{CSV_SEP} "\
               f"Die{CSV_SEP} "\
               "Win\n"

In [None]:
def proc_function(q_inp, q_res, PATH_DIR_DATASET, COLUMNS_NAME):
    # - q_inp: Queue for the communication Main Process  -> Child Process
    # - q_res: Queue for the communication Child Process -> Main Process
    # - PATH_DIR_DATASET: path where the dataset will be created
    # - COLUMNS_NAME: header of the resulting csv file
    # the function runs a rollout of the procgen game, saves the game's frames
    # write in a csv all the informations about the images frames, actions, etc...
    
    # wait until an item is available in q_inp, this item is symbolic element is:
    # - 1 if the child process must execute a rollout of the game
    # - 0 if the child process must terminate
    rollout = q_inp.get()
    while (rollout != None):
        # create rollout's directory, this will contains all the frames and info about the rolllout
        PATH_DIR_ROLLOUT = os.path.join(PATH_DIR_DATASET, f"rollout_{rollout}")
        os.mkdir(PATH_DIR_ROLLOUT)
        
        # create directory of observations (i.e., the images frames)
        PATH_DIR_OBSERVATIONS = os.path.join(PATH_DIR_ROLLOUT, "observations")
        os.mkdir(PATH_DIR_OBSERVATIONS)
        
        # init csv rollout
        PATH_CSV_ROLLOUT = os.path.join(PATH_DIR_ROLLOUT, f"rollout_{rollout}.csv")
        str_csv = COLUMNS_NAME
        
        # create gym environment
        env = gym.make("procgen:procgen-leaper-v0", start_level=0, num_levels=0, render_mode="rgb_array")
        obs = env.reset()
        step = 0
        
        # save first image observation
        PATH_IMG_OBSERVATION = os.path.join(PATH_DIR_OBSERVATIONS, f"step_{step}.png")
        img = Image.fromarray(obs)
        img.save(PATH_IMG_OBSERVATION)
        step += 1
        
        # continue rollout
        terminated = 0
        while (not terminated):
            die = 0
            win = 0
            # execute a random action
            action = random.randint(0,4)
            obs, reward, terminated, info = env.step(action+2)
            
            if not terminated:
                # save the obtained frame as next observation of the previous frame
                PATH_IMG_NEXT_OBSERVATION = os.path.join(PATH_DIR_OBSERVATIONS, f"step_{step}.png")
                img = Image.fromarray(obs)
                img.save(PATH_IMG_NEXT_OBSERVATION)
        
            else:
                # save last observation as next frame when terminated, this is useful for the training of the MDNRNN
                PATH_IMG_NEXT_OBSERVATION = os.path.join(PATH_DIR_OBSERVATIONS, f"step_{step-1}.png")
                
                # check if the terminated rollout is a win or a lose
                if reward < 10:
                    die = 1
                else:
                    win = 1
                
            # write step's infos into csv
            str_csv = str_csv + (f"{step}{CSV_SEP} "\
                                 f"{PATH_IMG_OBSERVATION}{CSV_SEP} "\
                                 f"{action}{CSV_SEP} "\
                                 f"{PATH_IMG_NEXT_OBSERVATION}{CSV_SEP} "\
                                 f"{reward}{CSV_SEP} "\
                                 f"{terminated}{CSV_SEP} "\
                                 f"{die}{CSV_SEP} "\
                                 f"{win}\n"
                                )
            
            # use the obtained frame as current observation and execute next step if not terminated
            PATH_IMG_OBSERVATION = PATH_IMG_NEXT_OBSERVATION
            step += 1
        
        # write rollout's infos into csv
        f_csv = open(PATH_CSV_ROLLOUT, "w")
        f_csv.write(str_csv)
        f_csv.close()
        
        # return infos to main process
        q_res.put((rollout, PATH_DIR_ROLLOUT, PATH_CSV_ROLLOUT, step, win))
        
        # wait for a next rollout or for a termination
        rollout = q_inp.get()

In [None]:
def init_processes(num_proc, q_inp, q_res, PATH_DIR_DATASET, COLUMNS_NAME):
    # starts the concurrent processes
    processes = []
    for i in range(num_proc):
        processes.append(Process(target=proc_function, 
                                 args=(q_inp, 
                                       q_res, 
                                       PATH_DIR_DATASET, 
                                       COLUMNS_NAME,
                                      )
                                )
                        )
        processes[i].start()
    return processes

## Create the dataset

In [None]:
# total number of rollouts of the dataset
NUM_ROLLOUTS = 10000
# number of concurrent processes
num_proc = 11

In [None]:
# reset the dataset if it already exists
if not os.path.exists(PATH_DIR_DATASET):
    os.mkdir(PATH_DIR_DATASET)
else:
    shutil.rmtree(PATH_DIR_DATASET)
    os.mkdir(PATH_DIR_DATASET)

# init main csv 
PATH_CSV_ROLLOUTS = os.path.join(PATH_DIR_DATASET, "rollouts.csv")
f_csv_rollouts = open(PATH_CSV_ROLLOUTS, "w")
f_csv_rollouts.write(f"Rollout number{CSV_SEP} "\
                     f"Path directory{CSV_SEP} "\
                     f"Path csv{CSV_SEP} "\
                     f"Number steps{CSV_SEP} "\
                     "Win\n")

# init queues
q_inp = Queue()
q_res = Queue()

# start processes
processes = init_processes(num_proc, q_inp, q_res, PATH_DIR_DATASET, COLUMNS_NAME)

# start concurrent executions of the rollouts
for i in range(NUM_ROLLOUTS):
    q_inp.put(i)

# retrieve the results of the rollouts and save them to the main csv
for i in range(NUM_ROLLOUTS):
    rollout, PATH_DIR_ROLLOUT, PATH_CSV_ROLLOUT, step, win = q_res.get()
    
    # write into csv of the results of the obtained rollout
    f_csv_rollouts.write(f"{rollout}{CSV_SEP} "\
                         f"{PATH_DIR_ROLLOUT}{CSV_SEP} "\
                         f"{PATH_CSV_ROLLOUT}{CSV_SEP} "\
                         f"{step}{CSV_SEP} "\
                         f"{win}\n")
    
    if (i+1)%1000 == 0:
        print(f"Number of rollouts: {i+1}")
  
f_csv_rollouts.close()