In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

parameters = {'axes.labelsize': 25,
              'axes.titlesize': 35,
              'xtick.labelsize': 20,
              'ytick.labelsize': 20,
              'legend.fontsize': 20,
              }
plt.rcParams.update(parameters)

PERSONAL_DIR = "/cluster/tufts/minos/jwolcott/app/personal/nd/nd-lar-reco"
TRAIN_DIR = "/media/hdd1/jwolcott/data/dune/nd/nd-lar-reco/train"
VALID_DIR = "/media/hdd1/jwolcott/data/dune/nd/nd-lar-reco/valid"

#SAMPLE = "uresnet+ppn-380Kevs-25Kits-batch32"
#SAMPLE = "uresnet+ppn-380Kevs-50Kits-batch32"
#SAMPLE = "track+showergnn-380Kevs-15Kits-batch32"
#SAMPLE = "track+showergnn-380Kevs-15Kits-batch16-attempt2"
SAMPLE = "track+intergnn-1400evs-1000Kits-batch8"
#SAMPLE = "tests"

In [None]:
target_dir = os.path.join(TRAIN_DIR, SAMPLE)
  
csvs=[os.path.join(target_dir,f) for f in os.listdir(target_dir) if f.endswith('.csv')]
dfs=[pd.read_csv(f) for f in csvs]
for idx in np.argsort([df.iter.min() for df in dfs]):
    df=dfs[idx]
    print(csvs[idx],df.iter.min(),'=>',df.iter.max())
df=pd.concat([dfs[idx] for idx in np.argsort([df.iter.min() for df in dfs])])
print(sorted(df.keys()))
print("losses:", sorted(k for k in df.keys() if "loss" in k))

In [None]:
import pathlib
import re

valid_dir = os.path.join(VALID_DIR, SAMPLE, "log_inference")
print("validation file dir:", valid_dir)
dfs_valid = []
filepattern = re.compile('.*log-(\d+).*')
for f in pathlib.Path(valid_dir).glob("**/*.csv"):
    f = str(f)
    matches = filepattern.match(f)
    if not matches:
        continue
    
    dfs_valid.append(pd.read_csv(f))
    dfs_valid[-1]['iter'] = int(matches.group(1))

df_valid = None
if len(dfs_valid) > 0:
    df_valid = pd.concat([dfs_valid[idx] for idx in np.argsort([df.iter.min() for df in dfs_valid])])

In [None]:
plotdir="/media/hdd1/jwolcott/data/dune/nd/nd-lar-reco/plots/" + SAMPLE
if not os.path.isdir(plotdir):
    os.mkdir(plotdir)

loss_types = {
#      "ppn_loss": "PPN loss",
#      "seg_loss": "SS loss",
#    "uresnet_loss": "SS loss",
#    "loss_ppn1": "PPN1 loss",
#    "loss_ppn2": "PPN2 loss",
#    "shower_edge_loss": "Shower GNN edge loss",
#    "shower_node_loss": "Shower GNN node loss",
#    "track_edge_loss": "Track GNN edge loss",
    "inter_edge_loss": "Interaction GNN edge loss"
#    "loss": "Total loss",
}

fig,ax=plt.subplots(figsize=(12,8),facecolor='w')
sdf=df

colors = {}
test = {}
for loss_name, loss_title in loss_types.items():
    print("considering loss:", loss_name)
    if loss_name in sdf:
        p = ax.plot(sdf.iter, sdf[loss_name], label=loss_title + " (train)", alpha=0.75)
        if loss_title not in colors:
            colors[loss_title] = p[-1].get_color()

    if df_valid and loss_name in df_valid:
        test[loss_title] = ax.plot(df_valid.iter, df_valid[loss_name], label=loss_title + " (test)", marker='o')[0]

# go back and set the "test" samples correctly
for title, color in colors.items():
    if title in test:
        test[title].set_color(color)

ax.set_yscale('log')
#ax.set_ylim(1e-10,1)
ax.set_xlabel("Iteration")
ax.set_ylabel("Loss")
ax.legend()
ax.grid(True)
plt.show()

for ext in ("pdf", "png"):
    fig.savefig(os.path.join(plotdir, "loss." + ext))

In [None]:


fig,ax=plt.subplots(figsize=(12,8),facecolor='w')
sdf=df.query('iter<15000')
ax.plot(sdf.iter,sdf.ppn_loss, label="PPN loss")
ax.plot(sdf.iter,sdf.loss, label="total loss")
#ax.plot(sdf.iter,sdf.frag_edge_loss)
#ax.plot(sdf.iter,sdf.frag_node_loss)
ax.set_yscale('log')
ax.set_ylim(0.01,5.0)
ax.grid(True)
plt.legend()
plt.show()

In [None]:
fig,ax=plt.subplots(figsize=(12,8),facecolor='w')
# ax.plot(df["fraction_positives_ppn1"], label="positives PPN1")
# ax.plot(df["fraction_positives_ppn2"], label="positives PPN2")
ax.plot(df["shower_node_accuracy"], label="shower node accuracy")
ax.plot(df["shower_edge_accuracy"], label="shower edge accuracy")
ax.plot(df["track_edge_accuracy"], label="track edge accuracy")
ax.set_xlabel("Iteration")
# ax.set_ylabel("Fraction")
ax.legend()

In [None]:
import re

times = []
pat = re.compile(r".*train time.*\(([0-9.]+) \[s\]\).*")
for line in open(os.path.join(TRAIN_DIR, "train-inter-gnn.20210718.log")):
    matches = pat.match(line)
    if matches:
        times.append(float(matches.group(1)))

fig,ax=plt.subplots(figsize=(12,8),facecolor='w')
ax.plot(times)
ax.set_xlabel("Iteration")
ax.set_ylabel("Train time (s)")

times = np.array(times)
#print(times)
start_place = 15
avg_time = sum(times[start_place:])/len(times[start_place:])
plt.plot((start_place, len(times)-1), (avg_time, avg_time), label="y=%.0f" % avg_time )

plt.legend()