Master in notebook

In [3]:
import pickle
import os
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem.Draw import rdMolDraw2D

import numpy as np
from functools import reduce
import operator
import matplotlib.pyplot as plt
import warnings

def load_data(pkl_path):
    with open(pkl_path, 'rb') as file:
        data = pickle.load(file)
    return data['results_dict_bl_ZINC']

def process_probabilities(smile, probabilities):
    mol = Chem.MolFromSmiles(smile)
    if mol is None:
        raise ValueError("Invalid SMILES string")

    num_atoms = mol.GetNumAtoms()
    atom_probabilities = []
    current_atom_prob = None
    non_atom_probs = []

    for char, prob in zip(smile, probabilities):
        if char.isalpha() or char == '[' or char == ']':  # Start of an atom symbol
            if current_atom_prob is not None:
                if non_atom_probs:
                    # Average the non-atom probabilities with the current and previous atom
                    avg_prob = (current_atom_prob + sum(non_atom_probs) + prob) / (len(non_atom_probs) + 2)
                    atom_probabilities.append(avg_prob)
                    non_atom_probs = []
                else:
                    atom_probabilities.append(current_atom_prob)
            current_atom_prob = prob
        else:  # Non-atom character
            non_atom_probs.append(prob)

    # Handle the last atom
    if current_atom_prob is not None:
        if non_atom_probs:
            # Average the non-atom probabilities with the last atom
            avg_prob = (current_atom_prob + sum(non_atom_probs)) / (len(non_atom_probs) + 1)
            atom_probabilities.append(avg_prob)
        else:
            atom_probabilities.append(current_atom_prob)

    # Ensure we have the correct number of probabilities
    if len(atom_probabilities) > num_atoms:
        atom_probabilities = atom_probabilities[:num_atoms]
    if len(atom_probabilities) != num_atoms:
        warnings.warn(f"Number of processed probabilities ({len(atom_probabilities)}) "
                      f"does not match number of atoms in molecule ({num_atoms}). "
                      "This may lead to incorrect coloring. "
                      "For now probabilities of 1 are added to keep the code working")
        atom_probabilities.extend([1.0] * (num_atoms - len(atom_probabilities)))

    return atom_probabilities

def get_color(prob):
    green_value = prob * 0.9
    red_value = 1.0 - prob * 0.40
    blue_value = 0.7
    return (float((red_value*255)/255), float((green_value*255)/255), float((blue_value*255)/255))
import xml.etree.ElementTree as ET
import re


def visualize_molecules(pkl_path, output_folder, num_molecules=5, generations_per_molecule=3):
    result = load_data(pkl_path)
    keys_list = list(result.keys())
    
    os.makedirs(output_folder, exist_ok=True)
    
    for index in range(min(num_molecules, len(keys_list))):
        target_smi = keys_list[index]
        
        # Target molecule
        smi = target_smi
        prob_list = [1 for _ in smi]
        label = "Target"
        svg = generate_colored_molecule(smi, prob_list, label)
        
        with open(os.path.join(output_folder, f"molecule_{index}_target.svg"), "w", encoding='utf-8') as f:
            f.write(svg)
        
        # Generated molecules
        for i in range(1, generations_per_molecule + 1):
            smi = result[target_smi][0][i-1][0]
            prob_list = result[target_smi][0][i-1][3].tolist()
            HSQC_err = result[target_smi][0][i-1][-2][0]
            COSY_err = result[target_smi][0][i-1][-2][1]
            tani_sim = result[target_smi][0][i-1][-3]
            prob_product = reduce(operator.mul, prob_list, 1)
            
            label = f"Generation: {i}\nProbability: {prob_product:.2e}\nHSQC Error: {HSQC_err:.3f}\nCOSY Error: {COSY_err:.3f}\nTanimoto Sim: {tani_sim:.3f}"
            
            svg = generate_colored_molecule(smi, prob_list, label)
            
            with open(os.path.join(output_folder, f"molecule_{index}_generation_{i}.svg"), "w", encoding='utf-8') as f:
                f.write(svg)

    print(f"Generated SVGs for {num_molecules} molecules with {generations_per_molecule} generations each.")
    
    
import re

def generate_colored_molecule(smile, probabilities, label):
    molecule = Chem.MolFromSmiles(smile)
    atom_probabilities = process_probabilities(smile, probabilities)
    
    # Increase the SVG size more significantly
    d = rdMolDraw2D.MolDraw2DSVG(300, 300)  # Keep molecule size at 300x300
    d.drawOptions().useBWAtomPalette()
    highlight_atom_colors = {i: get_color(prob) for i, prob in enumerate(atom_probabilities)}

    d.DrawMolecule(molecule, highlightAtoms=list(highlight_atom_colors.keys()), highlightAtomColors=highlight_atom_colors)
    d.FinishDrawing()
    
    svg = d.GetDrawingText()
    
    # Parse the SVG
    try:
        root = ET.fromstring(svg)
    except ET.ParseError:
        # If parsing fails, try to clean up the SVG
        svg = re.sub(r'</?svg[^>]*>', '', svg)  # Remove svg tags
        svg = f'<svg width="300" height="500">{svg}</svg>'  # Wrap in new svg tags with increased height
        root = ET.fromstring(svg)
    
    # Set or adjust the viewBox
    root.set('viewBox', '0 0 300 500')  # Set viewBox to match new dimensions
    root.set('width', '300')
    root.set('height', '500')
    
    # Add a white background rectangle
    background = ET.SubElement(root, 'rect', {
        'width': '300',
        'height': '500',
        'fill': 'white'
    })
    
    # Move the background to the back
    root.insert(0, background)
    
    # Add label
    text_element = ET.SubElement(root, 'text', {
        'x': '10',
        'y': '320',  # Start labels below the molecule
        'font-family': 'sans-serif',
        'font-size': '14px',  # Increased font size
        'fill': 'black'
    })
    
    y = 320
    for line in label.split('\n'):
        tspan = ET.SubElement(text_element, 'tspan', {
            'x': '10',
            'y': str(y)
        })
        tspan.text = line
        y += 20  # Increased line spacing
    
    # Convert back to string
    svg_str = ET.tostring(root, encoding='unicode')
    
    return svg_str



In [11]:

# Call this function to display the generated SVGs in your notebook
#display_molecules(output_folder)
pkl_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/_ISAK/Runfolder/pkl_files/20240731_223447_IC_exp_data.pkl'
output_folder = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20241022_Explainability_Model'
#visualize_molecules(pkl_path, output_folder)


Generated SVGs for 5 molecules with 3 generations each.


In [4]:
import xml.etree.ElementTree as ET
import re
# To display SVGs in the notebook
from IPython.display import SVG, display

def display_molecules(output_folder, num_molecules=5, generations_per_molecule=3):
    for index in range(num_molecules):
        display(SVG(filename=os.path.join(output_folder, f"molecule_{index}_target.svg")))
        for i in range(1, generations_per_molecule + 1):
            display(SVG(filename=os.path.join(output_folder, f"molecule_{index}_generation_{i}.svg")))


In [6]:
import pickle
import os
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem.Draw import rdMolDraw2D

import numpy as np
from functools import reduce
import operator
import matplotlib.pyplot as plt
import warnings

def load_data(pkl_path):
    with open(pkl_path, 'rb') as file:
        data = pickle.load(file)
    return data['results_dict_bl_ZINC']

