In [1]:
import json, os
from typing import Optional, Dict, Any, List
import utils as util

from scienceworld import ScienceWorldEnv
from textworld import EnvInfos

# from alfworld.agents.environment import get_environment
# import alfworld.agents.modules.generic as generic

# Textworld

### Test Data Creation

In [2]:
num_gens = 1

# Define our inputs
save_path = f"{os.getcwd()}/generated_data"
sampling_args = [
    util.SamplingArgs(name="cooking_singlemove", max_samples_per_env=20, rollout_length=1, total_samples=25, max_step_overlap=0),
    util.SamplingArgs(name="cooking_multimove", max_samples_per_env=8, rollout_length=5, total_samples=25, max_step_overlap=1),
]
cooking_wrapper = util.Textworld_Cooking_Wrapper_Env()

# Generate our data
data_sampler = util.Textworld_Sampling_Manager(
    wrapper = cooking_wrapper,
    sampling_args=sampling_args
)

for i in range(num_gens):
    data_sampler.generate_data(save_path=save_path, max_iters=None)

### Wrapper

In [3]:
test_cooking_wrapper = False

if test_cooking_wrapper:
    cooking_wrapper = util.Textworld_Cooking_Wrapper_Env()

    # Generate new game and load env in
    cooking_wrapper.generate_new_game(randomize_gen_args=True)
    cooking_wrapper.load_env()

    while not cooking_wrapper.game_state.done:
        cooking_wrapper.step_env()
        if cooking_wrapper.game_state.moves == 10:
            print(cooking_wrapper.get_env_state(full_state=True))
        else:
           print(cooking_wrapper.get_env_state(full_state=False))

### Playable Textworld

In [4]:
test_playable_tw = False

if test_playable_tw:
    play_textworld = True

    # Setup wrapper and env
    cooking_wrapper = util.Textworld_Cooking_Wrapper_Env()
    cooking_gen_args = {
        "recipe": 5,
        "take": 3,
        "go": 1,
        "open_": True,
        "cook": True,
        "cut": True,
        "drop": True
    }
    cooking_wrapper.generate_new_game(
       randomize_gen_args=False,
       game_gen_args=cooking_gen_args
    )
    cooking_wrapper.load_env()
    
    # Print initial state
    util.print_env_state(cooking_wrapper.get_env_state(full_state=True))
    
    # Main play loop
    done = False
    while not done:
        if play_textworld:
            command = input("> ")
        else:
            command = None
        if command == "stop":
            break
        done = cooking_wrapper.step_env(command)
        util.print_env_state(cooking_wrapper.get_env_state(full_state=False))

# Science World

### Test Data Creation

In [5]:
num_gens = 1

# Define our inputs
save_path = f"{os.getcwd()}/generated_data"
sampling_args = [
    util.SamplingArgs(name="scienceworld_singlemove", max_samples_per_env=20, rollout_length=1, total_samples=25, max_step_overlap=0),
    util.SamplingArgs(name="scienceworld_multimove", max_samples_per_env=8, rollout_length=5, total_samples=25, max_step_overlap=1),
]
scienceworld_wrapper = util.Scienceworld_Wrapper_Env()

# Generate our data
data_sampler = util.Textworld_Sampling_Manager(
    wrapper = scienceworld_wrapper,
    sampling_args=sampling_args
)

for i in range(num_gens):
    data_sampler.generate_data(save_path=save_path, max_iters=None)

### Play around with SW

In [11]:
test_scienceworld_wrapper = False

if test_scienceworld_wrapper:
    scienceworld_wrapper = util.Scienceworld_Wrapper_Env()
    scienceworld_wrapper.load_env(randomize_gen_args=True)

    while not scienceworld_wrapper.game_state.done:
        done = scienceworld_wrapper.step_env()
        if scienceworld_wrapper.game_state.moves == 5:
            print(scienceworld_wrapper.get_env_state(full_state=True))
        else:
            print(scienceworld_wrapper.get_env_state(full_state=False))

In [2]:
# Pretty print helper
def list_tasks(env: ScienceWorldEnv, print_tasks=False) -> List[str]:
    tasks = env.get_task_names()
    for i, t in enumerate(tasks):
        vars = env.get_max_variations(t)
        if print_tasks:
            print(f"{i:2d} [{vars:>4}]: {t}")
    return tasks

# Optionally can print out our task list
print_tasks = True

if print_tasks:
    scienceworld_wrapper = util.Scienceworld_Wrapper_Env()
    tasks = list_tasks(scienceworld_wrapper.env, print_tasks=True)

 0 [  30]: boil
 1 [  30]: change-the-state-of-matter-of
 2 [  32]: chemistry-mix
 3 [  36]: chemistry-mix-paint-secondary-color
 4 [  36]: chemistry-mix-paint-tertiary-color
 5 [ 300]: find-animal
 6 [ 300]: find-living-thing
 7 [ 300]: find-non-living-thing
 8 [ 300]: find-plant
 9 [  30]: freeze
10 [ 126]: grow-fruit
11 [ 126]: grow-plant
12 [  14]: identify-life-stages-1
13 [  10]: identify-life-stages-2
14 [ 168]: inclined-plane-determine-angle
15 [1386]: inclined-plane-friction-named-surfaces
16 [ 162]: inclined-plane-friction-unnamed-surfaces
17 [ 125]: lifespan-longest-lived
18 [ 125]: lifespan-longest-lived-then-shortest-lived
19 [ 125]: lifespan-shortest-lived
20 [ 436]: measure-melting-point-known-substance
21 [ 300]: measure-melting-point-unknown-substance
22 [  30]: melt
23 [ 120]: mendelian-genetics-known-plant
24 [ 480]: mendelian-genetics-unknown-plant
25 [  20]: power-component
26 [  20]: power-component-renewable-vs-nonrenewable-energy
27 [ 900]: test-conductivity
28 

