In [8]:
import numpy as np
import torch
import os

from decision_transformer.models.decision_transformer import DecisionTransformer
from qoe_to_go import QoE_predictor_model
import matplotlib.pyplot as plt

import envs.fixed_env_vmaf as env_test
from utils.data_loader import get_throughput_char, get_qoe2go_estimation

M_IN_K = 1000.

load QoE2Go prediction models

In [9]:
device_id = 'cuda:0'
device = torch.device(device_id if torch.cuda.is_available() else "cpu")
q2go_model = QoE_predictor_model().to(device)

model_checkpoint_path = "./checkpoints/q2go/Q2GO_predictor.pt"
q2go_model.load_state_dict(torch.load(model_checkpoint_path))
q2go_model.eval()

QoE_predictor_model(
  (QoE_predictor_model): Sequential(
    (0): Linear(in_features=4, out_features=128, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=128, out_features=32, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=32, out_features=1, bias=True)
  )
)

load environment of abr (vmaf version)

In [10]:
test_traces = '../test_traces/'
log_save_dir = './results_valid/'
if not os.path.exists(log_save_dir):
    os.mkdir(log_save_dir)
log_path_ini = log_save_dir + 'log_test_karmar' 

video_size_file = '../video_size/ori/video_size_' #video = 'origin'
video_vmaf_file = './video_vmaf/chunk_vmaf'
from utils import load_trace
all_cooked_time, all_cooked_bw, all_file_names = load_trace.load_trace(test_traces)
test_env = env_test.Environment(all_cooked_time=all_cooked_time,
                                all_cooked_bw=all_cooked_bw, 
                                all_file_names = all_file_names, 
                                video_size_file = video_size_file, 
                                video_psnr_file= video_vmaf_file
                                )

S_INFO = 7 # 
S_LEN = 8 # maximum length of states 
C_LEN = 0 # content length 
S_DIM = 16
VIDEO_BIT_RATE = [300,750,1200,1850,2850,4300]  # kbps
TOTAL_CHUNK_NUM = 49
QUALITY_PENALTY = 0.8469011 #dB
REBUF_PENALTY = 28.79591348
SMOOTH_PENALTY_P = -0.29797156
SMOOTH_PENALTY_N = 1.06099887
test_env.set_env_info(S_INFO, S_LEN, C_LEN, TOTAL_CHUNK_NUM, VIDEO_BIT_RATE, \
                    QUALITY_PENALTY, REBUF_PENALTY, \
                    SMOOTH_PENALTY_P, SMOOTH_PENALTY_N)

load dt model

In [11]:
# load model    
dt_model_checkpoint_path = "./checkpoints/dt/dt_model.pt"

dt_model = DecisionTransformer(
    state_dim=S_DIM,
    act_dim=len(VIDEO_BIT_RATE),
    max_length=4,
    max_ep_len=512,
    action_tanh=False,
    hidden_size=32,
    n_layer=3,
    n_head=1,
    n_inner=4 * 32,
    activation_function='relu',
    n_positions=1024,
    resid_pdrop=0.1,
    attn_pdrop=0.1,
).to(device)

dt_model.load_state_dict(torch.load(dt_model_checkpoint_path))
dt_model.eval()

DecisionTransformer(
  (transformer): GPT2Model(
    (wte): Embedding(1, 32)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): Block(
        (ln_1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): Block(
        (ln_1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((32,), eps=1e-05, elementwise_

simulation of a session

In [13]:

DEFAULT_QUALITY = 1
bit_rate = DEFAULT_QUALITY
last_bit_rate = DEFAULT_QUALITY
last_quality = test_env.chunk_psnr[DEFAULT_QUALITY][0]
state_dim = S_DIM
a_dim = len(VIDEO_BIT_RATE)
state = np.zeros((state_dim))
action_vec = np.zeros(a_dim)
action_vec[bit_rate] = 1
s_info, s_len, c_len, total_chunk_num, bitrate_versions, \
    quality_penalty, rebuffer_penalty, smooth_penalty_p, smooth_penalty_n \
        = test_env.get_env_info()
a_dim = len(bitrate_versions)

# initialize recording files
all_file_name = test_env.all_file_names
log_path = log_path_ini + '_' + all_file_name[test_env.trace_idx]
log_file = open(log_path, 'w')

# start to download video
for video_count in range(len(all_file_name)):
    '''for a network trace '''
    past_throughputs = np.zeros((1, 4))
    states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
    # states = torch.zeros((0, state_dim), device=device, dtype=torch.float32)
    actions = torch.from_numpy(action_vec).reshape(1, a_dim).to(device=device, dtype=torch.float32)
    # actions = torch.zeros((0, a_dim), device=device, dtype=torch.float32)
    rewards = torch.zeros(0, device=device, dtype=torch.float32)
    returns_est = torch.zeros(1, device=device, dtype=torch.float32).reshape(1, 1)
    timesteps = torch.zeros(1, device=device, dtype=torch.long).reshape(1, 1)
    time_stamp = 0
    chunk_id = 0
    while True:
        delay, sleep_time, buffer_size, rebuf, \
            video_chunk_size, next_video_chunk_sizes, next_video_chunk_psnrs, \
                end_of_video, video_chunk_remain, _, curr_chunk_psnrs \
                    = test_env.get_video_chunk(bit_rate)
        
        time_stamp += delay  # in ms
        time_stamp += sleep_time  # in ms

        # reward is video quality - rebuffer penalty - smooth penalty
        curr_quality = curr_chunk_psnrs[bit_rate]
        sm_dif_p = max(curr_quality - last_quality, 0)
        sm_dif_n = max(last_quality - curr_quality, 0)
        reward = quality_penalty * curr_quality \
                    - rebuffer_penalty * rebuf \
                        - smooth_penalty_p * sm_dif_p \
                            - smooth_penalty_n * sm_dif_n \
                                - 2.661618558192494

        last_bit_rate = bit_rate
        last_quality = curr_quality

        log_file.write(str(time_stamp / M_IN_K) + '\t' +
                    str(bitrate_versions[bit_rate]) + '\t' +
                    str(buffer_size) + '\t' +
                    str(rebuf) + '\t' +
                    str(video_chunk_size) + '\t' +
                    str(delay) + '\t' +
                    str(reward) + '\n')
        log_file.flush()

        # ========== get action and reward ============
        # add padding
        actions = torch.cat([actions, torch.zeros((1, a_dim), device=device)], dim=0)
        rewards = torch.cat([rewards, torch.zeros(1, device=device)])
    
        # ========== get state ============
        BUFFER_NORM_FACTOR = 10.
        state[0] = video_chunk_size / delay / M_IN_K  # kilo byte / ms # throughput
        state[1] = float(buffer_size / BUFFER_NORM_FACTOR)  # 10 sec # buffer size
        # last quality
        # state[2] = parse[1] / float(np.max(ACTION_SELECTED))
        state[2] = delay / M_IN_K  # chunk download time
        state[3] = np.minimum(video_chunk_remain, total_chunk_num) / float(
            total_chunk_num
        )  # fraction of remaining chunks
        state[4 : 4 + a_dim] = (
            np.array(next_video_chunk_sizes) / M_IN_K / M_IN_K
        )  # next chunk sizes
        state[4 + a_dim : 4 + 2 * a_dim] = (
            np.array(next_video_chunk_psnrs) / 100
        )  # vmaf values [0, 100] # next chunk vmaf values

        cur_state = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
        states = torch.cat([states, cur_state], dim=0)

        # ========== get Qoe2Go prediction ============
        # get observation of QoE2Go
        throughput_current = video_chunk_size / delay / M_IN_K  # kilo byte / ms
        throughput_mean, throughput_std = get_throughput_char(
            past_throughputs, throughput_current
        )
        past_throughputs = np.roll(past_throughputs, -1, axis=1)
        past_throughputs[0, -1] = throughput_current
        buffer_size = float(buffer_size / BUFFER_NORM_FACTOR)  # 10 sec
        remain_chunks_num = np.minimum(video_chunk_remain, total_chunk_num) / float(
            total_chunk_num
        )

        # get QoE2Go estimation
        q2go_obs = [throughput_mean, throughput_std, buffer_size, remain_chunks_num]
        q2go_estimation = get_qoe2go_estimation(q2go_obs, q2go_model, device)
        cur_return = torch.from_numpy(np.array([q2go_estimation])).reshape(1,1).to(device=device, dtype=torch.float32)
        returns_est = torch.cat([returns_est, cur_return], dim=0)

        timesteps = torch.cat([timesteps, torch.ones((1, 1), device=device, dtype=torch.long) * (chunk_id+1)], dim=1)


        # ========== get action prediction ============
        # get action inference
        with torch.no_grad():
            prob = dt_model.get_action(
                states.to(dtype=torch.float32),
                actions.to(dtype=torch.float32),
                rewards.to(dtype=torch.float32),
                returns_est.to(dtype=torch.float32),
                timesteps.to(dtype=torch.long),
            )
        bit_rate = int(torch.argmax(prob).squeeze().cpu().numpy())

        # set action
        action_vec = np.zeros(a_dim)
        action_vec[bit_rate] = 1
        action_vec = torch.from_numpy(action_vec).to(device=device, dtype=torch.float32)
        actions[-1] = action_vec
        rewards[-1] = reward

        chunk_id += 1

        # end of video  
        if end_of_video:
            last_quality = test_env.chunk_psnr[DEFAULT_QUALITY][0]
            state = np.zeros((state_dim))
            action_vec = np.zeros(a_dim)
            action_vec[DEFAULT_QUALITY] = 1
            log_file.write('\n')
            log_file.close()
            time_stamp = 0

            if video_count + 1 >= len(all_file_name):
                break
            else:
                log_path = log_path_ini + '_' + all_file_name[test_env.trace_idx]
                log_file = open(log_path, 'w')
                break

results analysis

In [3]:
import os
import numpy as np

rewards = []
test_log_folder = './results_valid/'
test_log_files = os.listdir(test_log_folder)
for test_log_file in test_log_files:
    reward = []
    with open(test_log_folder + test_log_file, "r") as f:
        for line in f:
            parse = line.split()
            try:
                reward.append(float(parse[-1]))
            except IndexError:
                break
    rewards.append(np.mean(reward[1:]))

rewards = np.array(rewards)

rewards_min = np.min(rewards)
rewards_5per = np.percentile(rewards, 5)
rewards_mean = np.mean(rewards)
rewards_median = np.percentile(rewards, 50)
rewards_95per = np.percentile(rewards, 95)
rewards_max = np.max(rewards)

print(f"rewards_min: {rewards_min}")
print(f"rewards_5per: {rewards_5per}")
print(f"rewards_mean: {rewards_mean}")
print(f"rewards_median: {rewards_median}")
print(f"rewards_95per: {rewards_95per}")
print(f"rewards_max: {rewards_max}")


NameError: name 'np' is not defined