def process_probabilities(smile, probabilities):
    mol = Chem.MolFromSmiles(smile)
    if mol is None:
        raise ValueError("Invalid SMILES string")

    num_atoms = mol.GetNumAtoms()
    atom_probabilities = []
    current_atom_prob = None
    non_atom_probs = []

    for char, prob in zip(smile, probabilities):
        if char.isalpha() or char == '[' or char == ']':  # Start of an atom symbol
            if current_atom_prob is not None:
                if non_atom_probs:
                    # Average the non-atom probabilities with the current and previous atom
                    avg_prob = (current_atom_prob + sum(non_atom_probs) + prob) / (len(non_atom_probs) + 2)
                    atom_probabilities.append(avg_prob)
                    non_atom_probs = []
                else:
                    atom_probabilities.append(current_atom_prob)
            current_atom_prob = prob
        else:  # Non-atom character
            non_atom_probs.append(prob)

    # Handle the last atom
    if current_atom_prob is not None:
        if non_atom_probs:
            # Average the non-atom probabilities with the last atom
            avg_prob = (current_atom_prob + sum(non_atom_probs)) / (len(non_atom_probs) + 1)
            atom_probabilities.append(avg_prob)
        else:
            atom_probabilities.append(current_atom_prob)

    # Ensure we have the correct number of probabilities
    if len(atom_probabilities) > num_atoms:
        atom_probabilities = atom_probabilities[:num_atoms]
    if len(atom_probabilities) != num_atoms:
        warnings.warn(f"Number of processed probabilities ({len(atom_probabilities)}) "
                      f"does not match number of atoms in molecule ({num_atoms}). "
                      "This may lead to incorrect coloring. "
                      "For now probabilities of 1 are added to keep the code working")
        atom_probabilities.extend([1.0] * (num_atoms - len(atom_probabilities)))

    return atom_probabilities

def get_color(prob):
    green_value = prob * 0.9
    red_value = 1.0 - prob * 0.40
    blue_value = 0.7
    return (float((red_value*255)/255), float((green_value*255)/255), float((blue_value*255)/255))

import xml.etree.ElementTree as ET
import re

import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
from PIL import Image, ImageDraw, ImageFont
import matplotlib.font_manager as fm

def create_label_image(label, width=300, height=120):
    # Create a white image
    image = Image.new('RGB', (width, height), color='white')
    draw = ImageDraw.Draw(image)
    
    # Try to use a common sans-serif font
    try:
        # Look for Arial or Helvetica
        font_path = fm.findfont(fm.FontProperties(family='DejaVu Sans'))
        font = ImageFont.truetype(font_path, 14)  # Increased font size to 14
    except:
        # If not found, use default
        font = ImageFont.load_default()

    # Draw the text
    y_text = 5
    for line in label.split('\n'):
        draw.text((5, y_text), line, fill='black', font=font)
        y_text += 24  # Increased line spacing

    return image


def html_to_svg(html_string, width=600, height=30):
    # Extract the colored spans from the HTML string
    spans = re.findall(r'<span style="background-color: ([^;]+);[^>]+>([^<]+)</span>', html_string)
    
    # Create the SVG string
    svg = f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">'
    svg += '<style>text { font-family: monospace; font-size: 20px; }</style>'
    
    x = 5  # Starting x position
    for color, char in spans:
        svg += f'<rect x="{x}" y="5" width="20" height="20" fill="{color}"/>'
        svg += f'<text x="{x+5}" y="22" fill="black">{char}</text>'
        x += 20  # Move to the next character position
    
    svg += '</svg>'
    return svg

def generate_colored_html(smile, probabilities):
    html_string = '<span style="font-family: monospace;">'
    
    for char, prob in zip(smile, probabilities):
        color = get_color(prob)
        css_color = f'rgb({int(color[0]*255)}, {int(color[1]*255)}, {int(color[2]*255)})'
        html_string += f'<span style="background-color: {css_color}; padding: 1px; display: inline-block;">{char}</span>'
    
    html_string += '</span>'
    return html_string

def visualize_molecules(pkl_path, output_folder, num_molecules=5, generations_per_molecule=3):
    result = load_data(pkl_path)
    keys_list = list(result.keys())
    
    os.makedirs(output_folder, exist_ok=True)
    
    for index in range(min(num_molecules, len(keys_list))):
        target_smi = keys_list[index]
        
        # Target molecule
        smi = target_smi
        prob_list = [1 for _ in smi]
        label = "Target"
        svg = generate_colored_molecule(smi, prob_list, "")  # Empty label for SVG
        
        svg_filename = os.path.join(output_folder, f"molecule_{index}_target.svg")
        png_filename = os.path.join(output_folder, f"molecule_{index}_target_label.png")
        str_svg_filename = os.path.join(output_folder, f"molecule_{index}_target_str.svg")
        
        with open(svg_filename, "w", encoding='utf-8') as f:
            f.write(svg)
        
        # Create and save label image
        label_image = create_label_image(label)
        label_image.save(png_filename)
        
        # Generate and save colored SMILES string as SVG
        colored_html = generate_colored_html(smi, prob_list)
        colored_svg = html_to_svg(colored_html, width=len(smi)*20+10)  # Adjust width based on SMILES length
        with open(str_svg_filename, "w", encoding='utf-8') as f:
            f.write(colored_svg)
        
        # Generated molecules
        for i in range(1, generations_per_molecule + 1):
            smi = result[target_smi][0][i-1][0]
            prob_list = result[target_smi][0][i-1][3].tolist()
            HSQC_err = result[target_smi][0][i-1][-2][0]
            COSY_err = result[target_smi][0][i-1][-2][1]
            tani_sim = result[target_smi][0][i-1][-3]
            prob_product = reduce(operator.mul, prob_list, 1)
            
            label = f"Generation: {i}\nProbability: {prob_product:.2e}\nHSQC Error: {HSQC_err:.3f}\nCOSY Error: {COSY_err:.3f}\nTanimoto Sim: {tani_sim:.3f}"
            
            svg = generate_colored_molecule(smi, prob_list, "")  # Empty label for SVG
            
            svg_filename = os.path.join(output_folder, f"molecule_{index}_generation_{i}.svg")
            png_filename = os.path.join(output_folder, f"molecule_{index}_generation_{i}_label.png")
            str_svg_filename = os.path.join(output_folder, f"molecule_{index}_generation_{i}_str.svg")
            
            with open(svg_filename, "w", encoding='utf-8') as f:
                f.write(svg)
            
            # Create and save label image
            label_image = create_label_image(label)
            label_image.save(png_filename)
            
            # Generate and save colored SMILES string as SVG
            colored_html = generate_colored_html(smi, prob_list)
            colored_svg = html_to_svg(colored_html, width=len(smi)*20+10)  # Adjust width based on SMILES length
            with open(str_svg_filename, "w", encoding='utf-8') as f:
                f.write(colored_svg)

    print(f"Generated SVGs, label PNGs, and colored SMILES SVGs for {num_molecules} molecules with {generations_per_molecule} generations each.")
    
    
def generate_colored_molecule(smile, probabilities, label):
    molecule = Chem.MolFromSmiles(smile)
    atom_probabilities = process_probabilities(smile, probabilities)
    
    # Increase the SVG size more significantly
    d = rdMolDraw2D.MolDraw2DSVG(300, 500)  # Increase height to accommodate labels
    d.drawOptions().useBWAtomPalette()
    highlight_atom_colors = {i: get_color(prob) for i, prob in enumerate(atom_probabilities)}

    d.DrawMolecule(molecule, highlightAtoms=list(highlight_atom_colors.keys()), highlightAtomColors=highlight_atom_colors)
    d.FinishDrawing()
    
    svg = d.GetDrawingText()
    
    # Parse the SVG
    try:
        root = ET.fromstring(svg)
    except ET.ParseError:
        # If parsing fails, try to clean up the SVG
        svg = re.sub(r'</?svg[^>]*>', '', svg)  # Remove svg tags
        svg = f'<svg width="300" height="500">{svg}</svg>'  # Wrap in new svg tags with increased height
        root = ET.fromstring(svg)
    
    # Set or adjust the viewBox
    root.set('viewBox', '0 0 300 500')  # Set viewBox to match new dimensions
    root.set('width', '300')
    root.set('height', '500')
    
    # Add a white background rectangle
    background = ET.SubElement(root, 'rect', {
        'width': '300',
        'height': '500',
        'fill': 'white'
    })
    
    # Move the background to the back
    root.insert(0, background)
    
    # Add label
    text_element = ET.SubElement(root, 'text', {
        'x': '10',
        'y': '320',  # Start labels below the molecule
        'font-family': 'sans-serif',
        'font-size': '14px',  # Increased font size
        'fill': 'black'
    })
    
    y = 340  # Start labels a bit lower
    for line in label.split('\n'):
        tspan = ET.SubElement(text_element, 'tspan', {
            'x': '10',
            'y': str(y)
        })
        tspan.text = line
        y += 20  # Increased line spacing
    
    # Convert back to string
    svg_str = ET.tostring(root, encoding='unicode')
    
    return svg_str

# Call this function to display the generated SVGs in your notebook
# display_molecules(output_folder)
pkl_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/_ISAK/Runfolder/pkl_files/20240731_223447_IC_exp_data.pkl'
output_folder = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20241022_Explainability_Model_2'
visualize_molecules(pkl_path, output_folder)


Generated SVGs, label PNGs, and colored SMILES SVGs for 5 molecules with 3 generations each.


### Version with alignment 
- Working - just replace for pkl path from the following folder for the exp, ACD and Sim data

- /projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240724_IC_batched_MMT_a1

In [30]:
import pickle
import os
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem.Draw import rdMolDraw2D

import numpy as np
from functools import reduce
import operator
import matplotlib.pyplot as plt
import warnings
import xml.etree.ElementTree as ET
import re

# ... [Keep all the existing import statements and helper functions] ...

def get_color(prob):
    green_value = prob * 0.9
    red_value = 1.0 - prob * 0.40
    blue_value = 0.7
    return (float((red_value*255)/255), float((green_value*255)/255), float((blue_value*255)/255))
import xml.etree.ElementTree as ET
import re


def generate_colored_molecule(smile, probabilities, label):
    # Canonicalize SMILES
    mol = Chem.MolFromSmiles(smile)
    canonical_smiles = Chem.CanonSmiles(smile)
    molecule = Chem.MolFromSmiles(canonical_smiles)
    
    # Recompute probabilities for the canonical SMILES
    canonical_probabilities = process_probabilities(canonical_smiles, probabilities)
    
    # Generate 2D coordinates
    rdDepictor.Compute2DCoords(molecule)
    
    # Increase the SVG size more significantly
    d = rdMolDraw2D.MolDraw2DSVG(300, 500)  # Increase height to accommodate labels
    d.drawOptions().useBWAtomPalette()
    highlight_atom_colors = {i: get_color(prob) for i, prob in enumerate(canonical_probabilities)}

    d.DrawMolecule(molecule, highlightAtoms=list(highlight_atom_colors.keys()), highlightAtomColors=highlight_atom_colors)
    d.FinishDrawing()
    
    svg = d.GetDrawingText()
    
    # Parse the SVG
    try:
        root = ET.fromstring(svg)
    except ET.ParseError:
        # If parsing fails, try to clean up the SVG
        svg = re.sub(r'</?svg[^>]*>', '', svg)  # Remove svg tags
        svg = f'<svg width="300" height="500">{svg}</svg>'  # Wrap in new svg tags with increased height
        root = ET.fromstring(svg)
    
    # Set or adjust the viewBox
    root.set('viewBox', '0 0 300 500')  # Set viewBox to match new dimensions
    root.set('width', '300')
    root.set('height', '500')
    
    # Add a white background rectangle
    background = ET.SubElement(root, 'rect', {
        'width': '300',
        'height': '500',
        'fill': 'white'
    })
    
    # Move the background to the back
    root.insert(0, background)
    
    # Add label
    text_element = ET.SubElement(root, 'text', {
        'x': '10',
        'y': '320',  # Start labels below the molecule
        'font-family': 'sans-serif',
        'font-size': '14px',  # Increased font size
        'fill': 'black'
    })
    
    y = 340  # Start labels a bit lower
    for line in label.split('\n'):
        tspan = ET.SubElement(text_element, 'tspan', {
            'x': '10',
            'y': str(y)
        })
        tspan.text = line
        y += 20  # Increased line spacing
    
    # Convert back to string
    svg_str = ET.tostring(root, encoding='unicode')
    
    return svg_str, canonical_smiles


