In [1]:
import seaborn as sns
import matplotlib.pyplot as plt 
import numpy as np 
import pandas as pd
import warnings
import json
from scipy.stats import wasserstein_distance
%matplotlib widget

warnings.filterwarnings("ignore")

### 1. 获取数据

In [3]:
train_data = pd.read_csv('ChEMBL25_property_smi.csv')
gen_data = pd.read_csv('3000_property_smi.csv')

print('The number of train_data is {}'.format(len(train_data)))
print('The number of gen_data is {}'.format(len(gen_data)))

The number of train_data is 100000
The number of gen_data is 143420


In [4]:
train_data['QED'] = train_data['qed']
gen_data['QED'] = gen_data['qed']

In [10]:
plt.close()
plt.figure(figsize=(20, 20))
legends = ['Training set', 'Generating set']
mol_discriptors = ['MolWt', 'MolLogP', 'BertzCT', 'TPSA', 'QED', 'SA']

for subfig in range(len(mol_discriptors)):
    plt.subplot(3, 2, subfig+1)
    plt.subplots_adjust(wspace=0.3, hspace=0.3)#调整子图间距

    for i, df in enumerate([train_data, gen_data]): 
        dist = wasserstein_distance(df[mol_discriptors[subfig]], train_data[mol_discriptors[subfig]])
        sns.distplot(
            df[mol_discriptors[subfig]],
            # bins=200,
            hist=False,
            kde=True,
            kde_kws={"shade": True, "linewidth": 3},
            label="{0} ({1:0.2g})".format(legends[i], dist),
        )


    ax = plt.gca()

    # 设置图标题
    font_title = {'family': 'Times New Roman',
             'style': 'normal',
             'weight': 'bold',
             'color': 'black',
             'size': 18
             }
    # plt.title("MolWt", fontdict=font_title, verticalalignment='bottom', pad=None)  
    
    if mol_discriptors[subfig] == 'MolLogP':
        plt.xlim(-2, 10)
    if mol_discriptors[subfig] == 'QED':
        plt.xlim(-0.1, 1.5)    

    # 设置图例
    plt.legend(prop={'family':'Times New Roman', 'size':16}, loc='upper right')

    # 设置刻度线标签
    plt.tick_params(axis='x', labelsize='16', width=2, length=6, direction='in')
    plt.tick_params(axis='y', labelsize='16', width=2, length=6, direction='in')
    labels = ax.get_xticklabels() + ax.get_yticklabels()
    [label.set_fontname('Times New Roman') for label in labels]
    [label.set_fontweight('bold') for label in labels]
    
    
    #设置坐标轴的粗细
    ax.spines['bottom'].set_linewidth(2)
    ax.spines['left'].set_linewidth(2)
    ax.spines['right'].set_linewidth(2)
    ax.spines['top'].set_linewidth(2)

    # # 横纵坐标范围
    # # plt.xlim([-10, 3000*1.1])
    # # plt.ylim([y1.min()*1.1, y1.max()*1.1])

    # 设置坐标轴标签
    x_label = {'family': 'Times New Roman',
             'style': 'normal',
             'weight': 'semibold',
             'color': 'black',
             'size': 22
             }
    y_label = {'family': 'Times New Roman',
             'style': 'normal',
             'weight': 'semibold',
             'color': 'black',
             'size': 22
             }
    plt.xlabel(mol_discriptors[subfig], fontdict=x_label)
    plt.ylabel("Density", fontdict=y_label,)

# # 网格线
# plt.grid(linewidth='0.5', linestyle='--') 

# 保存
plt.savefig('Kernel_Density_Curve_ChEMBL.pdf')
plt.savefig('Kernel_Density_Curve_ChEMBL.png', dpi=1000)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …