In [None]:
# Figure 1a: Simulated intensity-normalized spectra
import os
import sys; sys.path.append('../../')
from utils.dataset import MyDataset
from utils.tokenizer import Tokenizer
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import FixedLocator

max_length = 23
DATA_DIR = '../../data/'

# Make tokenizer
tokenizer = Tokenizer().load_vocab(os.path.join(DATA_DIR, 'vocab.json'))

# Make dataset
np.random.seed(7)
test_dataset = MyDataset(os.path.join(DATA_DIR, 'split.csv'), 'test', tokenizer, max_length)

# Randomly select samples from the dataset
indices = np.random.choice(len(test_dataset), 4, replace=False)

fig, axes = plt.subplots(1, 4, figsize=(11, 2.5))

for i, (ax, idx) in enumerate(zip(axes.ravel(), indices)):
    spectrum, smiles = test_dataset[idx]
    ax.plot(range(len(spectrum.squeeze())), spectrum.squeeze())
    
    # Decode the SMILES string
    decoded_smiles = tokenizer.decode(smiles)
    ax.set_title(f'{decoded_smiles}')
    print(decoded_smiles)
    
    if i == 0:
        ax.set_xlabel('Wavenumber')
        ax.set_ylabel('Intensity')
    else:
        ax.set_xlabel('')
        ax.set_ylabel('')
        ax.yaxis.set_major_locator(FixedLocator([])) 
    
    ax.invert_yaxis()
    ax.invert_xaxis()
    x_ticks = np.arange(0, 2001, 500)
    x_labels = [f"{int(round(value * 2))}" for value in x_ticks]
    ax.set_xticks(x_ticks)
    ax.set_xticklabels(x_labels)

plt.tight_layout()
plt.subplots_adjust(top=0.88)
plt.show()


In [None]:
# Figure 3a: Functional groups in dataset
from tqdm import tqdm
from collections import Counter
import pandas as pd
import matplotlib.pyplot as plt
from rdkit import Chem

# Define functional groups as SMARTS
functional_groups = {
    "Alcohol": "[OX2H][CX4;!$(C([OX2H])[O,S,#7,#15])]",
    "Aldehyde": "[CX3H1](=O)[#6]",
    "Carboxylic Acid": "[CX3](=O)[OX2H1]",
    "Ester": "[#6][CX3](=O)[OX2H0][#6]",
    "Ether": "[OD2]([#6])[#6]",
    "Aldehyde": "[CX3H1](=O)[#6]",
    "Ketone": "[#6][CX3](=O)[#6]",
    "Alkene": "[CX3]=[CX3]",
    "Alkyne": "[$([CX2]#C)]",
    "Benzene": "c1ccccc1",
    "Primary Amine": "[NX3;H2;!$(NC=[!#6]);!$(NC#[!#6])][#6]",
    "Amide": "[NX3][CX3](=[OX1])[#6]",
    "Fluorine": "[#6][F]",
    "Chlorine": "[#6][Cl]",
    "Iodine": "[#6][I]",
    "Bromine": "[#6][Br]",
    "Sulfide": "[#16X2H0]"
}

fg_counts = Counter()
df = pd.read_csv("../../data/split.csv")
smiles_list = df['smiles'].tolist()
total_smiles = len(smiles_list)

for smiles in tqdm(smiles_list):
    mol = Chem.MolFromSmiles(smiles)
    # Check if the molecule is valid
    if mol: 
        for fg, smarts in functional_groups.items():
            if mol.HasSubstructMatch(Chem.MolFromSmarts(smarts)):
                fg_counts[fg] += 1

# Calculate percentages
fg_percentage = {fg: (count / total_smiles) * 100 for fg, count in fg_counts.items()}
sorted_groups = sorted(fg_percentage.keys(), key=lambda k: fg_percentage[k], reverse=True)
sorted_percentages = [fg_percentage[group] for group in sorted_groups]

plt.figure(figsize=(11, 6))
plt.bar(sorted_groups, sorted_percentages, color = 'gray')

plt.ylabel('Percentage (%)', fontsize=18)
plt.xticks(rotation=45, ha="right", fontsize=18)
plt.yticks(fontsize=16)
plt.ylim(0, 50)  

plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)

plt.tight_layout() 
plt.show()

In [None]:
# Figure 3b: SMILES string lengths in dataset
import pandas as pd
import matplotlib.pyplot as plt

data = pd.read_csv("../../data/split.csv")
data['smiles_length'] = data['smiles'].apply(len)

# Plot histogram
plt.figure(figsize=(8, 6)) 
bins = range(1, 24)
plt.hist(data['smiles_length'], bins=bins, edgecolor = 'black', align='left')
plt.xlabel('Length of SMILES String', fontsize = 28)
plt.ylabel('Frequency', fontsize = 28)

ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

ticks = [tick for tick in bins if tick % 5 == 0]
plt.xticks(ticks)
plt.tick_params(axis='both', which='major', labelsize=24)
plt.tight_layout()
plt.show()

In [None]:
# Figure 3c: Molecular properties in dataset
import pandas as pd
import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit.Chem import Descriptors, Lipinski

df = pd.read_csv("../../data/split.csv")
smiles_list = df['smiles'].tolist()

# Property lists
molweight = []
logp = []
tpsa = []
numHDonors = []
numHAcceptors = []

# Calculate properties
for smiles in smiles_list:
    mol = Chem.MolFromSmiles(smiles)
    # Check if the molecule is valid
    if mol:
        molweight.append(Descriptors.MolWt(mol))
        logp.append(Descriptors.MolLogP(mol))
        tpsa.append(Descriptors.TPSA(mol))
        numHDonors.append(Lipinski.NumHDonors(mol))
        numHAcceptors.append(Lipinski.NumHAcceptors(mol))

# Histograms for MolWT, LogP, and TPSA
fig1, ax1 = plt.subplots(nrows=1, ncols=3, figsize=(12, 3))

ax1[0].hist(molweight, bins=50, color='lightgreen', edgecolor='black')
ax1[0].set_xlabel('MolWT', fontsize = 14)
ax1[0].set_ylabel('Frequency', fontsize = 14)

ax1[1].hist(logp, bins=50, color='lightgreen', edgecolor='black')
ax1[1].set_xlabel('LogP', fontsize = 14)

ax1[2].hist(tpsa, bins=50, color='lightgreen', edgecolor='black')
ax1[2].set_xlabel('TPSA', fontsize = 14)

ax1[0].spines['top'].set_visible(False)
ax1[0].spines['right'].set_visible(False)
ax1[1].spines['top'].set_visible(False)
ax1[1].spines['right'].set_visible(False)
ax1[2].spines['top'].set_visible(False)
ax1[2].spines['right'].set_visible(False)

for ax in ax1:
    ax.tick_params(axis='both', which='major', labelsize=12)

plt.tight_layout()
plt.show()

# Boxplot for HBD and HBA
fig, ax = plt.subplots(figsize=(5, 4))
data = [numHDonors, numHAcceptors]
bp = ax.boxplot(data, showfliers=False, patch_artist=True, vert=False)

colors = ['lightgreen', 'lightgreen']
for patch, color in zip(bp['boxes'], colors):
    patch.set_facecolor(color)
for median in bp['medians']:
    median.set_color('red')

ax.set_yticks([1, 2])
ax.set_xticks(list(range(6)))

plt.tick_params(axis='both', which='major', labelsize=18)
ax.set_yticklabels(['HBD', 'HBA'], fontsize=16)
ax.set_xlabel('Number', fontsize=20)

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

plt.tight_layout()
plt.show()