def generate_colored_html(smile, probabilities):
    html_string = '<span style="font-family: monospace;">'
    
    for char, prob in zip(smile, probabilities):
        color = get_color(prob)
        css_color = f'rgb({int(color[0]*255)}, {int(color[1]*255)}, {int(color[2]*255)})'
        html_string += f'<span style="background-color: {css_color}; padding: 1px; display: inline-block;">{char}</span>'
    
    html_string += '</span>'
    return html_string



def html_to_svg(html_string, width=600, height=30):
    # Extract the colored spans from the HTML string
    spans = re.findall(r'<span style="background-color: ([^;]+);[^>]+>([^<]+)</span>', html_string)
    
    # Create the SVG string
    svg = f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">'
    svg += '<style>text { font-family: monospace; font-size: 20px; }</style>'
    
    x = 5  # Starting x position
    for color, char in spans:
        svg += f'<rect x="{x}" y="5" width="20" height="20" fill="{color}"/>'
        svg += f'<text x="{x+5}" y="22" fill="black">{char}</text>'
        x += 20  # Move to the next character position
    
    svg += '</svg>'
    return svg

def visualize_molecules(pkl_path, output_folder, num_molecules=5, generations_per_molecule=8):
    result = load_data(pkl_path)
    keys_list = list(result.keys())
    
    os.makedirs(output_folder, exist_ok=True)
    
    for index in range(min(num_molecules, len(keys_list))):
        target_smi = keys_list[index]
        
        # Target molecule
        smi = target_smi
        prob_list = [1 for _ in smi]
        label = "Target"
        #import IPython; IPython.embed();
        
        svg, canonical_smi = generate_colored_molecule(smi, prob_list, "")  # Empty label for SVG
        
        svg_filename = os.path.join(output_folder, f"molecule_{index}_target.svg")
        png_filename = os.path.join(output_folder, f"molecule_{index}_target_label.png")
        str_svg_filename = os.path.join(output_folder, f"molecule_{index}_target_str.svg")
        
        with open(svg_filename, "w", encoding='utf-8') as f:
            f.write(svg)
        
        # Create and save label image
        label_image = create_label_image(label)
        label_image.save(png_filename)
        
        # Generate and save colored SMILES string as SVG

        colored_html = generate_colored_html(canonical_smi, prob_list)
        colored_svg = html_to_svg(colored_html, width=len(canonical_smi)*20+10)  # Adjust width based on SMILES length
        with open(str_svg_filename, "w", encoding='utf-8') as f:
            f.write(colored_svg)
        
        # Generated molecules
        for i in range(1, generations_per_molecule + 1):
            smi = result[target_smi][0][i-1][0]
            #import IPython; IPython.embed();
            
            prob_list = result[target_smi][0][i-1][3].tolist()
            HSQC_err = result[target_smi][0][i-1][5][0]
            COSY_err = result[target_smi][0][i-1][5][1]
            tani_sim = result[target_smi][0][i-1][4]
            prob_product = reduce(operator.mul, prob_list, 1)
            
            label = f"Generation: {i}\nProbability: {prob_product:.2e}\nHSQC Error: {HSQC_err:.3f}\nCOSY Error: {COSY_err:.3f}\nTanimoto Sim: {tani_sim:.3f}"

            svg, canonical_smi = generate_colored_molecule(smi, prob_list, "")  # Empty label for SVG
            
            svg_filename = os.path.join(output_folder, f"molecule_{index}_generation_{i}.svg")
            png_filename = os.path.join(output_folder, f"molecule_{index}_generation_{i}_label.png")
            str_svg_filename = os.path.join(output_folder, f"molecule_{index}_generation_{i}_str.svg")
            
            with open(svg_filename, "w", encoding='utf-8') as f:
                f.write(svg)
            
            # Create and save label image
            label_image = create_label_image(label)
            label_image.save(png_filename)
            
            # Generate and save colored SMILES string as SVG
            colored_html = generate_colored_html(canonical_smi, prob_list)
            colored_svg = html_to_svg(colored_html, width=len(canonical_smi)*20+10)  # Adjust width based on SMILES length
            with open(str_svg_filename, "w", encoding='utf-8') as f:
                f.write(colored_svg)

    print(f"Generated SVGs, label PNGs, and colored SMILES SVGs for {num_molecules} molecules with {generations_per_molecule} generations each.")

    
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import rdDepictor
import rdkit.Chem.Draw as rdDraw
import rdkit.Chem.AllChem as AllChem
import os
import pickle

def generate_noncolored_molecule_svgs(pkl_path, output_folder):
    # Create output directory if it doesn't exist
    os.makedirs(output_folder, exist_ok=True)
    
    # Load pickle file
    with open(pkl_path, 'rb') as f:
        data = pickle.load(f)
    
            
        # Get SMILES from trg_conv_SMI_list
        smiles_list = data['results_dict_ZINC_greedy_bl'].get('trg_conv_SMI_list', [])
        
        for smile_idx, smiles in enumerate(smiles_list):
            #try:
                # Create molecule from SMILES
                mol = Chem.MolFromSmiles(smiles)
                if mol is None:
                    print(f"Could not create molecule from SMILES: {smiles}")
                    continue
                
                # Generate 2D coordinates with the same parameters as colored version
                rdDepictor.Compute2DCoords(mol)
                
                # Create drawer with the same dimensions as in the colored version
                d = rdDraw.rdMolDraw2D.MolDraw2DSVG(300, 500)
                
                # Draw molecule
                d.drawOptions().useBWAtomPalette()
                d.DrawMolecule(mol)
                d.FinishDrawing()
                
                # Get SVG text
                svg = d.GetDrawingText()
                
                # Add white background to match the style of colored versions
                svg = svg.replace('<svg ', '<svg style="background-color: white" ', 1)
                
                # Save SVG file
                output_path = os.path.join(output_folder, f"molecule_{idx}_target_noncolored.svg")
                with open(output_path, 'w') as f:
                    f.write(svg)
                
                print(f"Generated SVG for molecule {idx} at {output_path}")
                
            #except Exception as e:
            #    print(f"Error processing SMILES {smiles}: {str(e)}")
    print(output_path)
    print("Finished generating non-colored SVGs")

# Example usage:
# generate_noncolored_molecule_svgs(pkl_path, output_folder)    
    
    
# The rest of the code remains the same
pkl_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240724_IC_batched_MMT_a1/20240731_115505_exp_sim_data_11_1.pkl'
#pkl_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240724_IC_batched_MMT_a1/20240731_104738_exp_sim_data_9_1.pkl'
#pkl_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240724_IC_batched_MMT_a1/20240731_100214_exp_sim_data_8_0.pkl'
#pkl_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240724_IC_batched_MMT_a1/20240731_084643_exp_sim_data_6_0.pkl'
#pkl_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240724_IC_batched_MMT_a1/20240731_082338_exp_sim_data_5_2.pkl'
pkl_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240724_IC_batched_MMT_a1/20240731_092426_exp_sim_data_7_0.pkl"
output_folder = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20241022_Explainability_Model_aligned_mol_7_0'
visualize_molecules(pkl_path, output_folder)
generate_noncolored_molecule_svgs(pkl_path, output_folder)


Generated SVGs, label PNGs, and colored SMILES SVGs for 5 molecules with 8 generations each.
Generated SVG for molecule 8 at /projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20241022_Explainability_Model_aligned_mol_7_0/molecule_8_target_noncolored.svg
/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20241022_Explainability_Model_aligned_mol_7_0/molecule_8_target_noncolored.svg
Finished generating non-colored SVGs


### Test functions

In [33]:
pkl_path = '/projects/cc/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240724_IC_batched_MMT_a1/20240731_084151_sim_sim_data_6_0.pkl'


In [34]:
result = load_data(pkl_path)


In [35]:
result

