# Analyzing logs

In [34]:
from typing import Optional
import fabric
from pathlib import Path
import re
import pandas as pd

class LocalWandbApi(object):
    def __init__(self, project_dir: str, ssh_host: Optional[str] = None):
        self.project_dir = project_dir
        self.ssh_host = ssh_host
        self.runner = fabric.Connection(ssh_host) if ssh_host else None
        self.wandb_dir = f"{self.project_dir}/detrex/wandb_output/wandb"
        self.folder_pat = re.compile(r"run-\d{8}_\d{6}-\w{7}")
    
    def list_runs(self) -> dict[str, Path]: 
        """ Return a dictionary run_id: run_folder. """
        output = self.runner.run(f"ls {self.wandb_dir}", hide=True).stdout
        folders = output.split("\n")
        runs = {run.split("-")[-1]: Path(f"{self.wandb_dir}/{run}") for run in folders if self.folder_pat.match(run)}
        return runs


    def get_run_folder(self, run_id: str) -> Path:
        """ Return the folder of a run. """
        # Example: run-20240805_162409-67c1veuz
        runs = self.list_runs()
        return runs[run_id]
    
    def get_run_logs_raw(self, run_id: str) -> str:
        """ Return the raw logs of a run. """
        run_folder = self.get_run_folder(run_id)
        output = self.runner.run(f"cat {run_folder}/files/output.log", hide=True).stdout
        files = output.split("\n")
        return files
    
    def get_run_logs(self, run_id: str) -> pd.DataFrame:
        """ Return a DataFrame with columns date, time, logger, message. """
        logs = self.get_run_logs_raw(run_id)
        pat = re.compile(r"^\[(\d{2}/\d{2}) (\d{2}:\d{2}:\d{2}) (.*)\]: (.+)")
        data = []
        columns = ["date", "time", "logger", "message"]
        for log in logs:
            # filter rows that do not contain logger prefix 
            # Examples:
            # [08/05 18:18:58 d2.evaluation.evaluator]: asdfsadfj asdfas
            # [08/05 18:20:13 d2.evaluation.fast_eval_api]: asdfsdf asdfasd
            match = pat.match(log)
            if match:
                data.append(match.groups())
        df = pd.DataFrame(data, columns=columns)
        return df

In [35]:
api = LocalWandbApi("~/development/edge", ssh_host="amsterdam")
dict(list(api.list_runs().items())[:2])

{'t0i3c54i': PosixPath('~/development/edge/detrex/wandb_output/wandb/run-20240802_161521-t0i3c54i'),
 'pt1ugkkz': PosixPath('~/development/edge/detrex/wandb_output/wandb/run-20240802_180626-pt1ugkkz')}

In [59]:
logsdf = api.get_run_logs(run_id="mzczh48c")

In [60]:
from datetime import datetime
logsdf = logsdf[logsdf.logger == "d2.utils.events"]
logsdf["iter"] = logsdf.message.str.extract(r"iter: (\d+)").astype(int)
# logsdf["date"] in format mm/dd, assume year 2024
# logsdf["time"] in format hh:mm:ss
logsdf["datetime"] = logsdf.apply(lambda x: datetime.strptime(f"2024-{x.date} {x.time}", "%Y-%m/%d %H:%M:%S"), axis=1)
# timestamp in seconds
logsdf["timestamp"] = logsdf.datetime.astype(int) // 10**9
logsdf["timestamp_delta"] = logsdf.timestamp - logsdf.timestamp.min()
logsdf


Unnamed: 0,date,time,logger,message,iter,datetime,timestamp,timestamp_delta
4,08/05,15:23:02,d2.utils.events,eta: 2:08:55 iter: 82519 total_loss: 10.74 ...,82519,2024-08-05 15:23:02,1722871382,0
5,08/05,15:23:25,d2.utils.events,eta: 2:06:55 iter: 82539 total_loss: 10.18 ...,82539,2024-08-05 15:23:25,1722871405,23
6,08/05,15:23:44,d2.utils.events,eta: 2:02:15 iter: 82559 total_loss: 10.25 ...,82559,2024-08-05 15:23:44,1722871424,42
7,08/05,15:24:07,d2.utils.events,eta: 2:01:55 iter: 82579 total_loss: 10.01 ...,82579,2024-08-05 15:24:07,1722871447,65
8,08/05,15:24:26,d2.utils.events,eta: 2:01:13 iter: 82599 total_loss: 10.36 ...,82599,2024-08-05 15:24:26,1722871466,84
...,...,...,...,...,...,...,...,...
187,08/05,16:08:18,d2.utils.events,eta: 1:22:14 iter: 85019 total_loss: 9.656 ...,85019,2024-08-05 16:08:18,1722874098,2716
188,08/05,16:08:38,d2.utils.events,eta: 1:22:00 iter: 85039 total_loss: 9.582 ...,85039,2024-08-05 16:08:38,1722874118,2736
189,08/05,16:08:58,d2.utils.events,eta: 1:21:39 iter: 85059 total_loss: 10.31 ...,85059,2024-08-05 16:08:58,1722874138,2756
190,08/05,16:09:17,d2.utils.events,eta: 1:21:12 iter: 85079 total_loss: 9.904 ...,85079,2024-08-05 16:09:17,1722874157,2775


In [72]:
from scipy.optimize import curve_fit
import plotly.graph_objects as go
from sklearn.metrics import r2_score

# linear regression
def linear(x, a, b):
    return a * x + b

x = logsdf.iter.astype(float)
y = logsdf.timestamp_delta.astype(float)
popt, pcov = curve_fit(linear, x, y)
y_pred = linear(x, *popt)

fig = go.Figure()
fig.add_trace(go.Scatter(x=x, y=y, mode="markers", name="data"))
fig.add_trace(go.Scatter(x=x, y=y_pred, name="fit"))
# put the residuals r^2 on the legend
r2 = r2_score(y, y_pred)

# fig.update_layout(title=f"Time x Iteration: {popt[0]:.2f} iter/s", showlegend=True)
fig.update_layout(title=f"Time x Iteration: {popt[0]:.3f} iter/s, r^2={r2:.3f}", showlegend=True)
fig.show()
