In [1]:
import re
import os

import pandas as pd

In [2]:
def parse_profiler_results(filename):
    with open(filename, 'r') as f:
        results = f.readlines()

    total_pattern = '[0-9.]+'
    total_expr = re.compile(total_pattern)

    total_line = results[0]
    total_stats = (total_func_calls, total_prim_calls, total_time) = \
        re.findall(total_expr, total_line)

    col_names = results[4]
    col_names = col_names.strip().split('  ')
    col_names = col_names[:-1] + col_names[-1].split(' ')

    data_lines = results[5:-2]
    data_lines = [line.strip() for line in data_lines]
    data_lines = [line.split('   ') for line in data_lines]
    data_lines = [[item.strip() for item in line]  for line in data_lines]
    data_lines = [line[:-1] + line[-1].split(' ', 1) for line in data_lines]

    df = pd.DataFrame(data_lines, columns=col_names)
    df = df.loc[:,~df.columns.duplicated()]
    df[['func_calls', 'prim_calls']] = df['ncalls'].str.split('/', expand=True)
    num_cols = ['func_calls', 'prim_calls', 'tottime', 'percall', 'cumtime']
    df[num_cols] = df[num_cols].astype(float)
    df['prim_calls'] = df['prim_calls'].fillna(0)
    df['filename:lineno(function)'] = df['filename:lineno(function)'].str.replace('/home/nicholas','')
    df['filename:lineno(function)'] = df['filename:lineno(function)'].str.replace('/Desktop/CARBONE-NICHOLAS','')
    
    return total_stats, df

In [3]:
line_break = '--------------------------------'
func_s = '({}) function: {}'
stats_s = 'function calls: {:.0f} | primitive calls: {:.0f} | {}: {:.2f}'
total_s = 'function calls: {} | primitive calls: {} | total runtime: {}'
keys = ['percall', 'tottime', 'cumtime']

def print_results(alias, total_stats, results):
    print('Printing results for {} code.'.format(alias))
    print(line_break)
    print(line_break + '\n')
    print('Overall stats')
    print(line_break)
    print(line_break)
    print(total_s.format(*total_stats))
    print(line_break)
    print(line_break + '\n')
    
    for key in keys:
        print('Results sorted by {}.'.format(key))
        print(line_break)
        print(line_break)
        
        i = 0
        for row in results.sort_values(key, ascending=False).iterrows():
            row = row[1]
            func_s_row = func_s.format(
                i+1,
                row['filename:lineno(function)'],
            )

            stats_s_row = stats_s.format(
                row['func_calls'],
                row['prim_calls'],
                key,
                row[key]
            )
            print(func_s_row)
            print(line_break)
            print(stats_s_row)

            i += 1
            if i < 10:
                print('')
            else:
                break

        print(line_break)
        print(line_break + '\n')

In [4]:
old_total_stats, old_results = parse_profiler_results('profiler_results_old.txt')
new_total_stats, new_results = parse_profiler_results('profiler_results_new.txt')

In [5]:
print_results('old', old_total_stats, old_results)

Printing results for old code.
--------------------------------
--------------------------------

Overall stats
--------------------------------
--------------------------------
function calls: 3345725 | primitive calls: 3252870 | total runtime: 38.447
--------------------------------
--------------------------------

Results sorted by percall.
--------------------------------
--------------------------------
(1) function: /miniconda3/envs/fin-env/lib/python3.8/site-packages/ray/rllib/policy/sample_batch.py:213(shuffle)
--------------------------------
function calls: 30 | primitive calls: 0 | percall: 0.36

(2) function: /miniconda3/envs/fin-env/lib/python3.8/site-packages/ray/rllib/utils/sgd.py:102(do_minibatch_sgd)
--------------------------------
function calls: 1 | primitive calls: 0 | percall: 0.12

(3) function: /miniconda3/envs/fin-env/lib/python3.8/json/decoder.py:343(raw_decode)
--------------------------------
function calls: 2 | primitive calls: 0 | percall: 0.02

(4) funct

In [6]:
print_results('new', new_total_stats, new_results)

Printing results for new code.
--------------------------------
--------------------------------

Overall stats
--------------------------------
--------------------------------
function calls: 2898403 | primitive calls: 2799314 | total runtime: 29.046
--------------------------------
--------------------------------

Results sorted by percall.
--------------------------------
--------------------------------
(1) function: /miniconda3/envs/fin-env/lib/python3.8/site-packages/ray/rllib/policy/sample_batch.py:213(shuffle)
--------------------------------
function calls: 30 | primitive calls: 0 | percall: 0.34

(2) function: /miniconda3/envs/fin-env/lib/python3.8/site-packages/ray/rllib/utils/sgd.py:102(do_minibatch_sgd)
--------------------------------
function calls: 1 | primitive calls: 0 | percall: 0.12

(3) function: /miniconda3/envs/fin-env/lib/python3.8/site-packages/ray/rllib/evaluation/sampler.py:452(_env_runner)
--------------------------------
function calls: 2 | primitive call

In [7]:
list_comp_funcs = [
    '/model.py:121(<listcomp>)',
    '/model.py:126(<listcomp>)'
]

old_results.loc[old_results['filename:lineno(function)'].isin(list_comp_funcs)]

Unnamed: 0,ncalls,tottime,percall,cumtime,filename:lineno(function),func_calls,prim_calls
4,3324,1.199,0.0,3.62,/model.py:121(<listcomp>),3324.0,0.0
5,3324,1.194,0.0,3.775,/model.py:126(<listcomp>),3324.0,0.0


Due to the highly optimized nature of libraries like PyTorch and RLLib, it can be difficult to achieve much speedup on algorithms when implemented properly. However, I compared one iteration for my model to one iteration for an implementation of a simple environment in RLLib and found a difference of approximately 28 seconds. I anticipated that my model would take a bit longer to run, given that it undergoes multiple forward passes of the model, but this seemed like a larger disparity. Therefore, I sought to identify areas of excess in my code.

I was able to trim down a lot of unnecessary references to objects throughout, but, for the most part, these only marginally sped up the algorithm. I then sought to investigate individual functions and their runtime using cProfile (method contained within the profiler.py script). Many of the functions listed as consuming the most time were PyTorch functions related to backprop and largely out of my control. However, I noted two particular functions of my own that seemed to be eating up a significant amount of time:

<ol>
    <li>/model.py:121(&ltlistcomp&gt)</li>
    <li>/model.py:126(&ltlistcomp&gt)</li>
</ol>

These are list comprehension and were shown to consume the fifth and sixth most total time spent in function call (see cell 5 - "tottime" section) in the original script. As can be seen from cell 7, they collectively consume over 7 seconds of runtime. They refer to these lines of code from my model:

<ol>
    <li>action_embeddings = torch.stack([self.action_embeddings[i,int(j.item())] for i, j in enumerate(a1_vec)])</li>
    <li>action_mask = torch.stack([self.action_mask[i,int(j.item())] for i, j in enumerate(a1_vec)])</li>
</ol>

Initially, I implemented this as a quick solution to indexing a Torch tensor based on the values in "a1_vec" to get the code running. I was aware upon implementation that this would copy tensors, but did not expect it to result in so much added time. As I became aware of their impact on speed, I decided to rewrite them using proper indexing as follows:

<ol>
    <li>action_embeddings = self.action_embeddings[torch.arange(self.action_embeddings.size(0)), a1_vec]</li>
    <li>action_mask = self.action_mask[torch.arange(self.action_mask.size(0)), a1_vec]</li>
</ol>

These new functions do not appear in the top 10 total time spent in function call for the new script. Cleaning these up also reduced the amount of copying that needed to be done within PyTorch's backend, which further optimized the code. Altogether, this improvement shaved off approximately 9 seconds from the runtime per training iteration, taking the original 38-second iteration down to 29 seconds, which is a speedup of approximately 24%. There is probably room for improvement elsewhere, requiring further analysis of the new cProfile results.