In [1]:
import numpy as np
import numpy.random as npr
import random

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split

import matplotlib.pyplot as plt

import torchvision
import torchvision.datasets as datasets

import itertools
import math
import string
import os

from macq import generate, extract
from macq.observation import IdentityObservation, AtomicPartialObservation





# First-order Methods

In [2]:
class Type:
    def __init__(self, name, parent):
        self.name = name
        self.parent = parent
        
    def is_child(self, another_type):
        if self.name == another_type.name:
            return True
        elif self.parent is None:
            return False
        else:
            return self.parent.is_child(another_type)

In [3]:
class Predicate:
  def __init__(self, name, params):
    self.name = name
    # params are dicts {Type: num}
    self.params = params
    self.params_types = sorted(params.keys(), key=lambda x: x.name)

  def proposition(self, sorted_obj_lists):
    return (self.name + ' ' + ' '.join([f'{self.params_types[i].name} '+f' {self.params_types[i].name} '.join(sorted_obj_lists[i])
                                       for i in range(len(sorted_obj_lists))])).strip()

  def ground(self, objects):
    '''
    Input a list of objects in the form {Type: []}
    Return all the propositions grounded from this predicates with the objects
    '''
    propositions = []
    obj_lists_per_params = {params_type:[] for params_type in self.params_types}
    for params_type in self.params_types:
        for obj_type in objects.keys():
            if obj_type.is_child(params_type):
                obj_lists_per_params[params_type].extend(objects[obj_type])
    for obj_lists in itertools.product(*[itertools.permutations(obj_lists_per_params[params_type], self.params[params_type])\
                                        for params_type in self.params_types]):
      propositions.append(self.proposition(obj_lists))
    return propositions
    
  def ground_num(self, objects):
    '''
    Return how many propositions this predicate can ground on the objects
    '''
    n_ground = 1
    for params_type in self.params_types:
      n_obj = 0
      for obj_type in objects.keys():
        if obj_type.is_child(params_type):
          n_obj += len(objects[obj_type])
      n_ground *= math.perm(n_obj, self.params[params_type])
    return n_ground

In [4]:
class Action_Schema(nn.Module):
  def __init__(self, name, params):
    super(Action_Schema, self).__init__()
    self.name = name
    # params are dicts {Type: num}
    self.params = params
    self.params_types = sorted(params.keys(), key=lambda x: x.name)
    # predicates that are relevant
    self.predicates = []

  def initialise(self, predicates, device):
    '''
    Input all predicates and generate the model for action schema
    '''
    n_features = 0
    for predicate in predicates:
      # A predicate is relevant to an action schema iff for each of its param type,
      # the number of objects required is leq the number of objects there is 
      # for the same type or children type in the action schema
      is_relevant = True
      # Also calculate how many propositions there are when predicate is grounded on "variables"
      # e.g. on X Y; on Y X when X and Y are variables
      n_ground = 1
      for params_type in predicate.params_types:
        n_params = 0
        for model_params_type in self.params:
          if model_params_type.is_child(params_type):
            n_params += self.params[model_params_type]
        if predicate.params[params_type]>n_params:
          is_relevant = False
          break
        else:
          n_ground *= math.perm(n_params, predicate.params[params_type])
      if is_relevant:
        self.predicates.append(predicate)
        n_features += n_ground
    n_features = int(n_features)
    
    self.randn = torch.randn(n_features, 128, device=device, requires_grad=True)
    self.mlp = nn.Sequential(
        nn.Linear(128, 64),
        nn.ReLU(),
        nn.Linear(64, 32),
        nn.ReLU(),
        nn.Linear(32, 16),
        nn.ReLU(),
        nn.Linear(16, 4),
        nn.Softmax(dim=1)
    )
    self.mlp.to(device)
    
  def forward(self):
    return self.mlp(self.randn)

  def ground(self, objects, is_single_action=False):
    if is_single_action:
      propositions = []
      for predicate in self.predicates:
        propositions.extend(predicate.ground(objects))
      return propositions
    else:
      propositions = []
      obj_lists_per_params = {params_type:[] for params_type in self.params_types}
      for params_type in self.params_types:
        for obj_type in objects.keys():
          if obj_type.is_child(params_type):
            obj_lists_per_params[params_type].extend(objects[obj_type])  
      for obj_list in itertools.product(*[itertools.permutations(obj_lists_per_params[params_type], self.params[params_type])\
                                          for params_type in self.params_types]):
        objects_per_action = {}
        for i in range(len(self.params_types)):
          objects_per_action[self.params_types[i]] = obj_list[i]
        propositions_per_action = []
        for predicate in self.predicates:
          propositions_per_action.extend(predicate.ground(objects_per_action))
        propositions.append(propositions_per_action)
      return propositions
    
  def pretty_print(self):
    var = {}
    n = 0
    for param_type in self.params_types:
      var[param_type] = list(string.ascii_lowercase)[n:n+self.params[param_type]]
      n += self.params[param_type]
    print(f'{self.name}' + ' ' + ' '.join([k.name+' '+v for k in var.keys() for v in var[k]]))
    propositions = self.ground(var, True)
    precon_list = []
    addeff_list = []
    deleff_list = []
    result = torch.argmax(self(), dim=1)
    for i in range(len(propositions)):
        if result[i]==1:
            addeff_list.append(propositions[i])
        elif result[i]==2:
            precon_list.append(propositions[i])
        elif result[i]==3:
            precon_list.append(propositions[i])
            deleff_list.append(propositions[i])
    print(', '.join(precon_list))
    print(', '.join(addeff_list))
    print(', '.join(deleff_list))

In [5]:
class Domain_Model(nn.Module):
  def __init__(self, predicates, action_schemas, device):
    super(Domain_Model, self).__init__()
    self.predicates = predicates
    self.action_schemas = action_schemas
    self.device = device
    for action_schema in action_schemas:
      action_schema.initialise(predicates, self.device)

  def ground(self, objects):
    # Ground predicates to propositions
    # Record in a dictionary with values as indices, for later lookup
    self.propositions = {}
    for predicate in self.predicates:
      for proposition in predicate.ground(objects):
        self.propositions[proposition] = len(self.propositions)

    # For each action schema, ground to actions and then find the indices
    self.indices = []
    # Also need to know which action schema each action is from
    self.action_to_schema = []
    for action_schema in self.action_schemas:
      for propositions in action_schema.ground(objects):
        self.indices.append([self.propositions[p] for p in propositions])
        self.action_to_schema.append(action_schema)
        

  def build(self, actions):
    '''
    actions is a list of numbers
    '''
    precon = torch.zeros((len(actions), len(self.propositions)), device=self.device, requires_grad=False)
    addeff = torch.zeros((len(actions), len(self.propositions)), device=self.device, requires_grad=False)
    deleff = torch.zeros((len(actions), len(self.propositions)), device=self.device, requires_grad=False)
    for i in range(len(actions)):
      y_indices = self.indices[actions[i]]
      schema = self.action_to_schema[actions[i]]
      y_indices_set = set(y_indices)
      
      schema_prams = schema()
      schema_precon = schema_prams @ torch.tensor([0.0, 0.0, 1.0, 1.0], device=self.device)
      schema_addeff = schema_prams @ torch.tensor([0.0, 1.0, 0.0, 0.0], device=self.device)
      schema_deleff = schema_prams @ torch.tensor([0.0, 0.0, 0.0, 1.0], device=self.device)

      if len(y_indices)>len(y_indices_set):
        # There are duplicate indices in y_indices
        # Multiple predicates are grounded to one same proposition
        # We need to combine the contribution from different predicates to one proposition
        applied = set()
        for y_idx in y_indices:
          if y_idx not in applied:
            precon[i, y_idx] += schema_precon[y_idx]
            addeff[i, y_idx] += schema_addeff[y_idx]
            deleff[i, y_idx] += schema_deleff[y_idx]
            applied.add(y_idx)
          else:
            # The multiple effects are combined with "or"
            # p v q = not ((not p)^(not q))
            precon[i, y_idx] = 1 - (1-precon[i, y_idx])*(1-schema_precon[y_idx])
            addeff[i, y_idx] = 1 - (1-addeff[i, y_idx])*(1-schema_addeff[y_idx])
            deleff[i, y_idx] = 1 - (1-deleff[i, y_idx])*(1-schema_deleff[y_idx])
      else:
        x_indices = [i]*len(y_indices)
        precon[x_indices, y_indices] += schema_precon
        addeff[x_indices, y_indices] += schema_addeff
        deleff[x_indices, y_indices] += schema_deleff
    return precon, addeff, deleff

# Blockworld

## Model

In [6]:
obj = Type("object", None)

In [7]:
model = Domain_Model(
    [
        Predicate('arm-empty', {}),
        Predicate('clear', {obj:1}),
        Predicate('on-table', {obj:1}),
        Predicate('holding', {obj:1}),
        Predicate('on', {obj:2}),
    ],
    [
        Action_Schema('pickup', {obj:1}),
        Action_Schema('putdown', {obj:1}),
        Action_Schema('stack', {obj:2}),
        Action_Schema('unstack', {obj:2}),
    ]
, device="cpu")

In [8]:
objects = {obj: ['block1', 'block2', 'block3', 'block4', 'block5']}

In [9]:
model.ground(objects)

In [10]:
all_grounded_actions = {}
for action_schema in model.action_schemas:
    obj_lists_per_params = {params_type:[] for params_type in action_schema.params_types}
    for params_type in action_schema.params_types:
        for obj_type in objects.keys():
            if obj_type.is_child(params_type):
                obj_lists_per_params[params_type].extend(objects[obj_type])  
    for obj_list in itertools.product(*[itertools.permutations(obj_lists_per_params[params_type], action_schema.params[params_type])\
                                          for params_type in action_schema.params_types]):
        objects_per_action = {}
        constructed = action_schema.name + ' ' + ' '.join([f'{action_schema.params_types[i].name} '
                                                           +f' {action_schema.params_types[i].name} '.join(obj_list[i])
                                                           for i in range(len(obj_list))])
        all_grounded_actions[constructed] = len(all_grounded_actions)

## Data

In [11]:
steps_state1 = []
steps_action = []
steps_state2 = []

traces = generate.pddl.VanillaSampling(dom='./blockworld/domain.pddl', prob='./blockworld/prob01.pddl', plan_len = 10, num_traces = 10).traces
for trace in traces:
    # last step no action
    for t in range(len(trace.steps)-1):
        fluents_in_state1 = {f._serialize()[1:-1] for f in trace.steps[t].state if trace.steps[t].state[f] is True}
        fluents_in_state2 = {f._serialize()[1:-1] for f in trace.steps[t+1].state if trace.steps[t+1].state[f] is True}
        state1 = [1 if p in fluents_in_state1 else 0 for p in model.propositions]
        state2 = [1 if p in fluents_in_state2 else 0 for p in model.propositions]
        
        # action parameter order may not be the same as our model
        action = trace.steps[t].action
        action_obj_params = sorted([o for o in action.obj_params], key=lambda o:o.obj_type)
        
        steps_action.append(all_grounded_actions[f"{action.name} {' '.join([o.details()for o in action_obj_params])}"])
        
        steps_state1.append(state1)
        steps_state2.append(state2)

100%|██████████████████████████████████████████| 10/10 [00:00<00:00, 110.34it/s]


In [12]:
steps_state1_tensor = torch.tensor(np.array(steps_state1)).float()
steps_action_tensor = torch.tensor(np.array(steps_action))
steps_state2_tensor = torch.tensor(np.array(steps_state2)).float()

In [13]:
batch_sz = 1000
dataset = TensorDataset(steps_state1_tensor, steps_action_tensor, steps_state2_tensor)
dataloader = DataLoader(dataset, batch_size=batch_sz, shuffle=False)

## Training

In [14]:
parameters = []
for schema in model.action_schemas:
    parameters.append({'params': schema.parameters(), 'lr': 1e-3})
optimizer = optim.Adam(parameters)

In [15]:
for epoch in range(100):
  optimizer.zero_grad()
  loss_final = 0.0
  for i, (state_1, executed_actions, state_2) in enumerate(dataloader):
    precon, addeff, deleff = model.build(executed_actions)
    # The result of applying a in s is (s\Del(a)) U Add(a)
    # We can simplfy this to be:
    # ((p in state 1) ^ (not p in Del(a))) v ((not p in state 1) ^ (p in Add(a)))
    # Note we implicitly apply the constraint that add effects and preconditions
    # cannot intersect and only preconditions can be deleted
    # The "or" can be translated to an addition as the two sides and exclusive
#     preds = addeff + (1-addeff)*state_1*(1-deleff)
    preds = 1- (1-state_1*(1-deleff)) * (1-(1-state_1)*addeff)
    
    # Since we view the state_2 as true targets, we can binary cross-entropy loss
    # If state_2 is also predicated, use KL-divergence to ensure two distributions are close?
    loss = F.mse_loss(preds, state_2, reduction='sum')
    # Add in validity constraint
    # Since executed actions are applicable in state_1
    # p in Pre(a) -> p in state_1 for all a in executed_actions
    # not ((p in Pre(a)) ^ (not p in state_1))
    validity_constraint = (1-state_1) * (precon)
    loss += F.mse_loss(validity_constraint, torch.zeros(validity_constraint.shape, dtype=validity_constraint.dtype), reduction='sum')
    loss += 0.2*F.mse_loss(precon, torch.ones(precon.shape, dtype=precon.dtype), reduction='sum')
#     loss += model.constraint_loss()
    loss.backward()
    optimizer.step()
    loss_final += loss.item() / batch_sz
  if epoch%10 == 0:
    print('Epoch {} RESULTS: Average loss: {:.10f}'.format(epoch, loss_final))

Epoch 0 RESULTS: Average loss: 0.8820755615
Epoch 10 RESULTS: Average loss: 0.8452726440
Epoch 20 RESULTS: Average loss: 0.7793428955
Epoch 30 RESULTS: Average loss: 0.6977715454
Epoch 40 RESULTS: Average loss: 0.6401904297
Epoch 50 RESULTS: Average loss: 0.6134895630
Epoch 60 RESULTS: Average loss: 0.6045328979
Epoch 70 RESULTS: Average loss: 0.5974949341
Epoch 80 RESULTS: Average loss: 0.5947376709
Epoch 90 RESULTS: Average loss: 0.5938486328


In [16]:
for action_schema in model.action_schemas:
    action_schema.pretty_print()
    print()

pickup object a
arm-empty, clear object a, on-table object a
holding object a
arm-empty, clear object a, on-table object a

putdown object a
holding object a
arm-empty, clear object a, on-table object a
holding object a

stack object a object b
clear object b, holding object a
arm-empty, clear object a, on object a object b
clear object b, holding object a

unstack object a object b
arm-empty, clear object a, on object a object b
clear object b, holding object a
arm-empty, clear object a, on object a object b



# Gripper

## Model (Problem Specific)

In [17]:
base = Type("object", None)
room = Type("room", base)
ball = Type("ball", base)
gripper = Type("gripper", base)

In [18]:
model = Domain_Model([
                        Predicate('at-robby', {room:1}),
                        Predicate('at', {ball:1, room:1}),
                        Predicate('free', {gripper:1}),
                        Predicate('carry', {ball:1, gripper:1}),],
                     [
                        Action_Schema('move', {room:2}),
                        Action_Schema('pick', {ball:1, room:1, gripper:1}),
                        Action_Schema('drop', {ball:1, room:1, gripper:1}),
                     ], device='cpu')

In [19]:
objects = {
    room: ['rooma', 'roomb'],
    ball: ['ball1', 'ball2', 'ball3', 'ball4', 'ball5', 'ball6'],
    gripper: ['left', 'right']
}

In [20]:
model.ground(objects)

In [21]:
all_grounded_actions = {}
for action_schema in model.action_schemas:
    obj_lists_per_params = {params_type:[] for params_type in action_schema.params_types}
    for params_type in action_schema.params_types:
        for obj_type in objects.keys():
            if obj_type.is_child(params_type):
                obj_lists_per_params[params_type].extend(objects[obj_type])  
    for obj_list in itertools.product(*[itertools.permutations(obj_lists_per_params[params_type], action_schema.params[params_type])\
                                          for params_type in action_schema.params_types]):
        objects_per_action = {}
        constructed = action_schema.name + ' ' + ' '.join([f'{action_schema.params_types[i].name} '
                                                           +f' {action_schema.params_types[i].name} '.join(obj_list[i])
                                                           for i in range(len(obj_list))])
        all_grounded_actions[constructed] = len(all_grounded_actions)

In [22]:
all_grounded_actions

{'move room rooma room roomb': 0,
 'move room roomb room rooma': 1,
 'pick ball ball1 gripper left room rooma': 2,
 'pick ball ball1 gripper left room roomb': 3,
 'pick ball ball1 gripper right room rooma': 4,
 'pick ball ball1 gripper right room roomb': 5,
 'pick ball ball2 gripper left room rooma': 6,
 'pick ball ball2 gripper left room roomb': 7,
 'pick ball ball2 gripper right room rooma': 8,
 'pick ball ball2 gripper right room roomb': 9,
 'pick ball ball3 gripper left room rooma': 10,
 'pick ball ball3 gripper left room roomb': 11,
 'pick ball ball3 gripper right room rooma': 12,
 'pick ball ball3 gripper right room roomb': 13,
 'pick ball ball4 gripper left room rooma': 14,
 'pick ball ball4 gripper left room roomb': 15,
 'pick ball ball4 gripper right room rooma': 16,
 'pick ball ball4 gripper right room roomb': 17,
 'pick ball ball5 gripper left room rooma': 18,
 'pick ball ball5 gripper left room roomb': 19,
 'pick ball ball5 gripper right room rooma': 20,
 'pick ball ball5 g

## Data

In [23]:
steps_state1 = []
steps_action = []
steps_state2 = []

traces = generate.pddl.VanillaSampling(dom='./gripper/domain.pddl', prob='./gripper/prob01.pddl', plan_len = 10, num_traces = 10).traces
for trace in traces:
    # last step no action
    for t in range(len(trace.steps)-1):
        fluents_in_state1 = {f._serialize()[1:-1] for f in trace.steps[t].state if trace.steps[t].state[f] is True}
        fluents_in_state2 = {f._serialize()[1:-1] for f in trace.steps[t+1].state if trace.steps[t+1].state[f] is True}
        state1 = [1 if p in fluents_in_state1 else 0 for p in model.propositions]
        state2 = [1 if p in fluents_in_state2 else 0 for p in model.propositions]
        
        # action parameter order may not be the same as our model
        action = trace.steps[t].action
        action_obj_params = sorted([o for o in action.obj_params], key=lambda o:o.obj_type)
        
        steps_action.append(all_grounded_actions[f"{action.name} {' '.join([o.details()for o in action_obj_params])}"])
        
        steps_state1.append(state1)
        steps_state2.append(state2)

100%|██████████████████████████████████████████| 10/10 [00:00<00:00, 157.10it/s]


In [24]:
steps_state1_tensor = torch.tensor(np.array(steps_state1)).float()
steps_action_tensor = torch.tensor(np.array(steps_action))
steps_state2_tensor = torch.tensor(np.array(steps_state2)).float()

In [25]:
batch_sz = 1000
dataset = TensorDataset(steps_state1_tensor, steps_action_tensor, steps_state2_tensor)
dataloader = DataLoader(dataset, batch_size=batch_sz, shuffle=False)

## Training

In [26]:
parameters = []
for schema in model.action_schemas:
    parameters.append({'params': schema.parameters(), 'lr': 1e-3})
optimizer = optim.Adam(parameters)

In [27]:
for epoch in range(100):
  loss_final = 0.0
  for i, (state_1, executed_actions, state_2) in enumerate(dataloader):
    optimizer.zero_grad()
    precon, addeff, deleff = model.build(executed_actions)
    # The result of applying a in s is (s\Del(a)) U Add(a)
    # We can simplfy this to be:
    # ((p in state 1) ^ (not p in Del(a))) v ((not p in state 1) ^ (p in Add(a)))
    # Note we implicitly apply the constraint that add effects and preconditions
    # cannot intersect and only preconditions can be deleted
    # The "or" can be translated to an addition as the two sides and exclusive
#     preds = addeff + (1-addeff)*state_1*(1-deleff)
    preds = state_1*(1-deleff) + (1-state_1)*addeff
    
    # Since we view the state_2 as true targets, we can binary cross-entropy loss
    # If state_2 is also predicated, use KL-divergence to ensure two distributions are close?
    loss = F.mse_loss(preds, state_2, reduction='sum')
    # Add in validity constraint
    # Since executed actions are applicable in state_1
    # p in Pre(a) -> p in state_1 for all a in executed_actions
    # not ((p in Pre(a)) ^ (not p in state_1))
    validity_constraint = (1-state_1) * (precon)
    loss += F.mse_loss(validity_constraint, torch.zeros(validity_constraint.shape, dtype=validity_constraint.dtype), reduction='sum')
#     loss += model.constraint_loss()
    loss += 0.2*F.mse_loss(precon, torch.ones(precon.shape, dtype=precon.dtype), reduction='sum')
    loss.backward()
    optimizer.step()
    loss_final += loss.item() / batch_sz
  if epoch%10 == 0:
    print('Epoch {} RESULTS: Average loss: {:.10f}'.format(epoch, loss_final))

Epoch 0 RESULTS: Average loss: 0.6226490479
Epoch 10 RESULTS: Average loss: 0.5924045410
Epoch 20 RESULTS: Average loss: 0.5525383301
Epoch 30 RESULTS: Average loss: 0.5116485291
Epoch 40 RESULTS: Average loss: 0.4767666321
Epoch 50 RESULTS: Average loss: 0.4681913452
Epoch 60 RESULTS: Average loss: 0.4654275513
Epoch 70 RESULTS: Average loss: 0.4632035828
Epoch 80 RESULTS: Average loss: 0.4623683167
Epoch 90 RESULTS: Average loss: 0.4620728760


In [28]:
for action_schema in model.action_schemas:
    action_schema.pretty_print()
    print()

move room a room b
at-robby room a
at-robby room b
at-robby room a

pick ball a gripper b room c
at-robby room c, at ball a room c, free gripper b
carry ball a gripper b
at ball a room c, free gripper b

drop ball a gripper b room c
at-robby room c, carry ball a gripper b
at ball a room c, free gripper b
carry ball a gripper b



# Logistics

## Model

In [29]:
base = Type("object", None)
movable = Type("movable", base)
location = Type("location", base)
city = Type("city", base)
obj = Type("obj", movable)
transport = Type("transport", movable)
truck = Type("truck", transport)
airplane = Type("airplane", transport)
airport = Type("airport", location)

In [30]:
model = Domain_Model([
                        Predicate('at', {movable:1, location: 1}),
                        Predicate('in', {obj:1, transport:1}),
                        Predicate('in-city', {location:1, city:1}),],
                     [
                        Action_Schema('load-truck', {obj:1, truck:1, location:1}),
                        Action_Schema('load-airplane', {obj:1, airplane:1, airport:1}),
                        Action_Schema('unload-truck', {obj:1, truck:1, location:1}),
                        Action_Schema('unload-airplane', {obj:1, airplane:1, airport:1}),
                        Action_Schema('drive-truck', {truck:1, location:2, city:1}),
                        Action_Schema('fly-airplane', {airplane:1, airport:2})
                     ], device="cpu")

In [31]:
objects = {
    location: [
        'city1-1', 'city2-1'
    ],
    city: [
        'city1', 'city2'
    ],
    obj: ['package1', 'package2', 'package3', 'package4', 'package5', 'package6',],
    truck: ['truck1', 'truck2'],
    airplane: ['plane1', 'plane2'],
    airport: ['city1-2', 'city2-2']
}

In [32]:
model.ground(objects)

In [33]:
all_grounded_actions = {}
for action_schema in model.action_schemas:
    obj_lists_per_params = {params_type:[] for params_type in action_schema.params_types}
    for params_type in action_schema.params_types:
        for obj_type in objects.keys():
            if obj_type.is_child(params_type):
                obj_lists_per_params[params_type].extend(objects[obj_type])  
    for obj_list in itertools.product(*[itertools.permutations(obj_lists_per_params[params_type], action_schema.params[params_type])\
                                          for params_type in action_schema.params_types]):
        objects_per_action = {}
        constructed = action_schema.name + ' ' + ' '.join([f'{action_schema.params_types[i].name} '
                                                           +f' {action_schema.params_types[i].name} '.join(obj_list[i])
                                                           for i in range(len(obj_list))])
        all_grounded_actions[constructed] = len(all_grounded_actions)

## Data

In [34]:
steps_state1 = []
steps_action = []
steps_state2 = []

traces = generate.pddl.VanillaSampling(dom='./logistics/domain.pddl', prob='./logistics/prob02.pddl', plan_len=10, num_traces=10).traces
for trace in traces:
    # last step no action
    for t in range(len(trace.steps)-1):
        fluents_in_state1 = set()
        for f in trace.steps[t].state:
            if trace.steps[t].state[f] is True:
                if f.name=='at':
                    serialized_list = f._serialize()[1:-1].split(' ')
                    fluents_in_state1.add(f'at location {serialized_list[4]} movable {serialized_list[2]}')
                elif f.name=='in':
                    serialized_list = f._serialize()[1:-1].split(' ')
                    fluents_in_state1.add(f'in obj {serialized_list[2]} transport {serialized_list[4]}')
                elif f.name=='in-city':
                    serialized_list = f._serialize()[1:-1].split(' ')
                    fluents_in_state1.add(f'in-city city {serialized_list[4]} location {serialized_list[2]}')
        fluents_in_state2 = set()
        for f in trace.steps[t+1].state:
            if trace.steps[t+1].state[f] is True:
                if f.name=='at':
                    serialized_list = f._serialize()[1:-1].split(' ')
                    fluents_in_state2.add(f'at location {serialized_list[4]} movable {serialized_list[2]}')
                elif f.name=='in':
                    serialized_list = f._serialize()[1:-1].split(' ')
                    fluents_in_state2.add(f'in obj {serialized_list[2]} transport {serialized_list[4]}')
                elif f.name=='in-city':
                    serialized_list = f._serialize()[1:-1].split(' ')
                    fluents_in_state2.add(f'in-city city {serialized_list[4]} location {serialized_list[2]}')
        state1 = [1 if p in fluents_in_state1 else 0 for p in model.propositions]
        state2 = [1 if p in fluents_in_state2 else 0 for p in model.propositions]
        
        # action parameter order may not be the same as our model
        action = trace.steps[t].action
        action_obj_params = []
        for o in action.obj_params:
            if 'airplane' not in action.name and o.obj_type=='airport':
                o.obj_type = 'location'
            action_obj_params.append(o)
        action_obj_params = sorted(action_obj_params, key=lambda o:o.obj_type)
        
        steps_action.append(all_grounded_actions[f"{action.name} {' '.join([o.details()for o in action_obj_params])}"])
        
        steps_state1.append(state1)
        steps_state2.append(state2)

100%|███████████████████████████████████████████| 10/10 [00:00<00:00, 64.44it/s]


In [35]:
steps_state1_tensor = torch.tensor(np.array(steps_state1)).float()
steps_state1_tensor[:, [64, 66, 69, 71]] = 1.0
steps_action_tensor = torch.tensor(np.array(steps_action))
steps_state2_tensor = torch.tensor(np.array(steps_state2)).float()
steps_state2_tensor[:, [64, 66, 69, 71]] = 1.0

In [36]:
batch_sz = 1000
dataset = TensorDataset(steps_state1_tensor, steps_action_tensor, steps_state2_tensor)
dataloader = DataLoader(dataset, batch_size=batch_sz, shuffle=False)

## Training

In [37]:
parameters = []
for schema in model.action_schemas:
    parameters.append({'params': schema.parameters(), 'lr': 1e-3})
optimizer = optim.Adam(parameters)

In [38]:
for epoch in range(100):
  optimizer.zero_grad()
  loss_final = 0.0
  for i, (state_1, executed_actions, state_2) in enumerate(dataloader):
    precon, addeff, deleff = model.build(executed_actions)
    # The result of applying a in s is (s\Del(a)) U Add(a)
    # We can simplfy this to be:
    # ((p in state 1) ^ (not p in Del(a))) v ((not p in state 1) ^ (p in Add(a)))
    # Note we implicitly apply the constraint that add effects and preconditions
    # cannot intersect and only preconditions can be deleted
    # The "or" can be translated to an addition as the two sides and exclusive
#     preds = addeff + (1-addeff)*state_1*(1-deleff)
    preds = 1- (1-state_1*(1-deleff)) * (1-(1-state_1)*addeff)
    
    # Since we view the state_2 as true targets, we can binary cross-entropy loss
    # If state_2 is also predicated, use KL-divergence to ensure two distributions are close?
    loss = F.mse_loss(preds, state_2, reduction='sum')
    # Add in validity constraint
    # Since executed actions are applicable in state_1
    # p in Pre(a) -> p in state_1 for all a in executed_actions
    # not ((p in Pre(a)) ^ (not p in state_1))
    validity_constraint = (1-state_1) * (precon)
    loss += F.mse_loss(validity_constraint, torch.zeros(validity_constraint.shape, dtype=validity_constraint.dtype), reduction='sum')
#     loss += model.constraint_loss()
    loss += 0.2*F.mse_loss(precon, torch.ones(precon.shape, dtype=precon.dtype), reduction='sum')
    loss.backward()
    optimizer.step()
    loss_final += loss.item() / batch_sz
  if epoch%10 == 0:
    print('Epoch {} RESULTS: Average loss: {:.10f}'.format(epoch, loss_final))

Epoch 0 RESULTS: Average loss: 1.3849838867
Epoch 10 RESULTS: Average loss: 1.3639166260
Epoch 20 RESULTS: Average loss: 1.3282530518
Epoch 30 RESULTS: Average loss: 1.2905253906
Epoch 40 RESULTS: Average loss: 1.2696861572
Epoch 50 RESULTS: Average loss: 1.2635958252
Epoch 60 RESULTS: Average loss: 1.2611717529
Epoch 70 RESULTS: Average loss: 1.2597830811
Epoch 80 RESULTS: Average loss: 1.2593234863
Epoch 90 RESULTS: Average loss: 1.2591214600


In [39]:
for action_schema in model.action_schemas:
    action_schema.pretty_print()
    print()

load-truck location a obj b truck c
at location a movable b, at location a movable c
in obj b transport c
at location a movable b

load-airplane airplane a airport b obj c
at location b movable a, at location b movable c
in obj c transport a
at location b movable c

unload-truck location a obj b truck c
at location a movable c, in obj b transport c
at location a movable b
in obj b transport c

unload-airplane airplane a airport b obj c
at location b movable a, in obj c transport a
at location b movable c
in obj c transport a

drive-truck city a location b location c truck d
at location b movable d, in-city city a location b, in-city city a location c
at location c movable d
at location b movable d

fly-airplane airplane a airport b airport c
at location b movable a
at location c movable a
at location b movable a

