In [2]:
import time
from design import *
import importlib
import shutil
from utils import *
from openai import OpenAI
from prompts import *
import json
import numpy as np
from gymnasium.envs.robodesign.GPTCheetah import GPTCheetahEnv
import os

In [None]:
import prompts
class DGA:
    def __init__(self):
        api_key = "<api_key>"
        self.client = OpenAI(api_key=api_key)
        self.model = "gpt-4o-mini"
        
    def extract_code(self, text):
        match = re.search(r'```python\n(.*?)\n```', text, re.DOTALL)
        return match.group(1).strip() if match else None

    def indent_code(self, code):
        return "\n".join(line if line.strip() else line for line in code.split("\n"))

    def generate_rewardfunc(self, rewardfunc_nums, folder_name):

        messages = [
            {"role": "system", "content": "You are a reinforcement learning reward function designer"},
            {"role": "user", "content": rewardfunc_prompts + zeroshot_rewardfunc_format}
        ]

        responses = self.client.chat.completions.create(
            model=self.model, messages=messages, n=rewardfunc_nums
        )
        files = []
        for i, choice in enumerate(responses.choices):
            reward_code = self.extract_code(choice.message.content)
            if reward_code:
                full_code = self.indent_code(reward_code) + "\n"
                file_name =  f"GPTCheetah_{i}.py"
                file_path = os.path.join(folder_name, "env", file_name)
                with open(file_path, "w") as fp:
                    fp.write(full_code)

                with open(file_path, "w") as fp:
                    fp.write(full_code)
                files.append(file_path)
                print(f"Saved: {file_path}")
        return files
    
    def generate_rewardfunc_div(self, rewardfunc_nums, folder_name):

        messages = [
            {"role": "system", "content": "You are a reinforcement learning reward function designer"},
            {"role": "user", "content": rewardfunc_prompts + zeroshot_rewardfunc_format}
        ]

        # 生成初始 Reward Function
        response = self.client.chat.completions.create(
            model=self.model, messages=messages, n=1, timeout=10
        )

        rewardfunc_files = []

        initial_code = self.extract_code(response.choices[0].message.content)
        if initial_code:
            reward_code = "import numpy as np\n" + self.indent_code(initial_code) + "\n"

            file_path = os.path.join(folder_name, "env", "GPTrewardfunc_0.py")
            with open(file_path, "w") as fp:
                fp.write(reward_code)
            rewardfunc_files.append(file_path)
            print(f"initial Saved: {file_path}")
        messages.append({"role": "assistant", "content": initial_code})

        # 生成不同的多样化 Reward Functions
        for i in range(1, rewardfunc_nums):
            diverse_messages = messages + [
                {"role": "user", "content": rewardfunc_div_prompts + zeroshot_rewardfunc_format}
            ]
            # print(diverse_messages)
            response = self.client.chat.completions.create(
                model=self.model, messages=diverse_messages, n=1
            )
            diverse_code = self.extract_code(response.choices[0].message.content)
            messages.append({"role": "assistant", "content": diverse_code})

            if diverse_code:
                reward_code =  "import numpy as np\n" + self.indent_code(diverse_code) + "\n"
                file_path = os.path.join(folder_name, "env", f"GPTrewardfunc_{i}.py")
                with open(file_path, "w") as fp:
                    fp.write(reward_code)
                rewardfunc_files.append(file_path)
                print(f"Saved: {file_path}")

        return rewardfunc_files


    def generate_morphology(self, morphology_nums, folder_name):
        messages = [
            {"role": "system", "content": "You are a helpful mujoco robot designer"},
            {"role": "user", "content": morphology_prompts + morphology_format}
        ]
        
        responses = self.client.chat.completions.create(
            model=self.model,
            messages=messages,
            response_format={'type': 'json_object'},
            n=morphology_nums                                                                                                                                                                                                                                                                                   
        )

        # 解析所有 response 里的参数
        for i, choice in enumerate(responses.choices):
            print(f"Response {i}:")
            print(json.dumps(choice.message.content, indent=4))

        parameter_list = [json.loads(choice.message.content).get('parameters', []) for choice in responses.choices]
        material_list = [compute_cheetah_volume(parameter) for parameter in parameter_list]

        xml_files = []
        for i, parameter in enumerate(parameter_list):
            if not isinstance(parameter, list):
                print(f"Skipping invalid parameter {i}: {parameter}")
                continue

            xml_file = cheetah_design(parameter)  
            filename = f"GPTCheetah_{i}.xml"
            file_path = os.path.join(folder_name, "assets", filename)
            xml_files.append(file_path)
            with open(file_path, "w") as fp:
                fp.write(xml_file)
            print(f"Successfully saved {filename}")
            
        return xml_files, material_list, parameter_list
    
    def generate_morphology_div(self, morphology_nums, folder_name):

        material_list = []
        xml_files = []
        parameter_list = []
        
        # 生成初始 morphology
        messages = [
            {"role": "system", "content": "You are a helpful mujoco robot designer"},
            {"role": "user", "content": morphology_prompts + morphology_format}
        ]
        
        response = self.client.chat.completions.create(
            model=self.model,
            messages=messages,
            response_format={'type': 'json_object'},
            n=1
        )

        initial_parameter = json.loads(response.choices[0].message.content)
        parameter_list.append(initial_parameter['parameters'])
        material_list.append(compute_cheetah_volume(initial_parameter['parameters']))
        messages.append({"role": "assistant", "content": json.dumps(initial_parameter)})
        logging.info(f"generate initial_parameter{initial_parameter['parameters']}" )
        xml_file = cheetah_design(initial_parameter['parameters'])  

        filename = f"GPTCheetah_0.xml"
        file_path = os.path.join(folder_name, "assets", filename)
        with open(file_path, "w") as fp:
            fp.write(xml_file)

        xml_files.append(file_path)

        # 生成不同的多样化设计
        for i in range(1, morphology_nums):
            diverse_messages = messages + [
                {"role": "user", "content": morphology_div_prompts + morphology_format}
            ]
            
            response = self.client.chat.completions.create(
                model=self.model,
                messages=diverse_messages,
                response_format={'type': 'json_object'},
                n=1
            )

            diverse_parameter = json.loads(response.choices[0].message.content)
            material_list.append(compute_cheetah_volume(diverse_parameter['parameters'])) 
            parameter_list.append(diverse_parameter['parameters'])
            messages.append({"role": "assistant", "content": json.dumps(diverse_parameter)})
            logging.info(f"generate diverse_parameter{ diverse_parameter['parameters']}")
            xml_file = cheetah_design(diverse_parameter['parameters'])  
            filename = f"GPTCheetah_{i}.xml"
            file_path = os.path.join(folder_name, "assets", filename)
            with open(file_path, "w") as fp:
                fp.write(xml_file)
            xml_files.append(file_path)

        return xml_files, material_list, parameter_list


    def improve_rewardfunc(self, best_rewardfunc, rewardfunc_list, fitness_list, folder_name, rewardfunc_index, morphology_index, iteration):
        reward_improve_prompts = prompts.reward_improve_prompts

        for reward_filename, fitness in zip(rewardfunc_list, fitness_list):
            with open(reward_filename, 'r') as f:
                reward_content = f.read()
            reward_improve_prompts += f"reward function:\n{reward_content}\nfitness: {fitness}\n"
            
        with open(best_rewardfunc, 'r') as f:
            best_reward_content = f.read()

        reward_improve_prompts += f"This is best reward function, please carefully review it :\n{best_reward_content}\nbest fitness: {max(fitness_list)}"
        # print(reward_improve_prompts)
        messages = [
            {"role": "system", "content": "You are a reinforcement learning reward function designer"},
            {"role": "user", "content": rewardfunc_prompts + rewardfunc_format}
        ]

        response = self.client.chat.completions.create(
            model=self.model, messages=messages
        )

        # print(response)
        reward_code = self.extract_code(response.choices[0].message.content)

        if reward_code:
            full_code = "import numpy as np \n" + self.indent_code(reward_code) + "\n"
            file_name =  f"GPTSwimmer_refine_{rewardfunc_index}_{morphology_index}_{iteration}.py"
            file_path = os.path.join(folder_name, "env", file_name)
            with open(file_path, "w") as fp:
                fp.write(full_code)

        return file_path

    def improve_morphology(self, best_parameter, parameter_list, fitness_list, folder_name, rewardfunc_index, morphology_index, iteration):
        morphology_improve_prompts = prompts.morphology_improve_prompts
        for parameter_content, fitness in zip(parameter_list, fitness_list):
            morphology_improve_prompts = morphology_improve_prompts + f"parameter:{parameter_content} \n" + f"fintess:{fitness}"
        morphology_improve_prompts = morphology_improve_prompts + f"best parameter:{best_parameter} \n" + f"best fintess:{max(fitness_list)}" 

        messages = [
            {"role": "system", "content": "You are a helpful mujoco robot designer"},
            {"role": "user", "content": morphology_improve_prompts + morphology_format}
        ]
        
        responses = self.client.chat.completions.create(
            model=self.model,
            messages=messages,
            response_format={'type': 'json_object'},
        )
        # print(responses)
        parameter = json.loads(responses.choices[0].message.content).get('parameters', []) 
        print(parameter)
        xml_file = cheetah_design(parameter)  
        filename = f"GPTCheetah_refine_{rewardfunc_index}_{morphology_index}_{iteration}.xml"
        file_path = os.path.join(folder_name, "assets", filename)

        with open(file_path, "w") as fp:
            fp.write(xml_file)

        print(f"Successfully saved {filename}")
        return file_path, parameter


