Skip to content

Commit

Permalink
Merge pull request #1207 from kithib/main
Browse files Browse the repository at this point in the history
add new android operation(open app,exit,stop)
  • Loading branch information
geekan committed May 17, 2024
2 parents fa164ac + 2b08ad9 commit d9ed99e
Show file tree
Hide file tree
Showing 5 changed files with 580 additions and 10 deletions.
138 changes: 130 additions & 8 deletions metagpt/environment/android/android_ext_env.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : The Android external environment to integrate with Android apps

import subprocess
import clip
import time
from pathlib import Path
from typing import Any, Optional

from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks

from PIL import Image
from pydantic import Field

from metagpt.environment.android.text_icon_localization import *
from metagpt.environment.android.const import ADB_EXEC_FAIL
from metagpt.environment.android.env_space import (
EnvAction,
Expand All @@ -17,6 +23,20 @@
EnvObsValType,
)
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
from metagpt.logs import logger
from metagpt.utils.common import download_model
from metagpt.const import DEFAULT_WORKSPACE_ROOT


def load_cv_model(device: str = "cpu") -> any:
ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo")
ocr_recognition = pipeline(Tasks.ocr_recognition,
model="damo/cv_convnextTiny_ocr-recognition-document_damo")
file_url = "https://huggingface.co/ShilongLiu/GroundingDINO/blob/main/groundingdino_swint_ogc.pth"
target_folder = Path(f"{DEFAULT_WORKSPACE_ROOT}/weights")
file_path = download_model(file_url, target_folder)
groundingdino_model = load_model(file_path, device=device).eval()
return ocr_detection, ocr_recognition, groundingdino_model


class AndroidExtEnv(ExtEnv):
Expand All @@ -25,26 +45,29 @@ class AndroidExtEnv(ExtEnv):
xml_dir: Optional[Path] = Field(default=None)
width: int = Field(default=720, description="device screen width")
height: int = Field(default=1080, description="device screen height")
ocr_detection: any = Field(default=None, description="ocr detection model")
ocr_recognition: any = Field(default=None, description="ocr recognition model")
groundingdino_model: any = Field(default=None, description="clip groundingdino model")

def __init__(self, **data: Any):
super().__init__(**data)
device_id = data.get("device_id")
self.ocr_detection, self.ocr_recognition, self.groundingdino_model = load_cv_model()
if device_id:
devices = self.list_devices()
if device_id not in devices:
raise RuntimeError(f"device-id: {device_id} not found")
(width, height) = self.device_shape
self.width = data.get("width", width)
self.height = data.get("height", height)

self.create_device_path(self.screenshot_dir)
self.create_device_path(self.xml_dir)

def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[dict[str, Any]] = None,
self,
*,
seed: Optional[int] = None,
options: Optional[dict[str, Any]] = None,
) -> tuple[dict[str, Any], dict[str, Any]]:
super().reset(seed=seed, options=options)

Expand Down Expand Up @@ -149,14 +172,26 @@ def get_screenshot(self, ss_name: str, local_save_dir: Path) -> Path:
ss_remote_path = Path(self.screenshot_dir).joinpath(f"{ss_name}.png")
ss_cmd = f"{self.adb_prefix_shell} screencap -p {ss_remote_path}"
ss_res = self.execute_adb_with_cmd(ss_cmd)

time.sleep(0.1)
res = ADB_EXEC_FAIL
if ss_res != ADB_EXEC_FAIL:
ss_local_path = Path(local_save_dir).joinpath(f"{ss_name}.png")
pull_cmd = f"{self.adb_prefix} pull {ss_remote_path} {ss_local_path}"
pull_res = self.execute_adb_with_cmd(pull_cmd)
time.sleep(0.1)
if pull_res != ADB_EXEC_FAIL:
res = ss_local_path
else:
ss_cmd = f"{self.adb_prefix_shell} rm /sdcard/{ss_name}.png"
ss_res = self.execute_adb_with_cmd(ss_cmd)
time.sleep(0.1)
ss_cmd = f"{self.adb_prefix_shell} screencap -p /sdcard/{ss_name}.png"
ss_res = self.execute_adb_with_cmd(ss_cmd)
time.sleep(0.1)
ss_cmd = f"{self.adb_prefix} pull /sdcard/{ss_name}.png {self.screenshot_dir}"
ss_res = self.execute_adb_with_cmd(ss_cmd)
image_path = Path(f"{self.screenshot_dir}/{ss_name}.png")
res = image_path
return Path(res)

