In [1]:
import json
from sklearn.preprocessing import MinMaxScaler
import numpy as np
import torch
import torch.nn as nn

# ----------------------------
# Load target scaler (as before)
# ----------------------------
with open('/Users/justint/Library/CloudStorage/OneDrive-Personal/Desktop/Academic Stuff/Arias Research/Materials_NN/Paper_Materials_2025/training/batch_3/scaler_y_0_025eV.json', 'r') as f:
    scaler_params = json.load(f)

scaler_y = MinMaxScaler(feature_range=tuple(scaler_params['feature_range']))
scaler_y.min_        = np.array(scaler_params['min_'])
scaler_y.scale_      = np.array(scaler_params['scale_'])
scaler_y.data_min_   = np.array(scaler_params['data_min_'])
scaler_y.data_max_   = np.array(scaler_params['data_max_'])
scaler_y.data_range_ = np.array(scaler_params['data_range_'])

# ----------------------------
# Device
# ----------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

# ----------------------------
# Inverse scaler (same API)
# ----------------------------
class InverseScaler(nn.Module):
    def __init__(self, feature_range, data_min, data_range):
        super().__init__()
        feature_min, feature_max = feature_range
        a = data_range / (feature_max - feature_min)
        b = (-feature_min * a) + data_min
        self.register_buffer('a_tensor', torch.tensor(a, dtype=torch.float32))
        self.register_buffer('b_tensor', torch.tensor(b, dtype=torch.float32))
    def forward(self, x):
        return x * self.a_tensor + self.b_tensor

# ----------------------------
# Max-Affine predictor (hard max)
# y(x) = max_j (w_j^T x + b_j)
# ----------------------------
class MaxAffine(nn.Module):
    def __init__(self, W, b):
        super().__init__()
        # W: [K, d], b: [K]
        # Store as Parameters for convenience (eval mode by default)
        self.W = nn.Parameter(W, requires_grad=False)
        self.b = nn.Parameter(b, requires_grad=False)
    def forward(self, x):
        # x: [N, d] (same feature layout you already use)
        z = x @ self.W.t() + self.b[:, None]   # [N, K]
        y = z.max(dim=1).values.view(-1, 1)    # [N, 1]
        return y

# ----------------------------
# Wrapper to match your "combined_model" name
# ----------------------------
class DeepSetWithInverseScaling(nn.Module):
    # Keep this class name to avoid changing downstream code
    def __init__(self, original_model, scaler_params):
        super().__init__()
        self.original_model = original_model
        feature_range = scaler_params['feature_range']
        data_min = scaler_params['data_min_'][0]
        data_range = scaler_params['data_range_'][0]
        self.inverse_scaler = InverseScaler(feature_range, data_min, data_range)
    def forward(self, x):
        scaled_output = self.original_model(x)          # standardized [0,1] target space
        inverse_scaled_output = self.inverse_scaler(scaled_output)  # back to eV
        return inverse_scaled_output

# ----------------------------
# Load OH_elements / p_dict_oh (as before)
# ----------------------------
with open('/Users/justint/Library/CloudStorage/OneDrive-Personal/Desktop/Academic Stuff/Arias Research/Materials_NN/Paper_Materials_2025/training/OH_elements.json', 'r') as file:
    OH_elements = json.load(file)
total_N_elements = len(OH_elements)

with open('/Users/justint/Library/CloudStorage/OneDrive-Personal/Desktop/Academic Stuff/Arias Research/Materials_NN/Paper_Materials_2025/training/p_dict_oh.json', 'r') as file:
    p_dict_oh = json.load(file)

input_dim = len(p_dict_oh['H'])   # feature dimension

# ----------------------------
# Load max-affine weights
# Expecting a .pt saved as: torch.save({'W': tensor[K,d], 'b': tensor[K]}, path)
# ----------------------------
max_affine_path = '/Users/justint/Library/CloudStorage/OneDrive-Personal/Desktop/Academic Stuff/Arias Research/Materials_NN/Paper_Materials_2025/training/batch_3/maxaffine_K7_0_025eV_oneB.pt'  # <-- set this
ckpt = torch.load(max_affine_path, map_location=device)

W = ckpt['W'].float().to(device)   # [K, d]
b = ckpt['b'].float().to(device)   # [K]

# Sanity check on dimensions
assert W.dim() == 2 and b.dim() == 1, "Checkpoint must contain W[K,d] and b[K]"
assert W.shape[1] == input_dim, f"Feature dim mismatch: W has d={W.shape[1]} but input_dim={input_dim}"

# ----------------------------
# Build models
# ----------------------------
# Max-affine predictor in standardized target space
model = MaxAffine(W, b).to(device)
model.eval()

# Prepare scaler params for inverse (same structure you used)
scaler_params_inverse = {
    'feature_range': scaler_y.feature_range,
    'data_min_': scaler_y.data_min_,
    'data_range_': scaler_y.data_range_
}

# Expose the SAME NAME used elsewhere in your code:
combined_model = DeepSetWithInverseScaling(model, scaler_params_inverse).to(device)
combined_model.eval()


Using device: cpu


DeepSetWithInverseScaling(
  (original_model): MaxAffine()
  (inverse_scaler): InverseScaler()
)

In [2]:
"""
Periodic Table Highlighter (matplotlib)
--------------------------------------
Now supports MULTIPLE groups (list of lists), each highlighted with its own color.

Changes requested:
✓ Different colors per group
✓ Darker outlines and element text
✓ Larger element names
✓ Remove unused empty squares above transition metals
✓ Remove "Lanthanides" / "Actinides" labels
✓ No title/legend; tight crop that fits the table nicely with minimal padding

CLI examples:
    # Single group (backward compatible)
    python periodic_highlighter.py "H,Be,B,Al,Si,Zn,Ga,Ge,In,Sn,Sb,Tl,Pb,Bi" -o one_group.png

    # Multiple groups (semicolon-separated groups, comma-separated symbols)
    python periodic_highlighter.py -g "H,Be; B,C,N,O; Zn,Ga,Ge,In,Sn,Sb,Tl,Pb,Bi" -o multi_groups.png \
        --colors "#fde68a,#fecaca,#bfdbfe" --fontsize 20 --dpi 300

Programmatic example:
    from periodic_highlighter import draw_periodic_groups
    groups = [["H","Be"], ["B","C","N","O"], ["Zn","Ga","Ge","In","Sn","Sb","Tl","Pb","Bi"]]
    draw_periodic_groups(groups, out_path="multi.png")
"""
from __future__ import annotations
import argparse
from typing import Iterable, Dict, Tuple, List, Set, Optional
import colorsys
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# ------------------------------
# Periodic table layout (18 cols)
# ------------------------------
PT_LAYOUT: List[List[str]] = [
    # 1
    ["H",  "",  "",  "",  "",  "",  "",  "",  "",  "",  "",  "",  "",  "",  "",  "",  "",  "He"],
    # 2
    ["Li", "Be", "",  "",  "",  "",  "",  "",  "",  "",  "",  "",  "B",  "C",  "N",  "O",  "F",  "Ne"],
    # 3
    ["Na", "Mg", "",  "",  "",  "",  "",  "",  "",  "",  "",  "",  "Al", "Si", "P",  "S",  "Cl", "Ar"],
    # 4
    ["K",  "Ca", "Sc", "Ti", "V",  "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr"],
    # 5
    ["Rb", "Sr", "Y",  "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I",  "Xe"],
    # 6
    ["Cs", "Ba", "La", "Hf", "Ta", "W",  "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn"],
    # 7
    ["Fr", "Ra", "Ac", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og"],
    # spacer (visual gap between main block and f-block rows)
    ["" for _ in range(18)],
    # Lanthanides (period 6 f-block) starting under group 4 (index 3)
    ["", "", "", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", ""],
    # Actinides (period 7 f-block)
    ["", "", "", "Th", "Pa", "U",  "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr", ""],
]

# Build lookup table: symbol -> (row, col)
POS: Dict[str, Tuple[int, int]] = {}
for r, row in enumerate(PT_LAYOUT):
    for c, sym in enumerate(row):
        if sym:
            POS[sym] = (r, c)

ALL_SYMBOLS: Set[str] = set(POS.keys())

# ------------------------------
# Drawing functions
# ------------------------------

def draw_periodic_groups(groups: List[Iterable[str]], out_path: str = "periodic_highlight.png",
                         dpi: int = 240,
                         facecolor: str = "white",
                         base_color: str = "#f8fafc",
                         line_color: str = "#111827",
                         colors: Optional[List[str]] = None,
                         text_color: str = "#0b0f15",
                         fontsize: int = 24) -> None:
    """Render a periodic table and highlight *multiple* groups of elements.

    Args:
        groups: list of iterables of element symbols. Each inner list gets its own color.
        out_path: output filepath (PNG/PDF/SVG)
        dpi: image DPI for rasters
        facecolor, base_color, line_color, text_color: colors
        colors: list of hex colors for groups; if None, a distinct, colorblind‑friendly palette is used
        fontsize: element symbol font size
    Notes:
        • Darker outlines and larger text by default.
        • No empty placeholder squares above transition metals.
        • No Lanthanides/Actinides labels. No title. Tight crop with padding.
    """
    # Normalize groups
    norm_groups: List[List[str]] = []
    for g in groups:
        norm_groups.append([s.strip() for s in g if s and isinstance(s, str)])

        # Build an element -> color map
    palette = _make_palette(len(norm_groups), colors)

    # Map element -> color (later groups override earlier on conflicts)
    sym2color: Dict[str, str] = {}
    for gi, g in enumerate(norm_groups):
        col = palette[gi % len(palette)]
        for s in g:
            if s in ALL_SYMBOLS:
                sym2color[s] = col

    # Map element -> color (later groups override earlier on conflicts)
    sym2color: Dict[str, str] = {}
    for gi, g in enumerate(norm_groups):
        col = palette[gi % len(palette)]
        for s in g:
            if s in ALL_SYMBOLS:
                sym2color[s] = col

    # Figure/axes: fill canvas
    fig = plt.figure(figsize=(20, 12), dpi=dpi, facecolor=facecolor)
    ax = fig.add_axes([0.005, 0.005, 0.99, 0.99])
    ax.set_facecolor(facecolor)
    ax.set_aspect('equal', adjustable='box')
    ax.axis('off')

    w, h = 1.0, 1.0

    # Bounds based only on occupied cells
    occupied: List[Tuple[int,int]] = [(r, c) for r, row in enumerate(PT_LAYOUT) for c, sym in enumerate(row) if sym]
    min_r = min(r for r, _ in occupied)
    max_r = max(r for r, _ in occupied)
    min_c = min(c for _, c in occupied)
    max_c = max(c for _, c in occupied)
    pad = 0.2
    ax.set_xlim(min_c - pad, max_c + 1 + pad)
    ax.set_ylim(min_r - pad, max_r + 1 + pad)
    ax.invert_yaxis()

    # Draw element cells only
    for r, row in enumerate(PT_LAYOUT):
        for c, sym in enumerate(row):
            if not sym:
                continue
            fc = sym2color.get(sym, base_color)
            rect = Rectangle((c, r), w, h, facecolor=fc, edgecolor=line_color, linewidth=1.6)
            ax.add_patch(rect)
            ax.text(c + 0.5, r + 0.54, sym,
                    ha='center', va='center', fontsize=fontsize, color=text_color, family='DejaVu Sans', fontweight='bold')

    # Save tightly cropped
    ext = out_path.split('.')[-1].lower()
    save_kwargs = dict(bbox_inches='tight', pad_inches=0.05, facecolor=facecolor)
    if ext in {"png", "jpg", "jpeg", "pdf", "svg"}:
        fig.savefig(out_path, **save_kwargs)
    else:
        fig.savefig("periodic_highlight.png", **save_kwargs)
    plt.close(fig)


def draw_periodic(highlight: Iterable[str], out_path: str = "periodic_highlight.png", **kwargs) -> None:
    """Backward-compatible single-group API.
    Example: draw_periodic(["H","Be"], out_path="one_group.png")
    """
    draw_periodic_groups([list(highlight)], out_path=out_path, **kwargs)


# ------------------------------
# Palette helpers
# ------------------------------

def _rgb_to_hex(rgb: Tuple[float, float, float]) -> str:
    r, g, b = [int(round(255*x)) for x in rgb]
    return f"#{r:02x}{g:02x}{b:02x}"


def _evenly_spaced_hues(n: int, s: float = 0.65, v: float = 0.90, hue_offset: float = 0.0) -> List[str]:
    """Generate n distinct colors by evenly spacing hues (HSV -> hex). Avoids very pale yellows."""
    cols: List[str] = []
    for i in range(n):
        h = (hue_offset + i / max(n, 1)) % 1.0
        # avoid the ~60° yellow band for legibility by nudging hues near 1/6
        if abs(h - 1/6) < 0.04:
            h += 0.06
        r, g, b = colorsys.hsv_to_rgb(h, s, v)
        cols.append(_rgb_to_hex((r, g, b)))
    return cols


def _make_palette(n: int, user_colors: Optional[List[str]] = None) -> List[str]:
    """Return a list of n distinct, colorblind‑aware colors.
    Priority:
      1) If user provided colors, use them (and extend if short).
      2) Use Matplotlib's 'tab10' (10 very distinct), then 'tab20' (evens then odds).
      3) If still short, generate evenly spaced HSV hues (skipping pale yellow band).
    """
    base: List[str] = []

    if user_colors:
        base.extend(user_colors)

    try:
        import matplotlib.pyplot as _plt
        tab10 = _plt.get_cmap('tab10')
        tab20 = _plt.get_cmap('tab20')
        # tab10 first (colorblind‑friendly)
        base.extend([_rgb_to_hex(tab10(i)[:3]) for i in range(10)])
        # pick every other from tab20 to maximize separation, then the rest
        base.extend([_rgb_to_hex(tab20(i)[:3]) for i in range(0, 20, 2)])
        base.extend([_rgb_to_hex(tab20(i)[:3]) for i in range(1, 20, 2)])
    except Exception:
        # Fallback if colormaps not available
        pass

    if len(base) < n:
        base.extend(_evenly_spaced_hues(n - len(base)))

    return base[:n]


In [3]:
mat = W.unsqueeze(1) - W.unsqueeze(0) + b.unsqueeze(1).unsqueeze(2) - b.unsqueeze(0).unsqueeze(2)
mat = mat[:,:,:92]
active_elements_onehot = []
for i in range(mat.shape[0]):
    row = mat[i]
    if i == 0:
        other = row[i+1:]
    elif i == mat.shape[0]-1:
        other = row[:-1]
    else:
        other = torch.cat((row[:i], row[i+1:]), dim=0)
    other[other<=0] = 0
    other[other>0] = 1
    other_prod = torch.prod(other, dim=0)

    active_elements_onehot.append(other_prod)

active_elements = []
for i in range(len(active_elements_onehot)):
    idx = torch.nonzero(active_elements_onehot[i], as_tuple=True)[0]
    idx = idx.tolist()
    elems = [OH_elements[j] for j in idx]
    active_elements.append(elems)

rainbow = [
    "#d94c4c",  # red
    "#63c965",  # green
    "#7465d8",  # indigo
    "#b45de9",  # violet
    "#f6db4a",  # yellow
    "#469be8",  # blue
    "#f49b25",  # orange
]

draw_periodic_groups(active_elements, out_path=f"active_elements_K7.png", colors=rainbow)