In [1]:
import pandas as pd

from actors import VesselCategory
from utils import load_serialized_neighbors, find_and_serialize_neighbors


grid = pd.read_excel(r"data\velocity.xlsx")
routes = pd.read_excel(r"data\routes.xlsx")
vessels = pd.read_excel(r"data\vessels.xlsx")

In [2]:
from envs.actors import ShipTimed, IceBreaker, Geopoint, VesselMoveStatus
time_tick = 1 / 6 # 10 min

def make_ship(config: dict):
    if '_' in config['iceclass']:
        start_geo = routes.iloc[:, 3:].sample(1).values[0][:2]
        return IceBreaker(
            name=config.get('name'),
            category=VesselCategory.arc92,
            location_point=Geopoint(start_geo[0], start_geo[1]),
            route_request=None, # им никуда не надо
            status=VesselMoveStatus.waiting,
            max_speed=18.50,
            avg_speed=0.,
            curr_speed=0.,
            tick=time_tick,
        )
    start_geo = routes.iloc[:, 3:].sample(1).values[0]
    return ShipTimed(
        name=config.get('name'),
        category=VesselCategory.arc6,
        location_point=Geopoint(start_geo[0], start_geo[1]),
        route_request=Geopoint(start_geo[2], start_geo[3]),
        status=VesselMoveStatus.waiting,
        max_speed=19,
        avg_speed=0.,
        curr_speed=0.,
        tick=time_tick,
    )

In [3]:
time_tick = 1 / 6 # 10 min
#ships icebrteakers
ships_list = [make_ship(ship_stat) for _, ship_stat in vessels.to_dict(orient='index').items()]
# grid
find_and_serialize_neighbors(grid=grid[["lat", "lon"]].values, static_folder="data")

config = {
    'max_episode_steps': 400,
    'neighbors_shape': (1, 32, ),
    'serialized_neighbors': load_serialized_neighbors("data"),
    'grid': grid,
    #'routes': routes.iloc[:, 3:],
    'date_start': 2,
    'time_tick': time_tick,
    'ice_breaker_counts': 4,
    'ships_count': 42,
    # TODO hard reset
    'ships_list': ships_list
}

In [4]:
from envs.waterworld_multiagent import WaterWorldMultiEnv
import warnings
warnings.filterwarnings("ignore")

wwme = WaterWorldMultiEnv(config)
w, _ = wwme.reset()

  from .autonotebook import tqdm as notebook_tqdm
2024-06-16 20:01:04,769	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2024-06-16 20:01:07,391	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


### алгоритм

In [8]:
from ray.rllib.algorithms import PPOConfig


try:
    algo.stop()
except:
    pass

training_options = {
    'train_batch_size': 1024,
    'sgd_minibatch_size': 128,
    #'lr': 1e-05,
    #'num_sgd_iter': 22,
    'model': {
        #'fcnet_activation': 'relu',
        #'fcnet_hiddens': [256, 256],
        'vf_share_layers': False,
        #'use_lstm': True,
    }
}

ppo_config = (
    PPOConfig()
    .environment(env=WaterWorldMultiEnv, env_config=config)
    .framework('torch')
    .rollouts(num_rollout_workers=0, rollout_fragment_length=8)
    #.evaluation(evaluation_num_workers=0, evaluation_interval=1)
    .training(**training_options)
    #TODO разные политики для каждого корабля, траектория должна совпаджать если модель одна (все агенты должны одновременно заканчивать)
)



In [9]:
algo = ppo_config.build()



2024-06-16 20:01:34,164	INFO tensorboardx.py:45 -- pip install "ray[tune]" to see TensorBoard files.


### обучение

In [None]:
# algo.load_checkpoint(checkpoint_dir='./checkpoint_v1/')
# algo.save_checkpoint(checkpoint_dir='./checkpoint_v1/')

In [None]:
hist = []
EPOCHS = 50
for i in range(EPOCHS):
    hist.append(algo.step())
    if i % 5:
        #print(f'Eposch = {i}\tinfo = {hist}')

### инференс

In [10]:
def alg_eval(env, algo, ship_id):
    all_actions = []
    total_profit = 0
    cords = []
    obs, _ = env.reset()
    for _ in range(100):
        x = algo.compute_single_action(obs[ship_id], explore=False)
        obs, profit, term, trunc, _ = env.step({ship_id: x})

        all_actions.append(x)
        total_profit += profit.get(ship_id, 0)
        cords.append(env.envs[ship_id].ship.location_point)
        if term.get(ship_id, False) or trunc.get(ship_id, False):
            break
        if term.get('__all__', False) or trunc.get(ship_id, False):
            break

    return {
        'total_profit': total_profit,
        'actions': all_actions,
        'cords': cords,
    }

In [29]:
ship_id = 12
print(ships_list[ship_id])
alg_eval(wwme, algo, ship_id)

ShipTimed(name='ШТУРМАН АЛЬБАНОВ', category=<VesselCategory.arc6: 6>, location_point=Geopoint(latitude=69.9, longitude=44.6), route_request=Geopoint(latitude=69.15, longitude=57.68), status=<VesselMoveStatus.waiting: 0>, max_speed=19, avg_speed=0.0, curr_speed=0.0, total_time=0.0, tick=0.16666666666666666, _last_x=None, _last_y=None)


{'total_profit': -10.87325686200906,
 'actions': [1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2],
 'cords': [Geopoint(latitude=69.9, longitude=44.6),
  Geopoint(latitude=66.73333333333333, longitude=44.6),
  Geopoint(latitude=66.73333333333333, longitude=47.766666666666666),
  Geopoint(latitude=66.73333333333333, longitude=50.93333333333333),
  Geopoint(latitude=66.73333333333333, longitude=54.099999999999994),
  Geopoint(latitude=66.73333333333333, longitude=57.26666666666666),
  Geopoint(latitude=66.73333333333333, longitude=60.43333333333332),
  Geopoint(latitude=66.73333333333333, longitude=63.59999999999999),
  Geopoint(latitude=66.73333333333333, longitude=66.76666666666665),
  Geopoint(latitude=66.73333333333333, longitude=69.93333333333332),
  Geopoint(latitude=66.73333333333333, longitude=69.93333333333332)]}

In [34]:
ship_id = 41
print(ships_list[ship_id])
alg_eval(wwme, algo, ship_id)

ShipTimed(name='ТАЙБОЛА', category=<VesselCategory.arc6: 6>, location_point=Geopoint(latitude=69.9, longitude=44.6), route_request=Geopoint(latitude=69.5, longitude=33.75), status=<VesselMoveStatus.waiting: 0>, max_speed=19, avg_speed=0.0, curr_speed=0.0, total_time=0.0, tick=0.16666666666666666, _last_x=None, _last_y=None)


{'total_profit': -9.5,
 'actions': [2, 2, 2, 2, 2, 2, 2, 2, 2],
 'cords': [Geopoint(latitude=69.9, longitude=44.6),
  Geopoint(latitude=69.9, longitude=47.766666666666666),
  Geopoint(latitude=69.9, longitude=50.93333333333333),
  Geopoint(latitude=69.9, longitude=54.099999999999994),
  Geopoint(latitude=69.9, longitude=57.26666666666666),
  Geopoint(latitude=69.9, longitude=60.43333333333332),
  Geopoint(latitude=69.9, longitude=63.59999999999999),
  Geopoint(latitude=69.9, longitude=66.76666666666665),
  Geopoint(latitude=69.9, longitude=66.76666666666665)]}