In [1]:
from pymatgen.core.periodic_table import Element
from pymatgen.analysis.local_env import CrystalNN
from pymatgen.core.structure import Structure
from pymatgen.io.xcrysden import XSF
import networkx as nx
import numpy as np
import plotly.graph_objs as go
import math
import copy

# 从文件中读取晶体结构
# 替换为您的文件路径和文件名
dir_="D:\\project\\谢琎老师锂电池\\第二次：Ti阻止了Li-Ni混排\\计算结果\\第三次：氧空位\\数据3-返稿意见\\结构\\combined_xsf_files\\"
str_name='LiNiO2_331_NCM_513_Ti_num_1_0'
xsf_file = f'{str_name}.xsf'

def read_xsf(file_path):
    with open(file_path, 'r') as file:
        xsf_content = file.read()
    xsf = XSF.from_string(xsf_content)
    return xsf.structure

crystal_orginal = read_xsf(dir_ + xsf_file)
# Create a deep copy of the crystal structure
crystal = copy.deepcopy(crystal_orginal)

def create_original_index_list(crystal):
    original_indices = []
    for i, site in enumerate(crystal):
        if str(site.specie) != 'O':  # 只为非氧原子创建映射
            original_indices.append(i)
    return original_indices

#index projection from original to delete
original_index_map = create_original_index_list(crystal)

def map_old_indices_to_new(original_index_map, crystal):
    new_index_map = {}
    for new_index, old_index in enumerate(original_index_map):
        new_index_map[old_index] = new_index
    return new_index_map

#index projection from delete to original
new_index_map = map_old_indices_to_new(original_index_map, crystal)


# Remove all oxygen sites from the copy
crystal_metal = copy.deepcopy(crystal_orginal)
crystal_metal.remove_species(["O"])

# 使用pymatgen的CrystalNN找到临近的原子
crystal_nn = CrystalNN()
neighbors = [crystal_nn.get_nn_info(crystal, n=i) for i, _ in enumerate(crystal.sites)]

neighbors_metal = [crystal_nn.get_nn_info(crystal_metal, n=i) for i, _ in enumerate(crystal_metal.sites)]


# 构建图
G = nx.Graph()
for i, site in enumerate(crystal.sites):
    G.add_node(i, element=site.specie.symbol)
    for neighbor in neighbors[i]:
        neighbor_index = neighbor['site_index']
        if not G.has_edge(i, neighbor_index):
            G.add_edge(i, neighbor_index)

for i, site in enumerate(crystal_metal.sites):
    for neighbor in neighbors_metal[i]:
        li_ni_edge=0
        neighbor_index = neighbor['site_index']
        if ((crystal_metal[i].specie==Element("Li")) and (crystal_metal[neighbor_index].specie==Element("Ni"))) or ((crystal_metal[i].specie==Element("Ni")) and (crystal_metal[neighbor_index].specie==Element("Li"))):
            li_ni_edge=1
        if not G.has_edge(original_index_map[i], original_index_map[neighbor_index]) and li_ni_edge==1:
            G.add_edge(original_index_map[i], original_index_map[neighbor_index])

color_map = {
    'Mn': 'rgba(128, 0, 128, 1)',  # 紫色
    'Ti': 'rgba(135, 206, 235, 1)', # 天蓝色
    'Co': 'rgba(0, 0, 139, 1)',     # 深蓝色
    'Ni': 'rgba(128, 128, 128, 1)', # 灰色
    'Li': 'rgba(0, 128, 0, 1)',     # 绿色
    'O': 'rgba(255, 0, 0, 1)'       # 红色
}

  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (


In [2]:

index_new=40
print(original_index_map[index_new])
print(new_index_map[original_index_map[index_new]])


94
40


In [3]:
def get_nearest_neighbors_of_type(atom_index, element_type):
    """
    找到给定原子索引的最近邻原子，这些原子的类型为指定类型，数量限制为n个。
    已放置的原子将从考虑中剔除。

    :param atom_index: 当前原子的索引。
    :param element_type: 要找的原子类型（例如 'O' 表示氧原子）。
    :param placed_atoms: 已经放置的原子索引的集合。
    :param n: 最多返回的邻近原子数量。
    :return: 指定类型的最近邻原子的索引列表。
    """
    n=6
    neighbors = crystal_nn.get_nn_info(crystal, n=atom_index)
    selected_neighbors = []
    if element_type!='Non-O-Li':
        for neighbor in neighbors:
            neighbor_index = neighbor['site_index']
            neighbor_element = crystal[neighbor_index].specie

            if neighbor_element == element_type and neighbor_index not in [key for key,v in pos.items()]:
                selected_neighbors.append(neighbor_index)
                if len(selected_neighbors) >= n:
                    break
    else:
        for neighbor in neighbors:
            neighbor_index = neighbor['site_index']
            neighbor_element = crystal[neighbor_index].specie

            if neighbor_element != Element('Li') and neighbor_element != Element('O') and neighbor_index not in [key for key,v in pos.items()]:
                selected_neighbors.append(neighbor_index)
                if len(selected_neighbors) >= n:
                    break    

    return selected_neighbors


def get_nearest_neighbors_metal(atom_index):
    """
    找到给定原子索引的最近邻原子，这些原子的类型为指定类型，数量限制为n个。
    已放置的原子将从考虑中剔除。

    :param atom_index: 当前原子的索引。
    :param element_type: 要找的原子类型（例如 'O' 表示氧原子）。
    :param placed_atoms: 已经放置的原子索引的集合。
    :param n: 最多返回的邻近原子数量。
    :return: 指定类型的最近邻原子的索引列表。
    """
    n=6
    neighbors = crystal_nn.get_nn_info(crystal_metal, n=new_index_map[atom_index])
    selected_neighbors = []
    for neighbor in neighbors:
        neighbor_index = neighbor['site_index']
        print(neighbor_index)
        neighbor_element = crystal_metal[neighbor_index].specie

        if neighbor_element != Element('Li') and neighbor_element != Element('O') and neighbor_index not in [key for key,v in pos.items()]:
            print("OR:",original_index_map[neighbor_index])
            selected_neighbors.append(original_index_map[neighbor_index])
            if len(selected_neighbors) >= n:
                break    

    return selected_neighbors

def get_nearest_neighbors_li(atom_index):
    """
    找到给定原子索引的最近邻原子，这些原子的类型为指定类型，数量限制为n个。
    已放置的原子将从考虑中剔除。

    :param atom_index: 当前原子的索引。
    :param element_type: 要找的原子类型（例如 'O' 表示氧原子）。
    :param placed_atoms: 已经放置的原子索引的集合。
    :param n: 最多返回的邻近原子数量。
    :return: 指定类型的最近邻原子的索引列表。
    """
    n=6
    neighbors = crystal_nn.get_nn_info(crystal_metal, n=new_index_map[atom_index])
    selected_neighbors = []
    for neighbor in neighbors:
        neighbor_index = neighbor['site_index']
        print(neighbor_index)
        neighbor_element = crystal_metal[neighbor_index].specie

        if neighbor_element == Element('Li') and neighbor_index not in [key for key,v in pos.items()]:
            print("OR:",original_index_map[neighbor_index])
            selected_neighbors.append(original_index_map[neighbor_index])
            if len(selected_neighbors) >= n:
                break    

    return selected_neighbors

def place_atoms_in_layer(base_distance, atom_indices):
    """
    根据给定的基础距离和原子索引列表，计算并放置这些原子的位置。
    同时更新链接列表。

    :param base_distance: 原子到中心的基础距离。
    :param atom_indices: 要放置的原子索引列表。
    :param links: 存储原子间链接的列表。
    :param pos: 存储原子位置的字典。
    """
    num_atoms = len(atom_indices)
    for i, atom_index in enumerate(atom_indices):
        angle = 2 * math.pi * i / num_atoms
        x = base_distance * math.cos(angle)
        y = base_distance * math.sin(angle)
        pos[atom_index] = (x, y)

        
'''       
# 中心原子和第一层

center_atom_vesta=16

center_atom_index = center_atom_vesta-1
pos = {center_atom_index:(0,0)}
#print("*"*20,"center atom","*"*20,)
#print(pos)

distances = [crystal.get_distance(center_atom_index, i) for i, _ in enumerate(crystal.sites)]
first_layer_oxygen_atoms = get_nearest_neighbors_of_type(center_atom_index, Element('O'))
place_atoms_in_layer(1, first_layer_oxygen_atoms)


# 1.2层和1.4层
layer_1_2_li_atoms=[]
layer_1_4_other_atoms=[]
#for oxygen_atom in first_layer_oxygen_atoms:
    #layer_1_2_li_atoms.extend(get_nearest_neighbors_of_type(oxygen_atom, Element('Li')))
    #layer_1_4_other_atoms.extend(get_nearest_neighbors_of_type(oxygen_atom, 'Non-O-Li'))
layer_1_4_other_atoms.extend(get_nearest_neighbors_metal(center_atom_index)) #只考虑最近邻的六个TM原子
layer_1_2_li_atoms.extend(get_nearest_neighbors_li(center_atom_index))
#layer_1_2_li_atoms=list(set(layer_1_2_li_atoms))
#layer_1_4_other_atoms=list(set(layer_1_4_other_atoms))
#print("AA",layer_1_4_other_atoms)
place_atoms_in_layer(1.2, layer_1_2_li_atoms)  # 第1.2层距离为1.2
place_atoms_in_layer(1.4, layer_1_4_other_atoms)  # 第1.4层距离为1.4

second_layer_oxygen_atoms = set()
for atom_index in [key for key,v in pos.items()]:
    if str(crystal[atom_index].specie) != Element('O'):  # 选择非氧原子
        nearest_oxygen_atoms = get_nearest_neighbors_of_type(atom_index, Element('O'))
        second_layer_oxygen_atoms.update(nearest_oxygen_atoms)

#print("*"*20,"1st layer atom","*"*20,)
#print(pos)
        
# 更新第二层原子的位置
second_layer_oxygen_atoms = second_layer_oxygen_atoms - set([key for key,v in pos.items()])
place_atoms_in_layer(2, list(second_layer_oxygen_atoms))  # 假设第二层距离为2

# 计算第2.2层和第2.4层原子位置
layer_2_2_li_atoms=[]
layer_2_4_other_atoms=[]
for oxygen_atom in second_layer_oxygen_atoms:
    layer_2_2_li_atoms.extend(get_nearest_neighbors_of_type(oxygen_atom, Element('Li')))
    layer_2_4_other_atoms.extend(get_nearest_neighbors_of_type(oxygen_atom, 'Non-O-Li'))
layer_2_2_li_atoms=list(set(layer_2_2_li_atoms))
layer_2_4_other_atoms=list(set(layer_2_4_other_atoms))
place_atoms_in_layer(2.2, layer_2_2_li_atoms)  # 第2.2层距离为2.2
place_atoms_in_layer(2.4, layer_2_4_other_atoms)  # 第2.4层距离为2.4

#print("*"*20,"2nd layer atom","*"*20,)
#print(pos)
    
# 放置剩余原子
# 在这里，pos字典包含了所有原子的位置，links列表包含了原子之间的链接


remaining_atoms = set(range(len(crystal.sites))) - set([key for key,v in pos.items()])

place_atoms_in_layer(3, list(remaining_atoms))  # 假设最外层距离为3

'''

# 中心原子和第一层

center_atom_vesta=108

center_atom_index = center_atom_vesta-1
pos = {center_atom_index:(0,0)}
#print("*"*20,"center atom","*"*20,)
#print(pos)

distances = [crystal.get_distance(center_atom_index, i) for i, _ in enumerate(crystal.sites)]
first_layer_oxygen_atoms = get_nearest_neighbors_of_type(center_atom_index, Element('O'))
place_atoms_in_layer(0.4, first_layer_oxygen_atoms)


# 1.2层和1.4层
layer_1_2_li_atoms=[]
layer_1_2_li_atoms.extend(get_nearest_neighbors_metal(center_atom_index)) #只考虑最近邻的六个TM原子
layer_1_2_li_atoms.extend(get_nearest_neighbors_li(center_atom_index))
place_atoms_in_layer(1.2, layer_1_2_li_atoms)  # 第1.2层距离为1.2

second_layer_oxygen_atoms = set()
for atom_index in [key for key,v in pos.items()]:
    if str(crystal[atom_index].specie) != Element('O'):  # 选择非氧原子
        nearest_oxygen_atoms = get_nearest_neighbors_of_type(atom_index, Element('O'))
        second_layer_oxygen_atoms.update(nearest_oxygen_atoms)

#print("*"*20,"1st layer atom","*"*20,)
#print(pos)
        
# 更新第二层原子的位置
second_layer_oxygen_atoms = second_layer_oxygen_atoms - set([key for key,v in pos.items()])
place_atoms_in_layer(1.8, list(second_layer_oxygen_atoms))  # 假设第二层距离为2

# 计算第2.2层和第2.4层原子位置
layer_2_2_li_atoms=[]
for oxygen_atom in second_layer_oxygen_atoms:
    layer_2_2_li_atoms.extend(get_nearest_neighbors_of_type(oxygen_atom, Element('Li')))
    layer_2_2_li_atoms.extend(get_nearest_neighbors_of_type(oxygen_atom, 'Non-O-Li'))
layer_2_2_li_atoms=list(set(layer_2_2_li_atoms))
place_atoms_in_layer(2.4, layer_2_2_li_atoms)  # 第2.2层距离为2.2

#print("*"*20,"2nd layer atom","*"*20,)
#print(pos)
    
# 放置剩余原子
# 在这里，pos字典包含了所有原子的位置，links列表包含了原子之间的链接


th_layer_oxygen_atoms = set()
for atom_index in [key for key,v in pos.items()]:
    if str(crystal[atom_index].specie) != Element('O'):  # 选择非氧原子
        nearest_oxygen_atoms = get_nearest_neighbors_of_type(atom_index, Element('O'))
        th_layer_oxygen_atoms.update(nearest_oxygen_atoms)

#print("*"*20,"1st layer atom","*"*20,)
#print(pos)
        
# 更新第二层原子的位置
th_layer_oxygen_atoms = th_layer_oxygen_atoms - set([key for key,v in pos.items()]) - second_layer_oxygen_atoms
place_atoms_in_layer(3, list(th_layer_oxygen_atoms))  # 假设第二层距离为2

# 计算第2.2层和第2.4层原子位置
layer_3_2_li_atoms=[]
for oxygen_atom in th_layer_oxygen_atoms:
    layer_3_2_li_atoms.extend(get_nearest_neighbors_of_type(oxygen_atom, Element('Li')))
    layer_3_2_li_atoms.extend(get_nearest_neighbors_of_type(oxygen_atom, 'Non-O-Li'))
layer_3_2_li_atoms=list(set(layer_3_2_li_atoms))
place_atoms_in_layer(3.6, layer_3_2_li_atoms)  # 第2.2层距离为2.2


# 绘制链接
# ...

# 注意：实际代码实现需要具体到您的数据结构和可用函数
#print("*"*20,"final layer atom","*"*20,)
#for i,v in pos.items():
#    print(i,v)
#print(len(pos),len(pos[0]))

0
19
18
1
24
4
50
OR: 104
49
OR: 103
28
OR: 82
35
OR: 89
33
OR: 87
34
OR: 88
0
OR: 0
19
OR: 19
18
OR: 18
1
OR: 1
24
OR: 24
4
OR: 4


In [4]:
import plotly.graph_objs as go
import plotly.offline as py_offline
import plotly.io as pio


In [5]:
#pip install -U kaleido

In [12]:

resolution=[20,12]

# 使用plotly绘制图形
x_nodes = [pos[i][0] for i in G.nodes()]
y_nodes = [pos[i][1] for i in G.nodes()]

#center_atom_index

#师兄配色
color_map = {
    'Mn': 'rgba(253,207,158 , 1)',  # 紫色
    'Ti': 'rgba(184,168,207, 1)', # 天蓝色
    'Co': 'rgba(239,164,132 , 1)',     # 深蓝色
    'Ni': 'rgba(182,118,108, 1)', # 灰色
    'Li': 'rgba(78,101,155, 1)',     # 绿色
    'O': 'rgba(178,101,155, 1)'       # 红色   
}

size_map = {
    'Mn': 40,  
    'Ti': 40, 
    'Co': 40,     
    'Ni': 40, 
    'Li': 40,     
    'O': 28     
}

reso=10

#239,164,132
size_map2 = {
    'Mn': 40*reso,  
    'Ti': 40*reso, 
    'Co': 40*reso,     
    'Ni': 40*reso, 
    'Li': 40*reso,     
    'O': 28*reso     
}



# Initialize edge trace
edge_traces = []
for edge in G.edges():
    if {'Li', 'Ni'} == {G.nodes[edge[0]]['element'], G.nodes[edge[1]]['element']}:
        pass
    
    x0, y0 = pos[edge[0]]
    x1, y1 = pos[edge[1]]
    edge_x = [x0, x1, None]
    edge_y = [y0, y1, None]

    # Check if edge connects Li and Ni
    #if {'Li', 'Ni'} == {G.nodes[edge[0]]['element'], G.nodes[edge[1]]['element']}:
    #    edge_color='red' # Red for Li-Ni edges
    #else:
    #    edge_color='grey'  # Gray for other edges
    edge_color=None
    # Create an edge trace
    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=0.5, color=edge_color),
        hoverinfo='none',
        mode='lines')

    edge_traces.append(edge_trace)
edge_traces = []    
    
# 生成节点标签
node_colors = []
node_sizes = []
node_sizes2 = []
node_labels = []
node_labels_show = []
for node in G.nodes():
    element = G.nodes[node]['element']
    # 直接根据元素符号分配颜色
    node_colors.append(color_map[element])
    node_sizes.append(size_map[element])
    node_sizes2.append(size_map2[element])
    # 获取连接的原子及其索引，并按原子距离中心原子的远近排序
    connected_atoms = [(G.nodes[neighbor]['element'], neighbor) for neighbor in G.neighbors(node)]
    # 按照与中心原子的距离排序
    connected_atoms_sorted = sorted(connected_atoms, key=lambda x: distances[x[1]])
    # 格式化标签字符串（显示时索引+1）
    connected_atoms_labels = ['{}#{}'.format(atom[0], atom[1] + 1) for atom in connected_atoms_sorted]
    label = f"{element}#{node + 1}: coor_num: {len(connected_atoms)} -- {'.'.join(connected_atoms_labels)}"
    label_show = f"{element}-{node + 1}"
    node_labels.append(label)
    node_labels_show.append(label_show)

# 生成边标签
edge_labels = []
for edge in G.edges():
    element1, element2 = G.nodes[edge[0]]['element'], G.nodes[edge[1]]['element']
    label = f"{element1}#{edge[0]+1}-{element2}-{edge[1]+1}"
    edge_labels.append(label)
    #color1, color2 = color_map[element1], color_map[element2]
    # 生成过渡颜色
    #edge_colors.append(f'rgba({(color1[0]+color2[0])//2}, {(color1[1]+color2[1])//2}, {(color1[2]+color2[2])//2}, 0.5)')

# 更新节点和边的trace以包含标签
'''
node_trace = go.Scatter(
    x=x_nodes, y=y_nodes,
    mode='markers+text',
    hoverinfo='text',
    hovertext=node_labels,
    textposition="top center",
    marker=dict(
        color=node_colors,
        size=node_sizes,
        line_width=2))
'''

node_trace = go.Scatter(
    x=x_nodes, y=y_nodes,
    mode='markers+text',
    hoverinfo='text',
    hovertext=node_labels,
    #text=node_labels_show,
    textposition="top center",
    marker=dict(
        color=node_colors,
        size=node_sizes,
        line_width=2))

node_trace2 = go.Scatter(
    x=x_nodes, y=y_nodes,
    mode='markers+text',
    hoverinfo='text',
    hovertext=node_labels,
    #text=node_labels_show,
    textposition="top center",
    marker=dict(
        color=node_colors,
        size=node_sizes2,
        line_width=2))

# Update edge trace to use edge colors
'''
edge_trace = go.Scatter(
    x=edge_x, y=edge_y,
    line=dict(width=0.5, color=edge_colors),  # Use edge_colors here
    hoverinfo='text',
    hovertext=edge_labels,
    mode='lines')
'''
# 创建图形对象并添加边和节点
fig = go.Figure(data=edge_traces + [node_trace],
                layout=go.Layout(
                    title=f'<br>Network graph of crystal structure centered with {crystal[center_atom_index].specie.symbol}#{center_atom_index+1}',
                    showlegend=False,
                    hovermode='closest',
                    margin=dict(b=20,l=5,r=5,t=40),
                    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)))

