In [1]:
import sys
sys.path.append("../../")

import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.50'

import numpy as np

In [2]:
# Constant Parameters

REGENERATE = True
PICKLE_PREFIX = "pickles/stark_sliders/"

COMMON_ARGS = {
    "temperature_kelvin": 300,
    "mode_frequencies": [1200, 100],
    "mode_couplings": [0.7, 2.0],
}

TWO_STATE_ARGS = {
    "transfer_integral": 100,
    "broadening": 200,
    "mode_basis_sets": [20, 200],
}

MLJ_ARGS = {
    "disorder_meV": 0,
    "basis_size": 20,
}

STARK_ARGS = {
    "positive_field_sum_percent": 0.5,
}

In [3]:
from dataclasses import dataclass
import ipywidgets as widgets

@dataclass
class DefaultMinMaxNum:
    default: float
    min: float
    max: float
    num: int
    
    def to_range(self):
        return np.linspace(self.min, self.max, self.num)
    
    def to_slider(self):
        if self.num > 1:
            own_range = self.to_range()
            step = own_range[1] - own_range[0]
        else:
            step = 0.1
        return widgets.FloatSlider(min=self.min, max=self.max, step=step, value=self.default)

# Slider Parameters
SLIDER_ARGS = {
    "energy_gap": DefaultMinMaxNum(8_000, 0, 8_000, 4),
    "g1": DefaultMinMaxNum(0.7, 0.7, 0.7, 1),
    "g2": DefaultMinMaxNum(2, 2, 8, 4),
    "positive_field_strength": DefaultMinMaxNum(0.01, 0.01, 0.1, 4),
    "field_delta_dipole": DefaultMinMaxNum(38, 0, 38, 4),
    "field_delta_polarizability": DefaultMinMaxNum(1_000, 0, 1_000, 4),
}

In [4]:
from itertools import product
import inspect
import os
import jax
jax.config.update("jax_enable_compilation_cache", False)

from quantumspectra_2024.models import StarkModel, TwoStateModel, MLJModel
from plot_utils import save_file, open_file

# PREPROCESS SPECTRA

two_state_valids = list(inspect.signature(TwoStateModel).parameters)
mlj_valids = list(inspect.signature(MLJModel).parameters)
stark_valids = list(inspect.signature(StarkModel).parameters)

def compute_spectra(slider_params):
    all_args = {**COMMON_ARGS, **slider_params, **TWO_STATE_ARGS, **MLJ_ARGS, **STARK_ARGS}
    
    two_state_args = {k: v for k, v in all_args.items() if k in two_state_valids}
    mlj_args = {k: v for k, v in all_args.items() if k in mlj_valids}
    stark_args = {k: v for k, v in all_args.items() if k in stark_valids}
    
    two_state_args["mode_couplings"] = [slider_params["g1"], slider_params["g2"]]
    mlj_args["mode_couplings"] = [slider_params["g1"], slider_params["g2"]]
    
    two_state_model = TwoStateModel(**two_state_args)
    mlj_model = MLJModel(**mlj_args)
    
    two_state_stark_model = StarkModel(neutral_submodel=two_state_model, **stark_args)
    mlj_stark_model = StarkModel(neutral_submodel=mlj_model, **stark_args)
    
    return (
        two_state_model.get_absorption(), 
        mlj_model.get_absorption(), 
        two_state_stark_model.get_absorption(), 
        mlj_stark_model.get_absorption(),
    )

def get_spectra(slider_params):
    # cut slider_params floats to 2 decimal places
    slider_params = {k: round(v, 3) if isinstance(v, float) else v for k, v in slider_params.items()}
    
    if not os.path.exists(f"{PICKLE_PREFIX}{slider_params}.pkl"):
        ts_abs, mlj_abs, ts_stark_abs, mlj_stark_abs = compute_spectra(slider_params)
        save_file((ts_abs, mlj_abs, ts_stark_abs, mlj_stark_abs), f"{PICKLE_PREFIX}{slider_params}.pkl")
    else:
        ts_abs, mlj_abs, ts_stark_abs, mlj_stark_abs = open_file(f"{PICKLE_PREFIX}{slider_params}.pkl")
        
    return ts_abs, mlj_abs, ts_stark_abs, mlj_stark_abs


if REGENERATE:
    # matches arrays of valid values to parameter names
    range_dict = {k: v.to_range() for k, v in SLIDER_ARGS.items()}

    # generate all combinations of parameter values
    param_values = list(product(*range_dict.values()))
    labeled_param_values = [dict(zip(range_dict.keys(), p)) for p in param_values]
    
    for idx, slider_params in enumerate(labeled_param_values):
        get_spectra(slider_params)
        print(f"{idx + 1}/{len(labeled_param_values)}", end="\r")
        jax.clear_caches()

11/1024

2024-05-13 00:45:41.818671: W external/tsl/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 1.48GiB (rounded to 1590595072)requested by op 
2024-05-13 00:45:41.818799: W external/tsl/tsl/framework/bfc_allocator.cc:494] *************************************************************************x***************___________
E0513 00:45:41.818852   35977 pjrt_stream_executor_client.cc:2809] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1590594900 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    1.48GiB
              constant allocation:         0B
        maybe_live_out allocation:    1.48GiB
     preallocated temp allocation:         0B
                 total allocation:    2.96GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 1.48GiB
		Entry Parameter Subshape: f32[2001,198725]

	Buffer 2:
		Size: 1.48GiB
		Ope

ValueError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1590594900 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    1.48GiB
              constant allocation:         0B
        maybe_live_out allocation:    1.48GiB
     preallocated temp allocation:         0B
                 total allocation:    2.96GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 1.48GiB
		Entry Parameter Subshape: f32[2001,198725]
		==========================

	Buffer 2:
		Size: 1.48GiB
		Operator: op_name="jit(fn)/jit(main)/mul" source_file="/home/ben/Python-Projects/QuantumSpectra-2024/quantumspectra_2024/models/two_state/TwoStateComputation.py" source_line=40
		XLA Label: fusion
		Shape: f32[2001,198725]
		==========================

	Buffer 3:
		Size: 776.3KiB
		Entry Parameter Subshape: f32[198725]
		==========================



In [5]:
import matplotlib.pyplot as plt
from plot_utils import get_stats, fix_twinx_ticks

def plot(axes, abs, stark_dipole, stark_polarizability):
    axes[0].plot(abs.energies, abs.intensities)
    
    abs_avg, *_ = get_stats(abs.energies, abs.intensities)
    axes[0].axvline(x=abs_avg, color="red", linestyle="--", label=r"$\bar{E}$")
    axes[0].legend()
        
    first_derivative = np.gradient(abs.intensities, abs.energies)
    second_derivative = np.gradient(first_derivative, abs.energies)
    
    stark_twins = [ax.twinx() for ax in axes[1:]]
    
    stark_twins[0].plot(abs.energies, second_derivative, color="orange", zorder=1)
    stark_twins[1].plot(abs.energies, first_derivative, color="orange", zorder=1)
    
    axes[1].plot(stark_dipole.energies, stark_dipole.intensities, zorder=2)
    axes[2].plot(stark_polarizability.energies, stark_polarizability.intensities, zorder=2)
    
    for ax, twin in zip(axes[1:], stark_twins):
        fix_twinx_ticks(ax, twin)

def display_plot(**params):
    jax.clear_caches()
    plt.clf()
    
    dipole_params = {k: v for k, v in params.items() if k != "field_delta_polarizability"}
    dipole_params["field_delta_polarizability"] = 0.0
    
    polarizability_params = {k: v for k, v in params.items() if k != "field_delta_dipole"}
    polarizability_params["field_delta_dipole"] = 0.0
    
    fig, axes = plt.subplots(3, 2)
    ts_abs, mlj_abs, ts_stark_abs_dipole, mlj_stark_abs_dipole = get_spectra(dipole_params)
    _, _, ts_stark_abs_polarizability, mlj_stark_abs_polarizability = get_spectra(polarizability_params)
    
    plot(axes[:, 0], ts_abs, ts_stark_abs_dipole, ts_stark_abs_polarizability)
    plot(axes[:, 1], mlj_abs, mlj_stark_abs_dipole, mlj_stark_abs_polarizability)
    
    axes[0,0].set_title("QM")
    axes[0,1].set_title("MLJ")
    
    fig.tight_layout()
    
widgets.interact(display_plot, **{k: v.to_slider() for k, v in SLIDER_ARGS.items()})

interactive(children=(FloatSlider(value=0.7, description='g1', max=0.7, min=0.7), FloatSlider(value=2.0, descr…

<function __main__.display_plot(**params)>