# Configuration

In [4]:

folder_name = "results/Random_m25_r5"
log_file = os.path.join(folder_name, "parameters.log")
logging.basicConfig(filename=log_file, level=logging.INFO, format="%(asctime)s - %(message)s")

# folder_name = setup_logging(div_flag=True)

best_fitness = float('-inf')  
best_morphology = None  
best_rewardfunc = None  
best_reward = None
best_material = None
best_efficiency = None

morphology_nums = 25
rewardfunc_nums = 5

fitness_matrix = np.array([[None for _ in range(morphology_nums)] for _ in range(rewardfunc_nums)])
efficiency_matrix = np.array([[None for _ in range(morphology_nums)] for _ in range(rewardfunc_nums)])
fitness_list = []
designer = DGA()


# print configuration info

In [4]:
logging.info(f"start!")

In [8]:
designer = DGA()
morphology_list, material_list, parameter_list = designer.generate_morphology_div(morphology_nums, folder_name)

In [9]:
designer = DGA()
rewardfunc_list = designer.generate_rewardfunc_div(rewardfunc_nums, folder_name)

initial Saved: results/Div_m50_r10\env\GPTrewardfunc_0.py
Saved: results/Div_m50_r10\env\GPTrewardfunc_1.py
Saved: results/Div_m50_r10\env\GPTrewardfunc_2.py
Saved: results/Div_m50_r10\env\GPTrewardfunc_3.py
Saved: results/Div_m50_r10\env\GPTrewardfunc_4.py
Saved: results/Div_m50_r10\env\GPTrewardfunc_5.py
Saved: results/Div_m50_r10\env\GPTrewardfunc_6.py
Saved: results/Div_m50_r10\env\GPTrewardfunc_7.py
Saved: results/Div_m50_r10\env\GPTrewardfunc_8.py
Saved: results/Div_m50_r10\env\GPTrewardfunc_9.py
Saved: results/Div_m50_r10\env\GPTrewardfunc_10.py


