Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved readability of the VIN model, in addition to minor changes #12

Merged
merged 11 commits into from Oct 2, 2020
5 changes: 4 additions & 1 deletion .gitignore
Expand Up @@ -107,4 +107,7 @@ venv.bak/
*.npz

# pth
*.pth
*.pth

# jetbrains project settings
.idea
4 changes: 2 additions & 2 deletions dataset/dataset.py
Expand Up @@ -16,7 +16,7 @@ def __init__(self,
self.imsize = imsize
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set
self.train = train # Training set or test set

self.images, self.S1, self.S2, self.labels = \
self._process(file, self.train)
Expand Down Expand Up @@ -58,7 +58,7 @@ def _process(self, file, train):
images = images.astype(np.float32)
S1 = S1.astype(int) # (S1, S2) location are integers
S2 = S2.astype(int)
labels = labels.astype(int) # labels are integers
labels = labels.astype(int) # Labels are integers
# Print number of samples
if train:
print("Number of Train Samples: {0}".format(images.shape[0]))
Expand Down
21 changes: 18 additions & 3 deletions dataset/make_training_data.py
Expand Up @@ -3,6 +3,8 @@
import numpy as np
from dataset import *

import argparse

sys.path.append('.')
from domains.gridworld import *
from generators.obstacle_gen import *
Expand Down Expand Up @@ -49,7 +51,7 @@ def make_data(dom_size, n_domains, max_obs, max_obs_size, n_traj,
# Get final map
im = obs.get_final()
# Generate gridworld from obstacle map
G = gridworld(im, goal[0], goal[1])
G = GridWorld(im, goal[0], goal[1])
# Get value prior
value_prior = G.t_get_reward_prior()
# Sample random trajectories to our goal
Expand Down Expand Up @@ -89,7 +91,7 @@ def make_data(dom_size, n_domains, max_obs, max_obs_size, n_traj,
return X_f, S1_f, S2_f, Labels_f


def main(dom_size=[28, 28],
def main(dom_size=(28, 28),
n_domains=5000,
max_obs=50,
max_obs_size=2,
Expand All @@ -113,4 +115,17 @@ def main(dom_size=[28, 28],


if __name__ == '__main__':
main()

parser = argparse.ArgumentParser()
parser.add_argument("--size", "-s", type=int, help="size of the domain", default=28)
parser.add_argument("--n_domains", "-nd", type=int, help="number of domains", default=5000)
parser.add_argument("--max_obs", "-no", type=int, help="maximum number of obstacles", default=50)
parser.add_argument("--max_obs_size", "-os", type=int, help="maximum obstacle size", default=2)
parser.add_argument("--n_traj", "-nt", type=int, help="number of trajectories", default=7)
parser.add_argument("--state_batch_size", "-bs", type=int, help="state batch size", default=1)

args = parser.parse_args()
size = args.size

main(dom_size=(size, size), n_domains=args.n_domains, max_obs=args.max_obs,
max_obs_size=args.max_obs_size, n_traj=args.n_traj, state_batch_size=args.state_batch_size)