In [None]:
# import os
# os.chdir('/content/drive/My Drive/Projects/cell-free-run')

# import numpy as np
# from utils import Environment
# power = 30
# env = Environment(10**(power/10))
# np.random.seed(1949)
# G_valid, F_valid, H_valid = [], [], []
# for i in range(501):
#   G, F, H = env.getCSI()
#   G_valid.append(G)
#   F_valid.append(F)
#   H_valid.append(H)
# G_valid = np.stack(G_valid)
# F_valid = np.stack(F_valid)
# H_valid = np.stack(H_valid)

# path = "./data"
# np.save(path+"/G_%s"%(str(power)), G_valid)
# np.save(path+"/F_%s"%(str(power)), F_valid)
# np.save(path+"/H_%s"%(str(power)), H_valid)


In [None]:
!nvidia-smi

In [None]:
import os
os.chdir('/content/drive/My Drive/Projects/cell-free-run')
import time
import random
import math

from IPython.display import clear_output
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F 
from tqdm import tqdm
import matplotlib.pyplot as plt
%config InlineBackend.figure_format='retina'
%matplotlib inline
import seaborn as sns
sns.set()

from utils import Environment, Memory
from net import TD3, OUStrategy


def explore(csi, state, timestep):
    """explore the environment and record transitions.
    Parameters:
    csi : current CSI
    state : current state
    timestep : the current timestep
    """
    model.eval()
    # get action from policy
    raw_act1, raw_act2 = model.actor(torch.tensor(state, dtype=torch.float, device=model.device).unsqueeze(0))
    act1, act2 = OUnoise.getActFromRaw(raw_act1, raw_act2, timestep)
    act2 = model.act2Trans(act2, env)
    # compute the average rate
    act1, act2 = act1.squeeze(0).detach().cpu().numpy(), act2.squeeze(0).detach().cpu().numpy()
    avg_rate = env.getRate(csi, act1, act2)
    csi_next = env.getCSI()
    state_next = env.getState(csi_next, None, None)
    # store the experience
    memory.push(state, np.concatenate((act1, act2)), avg_rate, state_next)
    return csi_next, state_next, avg_rate

def validation(G_total, F_total, H_total, act1_init=None, act2_init=None):
    """validate the performance of current model.
    """
    model.eval()
    n_bs, n_ris, n_user = env.getCount()
    power, rate = 0, 0
    i = 0
    for G, F, H in zip(G_total, F_total, H_total):
        state = env.getState((G, F, H), None, None)
        act1, act2 = model.actor(torch.tensor(state, dtype=torch.float, device=model.device).unsqueeze(0))
        act2 = model.act2Trans(act2, env)
        act1, act2 = act1.squeeze(0).detach().cpu().numpy(), act2.squeeze(0).detach().cpu().numpy()
        avg_rate = env.getRate((G, F, H), act1, act2)
        rate += avg_rate
        i += 1    
        # check the output beamforming and reflecting matrix
        W, Phi = env.actTrans(act1, act2)
        W = W.reshape(n_bs, n_user, 2, env.M)
        power = max(power, np.amax(np.sum(W**2, (1, 2, 3))))
        if i==100 or i==200:
            print('validate action\n', act1[:5], '\n', act2[:5])
    print("maximum used power:", power) 
    return rate/i, power

def save_opt(fpath):
    """save optimizer"""
    state_dicts = {'opt_act': opt_actor.state_dict(), 'opt_critic': opt_critic.state_dict()}
    torch.save(state_dicts, fpath+'/optim.bin')

def load_opt(fpath):
    """load optimizer"""
    state_dicts = torch.load(fpath+'/optim.bin', map_location=lambda storage, loc: storage)
    opt_actor.load_state_dict(state_dicts['opt_act'])
    opt_critic.load_state_dict(state_dicts['opt_critic'])

def numTrans(number):
    number = abs(number)
    if number//1 == 0:
        return 1
    if number//10 == 0:
        return math.floor(number)
    return (number//10)*10

In [None]:
def train(epoches, steps, dBm, out_path='./checkpoint', **args):
  global env, model, memory, OUnoise, OUnoise_target, opt_critic, opt_actor
  # -------- environment --------
  env = Environment(10**(dBm/10))
  print("environment initialized...")

  # -------- buffer --------
  memory = Memory(int(1e4))
  print("buffer initialized...")

  # -------- validation data --------
  G_valid = np.load("./data/G_%s.npy"%(str(dBm)))
  F_valid = np.load("./data/F_%s.npy"%(str(dBm)))
  H_valid = np.load("./data/H_%s.npy"%(str(dBm)))

  # ---- model hyperparameters ----
  torch.manual_seed(2020)
  random.seed(2020)
  batch_size = 64
  reward_decay = 0.1
  policy_upfreq = 2
  sync_rate = 0.001
  lr_decay = 0.999
  grad_clip = 10
  lr_c = args['lr_critic']
  lr_a = args['lr_actor']
  c_h_size = args['critic_hidden_size']
  a_h_size = args['actor_hidden_size']
  n_c_hidden = args['n_critic_hidden']
  n_a_hidden = args['n_actor_hidden']
  c_w_decay = args['critic_weight_decay']
  a_w_decay = args['actor_weight_decay']
  
  # -------- create model --------
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  torch.backends.cudnn.benchmark = True
  model = TD3(env.state_size,
         (env.act_space['dim1'], env.act_space['dim2']),
         a_h_size,
         c_h_size, 
         n_c_hidden, 
         n_a_hidden,
         max_act=(env.act_space['bound1'], env.act_space['bound2']))
  model.to(device)
  model.init()
  # model.load(out_path)
  
  opt_critic = torch.optim.Adam(model.critic.parameters(),
                  weight_decay=c_w_decay,
                  lr=lr_c)
  opt_actor = torch.optim.Adam(model.actor.parameters(),
                  weight_decay=a_w_decay,
                  lr=lr_a)
  # load_opt(out_path)
  # lr_c = opt_critic.param_groups[0]['lr']
  # lr_a = opt_actor.param_groups[0]['lr']

  # -------- noise --------
  OUnoise = OUStrategy(model.device, env.act_space, max_sigma=(1/100, 1/20))
  OUnoise_target = OUStrategy(model.device, env.act_space, theta=1, max_sigma=(1/200, 1/40), noise_bound=(1/100, 1/20))
  print("noise initialized...")

  # -------- explore --------
  with torch.no_grad():
    csi_, state_ = env.reset()
    for t in tqdm(range(1000), desc='exploring'):
      csi_, state_, _ = explore(csi_, state_, t)
  
  # -------- training --------
  info_interval = 100
  validate_interval = 1000
  start_time, train_time = time.time(), time.time()
  iter_n, critic_count, trial = 0, 0, 0
  best_score = 0.
  print("now begin to train networks ......")

  csi_explore, state_explore = env.reset()
  for epoch in range(epoches):
    if epoch % 5 == 0: clear_output()
    rate_info, loss_c_info, loss_a_info, q_info = [], [], [], []
    OUnoise.reset()
    OUnoise_target.reset()

    for step in range(steps):
      iter_n += 1
      # ---- explore the environment ----
      with torch.no_grad():
        csi_explore, state_explore, rate_explore = explore(csi_explore, state_explore, step)
      rate_info.append(rate_explore)
          
      # ---- get batch transition memory ----
      state, action, reward, state_next = memory.getBatch(batch_size)
      state = torch.tensor(state, dtype=torch.float, device=model.device)
      action = torch.tensor(action, dtype=torch.float, device=model.device)
      reward = torch.tensor(reward, dtype=torch.float, device=model.device).view(-1, 1)
      state_next = torch.tensor(state_next, dtype=torch.float, device=model.device)
  
      # ---- update critic networks ----
      model.actor_target.eval()
      model.critic_target.eval()
      act1_next, act2_next = model.actor_target(state_next)
      act1_next, act2_next = OUnoise_target.getActFromRaw(act1_next, act2_next, step)
      act2_next = model.act2Trans(act2_next, env)
      act_next = torch.cat((act1_next, act2_next), dim=1)
      Q_prime = reward + reward_decay*model.critic_target(state_next, act_next).detach()
  
      model.critic.train()
      Q1 = model.critic.critic1(state, action)
      Q2 = model.critic.critic2(state, action)
      loss_c = F.mse_loss(Q1, Q_prime) + F.mse_loss(Q2, Q_prime)
      loss_c_info.append(loss_c.item())
      
      opt_critic.zero_grad()
      loss_c.backward()
      nn.utils.clip_grad_norm_(model.critic.critic1.parameters(), grad_clip)
      nn.utils.clip_grad_norm_(model.critic.critic2.parameters(), grad_clip)
      opt_critic.step()
  
      critic_count += 1
      if critic_count % policy_upfreq == 0:
        # ---- update actor networks ----
        critic_count = 0
        model.actor.train()
        # get action from policy
        act1_eager, act2_eager = model.actor(state)
        act2_eager = model.act2Trans(act2_eager, env)
        act_eager = torch.cat((act1_eager, act2_eager), 1)
        # calculate punishment
        punish1 = model.act1Check(act1_eager, env)
        # punish2 = torch.sum(model.actor.output1.weight**2) + torch.sum(model.actor.output2.weight**2)

        # calculate loss
        model.critic.eval()
        Q_achieve = model.critic.getQ(state, act_eager).mean()
        # scale_num = numTrans(Q_achieve.item())
        loss_a = -1.0 * (Q_achieve - 1 * punish1) 
        q_info.append(Q_achieve.item())
        loss_a_info.append(loss_a.item())
   
        opt_actor.zero_grad()
        loss_a.backward()
        nn.utils.clip_grad_norm_(model.actor.parameters(), grad_clip)
        opt_actor.step()
  
        # ---- synchronize network ----
        with torch.no_grad():
          model.sync(sync_rate)
  
      # ---- loss information ----
      if iter_n % info_interval == 0:
        print("epoch: %d, iter: %d, best: %.4f, avg.rate: %.4f, loss_c: %.4f, loss_a: %.2f, Q: %.2f, speed: %.2f/iter, time eclapsed: %dsec"
          % (epoch,
            iter_n,
            best_score,
            np.mean(rate_info[-info_interval:]),
            np.mean(loss_c_info[-info_interval:]),
            np.mean(loss_a_info[-info_interval:]),
            np.mean(q_info[-info_interval:]),
            (time.time()-train_time)/info_interval,
            time.time()-start_time)
        )
        train_time = time.time()

      # ---- validation ----
      if iter_n % validate_interval == 0:
        print("validation begins, check whether the model upgraded......")
        avg_rate, power_used = validation(G_valid, F_valid, H_valid)
        print("----> avg.rate: %.4f" % avg_rate)

        if epoch<5:
          continue
  
        if avg_rate > best_score and power_used<=env.power_max:
          print("model upgraded...save to folder %s"%out_path)
          trial = 0
          best_score = avg_rate
          save_opt(out_path)
          model.save(out_path)
        else:
          trial += 1
          print("hit trial %d"%trial)

          lr_c = max(1e-4, lr_c*lr_decay)
          print("lr_critic change to %.4e"%lr_c)
          for param in opt_critic.param_groups:
            param['lr'] = lr_c

          lr_a = max(1e-4, lr_a*lr_decay)
          print("lr_actor change to %.4e"%lr_a)
          for param in opt_actor.param_groups:
            param['lr'] = lr_a
          
          # if trial%20 == 0:
          #   print("reload model")
          #   model.load(out_path)
          #   load_opt(out_path)
          #   lr_c = opt_critic.param_groups[0]['lr']
          #   lr_a = opt_actor.param_groups[0]['lr']
        
  return best_score, trial

In [None]:
epoches, steps = 500, int(5e3)
args = {'lr_critic':3e-4, 'lr_actor':3e-4, 'critic_hidden_size': -1, 'actor_hidden_size': -1,
      'n_critic_hidden': -1, 'n_actor_hidden': -1, 'critic_weight_decay': 0, 'actor_weight_decay': 0}
score, trial = train(epoches, steps, 30, **args)
print('\n\n', score, trial, loss)

In [None]:
env = Environment(10**(30/10))
env.action_space['bound'] = 2
print("environment initialized...")

# -------- create model --------
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True
model = TD3(env.state_size,
       env.action_space['dim1'],
       env.action_space['dim2'],
       2048,
       2048, 
       2, 
       3,
       bound=env.action_space['bound'])
model.to(device)
model.init()
model.load('./checkpoint')

In [None]:
print(model.actor.bn_out2.running_mean)
print(model.actor.bn_out2.running_var)