# enter coarse optimization stage

In [7]:
morphology_list = [f'results/Random_m25_r5/assets/GPTCheetah_{i}.xml' for i in range(0,25) ]
rewardfunc_list = [f'results/Random_m25_r5/env/GPTrewardfunc_{i}.py' for i in range(0,5)]

parameter_list = np.array([[-0.5999967320245699, 0.49084053751554363, 0.6511745036991378, 0.1614083168735254, 0.18830934934836743, -0.11349168768545505, -0.3475450932508038, -0.10357132549708853, -0.0744812872354166, -0.31010973094184535, -0.17759714219158085, -0.16225314007868108, 0.23512113856022238, -0.028947983847077252, 0.10684882355721198, -0.15868590213791245, 0.016905166658045104, 0.018187539672510133, 0.048757529286982945, 0.072896505140696, 0.12790632987128808, 0.0722737813430202, 0.06012375155551704, 0.06125134139068388], [-0.5231398071058978, 0.6282396166313019, 0.8738025202027287, 0.23883778214376641, 0.14078411248327305, -0.2713951519567949, -0.31908999003349814, 0.017574979957207665, -0.11010272159891002, -0.27951753816070857, -0.23678062294768046, -0.21318739742586781, 0.21597124820200372, -0.0858867464655582, 0.09850072824867669, -0.14420757279358135, 0.06477312079162964, 0.12933514174783511, 0.039537803469185014, 0.04594792892714418, 0.1763562338584656, 0.1461928071143182, 0.05914706793250758, 0.10492367486360436], [-0.5639770701416063, 0.538126285619105, 0.7396743651706692, 0.21406694962648923, 0.24539857425921646, -0.15544747913926332, -0.20308094569412477, -0.05176623954213938, -0.10115706092077395, -0.07027539998578439, 0.10453720585798545, -0.123454089673045, 0.059108333374711435, -0.20314653245097497, 0.0654298550281028, -0.24277138580771684, 0.3560156682904343, 0.037417826800615714, 0.026268435465092943, 0.07267995915436117, 0.09157452396559769, 0.005509090591452645, 0.2355573684067816, 0.20971361423906065], [-0.4787653805360263, 0.49744618804164686, 1.0997386796431923, 0.02447160032778184, 0.02579303707483044, -0.3009992684923034, -0.10318091617557881, -0.15459888648939418, 0.05460397275736416, -0.27339956168678714, -0.11797053783729211, -0.28278525792287146, 0.008082623790246171, -0.2329960970423259, 0.04862671104263769, -0.10994860739188304, 0.08706753684852597, 0.05134914398795494, 0.004674490654662321, 0.08788217326642542, 0.07745032234576407, 0.030424072579060768, 0.0847250084813206, 0.04295458861275301], [-0.5377005746194281, 0.5391952269055693, 0.9233258650328013, 0.16424987743121344, 0.041973691903114646, -0.15870677059838506, -0.22537432417377332, -0.0561899850967089, 0.13350178767884235, -0.06445309567676993, -0.2796106684716757, -0.2529353930573327, 0.08989005058130407, -0.21114795134002232, 0.05705864555675591, -0.09987379830421819, 0.18746317007882057, 0.04428721811959234, 0.024177319893483432, 0.013828969595350273, 0.0821069106285293, 0.03552861200609467, 0.13867369867305013, 0.03933125723296012], [-0.5728232274006421, 0.4814066737100376, 0.7790485872746069, 0.3208155359211797, 0.28982981764352134, -0.18735289793053558, -0.24919477364700987, -0.037206506004074755, -0.019806383198797337, -0.13912125714620582, -0.43125776023888734, -0.4130451799449614, 0.10082161498274095, -0.3561249187087504, 0.16685950555235213, -0.17743398401303417, 0.16408144943042235, 0.15001566355839774, 0.007532207781009728, 0.08761472402623861, 0.0916405891096633, 0.09269716506406613, 0.15842613604281078, 0.2300653924283344], [-0.8396726328954147, 0.6447948556922497, 0.8652800763254329, 0.08960399824969899, 0.22567450191232485, -0.2846782267274637, -0.31791003506588933, -0.1796513570411655, 0.09012792414691141, -0.14216124172673383, -0.09614339873700589, -0.25189389087961367, 0.09108765587333673, -0.21748354020204166, 0.1823572900008653, -0.14559221461897556, 0.16320527308165078, 0.01528047113426513, 0.22494525765210294, 0.06281137703015634, 0.14861620751502896, 0.11925533153932184, 0.13955226385884875, 0.17248083333165437], [-0.7785514798033687, 0.5388019514792561, 0.8055142893829808, 0.28425542993411534, 0.2633490211735402, -0.1024188249937622, -0.3621871407058539, -0.19186030788216618, 0.08149269196101677, -0.07461191026431617, -0.2761326393329029, -0.1873893706340301, 0.11628257666053043, -0.1773665874689818, 0.0575493025838696, -0.10341366263002547, 0.1111311257790452, 0.06058254044930694, 0.06332510249356711, 0.023869765436529952, 0.18377139917029733, 0.016256065393853103, 0.06792025238952128, 0.04474332521008602], [-0.599442737858779, 0.8057422673289204, 0.8055929833744548, 0.3426795887676316, 0.07015445199918519, -0.3389802949641335, -0.4833656630294757, -0.2091533825426115, -0.03989154757023648, -0.17761647053127266, -0.2427417801895812, -0.02985211584079872, 0.0766332918890465, -0.11470105047241236, -0.011146896241956375, -0.16565172278546225, 0.03164795828373698, 0.23747552105764574, 0.16628122330670653, 0.013170635809467937, 0.12192985755747018, 0.03318152714316672, 0.12043117292851145, 0.05457504836566177], [-0.6000072688861986, 0.6492994381816628, 0.762362632049048, 0.35816052309911106, 0.0997387988647219, -0.17897569390313908, -0.3561536063859976, 0.08555844223100798, -0.1607796018681618, -0.2087404872497639, -0.1701066969588162, -0.13348611446630493, 0.20823347938858944, -0.22280309850992802, -0.03273985947286191, -0.023677816564153797, 0.16845996280823167, 0.053151763795165466, 0.010172249229903597, 0.04503266567025078, 0.22971409951860383, 0.1643885389739264, 0.021355919347901736, 0.017114490485100293], [-0.7443487600265639, 0.6396012270473944, 0.8855029279562611, 0.19521354816121822, 0.06915544104321246, -0.31162289137967547, -0.041986491176203594, -0.3363197030153834, 0.15908199362058736, -0.09556748858682693, -0.20534965190072318, -0.3220284422078046, 0.07163505622827657, -0.2563558791696804, 0.021011485688540714, -0.3467816720410479, 0.03082898429812239, 0.021334817379955295, 0.007230075083742085, 0.06785675644697822, 0.12580335711358065, 0.06010519939089749, 0.09983789709669573, 0.0623602580164525], [-0.4137651143784302, 0.5152343897768821, 0.8524658188461917, 0.11259040959016636, 0.15720808083425236, -0.31470263738147497, -0.17522730464258102, -0.06091639156120622, -0.012579148650063918, -0.24650924294949217, -0.18525407057978663, -0.2228621039905708, 0.18245550516611725, -0.35014103365628396, 0.035698369926271636, -0.11169361347336072, 0.20043024048875174, 0.04812627138856683, 0.013601431711506871, 0.13919549710588403, 0.009658033776319165, 0.0981221431118132, 0.13716635707531866, 0.09987270873012905], [-0.5470509700592858, 0.6604838187789683, 0.6140846683751607, 0.38253389419069905, 0.20255304294727453, -0.2177371718991925, -0.2669221726704434, -0.216426310160282, 0.0253094532576292, -0.08899751530019472, -0.12857879061893063, -0.2230762622714756, 0.10792762876576534, -0.2975930129015766, 0.06421548118629024, -0.27259602342566974, 0.17952194062263144, 0.20507091225684015, 0.21349451397540536, 0.10165958987760233, 0.17463308835837166, 0.1491415880057384, 0.06283761621542028, 0.18822816094899736], [-0.5635540777798165, 0.4586746318054781, 0.7488697311444383, 0.3448633296438979, 0.30677572045144086, -0.09743882865349218, -0.4267834244949335, -0.18354783462902993, 0.07564787967027853, -0.1508595424512788, -0.18444582049451017, -0.2887644762625777, 0.02598757885586772, -0.30424555586728463, 0.08085886791299333, -0.0690747149093837, 0.0912986026239513, 0.01350381848096395, 0.09776786917087994, 0.09193439634527677, 0.02477762592968455, 0.06647232075126229, 0.09320571846622495, 0.011695502465502099], [-0.6567725509961825, 0.40165308587162163, 0.8292129305710837, 0.17418664121084318, 0.33453833089471996, -0.2952601718883504, -0.2696342689662281, -0.35980401783269933, 0.08857112674491696, -0.17697683933164893, -0.048721019120439274, -0.050641147372706946, 0.11342088447616887, -0.09197768741416128, -0.03385452026999908, -0.05397082037813076, 0.11449031335306394, 0.09074172603169922, 0.04426423080094815, 0.13280485631354882, 0.01118211156803868, 0.11221142517866818, 0.16023922570041793, 0.13294044256282372], [-0.5729572879927375, 0.6102931392056209, 0.8454598884474464, 0.3070315608146493, 0.36598274778036344, -0.17968779115334205, -0.3125300243291903, -0.025688165450229283, 0.14877693116291574, -0.1897435130713629, -0.17254798277899905, -0.16048406987248037, 0.13144616973582982, -0.09260077259333309, 0.12325835119585414, 0.023912582279511138, 0.07368368786790719, 0.03595560302089988, 0.13904488447439234, 0.10988518760542643, 0.03722631565552115, 0.008553416480552398, 0.006106930808757252, 0.1577657908128085], [-0.5881204685791764, 0.711502581939375, 0.8648994529979136, 0.22349650361697232, 0.23620823809489114, -0.24521312060744133, -0.28559738181855615, -0.07582424179113001, 0.04694758559139188, -0.21319238365187107, -0.12479690639289075, -0.31865226327770013, 0.14479076351016748, -0.2136581310460674, 0.23082689554656904, -0.20156760204596613, 0.013288028097562847, 0.004757933889477643, 0.06554750625385726, 0.19140918549464042, 0.08221334968720029, 0.07804846205530011, 0.12465977437755328, 0.03238695975833158], [-0.6713967900332695, 0.6646739791344567, 1.012124797207203, 0.332939545412835, 0.1362820490956276, -0.17977012600670766, -0.2859807817175208, -0.07016325517329813, 0.07851049122827754, -0.2399722494962443, -0.06925953625961741, -0.32870605512044104, 0.1368364055489497, -0.2528922068321138, 0.1605775235468842, -0.18373226150147137, 0.11387971273989089, 0.05918456128720251, 0.05144096729927786, 0.19210913174895433, 0.030207349161844147, 0.10832089770538533, 0.020002235853296677, 0.04210904669846152], [-0.5363391735596659, 0.49863079864798826, 0.8068861988816155, 0.32655610189357936, 0.14175500483541714, -0.17354721688579522, -0.3238597258216502, -0.14520731936183304, 0.014325611889942656, -0.22170354478420776, -0.1847883277953605, -0.25681516668790483, 0.19752275233883643, -0.09707926361046612, 0.281232824079261, -0.16811281611524237, 0.0658029217815687, 0.06769770171051287, 0.09846958761778708, 0.09125553714940861, 0.003184789273309077, 0.04807580469985205, 0.06936659529665185, 0.1689511014373181], [-0.42019965985611674, 0.709199603670981, 0.6276960161034528, 0.15230175058293474, 0.20586816600141217, -0.11546419183558677, -0.38157144971052964, 0.002451791552547933, 0.183739529346835, -0.24744360731430887, -0.12873403182656454, -0.0588641646247037, 0.17536492196054645, -0.2279490186162561, 0.033558005444344376, -0.06579088020197386, 0.09292787456121211, 0.014105422481821811, 0.1481623669679546, 0.14722397171308899, 0.027264764936932546, 0.068554796600844, 0.02138824806925929, 0.058360404181433706], [-0.6280774271301617, 0.6595886241927091, 0.9044792323144125, 0.17447170627765773, 0.14886097527741504, -0.22661502685376936, -0.41511185812555385, -0.28286622523015137, 0.09896957667172387, -0.269522154576286, -0.10985201399909467, -0.3443532915129136, -0.0165910967567397, -0.28463282610771085, 0.03888802995615088, -0.027996249953625985, 0.07737659344616331, 0.021053517611094655, 0.014186252516133556, 0.0533711601318735, 0.10668028379497392, 0.010867396761029234, 0.1403221645682666, 0.13446324737098564], [-0.5389060509831135, 0.7404345315826613, 0.8978850594413325, 0.06803543446570681, 0.0355396612316187, -0.21999083815467152, -0.34162475503385276, -0.09654764132716151, 0.08965923902769521, -0.1620884201723471, -0.17707409073790442, -0.3074423544224512, 0.09621885678039335, -0.10781354448707735, 0.1622426681758274, -0.137883963432536, 0.033862743050501194, 0.08426583160669655, 0.023204619944538023, 0.016347570869922834, 0.06791374824804813, 0.104300321836181, 0.03299586198785628, 0.05022321536067799], [-0.5473321864635754, 0.6263109248339055, 0.6504173672715937, 0.269148669241662, 0.16439789366316768, -0.19428458694424344, -0.3174685172365057, -0.0560293425082625, 0.15120642431389425, -0.22802321456681163, -0.2180006169410217, -0.1594125848117451, 0.12555368200395683, -0.32035714537544424, 0.08270292223123339, 0.015706441264956725, 0.11132761693065005, 0.007846702882341287, 0.11707858829492405, 0.043022700783573495, 0.07286179473546275, 0.043310069222946494, 0.23192861181740093, 0.06970529706879817], [-0.49781440690611584, 0.7902426090343498, 0.7644400556120129, 0.235900971199939, 0.006853349063400738, -0.1675184375591376, -0.21051839890622553, -0.1511311320851399, 0.15901789580250747, -0.1194874474703072, -0.3200550379762468, -0.2926979840345907, 0.32761005777511953, -0.11651303789619755, 0.09389943655870538, -0.26878137596260493, 0.048800466586609526, 0.05121592213678731, 0.05916442848619124, 0.04780769391227128, 0.1483752117410613, 0.04648398837882287, 0.14161055604529216, 0.036846011443462515], [-0.4521544289718611, 0.5043160924286081, 0.858065074458278, 0.09219904606179319, 0.31268864243669714, -0.13759847177786538, -0.15992587242882822, -0.09565321307426466, -0.12843299783945294, -0.22527456432883305, -0.004337664563801857, -0.1927863359616791, 0.11997365556895132, -0.2981212507294918, 0.23469788392414437, -0.0789009636774621, 0.046247666733065615, 0.13008058329878092, 0.07484290880484433, 0.142196954641706, 0.10012063657484523, 0.03308160145841314, 0.08231618283087827, 0.1949574463960267], [-0.5077588732786983, 0.6628917083654555, 0.7492151076373185, 0.0434592593384745, 0.17491654593501865, -0.07524492576904165, -0.37201864104028964, -0.16198152831043328, 0.10275968783736665, -0.25117848114640057, -0.0975860955874385, -0.12361710921959418, 0.1608302140359406, -0.07296773670485705, 0.1820086745726789, 0.002444789785485013, 0.13695070370134105, 0.1627832597252991, 0.1831688175064085, 0.019968919199436636, 0.008363621590611478, 0.17341921987109032, 0.05602240998695383, 0.2870512152072602]])

