In [1]:
import seaborn as sns
import matplotlib.pyplot as plt 
import numpy as np 
import pandas as pd
import warnings
import json
from scipy.stats import wasserstein_distance
%matplotlib widget

warnings.filterwarnings("ignore")

In [2]:
train_data = pd.read_csv('mol-RaHingeGAN/ChEMBL/ChEMBL25_property_smi.csv')
RaHingeGAN = pd.read_csv('mol-RaHingeGAN/ChEMBL/100_property_smi.csv')
WGANGP = pd.read_csv('mol-wgan-gp/ChEMBL/3000_property_smi.csv')
WGANdiv = pd.read_csv('mol-wgan-div/ChEMBL/3000_property_smi.csv')

print('The number of train_data is {}'.format(len(train_data)))
print('The number of RaHingeGAN is {}'.format(len(RaHingeGAN)))
print('The number of WGANGP is {}'.format(len(WGANGP)))
print('The number of WGANdiv is {}'.format(len(WGANdiv)))

train_data['QED'] = train_data['qed']
RaHingeGAN['QED'] = RaHingeGAN['qed']
WGANGP['QED'] = WGANGP['qed']
WGANdiv['QED'] = WGANdiv['qed']


plt.close()
plt.figure(figsize=(20, 20))
legends = ['Training set', 'RaHingeGAN', 'WGAN-GP', 'WGAN-div']
mol_discriptors = ['MolWt', 'MolLogP', 'BertzCT', 'TPSA', 'QED', 'SA']
ax_all = []

for subfig in range(len(mol_discriptors)):
    ax = plt.subplot(3, 2, subfig+1)
    ax_all.append(ax)
    # Adjust the spacing between sub-pictures
    plt.subplots_adjust(wspace=0.3, hspace=0.3) 

    for i, df in enumerate([train_data, RaHingeGAN, WGANGP, WGANdiv]): 
        sns.distplot(
            df[mol_discriptors[subfig]],
            hist=False,
            kde=True,
            kde_kws={"shade": True, "linewidth": 3},
            label="{0}".format(legends[i]),
        )


    ax = plt.gca()

    # Set figure title
    font_title = {'family': 'Times New Roman',
             'style': 'normal',
             'weight': 'bold',
             'color': 'black',
             'size': 18
             }
    # plt.title("MolWt", fontdict=font_title, verticalalignment='bottom', pad=None)  

    # Set legend
    plt.legend(prop={'family':'Times New Roman', 'size':18, 'weight': 'bold'}, loc='lower center', ncol=4, bbox_to_anchor=(1.1, 1.1))

    # Set tick lables
    plt.tick_params(axis='x', labelsize='16', width=2, length=6, direction='in')
    plt.tick_params(axis='y', labelsize='16', width=2, length=6, direction='in')
    labels = ax.get_xticklabels() + ax.get_yticklabels()
    [label.set_fontname('Times New Roman') for label in labels]
    [label.set_fontweight('bold') for label in labels]
    
    
    # Set the thickness of the coordinate axis
    ax.spines['bottom'].set_linewidth(2)
    ax.spines['left'].set_linewidth(2)
    ax.spines['right'].set_linewidth(2)
    ax.spines['top'].set_linewidth(2)

    # # Range of abscissa and ordinate 
    # # plt.xlim([-10, 3000*1.1])
    # # plt.ylim([y1.min()*1.1, y1.max()*1.1])

    # Set axis labels
    x_label = {'family': 'Times New Roman',
             'style': 'normal',
             'weight': 'semibold',
             'color': 'black',
             'size': 22
             }
    y_label = {'family': 'Times New Roman',
             'style': 'normal',
             'weight': 'semibold',
             'color': 'black',
             'size': 22
             }
    plt.xlabel(mol_discriptors[subfig], fontdict=x_label)
    plt.ylabel("Density", fontdict=y_label,)

# print(len(ax_all))
for i in range(1, 6):
    ax_all[i].legend_.remove()


# # Gridlines
# plt.grid(linewidth='0.5', linestyle='--') 

# Save
plt.savefig('fig/ChEMBL/Kernel_Density_Curve_ChEMBL.pdf')
plt.savefig('fig/ChEMBL/Kernel_Density_Curve_ChEMBL.png', dpi=1000)

The number of train_data is 100000
The number of RaHingeGAN is 160198
The number of WGANGP is 168126
The number of WGANdiv is 169839


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …