In [2]:
import lmdb
import pickle
import numpy as np

# LMDB 파일 경로
lmdb_path = "/home/minkyu/MinCatFlow/dataset/val_id/dataset.lmdb"

# 특정 index 설정
target_index = 24399  # 원하는 index로 변경

# LMDB 열기
env = lmdb.open(
    lmdb_path,
    subdir=False,
    readonly=True,
    lock=False,
    readahead=True,
    meminit=False,
    max_readers=1,
)

# 특정 index의 데이터 가져오기
with env.begin() as txn:
    # 키를 문자열로 변환 (index를 키로 사용)
    key_str = str(target_index)
    key_bytes = key_str.encode("ascii")
    
    # 해당 키의 값 가져오기
    value = txn.get(key_bytes)
    
    if value is None:
        print(f"Index {target_index}에 해당하는 데이터를 찾을 수 없습니다.")
    else:
        # pickle로 역직렬화
        data_dict = pickle.loads(value)
        
        print(f"=== Index {target_index}의 모든 Value ===\n")
        
        # 모든 키-값 쌍 출력
        for key, value in data_dict.items():
            print(f"'{key}':")
            print(f"  Type: {type(value)}")
            
            if isinstance(value, np.ndarray):
                print(f"  Shape: {value.shape}")
                print(f"  Dtype: {value.dtype}")
                if value.size < 50:  # 작은 배열은 전체 출력
                    print(f"  Value:\n{value}")
                else:
                    print(f"  Value (first 20 elements):\n{value.flat[:20]}")
                    print(f"  ... (total {value.size} elements)")
            elif hasattr(value, '__len__') and not isinstance(value, str):
                try:
                    print(f"  Length: {len(value)}")
                    if len(value) < 20:
                        print(f"  Value: {value}")
                    else:
                        print(f"  Value (first 10 items): {list(value)[:10]}")
                        print(f"  ... (total {len(value)} items)")
                except:
                    print(f"  Value: {value}")
            elif hasattr(value, 'positions'):  # ASE.Atoms 객체인 경우
                print(f"  Number of atoms: {len(value)}")
                print(f"  Positions shape: {value.positions.shape}")
                print(f"  Numbers (atomic numbers): {value.numbers}")
                print(f"  Cell shape: {value.cell.shape}")
                if hasattr(value, 'get_chemical_symbols'):
                    symbols = value.get_chemical_symbols()
                    print(f"  Chemical symbols (first 20): {symbols[:20]}")
                    if len(symbols) > 20:
                        print(f"  ... (total {len(symbols)} atoms)")
            else:
                print(f"  Value: {value}")
            print()

env.close()


=== Index 24399의 모든 Value ===

'sid':
  Type: <class 'int'>
  Value: 1022903

'primitive_slab':
  Type: <class 'pymatgen.io.ase.MSONAtoms'>
  Length: 32
  Value (first 10 items): [Atom('Re', [np.float64(1.1086056232452401), np.float64(5.760484059651693), np.float64(6.788795381784439)], index=0), Atom('Re', [np.float64(1.1086056232452406), np.float64(8.3207000096639), np.float64(3.168104559183122)], index=1), Atom('Re', [np.float64(1.1086056232452393), np.float64(0.6400540669759126), np.float64(3.1681045591831216)], index=2), Atom('Re', [np.float64(1.1086056232452397), np.float64(3.200269063313803), np.float64(10.409486383199692)], index=3), Atom('Re', [np.float64(1.1086056232452397), np.float64(3.200269063313803), np.float64(4.978450208902361)], index=4), Atom('Re', [np.float64(1.1086056232452401), np.float64(5.760484059651693), np.float64(1.357758909463883)], index=5), Atom('Re', [np.float64(1.1086056232452406), np.float64(8.3207000096639), np.float64(8.599140793085098)], index=6), At

In [None]:
from ase.atoms import Atoms
import numpy as np
from scripts.assemble import assemble
from pymatgen.io.ase import AseAtomsAdaptor
import math 
from ase.build import niggli_reduce
from pymatgen.core.lattice import Lattice
from ase.cell import Cell

adaptor = AseAtomsAdaptor()

# 위에서 추출한 데이터를 사용하여 atoms 객체 생성
# data_dict는 위 셀에서 이미 로드되어 있다고 가정
# 1. Primitive slab atoms 객체 (이미 atoms 객체이므로 그대로 사용 가능)
primitive_slab_atoms = data_dict["primitive_slab"]
ads_atomic_numbers = data_dict["ads_atomic_numbers"]  # [7, 1] (N, H)
ads_pos = data_dict["ads_pos"]  # (2, 3) 좌표
supercell_matrix = data_dict["supercell_matrix"]

n_slab = data_dict["n_slab"]
n_vac = data_dict["n_vac"]
scaling_factor = (n_slab + n_vac) / n_slab

# prim_slab_struct = adaptor.get_structure(primitive_slab_atoms)

# sc_matrix = supercell_matrix

# print(f"Supercell matrix:\n{sc_matrix}")

# recon_struct = prim_slab_struct.copy()
# recon_struct.make_supercell(sc_matrix, to_unit_cell=False)

# recon_tight_slab = adaptor.get_atoms(recon_struct)

# recon_slab = recon_tight_slab.copy()
# recon_cell = recon_slab.get_cell()
# recon_cell[2] = recon_cell[2] * scaling_factor
# recon_slab.set_cell(recon_cell)

# adsorbate = Atoms(
#     positions=ads_pos,
#     numbers=ads_atomic_numbers,
#     cell=recon_cell,
#     pbc=True,
# )

# recon_system = recon_slab + adsorbate

recon_system = assemble(
    primitive_slab_atoms.get_positions(),
    ads_pos,
    primitive_slab_atoms.get_cell_lengths_and_angles(),
    supercell_matrix,
    scaling_factor,
    primitive_slab_atoms.get_atomic_numbers(),
    ads_atomic_numbers,
)

print("recon_system lattice (matrix form):\n", recon_system.get_cell())

from ase.visualize import view

view(recon_system, viewer='x3d')


In [None]:
import lmdb
import pickle
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os

# LMDB 파일 경로
lmdb_path = "/home/minkyu/EfficientCatGen/dataset/train/dataset.lmdb"
output_dir = "/home/minkyu/EfficientCatGen/analysis_plots"
os.makedirs(output_dir, exist_ok=True)

# 모든 ads_pos 수집
all_ads_pos = []  # 모든 원자 좌표를 펼쳐서 저장
all_ads_pos_per_sample = []  # 샘플별로 저장 (outlier 탐지용)
sample_indices = []

env = lmdb.open(
    lmdb_path,
    subdir=False,
    readonly=True,
    lock=False,
    readahead=True,
    meminit=False,
    max_readers=1,
)

with env.begin() as txn:
    # 전체 데이터 개수 확인
    length = txn.stat()["entries"]
    print(f"Total samples in dataset: {length}")
    
    for idx in tqdm(range(length), desc="Loading ads_pos"):
        key_bytes = str(idx).encode("ascii")
        value = txn.get(key_bytes)
        
        if value is None:
            continue
            
        data_dict = pickle.loads(value)
        ads_pos = data_dict.get("ads_pos", None)
        
        if ads_pos is not None and len(ads_pos) > 0:
            all_ads_pos.append(ads_pos)
            all_ads_pos_per_sample.append({
                "idx": idx,
                "ads_pos": ads_pos,
                "mean": np.mean(ads_pos, axis=0),
                "std": np.std(ads_pos),
                "min": np.min(ads_pos, axis=0),
                "max": np.max(ads_pos, axis=0),
                "n_atoms": len(ads_pos),
            })
            sample_indices.append(idx)

env.close()

# 모든 좌표를 하나의 배열로 합치기
all_coords = np.vstack(all_ads_pos)
print(f"\nTotal adsorbate atoms: {len(all_coords)}")
print(f"Total samples with adsorbate: {len(all_ads_pos_per_sample)}")


In [None]:
# ========================
# 1. 기본 통계 분석
# ========================
print("=" * 60)
print("1. Basic Statistics of ads_pos (all atoms)")
print("=" * 60)

for i, axis in enumerate(["X", "Y", "Z"]):
    coords = all_coords[:, i]
    print(f"\n{axis}-axis:")
    print(f"  Min: {np.min(coords):.4f}")
    print(f"  Max: {np.max(coords):.4f}")
    print(f"  Mean: {np.mean(coords):.4f}")
    print(f"  Std: {np.std(coords):.4f}")
    print(f"  Median: {np.median(coords):.4f}")
    
    # Percentiles
    p1, p5, p25, p75, p95, p99 = np.percentile(coords, [1, 5, 25, 75, 95, 99])
    print(f"  Percentiles: 1%={p1:.2f}, 5%={p5:.2f}, 25%={p25:.2f}, 75%={p75:.2f}, 95%={p95:.2f}, 99%={p99:.2f}")

# ========================
# 2. Outlier 탐지 (IQR 방법)
# ========================
print("\n" + "=" * 60)
print("2. Outlier Detection (IQR method, 1.5*IQR)")
print("=" * 60)

outlier_samples = []
for i, axis in enumerate(["X", "Y", "Z"]):
    coords = all_coords[:, i]
    Q1 = np.percentile(coords, 25)
    Q3 = np.percentile(coords, 75)
    IQR = Q3 - Q1
    lower_bound = Q1 - 1.5 * IQR
    upper_bound = Q3 + 1.5 * IQR
    
    outliers = (coords < lower_bound) | (coords > upper_bound)
    n_outliers = np.sum(outliers)
    
    print(f"\n{axis}-axis:")
    print(f"  IQR: {IQR:.4f}")
    print(f"  Lower bound: {lower_bound:.4f}")
    print(f"  Upper bound: {upper_bound:.4f}")
    print(f"  Number of outliers: {n_outliers} ({100*n_outliers/len(coords):.2f}%)")
    
    if n_outliers > 0:
        outlier_values = coords[outliers]
        print(f"  Outlier range: [{np.min(outlier_values):.2f}, {np.max(outlier_values):.2f}]")

# ========================
# 3. 샘플별 통계 분석 (centroid 기준)
# ========================
print("\n" + "=" * 60)
print("3. Per-sample Statistics (centroid of each adsorbate)")
print("=" * 60)

centroids = np.array([s["mean"] for s in all_ads_pos_per_sample])
for i, axis in enumerate(["X", "Y", "Z"]):
    coords = centroids[:, i]
    print(f"\n{axis}-axis (centroids):")
    print(f"  Min: {np.min(coords):.4f}, Max: {np.max(coords):.4f}")
    print(f"  Mean: {np.mean(coords):.4f}, Std: {np.std(coords):.4f}")


In [None]:
# ========================
# 4. 2D Projection Plots (XY, YZ, ZX)
# ========================
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# XY plane
ax = axes[0]
ax.scatter(all_coords[:, 0], all_coords[:, 1], alpha=0.1, s=1, c='blue')
ax.set_xlabel('X (Å)', fontsize=12)
ax.set_ylabel('Y (Å)', fontsize=12)
ax.set_title('XY Plane Projection', fontsize=14)
ax.set_aspect('equal')
ax.grid(True, alpha=0.3)

# YZ plane
ax = axes[1]
ax.scatter(all_coords[:, 1], all_coords[:, 2], alpha=0.1, s=1, c='green')
ax.set_xlabel('Y (Å)', fontsize=12)
ax.set_ylabel('Z (Å)', fontsize=12)
ax.set_title('YZ Plane Projection', fontsize=14)
ax.set_aspect('equal')
ax.grid(True, alpha=0.3)

# ZX plane
ax = axes[2]
ax.scatter(all_coords[:, 2], all_coords[:, 0], alpha=0.1, s=1, c='red')
ax.set_xlabel('Z (Å)', fontsize=12)
ax.set_ylabel('X (Å)', fontsize=12)
ax.set_title('ZX Plane Projection', fontsize=14)
ax.set_aspect('equal')
ax.grid(True, alpha=0.3)

plt.suptitle('ads_pos 2D Projections (All Atoms)', fontsize=16, y=1.02)
plt.tight_layout()
plt.savefig(f"{output_dir}/ads_pos_2d_projections.png", dpi=150, bbox_inches='tight')
plt.show()
print(f"Saved: {output_dir}/ads_pos_2d_projections.png")


In [None]:
# ========================
# 5. Histograms for X, Y, Z coordinates
# ========================
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

colors = ['blue', 'green', 'red']
labels = ['X', 'Y', 'Z']

for i, (ax, color, label) in enumerate(zip(axes, colors, labels)):
    coords = all_coords[:, i]
    
    ax.hist(coords, bins=100, color=color, alpha=0.7, edgecolor='black', linewidth=0.5)
    ax.axvline(np.mean(coords), color='black', linestyle='--', linewidth=2, label=f'Mean: {np.mean(coords):.2f}')
    ax.axvline(np.median(coords), color='orange', linestyle='-.', linewidth=2, label=f'Median: {np.median(coords):.2f}')
    
    ax.set_xlabel(f'{label} (Å)', fontsize=12)
    ax.set_ylabel('Count', fontsize=12)
    ax.set_title(f'{label}-coordinate Distribution', fontsize=14)
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.suptitle('ads_pos Coordinate Distributions', fontsize=16, y=1.02)
plt.tight_layout()
plt.savefig(f"{output_dir}/ads_pos_histograms.png", dpi=150, bbox_inches='tight')
plt.show()
print(f"Saved: {output_dir}/ads_pos_histograms.png")


In [None]:
# ========================
# 6. 2D Density Heatmaps (더 나은 시각화)
# ========================
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# XY plane heatmap
ax = axes[0]
h = ax.hist2d(all_coords[:, 0], all_coords[:, 1], bins=100, cmap='viridis', cmin=1)
plt.colorbar(h[3], ax=ax, label='Count')
ax.set_xlabel('X (Å)', fontsize=12)
ax.set_ylabel('Y (Å)', fontsize=12)
ax.set_title('XY Plane Density', fontsize=14)

# YZ plane heatmap
ax = axes[1]
h = ax.hist2d(all_coords[:, 1], all_coords[:, 2], bins=100, cmap='viridis', cmin=1)
plt.colorbar(h[3], ax=ax, label='Count')
ax.set_xlabel('Y (Å)', fontsize=12)
ax.set_ylabel('Z (Å)', fontsize=12)
ax.set_title('YZ Plane Density', fontsize=14)

# ZX plane heatmap
ax = axes[2]
h = ax.hist2d(all_coords[:, 2], all_coords[:, 0], bins=100, cmap='viridis', cmin=1)
plt.colorbar(h[3], ax=ax, label='Count')
ax.set_xlabel('Z (Å)', fontsize=12)
ax.set_ylabel('X (Å)', fontsize=12)
ax.set_title('ZX Plane Density', fontsize=14)

plt.suptitle('ads_pos 2D Density Heatmaps', fontsize=16, y=1.02)
plt.tight_layout()
plt.savefig(f"{output_dir}/ads_pos_2d_heatmaps.png", dpi=150, bbox_inches='tight')
plt.show()
print(f"Saved: {output_dir}/ads_pos_2d_heatmaps.png")


In [None]:
# ========================
# 7. Outlier 샘플 식별 (극단적인 좌표를 가진 샘플)
# ========================
print("=" * 60)
print("7. Identifying Outlier Samples")
print("=" * 60)

# 전체 데이터에서 통계 계산
global_mean = np.mean(all_coords, axis=0)
global_std = np.std(all_coords, axis=0)

print(f"\nGlobal mean: X={global_mean[0]:.2f}, Y={global_mean[1]:.2f}, Z={global_mean[2]:.2f}")
print(f"Global std:  X={global_std[0]:.2f}, Y={global_std[1]:.2f}, Z={global_std[2]:.2f}")

# Z-score 기반 outlier 탐지 (|z| > 3)
outlier_threshold = 3.0
outlier_samples_info = []

for sample in all_ads_pos_per_sample:
    ads_pos = sample["ads_pos"]
    z_scores = np.abs((ads_pos - global_mean) / (global_std + 1e-8))
    max_z = np.max(z_scores)
    
    if max_z > outlier_threshold:
        outlier_samples_info.append({
            "idx": sample["idx"],
            "max_z_score": max_z,
            "mean": sample["mean"],
            "min": sample["min"],
            "max": sample["max"],
            "n_atoms": sample["n_atoms"],
        })

# Z-score로 정렬
outlier_samples_info.sort(key=lambda x: x["max_z_score"], reverse=True)

print(f"\nFound {len(outlier_samples_info)} samples with |z-score| > {outlier_threshold}")
print(f"\nTop 20 outlier samples:")
print("-" * 100)
print(f"{'Idx':<8} {'Max Z':<10} {'N_atoms':<8} {'Mean X':<12} {'Mean Y':<12} {'Mean Z':<12}")
print("-" * 100)

for info in outlier_samples_info[:20]:
    print(f"{info['idx']:<8} {info['max_z_score']:<10.2f} {info['n_atoms']:<8} "
          f"{info['mean'][0]:<12.2f} {info['mean'][1]:<12.2f} {info['mean'][2]:<12.2f}")


In [None]:
# ========================
# 8. Centroid Distribution (샘플별 adsorbate 중심 위치)
# ========================
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

centroids = np.array([s["mean"] for s in all_ads_pos_per_sample])

# XY plane
ax = axes[0]
ax.scatter(centroids[:, 0], centroids[:, 1], alpha=0.3, s=5, c='blue')
ax.set_xlabel('X (Å)', fontsize=12)
ax.set_ylabel('Y (Å)', fontsize=12)
ax.set_title('Centroid XY Projection', fontsize=14)
ax.grid(True, alpha=0.3)

# YZ plane
ax = axes[1]
ax.scatter(centroids[:, 1], centroids[:, 2], alpha=0.3, s=5, c='green')
ax.set_xlabel('Y (Å)', fontsize=12)
ax.set_ylabel('Z (Å)', fontsize=12)
ax.set_title('Centroid YZ Projection', fontsize=14)
ax.grid(True, alpha=0.3)

# ZX plane
ax = axes[2]
ax.scatter(centroids[:, 2], centroids[:, 0], alpha=0.3, s=5, c='red')
ax.set_xlabel('Z (Å)', fontsize=12)
ax.set_ylabel('X (Å)', fontsize=12)
ax.set_title('Centroid ZX Projection', fontsize=14)
ax.grid(True, alpha=0.3)

plt.suptitle('Adsorbate Centroid Distributions (per sample)', fontsize=16, y=1.02)
plt.tight_layout()
plt.savefig(f"{output_dir}/ads_pos_centroid_projections.png", dpi=150, bbox_inches='tight')
plt.show()
print(f"Saved: {output_dir}/ads_pos_centroid_projections.png")


In [None]:
# ========================
# 9. 좌표 범위 분석 (학습 관점)
# ========================
print("=" * 60)
print("9. Coordinate Range Analysis (Training Perspective)")
print("=" * 60)

# 좌표 범위
coord_range = np.max(all_coords, axis=0) - np.min(all_coords, axis=0)
print(f"\nCoordinate Ranges:")
print(f"  X range: {coord_range[0]:.2f} Å ({np.min(all_coords[:,0]):.2f} to {np.max(all_coords[:,0]):.2f})")
print(f"  Y range: {coord_range[1]:.2f} Å ({np.min(all_coords[:,1]):.2f} to {np.max(all_coords[:,1]):.2f})")
print(f"  Z range: {coord_range[2]:.2f} Å ({np.min(all_coords[:,2]):.2f} to {np.max(all_coords[:,2]):.2f})")

# 샘플 내 adsorbate 크기 (internal spread)
internal_stds = np.array([s["std"] for s in all_ads_pos_per_sample])
print(f"\nAdsorbate Internal Spread (std per sample):")
print(f"  Mean: {np.mean(internal_stds):.4f} Å")
print(f"  Std: {np.std(internal_stds):.4f} Å")
print(f"  Min: {np.min(internal_stds):.4f} Å")
print(f"  Max: {np.max(internal_stds):.4f} Å")

# 문제점 분석
print("\n" + "=" * 60)
print("10. Potential Issues for Learning")
print("=" * 60)

print(f"""
1. 좌표 범위가 매우 넓음:
   - X: {coord_range[0]:.1f} Å, Y: {coord_range[1]:.1f} Å, Z: {coord_range[2]:.1f} Å
   - 이렇게 넓은 범위를 학습하려면 모델이 큰 값들을 처리해야 함

2. Standardization 부재:
   - 현재 ads_pos는 mean/std standardization 없이 raw Cartesian 좌표 사용
   - 좌표가 centering되어 있지 않음 (mean ≠ 0)

3. 권장 사항:
   - ads_pos에도 standardization 적용 검토
   - 또는 adsorbate centroid를 기준으로 centering 후 학습
   
4. Global Statistics (for standardization):
   - Mean: X={global_mean[0]:.4f}, Y={global_mean[1]:.4f}, Z={global_mean[2]:.4f}
   - Std:  X={global_std[0]:.4f}, Y={global_std[1]:.4f}, Z={global_std[2]:.4f}
""")

# 통계를 파일로 저장
stats_dict = {
    "global_mean": global_mean.tolist(),
    "global_std": global_std.tolist(),
    "coord_range": coord_range.tolist(),
    "total_atoms": len(all_coords),
    "total_samples": len(all_ads_pos_per_sample),
}

import json
with open(f"{output_dir}/ads_pos_statistics.json", "w") as f:
    json.dump(stats_dict, f, indent=2)
print(f"\nSaved statistics to: {output_dir}/ads_pos_statistics.json")


In [None]:
# ========================
# Train dataset structural validity 검사
# ========================
import lmdb
import pickle
import numpy as np
from tqdm import tqdm
from ase import Atoms
from scripts.assemble import assemble

def structural_validity(atoms: Atoms) -> dict:
    """Check structural validity of an Atoms object and return detailed results."""
    results = {"valid": True, "vol_ok": True, "dist_ok": True, "width_ok": True}
    
    # 1. Check cell volume
    try:
        vol = float(atoms.get_volume())
        results["vol_ok"] = vol >= 0.1
        results["volume"] = vol
    except Exception:
        results["vol_ok"] = False
        results["volume"] = None

    # 2. Check atom clash
    try:
        if len(atoms) > 1:
            dists = atoms.get_all_distances()
            min_dist = np.min(dists[np.nonzero(dists)])
            results["dist_ok"] = min_dist >= 0.5
            results["min_dist"] = min_dist
        else:
            results["dist_ok"] = True
            results["min_dist"] = None
    except Exception:
        results["dist_ok"] = False
        results["min_dist"] = None
    
    # 3. Check min width of cell (a, b >= 8.0 Å)
    min_ab = 8.0
    a_length = np.linalg.norm(atoms.cell[0])
    b_length = np.linalg.norm(atoms.cell[1])
    results["a_length"] = a_length
    results["b_length"] = b_length
    results["width_ok"] = (a_length >= min_ab) and (b_length >= min_ab)

    # 4. Overall validity
    results["valid"] = results["vol_ok"] and results["dist_ok"] and results["width_ok"]
    
    return results

# LMDB 파일 경로
lmdb_path = "/home/minkyu/EfficientCatGen/dataset/train/dataset.lmdb"

# 통계 수집
total_samples = 0
valid_samples = 0
invalid_reasons = {"vol": 0, "dist": 0, "width": 0}
failed_samples = []

env = lmdb.open(
    lmdb_path,
    subdir=False,
    readonly=True,
    lock=False,
    readahead=True,
    meminit=False,
    max_readers=1,
)

with env.begin() as txn:
    length = txn.stat()["entries"]
    print(f"Total samples in dataset: {length}")
    
    for idx in tqdm(range(length), desc="Checking structural validity"):
        key_bytes = str(idx).encode("ascii")
        value = txn.get(key_bytes)
        
        if value is None:
            continue
        
        total_samples += 1
        
        try:
            data_dict = pickle.loads(value)
            
            # Extract data
            primitive_slab = data_dict["primitive_slab"]
            supercell_matrix = data_dict["supercell_matrix"]
            n_slab = data_dict["n_slab"]
            n_vac = data_dict["n_vac"]
            ads_atomic_numbers = data_dict.get("ads_atomic_numbers", np.array([]))
            ads_pos = data_dict.get("ads_pos", np.array([]).reshape(0, 3))
            
            # Calculate scaling factor
            scaling_factor = (n_slab + n_vac) / n_slab
            
            # Get primitive slab info
            prim_positions = primitive_slab.get_positions()
            prim_numbers = primitive_slab.get_atomic_numbers()
            lattice_params = primitive_slab.cell.cellpar()  # (a, b, c, alpha, beta, gamma)
            
            # Use assemble function from scripts/assemble.py
            recon_system = assemble(
                generated_prim_slab_coords=prim_positions,
                generated_ads_coords=ads_pos,
                generated_lattice=lattice_params,
                generated_supercell_matrix=supercell_matrix,
                generated_scaling_factor=scaling_factor,
                prim_slab_atom_types=prim_numbers,
                ads_atom_types=ads_atomic_numbers,
            )
            
            # Check structural validity
            validity = structural_validity(recon_system)
            
            if validity["valid"]:
                valid_samples += 1
            else:
                if not validity["vol_ok"]:
                    invalid_reasons["vol"] += 1
                if not validity["dist_ok"]:
                    invalid_reasons["dist"] += 1
                if not validity["width_ok"]:
                    invalid_reasons["width"] += 1
                
                # Store first few failed samples for inspection
                if len(failed_samples) < 10:
                    failed_samples.append({
                        "idx": idx,
                        "validity": validity,
                    })
                    
        except Exception as e:
            if len(failed_samples) < 10:
                failed_samples.append({"idx": idx, "error": str(e)})

env.close()

# 결과 출력
print(f"\n{'='*60}")
print("Structural Validity Results")
print(f"{'='*60}")
print(f"\nTotal samples: {total_samples}")
print(f"Valid samples: {valid_samples}")
print(f"Invalid samples: {total_samples - valid_samples}")
print(f"\nValidity rate: {100 * valid_samples / total_samples:.2f}%")

print(f"\n--- Invalid Reasons (can overlap) ---")
print(f"  Volume < 0.1: {invalid_reasons['vol']} ({100 * invalid_reasons['vol'] / total_samples:.2f}%)")
print(f"  Min distance < 0.5 Å: {invalid_reasons['dist']} ({100 * invalid_reasons['dist'] / total_samples:.2f}%)")
print(f"  Cell width (a or b) < 8.0 Å: {invalid_reasons['width']} ({100 * invalid_reasons['width'] / total_samples:.2f}%)")

if failed_samples:
    print(f"\n--- Sample Failed Cases (first {len(failed_samples)}) ---")
    for sample in failed_samples:
        print(f"  Index {sample['idx']}: {sample.get('validity', sample.get('error', 'Unknown'))}")


In [None]:
# ========================
# prim_slab_coords 와 ads_pos 의 mean/std 계산
# (Standardization을 위한 통계)
# ========================
import lmdb
import pickle
import numpy as np
from tqdm import tqdm

# LMDB 파일 경로
lmdb_path = "/home/minkyu/EfficientCatGen/dataset/train/dataset.lmdb"

# 모든 좌표 수집
all_prim_slab_coords = []
all_ads_coords = []

env = lmdb.open(
    lmdb_path,
    subdir=False,
    readonly=True,
    lock=False,
    readahead=True,
    meminit=False,
    max_readers=1,
)

with env.begin() as txn:
    length = txn.stat()["entries"]
    print(f"Total samples in dataset: {length}")
    
    for idx in tqdm(range(length), desc="Loading coordinates"):
        key_bytes = str(idx).encode("ascii")
        value = txn.get(key_bytes)
        
        if value is None:
            continue
            
        data_dict = pickle.loads(value)
        
        # Primitive slab coordinates
        prim_slab = data_dict.get("primitive_slab", None)
        if prim_slab is not None:
            # MSONAtoms에서 positions 추출
            positions = prim_slab.get_positions()  # (N, 3)
            all_prim_slab_coords.append(positions)
        
        # Adsorbate coordinates
        ads_pos = data_dict.get("ads_pos", None)
        if ads_pos is not None and len(ads_pos) > 0:
            all_ads_coords.append(ads_pos)

env.close()

# 모든 좌표를 하나의 배열로 합치기
all_prim_slab_flat = np.vstack(all_prim_slab_coords)
all_ads_flat = np.vstack(all_ads_coords)

print(f"\n{'='*60}")
print("Statistics for Standardization")
print(f"{'='*60}")

# ========================
# prim_slab_coords 통계
# ========================
print(f"\n[prim_slab_coords]")
print(f"  Total atoms: {len(all_prim_slab_flat)}")
print(f"  Total samples: {len(all_prim_slab_coords)}")

prim_slab_mean = np.mean(all_prim_slab_flat, axis=0)
prim_slab_std = np.std(all_prim_slab_flat, axis=0)
prim_slab_global_std = np.std(all_prim_slab_flat)  # 전체 std (scalar)

print(f"\n  Per-axis Mean: [{prim_slab_mean[0]:.6f}, {prim_slab_mean[1]:.6f}, {prim_slab_mean[2]:.6f}]")
print(f"  Per-axis Std:  [{prim_slab_std[0]:.6f}, {prim_slab_std[1]:.6f}, {prim_slab_std[2]:.6f}]")
print(f"  Global Std (scalar): {prim_slab_global_std:.6f}")

# ========================
# ads_pos 통계
# ========================
print(f"\n[ads_pos]")
print(f"  Total atoms: {len(all_ads_flat)}")
print(f"  Total samples: {len(all_ads_coords)}")

ads_mean = np.mean(all_ads_flat, axis=0)
ads_std = np.std(all_ads_flat, axis=0)
ads_global_std = np.std(all_ads_flat)  # 전체 std (scalar)

print(f"\n  Per-axis Mean: [{ads_mean[0]:.6f}, {ads_mean[1]:.6f}, {ads_mean[2]:.6f}]")
print(f"  Per-axis Std:  [{ads_std[0]:.6f}, {ads_std[1]:.6f}, {ads_std[2]:.6f}]")
print(f"  Global Std (scalar): {ads_global_std:.6f}")

# ========================
# Config 형식으로 출력 (복사용)
# ========================
print(f"\n{'='*60}")
print("Config Format (copy-paste ready)")
print(f"{'='*60}")

print(f"""
# prim_slab_coords standardization
prim_slab_coord_mean: [{prim_slab_mean[0]:.6f}, {prim_slab_mean[1]:.6f}, {prim_slab_mean[2]:.6f}]
prim_slab_coord_std: [{prim_slab_std[0]:.6f}, {prim_slab_std[1]:.6f}, {prim_slab_std[2]:.6f}]
# or use global std (recommended for isotropic scaling):
prim_slab_coord_global_std: {prim_slab_global_std:.6f}

# ads_pos standardization
ads_coord_mean: [{ads_mean[0]:.6f}, {ads_mean[1]:.6f}, {ads_mean[2]:.6f}]
ads_coord_std: [{ads_std[0]:.6f}, {ads_std[1]:.6f}, {ads_std[2]:.6f}]
# or use global std (recommended for isotropic scaling):
ads_coord_global_std: {ads_global_std:.6f}
""")


In [None]:
import lmdb
import pickle
import numpy as np

# LMDB 파일 경로
lmdb_path = "/home/minkyu/EfficientCatGen/is2res_train_val_test_lmdbs/data/is2re/all/val_id/data.lmdb"

# 특정 index 설정
target_index = 0  # 원하는 index로 변경

# LMDB 열기
env = lmdb.open(
    lmdb_path,
    subdir=False,
    readonly=True,
    lock=False,
    readahead=True,
    meminit=False,
    max_readers=1,
)

# 특정 index의 데이터 가져오기
with env.begin() as txn:
    # 키를 문자열로 변환 (index를 키로 사용)
    key_str = str(target_index)
    key_bytes = key_str.encode("ascii")
    
    # 해당 키의 값 가져오기
    value = txn.get(key_bytes)
    
    if value is None:
        print(f"Index {target_index}에 해당하는 데이터를 찾을 수 없습니다.")
    else:
        # pickle로 역직렬화
        data_dict = pickle.loads(value)
        
        print(f"=== Index {target_index}의 모든 Value ===\n")
        
        # 모든 키-값 쌍 출력
        for key, value in data_dict.items():
            print(f"'{key}':")
            print(f"  Type: {type(value)}")
            
            if isinstance(value, np.ndarray):
                print(f"  Shape: {value.shape}")
                print(f"  Dtype: {value.dtype}")
                if value.size < 50:  # 작은 배열은 전체 출력
                    print(f"  Value:\n{value}")
                else:
                    print(f"  Value (first 20 elements):\n{value.flat[:20]}")
                    print(f"  ... (total {value.size} elements)")
            elif hasattr(value, '__len__') and not isinstance(value, str):
                try:
                    print(f"  Length: {len(value)}")
                    if len(value) < 20:
                        print(f"  Value: {value}")
                    else:
                        print(f"  Value (first 10 items): {list(value)[:10]}")
                        print(f"  ... (total {len(value)} items)")
                except:
                    print(f"  Value: {value}")
            elif hasattr(value, 'positions'):  # ASE.Atoms 객체인 경우
                print(f"  Number of atoms: {len(value)}")
                print(f"  Positions shape: {value.positions.shape}")
                print(f"  Numbers (atomic numbers): {value.numbers}")
                print(f"  Cell shape: {value.cell.shape}")
                if hasattr(value, 'get_chemical_symbols'):
                    symbols = value.get_chemical_symbols()
                    print(f"  Chemical symbols (first 20): {symbols[:20]}")
                    if len(symbols) > 20:
                        print(f"  ... (total {len(symbols)} atoms)")
            else:
                print(f"  Value: {value}")
            print()

env.close()

In [None]:
import pickle
import numpy as np

# adsorbates.pkl 파일 경로
pkl_path = "/home/minkyu/EfficientCatGen/adsorbates.pkl"

# pickle 파일 로드
with open(pkl_path, "rb") as f:
    adsorbates_data = pickle.load(f)

print(f"=== adsorbates.pkl 파일 내용 ===\n")
print(f"데이터 타입: {type(adsorbates_data)}\n")

# 데이터 타입에 따라 다르게 처리
if isinstance(adsorbates_data, dict):
    print(f"딕셔너리 키 개수: {len(adsorbates_data)}")
    print(f"키 목록: {list(adsorbates_data.keys())}\n")
    
    for key, value in adsorbates_data.items():
        print(f"'{key}':")
        print(f"  Type: {type(value)}")
        
        if isinstance(value, np.ndarray):
            print(f"  Shape: {value.shape}")
            print(f"  Dtype: {value.dtype}")
            if value.size < 50:
                print(f"  Value:\n{value}")
            else:
                print(f"  Value (first 20 elements):\n{value.flat[:20]}")
                print(f"  ... (total {value.size} elements)")
        elif isinstance(value, (list, tuple)):
            print(f"  Length: {len(value)}")
            if len(value) < 20:
                print(f"  Value: {value}")
            else:
                print(f"  Value (first 10 items): {list(value)[:10]}")
                print(f"  ... (total {len(value)} items)")
        else:
            print(f"  Value: {value}")
        print()

elif isinstance(adsorbates_data, (list, tuple)):
    print(f"리스트/튜플 길이: {len(adsorbates_data)}\n")
    for i, item in enumerate(adsorbates_data[:10]):  # 처음 10개만 출력
        print(f"[{i}]:")
        print(f"  Type: {type(item)}")
        if isinstance(item, np.ndarray):
            print(f"  Shape: {item.shape}")
            print(f"  Dtype: {item.dtype}")
            if item.size < 20:
                print(f"  Value:\n{item}")
            else:
                print(f"  Value (first 10 elements): {item.flat[:10]}")
        else:
            print(f"  Value: {item}")
        print()
    if len(adsorbates_data) > 10:
        print(f"... (total {len(adsorbates_data)} items)\n")

elif isinstance(adsorbates_data, np.ndarray):
    print(f"배열 Shape: {adsorbates_data.shape}")
    print(f"Dtype: {adsorbates_data.dtype}")
    if adsorbates_data.size < 50:
        print(f"\n전체 값:\n{adsorbates_data}")
    else:
        print(f"\n처음 20개 요소:\n{adsorbates_data.flat[:20]}")
        print(f"... (total {adsorbates_data.size} elements)")

else:
    print(f"값: {adsorbates_data}")


In [None]:
import pickle
import sys

file_path = 'oc20_data_mapping.pkl'

try:
    # 바이너리 읽기 모드로 파일 열기
    with open(file_path, 'rb') as f:
        print(f"Loading {file_path}...")
        data = pickle.load(f)

    # 데이터 타입 확인
    print(f"Data type: {type(data)}")

    # 데이터가 딕셔너리인 경우
    if isinstance(data, dict):
        if len(data) > 0:
            # 전체를 순회하지 않고 첫 번째 키만 가져옴 (메모리 효율 및 속도)
            first_key = next(iter(data))
            first_value = data[first_key]
            
            print("\n--- First Item ---")
            print(f"Key: {first_key}")
            print(f"Value: {first_value}")
        else:
            print("The dictionary is empty.")
            
    # 데이터가 리스트인 경우 (참고용)
    elif isinstance(data, list):
        if len(data) > 0:
            print("\n--- First Item (List) ---")
            print(data[0])
        else:
            print("The list is empty.")
            
    else:
        print("The data is not a dictionary or list.")
        # 데이터의 앞부분만 살짝 출력 (문자열 변환 후 슬라이싱)
        print(f"Preview: {str(data)[:200]}")

except FileNotFoundError:
    print(f"Error: The file '{file_path}' was not found.")
except Exception as e:
    print(f"An error occurred: {e}")