material_list = [compute_cheetah_volume(parameter) for parameter in parameter_list]
parameter_list.shape

(26, 24)

In [9]:
logging.info(f'folder_name:{folder_name}')
logging.info(f'morphology_nums:{morphology_nums}')
logging.info(f'rewardfunc_nums:{rewardfunc_nums}')
logging.info(f'parameter_list:{parameter_list}')
logging.info(f'morphology_list:{morphology_list}')
logging.info(f'material_list:{material_list}')
logging.info(f'_________________________________enter coarse optimization stage_________________________________')

In [None]:
for i, rewardfunc in enumerate(rewardfunc_list):
    for j, morphology in enumerate(morphology_list):
        # if i not in [5]:
        #     continue
        # if j < 24:
        #     continue
        print(i, rewardfunc)
        print(j, morphology)
        shutil.copy(morphology, "GPTCheetah.xml")
        shutil.copy(rewardfunc, "GPTrewardfunc.py")         

        import GPTrewardfunc
        importlib.reload(GPTrewardfunc)  # 重新加载模块
        from GPTrewardfunc import _get_rew
        GPTCheetahEnv._get_rew = _get_rew

        model_path = Train(j,  i, folder_name, total_timesteps=5e5)
        # model_path = f"results/Div_m50_r10/SAC_morphology{j}_rewardfunc{i}_500000.0steps"
        fitness, reward = Eva(model_path)
        material = material_list[j]
        efficiency = fitness/material
        fitness_matrix[i][j] = fitness
        efficiency_matrix[i][j] = efficiency
        
        logging.info("___________________finish coarse optimization_____________________")
        logging.info(f"morphology: {j}, rewardfunc: {i}, material cost: {material} reward: {reward} fitness: {fitness} efficiency: {efficiency}")

        if fitness > best_fitness:
            best_fitness = fitness
            best_morphology = morphology
            best_efficiency = efficiency
            best_rewardfunc = rewardfunc
            best_material = material

0 results/Random_m25_r5/env/GPTrewardfunc_0.py
0 results/Random_m25_r5/assets/GPTCheetah_0.xml




0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
0 results/Random_m25_r5/env/GPTrewardfunc_0.py
1 results/Random_m25_r5/assets/GPTCheetah_1.xml
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
0 results/Random_m25_r5/env/GPTrewardfunc_0.py
2 results/Random_m25_r5/assets/GPTCheetah_2.xml
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79


In [16]:
fitness_matrix = np.array([[105.74359155598759, 212.16867845995213, 9.290034086806587,
        213.1820754226635, 269.06758857919203, 110.76376441773273,
        227.2402234229139, 321.25174702953274, 310.0811819898056,
        207.08820466541306, 157.71647852423584, 28.08470929895438,
        94.86410707841213, 253.05393914809926, 225.60727228903642,
        291.2915379578935, 247.53037681549188, 252.75206057050585,
        130.7053159556534, 172.40977205595496, 164.125278051098,
        66.37884303244769, 59.953225963208205, 218.33154323364386,
        112.65127918940892],
       [105.6446095028666, 207.8307572778086, 5.633397217174836,
        209.52737837995343, 256.27071821418525, 72.18829568623416,
        267.1473503660532, 322.9122158475668, 279.144256206135,
        205.0025752142732, 167.69240416813489, 27.376293246634024,
        89.68133817497701, 228.19161294164874, 233.2704445704459,
        298.8228611890579, 249.82103821343117, 259.9490427890532,
        178.7277534791408, 169.82829760017438, 176.13625411283317,
        78.68592541253466, 48.308555471960155, 149.70598918140874,
        89.87861758359404],
       [125.17124788689142, 214.2513368583829, 4.868965549513075,
        202.41627693346328, 254.4737214394124, 69.22400606334585,
        336.63737802502794, 319.4380007302696, 309.3335305279719,
        208.34085037176607, 171.05853721775173, 29.658131947429982,
        86.31347647843566, 249.72544832912862, 212.2872505956001,
        267.5210731621721, 230.80415113617448, 263.14848577469314,
        136.5487390614856, 160.9237609675422, 172.46141597580115,
        75.74329257272174, 58.77683787805382, 199.8957066710696,
        99.78584956921254],
       [112.8205538474683, 224.97327674499462, 21.70075190116985,
        214.41106661725283, 244.3399597095655, 99.22655751185818,
        295.1417738423967, 353.2173743183276, 308.898540963276,
        208.92277655194516, 162.43759882605116, 25.48123817719553,
        86.7254319171109, 238.21218470292064, 230.40583318427772,
        281.393058589059, 238.75803987061042, 259.3941457213482,
        138.2146037936391, 165.60970035355496, 163.23004333747045,
        96.56578439025952, 54.460860389646335, 179.2580279667622,
        118.19592407825623],
       [117.60862070463945, 212.21016572672113, 21.712176086295955,
        208.11996689703884, 240.25351420479572, 81.64362509108406,
        333.49311971772744, 328.9152796334914, 301.68046041508757,
        213.3196349745808, 150.41631228898237, 26.622859737286884,
        92.04154281573032, 236.9872183076211, 227.26118507933046,
        -12.051344701205648, -2.180247818943431, -1.3176172236223274,
        123.50802782421957, 167.675083226677, 160.72618149409607,
        73.79990332360931, 56.57118383831371, 212.0702707048405,
        93.52341280601371]], dtype=object)

In [17]:
efficiency_matrix= np.array([[2156.697086350333, 1189.4477551713346, 11.487063556590321,
        3546.531444728028, 1445.5754298063487, 312.82675558925774,
        617.8954380396254, 2755.7227262262363, 1039.4635114564423,
        755.0957052037743, 3007.957513486344, 119.36055242053637,
        188.1324618206517, 3254.151031213589, 1376.2503955314417,
        2727.651219419602, 2306.321460378778, 1578.600804187415,
        1193.9094823516284, 1567.4984676003303, 1815.46429369211,
        1280.2776158158292, 311.0763274660367, 2471.540013775534,
        638.0586820838062],
       [2154.6782944554693, 1165.1287527169513, 6.965657097545765,
        3485.731314336072, 1376.8237771983415, 203.87922394796564,
        726.4080565893023, 2769.966358207562, 935.7558136774174,
        747.4909754039933, 3198.217661159728, 116.34977063004507,
        177.8541057291869, 2934.4335641176126, 1422.9973100981133,
        2798.1744592615305, 2327.664301643333, 1623.5506332502412,
        1632.5637414055072, 1544.0284681604367, 1948.3231587700539,
        1517.6496664182262, 250.65620373170188, 1694.6902728010537,
        509.0739554450976],
       [2552.9345243542384, 1201.1234340452536, 6.020442573137122,
        3367.429881926116, 1367.1693465076366, 195.50727026591298,
        915.3603927248697, 2740.1642679681754, 1036.9572116259394,
        759.6631666607767, 3262.416312389655, 126.04762881195992,
        171.1750346710671, 3211.3482522195054, 1294.9955452862177,
        2505.064810831267, 2150.477746440982, 1643.5332330307647,
        1247.2854158738176, 1463.071064412974, 1907.6741039620397,
        1460.8938269930732, 304.9724610047189, 2262.844068714276,
        565.1886789469378],
       [2301.035516033483, 1261.2321523621258, 26.83283118073301,
        3566.9771407757057, 1312.7253421385744, 280.24257046428096,
        802.5284999510817, 3029.926388594838, 1035.4990264582223,
        761.7850159476968, 3097.99838566912, 108.29575028937742,
        171.99201585842297, 3063.293261145323, 1405.5225961260965,
        2634.9619517873098, 2224.5867286099397, 1620.0849406052707,
        1262.5020249726242, 1505.6742342870064, 1805.5616956501526,
        1862.5062830352865, 282.5783628565424, 2029.222998878504,
        669.4636411384853],
       [2398.690699559608, 1189.6803386831175, 26.846957107441096,
        3462.317389549329, 1290.7707646729978, 230.58362530427564,
        906.810749379011, 2821.4611110138003, 1011.3023586551369,
        777.8170681675966, 2868.7292598346135, 113.14766378498095,
        182.5348129338457, 3047.5407869025366, 1386.339601028774,
        -112.8486783390227, -20.31408184507945, -8.229375476213677,
        1128.16685754323, 1524.452082259429, 1777.8653418223828,
        1423.4108333042627, 293.52809337779524, 2400.6616360436888,
        529.7181350124886]], dtype=object)
mean = np.mean(efficiency_matrix)

std = np.std(efficiency_matrix)

print("平均值：", mean)
print("标准差：", std)

平均值： 1451.4350419276939
标准差： 1027.8972760848528


In [18]:
none_coords = np.argwhere(efficiency_matrix == None)
print(none_coords)

[]


In [19]:
efficiency_matrix_select = efficiency_matrix[:10, :50]
efficiency_matrix_select.shape

(5, 25)

# print coarse optimization info

In [20]:
logging.info(f'_________________________________end coarse optimization stage_________________________________')
logging.info(f"Stage1: Final best morphology: {best_morphology}, Fitness: {best_fitness}, best_efficiency: {best_efficiency}, best reward function: {best_rewardfunc}, Material cost: {best_material}, Reward: {best_reward}")
logging.info(f'folder_name:{folder_name}')
logging.info(f'parameter_list:{parameter_list}')
logging.info(f'fitness_matrix:{fitness_matrix}')
logging.info(f'efficiency_matrix:{efficiency_matrix}')
logging.info(f'_________________________________enter fine optimization stage_________________________________')

# configuration of fine optimization

In [21]:
# 获取矩阵中所有非 None 的值和它们的坐标
all_values_with_coords = []
for i in range(len(efficiency_matrix_select)):
    for j in range(len(efficiency_matrix_select[0])):
        value = efficiency_matrix_select[i][j]
        if value is not None:
            all_values_with_coords.append(((i, j), value))

# 按值降序排序
sorted_values = sorted(all_values_with_coords, key=lambda x: x[1], reverse=True)


top_k = max(1, int(len(sorted_values) * 0.1))

efficiency_coarse_best = [coord for coord, val in sorted_values[:top_k]]

logging.info(f"fitness_coarse_best {efficiency_coarse_best}")
logging.info(f"fitness_coarse_best values {sorted_values[:top_k]}")


In [22]:
coarse_best = efficiency_coarse_best
coarse_best

[(3, 3),
 (0, 3),
 (1, 3),
 (4, 3),
 (2, 3),
 (2, 10),
 (0, 13),
 (2, 13),
 (1, 10),
 (3, 10),
 (3, 13),
 (4, 13)]

# enter fine optimization stage

In [23]:
final_optimized_results = []  # 用来记录每个 coarse_best 的最优结果

for rewardfunc_index, morphology_index in coarse_best:
    
    morphology = morphology_list[morphology_index]
    parameter = parameter_list[morphology_index]
    rewardfunc = rewardfunc_list[rewardfunc_index]
    
    best_efficiency = efficiency_matrix_select[rewardfunc_index][morphology_index]
    best_fitness = fitness_matrix[rewardfunc_index][morphology_index]
    best_morphology = morphology
    best_parameter = parameter
    best_rewardfunc = rewardfunc
    best_material = compute_cheetah_volume(parameter)
    
    
    logging.info(f"Initial morphology:{morphology}")
    logging.info(f"Initial parameter:{parameter}" )
    logging.info(f"Initial rewardfunc:{rewardfunc}" )
    logging.info(f"Initial fitness:{best_fitness}" )
    logging.info(f"Initial efficiency:{best_efficiency}" )
    iteration = 0

    while True:
        improved = False  # 标记是否有改进，方便控制循环

        designer = DGA()

        # -------- 优化 morphology --------
        improved_morphology, improved_parameter = designer.improve_morphology(
            best_parameter,
            parameter_list,
            efficiency_matrix_select[rewardfunc_index, :],
            folder_name,
            rewardfunc_index, 
            morphology_index,
            iteration
            
        )

        shutil.copy(improved_morphology, "GPTCheetah.xml")
        shutil.copy(best_rewardfunc, "GPTrewardfunc.py")
        
        import GPTrewardfunc
        importlib.reload(GPTrewardfunc)  # 重新加载模块
        from GPTrewardfunc import _get_rew
        GPTCheetahEnv._get_rew = _get_rew
        
        model_path = Train(morphology_index, rewardfunc_index, folder_name, stage='fine', total_timesteps=5e5)
        improved_fitness, _ = Eva(model_path)
        improved_material = compute_cheetah_volume(improved_parameter)
        improved_efficiency = improved_fitness / improved_material

        if improved_efficiency > best_efficiency:

            best_fitness = improved_fitness
            best_morphology = improved_morphology
            best_parameter = improved_parameter
            best_material = improved_material
            best_efficiency = improved_efficiency
            improved = True
            iteration +=1
            logging.info(f"Morphology optimization improved iteration {iteration}: material={improved_material}, fitness={improved_fitness}, efficiency={improved_efficiency}")

        # -------- 没有进一步改进，跳出循环 --------
        if not improved:
            logging.info("Not improved Morphology!")
            logging.info("____________________________________________")
            break
            
        
        # -------- 优化 reward function --------
        improved_rewardfunc = designer.improve_rewardfunc(
            best_rewardfunc,
            rewardfunc_list,
            efficiency_matrix_select[:, morphology_index],
            folder_name,
            rewardfunc_index, 
            morphology_index,
            iteration
        )

        shutil.copy(best_morphology, "GPTCheetah.xml")
        shutil.copy(improved_rewardfunc, "GPTrewardfunc.py")
        
        import GPTrewardfunc
        importlib.reload(GPTrewardfunc)  # 重新加载模块
        from GPTrewardfunc import _get_rew
        GPTCheetahEnv._get_rew = _get_rew
        
        model_path = Train(morphology_index, rewardfunc_index, folder_name, stage='fine', total_timesteps=5e5)
        improved_fitness, _ = Eva(model_path)
        improved_material = compute_cheetah_volume(best_parameter)
        improved_efficiency = improved_fitness / improved_material


        if improved_efficiency > best_efficiency:
            best_fitness = improved_fitness
            best_rewardfunc = improved_rewardfunc
            best_material = improved_material
            best_efficiency = improved_efficiency
            improved = True
            iteration +=1
            logging.info(f"Reward optimization improved iteration {iteration}: material={improved_material}, fitness={improved_fitness}, efficiency={improved_efficiency}")
        
        if not improved:
            logging.info("Not improved Reward!")
            logging.info("____________________________________________")
            break
        
            
    # 保存当前 coarse_best 的最终最优结果
    final_optimized_results.append({
        "best_morphology": best_morphology,
        "best_parameter": best_parameter,
        "best_rewardfunc": best_rewardfunc,
        "best_fitness": best_fitness,
        "best_material": best_material,
        "best_efficiency": best_efficiency,
        "best_iteration":iteration
    })

    logging.info(f"Final optimized result: rewardfunc_index{rewardfunc_index} morphology_index{morphology_index}")
    logging.info(f"  Morphology: {best_morphology}")
    logging.info(f"  Parameter: {best_parameter}")
    logging.info(f"  Rewardfunc: {best_rewardfunc}")
    logging.info(f"  Fitness: {best_fitness}")
    logging.info(f"  Material: {best_material}")
    logging.info(f"  Efficiency: {best_efficiency}")
    logging.info("____________________________________________")

[-0.47876538, 0.49744619, 1.09973868, 0.0244716, 0.02579304, -0.30099927, -0.10318092, -0.15459889, 0.05460397, -0.27339956, -0.11797054, -0.28278526, 0.00808262, -0.2329961, 0.04862671, -0.10994861, 0.08706754, 0.05134914, 0.00467449, 0.08788217, 0.07745032, 0.03042407, 0.08472501, 0.04295459]
Successfully saved GPTCheetah_refine_3_3_0.xml




0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
[-0.47876538, 0.49744619, 1.09973868, 0.0244716, 0.02579304, -0.30099927, -0.10318092, -0.15459889, 0.05460397, -0.27339956, -0.11797054, -0.28278526, 0.00808262, -0.2329961, 0.04862671, -0.10994861, 0.08706754, 0.05134914, 0.00467449, 0.08788217, 0.07745032, 0.03042407, 0.08472501, 0.04295459]
Successfully saved GPTCheetah_refine_3_3_1.xml
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28


In [24]:

logging.info(f"{final_optimized_results}")

# logging.info(f"fine optimization end: best material cost: {best_material}  fitness: {improved_fitness} merterial_efficiency: {improved_material_efficiency}")

In [25]:
final_optimized_results

[{'best_morphology': 'results/Random_m25_r5/assets/GPTCheetah_refine_3_3_0.xml',
  'best_parameter': [-0.47876538,
   0.49744619,
   1.09973868,
   0.0244716,
   0.02579304,
   -0.30099927,
   -0.10318092,
   -0.15459889,
   0.05460397,
   -0.27339956,
   -0.11797054,
   -0.28278526,
   0.00808262,
   -0.2329961,
   0.04862671,
   -0.10994861,
   0.08706754,
   0.05134914,
   0.00467449,
   0.08788217,
   0.07745032,
   0.03042407,
   0.08472501,
   0.04295459],
  'best_rewardfunc': 'results/Random_m25_r5/env/GPTrewardfunc_3.py',
  'best_fitness': 217.91934798441721,
  'best_material': 0.06011001952157502,
  'best_efficiency': 3625.3414941281194,
  'best_iteration': 1},
 {'best_morphology': 'results/Random_m25_r5/assets/GPTCheetah_3.xml',
  'best_parameter': array([-0.47876538,  0.49744619,  1.09973868,  0.0244716 ,  0.02579304,
         -0.30099927, -0.10318092, -0.15459889,  0.05460397, -0.27339956,
         -0.11797054, -0.28278526,  0.00808262, -0.2329961 ,  0.04862671,
         -0