In [None]:
import pathlib
from typing import List

import matplotlib.pyplot as plt

In [None]:
RANK = 4

class ExpDir:
    def __init__(self, exp_dir: pathlib.Path):
        self.exp_dir = exp_dir
        
        # parse exp_dir name: <date>-<dataset>-<config>-<seq_len>
        self.date, self.dataset, self.config, self.seq_len = exp_dir.name.split('-')
        self.seq_len = int(self.seq_len)
    
    def __repr__(self):
        return f'ExpDir({self.exp_dir.name}), seq_len={self.seq_len}'
    
    @staticmethod
    def format_res(res: str) -> float:
        # bytes to GB
        return round(float(res) / 1024 / 1024 / 1024, 3)
    
    def read_single_res(self, res_file: pathlib.Path) -> float:
        with open(res_file, 'r') as f:
            res = f.read()
        return self.format_res(res)
    
    def read_res(self) -> List[int]:
        res_list = []
        for rank in range(RANK):
            res_file = self.exp_dir / f'mem-{rank}.txt'
            assert res_file.exists(), f'{res_file} not exists'
            res_list.append(self.read_single_res(res_file))
        return res_list

In [None]:
base_dir = pathlib.Path('../output')
exp_dir_list = base_dir.glob('20231102_1-*')
exp_dir_dict = {}
for exp_dir in exp_dir_list:
    exp_dir = ExpDir(exp_dir)
    exp_dir_dict[exp_dir.seq_len] = exp_dir

data_x = sorted(exp_dir_dict.keys())
data_y = [[] for _ in range(RANK)]
for seq_len in data_x:
    exp_dir = exp_dir_dict[seq_len]
    res_list = exp_dir.read_res()
    for rank, res in enumerate(res_list):
        data_y[rank].append(res)

In [None]:
from calculator import MemModel

data_ground_truth = []
for x in data_x:
    m = MemModel(s=x, L=8)
    data_ground_truth.append(m.bytes_to_gb(m.total))
data_ground_truth

In [None]:
# # make individual plot for each rank
# for rank in range(RANK):
#     plt.figure()
#     plt.plot(data_x, data_y[rank], 'o-')
#     plt.grid()
#     plt.xlabel('seq_len')
#     plt.ylabel('mem (GB)')
#     plt.title(f'rank {rank}')
#     # plt.savefig(f'rank-{rank}.png')

# Create a figure with subplots
fig, axs = plt.subplots(RANK // 2, 2, figsize=(12, 12))

# Reshape axs to a 1D array for easier indexing
axs = axs.ravel()

# Loop through each rank
for rank in range(RANK):
    ax = axs[rank]

    # Plot the data
    ax.plot(data_x, data_y[rank], 'o-', label='actual')
    ax.plot(data_x, data_ground_truth, 'o-', color='grey', alpha=0.5, label='theoretical')
    ax.grid()
    ax.set_xlabel('seq_len')
    ax.set_ylabel('mem (GB)')
    ax.set_title(f'rank {rank}')
    ax.legend()

# Adjust the layout and display the plot
plt.tight_layout()
plt.show()