@mark_as_readable
Expand Down Expand Up @@ -224,7 +259,94 @@ def user_swipe(self, x: int, y: int, orient: str = "up", dist: str = "medium", i
return swipe_res

@mark_as_writeable
def user_swipe_to(self, start: tuple[int, int], end: tuple[int, int], duration: int = 400):
def user_swipe_to(self, start: tuple[int, int], end: tuple[int, int], duration: int = 400) -> str:
adb_cmd = f"{self.adb_prefix_si} swipe {start[0]} {start[1]} {end[0]} {end[1]} {duration}"
swipe_res = self.execute_adb_with_cmd(adb_cmd)
return swipe_res

@mark_as_writeable
def user_exit(self) -> str:
adb_cmd = f"{self.adb_prefix_shell} am start -a android.intent.action.MAIN -c android.intent.category.HOME"
exit_res = self.execute_adb_with_cmd(adb_cmd)
return exit_res

def _ocr_text(self, text: str) -> list:
image = self.get_screenshot("screenshot", self.screenshot_dir)
iw, ih = Image.open(image).size
x, y = self.device_shape
if iw > ih:
x, y = y, x
iw, ih = ih, iw
in_coordinate, out_coordinate = ocr(image, text, self.ocr_detection, self.ocr_recognition, iw, ih)
output_list = [in_coordinate, out_coordinate, x, y, iw, ih, image]
return output_list

@mark_as_writeable
def user_open_app(self, app_name: str) -> str:
ocr_result = self._ocr_text(app_name)
in_coordinate, out_coordinate, x, y, iw, ih = (
ocr_result[0], ocr_result[1], ocr_result[2], ocr_result[3], ocr_result[4], ocr_result[5])
if len(in_coordinate) == 0:
logger.info(f"No App named {app_name}.")
return "no app here"
else:
tap_coordinate = [
(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
(in_coordinate[0][1] + in_coordinate[0][3]) / 2,
]
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
return self.system_tap(tap_coordinate[0] * x, (tap_coordinate[1] - round(50 / y, 2)) * y)

@mark_as_writeable
def user_click_text(self, text: str) -> str:
ocr_result = self._ocr_text(text)
in_coordinate, out_coordinate, x, y, iw, ih, image = (
ocr_result[0], ocr_result[1], ocr_result[2], ocr_result[3], ocr_result[4], ocr_result[5], ocr_result[6])
if len(out_coordinate) == 0:
logger.info(
f"Failed to execute action click text ({text}). The text \"{text}\" is not detected in the screenshot.")
elif len(out_coordinate) == 1:
tap_coordinate = [(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
(in_coordinate[0][1] + in_coordinate[0][3]) / 2]
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y)
else:
logger.info(
f"Failed to execute action click text ({text}). There are too many text \"{text}\" in the screenshot.")

@mark_as_writeable
def user_stop(self):
logger.info("Successful execution of tasks")

@mark_as_writeable
def user_click_icon(self, icon_shape_color: str) -> str:
screenshot_path = self.get_screenshot("screenshot", self.screenshot_dir)
image= screenshot_path
iw, ih = Image.open(image).size
x, y = self.device_shape
if iw > ih:
x, y = y, x
iw, ih = ih, iw
in_coordinate, out_coordinate = det(image, "icon", self.groundingdino_model) # 检测icon
if len(out_coordinate) == 1: # only one icon
tap_coordinate = [(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
(in_coordinate[0][1] + in_coordinate[0][3]) / 2]
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y)

else:
temp_file = Path(f"{DEFAULT_WORKSPACE_ROOT}/temp")
temp_file.mkdir(parents=True, exist_ok=True)
hash_table, clip_filter = [], []
for i, (td, box) in enumerate(zip(in_coordinate, out_coordinate)):
if crop_for_clip(image, td, i, temp_file):
hash_table.append(td)
crop_image = f"{i}.png"
clip_filter.append(temp_file.joinpath(crop_image))
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
clip_filter = clip_for_icon(clip_model, clip_preprocess, clip_filter, icon_shape_color)
final_box = hash_table[clip_filter]
tap_coordinate = [(final_box[0] + final_box[2]) / 2, (final_box[1] + final_box[3]) / 2]
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
print(tap_coordinate[0] * x, tap_coordinate[1] * y)
return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y)
43 changes: 43 additions & 0 deletions metagpt/environment/android/grounding_dino_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
batch_size = 1
modelname = "groundingdino"
backbone = "swin_T_224_1k"
position_embedding = "sine"
pe_temperatureH = 20
pe_temperatureW = 20
return_interm_indices = [1, 2, 3]
backbone_freeze_keywords = None
enc_layers = 6
dec_layers = 6
pre_norm = False
dim_feedforward = 2048
hidden_dim = 256
dropout = 0.0
nheads = 8
num_queries = 900
query_dim = 4
num_patterns = 0
num_feature_levels = 4
enc_n_points = 4
dec_n_points = 4
two_stage_type = "standard"
two_stage_bbox_embed_share = False
two_stage_class_embed_share = False
transformer_activation = "relu"
dec_pred_bbox_embed_share = True
dn_box_noise_scale = 1.0
dn_label_noise_ratio = 0.5
dn_label_coef = 1.0
dn_bbox_coef = 1.0
embed_init_tgt = True
dn_labelbook_size = 2000
max_text_len = 256
text_encoder_type = "bert-base-uncased"
use_text_enhancer = True
use_fusion_layer = True
use_checkpoint = True
use_transformer_ckpt = True
use_text_cross_attention = True
text_dropout = 0.0
fusion_dropout = 0.0
fusion_droppath = 0.1
sub_sentence_present = True

0 comments on commit d9ed99e

Please sign in to comment.