fig2 = go.Figure(data=edge_traces + [node_trace2],
                layout=go.Layout(
                    title=f'<br>Network graph of crystal structure centered with {crystal[center_atom_index].specie.symbol}#{center_atom_index+1}',
                    showlegend=False,
                    hovermode='closest',
                    margin=dict(b=20,l=5,r=5,t=40),
                    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)))

# 显示图形
#str_name=xsf_file.split(".xsf")[0].split("out_relax_")[1]
str_name=xsf_file.split(".xsf")[0]
fig.write_html(f'{dir_}{str_name}_{crystal[center_atom_index].specie.symbol}#{center_atom_index+1}_neighbor.html')
pio.write_image(fig2, f'{dir_}{str_name}_{crystal[center_atom_index].specie.symbol}#{center_atom_index+1}_neighbor.png', width=resolution[0]*100*reso, height=resolution[1]*100*reso)
fig.show()


In [72]:
#pip show kaleido

SyntaxError: invalid syntax (Temp/ipykernel_9940/1399984334.py, line 1)

In [76]:
#pip show plotly

Name: plotly
Version: 5.21.0
Summary: An open-source, interactive data visualization library for Python
Home-page: https://plotly.com/python/
Author: Chris P
Author-email: chris@plot.ly
License: MIT
Location: e:\anconda_installation_file\lib\site-packages
Requires: packaging, tenacity
Required-by: pymatgen
Note: you may need to restart the kernel to use updated packages.


In [34]:
#TM atoms test

In [10]:
def get_nearest_neighbors_of_type(atom_index, element_type):
    """
    找到给定原子索引的最近邻原子，这些原子的类型为指定类型，数量限制为n个。
    已放置的原子将从考虑中剔除。

    :param atom_index: 当前原子的索引。
    :param element_type: 要找的原子类型（例如 'O' 表示氧原子）。
    :param placed_atoms: 已经放置的原子索引的集合。
    :param n: 最多返回的邻近原子数量。
    :return: 指定类型的最近邻原子的索引列表。
    """
    n=6
    neighbors = crystal_nn.get_nn_info(crystal, n=atom_index)
    selected_neighbors = []
    if element_type!='Non-O-Li':
        for neighbor in neighbors:
            neighbor_index = neighbor['site_index']
            neighbor_element = crystal[neighbor_index].specie

            if neighbor_element == element_type and neighbor_index not in [key for key,v in pos.items()]:
                selected_neighbors.append(neighbor_index)
                if len(selected_neighbors) >= n:
                    break
    else:
        for neighbor in neighbors:
            neighbor_index = neighbor['site_index']
            neighbor_element = crystal[neighbor_index].specie

            if neighbor_element != Element('Li') and neighbor_element != Element('O') and neighbor_index not in [key for key,v in pos.items()]:
                selected_neighbors.append(neighbor_index)
                if len(selected_neighbors) >= n:
                    break    

    return selected_neighbors


def get_nearest_neighbors_metal(atom_index):
    """
    找到给定原子索引的最近邻原子，这些原子的类型为指定类型，数量限制为n个。
    已放置的原子将从考虑中剔除。

    :param atom_index: 当前原子的索引。
    :param element_type: 要找的原子类型（例如 'O' 表示氧原子）。
    :param placed_atoms: 已经放置的原子索引的集合。
    :param n: 最多返回的邻近原子数量。
    :return: 指定类型的最近邻原子的索引列表。
    """
    n=6
    neighbors = crystal_nn.get_nn_info(crystal_metal, n=new_index_map[atom_index])
    selected_neighbors = []
    for neighbor in neighbors:
        neighbor_index = neighbor['site_index']
        print(neighbor_index)
        neighbor_element = crystal_metal[neighbor_index].specie

        if neighbor_element != Element('Li') and neighbor_element != Element('O') and neighbor_index not in [key for key,v in pos.items()]:
            print("OR:",original_index_map[neighbor_index])
            selected_neighbors.append(original_index_map[neighbor_index])
            if len(selected_neighbors) >= n:
                break    

    return selected_neighbors

def get_nearest_neighbors_li(atom_index):
    """
    找到给定原子索引的最近邻原子，这些原子的类型为指定类型，数量限制为n个。
    已放置的原子将从考虑中剔除。

    :param atom_index: 当前原子的索引。
    :param element_type: 要找的原子类型（例如 'O' 表示氧原子）。
    :param placed_atoms: 已经放置的原子索引的集合。
    :param n: 最多返回的邻近原子数量。
    :return: 指定类型的最近邻原子的索引列表。
    """
    n=6
    neighbors = crystal_nn.get_nn_info(crystal_metal, n=new_index_map[atom_index])
    selected_neighbors = []
    for neighbor in neighbors:
        neighbor_index = neighbor['site_index']
        print(neighbor_index)
        neighbor_element = crystal_metal[neighbor_index].specie

        if neighbor_element == Element('Li') and neighbor_index not in [key for key,v in pos.items()]:
            print("OR:",original_index_map[neighbor_index])
            selected_neighbors.append(original_index_map[neighbor_index])
            if len(selected_neighbors) >= n:
                break    

    return selected_neighbors

def place_atoms_in_layer(base_distance, atom_indices):
    """
    根据给定的基础距离和原子索引列表，计算并放置这些原子的位置。
    同时更新链接列表。

    :param base_distance: 原子到中心的基础距离。
    :param atom_indices: 要放置的原子索引列表。
    :param links: 存储原子间链接的列表。
    :param pos: 存储原子位置的字典。
    """
    num_atoms = len(atom_indices)
    for i, atom_index in enumerate(atom_indices):
        angle = 2 * math.pi * i / num_atoms
        x = base_distance * math.cos(angle)
        y = base_distance * math.sin(angle)
        pos[atom_index] = (x, y)

# 中心原子和第一层

center_atom_vesta=38

center_atom_index = center_atom_vesta-1
pos = {center_atom_index:(0,0)}
#print("*"*20,"center atom","*"*20,)
#print(pos)

distances = [crystal.get_distance(center_atom_index, i) for i, _ in enumerate(crystal.sites)]
first_layer_oxygen_atoms = get_nearest_neighbors_of_type(center_atom_index, Element('O'))
place_atoms_in_layer(1, first_layer_oxygen_atoms)


