In [1]:
import seaborn as sns
import matplotlib.pyplot as plt 
import numpy as np 
import pandas as pd
import warnings
import json
from scipy.stats import wasserstein_distance
%matplotlib widget

warnings.filterwarnings("ignore")

In [2]:
train_data = pd.read_csv('mol-RaHingeGAN/ChEMBL/ChEMBL25_property_smi.csv')
RaHingeGAN = pd.read_csv('mol-RaHingeGAN/ChEMBL/100_property_smi.csv')
WGANGP = pd.read_csv('mol-wgan-gp/ChEMBL/3000_property_smi.csv')
WGANdiv = pd.read_csv('mol-wgan-div/ChEMBL/3000_property_smi.csv')

print('The number of train_data is {}'.format(len(train_data)))
print('The number of RaHingeGAN is {}'.format(len(RaHingeGAN)))
print('The number of WGANGP is {}'.format(len(WGANGP)))
print('The number of WGANdiv is {}'.format(len(WGANdiv)))

train_data['QED'] = train_data['qed']
RaHingeGAN['QED'] = RaHingeGAN['qed']
WGANGP['QED'] = WGANGP['qed']
WGANdiv['QED'] = WGANdiv['qed']

train_data['SetType'] = 'Training set'
RaHingeGAN['SetType'] = 'RaHingeGAN'
WGANGP['SetType'] = 'WGAN-GP'
WGANdiv['SetType'] = 'WGAN-div'

all_data = pd.concat([train_data, RaHingeGAN, WGANGP, WGANdiv], axis=0)

mol_discriptors = ['MolWt', 'MolLogP', 'BertzCT', 'TPSA', 'QED', 'SA']

plt.close()
plt.figure(figsize=(20, 20))
sns.set_style('darkgrid')

for subfig in range(len(mol_discriptors)):
    plt.subplot(3, 2, subfig+1)
    plt.subplots_adjust(wspace=0.3, hspace=0.3) # Adjust the spacing between sub-pictures
    
    sns.violinplot(x = "SetType", 
                   y = mol_discriptors[subfig], 
                   data = all_data, 
                   palette = 'RdBu', 
                   saturation=1,
                  )

    ax = plt.gca()

    # Set figure title
    font_title = {'family': 'Times New Roman',
             'style': 'normal',
             'weight': 'bold',
             'color': 'black',
             'size': 26
             }
    plt.title(mol_discriptors[subfig], fontdict=font_title, verticalalignment='bottom', pad=None)  

    # Set tick labels
    plt.tick_params(axis='x', labelsize='20', width=2, length=6, direction='in')
    plt.tick_params(axis='y', labelsize='20', width=2, length=6, direction='in')
    labels = ax.get_xticklabels() + ax.get_yticklabels()
    [label.set_fontname('Times New Roman') for label in labels]
    [label.set_fontweight('bold') for label in labels]

    plt.xlabel('')
    plt.ylabel('')

# Save
plt.savefig('fig/ChEMBL/Violin_Plot_ChEMBL.pdf')
plt.savefig('fig/ChEMBL/Violin_Plot_ChEMBL.png', dpi=1000)

The number of train_data is 100000
The number of RaHingeGAN is 160198
The number of WGANGP is 168126
The number of WGANdiv is 169839


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …