In [6]:
import seaborn as sns
import matplotlib.pyplot as plt 
import numpy as np 
import pandas as pd
import warnings
import json
from scipy.stats import wasserstein_distance
%matplotlib widget

warnings.filterwarnings("ignore")

### 1. 获取数据

In [7]:
epochs = list(range(100, 3000+1, 100))

moses_record = {}
for epoch in epochs:
    with open('metrics/'+str(epoch)+'_moses_smi.json', 'r') as f:
        moses_record[epoch] = json.load(f)  

metrics = {}
for epoch in epochs:
    for k, v in moses_record[epoch].items():
        if k in metrics.keys():
            metrics[k].append(v)
        else:
            metrics[k] = []
            metrics[k].append(v)

In [8]:
metrics.pop('SNN/TestSF')
metrics.pop('Frag/TestSF')
metrics.pop('Scaf/TestSF')
metrics.pop('IntDiv2')
metrics['SNN/Train'] = metrics.pop('SNN/Test')
metrics['Frag/Train'] = metrics.pop('Frag/Test')
metrics['Scaf/Train'] = metrics.pop('Scaf/Test')
metrics['Valid'] = metrics.pop('valid')
metrics['Unique'] = metrics.pop('unique@140000')
metrics['MolWt'] = metrics.pop('weight')

In [9]:
metrics.keys()

dict_keys(['IntDiv', 'Filters', 'logP', 'SA', 'QED', 'Novelty', 'SNN/Train', 'Frag/Train', 'Scaf/Train', 'Valid', 'Unique', 'MolWt'])

In [10]:
with open('metrics.json', 'w') as f:
    json.dump(metrics, f)

In [11]:
metrics_1 = ['Valid', 'Unique', 'Novelty', 'Filters']
metrics_2 = ['Frag/Train', 'Scaf/Train', 'SNN/Train', 'IntDiv']
metrics_3 = ['logP', 'SA', 'QED', 'MolWt']

In [12]:
# metrics 1
plt.close()
plt.figure(figsize=(10, 8))
for i, index in enumerate(metrics_1):
    plt.subplot(2, 2, i+1)
    plt.subplots_adjust(wspace=0.2, hspace=0.4)#调整子图间距
    
    plt.plot(epochs, metrics[index], '-b', marker='.')
        
    ax = plt.gca()
    
    # 设置图标题
    font_title = {'family': 'Times New Roman',
             'style': 'normal',
             'weight': 'bold',
             'color': 'black',
             'size': 17
             }
    plt.title(index, fontdict=font_title, verticalalignment='bottom', pad=None)
    
    # 设置刻度线标签
    plt.tick_params(axis='x', labelsize='12', width=1, length=4, direction='in')
    plt.tick_params(axis='y', labelsize='12', width=1, length=4, direction='in')
    labels = ax.get_xticklabels() + ax.get_yticklabels()
    [label.set_fontname('Times New Roman') for label in labels]
    # [label.set_fontweight('bold') for label in labels]
    
    # 设置坐标轴标签
    x_label = {'family': 'Times New Roman',
             'style': 'normal',
             'weight': 'normal',
             'color': 'black',
             'size': 12
             }
#     y_label = {'family': 'Times New Roman',
#              'style': 'normal',
#              'weight': 'semibold',
#              'color': 'black',
#              'size': 16
#              }
    plt.xlabel("Epochs", fontdict=x_label)
#     plt.ylabel("Loss", fontdict=y_label,)

    # # 网格线
    plt.grid(linewidth='0.5', linestyle='--') 

# 保存
# plt.savefig('metrics1_Valid_Unique_Novelty_Filter_ChEMBL.pdf')
# plt.savefig('metrics1_Valid_Unique_Novelty_Filter_ChEMBL.png', dpi=1000)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [13]:
# metrics 2
plt.close()
plt.figure(figsize=(10, 8))
for i, index in enumerate(metrics_2):
    plt.subplot(2, 2, i+1)
    plt.subplots_adjust(wspace=0.2, hspace=0.4)#调整子图间距
    
    plt.plot(epochs, metrics[index], '-b', marker='.')
        
    ax = plt.gca()
    
    # 设置图标题
    font_title = {'family': 'Times New Roman',
             'style': 'normal',
             'weight': 'bold',
             'color': 'black',
             'size': 17
             }
    plt.title(index, fontdict=font_title, verticalalignment='bottom', pad=None)
    
    # 设置刻度线标签
    plt.tick_params(axis='x', labelsize='12', width=1, length=4, direction='in')
    plt.tick_params(axis='y', labelsize='12', width=1, length=4, direction='in')
    labels = ax.get_xticklabels() + ax.get_yticklabels()
    [label.set_fontname('Times New Roman') for label in labels]
    # [label.set_fontweight('bold') for label in labels]
    
    # 设置坐标轴标签
    x_label = {'family': 'Times New Roman',
             'style': 'normal',
             'weight': 'normal',
             'color': 'black',
             'size': 12
             }
#     y_label = {'family': 'Times New Roman',
#              'style': 'normal',
#              'weight': 'semibold',
#              'color': 'black',
#              'size': 16
#              }
    plt.xlabel("Epochs", fontdict=x_label)
#     plt.ylabel("Loss", fontdict=y_label,)

    # # 网格线
    plt.grid(linewidth='0.5', linestyle='--') 

# 保存
# plt.savefig('metrics2_Frag_Scaf_SNN_IntDiv_ChEMBL.pdf')
# plt.savefig('metrics2_Frag_Scaf_SNN_IntDiv_ChEMBL.png', dpi=1000)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [22]:
# metrics 3
plt.close()
plt.figure(figsize=(10, 8))
for i, index in enumerate(metrics_3):
    plt.subplot(2, 2, i+1)
    plt.subplots_adjust(wspace=0.2, hspace=0.4)#调整子图间距
    
    plt.plot(epochs, metrics[index], '-b', marker='.')
        
    ax = plt.gca()
    
    # 设置图标题
    font_title = {'family': 'Times New Roman',
             'style': 'normal',
             'weight': 'bold',
             'color': 'black',
             'size': 17
             }
    plt.title(index, fontdict=font_title, verticalalignment='bottom', pad=None)
    
    # 设置刻度线标签
    plt.tick_params(axis='x', labelsize='12', width=1, length=4, direction='in')
    plt.tick_params(axis='y', labelsize='12', width=1, length=4, direction='in')
    labels = ax.get_xticklabels() + ax.get_yticklabels()
    [label.set_fontname('Times New Roman') for label in labels]
    # [label.set_fontweight('bold') for label in labels]

    # 设置坐标轴标签
    x_label = {'family': 'Times New Roman',
             'style': 'normal',
             'weight': 'normal',
             'color': 'black',
             'size': 12
             }
#     y_label = {'family': 'Times New Roman',
#              'style': 'normal',
#              'weight': 'semibold',
#              'color': 'black',
#              'size': 16
#              }
    plt.xlabel("Epochs", fontdict=x_label)
#     plt.ylabel("Loss", fontdict=y_label,)

    # # 网格线
    plt.grid(linewidth='0.5', linestyle='--') 

# 保存
plt.savefig('metrics3_logP_SA_QED_MolWt_ChEMBL.pdf')
plt.savefig('metrics3_logP_SA_QED_MolWt_ChEMBL.png', dpi=1000)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …