In [3]:
import os, re
import argparse
import torch
import datetime
import numpy as np
import random
import multiprocessing

from agents import PPO, PG, QPO, QPPO, SPSA
from envs import ToyEnv, RSMA,RSMA_NO_RIS,NOMA,TDMA


class Options(object):
    def __init__(self, algo_name):

        parser = argparse.ArgumentParser()
        parser.add_argument('--env_name', type=str, default='ToyEnv')
        parser.add_argument('--algo_name', type=str, default=algo_name)
        parser.add_argument('--q_alpha', type=float, default=0.25)
        # parser.add_argument('--q_alpha', type=float, default=0.1)
        parser.add_argument('--est_interval', type=int, default=100)
        parser.add_argument('--log_interval', type=int, default=100)
        # parser.add_argument('--max_episode', type=int, default=200000)
        parser.add_argument('--max_episode', type=int, default=150000)
        # parser.add_argument('--emb_dim', type=list, default=[8,8])
        # ****** My Code ****** #


        parser.add_argument('--emb_dim', type=list, default=[128,128])
        # parser.add_argument('--emb_dim', type=list, default=[256,256])
        # ****** My Code ****** #

        parser.add_argument('--init_std', type=float, default=np.sqrt(1e-1))
        parser.add_argument('--gamma', type=float, default=0.99)
        # parser.add_argument('--gamma', type=float, default=0.95)

        # lr = a / (b + episode) ** c
        parser.add_argument('--theta_a', type=float, default=(10000**0.9)*1e-3)
        parser.add_argument('--theta_b', type=float, default=10000)
        parser.add_argument('--theta_c', type=float, default=0.9)
        parser.add_argument('--q_a', type=float, default=(10000**0.6)*1e-2)
        parser.add_argument('--q_b', type=float, default=10000)
        parser.add_argument('--q_c', type=float, default=0.6)

        args = parser.parse_args(args=[])
        if args.algo_name == 'QPPO':
            parser.add_argument('--lambda_gae_adv', type=float, default=0.95)
            parser.add_argument('--clip_eps', type=float, default=0.2)
            parser.add_argument('--vf_coef', type=float, default=0.5)
            parser.add_argument('--ent_coef', type=float, default=0.00)
            parser.add_argument('--upd_interval', type=int, default=2000)
            parser.add_argument('--upd_step', type=int, default=5)
            parser.add_argument('--mini_batch', type=int, default=100)
            parser.add_argument('--T', type=int, default=10)
            parser.add_argument('--T0', type=int, default=5)
        if args.algo_name == 'SPSA':
            parser.add_argument('--spsa_batch', type=int, default=5)
            parser.add_argument('--perturb_c', type=float, default=1.9)
            parser.add_argument('--perturb_gamma', type=float, default=1/6)
        if args.algo_name == 'PPO':
            parser.add_argument('--lambda_gae_adv', type=float, default=0.95)
            parser.add_argument('--clip_eps', type=float, default=0.2)
            parser.add_argument('--vf_coef', type=float, default=0.5)
            parser.add_argument('--ent_coef', type=float, default=0.00)
            parser.add_argument('--upd_interval', type=int, default=2000)
            parser.add_argument('--upd_step', type=int, default=10)
            parser.add_argument('--mini_batch', type=int, default=100)
        self.parser = parser

    def parse(self, seed=0, device='0'):
        args = self.parser.parse_args(args=[])
        args.seed = seed
        args.device = torch.device("cuda:" + device if torch.cuda.is_available() else "cpu")

        current_time = re.sub(r'\D', '', str(datetime.datetime.now())[4:-7])
        args.path = './logs/' + args.env_name + '/' + args.algo_name + '_' + current_time + '_' + str(args.seed)
        os.makedirs(args.path)
        return args


def run(algo_name, seed, device):
    args = Options(algo_name).parse(seed, str(device))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # env = ToyEnv(n=10)

    # ***** My Code ***** #
    # env = RSMA(n_users=20) # set the number of users
    # env = RSMA(n_users=16) # set the number of users
    # env = RSMA(n_users=15) # set the number of users
    # env = RSMA(n_users=12) # set the number of users
    # env = RSMA(n_users=10) # set the number of users
    # env = RSMA(n_users=8) # set the number of users
    # env = RSMA(n_users=6) # set the number of users

    # env = RSMA(n_users=8) # set the number of users
    env = RSMA(4,2) # set the number of users
    # env = RSMA(n_users=2) # set the number of users

    # env = NOMA(n_users=20) # set the number of users
    # env = NOMA(n_users=16) # set the number of users
    # env = NOMA(n_users=12) # set the number of users
    # env = NOMA(n_users=8) # set the number of users
    # env = NOMA(n_users=6) # set the number of users
    # env = NOMA(n_users=4) # set the number of users



    # env = TDMA(n_users=20) # set the number of users
    # env = TDMA(n_users=16) # set the number of users
    # env = TDMA(n_users=12) # set the number of users
    # env = TDMA(n_users=8) # set the number of users
    # env = TDMA(n_users=4) # set the number of users


    # env = RSMA_NO_RIS(n_users=4) # set the number of users
    # env = NOMA(n_users=20) # set the number of users
    # env = TDMA(n_users=20) # set the number of users
    # ***** My Code ***** #

    if args.algo_name == 'PPO':
        agent = PPO(args, env)
    elif args.algo_name == 'QPPO':
        agent = QPPO(args, env)
    elif args.algo_name == 'PG':
        agent = PG(args, env)
    elif args.algo_name == 'QPO':
        agent = QPO(args, env)
    else:
        agent = SPSA(args, env)
    print(args.algo_name + ' running')
    agent.train()


if __name__ == '__main__':
    n = 1
    # algos = ['SPSA']*n
    algos = ['PPO']*n
    # algos = ['QPPO']*n
    seeds = [i for i in range(n)]
    devices = [0 for i in range(n)]
    zipped_list = list(zip(algos, seeds, devices))
    pool = multiprocessing.Pool(processes=n)
    pool.starmap(run, zipped_list)
    pool.close()
    pool.join()





8 128
128 128
PPO running




Epi:00100 || disc_a_r:-131.218 disc_q_r:-134.785
Epi:00200 || disc_a_r:-130.938 disc_q_r:-134.567
Epi:00222 || model Updated with lr:9.80e-04
Epi:00300 || disc_a_r:-129.851 disc_q_r:-133.127
Epi:00400 || disc_a_r:-129.577 disc_q_r:-132.749
Epi:00445 || model Updated with lr:9.62e-04
Epi:00500 || disc_a_r:-127.089 disc_q_r:-131.093
Epi:00600 || disc_a_r:-122.435 disc_q_r:-130.953
Epi:00668 || model Updated with lr:9.43e-04
Epi:00700 || disc_a_r:-117.105 disc_q_r:-125.351
Epi:00800 || disc_a_r:-111.414 disc_q_r:-122.473
Epi:00891 || model Updated with lr:9.26e-04
Epi:00900 || disc_a_r:-114.274 disc_q_r:-123.462
Epi:01000 || disc_a_r:-144.725 disc_q_r:-160.953
Epi:01100 || disc_a_r:-144.417 disc_q_r:-154.532


Process ForkPoolWorker-6:
Traceback (most recent call last):


KeyboardInterrupt: 

  File "/home/eb_khosravi/.conda/envs/eb_khosravi/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/eb_khosravi/.conda/envs/eb_khosravi/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/eb_khosravi/.conda/envs/eb_khosravi/lib/python3.12/multiprocessing/pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
                    ^^^^^^^^^^^^^^^^^^^
  File "/home/eb_khosravi/.conda/envs/eb_khosravi/lib/python3.12/multiprocessing/pool.py", line 51, in starmapstar
    return list(itertools.starmap(args[0], args[1]))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_2276862/3615699706.py", line 135, in run
    agent.train()
  File "/home/eb_khosravi/agents/ppo.py", line 101, in train
    self.update()
  File "/home/eb_khosravi/agents/ppo.py", line 179, in update
    logprobs, state_values, dist_entropy = self.evaluate(old_states[idx], 