In [None]:
import flax
import jax
import jax.numpy as jnp
import os
from flax.training import checkpoints
import numpy as np
import matplotlib.pyplot as plt
import imageio
import openai
import copy
import cv2

from env.env import PICK_TARGETS, PLACE_TARGETS, PickPlaceEnv
import clip
import torch
from moviepy.editor import ImageSequenceClip
from clipport.model import TransporterNets, n_params
from clipport.train import train_step, eval_step
from clipport.run import run_cliport

from llm.score import gpt3_scoring, make_options
from llm.helper import *
from llm.affordance import affordance_scoring, affordance_score2
from llm.davnici import lm_planner_unct
from llm.run import run
from vild.forward import vild

In [None]:
env = PickPlaceEnv()

In [None]:
openai.api_key = ""

## User Configuration

In [None]:
raw_input = "put one block in red bowl." 
config = {"pick":  ["red block", "yellow block", "blue block"],
          "place": ['green bowl', 'red bowl']}

termination_string = "done()"

In [None]:
obs = env.reset(config)

img_top = env.get_camera_image()
img_top_rgb = cv2.cvtColor(img_top, cv2.COLOR_BGR2RGB)
plt.imshow(img_top)
plt.show()

## Setup

In [None]:
clip_model, clip_preprocess = clip.load("ViT-B/32")
clip_model.cuda().eval()

In [None]:
coord_x, coord_y = np.meshgrid(np.linspace(-1, 1, 224), np.linspace(-1, 1, 224), sparse=False, indexing='ij')
coords = np.concatenate((coord_x[..., None], coord_y[..., None]), axis=2)

In [None]:
category_names = ['blue block',
                  'red block',
                  'green block',
                  'orange block',
                  'yellow block',
                  'purple block',
                  'pink block',
                  'cyan block',
                  'brown block',
                  'gray block',

                  'blue bowl',
                  'red bowl',
                  'green bowl',
                  'orange bowl',
                  'yellow bowl',
                  'purple bowl',
                  'pink bowl',
                  'cyan bowl',
                  'brown bowl',
                  'gray bowl']

#@markdown ViLD settings.
category_name_string = ";".join(category_names)
max_boxes_to_draw = 8 #@param {type:"integer"}

# Extra prompt engineering: swap A with B for every (A, B) in list.
prompt_swaps = [('block', 'cube')]

nms_threshold = 0.4 #@param {type:"slider", min:0, max:0.9, step:0.05}
min_rpn_score_thresh = 0.4  #@param {type:"slider", min:0, max:1, step:0.01}
min_box_area = 10 #@param {type:"slider", min:0, max:10000, step:1.0}
max_box_area = 3000  #@param {type:"slider", min:0, max:10000, step:1.0}
vild_params = max_boxes_to_draw, nms_threshold, min_rpn_score_thresh, min_box_area, max_box_area

## ViLD

In [None]:
image_path = "./2db.png"
np.random.seed(2)
if config is None:
  pick_items = list(PICK_TARGETS.keys())
  pick_items = np.random.choice(pick_items, size=np.random.randint(1, 5), replace=False)

  place_items = list(PLACE_TARGETS.keys())[:-9]
  place_items = np.random.choice(place_items, size=np.random.randint(1, 6 - len(pick_items)), replace=False)
  config = {"pick":  pick_items,
            "place": place_items}
  print(pick_items, place_items)

obs = env.reset(config)

img_top = env.get_camera_image_top()
img_top_rgb = cv2.cvtColor(img_top, cv2.COLOR_BGR2RGB)
plt.imshow(img_top)

imageio.imsave(image_path, img_top)

In [None]:
vild_model = vild(clip_model,category_name_string, vild_params)

In [None]:
found_objects, boxes = vild_model.infer(image_path,plot_on=True)


## LLM

In [None]:
lm_planner = lm_planner_unct(2)
lm_planner.objects = copy.deepcopy(found_objects)
lm_planner.set_goal(raw_input)

In [None]:
THRES = 0.3

In [None]:
done = False
num_tasks = 0
max_tasks = 5
steps_text = []
while not done:
    num_tasks += 1
    if num_tasks > max_tasks:
        break
    tasks, scores , unct = lm_planner.plan_with_unct()
    for t in tasks:
        if 'done' in t:
            done = True
            break
    if done:
        break
    if tasks != None:
        selected_task = None
        if len(scores)>0:
            scores = np.asarray(scores)
            idxs= np.argsort(scores)
            flag = False
            for idx in idxs[::-1]:
                try:
                    aff = affordance_score2(tasks[idx], found_objects)
                except:
                    print(tasks[idx])
                    aff = 0
                if aff > 0:
                    selected_task = tasks[idx]
                    lm_planner.append(None, None, selected_task)
                    break
            if aff == 2: 
                done=True 
                lm_planner.append(None, None, tasks[idx])
                # steps_text.append("done()")
                # uncts.append(unct)
                break
        if unct > THRES:
            reason, ques = lm_planner.question_generation()
            answer = input("Answer: ")
            lm_planner.answer(answer)
        elif selected_task != None:
            steps_text.append(selected_task)
    else: break

## Execute

In [None]:
# Initialize model weights using dummy tensors.
rng = jax.random.PRNGKey(0)
rng, key = jax.random.split(rng)
init_img = jnp.ones((4, 224, 224, 5), jnp.float32)
init_text = jnp.ones((4, 512), jnp.float32)
init_pix = jnp.zeros((4, 2), np.int32)
init_params = TransporterNets().init(key, init_img, init_text, init_pix)['params']
# gpus = jax.devices('gpu')
# init_params = jax.jit(init_params,device=gpus[1])
print(f'Model parameters: {n_params(init_params):,}')
optim = flax.optim.Adam(learning_rate=1e-4).create(init_params)


ckpt_path = f'ckpt_{40000}'
if not os.path.exists(ckpt_path):
    !gdown --id 1Nq0q1KbqHOA5O7aRSu4u7-u27EMMXqgP
optim = checkpoints.restore_checkpoint(ckpt_path, optim)
print('Loaded:', ckpt_path)

In [None]:
def save_img(env, image_path):
    img_top = env.get_camera_image_top()
    imageio.imsave(image_path, img_top)

In [None]:
print('Initial state:')
plt.imshow(env.get_camera_image())
for i, step in enumerate(steps_text[:-1]):
  if step == '' or step == termination_string:
    break
  print(step)
  nlp_step = step_to_nlp(step)
  print('GPT-3 says next step:', nlp_step)
  obs = run_cliport(env,clip_model,coords, optim, obs, nlp_step)
  # success = False
  # while not success:
  #   obs = run_cliport(env,clip_model,coords, optim, obs, nlp_step)
  #   save_img(env, image_path)
  #   category_name_string = ";".join(found_objects)
  #   vild_model.category_name_string = copy.deepcopy(category_name_string)
  #   found_objects, boxes = vild_model.infer(image_path)
  #   success = success_detector(found_objects,boxes,step)

  # Show camera image after task.
print('Final state:')
plt.imshow(env.get_camera_image())

In [None]:
debug_clip = ImageSequenceClip(env.cache_video, fps=25)
display(debug_clip.ipython_display(autoplay=1, loop=1, center=False))
env.cache_video = []