# 1.2层和1.4层
layer_1_2_li_atoms=[]
layer_1_4_other_atoms=[]
#for oxygen_atom in first_layer_oxygen_atoms:
    #layer_1_2_li_atoms.extend(get_nearest_neighbors_of_type(oxygen_atom, Element('Li')))
    #layer_1_4_other_atoms.extend(get_nearest_neighbors_of_type(oxygen_atom, 'Non-O-Li'))
layer_1_4_other_atoms.extend(get_nearest_neighbors_metal(center_atom_index)) #只考虑最近邻的六个TM原子
layer_1_2_li_atoms.extend(get_nearest_neighbors_li(center_atom_index))
#layer_1_2_li_atoms=list(set(layer_1_2_li_atoms))
#layer_1_4_other_atoms=list(set(layer_1_4_other_atoms))
#print("AA",layer_1_4_other_atoms)
place_atoms_in_layer(1.2, layer_1_2_li_atoms)  # 第1.2层距离为1.2
place_atoms_in_layer(1.4, layer_1_4_other_atoms)  # 第1.4层距离为1.4

second_layer_oxygen_atoms = set()
for atom_index in [key for key,v in pos.items()]:
    if str(crystal[atom_index].specie) != Element('O'):  # 选择非氧原子
        nearest_oxygen_atoms = get_nearest_neighbors_of_type(atom_index, Element('O'))
        second_layer_oxygen_atoms.update(nearest_oxygen_atoms)

#print("*"*20,"1st layer atom","*"*20,)
#print(pos)
        
# 更新第二层原子的位置
second_layer_oxygen_atoms = second_layer_oxygen_atoms - set([key for key,v in pos.items()])
place_atoms_in_layer(2, list(second_layer_oxygen_atoms))  # 假设第二层距离为2

# 计算第2.2层和第2.4层原子位置
layer_2_2_li_atoms=[]
layer_2_4_other_atoms=[]
for oxygen_atom in second_layer_oxygen_atoms:
    layer_2_2_li_atoms.extend(get_nearest_neighbors_of_type(oxygen_atom, Element('Li')))
    layer_2_4_other_atoms.extend(get_nearest_neighbors_of_type(oxygen_atom, 'Non-O-Li'))
layer_2_2_li_atoms=list(set(layer_2_2_li_atoms))
layer_2_4_other_atoms=list(set(layer_2_4_other_atoms))
place_atoms_in_layer(2.2, layer_2_2_li_atoms)  # 第2.2层距离为2.2
place_atoms_in_layer(2.4, layer_2_4_other_atoms)  # 第2.4层距离为2.4

#print("*"*20,"2nd layer atom","*"*20,)
#print(pos)
    
# 放置剩余原子
# 在这里，pos字典包含了所有原子的位置，links列表包含了原子之间的链接


remaining_atoms = set(range(len(crystal.sites))) - set([key for key,v in pos.items()])

place_atoms_in_layer(3, list(remaining_atoms))  # 假设最外层距离为3



# 绘制链接
# ...

# 注意：实际代码实现需要具体到您的数据结构和可用函数
#print("*"*20,"final layer atom","*"*20,)
#for i,v in pos.items():
#    print(i,v)
#print(len(pos),len(pos[0]))

9
OR: 9
12
OR: 12
25
OR: 25
6
OR: 6
3
OR: 3
8
OR: 8
9
12
25
6
3
8


In [1]:
# 使用plotly绘制图形
x_nodes = [pos[i][0] for i in G.nodes()]
y_nodes = [pos[i][1] for i in G.nodes()]


# Initialize edge trace
edge_traces = []
for edge in G.edges():
    x0, y0 = pos[edge[0]]
    x1, y1 = pos[edge[1]]
    edge_x = [x0, x1, None]
    edge_y = [y0, y1, None]

    # Check if edge connects Li and Ni
    if {'Li', 'Ni'} == {G.nodes[edge[0]]['element'], G.nodes[edge[1]]['element']}:
        edge_color='red'  # Red for Li-Ni edges
    else:
        edge_color='grey'  # Gray for other edges
    # Create an edge trace
    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=0.5, color=edge_color),
        hoverinfo='none',
        mode='lines')

    edge_traces.append(edge_trace)
    
    
