In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import KDTree
import plotly.graph_objects as go
import plotly.express as px

# Read input data
file_path = 'input_CaF2.txt'  # Update with actual file name

with open(file_path, 'r') as file:
    lines = file.readlines()
    a, b, c = map(float, lines[0].split())  # Lattice parameters (Å)
    alpha, beta, gamma = map(float, lines[1].split())  # Unit cell angles (degrees)
    calc_tresh = float(lines[2].strip())  # Threshold for duplicate removal (Å)

# Convert angles from degrees to radians
alpha_r, beta_r, gamma_r = np.radians([alpha, beta, gamma])

# Build lattice vectors for a triclinic cell
ax = a
ay = 0
az = 0

bx = b * np.cos(gamma_r)
by = b * np.sin(gamma_r)
bz = 0

cx = c * np.cos(beta_r)
cy = c * (np.cos(alpha_r) - np.cos(beta_r) * np.cos(gamma_r)) / np.sin(gamma_r)
cz = np.sqrt(c**2 - cx**2 - cy**2)

lattice_vectors = np.array([[ax, ay, az], [bx, by, bz], [cx, cy, cz]])

# Read atoms and apply lattice transformation
atoms = []
atom_symbols = []
for line in lines[3:]:
    symbol, x, y, z, charge = line.split()[0], *map(float, line.split()[1:5])
    frac_coords = np.array([x, y, z])
    if np.all((0 <= frac_coords) & (frac_coords < 1)):
        cart_coords = frac_coords @ lattice_vectors  # Fractional to Cartesian
        atoms.append([*cart_coords, charge])
        atom_symbols.append(symbol)

atoms = np.array(atoms)
atom_symbols = np.array(atom_symbols)
original_atom_0_position = atoms[0, :3]  # Store reference position
original_atom_0_charge = atoms[0, 3]  # Store reference charge

# Function to remove duplicates using KDTree
def remove_duplicates(atoms, symbols, threshold=0.1):
    positions = np.round(atoms[:, :3], decimals=6)
    _, unique_indices = np.unique(positions, axis=0, return_index=True)
    return atoms[unique_indices], symbols[unique_indices]

atoms, atom_symbols = remove_duplicates(atoms, atom_symbols, threshold=calc_tresh)

# Function to expand lattice
def expand_lattice(expanding_factor):
    shifts = np.array(np.meshgrid(
        np.arange(-expanding_factor, expanding_factor + 1),
        np.arange(-expanding_factor, expanding_factor + 1),
        np.arange(-expanding_factor, expanding_factor + 1),
        indexing='ij'
    )).reshape(3, -1).T

    expanded_atoms = []
    expanded_symbols = []
    for shift in shifts:
        shift_vector = shift @ lattice_vectors  # Shift in Cartesian
        shifted_positions = atoms[:, :3] + shift_vector
        shifted = np.column_stack((shifted_positions, atoms[:, 3]))
        expanded_atoms.append(shifted)
        expanded_symbols.extend(atom_symbols)

    expanded_atoms = np.vstack(expanded_atoms)
    expanded_symbols = np.array(expanded_symbols)
    return remove_duplicates(expanded_atoms, expanded_symbols, threshold=calc_tresh)

# Plot structure using Plotly
def plot_structure(atoms, symbols, title="Structure Visualization", draw_unit_cell=False):
    positions = atoms[:, :3]
    unique_symbols = list(set(symbols))
    colors = px.colors.qualitative.Plotly
    color_map = {sym: colors[i % len(colors)] for i, sym in enumerate(unique_symbols)}
    fig = go.Figure()

    # Plot each species separately
    for sym in unique_symbols:
        mask = (symbols == sym)
        fig.add_trace(go.Scatter3d(
            x=positions[mask, 0],
            y=positions[mask, 1],
            z=positions[mask, 2],
            mode='markers',
            marker=dict(size=4, color=color_map[sym], opacity=0.7),
            name=sym
        ))

    # Draw unit cell box if requested
    if draw_unit_cell:
        origin = np.array([0, 0, 0])
        a_vec, b_vec, c_vec = lattice_vectors
        corners = [
            origin,
            a_vec,
            b_vec,
            c_vec,
            a_vec + b_vec,
            a_vec + c_vec,
            b_vec + c_vec,
            a_vec + b_vec + c_vec
        ]
        corners = np.array(corners)

        edge_indices = [
            (0,1), (0,2), (0,3),
            (1,4), (1,5),
            (2,4), (2,6),
            (3,5), (3,6),
            (4,7), (5,7), (6,7)
        ]

        for i, j in edge_indices:
            fig.add_trace(go.Scatter3d(
                x=[corners[i][0], corners[j][0]],
                y=[corners[i][1], corners[j][1]],
                z=[corners[i][2], corners[j][2]],
                mode='lines',
                line=dict(color='black', width=2),
                showlegend=False
            ))

    fig.update_layout(
        scene=dict(
            xaxis_title='X (Å)',
            yaxis_title='Y (Å)',
            zaxis_title='Z (Å)',
            aspectmode='data'
        ),
        title=title,
        margin=dict(l=0, r=0, b=0, t=40)
    )
    fig.show()

# Compute Madelung constant and charge density within a sphere centered on Ion 0
def calculate_madelung_and_charge_density(expanded_atoms, radius):
    positions = expanded_atoms[:, :3]
    charges = expanded_atoms[:, 3]

    distances = np.linalg.norm(positions - original_atom_0_position, axis=1)

    mask = distances <= radius
    filtered_positions = positions[mask]
    filtered_charges = charges[mask]
    filtered_distances = distances[mask]

    if len(filtered_distances) == 0:
        raise ValueError(f"No valid atoms found in sphere of radius {radius:.3f} Å!")

    net_charge = np.sum(filtered_charges)
    sphere_volume = (4/3) * np.pi * (radius**3)
    charge_density = net_charge / sphere_volume

    mask_atom_0 = np.all(filtered_positions == original_atom_0_position, axis=1)
    madelung_positions = filtered_positions[~mask_atom_0]
    madelung_charges = filtered_charges[~mask_atom_0]
    madelung_distances = filtered_distances[~mask_atom_0]

    if len(madelung_distances) == 0:
        raise ValueError("No atoms left to compute Madelung constant after removing Ion 0!")

    r_0 = np.min(madelung_distances[madelung_distances > 1e-6])
    madelung_constant = -np.sum(madelung_charges / (madelung_distances / r_0))

    return madelung_constant, net_charge, charge_density

expansion_steps = list(range(1, 31))
madelung_values = []
charge_densities = []

for i in expansion_steps:
    radius = i * min(a, b, c)
    expanded_atoms, expanded_symbols = expand_lattice(i)

    if i == 1:
        plot_structure(expanded_atoms, expanded_symbols, title="Expanded Structure (Step 1)", draw_unit_cell=True)

    madelung_constant, net_charge, charge_density = calculate_madelung_and_charge_density(expanded_atoms, radius)

    artificial_ion_position = original_atom_0_position + np.array([radius, 0, 0])
    artificial_ion_charge = -net_charge

    expanded_atoms = np.vstack([expanded_atoms, [*artificial_ion_position, artificial_ion_charge]])

    madelung_constant, _, _ = calculate_madelung_and_charge_density(expanded_atoms, radius)

    madelung_values.append(madelung_constant)
    charge_densities.append(charge_density)

    print(f"Expansion {i}: Radius = {radius:.3f} Å | Number of ions = {len(expanded_atoms):.0f} | Madelung = {madelung_constant:.6e} | Net Charge = {net_charge:.3f} | Charge Density = {charge_density:.6e} | Artificial Ion at {artificial_ion_position} with Charge = {artificial_ion_charge:.3f}")

if len(madelung_values) >= 5:
    last_five_avg_madelung = np.mean(madelung_values[-5:])
    last_five_std_madelung = np.std(madelung_values[-5:])
    print(f"\nAverage Madelung Constant over last 5 expansions: {last_five_avg_madelung:.4f}")
    print(f"Standard Deviation of Madelung Constant over last 5 expansions: {last_five_std_madelung:.4f}")
else:
    print("\nNot enough expansion cycles to compute statistics (requires at least 5).")

plt.figure(figsize=(3.25, 2.5))
plt.plot(expansion_steps, madelung_values, marker='o', linestyle='-', color='b')
plt.xlabel("Expansion Step (i)")
plt.ylabel("Madelung Constant")
plt.grid(True)
plt.savefig(file_path.removesuffix('.txt')+"_madelung_vs_expansion_sphere.svg", format="svg")
plt.show()

plt.figure(figsize=(3.25, 2.5))
plt.plot(expansion_steps, charge_densities, marker='s', linestyle='-', color='r')
plt.xlabel("Expansion Step (i)")
plt.ylabel("Charge Density (e/Å³)")
plt.grid(True)
plt.savefig(file_path.removesuffix('.txt')+"_charge_density_vs_expansion_sphere.svg", format="svg")
plt.show()