In [None]:
import pathlib
from typing import List

import matplotlib.pyplot as plt

In [None]:
RANK = 4


class ExpDir:
    def __init__(self, exp_dir: pathlib.Path):
        self.exp_dir = exp_dir

        # parse exp_dir name: <dataset>-<config>-<seq_len>
        self.dataset, self.config, self.seq_len = exp_dir.name.split("-")
        self.seq_len = int(self.seq_len)

    def __repr__(self):
        return f"ExpDir({self.exp_dir.name}), seq_len={self.seq_len}"

    @staticmethod
    def format_res(res: str) -> float:
        # bytes to GB
        return round(float(res) / 1024 / 1024 / 1024, 3)

    def read_single_res(self, res_file: pathlib.Path) -> float:
        with open(res_file, "r") as f:
            res = f.read()
        return self.format_res(res)

    def read_res(self) -> List[int]:
        res_list = []
        for rank in range(RANK):
            res_file = self.exp_dir / f"mem-{rank}.txt"
            # assert res_file.exists(), f'{res_file} not exists'
            if not res_file.exists():
                print(f"[warning]: {res_file} not exists, return empty list")
                return []
            res_list.append(self.read_single_res(res_file))
        return res_list

In [None]:
base_dir = pathlib.Path("../output")
group_dir_list = [base_dir / exp for exp in ["20231107_5"]]

# print('group_dir_list:', list(group_dir_list))

group_y = {}
longest_x = []
group_x = {}

for group_dir in group_dir_list:
    print(group_dir)
    exp_dir_list = group_dir.glob("*")
    exp_dir_dict = {}
    for exp_dir in exp_dir_list:
        exp_dir = ExpDir(exp_dir)
        exp_dir_dict[exp_dir.seq_len] = exp_dir

    data_x = sorted(exp_dir_dict.keys())
    longest_x = data_x if len(data_x) > len(longest_x) else longest_x
    group_x[group_dir.name] = []

    data_y = [[] for _ in range(RANK)]
    for seq_len in data_x:
        exp_dir = exp_dir_dict[seq_len]
        res_list = exp_dir.read_res()
        for rank, res in enumerate(res_list):
            data_y[rank].append(res)
        if res_list:
            group_x[group_dir.name].append(seq_len)

    group_y[group_dir.name] = data_y

In [None]:
from mem_model import MemModel

data_ground_truth = []
for x in longest_x:
    m = MemModel(s=x, L=8)
    data_ground_truth.append(m.bytes_to_gb(m.total))
data_ground_truth

In [None]:
# load label
label_file = base_dir / "experiments.md"
label_dict = {}
with open(label_file, "r") as f:
    lines = f.readlines()
    for line in lines:
        exp_name = line.split(":")[0].strip()
        label_dict[exp_name] = line.strip()
print(label_dict)

In [None]:
# Create a figure with subplots
fig, axs = plt.subplots(RANK // 2, 2, figsize=(12, 12))

# Reshape axs to a 1D array for easier indexing
axs = axs.ravel()

# Loop through each rank
for rank in range(RANK):
    ax = axs[rank]

    # Plot the data
    for exp_name in sorted(group_y.keys()):
        data_y_single = group_y[exp_name]
        ax.plot(
            group_x[exp_name], data_y_single[rank], "x-", label=label_dict[exp_name]
        )
    # ax.plot(data_x, data_ground_truth, 'o-', color='grey', alpha=0.5, label='theoretical')
    ax.grid()
    ax.set_xlabel("seq_len")
    ax.set_ylabel("mem (GB)")
    ax.set_title(f"rank {rank}")
    ax.legend()

# Adjust the layout and display the plot
plt.tight_layout()
plt.show()