# 生成节点标签
node_colors = []
node_labels = []
for node in G.nodes():
    element = G.nodes[node]['element']
    # 直接根据元素符号分配颜色
    node_colors.append(color_map[element])
    # 获取连接的原子及其索引，并按原子距离中心原子的远近排序
    connected_atoms = [(G.nodes[neighbor]['element'], neighbor) for neighbor in G.neighbors(node)]
    # 按照与中心原子的距离排序
    connected_atoms_sorted = sorted(connected_atoms, key=lambda x: distances[x[1]])
    # 格式化标签字符串（显示时索引+1）
    connected_atoms_labels = ['{}#{}'.format(atom[0], atom[1] + 1) for atom in connected_atoms_sorted]
    label = f"{element}#{node + 1}: coor_num: {len(connected_atoms)} -- {'.'.join(connected_atoms_labels)}"
    node_labels.append(label)

# 生成边标签
edge_labels = []
for edge in G.edges():
    element1, element2 = G.nodes[edge[0]]['element'], G.nodes[edge[1]]['element']
    label = f"{element1}#{edge[0]+1}-{element2}-{edge[1]+1}"
    edge_labels.append(label)
    #color1, color2 = color_map[element1], color_map[element2]
    # 生成过渡颜色
    #edge_colors.append(f'rgba({(color1[0]+color2[0])//2}, {(color1[1]+color2[1])//2}, {(color1[2]+color2[2])//2}, 0.5)')

# 更新节点和边的trace以包含标签
node_trace = go.Scatter(
    x=x_nodes, y=y_nodes,
    mode='markers+text',
    hoverinfo='text',
    hovertext=node_labels,
    textposition="top center",
    marker=dict(
        color=node_colors,
        size=10,
        line_width=2))

# Update edge trace to use edge colors
'''
edge_trace = go.Scatter(
    x=edge_x, y=edge_y,
    line=dict(width=0.5, color=edge_colors),  # Use edge_colors here
    hoverinfo='text',
    hovertext=edge_labels,
    mode='lines')
'''
# 创建图形对象并添加边和节点
fig = go.Figure(data=edge_traces + [node_trace],
                layout=go.Layout(
                    title=f'<br>Network graph of crystal structure centered with {crystal[center_atom_index].specie.symbol}#{center_atom_index+1}',
                    showlegend=False,
                    hovermode='closest',
                    margin=dict(b=20,l=5,r=5,t=40),
                    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)))

# 显示图形
str_name=xsf_file.split(".xsf")[0].split("out_relax_")[1]
fig.write_html(f'{dir_}{str_name}_{crystal[center_atom_index].specie.symbol}#{center_atom_index+1}_neighbor.html')
fig.show()


NameError: name 'G' is not defined

定义R-GCN模型
R-GCN（关系图卷积网络）是一种用于处理图数据的神经网络，特别适合处理有多种类型边的图。在您的案例中，有两种类型的边：'original'边和'li_ni_edge'边。

In [None]:
import torch
import torch.nn as nn
from torch_geometric.nn import RGCNConv

class RGCN(nn.Module):
    def __init__(self, in_channels, out_channels, num_relations):
        super(RGCN, self).__init__()
        # 定义R-GCN层
        self.conv1 = RGCNConv(in_channels, 16, num_relations)
        self.conv2 = RGCNConv(16, out_channels, num_relations)

    def forward(self, x, edge_index, edge_type):
        # 通过R-GCN层传递节点特征
        x = self.conv1(x, edge_index, edge_type)
        x = torch.relu(x)
        x = self.conv2(x, edge_index, edge_type)
        return x

# 实例化模型
model = RGCN(in_channels=节点特征维度, out_channels=输出特征维度, num_relations=2) # 这里有两种类型的边


准备数据
您需要将图数据转换为适合PyTorch Geometric的格式。这包括节点特征矩阵、边索引列表和边类型。

In [None]:
import torch
from torch_geometric.data import Data

# 假设G是您的图
# 将图数据转换为PyTorch Geometric格式

# 节点特征矩阵
x = torch.tensor(节点特征矩阵, dtype=torch.float)

# 边索引
edge_index = torch.tensor([list(zip(*G.edges()))], dtype=torch.long)

# 边类型（'original'为0, 'li_ni_edge'为1）
edge_type = torch.tensor([0 if G[u][v]['edge_type'] == 'original' else 1 for u, v in G.edges()], dtype=torch.long)

# 创建PyTorch Geometric数据对象
data = Data(x=x, edge_index=edge_index, edge_type=edge_type)


训练模型
最后，您需要定义一个训练循环来训练R-GCN模型。

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()  # 假设您是在做回归任务

for epoch in range(总的训练轮数):
    model.train()
    optimizer.zero_grad()
    
    out = model(data.x, data.edge_index, data.edge_type)
    
    # 计算损失
    loss = criterion(out, 目标值)
    
    loss.backward()
    optimizer.step()

    print(f'Epoch {epoch}, Loss: {loss.item()}')
