In [None]:
import numpy as np
import plotly.graph_objects as go
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.spatial import KDTree

# === Read input data ===
file_path = 'input_CaF2.txt'  # Update with actual file name

with open(file_path, 'r') as file:
    lines = file.readlines()
    a, b, c = map(float, lines[0].split())  # Lattice lengths (Å)
    alpha, beta, gamma = map(float, lines[1].split())  # Angles (degrees)
    calc_tresh = float(lines[2].strip())  # Threshold for duplicate removal (Å)

# === Convert angles to radians ===
alpha_rad = np.radians(alpha)
beta_rad = np.radians(beta)
gamma_rad = np.radians(gamma)

# === Construct lattice matrix for a general unit cell ===
cos_alpha, cos_beta, cos_gamma = np.cos(alpha_rad), np.cos(beta_rad), np.cos(gamma_rad)
sin_gamma = np.sin(gamma_rad)

# Volume correction for third lattice vector
volume_factor = np.sqrt(1 - cos_alpha**2 - cos_beta**2 - cos_gamma**2 +
                        2 * cos_alpha * cos_beta * cos_gamma)

# Lattice vectors as columns of the matrix
lattice_matrix = np.array([
    [a, b * cos_gamma, c * cos_beta],
    [0, b * sin_gamma, c * (cos_alpha - cos_beta * cos_gamma) / sin_gamma],
    [0, 0, c * volume_factor / sin_gamma]
])

# === Read atoms and convert fractional to Cartesian coordinates ===
atoms = []
for line in lines[3:]:
    symbol, x, y, z, charge = line.split()[0], *map(float, line.split()[1:5])
    if 0 <= x <= 1 and 0 <= y <= 1 and 0 <= z <= 1:
        fractional = np.clip([x, y, z], 1e-6, 1 - 1e-6)
        cartesian = lattice_matrix @ fractional
        atoms.append({'symbol': symbol, 'position': cartesian, 'charge': charge})

# === Remove duplicate atoms using KDTree ===
def remove_duplicates(atom_list, threshold=0.1):
    positions = np.round(np.array([atom['position'] for atom in atom_list]), decimals=6)
    tree = KDTree(positions)
    unique_atoms = []
    seen = set()
    for i, pos in enumerate(positions):
        if tuple(pos) in seen:
            continue
        seen.add(tuple(pos))
        unique_atoms.append(atom_list[i])
    return unique_atoms

atoms = remove_duplicates(atoms, threshold=calc_tresh)

# === Compute dipole moment ===
def compute_dipole_moment(atoms):
    unit_cell_atoms = atoms  # All atoms assumed inside unit cell after clipping

    dipole_moment = np.sum([atom['charge'] * atom['position'] for atom in unit_cell_atoms], axis=0)
    dipole_moment_Debye = dipole_moment * 4.803  # Convert Å·e to Debye
    dipole_magnitude_Debye = np.linalg.norm(dipole_moment_Debye)

    return dipole_moment_Debye, dipole_magnitude_Debye, unit_cell_atoms

dipole_vector, dipole_magnitude, unit_cell_atoms = compute_dipole_moment(atoms)

print(f"Net Dipole Moment (Debye): {dipole_vector}")
print(f"Magnitude of Dipole Moment: {dipole_magnitude:.6f} D")

# === Plotting function ===
def plot_unit_cell(atoms, dipole_vector=None):
    symbols = [atom['symbol'] for atom in atoms]
    positions = np.array([atom['position'] for atom in atoms])

    color_palette = sns.color_palette("hls", len(set(symbols)))
    symbol_to_color = {
        sym: f'rgb{tuple((np.array(c) * 255).astype(int))}'
        for sym, c in zip(sorted(set(symbols)), color_palette)
    }
    colors = [symbol_to_color[s] for s in symbols]

    fig = go.Figure()

    # Plot atoms
    fig.add_trace(go.Scatter3d(
        x=positions[:, 0], y=positions[:, 1], z=positions[:, 2],
        mode='markers',
        marker=dict(size=6, color=colors, line=dict(width=1, color='black')),
        text=[f"{s} ({p[0]:.2f}, {p[1]:.2f}, {p[2]:.2f})" for s, p in zip(symbols, positions)],
        hoverinfo='text'
    ))

    # Plot dipole vector
    if dipole_vector is not None:
        cell_center = 0.5 * (lattice_matrix @ np.array([1, 1, 1]))
        scale = 0.2 * min(a, b, c) / np.linalg.norm(dipole_vector)
        end = cell_center + dipole_vector * scale
        fig.add_trace(go.Scatter3d(
            x=[cell_center[0], end[0]],
            y=[cell_center[1], end[1]],
            z=[cell_center[2], end[2]],
            mode='lines+markers',
            line=dict(color='red', width=5),
            marker=dict(size=2),
            name='Dipole Moment'
        ))

    # Plot unit cell box
    corners_frac = np.array([
        [0, 0, 0],
        [1, 0, 0],
        [1, 1, 0],
        [0, 1, 0],
        [0, 0, 1],
        [1, 0, 1],
        [1, 1, 1],
        [0, 1, 1]
    ])
    corners_cart = np.array([lattice_matrix @ f for f in corners_frac])

    edge_indices = [
        (0, 1), (1, 2), (2, 3), (3, 0),  # Bottom face
        (4, 5), (5, 6), (6, 7), (7, 4),  # Top face
        (0, 4), (1, 5), (2, 6), (3, 7)   # Vertical edges
    ]

    for i1, i2 in edge_indices:
        fig.add_trace(go.Scatter3d(
            x=[corners_cart[i1, 0], corners_cart[i2, 0]],
            y=[corners_cart[i1, 1], corners_cart[i2, 1]],
            z=[corners_cart[i1, 2], corners_cart[i2, 2]],
            mode='lines',
            line=dict(color='gray', width=2),
            showlegend=False
        ))

    fig.update_layout(
        scene=dict(
            xaxis_title='X (Å)',
            yaxis_title='Y (Å)',
            zaxis_title='Z (Å)',
            aspectmode='data'
        ),
        margin=dict(l=0, r=0, b=0, t=40),
        title="Unit Cell Visualization"
    )

    fig.show()

# === Call plot ===
plot_unit_cell(unit_cell_atoms, dipole_vector)