{'CCN(CC)CCNC(=O)c1cc(Br)c(N)cc1OC': [[['CCN(CC)CCNC(=O)c1cc(Br)c(N)cc1OC',
    2.760256767272949,
    45.997840881347656,
    tensor([0.9973, 0.9895, 0.9895, 0.9999, 0.9952, 0.9925, 0.9966, 0.9990, 0.9855,
            0.9772, 0.9993, 0.9990, 0.9991, 1.0000, 0.9992, 1.0000, 0.9906, 0.9955,
            0.9984, 0.0464, 1.0000, 0.9974, 0.9576, 0.9407, 0.9564, 0.9992, 0.9567,
            1.0000, 0.9579, 0.9966, 0.9980], device='cuda:0'),
    1.0,
    [0.0019983669366186213, 0.0016715491873415774]],
   ['CCN(CC)CCNC(=O)c1cc(N)c(Br)cc1OC',
    2.8225786685943604,
    46.96356964111328,
    tensor([0.9973, 0.9895, 0.9895, 0.9999, 0.9952, 0.9925, 0.9966, 0.9990, 0.9855,
            0.9772, 0.9993, 0.9990, 0.9991, 1.0000, 0.9992, 1.0000, 0.9906, 0.9955,
            0.9984, 0.8611, 0.8570, 0.9979, 0.9771, 0.0177, 1.0000, 0.9996, 0.9518,
            1.0000, 0.9553, 0.9967, 0.9982], device='cuda:0'),
    0.9069767441860465,
    [0.02128619997786351, 0.006526293427044074]],
   ['CCN(CC)CCNC(=O)c1cc

In [17]:
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Chem import rdDepictor
import os
from PIL import Image, ImageDraw, ImageFont
import matplotlib.font_manager as fm

def create_dummy_molecules():
    smiles_list = [
        'CCc1ccc(CC(=O)Nc2ccccc2)cc1',
        'CCc1ccc(CC(=O)Nc2cccc(Cl)c2)cc1',
        'CCc1ccc(CC(=O)Nc2ccc(F)cc2)cc1'
    ]
    mols = [Chem.MolFromSmiles(smiles) for smiles in smiles_list]
    return mols

def align_molecules(mols):
    aligned_mols = []
    for mol in mols:
        # Generate a canonical SMILES
        can_smiles = Chem.CanonSmiles(Chem.MolToSmiles(mol))
        # Create a new molecule from the canonical SMILES
        aligned_mol = Chem.MolFromSmiles(can_smiles)
        # Generate 2D coordinates
        rdDepictor.Compute2DCoords(aligned_mol)
        aligned_mols.append(aligned_mol)
    return aligned_mols

def create_label_image(label, width=300, height=50):
    image = Image.new('RGB', (width, height), color='white')
    draw = ImageDraw.Draw(image)
    
    try:
        font_path = fm.findfont(fm.FontProperties(family='DejaVu Sans'))
        font = ImageFont.truetype(font_path, 14)
    except:
        font = ImageFont.load_default()

    draw.text((5, 5), label, fill='black', font=font)

    return image

def save_molecule_images(mols, output_folder, aligned=False):
    os.makedirs(output_folder, exist_ok=True)
    
    for i, mol in enumerate(mols):
        # Generate and save molecule image
        img = Draw.MolToImage(mol, size=(300, 300))
        img_filename = os.path.join(output_folder, f"molecule_{i}_{'aligned' if aligned else 'original'}.png")
        img.save(img_filename)
        
        # Create and save label image
        label = f"{'Aligned' if aligned else 'Original'} Molecule {i+1}"
        label_image = create_label_image(label)
        label_filename = os.path.join(output_folder, f"molecule_{i}_{'aligned' if aligned else 'original'}_label.png")
        label_image.save(label_filename)

# Main execution
output_folder = '/projects/cc/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20241014_Explainability_Model'
dummy_mols = create_dummy_molecules()
save_molecule_images(dummy_mols, output_folder, aligned=False)

aligned_mols = align_molecules(dummy_mols)
save_molecule_images(aligned_mols, output_folder, aligned=True)

print(f"Images saved in the '{output_folder}' directory.")

Images saved in the '/projects/cc/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20241014_Explainability_Model' directory.


In [15]:
!pwd

/projects/cc/knlr326/1_NMR_project/2_Notebooks/MultiModalTransformer
