In [None]:
import gymnasium as gym
import numpy as np
from stable_baselines3 import SAC, TD3, A2C
import matplotlib.pyplot as plt
import pickle
import os
import argparse
import asyncio

In [None]:
models_dir = 'models'
logs_dir = 'logs'
if not os.path.exists(models_dir):
    os.makedirs(models_dir)

if not os.path.exists(logs_dir):
    os.makedirs(logs_dir)

In [None]:
def train(env, sb3_algo, max_iters=4, model=None, curr_name=None):
    if model is None:
        match sb3_algo:
            case 'SAC':
                model = SAC('MlpPolicy', env, verbose=1, tensorboard_log=logs_dir)
            case 'TD3':
                model = TD3('MlpPolicy', env, verbose=1, tensorboard_log=logs_dir)
            case 'A2C':
                model = A2C('MlpPolicy', env, verbose=1, tensorboard_log=logs_dir)
            case _:
                print('Invalid algorithm')
                return
        name = f'{models_dir}/{sb3_algo}'
    else:
        if curr_name is None:
            print('Please provide a name for the model')
            return
        name = curr_name
        model.set_env(env)

    TIMESTEPS = 25000
    iters = 0
    while True and iters < max_iters:
        iters +=1
        model.learn(total_timesteps=TIMESTEPS, reset_num_timesteps=False)
        model.save(f'{name}_{TIMESTEPS*iters}')

In [None]:
gymenv = gym.make('Humanoid-v4', render_mode=None)
# train(gymenv, 'SAC')
# train(gymenv, 'TD3')
# train(gymenv, 'A2C')

In [None]:

def test(env, sb3_algo, path_to_model):
    match sb3_algo:
        case 'SAC':
            model = SAC.load(path_to_model)
        case 'TD3':
            model = TD3.load(path_to_model)
        case 'A2C':
            model = A2C.load(path_to_model)
        case _:
            print('Invalid algorithm')
            return

    obs = env.reset()[0]
    done = False
    extra_steps = 500
    while True:
        action, _states = model.predict(obs)
        obs, _, done, _, _ = env.step(action)
        
        if done:
            extra_steps -= 1

        if extra_steps < 0:
            break

In [None]:
gymenv_test = gym.make('Humanoid-v4', render_mode='human')

na podstawie tenorboarda stwierdzamy ze sac jest najlepszy

In [None]:
test(gymenv_test, 'SAC', 'models/SAC_100000')

Pouczymy model jescze przez kolejne 100000 epok, zaczynajac od poprzedniego stanu

In [None]:
model = SAC.load('models/SAC_100000')
gymenv = gym.make('Humanoid-v4', render_mode=None)
# train(gymenv, 'SAC', model=model, max_iters=4, curr_name='models/SAC2_continued')

In [None]:

test(gymenv_test, 'SAC', 'models/SAC2_continued_100000')