In [3]:
test_playable_sw = True

# Create our wrapper
scienceworld_wrapper = util.Scienceworld_Wrapper_Env()

if test_playable_sw:
    play_scienceworld = False

    task_idx = 8
    tasks = list_tasks(scienceworld_wrapper.env, print_tasks=False)
    task_name = tasks[task_idx]
    task_variation_number = 0
    
    print(f"Loading task #{task_idx}: {task_name}.")
    
    scienceworld_wrapper.load_env(
        randomize_gen_args=False,
        task_args=(task_name, task_variation_number, "")
    )

    util.print_env_state(scienceworld_wrapper.get_env_state(full_state=True))    
    
    # Main play loop
    done = False
    while not done:
        if play_scienceworld:
            command = input("> ")
        else:
            command = None
        if command == "stop":
            break
        done = scienceworld_wrapper.step_env(command)
        util.print_env_state(scienceworld_wrapper.get_env_state(full_state=False))

Loading task #8: find-plant.
Full Observation: Move 0 | Score = 0/100
Actionable Verbs: ['activate', 'close', 'connect OBJ to', 'deactivate', 'disconnect', 'dunk OBJ in', 'eat', 'flush', 'focus on', 'go', 'inventory', 'look around', 'look at', 'look in', 'mix', 'move OBJ to', 'open', 'pick up', 'pour OBJ in', 'put down', 'read', 'task', 'use OBJ on', 'wait', 'wait1']
Objective: Your task is to find a(n) plant. First, focus on the thing. Then, move it to the red box in the kitchen.
Location: Hallway
Inventory: In your inventory, you see:
	an orange

Current Observation: This room is called the hallway. In it, you see: 
	the agent
	a substance called air
	a picture
You also see:
	A door to the art studio (that is closed)
	A door to the bedroom (that is closed)
	A door to the greenhouse (that is closed)
	A door to the kitchen (that is closed)
	A door to the living room (that is closed)
	A door to the workshop (that is closed)




User: open door to greenhouse


Environment: The door is no

# Other Helpers

### Code to check the distribution of our move sampling

In [None]:
# from collections import deque, Counter
# import math, random

# multi_sample_args = util.SamplingArgs(
#     name="Test Multi",
#     max_samples_per_env=8,
#     rollout_length=5,
#     max_step_overlap=1
# )
# single_sample_args = util.SamplingArgs(
#     name="Test Single",
#     max_samples_per_env=20,
#     rollout_length=1,
#     max_step_overlap=0
# )


# def _get_sample_indices(self, gold_path: List[str], sampling_args) -> deque[int]:
#     # params
#     L = math.ceil(sampling_args.rollout_length / 2)
#     d = max(1, L - int(sampling_args.max_step_overlap))
#     U = len(gold_path)             # latest allowed start
#     if U < 0: return deque()

#     # random lattice offset to avoid 0-spike
#     o = random.randint(0, min(d - 1, max(0, U)))
#     cap = 1 + (U - o) // d
#     n = min(int(sampling_args.max_samples_per_env), cap)
#     if n <= 0: return deque()
#     if n == 1: return deque([random.randint(o, U)])

#     # slack after tight pack at o, o+d, ...
#     S = U - (o + d * (n - 1))
#     if S <= 0: return deque([o + i * d for i in range(n)])

#     # stars-and-bars over (lead, inter..., tail) summing to S
#     cuts = sorted(random.sample(range(1, S + n), n))
#     prev = 0; parts = []
#     for c in cuts + [S + n]:
#         parts.append(c - prev - 1); prev = c
#     lead, inter = parts[0], parts[1:n]

#     xs, acc = [], o + lead
#     for i in range(n):
#         xi = acc + i * d
#         xs.append(xi)
#         if i < n - 1: acc += inter[i]

#     assert all(0 <= x <= U for x in xs)
#     assert all(xs[i+1] - xs[i] >= d for i in range(n - 1))
#     return deque(xs)



# # ---- Simulation & plotting ----
# def simulate_and_plot_move_index_freqs(num_envs=10000, min_len=30, max_len=60, seed=0, sampler=None):
#     import random, matplotlib.pyplot as plt
#     if sampler is None:
#         sampler = lambda gp, args: _get_sample_indices(None, gp, args)  # your exact signature
#     random.seed(seed)
#     single = [0]*(max_len+1); multi = [0]*(max_len+1)
#     for _ in range(num_envs):
#         L = random.randint(min_len, max_len); gp = ['a']*L
#         for x in sampler(gp, single_sample_args):
#             if 0 <= x <= max_len: single[x] += 1
#         for x in sampler(gp, multi_sample_args):
#             if 0 <= x <= max_len:  multi[x]  += 1
#     xs = range(max_len+1)
#     print(single)
#     print(multi)
#     plt.figure(); plt.plot(xs, single, label="single"); plt.plot(xs, multi, label="multi")
#     plt.xlabel("Move index (0-based)"); plt.ylabel("Frequency"); plt.legend(); plt.tight_layout(); plt.show()

# simulate_and_plot_move_index_freqs()