In [None]:
import gc
import glob
import json
import random
from collections import defaultdict
from pathlib import Path
from tempfile import gettempdir

import numpy as np
import pandas as pd
import plotly
import plotly.graph_objects as go
import scipy as sp
import scipy.constants
from IPython.core.display import HTML, display
from jupyter_dash import JupyterDash
from numpy import ma

from common import *

# Setup plotting
JupyterDash.infer_jupyter_proxy_config()
COLORS = plotly.colors.DEFAULT_PLOTLY_COLORS

# Find the default output directory
OUTPUT_DIR = Path(gettempdir()) / "boltzmann_solver" / "leptogenesis"
if not OUTPUT_DIR.is_dir():
    OUTPUT_DIR = Path("/media/ssh/uni-josh/tmp/josh/boltzmann_solver/leptogenesis")

# Decay Only

## 1 Generation

In [None]:
data = read_csv(OUTPUT_DIR / "decay_only" / "1gen" / "n.csv")

print("Integration steps:", len(data["n"].index))
print("Final B-L:", data["n"]["ΔB-L"].iloc[-1])
plot_integration(data)

In [None]:
plot_asymmetry(data, ["H", "L1", "N1"])

In [None]:
plot_density(data, ["H", "L1", "N1"])

## 3 Generations

In [None]:
data = read_csv(OUTPUT_DIR / "decay_only" / "3gen" / "n.csv")

print("Integration steps:", len(data["n"].index))
print("Final B-L:", data["n"]["ΔB-L"].iloc[-1])
plot_integration(data)

In [None]:
plot_asymmetry(data, ["H", "L1", "L2", "L3", "N1", "N2", "N3"])

In [None]:
plot_density(data, ["H", "L1", "L2", "L3", "N1", "N2", "N3"])

# $\Delta L = 2$ Only

## 1 Generation

In [None]:
data = read_csv(OUTPUT_DIR / "decay_only" / "3gen" / "n.csv")

print("Integration steps:", len(data["n"].index))
print("Final B-L:", data["n"]["ΔB-L"].iloc[-1])
plot_integration(data)

In [None]:
plot_asymmetry(data, ["H", "L1", "N1"])

In [None]:
plot_density(data, ["H", "L1", "N1"])

## 3 Generation

In [None]:
data = read_csv(OUTPUT_DIR / "washout_only" / "3gen" / "n.csv")

print("Integration steps:", len(data["n"].index))
print("Final B-L:", data["n"]["ΔB-L"].iloc[-1])
plot_integration(data)

In [None]:
plot_asymmetry(data, ["H", "L1", "L2", "L3", "N1", "N2", "N3"])

In [None]:
plot_density(data, ["H", "L1", "L2", "L3", "N1", "N2", "N3"])

# Decay + $\Delta L = 2$

## 1 Generation

In [None]:
data = read_csv(OUTPUT_DIR / "decay_washout" / "1gen" / "n.csv")

print("Integration steps:", len(data["n"].index))
print("Final B-L:", data["n"]["ΔB-L"].iloc[-1])
plot_integration(data)

In [None]:
plot_asymmetry(data, ["H", "L1", "N1"])

In [None]:
plot_density(data, ["H", "L1", "N1"])

## 3 Generation

In [None]:
data = read_csv(OUTPUT_DIR / "decay_washout" / "3gen" / "n.csv")

print("Integration steps:", len(data["n"].index))
print("Final B-L:", data["n"]["ΔB-L"].iloc[-1])
plot_integration(data)

In [None]:
plot_asymmetry(data, ["H", "L1", "L2", "L3", "N1", "N2", "N3"])

In [None]:
plot_density(data, ["H", "L1", "L2", "L3", "N1", "N2", "N3"])

# Miscellaneous

## Evolution

In [None]:
data, ptcls = read_evolution(OUTPUT_DIR / "evolution.json")

In [None]:
go.Figure(
    data=[go.Scatter(name=p, x=data["beta"], y=data[p, "mass"]) for p in ptcls],
    layout=go.Layout(
        xaxis=go.layout.XAxis(
            title="Inverse Temperature [GeV⁻¹]", type="log", exponentformat="power",
        ),
        yaxis=go.layout.YAxis(title="Mass [GeV]", type="log", exponentformat="power"),
    ),
)

In [None]:
go.Figure(
    data=[
        go.Scatter(name=p, x=data["beta"], y=data[p, "mass"] * data["beta"])
        for p in ptcls
    ],
    layout=go.Layout(
        xaxis=go.layout.XAxis(
            title="Inverse Temperature [GeV⁻¹]", type="log", exponentformat="power",
        ),
        yaxis=go.layout.YAxis(
            title="Mass / Temperatre", type="log", exponentformat="power"
        ),
    ),
)

In [None]:
go.Figure(
    data=[
        go.Scatter(name=p, x=data["beta"], y=data[p, "width"] / data[p, "mass"])
        for p in ptcls
    ],
    layout=go.Layout(
        xaxis=go.layout.XAxis(
            title="Inverse Temperature [GeV⁻¹]", type="log", exponentformat="power",
        ),
        yaxis=go.layout.YAxis(title="Width / Mass", type="log", exponentformat="power"),
    ),
)

## Higgs Equilibrium

In [None]:
datas = list(
    map(read_csv, sorted(glob.glob(str(OUTPUT_DIR / "higgs_equilibrium" / "*.csv"))))
)

In [None]:
go.Figure(
    data=[
        go.Scatter(
            x=data["n"]["beta"],
            y=data["n"]["H"],
            mode="lines",
            line=go.scatter.Line(color=cmap("viridis", i / len(datas))),
            showlegend=False,
        )
        for i, data in enumerate(datas)
    ]
    + [
        go.Scatter(
            x=data["n"]["beta"],
            y=data["n"]["(H)"],
            mode="lines",
            line=go.scatter.Line(color="black"),
            showlegend=False,
        )
        for i, data in enumerate(datas)
    ],
    layout=go.Layout(
        xaxis=go.layout.XAxis(
            title="Inverse Temperature [GeV⁻¹]", type="log", exponentformat="power",
        ),
        yaxis=go.layout.YAxis(title="Width / Mass", type="log", exponentformat="power"),
    ),
)

## Lepton Equilibrium

In [None]:
datas = list(
    map(read_csv, sorted(glob.glob(str(OUTPUT_DIR / "lepton_equilibrium" / "*.csv"))))
)

In [None]:
go.Figure(
    data=[
        go.Scatter(
            x=data["n"]["beta"],
            y=data["n"]["L1"],
            mode="lines",
            line=go.scatter.Line(color=cmap("viridis", i / len(datas))),
            showlegend=False,
        )
        for i, data in enumerate(datas)
    ]
    + [
        go.Scatter(
            x=data["n"]["beta"],
            y=data["n"]["(L1)"],
            mode="lines",
            line=go.scatter.Line(color="black"),
            showlegend=False,
        )
        for i, data in enumerate(datas)
    ],
    layout=go.Layout(
        xaxis=go.layout.XAxis(
            title="Inverse Temperature [GeV⁻¹]", type="log", exponentformat="power",
        ),
        yaxis=go.layout.YAxis(title="Width / Mass", type="log", exponentformat="power"),
    ),
)

## Gammas

In [None]:
data = defaultdict(dict)

for file in glob.glob(str(OUTPUT_DIR / "gamma" / "spline" / "*" / "*.csv")):
    # print(file)
    group = file.split("/")[-2]
    name = file.split("/")[-1].split(".csv")[0]
    data[group][f"{name} [Spline]"] = pd.read_csv(file)

for file in glob.glob(str(OUTPUT_DIR / "gamma" / "raw" / "*" / "*.csv")):
    group = file.split("/")[-2]
    name = file.split("/")[-1].split(".csv")[0]
    data[group][f"{name} [Raw]"] = pd.read_csv(file)

In [None]:
for group in data.keys():
    fig = go.Figure(
        layout=go.Layout(
            xaxis=go.layout.XAxis(
                title="Inverse Temperature", type="log", exponentformat="power",
            ),
            yaxis=go.layout.YAxis(
                title="Interaction Rate",
                type="log",
                exponentformat="power",
                range=[-20, 20],
            ),
        )
    )

    colors = dict()
    for i, name in enumerate(data[group].keys()):
        if len(group.split()) == 3:
            if "Raw" in name:
                continue

        short_name = name.split("[")[0]
        if short_name in colors:
            color = colors[short_name]
        else:
            colors[short_name] = [COLORS[2 * i], COLORS[2 * i + 1]]

        fig.add_trace(
            go.Scatter(
                name=name,
                x=data[group][name]["beta"],
                y=data[group][name]["gamma [normalized]"],
                line=go.scatter.Line(
                    dash="solid" if "Spline" in name else "dot",
                    color=colors[short_name][0],
                ),
            )
        )
        
    fig.add_shape(
        type="rect",
        x0=1e-17, x1=1e-2,
        y0=1e-1, y1=1e1,
        fillcolor="Grey",
        line_color="Grey",
        opacity=0.2
    )
        
    fig.show()

    # Draw the region where process go from being fast to slow
#     ax.autoscale(enable=False)
#     ax.fill_between([1e-17, 1e-2], 0.1, 10, alpha=0.1, color=(0.1, 0.1, 0.1))