In [None]:
import torch
import pandas as pd
from torch.utils.data import Dataset
import os
from torch_geometric.data import Data
import pickle
import numpy as np
import pytorch_lightning as pl
from pathlib import Path
from torch_geometric.data import DataLoader
from torch_geometric.data import Batch
from dBandDiff.data_utils import preprocess, preprocess_tensors, add_scaled_lattice_prop, get_scaler_from_data_list
import math, copy
import random
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from typing import Any, Dict, Optional, Sequence
from torch_scatter import scatter
from torch_geometric.utils import to_dense_adj, dense_to_sparse
from tqdm import tqdm
from dBandDiff.diff_utils import d_log_p_wrapped_normal
import hydra
import omegaconf
from torch_scatter.composite import scatter_softmax
from dBandDiff.data_utils import (EPSILON, cart_to_frac_coords, mard, lengths_angles_to_volume, 
                        lattice_params_to_matrix_torch, frac_to_cart_coords, min_distance_sqr_pbc)
from dBandDiff.crystal_family import CrystalFamily
from dBandDiff.diff_utils import d_log_p_wrapped_normal
from dBandDiff.diff_utils import BetaScheduler, SigmaScheduler
from copy import deepcopy as dc
from dBandDiff.diffusion import model



MAX_ATOMIC_NUM = 118

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False  

set_seed(456)  

In [None]:
import time
import argparse
import torch

from tqdm import tqdm
from torch.optim import Adam
from pathlib import Path
from types import SimpleNamespace
from torch_geometric.data import Data, Batch, DataLoader
from torch.utils.data import Dataset
from dBandDiff.eval_utils import load_model, lattices_to_params_shape, get_crystals_list, recommand_step_lr
from pymatgen.core.structure import Structure
from pymatgen.core.lattice import Lattice
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter
from pyxtal.symmetry import Group
import chemparse
import numpy as np
from p_tqdm import p_map
import pdb
import os





def diffusion(loader, model, step_lr):

    frac_coords = []
    num_atoms = []
    atom_types = []
    lattices = []
    input_data_list = []
    traj_list = []
    
    for idx, batch in enumerate(loader):

        if torch.cuda.is_available():
            batch.cuda()
        outputs, traj = model.sample(batch, step_lr = step_lr)
        #frac_coords.append(outputs['frac_coords'].detach().cpu())
        #num_atoms.append(outputs['num_atoms'].detach().cpu())
        #atom_types.append(outputs['atom_types'].detach().cpu())
        #lattices.append(outputs['lattices'].detach().cpu())
        traj_list.append(traj)
    
    #frac_coords = torch.cat(frac_coords, dim=0)
    #num_atoms = torch.cat(num_atoms, dim=0)
    #atom_types = torch.cat(atom_types, dim=0)
    #lattices = torch.cat(lattices, dim=0)
    #lengths, angles = lattices_to_params_shape(lattices)
    
    #return (frac_coords, atom_types, lattices, lengths, angles, num_atoms)
    return traj_list



class SampleDataset(Dataset):

    def __init__(self, num_graph, num_atoms, spg_number, d_band_center, ops, anchor_index):
        super().__init__()
        self.num_graph = num_graph
        self.num_atoms = [num_atoms]*num_graph
        self.spg_number = [spg_number]*num_graph

        opss = []
        anchors = []
        
        for i in range(0, num_graph):
            opss.append(ops)
            anchors.append(anchor_index)
            
        self.opss = opss
        self.anchors = anchors 
        self.d_band_center = [d_band_center]*num_graph

    def __len__(self) -> int:
        return self.num_graph

    def __getitem__(self, index):

        num_atom = self.num_atoms[index]
        spg_number = self.spg_number[index]
        d_band_center = self.d_band_center[index]
        ops = self.opss[index]
        ops_tensor = torch.tensor(ops)
        anchor_index = self.anchors[index]  

        data = Data(
            d_band_center = torch.Tensor([d_band_center]).view(1, -1),
            num_atoms=num_atoms,
            num_nodes=num_atoms,
        )

        data.spg_number = torch.Tensor([spg_number]).view(1, -1)
        data.ops = torch.Tensor(ops)
        data.anchor_index = torch.LongTensor(anchor_index)
        data.ops_inv = torch.linalg.pinv(data.ops[:,:3,:3])
        return data



        
def main(num_graph, num_atoms, spg_number, d_band_center, ops, anchor_index):

    model_path = "model_weights.pth"
    batch_size = 1  

    step_lr = 1e-5 
    label = ''  

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    
    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path, map_location=device))
        #print(f"Successfully loaded model from {model_path}")
    else:
        print(f"Model file not found at {model_path}")
        return  

    model.to(device)


    test_set = SampleDataset(num_graph, num_atoms, spg_number, d_band_center, ops, anchor_index)
    test_loader = DataLoader(test_set, batch_size=batch_size)

    step_lr = step_lr  

    start_time = time.time()
    traj_list = diffusion(test_loader, model, step_lr)
    return traj_list



In [None]:
import os
from pymatgen.core.structure import Structure
from pymatgen.core.lattice import Lattice
from pymatgen.io.cif import CifWriter
from ase import io
from ase.visualize import view
import warnings
import csv,random

warnings.filterwarnings("ignore")

def save_traj_as_cif(trajs, save_folder):

    if not os.path.exists(save_folder):
        os.makedirs(save_folder)

    atom_types = trajs['atom_types'].cpu().numpy()
    frac_coords = trajs['all_frac_coords'].cpu().numpy()  
    lattices = trajs['all_lattices'].cpu().numpy() 

    j = 999
    lattice = Lattice(lattices[j])

    structure = Structure(lattice, atom_types[j].tolist(), frac_coords[j].tolist())

    # 创建 CIF 文件
    cif_file_path = os.path.join(save_folder, f"structure_{j}.cif")
    writer = CifWriter(structure)
    writer.write_file(cif_file_path)
    #print(f"Saved CIF file: {cif_file_path}")

In [None]:
from dBandDiff.get_symmetry import get_symmetry_info
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer


with open(f'generate_samples.csv', mode="w", newline="") as csv_file:
    writer = csv.writer(csv_file)
    writer.writerow([
        "ID", "spg_symbol", "spg_number", "d_band_center"
    ])


for file in os.listdir('template'):

    cif_path = os.path.join('template',file)
    crystal, c, sym_info = get_symmetry_info(cif_path = cif_path, tol=0.01)
    print(sym_info['spacegroup'])
    num_atoms = len(crystal)
    spg_number = sym_info['spacegroup']
    ops = sym_info['wyckoff_ops']
    anchor_index = sym_info['anchors']
    num_graph = 20
    sga = SpacegroupAnalyzer(crystal, symprec=0.01)
    spg_symbol = sga.get_space_group_symbol()
    spg_number2 = sga.get_space_group_number()
    if spg_number2 == spg_number:
        print('spg_number correct')
        for i in range(0,num_graph):
            d = random.uniform(-3, 0)
            print(d)            
            traj_list = main(num_graph = 1, num_atoms = num_atoms, spg_number = spg_number, d_band_center = d, ops = ops, anchor_index = anchor_index)
            trajs = traj_list[0]
            save_folder = f'generate/spg_number{spg_number}/cif_files{i}'
            
            save_traj_as_cif(trajs, save_folder)
            with open(f'generate_samples.csv', mode="a", newline="") as csv_file:
                writer = csv.writer(csv_file)
                writer.writerow([
                    f"cif_files{i}", spg_symbol, spg_number, d
                ])
            