In [1]:
import pickle
import os

import torch
import torch.nn.functional as F
import numpy as np

from torch.utils.data import Dataset, DataLoader 
from tqdm import tqdm

from dataloader import *
import utils
import visualizer as visual

# Visualization
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns



In [2]:
def create_dataset(n_nodes: int, env_type: str, trajectory_length: int, num_desired_trajectories: int, args = None, seed = 70, dump=False, dir=None, fname = None):

    utils.set_random_seed(seed)

    # Initialize the environment
    env = GraphEnv(
        n_items=n_nodes,                     # number of possible observations
        env=env_type, 
        batch_size=trajectory_length, 
        num_desired_trajectories=num_desired_trajectories, 
        device=None, 
        unique=True,                         # each state is assigned a unique observation if true
        args=args
    )

    # Generate datasets
    train_dataset = env.gen_dataset()
    test_dataset = env.gen_dataset()

    # Create dataloaders
    train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)

    # If dump is True, save the datasets to a local file
    if dump:

        if dir is None:
            dir = os.getcwd()
        if not os.path.exists(dir):
            os.makedirs(dir)    

        if fname is None:
            fname = utils.generate_data_name(n_nodes, env_type, trajectory_length, num_desired_trajectories, args, seed)

        dump_path = os.path.join(dir, fname)

        metadata = {
            'n_nodes': n_nodes,
            'trajectory_length': trajectory_length,
            'num_desired_trajectories': num_desired_trajectories,
            'env_config': {
                'n_items': n_nodes,
                'env_type': env_type,
                'batch_size': trajectory_length,
                'num_desired_trajectories': num_desired_trajectories,
                'unique': True,
                'args': args
            },
            'seed' : seed,  
        }

        with open(dump_path, 'wb') as f:
            pickle.dump({
                'train_dataset': train_dataset,
                'test_dataset': test_dataset,
                'metadata': metadata,
                'env': env,
            }, f)
        print(f"Datasets dumped to {dump_path}")

    return train_dataloader, test_dataloader, env

In [7]:
vis = False
# seed: 65, 70, 75, 80 
for seed in [80]:
    train, test, env = create_dataset(
        n_nodes=9,
        env_type = 'grid',
        #env_type = 'tree',
        trajectory_length=16,
        num_desired_trajectories=30,
        args = {"rows": 3, "cols": 3},
        #args = {"levels": 4},
        seed = seed, 
        dump = True,                             # set to True to dump to file
        dir = "./data",
        fname = None,                         # customize the file path
    )

    if vis:
        visual.visualize_env(env)
        for traj in train:
            print(f'Sample trajectory: of shape {traj[0].shape}')
            print(traj[0])
            break



    # env = GraphEnv( n_items=n_nodes,                     # number of possible observations
    #             env='tree', 
    #             batch_size=trajectory_length, 
    #             num_desired_trajectories=num_desired_trajectories, 
    #             device=None, 
    #             unique=True,                         # each state is assigned a unique observation if true
    #             args = {"levels": 4}
    #         )

Datasets dumped to ./data/data_n_nodes_9_env_grid_traj_len_16_n_traj_30_args_{'rows': 3, 'cols': 3}_seed_80.pickle
