# MultiModalSpectralTransformer


In [None]:
!nvidia-smi

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Pastel color palette hex codes
pastel_colors = [
    '#A1C8F3',  # light blue/periwinkle
    '#FFB381',  # salmon/peach
    '#8BE5A0',  # mint green
    '#FF9D9A',  # coral pink
    '#D1B9FE',  # lavender
    '#DEBA9A',  # beige/tan
    '#FCAEE3',  # pink
    '#CFCECE',  # light gray
    '#FEFDA2',  # pale yellow
    '#B8F1EF',  # light cyan/aqua
]

# Create sample data
categories = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J']
values = np.random.randint(10, 50, size=10)

# Create the bar plot
plt.figure(figsize=(12, 6))
bars = plt.bar(categories, values)

# Set each bar's color according to the pastel palette
for bar, color in zip(bars, pastel_colors):
    bar.set_color(color)

# Customize the plot
plt.title('Sample Bar Plot with Pastel Colors', pad=20, size=14)
plt.xlabel('Categories', labelpad=10)
plt.ylabel('Values', labelpad=10)
plt.grid(axis='y', linestyle='--', alpha=0.7)

# Add value labels on top of each bar
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{int(height)}',
             ha='center', va='bottom')

plt.tight_layout()
plt.show()

### Config

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
# Core libraries
import json
import os
import random
import glob
import pickle
from datetime import datetime
import tempfile
import copy
import statistics

# Data processing and scientific computing
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import defaultdict

# Machine learning and data visualization
import matplotlib.pyplot as plt
import umap
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

# PyTorch for deep learning
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F

# RDKit for cheminformatics
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, Draw, MolFromSmiles, MolToSmiles
from rdkit.Chem import Descriptors

# tqdm for progress bars
from tqdm.autonotebook import tqdm

# Weights & Biases for experiment tracking
import wandb

# Miscellaneous
from argparse import Namespace
from IPython.display import HTML, SVG

# Setting up environment
torch.cuda.device_count()
wandb.login()
import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.Descriptors import MolWt
from rdkit import DataStructs
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
from sklearn.manifold import TSNE
from tqdm import tqdm


In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.profiler import SimpleProfiler, AdvancedProfiler
from pytorch_lightning.loggers import WandbLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from torch.utils.data.distributed import DistributedSampler
import os
import random

import pytorch_lightning as pl
from pytorch_lightning.loggers.wandb import WandbLogger

In [None]:
import utils_MMT.clip_functions_v15_4 as cf #
import utils_MMT.MT_functions_v15_4 as mtf # is different compared to V14_1
import utils_MMT.validate_generate_MMT_v15_4 as vgmmt #
import utils_MMT.run_batch_gen_val_MMT_v15_4 as rbgvm #
import utils_MMT.clustering_visualization_v15_4 as cv #
import utils_MMT.plotting_v15_4 as pt #
import utils_MMT.execution_function_v15_4 as ex #
import utils_MMT.train_test_functions_pl_v15_4 as ttf
import utils_MMT.ir_simulation_v15_4 as irs
import utils_MMT.helper_functions_pl_v15_4 as hf
import utils_MMT.mmt_result_test_functions_15_4 as mrtf
###
import utils_MMT.experiment_function_v15_4 as exp_func

In [None]:

def load_json_dics():
    with open('./itos.json', 'r') as f:
        itos = json.load(f)
    with open('./stoi.json', 'r') as f:
        stoi = json.load(f)
    with open('./stoi_MF.json', 'r') as f:
        stoi_MF = json.load(f)
    with open('./itos_MF.json', 'r') as f:
        itos_MF = json.load(f)    
    return itos, stoi, stoi_MF, itos_MF
    
itos, stoi, stoi_MF, itos_MF = load_json_dics()
rand_num = str(random.randint(1, 10000000))

In [None]:
IR_config_dict = {
    "gpu": list(range(torch.cuda.device_count())),  # Default value is None, should be one of the available GPU indices
    "test_path": ["/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MultiModalSpectralTransformer/chemprop-IR/ir_models_data/solvation_example/solvation_spectra.csv"],  # Default value is None
    "use_compound_names": [False],  # Default is False
    "preds_path": ["/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MultiModalSpectralTransformer/chemprop-IR/ir_models_data/ir_preds_test_2.csv"],  # Default value is None
    "checkpoint_dir": ["/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MultiModalSpectralTransformer/chemprop-IR/ir_models_data/experiment_model/model_files"],  # Default value is None
    "spectra_type": ["experimental"],  # ["experimental", "simulated"] Default value is None
    "spectra_type_nr": [0],  # 0-4 Default value is None
    "checkpoint_path": [None],  # Default value is None
    "batch_size": [64],  # Default is 50
    "no_cuda": [False],  # Default is False
    "features_generator":[None],  # Default value is None, should be one of the available features generators
    "features_path": [None],  # Default value is None
    "max_data_size": [100],  # Default value is None
    "ensemble_variance": [False],  # Default is False
    "ensemble_variance_conv": [0.0],  # Default is 0.0
}

In [None]:
hyperparameters = {
    # General project information
    "project": ["Improv_Cycle_v1"],  # Name of the project for wandb monitoring
    "ran_num":[rand_num],
    "device": ["cuda"], # device on which training takes place
    "gpu_num":[1], # number of GPUs for training with pytorch lightning
    "num_workers":[4], # Needs to stay 1 otherwise code crashes - ToDO
    "data_type":["sgnn"], #["sgnn", "exp", "acd", "real", "inference"], Different data types to select
    "execution_type":["validate_MMT"], #[ "plot_similarities", "simulate_real", "test_performance", "SMI_generation_MMT", "SMI_generation_MF", "data_generation", "transformer_training","transformer_improvement", "clip_training", "clip_improvement", "validate_MMT"] # different networks to select for training
    "syn_data_simulated": [False],  # For the improvment cycle a ticker that shows whether data has been simulated or not.
    "training_type":["clip"], #["clip","transformer"] # different networks to select for training

    # Encoding dicts
    "itos_path":["/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MultiModalSpectralTransformer/itos.json"],
    "stoi_path":["/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MultiModalSpectralTransformer/stoi.json"],
    "itos_MF_path":["/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MultiModalSpectralTransformer/itos_MF.json"],
    "stoi_MF_path":["/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MultiModalSpectralTransformer/stoi_MF.json"],
    
    ### Data settings
    "input_dim_1H":[2], # Imput dimensions of the 1H data
    "input_dim_13C": [1], # Imput dimensions of the 13C data
    "input_dim_HSQC": [2], # Imput dimensions of the HSQC data
    "input_dim_COSY": [2],  # Imput dimensions of the COSY data
    "input_dim_IR": [1000],  # Imput dimensions of the IR data
    "MF_vocab_size": [len(stoi_MF)],  # New, size of the vocabulary for molecular formulas
    "MS_vocab_size": [len(stoi)],  # New, size of the vocabulary for molecular formulas
    "tr_te_split":[0.9], # Train-Test split
    "padding_points_number":[64], # Padding number for the embedding layer into the network
    "data_size": [1000], # number of datapoints for the training 3975764/1797828
    "test_size": [10], # number of datapoints for the training 3975764
    "model_save_dir": ["/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/test"], # Folder where networks are saved
    "ML_dump_folder": ["/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/dump"], # a folder where intermediate files for the SGNN network are generated
    "model_save_interval": [10000], # seconds passed until next model is saved
    
    # Option 1 SGNN
    "use_real_data":[False], #[True, False]
    "ref_data_type":["1H"], #["1H","13C","HSQC","COSY","IR"]
    "csv_train_path": ['/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/ML_NMR_5M_XL_1H_comb_train_V8.csv'], # To keep a reference of the compounds that it was trained on
    "csv_1H_path_SGNN": ['/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/ML_NMR_5M_XL_1H_comb_train_V8.csv'],
    "csv_13C_path_SGNN": ['/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/ML_NMR_5M_XL_13C.csv'],    
    "csv_HSQC_path_SGNN": ['/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/ML_NMR_5M_XL_HSQC.csv'],    
    "csv_COSY_path_SGNN": ['/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/ML_NMR_5M_XL_COSY.csv'],      
    "csv_IR_MF_path": [''],     #571124
    "csv_path_val": ['/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/ML_NMR_5M_XL_1H_comb_test_V8.csv'], #63459   
    #"IR_data_folder": ["/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/IR_spectra_NN"],
    "IR_data_folder": [""],

   # "pickle_file_path": ["/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/ML_NMR_5M_XL_1H_comb_V8_938756.pkl"],
    "pickle_file_path": ["/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/ML_NMR_5M_XL_1H_comb_test_V8_355655.pkl"],
    
    "dl_mode": ['val'], #["val","train"]   
    "isomericSmiles": [False], # whether stereochemistry is considered or not
    
    # Option 2 exp
    #"exp_path": ['/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/9_ZINC_250k/missing_ZINC_files.csv'], #63459   

     # Option 2 ACD
    #"csv_path_1H_ACD": ['/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/9_ZINC_250k/1H_ZINC_XL_v3.csv'],
    #"data_folder_HSQC_ACD": ["/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/9_ZINC_250k/zinc250k"],
    # Option 3 real
    "comparision_number": [1000],  # With how many of the training examples should it be compared with in a t-SNE plot
    "vector_db": ['/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/smiles_fingerprints_train_4M_v1.csv'],    
    "secret_csv_SMI_vectors": ['/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/25_Test_Improvement_cycle/1_Test_ZINC_250/test_32_zinc250_vec_db.csv'],    
    "secret_csv_SMI_targets": ['/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/25_Test_Improvement_cycle/1_Test_ZINC_250/test_32_zinc250.csv'],    
    "secret_csv_SMI_sim_searched": ['/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/25_Test_Improvement_cycle/1_Test_ZINC_250/test_32_zinc250.csv'],    
    "csv_SMI_targets": ['/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/25_Test_Improvement_cycle/1_Test_ZINC_250/test_32_zinc250_single_target_919.csv'],
    "csv_1H_path_REAL": [''],
    "csv_13C_path_REAL": [''],    
    "csv_HSQC_path_REAL": [''],    
    "csv_COSY_path_REAL": [''],    
    #"pkl_path_HSQC_real": [""],

    #### Transformer Settings ####
    # Training and model settings
    "training_mode":["1H_13C_HSQC_COSY_IR_MF_MW"], #["edding_src_1H = torch.zeros((feature_dim, current_ba"], Modalities selected for training
    "blank_percentage":[0.0], # percentage of spectra that are blanked out during training for better generalizability of the network to various datatypes
    "batch_size":[64], # number needs to be the same as number of GPUs 
    "num_epochs": [10], # number of epochs for training
    "lr_pretraining": [1e-4], # Pretraining learning rate
    "lr_finetuning": [5e-5], # Finetuning learning rate
    "load_model": [True], # if model should be loaded from path
    "checkpoint_path":["/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8_MMT_MW_Drop/MultimodalTransformer_time_1704760608.6927748_Loss_0.137.ckpt"], #V8
    "save_model": [True], # if model should be saved
    
    # Model architecture
    "in_size": [len(stoi)],
    "hidden_size": [128],
    "out_size": [len(stoi)],
    "num_encoder_layers": [6], #8
    "num_decoder_layers": [6], #8
    "num_heads": [16], #8  ### number of attention heads
    "forward_expansion": [4], #4
    "max_len": [128], # maximum length of the generated sequence
    "drop_out": [0.1],
    "fingerprint_size": [512], # Dimensions of encoder output for CLIP contrastive training    
    "gen_SMI_sequence":[True], # If the model generates a sequence with the SMILES current model for evaluation
    "sampling_method":["mix"], # weight_mol_weight ["multinomial", "greedy". "mix"]  
    "training_setup":["pretraining"], # ["pretraining","finetuning"]
    "smi_randomizer":[False], # if smiles are randomized or canonical during training
    
    ### SGNN Feedback
    "sgnn_feedback":[False], # if SGNN generates 1H and 13C spectrum on the fly on the generated smiles -> "gen_SMI_sequence":[True]
    "matching":["HungDist"], #["MinSum","EucDist","HungDist"], # HSQC point matching technique used
    "padding":["NN"], # ["Zero","Trunc","NN"], # HSQC padding technique used -> see publication: XXX
    # Weight feedback
    "train_weight_min":[None], # Calculate on the fly - Used for the weight loss calculation for scaling
    "train_weight_max":[None], # Calculate on the fly - Used for the weight loss calculation for scaling
    # Training Loss Weighting options
    "weight_validity": [0.0], # up to 1
    "weight_SMI": [1.0], # up to 1
    "weight_FP": [0.0], # up to 1
    "weight_MW": [0], # up to 100
    "weight_sgnn": [0.0], # up to 10
    "weight_tanimoto": [0.0], # up to 1
    "change_loss_weights":[False], # if selected the weights get ajusted along the training
    "increment":[0.01], # increment on how much it gets ajusted during training -> TODO
    "batch_frequency":[10000], # Frequency how often it gets ajusted -> TODO
    
    ### For Validation
    "beam_size": [1],  
    "multinom_runs": [1], 
    "temperature":[1],
    "gen_len":[64],
    "pkl_save_folder":["/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/pkl_save_folder"],
    
    ### Molformer options 
    "MF_max_trails":[500],
    "MF_tanimoto_filter":[0.1],
    "MF_filter_higher":[1], # False = 0 True = 1
    "MF_delta_weight":[5],
    "MF_generations":[30],
    "MF_model_path":["/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MultiModalSpectralTransformer/deep-molecular-optimization/experiments/trained/Alessandro_big/weights_pubchem_with_counts_and_rank_sanitized.ckpt"],
    "MF_vocab":["/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MultiModalSpectralTransformer/deep-molecular-optimization/experiments/trained/Alessandro_big/vocab_new.pkl"],
    "MF_csv_source_folder_location":["/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MultiModalSpectralTransformer/deep-molecular-optimization/data/MMP"],
    "MF_csv_source_file_name":["test_selection_2"],
    "MF_methods":["MMP"], #["MMP", "scaffold", "MMP_scaffold"],    
    "max_scaffold_generations":[10], #
    
    ### MMT batch generation
    "MMT_batch":[32], # how big is the batch of copies of the same inputs that is processed by MMT 
    "MMT_generations":[4], # need to be multiple of MMT_batch -> number of valid generated molecules
    #------------------------
    "n_samples":[10], # number of molecules that should be processed for data generation - needs to be smaller than dataloader size
    "gen_mol_csv_folder_path":["/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/24_SGNN_gen_folder_2"], # number of molecules that should be processed for data generation - needs to be smaller than dataloader size
    
    ### Fine-tuning improvement options
    "train_data_blend":[0], # how many additional molecules should be added to the new dataset from the original training dataset
    "train_data_blend_CLIP":[1000], # how many additional molecules should be added to the new dataset from the original training dataset
    
    ### Data generation SGNN -> 1H, 13C, HSQC, COSY
    "SGNN_gen_folder_path":["/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/24_SGNN_gen_folder_2/dump_2"],
    "SGNN_csv_gen_smi":["/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MultiModalSpectralTransformer/deep-molecular-optimization/data/MMP/test_selection_1.csv"],
    "SGNN_size_filter":[550],
    "SGNN_csv_save_folder":["/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/24_SGNN_gen_folder_2"],
    "IR_save_folder":["/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/24_SGNN_gen_folder_2/IR_data"],
    
    #################################################
    #### LEGACY parameters for other expeirments ####
    #################################################
    #### CLIP Settings ####
    ### ChemBerta
    "model_version":["/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v1/Chemberta_source"],   # Source of pretrained chemberta from paper
    "CB_model_path":["/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v1/Large_300_15.pth"], # path to pretrained Chemberta model
    "num_class":[1024], #
    "num_linear_layers":[0], # number of linear layers in architecture before num_class output
    "use_dropout":[True],
    "use_relu":[False],
    "loss_fn":["BCEWithLogitsLoss"], #"MSELoss", "BCELoss", 
    "CB_embedding": [1024], #1024
    # PCA
    "fp_dim_reduction":[False], #True
    "pca_components":[300],  
    #"CB_model_name": ["Large_300_15"],

    ### Multimodal Transformer
    "MT_model_path":["/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8_MMT_MW2_Drop/MultimodalTransformer_time_1706856620.3718672_Loss_0.202.pth"],  # path to pretrained Multimodal Transformer model  
    #"MT_model_name": ["SpectrumBERT_PCA_large_3.6"],
    "MT_embedding": [512], #512
    ### Projection Head
    "projection_dim": [512],
    "dropout": [0.1],
    
    #CLIP
    # Dataloader settings
    "similarity_threshold":[0.6], # Filtere that selects just molecules with a tanimotosimilarity higher than that number
    "max_search_size":[10000], # Size of the data that will be searched to find the similar molecules  # 100000
    "weight_delta":[50], # Filter to molecules with a +/- delta weight of that numbeTraceback (most recent call last):
    "CLIP_batch_size":[128],  #,64,128,256 ### batch size for the CLIP training
    "CLIP_NUM_EPOCHS": [10],    # Number of training epochs
    
    ### Train parameters
    ### CLIP Model   
    "CLIP_temperature": [1],
    #"CB_projection_lr": [1e-3], # projection head learning rate for Chemberta
    "MT_projection_lr": [1e-3], # projection head learning rate for Multimodal Transfomer
    "CB_lr": [1e-4], # Chemberta Learning Rate
    "MT_lr": [1e-5], # Multimodal Transfomer Learning Rate
    "weight_decay": [1e-3], # Weight decay for projection heads -> TODO why just on those
    "patience": [1],   # not integrated yet
    "factor": [0.8],   # not integrated yet
    "CLIP_continue_training":[True],
    "CLIP_model_path":["/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8_modalities_CLIP_1_dot_product/MultimodalCLIP_Epoch_9_Loss0.096.ckpt"],   
    "CLIP_model_save_dir":["/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/test_CLIP"],
    
    ### BLIP Model
    "BLIP_temperature": [1],
    "Qformer_lr":[1e-4],
    "Qformer_CB_lr":[1e-4],
    "Qformer_MT_lr":[1e-4],
    "BLIP_continue_training":[True],
    # "BLIP_model_path":["/projects/cc/se_users/knlr326//knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/test_BLIP_1M/model_BLIP-epoch=03-loss=2.54_v0.ckpt"],   
    # "BLIP_model_save_dir":["/projects/cc/se_users/knlr326//knlr326//knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/test_BLIP_1M"],
    
    }



In [None]:

def save_config(config, path):
    with open(path, 'w') as f:
        json.dump(config, f)

def load_config(path):
    try:
        with open(path, 'r') as f:
            return json.load(f)
    except FileNotFoundError:
        return None    

def parse_arguments(hyperparameters):
    # Using dictionary comprehension to simplify your code
    parsed_args = {key: val[0] for key, val in hyperparameters.items()}
    return Namespace(**parsed_args)


config = parse_arguments(hyperparameters)
ir_config_path = './utils_MMT/ir_config_V8.json'
save_config(IR_config_dict, ir_config_path)
IR_config_dict = load_config(ir_config_path)
IR_config = parse_arguments(IR_config_dict)
irs.modify_predict_args(IR_config)


config_path = './utils_MMT/config_V8.json'
save_config(hyperparameters, config_path)
config_dict = load_config(config_path)
config = parse_arguments(config_dict)

### 0.0 Test Sized Models and Different Training Data (Model with IR as data input)



##### Data size effect - Training times effect


In [None]:
"""
# V8 Raw 1M
#config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_1Mio/model-epoch=07-loss=0.10.ckpt"
#same batches as 20M mol seen
config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_1Mio/model-epoch=19-loss=0.04.ckpt"


model_MMT = mrtf.load_MMT_model(config)
prob_dict_results_0a, results_dict_0a = mrtf.run_model_analysis(config, model_MMT, val_dataloader_multi, stoi, itos)
print(np.mean(results_dict_0a["tanimoto_sim"]))

import pickle
# Save the data to a file
#file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/0.0_prob_dict_results_1M_L_epoch_7.pkl'  # Replace with your desired path
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/0.0_prob_dict_results_1M_L_20M_mol.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'wb') as file:
    pickle.dump(prob_dict_results_0a, file)
    

# Save the data to a file
#file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/0.1_results_dict_1M_L_epoch_7.pkl'  # Replace with your desired path
file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/0.1_results_dict_1M_L_20M_mol.pkl'  # Replace with your desired path
with open(file_results_dict_path, 'wb') as file:
    pickle.dump(results_dict_0a, file)"""

In [None]:
"""
# V8 Raw 2M
#config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_2Mio/model-epoch=07-loss=0.06.ckpt"
#same batches as 20M mol seen
config.checkpoint_path = "/projects/cc/se_users/knlr326//knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_2Mio/model-epoch=09-loss=0.04.ckpt"

model_MMT = mrtf.load_MMT_model(config)
prob_dict_results_0b, results_dict_0b = mrtf.run_model_analysis(config, model_MMT, val_dataloader_multi, stoi, itos)
print(np.mean(results_dict_0b["tanimoto_sim"]))


import pickle
# Save the data to a file
#file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/0.0_prob_dict_results_2M_L_epoch_7.pkl'  # Replace with your desired path
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/0.0_prob_dict_results_2M_L_20M_mol.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'wb') as file:
    pickle.dump(prob_dict_results_0b, file)
    

# Save the data to a file
#file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/0.1_results_dict_2M_L_epoch_7.pkl'  # Replace with your desired path
file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/0.1_results_dict_2M_L_20M_mol.pkl'  # Replace with your desired path
with open(file_results_dict_path, 'wb') as file:
    pickle.dump(results_dict_0b, file)"""

In [None]:
"""
# V8 Raw 4M
#config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_4Mio/model-epoch=07-loss=0.02.ckpt"
#same batches as 20M mol seen
config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_4Mio/model-epoch=04-loss=0.04.ckpt"

model_MMT = mrtf.load_MMT_model(config)
prob_dict_results_0b, results_dict_0b = mrtf.run_model_analysis(config, model_MMT, val_dataloader_multi, stoi, itos)

import pickle
# Save the data to a file
#file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/0.0_prob_dict_results_4M_L_epoch_7.pkl'  # Replace with your desired path
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/0.0_prob_dict_results_4M_L_20M_mol.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'wb') as file:
    pickle.dump(prob_dict_results_0b, file)
    

# Save the data to a file
#file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/0.1_results_dict_4M_L_epoch_7.pkl'  # Replace with your desired path
file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/0.1_results_dict_4M_L_20M_mol.pkl'  # Replace with your desired path
with open(file_results_dict_path, 'wb') as file:
    pickle.dump(results_dict_0b, file)"""

#### Load saved data - EPOCH 7
- Different data sizes



In [None]:
### Epoch 7 for all molecules seen

import pickle
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240430_Epoch_7/0.0_prob_dict_results_4M_L_epoch_7.pkl'  # Replace with your desired path
# Loading results_list_b
with open(file_prob_dict_path, 'rb') as file:
    prob_dict_results_0a = pickle.load(file)
    

file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240430_Epoch_7/0.1_results_dict_4M_L_epoch_7.pkl'  # Replace with your desired path
# Loading results_list_b
with open(file_results_dict_path, 'rb') as file:
    results_dict_0a = pickle.load(file)

    
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240430_Epoch_7/0.0_prob_dict_results_2M_L_epoch_7.pkl'  # Replace with your desired path
# Loading results_list_b
with open(file_prob_dict_path, 'rb') as file:
    prob_dict_results_0b = pickle.load(file)
    

file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240430_Epoch_7/0.1_results_dict_2M_L_epoch_7.pkl'  # Replace with your desired path
# Loading results_list_b
with open(file_results_dict_path, 'rb') as file:
    results_dict_0b = pickle.load(file)
    

file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240430_Epoch_7/0.0_prob_dict_results_1M_L_epoch_7.pkl'  # Replace with your desired path
# Loading results_list_b
with open(file_prob_dict_path, 'rb') as file:
    prob_dict_results_0c = pickle.load(file)
    

file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240430_Epoch_7/0.1_results_dict_1M_L_epoch_7.pkl'  # Replace with your desired path
# Loading results_list_b
with open(file_results_dict_path, 'rb') as file:
    results_dict_0c = pickle.load(file)
    



In [None]:
### add 0.1M to this one 
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240430_Epoch_7/0.0_prob_dict_results_0.1M_L_epoch_7.pkl'  # Replace with your desired path
# Loading results_list_b
with open(file_prob_dict_path, 'rb') as file:
    prob_dict_results_0d = pickle.load(file)
    

file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240430_Epoch_7/0.1_results_dict_0.1M_L_epoch_7.pkl'  # Replace with your desired path
# Loading results_list_b
with open(file_results_dict_path, 'rb') as file:
    results_dict_0d = pickle.load(file)
    

In [None]:

prob_dict_results_2 = [ prob_dict_results_0c, prob_dict_results_0b, prob_dict_results_0a]
results_dict_2 = [results_dict_0c, results_dict_0b, results_dict_0a]

In [None]:

prob_dict_results_2 = [ prob_dict_results_0d, prob_dict_results_0c, prob_dict_results_0b, prob_dict_results_0a]
results_dict_2 = [results_dict_0d, results_dict_0c, results_dict_0b, results_dict_0a]

##### Violin Chart

In [None]:
labels = ['0.1M Epoch 7', '1M Epoch 7', '2M Epoch 7', '4M Epoch 7']
prob_dict_results_2 = [ prob_dict_results_0d, prob_dict_results_0c, prob_dict_results_0b, prob_dict_results_0a]

# Calculate mean and standard deviation for each dictionary
mean_results = []
std_results = []
data_for_violin = []
for prob_dict in prob_dict_results_2:
    mean_prob = np.mean(prob_dict["aggregated_corr_prob_multi"])
    mean_results.append(mean_prob)
    std_value = statistics.stdev(prob_dict["aggregated_corr_prob_multi"])
    std_results.append(std_value)
    data_for_violin.append(prob_dict["aggregated_corr_prob_multi"])

# Create the violin plot
fig, ax = plt.subplots(figsize=(8, 10))
parts = ax.violinplot(data_for_violin, showmeans=True, showmedians=False, showextrema=False)

# Customizing colors
color = '#A1C8F3'  # Use a single color for all the violin plots
for pc in parts['bodies']:
    pc.set_facecolor(color)
    pc.set_edgecolor('black')
    pc.set_alpha(0.7)

# Customizing the axes and labels
ax.set_title('Correct SMILES Sample Probability', fontsize=22)
#ax.set_xlabel('Trained Model', fontsize=16)
ax.set_ylabel('Probability of Correct SMILES', fontsize=22)
ax.set_xticks(np.arange(1, len(labels) + 1))
ax.set_xticklabels(labels, rotation=45, ha='center', fontsize=22)
ax.tick_params(axis='both', which='major', labelsize=22)  # Adjust tick label size here

mean_values = [np.mean(data) if data is not None else 0 for data in data_for_violin]

# Add mean values as text
for pos, mean_value in zip(np.arange(1, len(mean_values) + 1), mean_values):
    ax.text(pos, mean_value, f'{mean_value:.2f}', ha='center', va='bottom', fontsize=22)

# Adding a legend
lightseagreen_patch = plt.Line2D([0], [0], color='lightseagreen', lw=4, label='MMST')
#ax.legend(handles=[lightseagreen_patch], loc='upper left', fontsize=22)

# Add grid and set the limits
ax.grid(axis='y', linestyle='--', alpha=0.7)
ax.set_ylim(0, 1)  # Adjust based on your data's range

plt.tight_layout()
save_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/0.0_Violin_Epoch_7_v3.png'  # Replace with your desired path
plt.savefig(save_path, format='png')
plt.show()


##### Tanimoto Chart

In [None]:
results_dict_2 = [results_dict_0c, results_dict_0b, results_dict_0a]


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Sample input data
#labels = ['MMTi \n0.1 Mol - 20M', 'MMTi \n1 Mol - 20M', 'MMTi \n2 Mol - 20M', 'MMTi \n4 Mol - 20M']
#results_dict_2 = [results_dict_0e, results_dict_0c, results_dict_0b, results_dict_0a]

labels = ['0.1M Epoch 7', '1M Epoch 7', '2M Epoch 7', '4M Epoch 7']
results_dict_2 = [results_dict_0d,results_dict_0c, results_dict_0b, results_dict_0a]


#labels = ['MMTi \n0.1 Mol - 20M', 'MMTi \n0.5 Mol - 20M', 'MMTi \n1 Mol - 20M', 'MMTi \n2 Mol - 20M', 'MMTi \n4 Mol - 20M']
#results_dict_2 = [results_dict_0e, results_dict_0d, results_dict_0c, results_dict_0b, results_dict_0a]

# Prepare data
data_for_violin = [d["tanimoto_sim"] for d in results_dict_2]

# Calculate means for each dataset
mean_values = [np.mean(data) for data in data_for_violin]

# Labels for each subplot

# Create the violin plot
fig, ax = plt.subplots(figsize=(8, 10))
parts = ax.violinplot(data_for_violin, showmeans=True, showmedians=False, showextrema=False)

# Customizing colors to all lightseagreen
color = '#A1C8F3'
for pc in parts['bodies']:
    pc.set_facecolor(color)
    pc.set_edgecolor('black')
    pc.set_alpha(0.7)

# Customizing the axes and labels
ax.set_title('Greedy Sampled Average Tanimoto Similarity', fontsize=22)
ax.set_ylabel('Average Tanimoto Similarity', fontsize=22)

ax.set_xticks(np.arange(1, len(labels) + 1))
ax.set_xticklabels(labels, rotation=45, ha='center', fontsize=22)

# Adding a legend
lightseagreen_patch = plt.Line2D([0], [0], color='lightseagreen', lw=4, label='MMST')
#ax.legend(handles=[lightseagreen_patch], loc='lower left', fontsize=22)

# Add grid and set the limits
ax.grid(axis='y', linestyle='--', alpha=0.7)
ax.set_ylim(0, 1)
ax.tick_params(axis='both', which='major', labelsize=22)  # Adjust tick label size here

# Add mean values as text
for pos, mean_value in zip(np.arange(1, len(mean_values) + 1), mean_values):
    ax.text(pos, mean_value, f'{mean_value:.2f}', ha='center', va='bottom', fontsize=22)

plt.tight_layout()
save_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/0.0_Tanimoto_Violin_plot_Epoch_7_v3.png'
plt.savefig(save_path, format='png')
plt.show()


##### Number of Invalid molecules

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Sample input data
labels = ['0.1M Epoch 7', '1M Epoch 7', '2M Epoch 7', '4M Epoch 7']
results_dict_2 = [results_dict_0d, results_dict_0c, results_dict_0b, results_dict_0a]

# Prepare data
mean_results_2 = [len(d["failed"]) for d in results_dict_2]  # Assuming "failed" is a key in each dictionary
total_entries = len(results_dict_2[0]["gen_conv_SMI_list"]) # Total for percentage calculation

# Define colors for the bars
color2 = '#A1C8F3'

# Set bar width
bar_width = 0.35
positions = np.arange(len(labels))

# Create the plot
fig, ax = plt.subplots(figsize=(8, 10))

# Plotting the bars
bar1 = ax.bar(positions, mean_results_2, bar_width, color=color2, edgecolor='black', label='MMST')

# Adding value labels inside and percentage on top of each bar
for bar in bar1:
    yval = bar.get_height()
    percentage = (yval / total_entries) * 100 if total_entries > 0 else 0  # Calculate percentage
    
    # Inside the bar
    ax.text(bar.get_x() + bar.get_width() / 2, yval / 2, f'{yval:,.0f}', 
            rotation=90, ha='center', va='center', fontsize=22, color='black')
    
    # Top of the bar (moved higher)
    ax.text(bar.get_x() + bar.get_width() / 2, yval + 2000, f'{percentage:.1f}%', 
            rotation=90, ha='center', va='bottom', fontsize=22)

# Set the title and labels
plt.title(f'Greedy Sampled Number of Invalid SMILES           ', fontsize=22)
plt.ylabel('Invalid Molecules', fontsize=22)
plt.xticks(positions, labels, ha='center', fontsize=22)
ax.set_xticklabels(labels, rotation=45, ha='center', fontsize=22)

# Adding grid lines for better readability
ax.grid(axis='y', linestyle='--', alpha=0.7)

# Set y-limit slightly higher than max for label visibility
ax.set_ylim(0, max(mean_results_2) * 1.25)  

ax.tick_params(axis='both', which='major', labelsize=22)  # Adjust tick label size here

# Set aspect ratio
ax.set_aspect(aspect='auto')

# Adding a legend
#ax.legend(fontsize=22)

# Adjust layout
plt.tight_layout()

# Save the figure
save_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/0.0_Invalid_molecules_Epoch_7_v3.png'
plt.savefig(save_path, format='png')

# Show plot
plt.show()

#### Load saved data - 20M trained network
- There is something wrong with the 0.5M - leave it out for now

In [None]:
### Exactly 20M molecules seen
### Currently plotted the old - wrong results (without IR but trained with IR)

import pickle
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240508_20M_Molecules/0.0_prob_dict_results_4M_L_20M_mol.pkl'  # Replace with your desired path
# Loading results_list_b
with open(file_prob_dict_path, 'rb') as file:
    prob_dict_results_0a = pickle.load(file)
    

file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240508_20M_Molecules/0.1_results_dict_4M_L_20M_mol.pkl'  # Replace with your desired path
# Loading results_list_b
with open(file_results_dict_path, 'rb') as file:
    results_dict_0a = pickle.load(file)

    
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240508_20M_Molecules/0.0_prob_dict_results_2M_L_20M_mol.pkl'  # Replace with your desired path
# Loading results_list_b
with open(file_prob_dict_path, 'rb') as file:
    prob_dict_results_0b = pickle.load(file)
    

file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240330_20M_Molecules_experiment_FALSE/0.1_results_dict_2M_L_20M_mol.pkl'  # Replace with your desired path
# Loading results_list_b
with open(file_results_dict_path, 'rb') as file:
    results_dict_0b = pickle.load(file)
    

file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240330_20M_Molecules_experiment_FALSE/0.0_prob_dict_results_1M_L_20M_mol.pkl'  # Replace with your desired path
# Loading results_list_b
with open(file_prob_dict_path, 'rb') as file:
    prob_dict_results_0c = pickle.load(file)
    

file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240330_20M_Molecules_experiment_FALSE/0.1_results_dict_1M_L_20M_mol.pkl'  # Replace with your desired path
# Loading results_list_b
with open(file_results_dict_path, 'rb') as file:
    results_dict_0c = pickle.load(file)
    
    

file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240508_20M_Molecules/0.0_prob_dict_results_0.5M_L_20M_mol.pkl'  # Replace with your desired path
# Loading results_list_b
with open(file_prob_dict_path, 'rb') as file:
    prob_dict_results_0d = pickle.load(file)
    

file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240508_20M_Molecules/0.1_results_dict_0.5M_L_20M_mol.pkl'  # Replace with your desired path
# Loading results_list_b
with open(file_results_dict_path, 'rb') as file:
    results_dict_0d = pickle.load(file)


file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240508_20M_Molecules/0.0_prob_dict_results_0.1M_L_20M_mol.pkl'  # Replace with your desired path
# Loading results_list_b
with open(file_prob_dict_path, 'rb') as file:
    prob_dict_results_0e = pickle.load(file)
    

file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240508_20M_Molecules/0.1_results_dict_0.1M_L_20M_mol.pkl'  # Replace with your desired path
# Loading results_list_b
with open(file_results_dict_path, 'rb') as file:
    results_dict_0e = pickle.load(file)
    


In [None]:
prob_dict_results_2 = [prob_dict_results_0e, prob_dict_results_0d, prob_dict_results_0c, prob_dict_results_0b, prob_dict_results_0a]
results_dict_2 = [results_dict_0e, results_dict_0d, results_dict_0c, results_dict_0b, results_dict_0a]

##### Violine Chart

In [None]:
labels = ['MMTi \n0.1 Mol - 20M', 'MMTi \n0.5 Mol - 20M', 'MMTi \n1 Mol - 20M', 'MMTi \n2Mol - 20M', 'MMTi \n4Mol - 20M']
prob_dict_results_2 = [ prob_dict_results_0e, prob_dict_results_0d, prob_dict_results_0c, prob_dict_results_0b, prob_dict_results_0a]

labels = ['0.1M | 20M', '1M | 20M', '2M | 20M', '4M | 20M']
prob_dict_results_2 = [ prob_dict_results_0e, prob_dict_results_0c, prob_dict_results_0b, prob_dict_results_0a]


# Calculate mean and standard deviation for each dictionary
mean_results = []
std_results = []
data_for_violin = []
for prob_dict in prob_dict_results_2:
    mean_prob = np.mean(prob_dict["aggregated_corr_prob_multi"])
    mean_results.append(mean_prob)
    std_value = statistics.stdev(prob_dict["aggregated_corr_prob_multi"])
    std_results.append(std_value)
    data_for_violin.append(prob_dict["aggregated_corr_prob_multi"])

# Create the violin plot
fig, ax = plt.subplots(figsize=(8, 10))
parts = ax.violinplot(data_for_violin, showmeans=True, showmedians=False, showextrema=False)

# Customizing colors
color = '#A1C8F3'  # Use a single color for all the violin plots
for pc in parts['bodies']:
    pc.set_facecolor(color)
    pc.set_edgecolor('black')
    pc.set_alpha(0.7)

# Customizing the axes and labels
ax.set_title('Correct SMILES Sample Probability', fontsize=22)
ax.set_ylabel('Probability of Correct SMILES', fontsize=22)
ax.set_xticks(np.arange(1, len(labels) + 1))
ax.set_xticklabels(labels, rotation=45, ha='center', fontsize=22)
ax.tick_params(axis='both', which='major', labelsize=22)  # Adjust tick label size here

mean_values = [np.mean(data) if data is not None else 0 for data in data_for_violin]

# Add mean values as text
for pos, mean_value in zip(np.arange(1, len(mean_values) + 1), mean_values):
    ax.text(pos, mean_value, f'{mean_value:.2f}', ha='center', va='bottom', fontsize=22)

# Adding a legend
lightseagreen_patch = plt.Line2D([0], [0], color='lightseagreen', lw=4, label='MMST')
#ax.legend(handles=[lightseagreen_patch], loc='upper left', fontsize=22)

# Add grid and set the limits
ax.grid(axis='y', linestyle='--', alpha=0.7)
ax.set_ylim(0, 1)  # Adjust based on your data's range

plt.tight_layout()
save_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/0.1_Violin_20M_v3.png'  # Replace with your desired path
plt.savefig(save_path, format='png')
plt.show()


##### Tanimoto PLot

In [None]:
import numpy as np
import matplotlib.pyplot as plt

labels = ['MMTi \n0.1 Mol - 20M', 'MMTi \n0.5 Mol - 20M', 'MMTi \n1 Mol - 20M', 'MMTi \n2 Mol - 20M', 'MMTi \n4 Mol - 20M']
results_dict_2 = [results_dict_0e, results_dict_0d, results_dict_0c, results_dict_0b, results_dict_0a]

labels = ['0.1M | 20M', '1M | 20M', '2M | 20M', '4M | 20M']
results_dict_2 = [results_dict_0e, results_dict_0c, results_dict_0b, results_dict_0a]


# Prepare data
data_for_violin = [d["tanimoto_sim"] for d in results_dict_2]

# Calculate means for each dataset
mean_values = [np.mean(data) for data in data_for_violin]

# Labels for each subplot

# Create the violin plot
fig, ax = plt.subplots(figsize=(8, 10))
parts = ax.violinplot(data_for_violin, showmeans=True, showmedians=False, showextrema=False)

# Customizing colors to all lightseagreen
color = '#A1C8F3'
for pc in parts['bodies']:
    pc.set_facecolor(color)
    pc.set_edgecolor('black')
    pc.set_alpha(0.7)

# Customizing the axes and labels
ax.set_title('Greedy Sampled Average Tanimoto Similarity', fontsize=22)
#ax.set_xlabel('Trained Model', fontsize=16)
ax.set_ylabel('Average Tanimoto Similarity', fontsize=22)
ax.set_xticks(np.arange(1, len(labels) + 1))
ax.set_xticklabels(labels, rotation=45, ha='center', fontsize=22)
ax.tick_params(axis='both', which='major', labelsize=22)  # Adjust tick label size here

# Adding a legend
lightseagreen_patch = plt.Line2D([0], [0], color='lightseagreen', lw=4, label='MMST')
#ax.legend(handles=[lightseagreen_patch], loc='lower left', fontsize=22)

# Add grid and set the limits
ax.grid(axis='y', linestyle='--', alpha=0.7)
ax.set_ylim(0, 1)

# Add mean values as text
for pos, mean_value in zip(np.arange(1, len(mean_values) + 1), mean_values):
    ax.text(pos, mean_value, f'{mean_value:.2f}', ha='center', va='bottom', fontsize=22)

plt.tight_layout()
save_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/0.1_Tanimoto_Violin_plot_20_M_v3.png'
plt.savefig(save_path, format='png')
plt.show()


##### Invalid Molecules

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Sample input data
# Labels, split into two lines for x-tick display
labels = ['MMTi \n0.1 Mol - 20M', 'MMTi \n0.5 Mol - 20M', 'MMTi \n1 Mol - 20M', 'MMTi \n2 Mol - 20M', 'MMTi \n4 Mol - 20M']
results_dict_2 = [results_dict_0e, results_dict_0d, results_dict_0c, results_dict_0b, results_dict_0a]

# Labels, split into two lines for x-tick display
labels = ['0.1M | 20M', '1M | 20M', '2M | 20M', '4M | 20M']
results_dict_2 = [results_dict_0e, results_dict_0c, results_dict_0b, results_dict_0a]


# Prepare data
mean_results_2 = [len(d["failed"]) for d in results_dict_2]  # Assuming "failed" is a key in each dictionary
total_entries = len(results_dict_2[0]["gen_conv_SMI_list"]) # Total for percentage calculation

# Define colors for the bars
color2 = '#A1C8F3'

# Set bar width
bar_width = 0.35
positions = np.arange(len(labels))

# Create the plot
fig, ax = plt.subplots(figsize=(8, 10))

# Plotting the bars
bar1 = ax.bar(positions, mean_results_2, bar_width, color=color2, edgecolor='black', label='MMST')

# Adding value labels inside and percentage on top of each bar
for bar in bar1:
    yval = bar.get_height()
    percentage = (yval / total_entries) * 100 if total_entries > 0 else 0  # Calculate percentage
    # Inside the bar
    ax.text(bar.get_x() + bar.get_width() / 2, yval / 2, f'{yval:,.0f}', 
            rotation=90, ha='center', va='center', fontsize=22, color='black')
    
    # Top of the bar (moved higher)
    ax.text(bar.get_x() + bar.get_width() / 2, yval + 2000, f'{percentage:.1f}%', 
            rotation=90, ha='center', va='bottom', fontsize=22)

# Set the title and labels
plt.title(f'Greedy Sampled Number of Invalid SMILES         ', fontsize=22)
#plt.xlabel('Trained Model', fontsize=16)
plt.ylabel('Number of Invalid Molecules', fontsize=22)
plt.xticks(positions, labels, ha='center', fontsize=22)
ax.set_xticklabels(labels, rotation=45, ha='center', fontsize=22)
ax.tick_params(axis='both', which='major', labelsize=22)  # Adjust tick label size here

# Adding grid lines for better readability
ax.grid(axis='y', linestyle='--', alpha=0.7)
ax.set_ylim(0, max(mean_results_2) * 1.2)  # Set y-limit slightly higher than max for label visibility

# Set aspect ratio
ax.set_aspect(aspect='auto')

# Adding a legend
#ax.legend(fontsize=22)

# Adjust layout
plt.tight_layout()

# Save the figure
save_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/0.1_Invalid_molecules_20M_v3.png'
plt.savefig(save_path, format='png')

# Show plot
plt.show()


#### Model Size  & Training Time


In [None]:
"""
# V8 Raw 1M Large
config.num_encoder_layers = 6
config.num_decoder_layers = 6
config.num_heads = 16

# V8 Raw 1M
config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_1Mio/model-epoch=19-loss=0.04.ckpt"

model_MMT = mrtf.load_MMT_model(config)
prob_dict_results_0a, results_dict_0a = mrtf.run_model_analysis(config, model_MMT, val_dataloader_multi, stoi, itos)

import pickle
# Save the data to a file
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/0.0_prob_dict_results_1M_L_epoch_20.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'wb') as file:
    pickle.dump(prob_dict_results_0a, file)
    

# Save the data to a file
file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/0.1_results_dict_1M_Lepoch_20.pkl'  # Replace with your desired path
with open(file_results_dict_path, 'wb') as file:
    pickle.dump(results_dict_0a, file)"""

In [None]:
"""
# V8 Raw 1M Medium
config.num_encoder_layers = 3
config.num_decoder_layers = 3
config.num_heads = 8


### 20 M molecules seen
config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_1Mio_medium/model-epoch=19-loss=0.09.ckpt"


model_MMT = mrtf.load_MMT_model(config)
prob_dict_results_0b, results_dict_0b = mrtf.run_model_analysis(config, model_MMT, val_dataloader_multi, stoi, itos)


import pickle
# Save the data to a file
#file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/0.0_prob_dict_results_1M_M_20M_mol.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'wb') as file:
    pickle.dump(prob_dict_results_0b, file)
    

# Save the data to a file
#file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/0.1_results_dict_1M_M_20M_mol.pkl'  # Replace with your desired path
with open(file_results_dict_path, 'wb') as file:
    pickle.dump(results_dict_0b, file)"""


In [None]:
"""
# V8 Raw 1M Small
config.num_encoder_layers = 2
config.num_decoder_layers = 2
config.num_heads = 4

### 20 M molecules seen
#config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_1Mio_small/model-epoch=19-loss=0.11.ckpt"

model_MMT = mrtf.load_MMT_model(config)
prob_dict_results_0c, results_dict_0c = mrtf.run_model_analysis(config, model_MMT, val_dataloader_multi, stoi, itos)

import pickle
# Save the data to a file
#file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/0.0_prob_dict_results_1M_S_20M_mol.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'wb') as file:
    pickle.dump(prob_dict_results_0c, file)
    

# Save the data to a file
#file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/0.1_results_dict_1M_S_20M_mol.pkl'  # Replace with your desired path
with open(file_results_dict_path, 'wb') as file:
    pickle.dump(results_dict_0c, file)"""

##### Load saved data - Model Size 20M

In [None]:
import pickle


file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240508_20M_Molecules/0.0_prob_dict_results_1M_L_20M_mol.pkl'  # Replace with your desired path
# Loading results_list_b
with open(file_prob_dict_path, 'rb') as file:
    prob_dict_results_0d = pickle.load(file)

    
file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240508_20M_Molecules/0.1_results_dict_1M_L_20M_mol.pkl'  # Replace with your desired path
# Loading results_list_b
with open(file_results_dict_path, 'rb') as file:
    results_dict_0d = pickle.load(file)
    

file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240508_20M_Molecules_backup/0.0_prob_dict_results_1M_M_20M_mol.pkl'  # Replace with your desired path
#file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/20240330_20M_Molecules_experiment_FALSE/0.0_prob_dict_results_1M_M_last.pkl'  # Replace with your desired path
# Loading results_list_b
with open(file_prob_dict_path, 'rb') as file:
    prob_dict_results_0e = pickle.load(file)


file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240508_20M_Molecules_backup/0.1_results_dict_1M_M_20M_mol.pkl'  # Replace with your desired path
#file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/0.1_results_dict_1M_M_last.pkl'  # Replace with your desired path
# Loading results_list_b
with open(file_results_dict_path, 'rb') as file:
    results_dict_0e = pickle.load(file)
    

file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240508_20M_Molecules_backup/0.0_prob_dict_results_1M_S_20M_mol.pkl'  # Replace with your desired path
# Loading results_list_b
with open(file_prob_dict_path, 'rb') as file:
    prob_dict_results_0f = pickle.load(file)
    

file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240508_20M_Molecules_backup/0.1_results_dict_1M_S_20M_mol.pkl'  # Replace with your desired path
# Loading results_list_b
with open(file_results_dict_path, 'rb') as file:
    results_dict_0f = pickle.load(file)
    


In [None]:
prob_dict_results_2 = [prob_dict_results_0f, prob_dict_results_0e, prob_dict_results_0d]
results_dict_2 = [results_dict_0f, results_dict_0e, results_dict_0d]
test_data_size = len(pd.read_csv(config.csv_path_val))

##### Violin Chart


In [None]:
# Labels, split into two lines for x-tick display
labels = ['MMTi \n1 Mol 20M S', 'MMTi \n1 Mol 20M M','MMTi \n1 Mol 20M L']
labels = ['1M | 20M | S', '1M | 20M | M','1M | 20M | L']

results_dict_2 = [results_dict_0f, results_dict_0e, results_dict_0d]

# Calculate mean and standard deviation for each dictionary
mean_results = []
std_results = []
data_for_violin = []
for prob_dict in prob_dict_results_2:
    mean_prob = np.mean(prob_dict["aggregated_corr_prob_multi"])
    mean_results.append(mean_prob)
    std_value = statistics.stdev(prob_dict["aggregated_corr_prob_multi"])
    std_results.append(std_value)
    data_for_violin.append(prob_dict["aggregated_corr_prob_multi"])

# Create the violin plot
fig, ax = plt.subplots(figsize=(8, 10))
parts = ax.violinplot(data_for_violin, showmeans=True, showmedians=False, showextrema=False)

# Customizing colors
color = '#A1C8F3'  # Use a single color for all the violin plots
for pc in parts['bodies']:
    pc.set_facecolor(color)
    pc.set_edgecolor('black')
    pc.set_alpha(0.7)

# Customizing the axes and labels
ax.set_title('Correct SMILES Sample Probability', fontsize=22)
#ax.set_xlabel('Trained Model', fontsize=22)
ax.set_ylabel('Probability of Correct SMILES', fontsize=22)
ax.set_xticks(np.arange(1, len(labels) + 1))
ax.set_xticklabels(labels, rotation=45, ha='center', fontsize=22)
ax.tick_params(axis='both', which='major', labelsize=22)  # Adjust tick label size here

mean_values = [np.mean(data) if data is not None else 0 for data in data_for_violin]

# Add mean values as text
for pos, mean_value in zip(np.arange(1, len(mean_values) + 1), mean_values):
    ax.text(pos, mean_value, f'{mean_value:.2f}', ha='center', va='bottom', fontsize=22)

# Adding a legend
lightseagreen_patch = plt.Line2D([0], [0], color='#A1C8F3', lw=4, label='MMST')
#ax.legend(handles=[lightseagreen_patch], loc='upper left', fontsize=22)

# Add grid and set the limits
ax.grid(axis='y', linestyle='--', alpha=0.7)
ax.set_ylim(0, 1)  # Adjust based on your data's range

plt.tight_layout()
save_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/0.3_Violin_Size_20_M_mol_v3.png'  # Replace with your desired path
plt.savefig(save_path, format='png')
plt.show()


##### Tanimoto Chart

In [None]:
import numpy as np
import matplotlib.pyplot as plt


# Labels, split into two lines for x-tick display
labels = ['MMTi \n1 Mol 20M S', 'MMTi \n1 Mol 20M M','MMTi \n1 Mol 20M L']
labels = ['1M | 20M | S', '1M | 20M | M','1M | 20M | L']

results_dict_2 = [results_dict_0f, results_dict_0e, results_dict_0d]


# Prepare data
data_for_violin = [d["tanimoto_sim"] for d in results_dict_2]

# Calculate means for each dataset
mean_values = [np.mean(data) for data in data_for_violin]

# Labels for each subplot

# Create the violin plot
fig, ax = plt.subplots(figsize=(8, 10))
parts = ax.violinplot(data_for_violin, showmeans=True, showmedians=False, showextrema=False)

# Customizing colors to all lightseagreen
color = '#A1C8F3'
for pc in parts['bodies']:
    pc.set_facecolor(color)
    pc.set_edgecolor('black')
    pc.set_alpha(0.7)

# Customizing the axes and labels
ax.set_title('Greedy Sampled Average Tanimoto Similarity', fontsize=22)
#ax.set_xlabel('Trained Model', fontsize=16)
ax.set_ylabel('Average Tanimoto Similarity', fontsize=22)
ax.set_xticks(np.arange(1, len(labels) + 1))
ax.set_xticklabels(labels, rotation=45, ha='center', fontsize=22)
ax.tick_params(axis='both', which='major', labelsize=22)  # Adjust tick label size here

# Adding a legend
lightseagreen_patch = plt.Line2D([0], [0], color='#A1C8F3', lw=4, label='MMST')
#ax.legend(handles=[lightseagreen_patch], loc='lower left', fontsize=22)

# Add grid and set the limits
ax.grid(axis='y', linestyle='--', alpha=0.7)
ax.set_ylim(0, 1)

# Add mean values as text
for pos, mean_value in zip(np.arange(1, len(mean_values) + 1), mean_values):
    ax.text(pos, mean_value, f'{mean_value:.2f}', ha='center', va='bottom', fontsize=22)

plt.tight_layout()
save_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/0.3_Tanimoto_Violin_plot_Size_20_M_mol_v3.png'
plt.savefig(save_path, format='png')
plt.show()


##### Number of Invalid Molecules

In [None]:
import numpy as np
import matplotlib.pyplot as plt


# Labels, split into two lines for x-tick display
labels = ['MMTi \n1 Mol 20M S', 'MMTi \n1 Mol 20M M','MMTi \n1 Mol 20M L']
labels = ['1M | 20M | S', '1M | 20M | M','1M | 20M | L']

results_dict_2 = [results_dict_0f, results_dict_0e, results_dict_0d]


# Prepare data
mean_results_2 = [len(d["failed"]) for d in results_dict_2]  # Assuming "failed" is a key in each dictionary
total_entries = len(results_dict_2[0]["gen_conv_SMI_list"]) # Total for percentage calculation

# Define colors for the bars
color2 = '#A1C8F3'

# Set bar width
bar_width = 0.35
positions = np.arange(len(labels))

# Create the plot
fig, ax = plt.subplots(figsize=(8, 10))

# Plotting the bars
bar1 = ax.bar(positions, mean_results_2, bar_width, color=color2, edgecolor='black', label='MMST')

# Adding value labels inside and percentage on top of each bar
for bar in bar1:
    yval = bar.get_height()
    percentage = (yval / total_entries) * 100 if total_entries > 0 else 0  # Calculate percentage
    # Inside the bar
    ax.text(bar.get_x() + bar.get_width() / 2, yval / 2, f'{yval:,.0f}', 
            rotation=90, ha='center', va='center', fontsize=22, color='black')
    
    # Top of the bar (moved higher)
    ax.text(bar.get_x() + bar.get_width() / 2, yval + 2000, f'{percentage:.1f}%', 
            rotation=90, ha='center', va='bottom', fontsize=22)

# Set the title and labels
plt.title(f'Greedy Sampled Number of Invalid SMILES       ', fontsize=22)
#plt.xlabel('Trained Model', fontsize=16)
plt.ylabel('Number of Invalid Molecules', fontsize=22)
plt.xticks(positions, labels, ha='center', fontsize=22)
ax.tick_params(axis='both', which='major', labelsize=22)  # Adjust tick label size here
ax.set_xticklabels(labels, rotation=45, ha='center', fontsize=22)

# Adding grid lines for better readability
ax.grid(axis='y', linestyle='--', alpha=0.7)
ax.set_ylim(0, max(mean_results_2) * 1.2)  # Set y-limit slightly higher than max for label visibility

# Set aspect ratio
ax.set_aspect(aspect='auto')

# Adding a legend
#ax.legend(fontsize=22)

# Adjust layout
plt.tight_layout()

# Save the figure
save_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/0.3_Invalid_molecules_Size_20_M_mol_v3.png'
plt.savefig(save_path, format='png')

# Show plot
plt.show()


### 1.0 Test different models for paper


#### Run calculations


In [None]:
"""
config.IR_data_folder="/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/IR_spectra_NN"
config.data_size = 500000 #int(1000*data_fraction)
config.IR_data_folder = ""
config.training_mode = "1H_13C_HSQC_COSY_IR_MF_MW" # it ignors IR because no folder provided and puts zeros in for IR
config.multinom_runs = 1
config.batch_size = 1024 #int(1000*data_fraction)

config.csv_path_val = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/ML_NMR_5M_XL_1H_comb_test_V8.csv'
config.pickle_file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/ML_NMR_5M_XL_1H_comb_test_V8_355655.pkl"

val_dataloader_multi = mrtf.load_data(config, stoi, stoi_MF, single=False, mode="val")
"""

In [None]:
"""

# V8i Raw 
config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT/MultimodalTransformer_time_1702212447.1475492_Loss_0.056.ckpt"
model_MMT = mrtf.load_MMT_model(config)
prob_dict_results_1ai, results_dict_1ai = mrtf.run_model_analysis(config, model_MMT, val_dataloader_multi, stoi, itos)
print(np.mean(results_dict_1ai["tanimoto_sim"]))

# V8i Raw + MW
config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_MW/MultimodalTransformer_time_1702852525.1205726_Loss_0.088.ckpt"
model_MMT = mrtf.load_MMT_model(config)
prob_dict_results_1bi, results_dict_1bi = mrtf.run_model_analysis(config, model_MMT, val_dataloader_multi, stoi, itos)
print(np.mean(results_dict_1bi["tanimoto_sim"]))


# V8i Raw + MW + Drop
config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_Drop4/MultimodalTransformer_time_1710027004.1571195_Loss_0.112.ckpt"
model_MMT = mrtf.load_MMT_model(config)
prob_dict_results_1ci, results_dict_1ci = mrtf.run_model_analysis(config, model_MMT, val_dataloader_multi, stoi, itos)
print(np.mean(results_dict_1ci["tanimoto_sim"]))"""

In [None]:
"""
# V8i Raw 
config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT/MultimodalTransformer_time_1702212447.1475492_Loss_0.056.ckpt"

model_MMT = mrtf.load_MMT_model(config)
prob_dict_results_1ai, results_dict_1ai = mrtf.run_model_analysis(config, model_MMT, val_dataloader_multi, stoi, itos)
print(np.mean(results_dict_1ai["tanimoto_sim"]))


import pickle
# Save the data to a file
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/1.0_prob_dict_results_1ai.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'wb') as file:
    pickle.dump(prob_dict_results_1ai, file)
    

# Save the data to a file
file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/1.1_results_dict_1ai.pkl'  # Replace with your desired path
with open(file_results_dict_path, 'wb') as file:
    pickle.dump(results_dict_1ai, file)"""


In [None]:
"""

# V8i Raw + MW
config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_MW/MultimodalTransformer_time_1702852525.1205726_Loss_0.088.ckpt"
model_MMT = mrtf.load_MMT_model(config)
prob_dict_results_1bi, results_dict_1bi = mrtf.run_model_analysis(config, model_MMT, val_dataloader_multi, stoi, itos)
print(np.mean(results_dict_1bi["tanimoto_sim"]))


import pickle
# Save the data to a file
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/1.0_prob_dict_results_1bi.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'wb') as file:
    pickle.dump(prob_dict_results_1bi, file)
    

# Save the data to a file
file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/1.1_results_dict_1bi.pkl'  # Replace with your desired path
with open(file_results_dict_path, 'wb') as file:
    pickle.dump(results_dict_1bi, file)"""


In [None]:
"""

# V8i Raw + MW + Drop
config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_Drop4/MultimodalTransformer_time_1710027004.1571195_Loss_0.112.ckpt"
model_MMT = mrtf.load_MMT_model(config)
prob_dict_results_1ci, results_dict_1ci = mrtf.run_model_analysis(config, model_MMT, val_dataloader_multi, stoi, itos)
print(np.mean(results_dict_1ci["tanimoto_sim"]))


import pickle
# Save the data to a file
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/1.0_prob_dict_results_1ci.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'wb') as file:
    pickle.dump(prob_dict_results_1ci, file)
        

# Save the data to a file
file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/1.1_results_dict_1ci.pkl'  # Replace with your desired path
with open(file_results_dict_path, 'wb') as file:
    pickle.dump(results_dict_1ci, file)"""


#### Load calculated results


In [None]:
import pickle


    
file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240508_1_Trainings_Experiments/1.0_prob_dict_results_1ai__base.pkl"
# Loading results_list_b
with open(file_path, 'rb') as file:
    prob_dict_results_1ai = pickle.load(file)
    
file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240508_1_Trainings_Experiments/1.0_prob_dict_results_1bi.pkl"
# Loading results_list_b
with open(file_path, 'rb') as file:
    prob_dict_results_1bi = pickle.load(file)
    
file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240508_1_Trainings_Experiments/1.0_prob_dict_results_1ci.pkl"
# Loading results_list_b
with open(file_path, 'rb') as file:
    prob_dict_results_1ci = pickle.load(file)

In [None]:
import pickle

file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240508_1_Trainings_Experiments/1.1_results_dict_1ai__base.pkl"
# Loading results_list_b
with open(file_path, 'rb') as file:
    results_dict_1ai = pickle.load(file)
    

file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240508_1_Trainings_Experiments/1.1_results_dict_1bi_.pkl"
# Loading results_list_b
with open(file_path, 'rb') as file:
    results_dict_1bi = pickle.load(file)
    
file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240508_1_Trainings_Experiments/1.1_results_dict_1ci_.pkl"
# Loading results_list_b
with open(file_path, 'rb') as file:
    results_dict_1ci = pickle.load(file)
    


In [None]:
prob_dict_results_1 = [prob_dict_results_1ai, prob_dict_results_1bi, prob_dict_results_1ci]
results_dict_1 = [results_dict_1ai, results_dict_1bi, results_dict_1ci]


#### Correct Sample Probability

##### Bar Chart

In [None]:
prob_dict_results_1 = [prob_dict_results_1ai, prob_dict_results_1bi,  prob_dict_results_1ci]
results_dict_1 = [results_dict_1ai, results_dict_1bi, results_dict_1ci]


In [None]:
#results_dict_1 = [results_dict_1a, results_dict_1ai, results_dict_1b, results_dict_1bi, results_dict_1c, results_dict_1ci]
#results_dict_1 = [results_dict_1a, results_dict_1b, results_dict_1c]
results_dict_2 = [results_dict_1ai, results_dict_1bi, results_dict_1ci]
test_data_size = len(list(results_dict_1ai["trg_conv_SMI_list"]))

##### Violin Plot

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Sample input data
prob_dict_results_2 = [prob_dict_results_1ai, prob_dict_results_1bi, prob_dict_results_1ci]

# Prepare data
data_for_violin = [d["aggregated_corr_prob_multi"] for d in prob_dict_results_2]

# Calculate means for each data set
mean_values = [np.mean(d) for d in data_for_violin]

# Labels
labels = ['Raw', 'Raw \nMW', 'Raw \nMW \nDropout']

# Create the violin plot
fig, ax = plt.subplots(figsize=(8, 10))
parts = ax.violinplot(data_for_violin, showmeans=True, showmedians=False, showextrema=False)

# Customizing colors
for pc in parts['bodies']:
    pc.set_facecolor("#A1C8F3")
    pc.set_edgecolor('black')
    pc.set_alpha(0.7)

# Customizing the axes and labels
ax.set_title('Correct SMILES Sample Probability', fontsize=22)
ax.set_ylabel('Probability of Correct SMILES', fontsize=22)
ax.set_xticks([1, 2, 3])
ax.set_xticklabels(labels, ha='center', fontsize=22)
plt.yticks(fontsize=22)

# Add grid and set the limits
ax.grid(axis='y', linestyle='--', alpha=0.7)
ax.set_ylim(0, 1)

# Add mean values as text
for pos, mean_value in enumerate(mean_values, 1):
    ax.text(pos, mean_value, f'{mean_value:.2f}', ha='center', va='bottom', fontsize=22)

plt.tight_layout()
    
save_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/1.0_violin_plot_v3.png'  # Replace with your desired path
plt.savefig(save_path, format='png')
plt.show()

# Uncomment the following lines if you want to save the figure
# save_path = '/path/to/your/figure/1.0_violin_plot_prob_dict_results_2.png'


#### Greedy Sampled Tanimoto Similarity

##### Violing Plot

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Sample input data
results_dict_2 = [results_dict_1ai, results_dict_1bi, results_dict_1ci]

# Prepare data
data_for_violin = [d["tanimoto_sim"] for d in results_dict_2]

# Calculate means for each data set
mean_values = [np.mean(d) for d in data_for_violin]

# Labels
labels = ['Raw', 'Raw \nMW', 'Raw \nMW \nDropout']

# Create the violin plot
fig, ax = plt.subplots(figsize=(8, 10))
parts = ax.violinplot(data_for_violin, showmeans=True, showmedians=False, showextrema=False)

# Customizing colors
for pc in parts['bodies']:
    pc.set_facecolor('#A1C8F3')
    pc.set_edgecolor('black')
    pc.set_alpha(0.7)

# Customizing the axes and labels
ax.set_title('Greedy Sampled Average Tanimoto Similarity', fontsize=22)
ax.set_ylabel('Average Tanimoto Similarity', fontsize=22)
ax.set_xticks([1, 2, 3])
ax.set_xticklabels(labels, ha='center', fontsize=22)
plt.yticks(fontsize=22)

# Add grid and set the limits
ax.grid(axis='y', linestyle='--', alpha=0.7)
ax.set_ylim(0, 1)

# Add mean values as text
for pos, mean_value in enumerate(mean_values, 1):
    ax.text(pos, mean_value, f'{mean_value:.2f}', ha='center', va='bottom', fontsize=22)

plt.tight_layout()

save_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/1.0_Tanimoto_Similarity_v3.png'
plt.savefig(save_path, format='png')
plt.show()

#### Number of Invalid Molecules

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Sample input data
results_dict_2 = [results_dict_1ai, results_dict_1bi, results_dict_1ci]

# Prepare data
mean_results_2 = [len(d["failed"]) for d in results_dict_2]

# Labels
labels = ['Raw', 'Raw \nMW', 'Raw \nMW \nDropout']

# Define color for the bars
color = '#A1C8F3'

# Set bar width
bar_width = 0.5
positions = np.arange(len(labels))

# Create the plot
fig, ax = plt.subplots(figsize=(8, 10))

# Plotting the bars
bars = ax.bar(positions, mean_results_2, bar_width, capsize=10, 
              color=color, edgecolor='black', label='MMST')

# Adding percentage labels on top of each bar
for bar in bars:
    yval = bar.get_height()
    percentage = (yval / test_data_size) * 100
    ax.text(bar.get_x() + bar.get_width() / 2, yval + 500, f'{percentage:.2f}%', 
            rotation=90, ha='center', va='bottom', fontsize=22)

# Adding value labels in the middle of each bar
for bar in bars:
    yval = bar.get_height()
    ax.text(bar.get_x() + bar.get_width() / 2, yval / 2, f'{yval:,.0f}', 
            rotation=90, ha='center', va='center', fontsize=22, color='black')

# Set the title and labels
plt.title('Greedy Sampled Number of Invalid SMILES', fontsize=22)
plt.ylabel('Number of Invalid Molecules', fontsize=22)
plt.yticks(fontsize=22)

# Set x-ticks
plt.xticks(positions, labels,  ha='center', fontsize=22)

# Adding grid lines for better readability
ax.grid(axis='y', linestyle='--', alpha=0.7)
ax.set_ylim(0, max(mean_results_2) * 1.2)

# Set aspect ratio
ax.set_aspect(aspect='auto')

# Adding a legend
#ax.legend(fontsize=22)

# Adjust layout
plt.tight_layout()

save_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/1.0_Invalid_Molecules_v3.png'
plt.savefig(save_path, format='png', dpi=300, bbox_inches='tight')

# Show plot
plt.show()

### 2.0 Ablation Study fine-tuned

#### Run Experiment

In [None]:
'''config.multinom_runs = 1
config.training_mode = "13C_HSQC_COSY_IR_MF_MW"
config.data_size = 489993
# No 1H
config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_4Mio_no_H_3/model-epoch=00-loss=0.03.ckpt"
model_MMT = mrtf.load_MMT_model(config)
val_dataloader_multi = mrtf.load_data(config, stoi, stoi_MF, single=False, mode="val")
prob_dict_results_2_1a, results_dict_2_1a = mrtf.run_model_analysis(config, model_MMT, val_dataloader_multi, stoi, itos)
#print(np.mean(results_dict_2_1a["tanimoto_sim"]))


import pickle
# Save the data to a file
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240624_FT_Ablation_Study/2.1_prob_dict_results_2_1a.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'wb') as file:
    pickle.dump(prob_dict_results_2_1a, file)

# Save the data to a file
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240624_FT_Ablation_Study/2.1_results_dict_2_1a.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'wb') as file:
    pickle.dump(results_dict_2_1a, file)

In [None]:
'''config.multinom_runs = 1
config.training_mode = "1H_HSQC_COSY_IR_MF_MW"
config.data_size = 489993
# No 13C
config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_4Mio_no_C/model-epoch=00-loss=0.03.ckpt"
model_MMT = mrtf.load_MMT_model(config)
val_dataloader_multi = mrtf.load_data(config, stoi, stoi_MF, single=False, mode="val")
prob_dict_results_2_1b, results_dict_2_1b = mrtf.run_model_analysis(config, model_MMT, val_dataloader_multi, stoi, itos)
#print(np.mean(results_dict_2_1b["tanimoto_sim"]))

import pickle
# Save the data to a file
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240624_FT_Ablation_Study/2.1_prob_dict_results_2_1b.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'wb') as file:
    pickle.dump(prob_dict_results_2_1b, file)

# Save the data to a file
file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240624_FT_Ablation_Study/2.1_results_dict_2_1b.pkl'  # Replace with your desired path
with open(file_results_dict_path, 'wb') as file:
    pickle.dump(results_dict_2_1b, file)

In [None]:
'''config.multinom_runs = 1
config.training_mode = "1H_13C_COSY_IR_MF_MW"
config.data_size = 489993
# No HSQC
config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_4Mio_no_HSQC/model-epoch=00-loss=0.03.ckpt"
model_MMT = mrtf.load_MMT_model(config)
val_dataloader_multi = mrtf.load_data(config, stoi, stoi_MF, single=False, mode="val")
prob_dict_results_2_1c, results_dict_2_1c = mrtf.run_model_analysis(config, model_MMT, val_dataloader_multi, stoi, itos)
#print(np.mean(results_dict_2_1c["tanimoto_sim"]))

import pickle
# Save the data to a file
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240624_FT_Ablation_Study/2.1_prob_dict_results_2_1c.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'wb') as file:
    pickle.dump(prob_dict_results_2_1c, file)

# Save the data to a file
file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240624_FT_Ablation_Study/2.1_results_dict_2_1c.pkl'  # Replace with your desired path
with open(file_results_dict_path, 'wb') as file:
    pickle.dump(results_dict_2_1c, file)

In [None]:
'''config.multinom_runs = 1
config.training_mode = "1H_13C_HSQC_IR_MF_MW"
config.data_size = 489993
# No COSY
config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_4Mio_no_COSY/model-epoch=00-loss=0.03.ckpt"
model_MMT = mrtf.load_MMT_model(config)
val_dataloader_multi = mrtf.load_data(config, stoi, stoi_MF, single=False, mode="val")
prob_dict_results_2_1d, results_dict_2_1d = mrtf.run_model_analysis(config, model_MMT, val_dataloader_multi, stoi, itos)
#print(np.mean(results_dict_2_1e["tanimoto_sim"]))

import pickle
# Save the data to a file
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240624_FT_Ablation_Study/2.1_prob_dict_results_2_1d.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'wb') as file:
    pickle.dump(prob_dict_results_2_1d, file)

# Save the data to a file
file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240624_FT_Ablation_Study/2.1_results_dict_2_1d.pkl'  # Replace with your desired path
with open(file_results_dict_path, 'wb') as file:
    pickle.dump(results_dict_2_1d, file)

In [None]:
'''config.multinom_runs = 1
config.training_mode = "1H_13C_HSQC_COSY_MF_MW"
config.data_size = 489993
# No IR
config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_4Mio_no_IR/model-epoch=00-loss=0.02.ckpt"
model_MMT = mrtf.load_MMT_model(config)
val_dataloader_multi = mrtf.load_data(config, stoi, stoi_MF, single=False, mode="val")
prob_dict_results_2_1e, results_dict_2_1e = mrtf.run_model_analysis(config, model_MMT, val_dataloader_multi, stoi, itos)
#print(np.mean(results_dict_2_1e["tanimoto_sim"]))

import pickle
# Save the data to a file
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240624_FT_Ablation_Study/2.1_prob_dict_results_2_1e.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'wb') as file:
    pickle.dump(prob_dict_results_2_1e, file)

# Save the data to a file
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240624_FT_Ablation_Study/2.1_results_dict_2_1e.pkl'  # Replace with your desired path
with open(file_results_dict_path, 'wb') as file:
    pickle.dump(results_dict_2_1e, file)

#### Load data

In [None]:
import pickle
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240624_FT_Ablation_Study/2.1_results_dict_2_1a.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'rb') as file:
    results_dict_2_1a = pickle.load(file)
    
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240624_FT_Ablation_Study/2.1_results_dict_2_1b.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'rb') as file:
    results_dict_2_1b = pickle.load(file)
    
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240624_FT_Ablation_Study/2.1_results_dict_2_1c.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'rb') as file:
    results_dict_2_1c = pickle.load(file)
    
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240624_FT_Ablation_Study/2.1_results_dict_2_1d.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'rb') as file:
    results_dict_2_1d = pickle.load(file)
    
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240624_FT_Ablation_Study/2.1_results_dict_2_1e.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'rb') as file:
    results_dict_2_1e = pickle.load(file)


In [None]:
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240624_FT_Ablation_Study/2.1_prob_dict_results_2_1a.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'rb') as file:
    prob_dict_results_2_1a = pickle.load(file)
    
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240624_FT_Ablation_Study/2.1_prob_dict_results_2_1b.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'rb') as file:
    prob_dict_results_2_1b = pickle.load(file)
    
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240624_FT_Ablation_Study/2.1_prob_dict_results_2_1c.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'rb') as file:
    prob_dict_results_2_1c = pickle.load(file)
    
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240624_FT_Ablation_Study/2.1_prob_dict_results_2_1d.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'rb') as file:
    prob_dict_results_2_1d = pickle.load(file)
    
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240624_FT_Ablation_Study/2.1_prob_dict_results_2_1e.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'rb') as file:
    prob_dict_results_2_1e = pickle.load(file)
    

In [None]:
#all
file_prob_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/2.0_prob_dict_results_2ai.pkl'  # Replace with your desired path
with open(file_prob_dict_path, 'rb') as file:
    prob_dict_results_2_1all = pickle.load(file)
    
file_results_dict_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/2.1_results_dict_2ai.pkl'  # Replace with your desired path
with open(file_results_dict_path, 'rb') as file:
    results_dict_2_1all = pickle.load(file)

In [None]:
results_dict_2 = [results_dict_2_1all, 
                  results_dict_2_1a, 
                    results_dict_2_1b, 
                    results_dict_2_1c, 
                    results_dict_2_1d, 
                    results_dict_2_1e,
                 ]

In [None]:
prob_dict_results_2 = [prob_dict_results_2_1all,
                       prob_dict_results_2_1a, 
                    prob_dict_results_2_1b, 
                    prob_dict_results_2_1c, 
                    prob_dict_results_2_1d, 
                    prob_dict_results_2_1e,
                      ]


#### Violin Plot

In [None]:

# Prepare the data for the violin plot
data_for_violin = [d["aggregated_corr_prob_multi"] for d in prob_dict_results_2]

# Labels
labels = ['All', 'W/O 1H', 'W/O 13C', 
          'W/O HSQC', 'W/O COSY', 'W/O IR']

# Create the violin plot
fig, ax = plt.subplots(figsize=(8, 10))
parts = ax.violinplot(data_for_violin, showmeans=True, showmedians=False, showextrema=False)

# Customizing colors
colors = ['#A1C8F3', '#A1C8F3', '#A1C8F3', '#A1C8F3', '#A1C8F3', '#A1C8F3', '#A1C8F3']
for i, pc in enumerate(parts['bodies']):
    pc.set_facecolor(colors[i])
    pc.set_edgecolor('black')
    pc.set_alpha(0.7)
    

mean_values = [np.mean(data) if data is not None else 0 for data in data_for_violin]

# Add mean values as text
for pos, mean_value in zip(np.arange(1, len(mean_values) + 1), mean_values):
    ax.text(pos, mean_value, f'{mean_value:.2f}', ha='center', va='bottom', fontsize=22)

# Customizing the axes and labels
ax.set_title(f'Averaged Correct SMILES Sample Probability', fontsize=22)
#ax.set_xlabel('Trained Model', fontsize=22)
ax.set_ylabel('Probability of Correct SMILES', fontsize=22)
ax.set_xticks(np.arange(1, len(labels) + 1))
ax.set_xticklabels(labels, rotation=45, ha='center', fontsize=22)
ax.tick_params(axis='both', which='major', labelsize=22)  # Adjust tick label size here
#ax.legend(handles=[mmti_patch], loc='upper left', fontsize=22)

# Add grid and set the limits
ax.grid(axis='y', linestyle='--', alpha=0.7)
ax.set_ylim(0, 1)  # Adjust based on your data's range

plt.tight_layout()
save_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/2.1_Ablation_Violin_chart_v3.png'  # Replace with your desired path
plt.savefig(save_path, format='png')
plt.show()


#### Histogram

In [None]:
"""list_list = []
for prob_dict_results in prob_dict_results_2:
    list_list.append((prob_dict_results["aggregated_corr_prob_multi"]))
    print(len(prob_dict_results["aggregated_corr_prob_multi"]))

# Labels for each subplot
labels = ['All', 'W/O 1H', 'W/O 13C', 
          'W/O HSQC', 'W/O COSY', 'W/O IR']

# Creating a 2x3 grid of subplots
fig, axes = plt.subplots(3, 2, figsize=(15, 15))

# Flatten the axes array for easy iteration
axes = axes.ravel()

for i, ax in enumerate(axes):
    if i != 12:
        # Plot histogram for each sublist
        ax.hist(list_list[i], bins=20, color='#A1C8F3', edgecolor='black')
        ax.set_title(f'Histogram of {labels[i]}', fontsize=16)
        ax.set_xlabel('Value', fontsize=16)
        ax.set_ylabel('Frequency', fontsize=16)
        
plt.yticks(fontsize=16)
plt.xticks(fontsize=16)

# Adjust layout
# Adjust layout to prevent overlap
plt.tight_layout()
save_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/2.1_Histogram_Ablation_v3.png'  # Replace with your desired path
plt.savefig(save_path, format='png')"""

#### Failed Molecules

In [None]:
import matplotlib.pyplot as plt
import statistics

# Calculate the mean number of failed results for each configuration
mean_results_2 = []
for results_dict in results_dict_2:
    mean_prob = len(results_dict["failed"])
    mean_results_2.append(mean_prob)

# Labels for each subplot
labels = ['All', 'W/O 1H', 'W/O 13C', 
          'W/O HSQC', 'W/O COSY', 'W/O IR']

# Define colors for the bars
color = '#A1C8F3'

# Plotting the bar chart
fig, ax = plt.subplots(figsize=(8, 10))
bars = ax.bar(range(len(mean_results_2)), mean_results_2, align='center', alpha=0.7, ecolor='black', capsize=10, edgecolor='black', color=color)

# Adding value labels in the center of each bar
for bar in bars:
    yval = bar.get_height()
    ax.text(bar.get_x() + bar.get_width() / 2, yval / 2, f'{yval:,.0f}',
            rotation=90, ha='center', va='center', fontsize=22, color='black')

# Set the title and labels
ax.set_title('Greedy Sampled Number of Invalid Molecules   ', fontsize=22)
ax.set_ylabel('Number of Invalid Molecules', fontsize=22)

# Set x-ticks
ax.set_xticks(range(len(mean_results_2)))
ax.set_xticklabels(labels, rotation=45, ha='center', fontsize=22)
ax.tick_params(axis='both', which='major', labelsize=22)
#ax.legend(handles=[mmti_patch], loc='upper left', fontsize=22)

# Adding grid lines for better readability
ax.grid(axis='y', linestyle='--', alpha=0.7)

# Show plot with a black box around the bars
for bar in bars:
    bar.set_edgecolor('black')

# Adjust layout
plt.tight_layout()

# Save the plot (uncomment if needed)
save_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/2.1_Ablation_Invalid_molecules_v3.png'  # Replace with your desired path
plt.savefig(save_path, format='png', dpi=300, bbox_inches='tight')

# Show plot
plt.show()

#### Tanimoto comparison

In [None]:
results_dict_2_1all.get("tanimoto_sim", [])

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Sample input data (replace these with your actual data)
results_dict_2 = [results_dict_2_1all, 
                  results_dict_2_1a, 
                  results_dict_2_1b, 
                  results_dict_2_1c, 
                  results_dict_2_1d, 
                  results_dict_2_1e]
#                  results_dict_2_1f]

# Prepare data and filter out empty datasets
data_for_violin = []
filtered_labels = []
labels = ['All', 'W/O 1H', 'W/O 13C', 
          'W/O HSQC', 'W/O COSY', 'W/O IR']

for d, label in zip(results_dict_2, labels):
    if label == 'All':
        tanimoto_sim = d.get("tanimoto_sim", [])
    else:
        tanimoto_sim = d.get("tanimoto_scores_all", [])
    if len(tanimoto_sim) > 0:
        data_for_violin.append(tanimoto_sim)
        filtered_labels.append(label)

if not data_for_violin:
    print("Error: All datasets are empty. Unable to create violin plot.")
else:
    # Create the violin plot
    fig, ax = plt.subplots(figsize=(8, 10))

    # Define positions for the violins
    positions = np.arange(1, len(filtered_labels) + 1)

    # Plot violins
    parts = ax.violinplot(data_for_violin, positions=positions, showmeans=True, showmedians=False, showextrema=False)

    # Customizing colors
    color = '#A1C8F3'
    for pc in parts['bodies']:
        pc.set_facecolor(color)
        pc.set_edgecolor('black')
        pc.set_alpha(0.7)

    # Customizing the axes and labels
    ax.set_title('Average Greedy Sampled Tanimoto Similarity', fontsize=22)
    #ax.set_xlabel('Model Variation', fontsize=22)
    ax.set_ylabel('Tanimoto Similarity', fontsize=22)

    # Set x-ticks in the middle of each group
    ax.set_xticks(positions)
    ax.set_xticklabels(filtered_labels, rotation=45, ha='center', fontsize=22)

    ax.tick_params(axis='both', which='major', labelsize=22)

    # Adding a legend
    mmti_patch = plt.Line2D([0], [0], color='#A1C8F3', lw=4, label='MMST')
    #ax.legend(handles=[mmti_patch], loc='lower left', fontsize=22)

    # Add grid and set the limits
    ax.grid(axis='y', linestyle='--', alpha=0.7)
    ax.set_ylim(0, 1)

    # Calculate and add mean values as text
    for pos, data in zip(positions, data_for_violin):
        mean_value = np.mean(data)
        ax.text(pos, mean_value, f'{mean_value:.2f}', ha='center', va='bottom', fontsize=22)

    plt.tight_layout()
    
    # Save the figure
    save_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/2.1_Ablation_Tanimoto_v3.png'
    plt.savefig(save_path, format='png')

    plt.show()

# Print the data for debugging
print("Data for violin plot:", data_for_violin)
print("Filtered labels:", filtered_labels)


### 3.0 Percentage Calculations V8 and V8i with TRUE or FALSE

#### Generate a table with Greedy/Top 1/3/5/10 accuracy


#### 3.2 Load/process and plot data

In [None]:
# Standard library imports
import glob
import json
import os
import random
from collections import defaultdict
from functools import reduce
from argparse import Namespace

# Data processing and scientific computing
import numpy as np
import pandas as pd
from tqdm import tqdm
import statistics
import operator

# Visualization libraries
import matplotlib.pyplot as plt

# PyTorch for deep learning
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader

# RDKit for cheminformatics
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, Draw, Descriptors, MolFromSmiles, MolToSmiles

# Weights & Biases for experiment tracking
import wandb

# Local utilities/modules
import utils_MMT.sgnn_code_pl_v15_4 as sc
import utils_MMT.train_test_functions_pl_v15_4 as ttf
import utils_MMT.helper_functions_pl_v15_4 as hf
import utils_MMT.validate_generate_MMT_v15_4 as vgmmt
import utils_MMT.run_batch_gen_val_MMT_v15_4 as rbgvm
from utils_MMT.dataloaders_pl_v15_4 import MultimodalData, collate_fn
from utils_MMT.models_MMT_v15_4 import MultimodalTransformer, TransformerMultiGPU
from utils_MMT.models_CLIP_v15_4 import CLIPMultiGPU  
from utils_MMT.models_BLIP_v15_4 import BLIPMultiGPU  




def worker_init_fn(worker_id):
    np.random.seed(torch.initial_seed() % (2**32))
    random.seed(torch.initial_seed() % (2**32))

def prepare_HSQC_data_from_src_2(src_HSQC_list):
    """
    Processes and scales HSQC spectral data from the source list.
    """
    processed_HSQC = []

    for src in src_HSQC_list:
        #for src in src_HSQC:
            # Filter out rows where both elements are not zero
            non_zero_mask = (src != 0).all(dim=1)
            filtered_src = src[non_zero_mask]

            if filtered_src.nelement() != 0:  # Check if tensor is not empty
                scaled_tensors = [filtered_src[:, 0] * 10, filtered_src[:, 1] * 200]
                combined_tensor = torch.stack(scaled_tensors, dim=1)
                processed_HSQC.append(combined_tensor)
            else:
                processed_HSQC.append(torch.tensor([]))  # Append an empty tensor for consistency

    return processed_HSQC

def prepare_COSY_data_from_src_2(src_COSY_list):
    """
    Processes and scales HSQC spectral data from the source list.
    """
    processed_COSY = []

    for src in src_COSY_list:
        #for src in src_HSQC:

            # Filter out rows where both elements are not zero
            non_zero_mask = (src != 0).all(dim=1)
            filtered_src = src[non_zero_mask]

            if filtered_src.nelement() != 0:  # Check if tensor is not empty
                scaled_tensors = [filtered_src[:, 0] * 10, filtered_src[:, 1] * 10]
                combined_tensor = torch.stack(scaled_tensors, dim=1)
                processed_COSY.append(combined_tensor)
            else:
                processed_COSY.append(torch.tensor([]))  # Append an empty tensor for consistency

    return processed_COSY

def generate_df_for_HSQC_calculations(gen_conv_SMI_list, trg_conv_SMI_list, src_HSQC_list):
    """
    Generates a DataFrame containing successfully generated SMILES along with their corresponding target SMILES
    and a unique sample identifier. It also filters and returns new lists for source HSQC data and failed SMILES pairs.
    """
    succ_gen_list =[]
    src_HSQC_list_new = []
    failed_list = []
    for i, (gen_smi, trg_smi, src_HSQC) in enumerate(zip(gen_conv_SMI_list, trg_conv_SMI_list, src_HSQC_list)):
            mol = Chem.MolFromSmiles(gen_smi)
            if mol is not None:
                ran_num = random.randint(0, 100000)
                sample_id = f"{i}_{ran_num}"
                succ_gen_list.append([sample_id, gen_smi, trg_smi])
                src_HSQC_list_new.append(src_HSQC)
            else:
                failed_list.append([gen_smi,trg_smi])
                continue

    # Create a DataFrame of successful SMILES generations
    df_succ_smis = pd.DataFrame(succ_gen_list, columns=['sample-id', 'SMILES', "trg_SMILES"])
    return df_succ_smis, src_HSQC_list_new, failed_list

def generate_df_for_COSY_calculations(gen_conv_SMI_list, trg_conv_SMI_list, src_COSY_list):
    """
    Generates a DataFrame containing successfully generated SMILES along with their corresponding target SMILES
    and a unique sample identifier. It also filters and returns new lists for source HSQC data and failed SMILES pairs.
    """
    succ_gen_list =[]
    src_COSY_list_new = []
    failed_list = []
    for i, (gen_smi, trg_smi, src_HSQC) in enumerate(zip(gen_conv_SMI_list, trg_conv_SMI_list, src_COSY_list)):
            mol = Chem.MolFromSmiles(gen_smi)
            if mol is not None:
                ran_num = random.randint(0, 100000)
                sample_id = f"{i}_{ran_num}"
                succ_gen_list.append([sample_id, gen_smi, trg_smi])
                src_COSY_list_new.append(src_HSQC)
            else:
                failed_list.append([gen_smi,trg_smi])
                continue

    # Create a DataFrame of successful SMILES generations
    df_succ_smis = pd.DataFrame(succ_gen_list, columns=['sample-id', 'SMILES', "trg_SMILES"])
    return df_succ_smis, src_COSY_list_new, failed_list


def calculate_corr_max_prob(config, model_MMT, val_dataloader, stoi, itos):
    """
    Calculates and aggregates the probabilities of correct token predictions and maximum probabilities
    across all batches in a validation dataloader using a given model.
    """
    aggregated_corr_prob_multi, aggregated_corr_prob_avg, aggregated_max_prob_multi, aggregated_max_prob_avg =[],[],[],[]
    prob_dict_results = {}
    for idx, data_dict in enumerate(tqdm(val_dataloader)):

        memory, src_padding_mask, trg_enc_SMI, fingerprint, src_HSQC, src_COSY = vgmmt.run_model(model_MMT, data_dict, config) 
        trg_tensor, corr_token_prob, trg_tensor_max, max_token_prob = predict_prop_correct_max_sequence_3(model_MMT, stoi, memory, src_padding_mask, trg_enc_SMI, config)

        # Ensure we're working with 2D tensors even for single molecules
        if corr_token_prob.dim() == 1:
            corr_token_prob = corr_token_prob.unsqueeze(1)
            max_token_prob = max_token_prob.unsqueeze(1)
        if isinstance(trg_tensor, int):
            trg_tensor = [trg_tensor]
        if isinstance(trg_tensor_max, int):
            trg_tensor_max = [trg_tensor_max]

        for corr_prob_list, max_prob_list, token_list, token_list_max in zip(corr_token_prob.T, max_token_prob.T, trg_tensor, trg_tensor_max):
            seq_corr_probs, seq_max_probs  = [], []
            for corr_prob, max_prob, token in zip(corr_prob_list, max_prob_list, token_list):                                                  
                if token == stoi["<EOS>"]:  # End of sequence
                    break
                seq_corr_probs.append(corr_prob.detach().item())
                seq_max_probs.append(max_prob.detach().item())
                    # Populate dictionaries with calculated probabilities

            prob_corr_multi = reduce(operator.mul, seq_corr_probs, 1)
            prob_corr_avg = statistics.mean(seq_corr_probs) if seq_corr_probs else 0
            prob_max_multi = reduce(operator.mul, seq_max_probs, 1)
            prob_max_avg = statistics.mean(seq_max_probs) if seq_max_probs else 0

            aggregated_corr_prob_multi.append(prob_corr_multi)
            aggregated_corr_prob_avg.append(prob_corr_avg)  
            aggregated_max_prob_multi.append(prob_max_multi)
            aggregated_max_prob_avg.append(prob_max_avg) 
        #import IPython; IPython.embed();


    prob_dict_results["aggregated_corr_prob_multi"] = aggregated_corr_prob_multi
    prob_dict_results["aggregated_corr_prob_avg"] = aggregated_corr_prob_avg
    prob_dict_results["aggregated_max_prob_multi"] = aggregated_max_prob_multi
    prob_dict_results["aggregated_max_prob_avg"] = aggregated_max_prob_avg
    return prob_dict_results


def evaluate_greedy_2(model, stoi, itos, val_dataloader, config, randomize=False):
    """
    Evaluates the greedy generation approach over a dataset.
    """    
    gen_conv_SMI_list = []
    trg_conv_SMI_list = []
    src_HSQC_list = []
    src_COSY_list = []
    token_probs_list = []
    data_dict_list = []
    # generate all the smiles of trg and greedy gen
    for i, data_dict in tqdm(enumerate(val_dataloader)):
        #import IPython; IPython.embed();
        memory, src_padding_mask, trg_enc_SMI, fingerprint, src_HSQC, src_COSY = vgmmt.run_model(model,
                                                                       data_dict, 
                                                                       config)
        #print("eval_greedy")
        #print(src_HSQC)
        greedy_tensor, greedy_token_prob = greedy_sequence_2(model, stoi, itos, memory, src_padding_mask, config)
        gen_conv_SMI = hf.tensor_to_smiles(greedy_tensor, itos)
        #gen_conv_SMI, token_probs = ttf.tensor_to_smiles_and_prob(greedy_tensor.squeeze(1), greedy_token_prob, itos)
        token_probs_list.append(greedy_token_prob)
        gen_conv_SMI_list.extend(gen_conv_SMI)
        #gen_conv_SMI_list = gen_conv_SMI_list + gen_conv_SMI
        
        trg_enc_SMI = data_dict["trg_enc_SMI"]
        trg_enc_SMI = trg_enc_SMI.transpose(0, 1)
        trg_SMI_input = trg_enc_SMI[1:, :] # Remove <EOS> token from target sequence
        trg_conv_SMI = hf.tensor_to_smiles(trg_SMI_input, itos)
        trg_conv_SMI_list = trg_conv_SMI_list + trg_conv_SMI
        src_HSQC_list.extend(src_HSQC)
        src_COSY_list.extend(src_COSY)
        data_dict_list.append(data_dict)
    # Calculate validity of gen smiles
    validity_term = hf.get_validity_term(gen_conv_SMI_list) 
    # Calculate tanimoto similarity
    if randomize == True:
        random.shuffle(gen_conv_SMI_list)

    tanimoto_mean, tanimoto_std_dev, failed_pairs, tanimoto_scores_, tanimoto_scores_all, gen_conv_SMI_list_, trg_conv_SMI_list_, idx_list = hf.calculate_tanimoto_similarity_2(gen_conv_SMI_list, trg_conv_SMI_list)

    results_dict = {
            'gen_conv_SMI_list': gen_conv_SMI_list,
            'trg_conv_SMI_list': trg_conv_SMI_list,
            'gen_conv_SMI_list_': gen_conv_SMI_list_,
            'trg_conv_SMI_list_': trg_conv_SMI_list_,            
            'idx_list': idx_list,
            'token_probs_list': token_probs_list,
            'validity_term': validity_term,
            'tanimoto_scores_': tanimoto_scores_,
            'tanimoto_scores_all': tanimoto_scores_all,
            'data_dict_list': data_dict_list,
            'failed':failed_pairs}
    return results_dict, src_HSQC_list, src_COSY_list    



def predict_prop_correct_max_sequence_3(model, stoi, memory, src_padding_mask, trg_enc_SMI, config):
    """
    Predicts the properties of each token in a sequence generated by a transformer model.
    """

    # Ensure the model is in evaluation mode
    model.eval()

    # Define the initial target tensor with <SOS> tokens
    N = memory.size(1)
    trg_tensor = torch.full((1, N), stoi["<SOS>"], dtype=torch.long, device=config.device)
    trg_tensor_max = torch.full((1, N), stoi["<SOS>"], dtype=torch.long, device=config.device)

    # Token probabilities containers
    corr_token_prob, max_token_prob = [], []

    # Transpose target encoded SMILES and remove <EOS> token
    trg_enc_SMI_T = trg_enc_SMI.transpose(0, 1)
    real_trg = trg_enc_SMI_T[1:, :]
    
    # Iterate over each token in the target sequence
    with torch.no_grad():
        for idx in range(real_trg.shape[0]):
            # Prepare input for the decoder
            gen_seq_length, N = trg_tensor.shape
            gen_positions = torch.arange(gen_seq_length).unsqueeze(1).expand(gen_seq_length, N).to(config.device)
            embedding_gen = model.dropout2(model.embed_trg(trg_tensor) + model.pe_trg(gen_positions))
            gen_mask = model.generate_square_subsequent_mask(gen_seq_length).to(config.device)

            # Generate output from the decoder
            output = model.decoder(embedding_gen, memory, tgt_mask=gen_mask, memory_key_padding_mask=src_padding_mask)
            output = model.fc_out(output)
            probabilities = F.softmax(output / config.temperature, dim=2)

            # Process token probabilities
            next_word = torch.argmax(probabilities[-1], dim=1)
            max_prob = probabilities[-1].gather(1, next_word.unsqueeze(-1)).squeeze()
            max_token_prob.append(max_prob)

            # Update target tensor with max probability token
            trg_tensor_max = torch.cat((trg_tensor_max, next_word.unsqueeze(0)), dim=0)

            # Correct token probability
            if idx <= real_trg.shape[0]:
                corr_probability = probabilities[-1].gather(1, real_trg[idx].unsqueeze(-1)).squeeze()
                corr_token_prob.append(corr_probability)

            # Update target tensor with actual next token
            next_word = real_trg[idx].unsqueeze(0)
            trg_tensor = torch.cat((trg_tensor, next_word), dim=0)

    #import IPython; IPython.embed();
    # Organize and return probabilities
    max_token_prob = torch.stack(max_token_prob)#.transpose(0, 1)
    corr_token_prob = torch.stack(corr_token_prob)#.transpose(0, 1)
    #import IPython; IPython.embed();

    # Remove <SOS> token from target sequences
    trg_tensor = trg_tensor.transpose(0, 1)[:, 1:]
    trg_tensor_max = trg_tensor_max.transpose(0, 1)[:, 1:]
    return trg_tensor, corr_token_prob, trg_tensor_max, max_token_prob


def run_model_analysis(config, model_MMT, val_dataloader, stoi, itos):
    #import IPython; IPython.embed();
    print("calculate_corr_max_prob")

    prob_dict_results = calculate_corr_max_prob(config, model_MMT, val_dataloader, stoi, itos)
    try:
        final_prob_max_multi_sum = sum(prob_dict_results["aggregated_max_prob_multi"])
        final_prob_max_multi_avg = statistics.mean(prob_dict_results["aggregated_max_prob_multi"])
        final_prob_max_avg_avg = statistics.mean(prob_dict_results["aggregated_max_prob_avg"])
        final_prob_corr_multi_sum = sum(prob_dict_results["aggregated_corr_prob_multi"])
        final_prob_corr_multi_avg = statistics.mean(prob_dict_results["aggregated_corr_prob_multi"])
        final_prob_corr_avg_avg = statistics.mean(prob_dict_results["aggregated_corr_prob_avg"])
    except:
        print("failed statistics")

    print("evaluate_greedy_2")
    results_dict, src_HSQC_list, src_COSY_list = evaluate_greedy_2(model_MMT, stoi, itos, val_dataloader, config, randomize=False)

    trg_conv_SMI_list = results_dict["trg_conv_SMI_list"]
    gen_conv_SMI_list = results_dict["gen_conv_SMI_list"]

    print("generate_df_for_HSQC_calculations")
    df_succ_smis, src_HSQC_list, failed_list = generate_df_for_HSQC_calculations(gen_conv_SMI_list, trg_conv_SMI_list, src_HSQC_list)
    df_succ_smis, src_COSY_list, failed_list = generate_df_for_COSY_calculations(gen_conv_SMI_list, trg_conv_SMI_list, src_COSY_list)
    tensor_HSQC = prepare_HSQC_data_from_src_2(src_HSQC_list)
    tensor_COSY = prepare_COSY_data_from_src_2(src_COSY_list)

    print("run_sgnn_sim_calculations_if_possible_return_spectra")
    #sgnn_avg_sim_error, HSQC_sim_error_list = ttf.run_sgnn_sim_calculations_if_possible(df_succ_smis, tensor_HSQC, vgmmt.sgnn_means_stds, config)
    avg_sim_error_HSQC, avg_sim_error_COSY, HSQC_sim_error_list, COSY_sim_error_list, batch_data = ttf.run_sgnn_sim_calculations_if_possible_return_spectra(df_succ_smis, tensor_HSQC, tensor_COSY, vgmmt.sgnn_means_stds, config)

    results_dict["HSQC_sim_error_list"] = HSQC_sim_error_list
    results_dict["COSY_sim_error_list"] = COSY_sim_error_list
    results_dict["batch_data"] = batch_data
    results_dict["df_succ_smis"] = df_succ_smis
    results_dict["tensor_HSQC"] = tensor_HSQC
    results_dict["tensor_COSY"] = tensor_COSY
    return prob_dict_results, results_dict



def load_data(config, stoi, stoi_MF, single=True, mode="val"):
    """Loads the dataset and Multimodal Transformer (MMT) model."""

    # Load and prepare the validation dataset
    data = MultimodalData(config, stoi, stoi_MF, mode=mode)
    if single:
        dataloader = DataLoader(data, 
                                    batch_size=1, 
                                    shuffle=False, 
                                    collate_fn=collate_fn,
                                    drop_last=True, 
                                    worker_init_fn=worker_init_fn)
    else:
        dataloader = DataLoader(data, 
                            batch_size=config.batch_size, 
                            shuffle=False, 
                            collate_fn=collate_fn,
                            drop_last=False, 
                            worker_init_fn=worker_init_fn)
    return dataloader



def load_MMT_model(config):
    """Loads the dataset and Multimodal Transformer (MMT) model."""
    # Initialize and load the multi-GPU model

    multi_gpu_model = TransformerMultiGPU(config)
    multi_gpu_model = multi_gpu_model.load_from_checkpoint(config.checkpoint_path, config=config)
    multi_gpu_model.model.to("cuda")

    return multi_gpu_model.model



def load_CLIP_model(config):

    CLIP_multi_gpu_model = CLIPMultiGPU(config)
    checkpoint_path = config.CLIP_model_path
    CLIP_model = CLIP_multi_gpu_model.load_from_checkpoint(config=config, checkpoint_path=checkpoint_path)

    #CLIP_model, optimizer = CLIP_make(config, stoi, stoi_MF, itos)
    CLIP_model.to(config.device)

    return CLIP_model.CLIP_model



def load_BLIP_model(config):

    BLIP_multi_gpu_model = BLIPMultiGPU(config)
    checkpoint_path = config.BLIP_model_path
    BLIP_model = BLIP_multi_gpu_model.load_from_checkpoint(config=config, checkpoint_path=checkpoint_path)

    #CLIP_model, optimizer = CLIP_make(config, stoi, stoi_MF, itos)
    BLIP_model.to(config.device)

    return BLIP_model.BLIP_model


def run_test_mns_performance_CLIP_3(config, 
                                model_MMT,
                                model_CLIP,
                                val_dataloader,                                
                                 stoi, 
                                 itos, 
                                 MW_filter):
    ### Same code as function: run_multinomial_sampling
    n_times = config.multinom_runs
    results_dict = {} #defaultdict(list)
    temperature_orig = config.temperature
    for idx, data_dict in enumerate(val_dataloader):
        if idx % 10 == 0:
            print(idx)
        gen_conv_SMI_list, trg_conv_SMI_list, token_probs_list, src_HSQC_list, prob_list = [], [], [], [], []
        data_dict_dup = rbgvm.duplicate_dict(data_dict, 128)
        trg_enc_SMI = data_dict["trg_enc_SMI"][0]
        trg_conv_SMI = hf.tensor_to_smiles(trg_enc_SMI[1:], itos)
        # to confirm that this smies is valid
        if Chem.MolFromSmiles(trg_conv_SMI) == None:
            continue
        memory, src_padding_mask, trg_enc_SMI, fingerprint, src_HSQC, src_COSY = vgmmt.run_model(model_MMT,
                                                                                    data_dict_dup, 
                                                                                    config)
        counter = 1
        while len(gen_conv_SMI_list)<n_times:
            # increase the temperature if not enough different molecules get generated               
            print(counter, len(gen_conv_SMI_list), config.temperature)
            if counter%20==0:
                print(trg_conv_SMI)
                break
            multinom_tensor, multinom_token_prob = multinomial_sequence_multi_2(model_MMT, memory, src_padding_mask, stoi, config)
            gen_conv_SMI, token_probs = hf.tensor_to_smiles_and_prob_2(multinom_tensor, multinom_token_prob, itos)

            # import IPython; IPython.embed();
            gen_conv_SMI, token_probs = filter_probs_and_valid_smiles_and_canonicolize(gen_conv_SMI, token_probs)   ### for 10.000 commented out
            if MW_filter == True:
                gen_conv_SMI, token_probs = filter_for_MW_2(trg_conv_SMI, gen_conv_SMI, token_probs)
            gen_conv_SMI_list.extend(gen_conv_SMI)
            prob_list.extend(token_probs)
            gen_conv_SMI_list, prob_list = deduplicate_smiles(gen_conv_SMI_list, prob_list)
            #gen_conv_SMI_list = list(set(gen_conv_SMI_list)) ### for 10.000 commented out
            counter += 1
            config.temperature = config.temperature + 0.1  
        config.temperature = temperature_orig
        gen_conv_SMI_list = gen_conv_SMI_list[:n_times]
        prob_list = prob_list[:n_times]
        trg_conv_SMI_list = [trg_conv_SMI for i in range(len(gen_conv_SMI_list))]
        data_dict_dup_CLIP = rbgvm.duplicate_dict(data_dict, len(gen_conv_SMI_list))
        if len(gen_conv_SMI_list) != 0:
            mean_loss, losses, logits, targets, dot_similarity= model_CLIP.inference(data_dict_dup_CLIP, 
                                                                                gen_conv_SMI_list)

            combined_list = [[smile, num.item(), dot_sim.item(), prob] for smile, num, dot_sim, prob in zip(gen_conv_SMI_list, losses, dot_similarity, prob_list)]
            ### Sort by the lowest similarity
            #sorted_list = sorted(combined_list, key=lambda x: x[1])

            combined_list, failed_combined_list = add_tanimoto_similarity(trg_conv_SMI, combined_list)
            combined_list, batch_data = rbgvm.add_HSQC_COSY_error(config, combined_list, data_dict_dup, gen_conv_SMI_list, trg_conv_SMI, config.multinom_runs) # config.MMT_batch

            sorted_list = sorted(combined_list, key=lambda x: -x[4]) # SMILES = 0, losses =1, dot_sim= 2, propb = 3, tanimoto = 4

            results_dict[trg_conv_SMI] = [sorted_list, batch_data]
        else:
            results_dict[trg_conv_SMI] = [None, None]

    return results_dict

def run_test_performance_CLIP_greedy_3(config, 
                                 model_MMT,
                                 model_CLIP,
                                 val_dataloader,                                        
                                 stoi, 
                                 stoi_MF, 
                                 itos, 
                                 itos_MF):


    results_dict = defaultdict(list)
    gen_conv_SMI_list = []
    token_probs_list = []
    src_HSQC_list = []
    failed = []

    # generate all the smiles of trg and greedy gen
    for idx, data_dict in enumerate(val_dataloader):
        if idx % 10 == 0:
            print(idx)
        data_dict_dup = rbgvm.duplicate_dict(data_dict, 1)  # Maybe should hardcode it here as 64 - it will always cut it down to the number needed with ntimes

        memory, src_padding_mask, trg_enc_SMI, fingerprint, src_HSQC = vgmmt.run_model(model_MMT,
                                                                        data_dict_dup, 
                                                                        config)
        trg_enc_SMI = data_dict["trg_enc_SMI"][0]
        trg_conv_SMI = hf.tensor_to_smiles(trg_enc_SMI[1:], itos)

        greedy_tensor, greedy_token_prob = greedy_sequence_2(model_MMT, stoi, itos, memory, src_padding_mask, config)
        gen_conv_SMI, token_probs = hf.tensor_to_smiles_and_prob(greedy_tensor.squeeze(1), greedy_token_prob, itos)
        tan_sim = try_calculate_tanimoto_from_two_smiles(trg_conv_SMI, gen_conv_SMI, 512, extra_info = False)
        mean_loss, losses, logits, targets, dot_similarity = model_CLIP.inference(data_dict_dup, gen_conv_SMI)

        sgnn_avg_sim_error, sim_error_list, batch_data = rbgvm.calculate_HSQC_error(config, data_dict_dup, gen_conv_SMI)
        combined_list =[gen_conv_SMI, losses.item(), dot_similarity.item(), tan_sim, sgnn_avg_sim_error]
        if tan_sim == None:
            failed.append([trg_conv_SMI, gen_conv_SMI, combined_list, batch_data])
            continue
        else:
            results_dict[trg_conv_SMI] = [combined_list, batch_data]
    #if i == config.n_samples:
    #    break
    #import IPython; IPython.embed();

    return results_dict, failed

def predict_corr_max_performance_metric(trg_tensor, corr_token_prob, max_token_prob, stoi):
    """
    Calculates performance metrics for token predictions in sequences.

    Parameters:
    trg_tensor (torch.Tensor): The target tensor containing sequences.
    corr_token_prob (torch.Tensor): Probabilities of correct tokens.
    max_token_prob (torch.Tensor): Probabilities of tokens with maximum likelihood.
    multinom_token_prob (torch.Tensor): Probabilities of tokens chosen via multinomial sampling.

    Returns:
    tuple: Contains two dictionaries - one with aggregated probabilities and another with sample-wise probabilities.
    """

    # Initialize dictionaries for storing results
    prop_dict, sample_dict = {}, {}
    sample_prob_list_corr, sample_prob_list_max = [], []
    prob_corr_multi, prob_corr_avg, prob_max_multi, prob_max_avg = [], [], [], []

    # Iterate over each sequence to calculate probabilities
    seq_corr_probs, seq_max_probs, seq_multinom_probs = [], [], []

    for corr_prob, max_prob, token in zip(corr_token_prob, max_token_prob, trg_tensor[0]):
        # Initialize lists for individual sequence probabilities
        #import IPython; IPython.embed();
        # Iterate over each token in the sequence
        #for idx, (corr_prob, max_prob, multinom_prob, token) in enumerate(zip(corr_probs, max_probs, multinom_probs, tokens)):
        if token == stoi["<EOS>"]:  # End of sequence
            break
        seq_corr_probs.append(corr_prob.item())
        seq_max_probs.append(max_prob.item())

    # Calculate and append aggregated probabilities
    prob_corr_multi = reduce(operator.mul, seq_corr_probs, 1)
    prob_corr_avg = statistics.mean(seq_corr_probs) if seq_corr_probs else 0
    prob_max_multi = reduce(operator.mul, seq_max_probs, 1)
    prob_max_avg = statistics.mean(seq_max_probs) if seq_max_probs else 0

    # Populate dictionaries with calculated probabilities
    sample_dict = {
        "sample_prob_list_corr": seq_corr_probs,
        "sample_prob_list_max": seq_max_probs    
        }

    prop_dict = {
        "prob_corr_multi": prob_corr_multi,
        "prob_corr_avg": prob_corr_avg,
        "prob_max_multi": prob_max_multi,
        "prob_max_avg": prob_max_avg
        }

    return prop_dict, sample_dict

## There is something wrong with the predict_corr_max_performance_metric function
def calculate_corr_max_prob_2(config, model, stoi, val_dataloader, gen_num):
    
    prob_dict_results = {}
    aggregated_corr_prob_multi = []
    aggregated_corr_prob_avg = []
    aggregated_max_prob_multi = []
    aggregated_max_prob_avg = []
    #for _ in range(2):  # Num_Runs is the number of times you want to run the entire process for randomized smiles

    for idx, data_dict in enumerate(tqdm(val_dataloader)):
        memory, src_padding_mask, trg_enc_SMI, fingerprint, src_HSQC, src_COSY = vgmmt.run_model(model, data_dict, config) 
        trg_tensor, corr_token_prob, trg_tensor_max, max_token_prob = predict_prop_correct_max_sequence_3(model, stoi, memory, src_padding_mask, trg_enc_SMI, config)
        #import IPython; IPython.embed();

        prop_dict, sample_dict = predict_corr_max_performance_metric(trg_tensor, corr_token_prob, max_token_prob, stoi)
        aggregated_corr_prob_multi.append(prop_dict["prob_corr_multi"])
        aggregated_corr_prob_avg.append(prop_dict["prob_corr_avg"])  
        aggregated_max_prob_multi.append(prop_dict["prob_max_multi"])
        aggregated_max_prob_avg.append(prop_dict["prob_max_avg"]) 
    #import IPython; IPython.embed();

    prob_dict_results["aggregated_corr_prob_multi"] = aggregated_corr_prob_multi
    prob_dict_results["aggregated_corr_prob_avg"] = aggregated_corr_prob_avg
    prob_dict_results["aggregated_max_prob_multi"] = aggregated_max_prob_multi
    prob_dict_results["aggregated_max_prob_avg"] = aggregated_max_prob_avg
    #import IPython; IPython.embed();

    return prob_dict_results, sample_dict


def run_test_performance_CLIP_3(config, 
                                model_MMT,
                                val_dataloader,                                   
                                stoi):


    total_results = {}
    # Number of times to duplicate each tensor 
    gen_num = 1 # Multinomial sampling number
    prob_dict_results, _ = calculate_corr_max_prob_2(config, model_MMT, stoi, val_dataloader, gen_num)
    try:
        final_prob_max_multi_sum = sum(prob_dict_results["aggregated_max_prob_multi"])
        final_prob_max_multi_avg = statistics.mean(prob_dict_results["aggregated_max_prob_multi"])
        final_prob_max_avg_avg = statistics.mean(prob_dict_results["aggregated_max_prob_avg"])
        final_prob_corr_multi_sum = sum(prob_dict_results["aggregated_corr_prob_multi"])
        final_prob_corr_multi_avg = statistics.mean(prob_dict_results["aggregated_corr_prob_multi"])
        final_prob_corr_avg_avg = statistics.mean(prob_dict_results["aggregated_corr_prob_avg"])
    except:
        print("failed statistics")
    total_results["statistics_multiplication_avg"] = [final_prob_corr_multi_avg,
                                                    final_prob_max_multi_avg]

    total_results["statistics_multiplication_sum"] = [final_prob_corr_multi_sum,
                                                    final_prob_max_multi_sum]

    total_results["statistics_avg_avg"] = [final_prob_corr_avg_avg,
                                            final_prob_max_avg_avg]
    return total_results


#__________________________________________________
### Sample 1000 unique molecules and plot their tanimoto similartyy 
### compared to the target
def run_multinomial_sampling(config, model_MMT, val_dataloader, itos, stoi, MW_filter=False):
    n_times = config.multinom_runs
    results_dict = defaultdict(list)
    temperature_orig = config.temperature
    for idx, data_dict in tqdm(enumerate(val_dataloader)):
        if idx % 10 ==0:
            print(idx)

        gen_conv_SMI_list, trg_conv_SMI_list, token_probs_list, src_HSQC_list, prob_list = [], [], [], [], []
        data_dict_dup = rbgvm.duplicate_dict(data_dict, 128)
        trg_enc_SMI = data_dict["trg_enc_SMI"][0]
        trg_conv_SMI = hf.tensor_to_smiles(trg_enc_SMI[1:], itos)
        memory, src_padding_mask, trg_enc_SMI, fingerprint, src_HSQC = vgmmt.run_model(model_MMT,
                                                                   data_dict_dup, 
                                                                   config)
        counter = 1
        while len(gen_conv_SMI_list)<n_times:
            # increase the temperature if not enough different molecules get generated               
            if counter%20==0:
                print(trg_conv_SMI)
                break
            multinom_tensor, multinom_token_prob = multinomial_sequence_multi_2(model_MMT, memory, src_padding_mask, stoi, config)
            gen_conv_SMI, token_probs = hf.tensor_to_smiles_and_prob_2(multinom_tensor, multinom_token_prob, itos)

            # import IPython; IPython.embed();
            gen_conv_SMI, token_probs = filter_probs_and_valid_smiles_and_canonicolize(gen_conv_SMI, token_probs)   ### for 10.000 commented out
            if MW_filter == True:
                gen_conv_SMI, token_probs = filter_for_MW_2(trg_conv_SMI, gen_conv_SMI, token_probs)
            gen_conv_SMI_list.extend(gen_conv_SMI)
            prob_list.extend(token_probs)
            gen_conv_SMI_list, prob_list = deduplicate_smiles(gen_conv_SMI_list, prob_list)
            #gen_conv_SMI_list = list(set(gen_conv_SMI_list)) ### for 10.000 commented out
            counter += 1
            config.temperature = config.temperature + 0.1    
        gen_conv_SMI_list = gen_conv_SMI_list[:n_times]
        prob_list = prob_list[:n_times]
        trg_conv_SMI_list = [trg_conv_SMI for i in range(len(gen_conv_SMI_list))]

        tanimoto_mean, tanimoto_std_dev, failed, tanimoto_list_all = vgmmt.calculate_tanimoto_similarity(gen_conv_SMI_list, trg_conv_SMI_list)
        results_dict[idx].append({
            'gen_conv_SMI_list': gen_conv_SMI_list,
            'trg_conv_SMI_list': trg_conv_SMI_list,
            'tanimoto_sim': tanimoto_list_all,
            'tanimoto_mean': tanimoto_mean,
            'tanimoto_std_dev': tanimoto_std_dev,
            'failed': failed,
            'prob_list': prob_list,
        })
        config.temperature = temperature_orig 
    return config, results_dict



# Multinomial
def multinomial_sequence_multi_2(model, memory, src_padding_mask, stoi, config):
    # Initialization

    model.eval()
    N = memory.size(1)
    multinom_tensor = torch.full((1, N), stoi["<SOS>"], dtype=torch.long, device=config.device)
    multinom_token_prob = []  

    # Sequence Prediction
    with torch.no_grad():
        for idx in range(0, config.max_len):
            # [The same logic you already had]
            gen_seq_length, N = multinom_tensor.shape
            gen_positions = (
                torch.arange(0, gen_seq_length)
                .unsqueeze(1)
                .expand(gen_seq_length, N)
                .to(config.device))

            embedding_gen = model.dropout2((model.embed_trg(multinom_tensor) + model.pe_trg(gen_positions)))
            gen_mask = model.generate_square_subsequent_mask(gen_seq_length).to(config.device)
            output = model.decoder(embedding_gen, memory, tgt_mask=gen_mask, memory_key_padding_mask=src_padding_mask)
            output = model.fc_out(output)

            probabilities = F.softmax(output / config.temperature, dim=2)
            ## Capturing the probability of the next predicted token with multinomial sampling
            next_word = torch.multinomial(probabilities[-1, :, :], 1)  
            sel_prob = probabilities[-1, :, :].gather(1, next_word).squeeze()  # Get the probability of the predicted token
            multinom_token_prob.append(sel_prob)
            next_word = next_word.squeeze(1).unsqueeze(0)
            multinom_tensor = torch.cat((multinom_tensor, next_word), dim=0)
    # import IPython; IPython.embed();
    multinom_token_prob = torch.stack(multinom_token_prob) 

    # remove "SOS" token
    multinom_tensor = multinom_tensor[1:,:]
    multinom_token_prob = multinom_token_prob[1:,:]
    
    return multinom_tensor, multinom_token_prob


def run_greedy_sampling(config, model_MMT, val_dataloader, itos, stoi):

    results_dict = defaultdict(list)
    #model_MMT, val_dataloader = vgmmt.load_data_and_MMT_model(config, stoi, stoi_MF, single=True, mode="val")

    gen_conv_SMI_list, trg_conv_SMI_list, prob_list, src_HSQC_list, src_COSY_list = [], [], [], [], []
    for idx, data_dict in enumerate(val_dataloader):
        trg_enc_SMI = data_dict["trg_enc_SMI"]
        trg_conv_SMI = hf.tensor_to_smiles(trg_enc_SMI.T[1:], itos)
        memory, src_padding_mask, trg_enc_SMI, fingerprint, src_HSQC, src_COSY = vgmmt.run_model(model_MMT,
                                                                   data_dict, 
                                                                   config)

        greedy_tensor, greedy_token_prob = greedy_sequence_2(model_MMT, stoi, itos, memory, src_padding_mask, config)

        gen_conv_SMI, token_probs = hf.tensor_to_smiles_and_prob(greedy_tensor, greedy_token_prob.T, itos)

        gen_conv_SMI_list.extend(gen_conv_SMI)
        trg_conv_SMI_list.extend(trg_conv_SMI)
        prob_list.extend(token_probs)
        src_HSQC_list.extend(src_HSQC)
        src_COSY_list.extend(src_COSY)
        
    tanimoto_mean, tanimoto_std_dev, failed, tanimoto_list_all = vgmmt.calculate_tanimoto_similarity(gen_conv_SMI_list, trg_conv_SMI_list)
    results_dict = {
        'gen_conv_SMI_list': gen_conv_SMI_list,
        'trg_conv_SMI_list': trg_conv_SMI_list,
        'tanimoto_sim': tanimoto_list_all,
        'tanimoto_mean': tanimoto_mean,
        'tanimoto_std_dev': tanimoto_std_dev,
        'failed': failed,
        'prob_list': prob_list,
        "src_HSQC_list": src_HSQC_list,
        "src_COSY_list": src_COSY_list,
        }
    
    return config, results_dict

    

def greedy_sequence_2(model, stoi, itos, memory, src_padding_mask, config):
    """
    Generates a sequence of tokens using a greedy approach.
    """
    model.eval()
    N = memory.size(1)  # Batch size
    greedy_tensor = torch.full((1, N), stoi["<SOS>"], dtype=torch.long, device=config.device)
    greedy_token_prob = []

    with torch.no_grad():
        for _ in range(config.max_len):
            gen_seq_length = greedy_tensor.size(0)
            gen_positions = torch.arange(gen_seq_length, device=config.device).unsqueeze(1).expand(gen_seq_length, N)
            embedding_gen = model.embed_trg(greedy_tensor) + model.pe_trg(gen_positions)
            if model.training:
                embedding_gen = model.dropout2((embedding_gen))
            gen_mask = model.generate_square_subsequent_mask(gen_seq_length).to(config.device)

            output = model.decoder(embedding_gen, memory, tgt_mask=gen_mask, memory_key_padding_mask=src_padding_mask)
           
            output = model.fc_out(output)
            probabilities = F.softmax(output / config.temperature, dim=2)

            next_word = torch.argmax(probabilities[-1, :, :], dim=1)
            max_prob = probabilities[-1, :, :].gather(1, next_word.unsqueeze(-1)).squeeze()
            greedy_token_prob.append(max_prob)
            next_word = next_word.unsqueeze(0)

            greedy_tensor = torch.cat((greedy_tensor, next_word), dim=0)
            if (next_word == 0).all():
                break

    greedy_token_prob = torch.stack(greedy_token_prob)
    #import IPython; IPython.embed();

    # Remove "SOS" token
    greedy_tensor = greedy_tensor[1:]
    greedy_token_prob = greedy_token_prob


    # Handle single molecule case
    if N == 1:
        greedy_tensor = greedy_tensor#.squeeze(1)
        greedy_token_prob = greedy_token_prob.unsqueeze(-1)
    # else:
    #     greedy_tensor = greedy_tensor.transpose(0, 1)
    #     greedy_token_prob = greedy_token_prob.transpose(0, 1)
    
    return greedy_tensor, greedy_token_prob


def deduplicate_smiles(smiles_list, prob_list):
    # Create a dictionary to hold unique smiles and their corresponding probabilities
    unique_smiles = {}

    # Loop over the SMILES and their corresponding probabilities
    for smi, prob in zip(smiles_list, prob_list):
        if smi not in unique_smiles:
            unique_smiles[smi] = prob

    # Extracting the deduplicated lists
    deduped_smiles = list(unique_smiles.keys())
    deduped_probs = list(unique_smiles.values())

    return deduped_smiles, deduped_probs

# Function to filter valid SMILES
def filter_probs_and_valid_smiles_and_canonicolize(smiles_list, token_probs, canonical=True, isomericSmiles=False):
    valid_smiles = []
    valid_token_probs = []
    for smi, prob in zip(smiles_list, token_probs):
        mol = Chem.MolFromSmiles(smi)
        if mol is not None:
            gen_smi = Chem.MolToSmiles(mol, canonical=canonical, doRandom=False, isomericSmiles=isomericSmiles)
            valid_smiles.append(gen_smi)
            valid_token_probs.append(prob)
    return valid_smiles, valid_token_probs


# Function to calculate the rounded molecular weight
def calc_rounded_mw(smi):
    try:
        mol = Chem.MolFromSmiles(smi)
        return round(Descriptors.MolWt(mol))
    except:
        return None



def filter_for_MW(trg_conv_SMI, gen_conv_SMI):

    # Calculate the rounded molecular weight of the target molecule
    trg_mw = calc_rounded_mw(trg_conv_SMI)

    # Filter the list based on molecular weight
    filtered_gen_smis = [smi for smi in gen_conv_SMI if calc_rounded_mw(smi) == trg_mw]  
    return filtered_gen_smis


def filter_for_MW_2(trg_conv_SMI, gen_conv_SMI, prob_list):

    # Calculate the rounded molecular weight of the target molecule
    trg_mw = calc_rounded_mw(trg_conv_SMI)

    # Filter the list based on molecular weight
    filtered_gen_smis = [smi for (smi, token_prob) in zip(gen_conv_SMI, prob_list) if calc_rounded_mw(smi) == trg_mw]  
    filtered_prob_list = [token_prob for (smi, token_prob) in zip(gen_conv_SMI, prob_list) if calc_rounded_mw(smi) == trg_mw]  
    return filtered_gen_smis, filtered_prob_list

    
def filter_dict_results_for_MW(results_dict):
    tanimoto_values = []
    gen_trg_smi_lists = []
    for i, idx in enumerate(results_dict.keys()):
        # Collect all tanimoto_list_all for the current idx

        for result in results_dict[idx]:
            trg_smi = result["trg_conv_SMI_list"][0]
            smi_list = result["gen_conv_SMI_list"]

            # Calculate the rounded molecular weight of the target molecule
            trg_mw = calc_rounded_mw(trg_smi)

            # Filter the list based on molecular weight
            filtered_gen_smis = [smi for smi in smi_list if calc_rounded_mw(smi) == trg_mw]  
            trg_conv_SMI_list = [trg_smi for i in range(len(filtered_gen_smis))]
            gen_trg_smi_lists.append([filtered_gen_smis, trg_conv_SMI_list])
            tanimoto_mean, tanimoto_std_dev, failed, tanimoto_list_all = vgmmt.calculate_tanimoto_similarity(filtered_gen_smis, trg_conv_SMI_list)
            tanimoto_values.append(tanimoto_list_all)
    return tanimoto_values, gen_trg_smi_lists


def plot_hist_MN_sampling(config, results_dict):
    # Number of unique idx values
    num_idx = len(results_dict)

    # Create a figure with subplots (if there are many idx, you might need to adjust the size and layout)
    plt.figure(figsize=(10, 6))

    for i, idx in enumerate(results_dict.keys()):
        # Collect all tanimoto_list_all for the current idx
        tanimoto_values = []
        for result in results_dict[idx]:
            tanimoto_values.extend(result['tanimoto_sim'])
        #tanimoto_values = [val for result in results_dict[idx] for val in result['tanimoto_sim'] if val != 0]

        # Create a subplot for each idx
       
        plt.hist(tanimoto_values, bins=100+idx*3, alpha=0.7, label=f'Idx {idx}, Nr: {len(tanimoto_values)}')
    plt.xlabel('Tanimoto Similarity')
    plt.ylabel('Frequency')
    plt.title(f'Histogram of Tanimoto Similarity for {idx+1} Molecule generations')
    plt.legend()

    # Show the plot
    plt.tight_layout()
    plt.show()

    
def plot_hist_MN_sampling_all(config, results_dict):
    # Number of unique idx values
    num_idx = len(results_dict)

    # Create a figure with subplots (if there are many idx, you might need to adjust the size and layout)
    plt.figure(figsize=(10, 6))
    tanimoto_values = []
    for i, key in enumerate(results_dict.keys()):
        # Collect all tanimoto_list_all for the current idx
        for result in results_dict[key]:
            tanimoto_values.extend(result['tanimoto_sim'])
        #tanimoto_values = [val for result in results_dict[idx] for val in result['tanimoto_sim'] if val != 0]

        # Create a subplot for each idx
       
    plt.hist(tanimoto_values, bins=100+idx*3, alpha=0.7, label=f'Idx {idx}, Nr: {len(tanimoto_values)}')
    plt.xlabel('Tanimoto Similarity')
    plt.ylabel('Frequency')
    plt.title(f'Histogram of Tanimoto Similarity for {i+1} Molecule generations')
    plt.legend()

    # Show the plot
    plt.tight_layout()
    plt.show()
    
def plot_hist_MN_sampling_filtered(config, results_dict):
    ##filter out all molecules with the same MW as the target

    # Number of unique idx values
    num_idx = len(results_dict)
    tanimoto_values, gen_trg_smi_lists = filter_dict_results_for_MW(results_dict)

    # Plotting the histogram
    plt.figure(figsize=(10, 6))

    # Plot each list in the histogram
    for idx, lst in enumerate(tanimoto_values):
        
        plt.hist(lst, bins=20, alpha=0.5, label=f'List {idx+1}, Nr: {len(lst)}')

    # Adding labels and title
    plt.xlabel('Tanimoto Similarity')
    plt.ylabel('Frequency')
    plt.title(f'Histogram of Tanimoto Similarity with same MW')
    plt.legend()
    plt.tight_layout()

    # Show the plot
    plt.show()
    return tanimoto_values, gen_trg_smi_lists


def calc_percentage_top_x_correct(results_dict, top_x):
    count_yes = 0
    count_no = 0
    for i, idx in enumerate(results_dict.keys()):
        # Collect all tanimoto_list_all for the current idx

        for result in results_dict[idx]:
            tanimoto_sim = result["tanimoto_sim"]
            tanimoto_sim = tanimoto_sim[:top_x]
            if 1 in tanimoto_sim:
                count_yes += 1
            else:
                count_no += 1
        percentage = count_yes/(count_yes +count_no)
    return percentage
    
def calc_percentage_top_x_correct_greedy(results_dict):
    count_yes = 0
    count_no = 0

    for i in results_dict["tanimoto_sim"]:
        if i==1:
            count_yes += 1
        else:
            count_no += 1

    percentage = count_yes/(count_yes + count_no)
    return percentage
    

def run_precentage_calculation(config, itos, stoi, stoi_MF, MW_filter):
    
    model_MMT = load_MMT_model(config)
    val_dataloader = load_data(config, stoi, stoi_MF, single=True, mode="val")  
    config.temperature = 1
    config, results_dict_mns_10 = run_multinomial_sampling(config, model_MMT, val_dataloader, itos, stoi, MW_filter=MW_filter)
    top_x = 10
    percentage_top_10 = calc_percentage_top_x_correct(results_dict_mns_10, top_x)
    top_x = 5
    percentage_top_5 = calc_percentage_top_x_correct(results_dict_mns_10, top_x)
    top_x = 3
    percentage_top_3 = calc_percentage_top_x_correct(results_dict_mns_10, top_x)
    top_x = 1
    percentage_top_1 = calc_percentage_top_x_correct(results_dict_mns_10, top_x)
    
    val_dataloader_multi = load_data(config, stoi, stoi_MF, single=False, mode="val")  
    config, results_dict_greedy = run_greedy_sampling(config, model_MMT, val_dataloader_multi, itos, stoi)

    percentage_1_greedy = calc_percentage_top_x_correct_greedy(results_dict_greedy)

    percentage_collection = [percentage_1_greedy, percentage_top_1, percentage_top_3, percentage_top_5, percentage_top_10]

    return percentage_collection, results_dict_mns_10, results_dict_greedy

    
def add_tanimoto_similarity(trg_conv_SMI, combined_list):
    # Generate fingerprint for the ground truth molecule
    ground_truth_mol = Chem.MolFromSmiles(trg_conv_SMI)
    ground_truth_fp = AllChem.GetMorganFingerprintAsBitVect(ground_truth_mol, 2, nBits=512)
    # Calculate Tanimoto similarity and add to the list
    new_combined_list = []
    failed_combined_list = []
    for item in combined_list:

        smiles = item[0]
        mol = Chem.MolFromSmiles(smiles)
        if mol:  # Check if the molecule is valid
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=512)
            tanimoto_similarity = DataStructs.TanimotoSimilarity(ground_truth_fp, fp)
            item.append(tanimoto_similarity)
            new_combined_list.append(item)
        else:
            item.append(None)
            failed_combined_list.append(item)
    return new_combined_list, failed_combined_list
#print(combined_list)


def try_calculate_tanimoto_from_two_smiles(smi1, smi2, nbits, extra_info = False):
    """This function takes two smile_stings and 
    calculates the Tanimoto similarity and returns it and prints it out"""
    
    try:
        pattern1 = Chem.MolFromSmiles(smi1)
        pattern2 = Chem.MolFromSmiles(smi2)
        fp1 = AllChem.GetMorganFingerprintAsBitVect(pattern1, 2, nBits=nbits)
        fp2 = AllChem.GetMorganFingerprintAsBitVect(pattern2, 2, nBits=nbits)

        tan_sim = DataStructs.TanimotoSimilarity(fp1, fp2)
        tan_sim = round(tan_sim,4)
        if extra_info:
            print(f"Smiles 1: {smi1} \n Target Smiles: {smi2} \nTanimoto score:{tan_sim}")

        return tan_sim
    except:
        return None

from collections import defaultdict

def calculate_tanimoto_of_all_compared_to_trg(trg_conv_SMI, gen_conv_SMI_list):
    tani_list = []
    ground_truth_mol = Chem.MolFromSmiles(trg_conv_SMI)
    ground_truth_fp = AllChem.GetMorganFingerprintAsBitVect(ground_truth_mol, 2, nBits=512)
    for smi in gen_conv_SMI_list:
        #smiles = item[0]
        mol = Chem.MolFromSmiles(smi)
        if mol:  # Check if the molecule is valid
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=512)
            tanimoto_similarity = DataStructs.TanimotoSimilarity(ground_truth_fp, fp)
            tani_list.append(tanimoto_similarity)
    return tani_list


def run_mns_hit_counter_experiment(config, model_MMT, val_dataloader, itos, stoi, MW_filter, max_runs):
    percentage_collection = []
    molecule_tani_comparison_lists = []
    one_finder_list = []
    n_times = config.multinom_runs

    results_dict = defaultdict(list)
    # generate all the smiles of trg and greedy gen
    for i, data_dict in enumerate(val_dataloader):
        gen_conv_SMI_list = []
        gen_conv_SMI_list, trg_conv_SMI_list,  prob_list = [], [], []

        trg_enc_SMI = data_dict["trg_enc_SMI"][0]
        trg_conv_SMI = hf.tensor_to_smiles(trg_enc_SMI[1:], itos)

        ### multiply the input to paralellize the generation
        data_dict_dup = rbgvm.duplicate_dict(data_dict, 128)

        memory, src_padding_mask, trg_enc_SMI, fingerprint, src_HSQC = vgmmt.run_model(model_MMT,
                                                                       data_dict_dup, 
                                                                       config)
        counter = 0
        ### Here I increase the temperature in case if it does not generate enough diverse molecules but sets it back to the original
        ### value after enough molecules were found.
        temp_orig = config.temperature
        while len(gen_conv_SMI_list)<n_times and counter<max_runs:
            # increase the temperature if not enough different molecules get generated
            if counter %10 == 0:
                print(counter, len(gen_conv_SMI_list), config.temperature )
                
            multinom_tensor, multinom_token_prob = multinomial_sequence_multi_2(model_MMT, memory, src_padding_mask, stoi, config)
            gen_conv_SMI, token_probs = hf.tensor_to_smiles_and_prob_2(multinom_tensor, multinom_token_prob, itos)

            # import IPython; IPython.embed();
            gen_conv_SMI, token_probs = filter_probs_and_valid_smiles_and_canonicolize(gen_conv_SMI, token_probs)   ### for 10.000 commented out
            if MW_filter == True:
                gen_conv_SMI, token_probs = filter_for_MW_2(trg_conv_SMI, gen_conv_SMI, token_probs)
            gen_conv_SMI_list.extend(gen_conv_SMI)
            prob_list.extend(token_probs)
            gen_conv_SMI_list, prob_list = deduplicate_smiles(gen_conv_SMI_list, prob_list)
            
            tani_list = calculate_tanimoto_of_all_compared_to_trg(trg_conv_SMI, gen_conv_SMI_list)
            if 1 in tani_list:
                break
            #print(counter, len(gen_conv_SMI_list))
            counter += 1
            config.temperature = config.temperature + 0.1    
            
        try:
            one_finder_list.append(tani_list.index(1.0)+1)  # because indices start with 0
        except:
            one_finder_list.append(-5)
        trg_SMI_list = [trg_conv_SMI for i in range(len(gen_conv_SMI_list))]
        molecule_tani_comparison_lists.append([gen_conv_SMI_list,trg_SMI_list, tani_list])
        config.temperature = temp_orig 
        #break
    return one_finder_list, molecule_tani_comparison_lists

def sel_data_slice_and_save_as_csv(config):
    """ Saves the selected molecules in a new csv and replaces the csv_SMI_targets"""
    # File path
    file_path = config.csv_path_val  # Replace with your file path

    # Read the CSV file
    df = pd.read_csv(file_path)

    # Select the first X rows
    df_selected = df.head(config.data_size)

    # Save the selected rows to a new CSV file
    new_file_path = file_path.replace('.csv', f'_sel_{config.data_size}.csv')
    df_selected.to_csv(new_file_path, index=False)
    config.csv_SMI_targets = new_file_path
    return config



def filter_smiles(df, smi_list):
    """
    Filter out SMILES strings from smi_list that are present in the DataFrame df.

    Parameters:
    df (pandas.DataFrame): DataFrame with a 'SMILES' column.
    smi_list (list): List of SMILES strings.

    Returns:
    list: Filtered list of SMILES strings.
    """
    df_smiles_set = set(df['SMILES'])
    filtered_list = [smiles for smiles in smi_list if smiles not in df_smiles_set]
    return filtered_list


import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit.Chem import Draw

import matplotlib.pyplot as plt

def analyze_and_plot(results_dict, mode):
    ranks = []
    failed = []
    
    for key, value in results_dict.items():
        # Extracting second and third numbers from each sublist
        if mode == "HSQC_sim":
            similaritys = [item[4] for item in value]
        elif mode == "dot_sim":
            similaritys = [item[2] for item in value]
        tani_sims = [item[3] for item in value]
        
        if tani_sims[0] == 1:
            # First number
            first_number = similaritys[0]

            # Sorting the numbers in descending order
            if mode == "HSQC_sim":
                sorted_numbers = sorted(similaritys, reverse=False)
            elif mode == "dot_sim":
                sorted_numbers = sorted(similaritys, reverse=True)

            # Finding the rank of the first number
            rank_of_first_number = sorted_numbers.index(first_number) + 1  # Adding 1 because index starts from 0

            ranks.append(rank_of_first_number)
        else:
            failed.append([key, value])
            #print(third_numbers)
            
    # Plotting the histogram of ranks
    plt.hist(ranks, bins=range(1, len(value) + 2), align='left')
    plt.xlabel('Rank')
    plt.ylabel('Frequency')
    plt.title(f'Rank Histogram of {mode} corresponding correct Molecule')
    plt.xticks(range(1, len(value) + 1))
    plt.show()
    return failed
    
    
# Function to plot a molecule with additional data
def plot_molecule_with_data(smiles, cosine_sim, tanimoto, HSQC_error):
    mol = Chem.MolFromSmiles(smiles)
    cosine_sim_rounded = round(float(cosine_sim), 3) if cosine_sim else ''
    tanimoto_rounded = round(float(tanimoto), 3) if tanimoto else ''
    HSQC_error_rounded = round(float(HSQC_error), 4) if HSQC_error else ''
    fig, ax = plt.subplots()
    img = Draw.MolToImage(mol, size=(300, 300))
    ax.imshow(img)
    ax.axis('off')
    ax.set_title(f"CLIP: {cosine_sim_rounded}, Tanimoto: {tanimoto_rounded},  HSQC_err: {HSQC_error_rounded}")
    plt.show()

    ### PLOT Molecules

def plot_CLIP_molecules(results_dict, investigate_number, stop_nr):
    #investigate_number = 5
    # Plot the key SMILES


    # Plot the first item of each list with the third and fourth elements
    for idx, data_val in enumerate(results_dict.values()):
        
        if idx == investigate_number:
            key_smiles = list(results_dict.keys())[investigate_number]   
            plot_molecule_with_data(key_smiles, '', '', '')
            
            for idy, lists in enumerate(data_val[0]):
                first_smiles, _, cosine_sim, tanimoto, HSQC_error = lists
                plot_molecule_with_data(first_smiles, cosine_sim, tanimoto, HSQC_error)
                if idy == stop_nr:
                    break
                    
def generate_mns_list(config, 
                  model_MMT, 
                  val_dataloader,
                  stoi, 
                  itos, 
                  MW_filter):
    ### Same code as function: run_multinomial_sampling
    n_times = config.multinom_runs
    gen_dict = {} #defaultdict(list)
    temperature_orig = config.temperature
    for idx, data_dict in enumerate(val_dataloader):
        if idx % 10 == 0:
            print(idx)
        gen_conv_SMI_list, trg_conv_SMI_list, token_probs_list, src_HSQC_list, prob_list = [], [], [], [], []
        data_dict_dup = rbgvm.duplicate_dict(data_dict, 128)
        trg_enc_SMI = data_dict["trg_enc_SMI"][0]
        trg_conv_SMI = ttf.tensor_to_smiles(trg_enc_SMI[1:], itos)
        memory, src_padding_mask, trg_enc_SMI, fingerprint, src_HSQC = vgmmt.run_model(model_MMT,
                                                                   data_dict_dup, 
                                                                   config)
        counter = 1
        while len(gen_conv_SMI_list)<n_times:
            # increase the temperature if not enough different molecules get generated               
            print(counter, len(gen_conv_SMI_list), config.temperature)
            if counter%30==0:
                print(trg_conv_SMI)
                break

            multinom_tensor, multinom_token_prob = multinomial_sequence_multi_2(model_MMT, memory, src_padding_mask, stoi, config)
            gen_conv_SMI, token_probs = hf.tensor_to_smiles_and_prob_2(multinom_tensor, multinom_token_prob, itos)
            # import IPython; IPython.embed();
            gen_conv_SMI, token_probs = filter_probs_and_valid_smiles_and_canonicolize(gen_conv_SMI, token_probs)   ### for 10.000 commented out
            if MW_filter == True:
                gen_conv_SMI, token_probs = filter_for_MW_2(trg_conv_SMI, gen_conv_SMI, token_probs)
            gen_conv_SMI_list.extend(gen_conv_SMI)
            prob_list.extend(token_probs)
            gen_conv_SMI_list, prob_list = deduplicate_smiles(gen_conv_SMI_list, prob_list)
            #gen_conv_SMI_list = list(set(gen_conv_SMI_list)) ### for 10.000 commented out
            counter += 1
            config.temperature = config.temperature + 0.1  
            #print(time.time() -start)
            #start = time.time()
        config.temperature = temperature_orig

        gen_conv_SMI_list = gen_conv_SMI_list[:n_times]
        prob_list = prob_list[:n_times]
        trg_conv_SMI_list = [trg_conv_SMI for i in range(len(gen_conv_SMI_list))]
        
        gen_dict[trg_conv_SMI] = [gen_conv_SMI_list, trg_conv_SMI_list, data_dict]
    return gen_dict

def run_clip_similarity_check(config, model_CLIP, gen_dict):
    results_dict ={}
    for trg_conv_SMI, values in gen_dict.items():
        gen_conv_SMI_list, trg_conv_SMI_list, data_dict = values
        data_dict_dup = rbgvm.duplicate_dict(data_dict, len(gen_conv_SMI_list))

        mean_loss, losses, logits, targets, dot_similarity= model_CLIP.inference(data_dict_dup, 
                                                                    gen_conv_SMI_list)

        combined_list = [[smile, num.item(), dot_sim.item()] for smile, num, dot_sim in zip(gen_conv_SMI_list, losses, dot_similarity)]
        ### Sort by the lowest similarity
        #sorted_list = sorted(combined_list, key=lambda x: x[1])

        combined_list, failed_combined_list = add_tanimoto_similarity(trg_conv_SMI, combined_list)
        #combined_list, batch_data = rbgvm.add_HSQC_error(config, combined_list, data_dict_dup, gen_conv_SMI_list, trg_conv_SMI, config.multinom_runs) # config.MMT_batch
        #sorted_list = sorted(combined_list, key=lambda x: -x[3]) # SMILES = 0, losses =1, dot_sim= 2, tanimoto = 3 


        results_dict[trg_conv_SMI] = [combined_list]
    return results_dict

def run_HSQC_similarity_check(config, gen_dict, results_dict):
    results_dict_HSQC ={}
    for trg_conv_SMI, values in gen_dict.items():
        gen_conv_SMI_list, trg_conv_SMI_list, data_dict = values
        data_dict_dup = rbgvm.duplicate_dict(data_dict, len(gen_conv_SMI_list))

        src_HSQC_list = [data_dict_dup["src_HSQC"]]
        tensor_HSQC = rbgvm.prepare_HSQC_data_from_src(src_HSQC_list)

        sgnn_avg_sim_error, sim_error_list, batch_data = rbgvm.calculate_HSQC_error(config, data_dict_dup, gen_conv_SMI_list)
        results_dict_HSQC[trg_conv_SMI] = [gen_conv_SMI_list, sim_error_list]# batch_data]
    return results_dict_HSQC

def combine_CLIP_HSQC_data(results_dict_CLIP, results_dict_HSQC):

    final_results = {}
    for trg, data in results_dict_CLIP.items():
        for idx, data_list in enumerate(data[0]):
            try:
                hsqc_error = results_dict_HSQC[trg][1][idx]
                data[0][idx].append(hsqc_error)
            except:
                data[0][idx].append(-9)            
        data = sorted(data[0], key=lambda x: -x[3]) # SMILES = 0, losses =1, dot_sim= 2, tanimoto = 3 
        final_results[trg] = data
    return final_results


def filter_invalid_inputs(results_dict):
    """ Check if there is just a [None,None] entry from the run_test_mns_performance_CLIP_3 function"""
    filtered_dict = {}
    counter = 0
    for key, value in results_dict.items():
        # Assuming 'combined_list' is the first item in the list which is the value of the dictionary.
        # And we're checking if the first item in 'combined_list' is not '[None, None]'.
        if value[0] is not None:# and value[1] is not None:
            filtered_dict[key] = value
        else:
            counter+=1
    return filtered_dict, counter


In [None]:
from rdkit.Chem.rdMolDescriptors import CalcMolFormula

def calc_molecular_formula(smi):
    try:
        mol = Chem.MolFromSmiles(smi)
        return CalcMolFormula(mol)
    except:
        return None

def filter_for_MF_2(trg_conv_SMI, gen_conv_SMI, prob_list):
    # Calculate the molecular formula of the target molecule
    trg_mf = calc_molecular_formula(trg_conv_SMI)

    # Filter the list based on molecular formula
    filtered_gen_smis = [smi for (smi, token_prob) in zip(gen_conv_SMI, prob_list) if calc_molecular_formula(smi) == trg_mf]
    filtered_prob_list = [token_prob for (smi, token_prob) in zip(gen_conv_SMI, prob_list) if calc_molecular_formula(smi) == trg_mf]
    return filtered_gen_smis, filtered_prob_list


def run_multinomial_sampling_v2(config, model_MMT, val_dataloader, itos, stoi, MW_filter=False, MF_filter=False):
    n_times = config.multinom_runs
    results_dict = defaultdict(list)
    temperature_orig = config.temperature

    for idx, data_dict in tqdm(enumerate(val_dataloader)):
        if idx % 10 == 0:
            print(idx)

        gen_conv_SMI_list, trg_conv_SMI_list, token_probs_list, src_HSQC_list, prob_list = [], [], [], [], []
        data_dict_dup = rbgvm.duplicate_dict(data_dict, 16)

        trg_enc_SMI = data_dict["trg_enc_SMI"][0]
        trg_conv_SMI = hf.tensor_to_smiles(trg_enc_SMI[1:], itos)
        memory, src_padding_mask, trg_enc_SMI, fingerprint, src_HSQC, src_COSY = vgmmt.run_model(model_MMT,
                                                                   data_dict_dup, 
                                                                   config)
        counter = 1
        removed = 0
        while len(gen_conv_SMI_list) < n_times:
            # Increase the temperature if not enough different molecules get generated               
            if counter % 80 == 0:
                print(trg_conv_SMI)
                break
            multinom_tensor, multinom_token_prob = multinomial_sequence_multi_2(model_MMT, memory, src_padding_mask, stoi, config)
            gen_conv_SMI, token_probs = hf.tensor_to_smiles_and_prob_2(multinom_tensor, multinom_token_prob, itos)

            # Filter valid SMILES and canonicalize
            gen_conv_SMI, token_probs = filter_probs_and_valid_smiles_and_canonicolize(gen_conv_SMI, token_probs)
            #print(f"valid smi: {len(set(gen_conv_SMI))}")
            #import IPython; IPython.embed();
            before_gen = len(gen_conv_SMI)
            if MW_filter:
                gen_conv_SMI, token_probs = filter_for_MW_2(trg_conv_SMI, gen_conv_SMI, token_probs)
                #print(f"MW_filter: {before_gen - len(gen_conv_SMI)}")
                removed_1 = before_gen - len(gen_conv_SMI)
                removed += removed_1
            if MF_filter:
                gen_conv_SMI, token_probs = filter_for_MF_2(trg_conv_SMI, gen_conv_SMI, token_probs)
                #print(f"MF_filter: {before_gen - len(gen_conv_SMI)}")
                removed_2 = before_gen - len(gen_conv_SMI)
                removed += removed_2

            gen_conv_SMI_list.extend(gen_conv_SMI)
            prob_list.extend(token_probs)
            gen_conv_SMI_list, prob_list = deduplicate_smiles(gen_conv_SMI_list, prob_list)
            counter += 1
            config.temperature = config.temperature + 0.1
        #print(removed, len(gen_conv_SMI_list))
        gen_conv_SMI_list = gen_conv_SMI_list[:n_times]
        prob_list = prob_list[:n_times]
        trg_conv_SMI_list = [trg_conv_SMI for _ in range(len(gen_conv_SMI_list))]

        tanimoto_mean, tanimoto_std_dev, failed, tanimoto_list_all = vgmmt.calculate_tanimoto_similarity(gen_conv_SMI_list, trg_conv_SMI_list)
        results_dict[idx].append({
            'gen_conv_SMI_list': gen_conv_SMI_list,
            'trg_conv_SMI_list': trg_conv_SMI_list,
            'tanimoto_sim': tanimoto_list_all,
            'tanimoto_mean': tanimoto_mean,
            'tanimoto_std_dev': tanimoto_std_dev,
            'failed': failed,
            'prob_list': prob_list,
        })
        config.temperature = temperature_orig
    return config, results_dict

def run_precentage_calculation_v2(config, itos, stoi, stoi_MF, MW_filter=False, MF_filter=False):
    model_MMT = load_MMT_model(config)
    val_dataloader = load_data(config, stoi, stoi_MF, single=True, mode="val")  
    config.temperature = 1
    config, results_dict_mns_10 = run_multinomial_sampling_v2(config, model_MMT, val_dataloader, itos, stoi, MW_filter=MW_filter, MF_filter=MF_filter)
    top_x = 10
    percentage_top_10 = calc_percentage_top_x_correct(results_dict_mns_10, top_x)
    top_x = 5
    percentage_top_5 = calc_percentage_top_x_correct(results_dict_mns_10, top_x)
    top_x = 3
    percentage_top_3 = calc_percentage_top_x_correct(results_dict_mns_10, top_x)
    top_x = 1
    percentage_top_1 = calc_percentage_top_x_correct(results_dict_mns_10, top_x)
    
    val_dataloader_multi = load_data(config, stoi, stoi_MF, single=False, mode="val")  
    config, results_dict_greedy = run_greedy_sampling_v2(config, model_MMT, val_dataloader_multi, itos, stoi, MW_filter=False, MF_filter=False)
    #config, results_dict_greedy = run_greedy_sampling(config, model_MMT, val_dataloader_multi, itos, stoi, MW_filter=MW_filter)

    percentage_1_greedy = calc_percentage_top_x_correct_greedy(results_dict_greedy)

    percentage_collection = [percentage_1_greedy, percentage_top_1, percentage_top_3, percentage_top_5, percentage_top_10]

    return percentage_collection, results_dict_mns_10, results_dict_greedy

def run_greedy_sampling_v2(config, model_MMT, val_dataloader, itos, stoi, MW_filter=False, MF_filter=False):
    results_dict = defaultdict(list)
    gen_conv_SMI_list, trg_conv_SMI_list, prob_list, src_HSQC_list, src_COSY_list = [], [], [], [], []
    for idx, data_dict in enumerate(val_dataloader):
        trg_enc_SMI = data_dict["trg_enc_SMI"]
        trg_conv_SMI = hf.tensor_to_smiles(trg_enc_SMI.T[1:], itos)
        memory, src_padding_mask, trg_enc_SMI, fingerprint, src_HSQC, src_COSY = vgmmt.run_model(model_MMT,
                                                                   data_dict, 
                                                                   config)

        greedy_tensor, greedy_token_prob = greedy_sequence_2(model_MMT, stoi, itos, memory, src_padding_mask, config)

        gen_conv_SMI, token_probs = hf.tensor_to_smiles_and_prob(greedy_tensor, greedy_token_prob.T, itos)

        # Filter valid SMILES and canonicalize
        """gen_conv_SMI, token_probs = filter_probs_and_valid_smiles_and_canonicolize(gen_conv_SMI, token_probs)
        if MW_filter:
            gen_conv_SMI, token_probs = filter_for_MW_2(trg_conv_SMI[0], gen_conv_SMI, token_probs)
        if MF_filter:
            gen_conv_SMI, token_probs = filter_for_MF_2(trg_conv_SMI[0], gen_conv_SMI, token_probs)
        """
        gen_conv_SMI_list.extend(gen_conv_SMI)
        trg_conv_SMI_list.extend(trg_conv_SMI)
        prob_list.extend(token_probs)
        src_HSQC_list.extend(src_HSQC)
        src_COSY_list.extend(src_COSY)
        
    tanimoto_mean, tanimoto_std_dev, failed, tanimoto_list_all = vgmmt.calculate_tanimoto_similarity(gen_conv_SMI_list, trg_conv_SMI_list)
    results_dict = {
        'gen_conv_SMI_list': gen_conv_SMI_list,
        'trg_conv_SMI_list': trg_conv_SMI_list,
        'tanimoto_sim': tanimoto_list_all,
        'tanimoto_mean': tanimoto_mean,
        'tanimoto_std_dev': tanimoto_std_dev,
        'failed': failed,
        'prob_list': prob_list,
        "src_HSQC_list": src_HSQC_list,
        "src_COSY_list": src_COSY_list,
        }
    
    return config, results_dict


In [None]:
config.csv_1H_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/ML_NMR_5M_XL_1H_comb_test_V8.csv'
config.csv_13C_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/ML_NMR_5M_XL_13C.csv'    
config.csv_HSQC_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/ML_NMR_5M_XL_HSQC.csv'    
config.csv_COSY_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/ML_NMR_5M_XL_COSY.csv'
config.csv_path_val = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/ML_NMR_5M_XL_1H_comb_test_V8.csv'
config.pickle_file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/ML_NMR_5M_XL_1H_comb_test_V8_355655.pkl"

In [None]:
import pickle

config.training_mode = "1H_13C_HSQC_COSY_IR_MF_MW"
config.IR_data_folder="/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/IR_spectra_NN"
config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_Drop4/MultimodalTransformer_time_1710027004.1571195_Loss_0.112.ckpt"
config.temperature = 1
config.multinom_runs = 10
config.data_size = 10000
MW_filter = False
MF_filter = False

percentage_collection_FALSE_FALSE, results_dict_mns_10_FALSE_FALSE, results_dict_greedy_FALSE_FALSE = run_precentage_calculation_v2(config, itos, stoi, stoi_MF, MW_filter, MF_filter)

file_path_c = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240923_Experiment_3_v2/3.0.1_results_dict_greedy_FALSE_FALSE_10000.pkl'
with open(file_path_c, 'wb') as file:
    pickle.dump(results_dict_greedy_FALSE_FALSE, file)

file_path_c = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240923_Experiment_3_v2/3.0.2_results_dict_mns_10_FALSE_FALSE_10000.pkl'
with open(file_path_c, 'wb') as file:
    pickle.dump(results_dict_mns_10_FALSE_FALSE, file)

file_path_c = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240923_Experiment_3_v2/3.0.3_percentage_collection_FALSE_FALSE_10000.pkl'
with open(file_path_c, 'wb') as file:
    pickle.dump(percentage_collection_FALSE_FALSE, file)


In [None]:
import pickle

config.training_mode = "1H_13C_HSQC_COSY_IR_MF_MW"
config.IR_data_folder="/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/IR_spectra_NN"
config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_Drop4/MultimodalTransformer_time_1710027004.1571195_Loss_0.112.ckpt"
config.temperature = 1
config.multinom_runs = 10
config.data_size = 10000
MW_filter = True
MF_filter = False

percentage_collection_TRUE_FALSE, results_dict_mns_10_TRUE_FALSE, results_dict_greedy_TRUE_FALSE = run_precentage_calculation_v2(config, itos, stoi, stoi_MF, MW_filter, MF_filter)

file_path_c = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240923_Experiment_3_v2/3.0.1_results_dict_greedy_TRUE_FALSE_10000.pkl'
with open(file_path_c, 'wb') as file:
    pickle.dump(results_dict_greedy_TRUE_FALSE, file)

file_path_c = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240923_Experiment_3_v2/3.0.2_results_dict_mns_10_TRUE_FALSE_10000.pkl'
with open(file_path_c, 'wb') as file:
    pickle.dump(results_dict_mns_10_TRUE_FALSE, file)

file_path_c = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240923_Experiment_3_v2/3.0.3_percentage_collection_TRUE_FALSE_10000.pkl'
with open(file_path_c, 'wb') as file:
    pickle.dump(percentage_collection_TRUE_FALSE, file)


In [None]:
# V8 MW Drop
import pickle

config.training_mode = "1H_13C_HSQC_COSY_IR_MF_MW"
config.IR_data_folder="/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/IR_spectra_NN"
config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_Drop4/MultimodalTransformer_time_1710027004.1571195_Loss_0.112.ckpt"
config.temperature = 1
config.multinom_runs = 10
config.data_size = 10000
MW_filter = False
MF_filter = True

percentage_collection_FALSE_TRUE, results_dict_mns_10_FALSE_TRUE, results_dict_greedy_FALSE_TRUE = run_precentage_calculation_v2(config, itos, stoi, stoi_MF, MW_filter, MF_filter)

file_path_c = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240923_Experiment_3_v2/3.0.1_results_dict_greedy_FALSE_TRUE_10000.pkl'
with open(file_path_c, 'wb') as file:
    pickle.dump(results_dict_greedy_FALSE_TRUE, file)

file_path_c = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240923_Experiment_3_v2/3.0.2_results_dict_mns_10_FALSE_TRUE_10000.pkl'
with open(file_path_c, 'wb') as file:
    pickle.dump(results_dict_mns_10_FALSE_TRUE, file)

file_path_c = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240923_Experiment_3_v2/3.0.3_percentage_collection_FALSE_TRUE_10000.pkl'
with open(file_path_c, 'wb') as file:
    pickle.dump(percentage_collection_FALSE_TRUE, file)


In [None]:
# V8 MW Drop
import pickle

config.training_mode = "1H_13C_HSQC_COSY_IR_MF_MW"
config.IR_data_folder="/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/IR_spectra_NN"
config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_Drop4/MultimodalTransformer_time_1710027004.1571195_Loss_0.112.ckpt"
config.temperature = 1
config.multinom_runs = 10
config.data_size = 10000
MW_filter = True
MF_filter = True

percentage_collection_TRUE_TRUE, results_dict_mns_10_TRUE_TRUE, results_dict_greedy_TRUE_TRUE = run_precentage_calculation_v2(config, itos, stoi, stoi_MF, MW_filter, MF_filter)

file_path_c = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240923_Experiment_3_v2/3.0.1_results_dict_greedy_TRUE_TRUE_10000.pkl'
with open(file_path_c, 'wb') as file:
    pickle.dump(results_dict_greedy_TRUE_TRUE, file)

file_path_c = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240923_Experiment_3_v2/3.0.2_results_dict_mns_10_TRUE_TRUE_10000.pkl'
with open(file_path_c, 'wb') as file:
    
file_path_c = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240923_Experiment_3_v2/3.0.3_percentage_collection_TRUE_TRUE_10000.pkl'
with open(file_path_c, 'wb') as file:
    pickle.dump(percentage_collection_TRUE_TRUE, file)


##### Plot new Data

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pickle



def calc_percentage_and_count_top_x_correct_greedy(results_dict):
    count_yes = 0
    count_no = 0

    for i in results_dict["tanimoto_sim"]:
        if i == 1:
            count_yes += 1
        else:
            count_no += 1

    total = count_yes + count_no
    percentage = (count_yes / total) * 100 if total > 0 else 0
    return percentage, count_yes, total


def calc_percentage_and_count_top_x_correct(results_dict, top_x):
    count_yes = 0
    count_no = 0
    for idx in results_dict.keys():
        for result in results_dict[idx]:
            tanimoto_sim = result["tanimoto_sim"][:top_x]
            if 1 in tanimoto_sim:
                count_yes += 1
            else:
                count_no += 1
    total = count_yes + count_no
    percentage = (count_yes / total) * 100 if total > 0 else 0
    return percentage, count_yes, total

def prepare_data(results_dict_FALSE_FALSE, results_dict_FALSE_TRUE, 
                 results_dict_TRUE_FALSE, results_dict_TRUE_TRUE, 
                 results_dict_greedy):
    
    # Calculate percentages for each filter combination
    def calc_percentages(results_dict):
        percentage_greedy, count_greedy, _ = calc_percentage_and_count_top_x_correct_greedy(results_dict_greedy)
        percentage_top_1, count_top_1, _ = calc_percentage_and_count_top_x_correct(results_dict, 1)
        percentage_top_3, count_top_3, _ = calc_percentage_and_count_top_x_correct(results_dict, 3)
        percentage_top_5, count_top_5, _ = calc_percentage_and_count_top_x_correct(results_dict, 5)
        percentage_top_10, count_top_10, _ = calc_percentage_and_count_top_x_correct(results_dict, 10)
        
        return [(percentage_greedy, count_greedy), (percentage_top_1, count_top_1),
                (percentage_top_3, count_top_3), (percentage_top_5, count_top_5),
                (percentage_top_10, count_top_10)]

    no_filter = calc_percentages(results_dict_FALSE_FALSE)
    mf_filter = calc_percentages(results_dict_FALSE_TRUE)
    mw_filter = calc_percentages(results_dict_TRUE_FALSE)
    mw_mf_filter = calc_percentages(results_dict_TRUE_TRUE)

    return no_filter, mf_filter, mw_filter, mw_mf_filter


def plot_results(no_filter, mf_filter, mw_filter, mw_mf_filter):
    labels = ['Greedy', '1 Sample', '3 Samples', '5 Samples', '10 Samples']
    x = np.arange(len(labels))
    width = 0.2  # Reduced width to accommodate 4 bars
    colors = ['#A1C8F3', '#FFB381', '#8BE5A0', '#FF9D9A']
    
    fig, ax = plt.subplots(figsize=(15, 8))
    
    # Plot bars in the desired order
    rects1 = ax.bar(x - 1 * width, [d[0] for d in no_filter], width, label='No Filter', color=colors[0])
    rects2 = ax.bar(x - 0.0 * width, [d[0] for d in mw_filter], width, label='MW Filter', color=colors[1])
    rects3 = ax.bar(x + 1 * width, [d[0] for d in mf_filter], width, label='MF Filter', color=colors[2])
    #rects4 = ax.bar(x + 1.5 * width, [d[0] for d in mw_mf_filter], width, label='MW + MF Filter', color=colors[3])

    # Set labels and title with larger font sizes
    ax.set_ylabel('Percentage', fontsize=30)
    ax.set_title('Performance Comparison with Different Filters', fontsize=30)
    ax.set_xticks(x)
    ax.set_xticklabels(labels, fontsize=30)
    ax.legend(fontsize=30)
    ax.tick_params(axis='y', labelsize=30)
    ax.set_ylim(40, 100)

    # Function to add annotations to the bars
    def autolabel(rects, data):
        for rect, (percentage, count) in zip(rects, data):
            height = rect.get_height()
            visible_height = height - 40  # Adjust for the 40% lower limit
            
            # Add rotated percentage on top of the bar
            ax.text(rect.get_x() + rect.get_width() / 2, height + 2, f'{percentage:.1f}%',
                    ha='center', va='bottom', fontsize=20, rotation=90)
            
            # Add count at the middle of the visible part of the bar
            ax.text(rect.get_x() + rect.get_width() / 2, 40 + visible_height / 2, f'{count}',
                    ha='center', va='center', fontsize=20, rotation=90)

    # Add annotations to each bar
    autolabel(rects1, no_filter)
    autolabel(rects2, mw_filter)
    autolabel(rects3, mf_filter)
    #autolabel(rects4, mw_mf_filter)

    fig.tight_layout()
    plt.savefig('/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/performance_comparison_with_counts_all_filters_10000_v1a.png')
    plt.show()

In [None]:

# Load your new data
file_path_base = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240923_Experiment_3_v2/'

with open(file_path_base + '3.0.2_results_dict_mns_10_FALSE_FALSE_10000.pkl', 'rb') as file:
    results_dict_FALSE_FALSE = pickle.load(file)

with open(file_path_base + '3.0.2_results_dict_mns_10_FALSE_TRUE_10000.pkl', 'rb') as file:
    results_dict_FALSE_TRUE = pickle.load(file)

with open(file_path_base + '3.0.2_results_dict_mns_10_TRUE_FALSE_10000.pkl', 'rb') as file:
    results_dict_TRUE_FALSE = pickle.load(file)

with open(file_path_base + '3.0.2_results_dict_mns_10_TRUE_TRUE_10000.pkl', 'rb') as file:
    results_dict_TRUE_TRUE = pickle.load(file)

with open("/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240923_Experiment_3_v2/3.0.1_results_dict_greedy_FALSE_FALSE_10000.pkl", 'rb') as file:
    results_dict_greedy = pickle.load(file)

# Prepare and plot data
no_filter, mf_filter, mw_filter, mw_mf_filter = prepare_data(
    results_dict_FALSE_FALSE, 
    results_dict_FALSE_TRUE,
    results_dict_TRUE_FALSE, 
    results_dict_TRUE_TRUE,
    results_dict_greedy
    )
plot_results(no_filter, mf_filter, mw_filter, mw_mf_filter)

### 5.0. Similarity reduction plotting

#### 5.1 Tanimoto Similarity

In [None]:
import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import umap
import matplotlib.pyplot as plt
from tqdm import tqdm

def main():
    np.random.seed(42)  # Set random seed for reproducibility   
    zinc_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/ML_NMR_5M_XL_1H_comb_train_V8.csv' 
    pubchem_paths = [
        '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_0_250_x1000/ML_NMR_2M_XL_1H_V1_test_f_0_250_x1000.csv',
        '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_250_350_x1000/ML_NMR_2M_XL_1H_V1_test_f_250_350_x1000.csv',
        '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_1H_V1_test_f_350_500_x1000.csv'
    ]
    weight_ranges = ['0-250 Da', '250-350 Da', '350-500 Da']

    print("Loading ZINC data...")
    zinc_smiles = exp_fun.load_data(zinc_path)
    zinc_smiles = exp_fun.load_data(zinc_path, sample_size=3000)

    zinc_fp = exp_fun.calculate_fingerprints(zinc_smiles)  # Limit to 10000 compounds for performance

    for pubchem_path, weight_range in zip(pubchem_paths, weight_ranges):
        print(f"Processing PubChem data for {weight_range}...")
        pubchem_smiles = exp_fun.load_data(pubchem_path)
        pubchem_fp = calculate_fingerprints(pubchem_smiles[:100])  # Limit to 100 compounds for performance

        combined_fp = np.vstack((zinc_fp, pubchem_fp))
        labels = np.array(['ZINC'] * len(zinc_fp) + ['PubChem'] * len(pubchem_fp))

        print("Performing dimensionality reduction...")
        tsne_result, pca_result, umap_result = exp_fun.perform_dimensionality_reduction(combined_fp)

        print("Plotting results...")
        exp_fun.plot_tsne_umap_pca([tsne_result, pca_result, umap_result], labels, 
                     f"ZINC vs PubChem ({weight_range})", 
                     ['t-SNE', 'PCA', 'UMAP'])

    print("All plots generated successfully!")


In [None]:
main()
    

#### 5.2 Vector Similarity

In [None]:
import pandas as pd
import ast  # For safely evaluating strings containing Python literals
from sklearn.neighbors import NearestNeighbors
import numpy as np
from rdkit import Chem
from rdkit.Chem import Draw
import matplotlib.pyplot as plt
import math
import pandas as pd
import torch
import ast
from tqdm import tqdm
import time
import numpy as np
from sklearn.neighbors import NearestNeighbors
import random
from tqdm import tqdm
from sklearn.neighbors import RadiusNeighborsRegressor
import numpy as np
from scipy.spatial.distance import cosine
from sklearn.neighbors import NearestNeighbors
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.DataStructs import FingerprintSimilarity
from rdkit.DataStructs import TanimotoSimilarity
import numpy as np


4.2.1 Pubchem vectorization

In [None]:
#config.csv_1H_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_1H_V1_test_f_350_500_x1000.csv'
#config.csv_13C_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_13C_V1_test_350_500_x1000.csv'    
#config.csv_HSQC_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_HSQC_V1_test_350_500_x1000.csv'    
#config.csv_COSY_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_COSY_V1_test_350_500_x1000.csv'   
#config.csv_path_val = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_1H_V1_test_f_350_500_x1000.csv'
#config.pickle_file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_1H_V1_test_f_350_500_x1000_185242.pkl"

#config.csv_1H_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_0_250_x1000/ML_NMR_2M_XL_1H_V1_test_f_0_250_x1000.csv'
#config.csv_13C_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_0_250_x1000/ML_NMR_2M_XL_13C_V1_test_0_250_x1000.csv'
#config.csv_HSQC_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_0_250_x1000/ML_NMR_2M_XL_HSQC_V1_test_0_250_x1000.csv'    
#config.csv_COSY_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_0_250_x1000/ML_NMR_2M_XL_COSY_V1_test_0_250_x1000.csv'   
#config.csv_path_val = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_0_250_x1000/ML_NMR_2M_XL_1H_V1_test_f_0_250_x1000.csv'
#config.pickle_file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_0_250_x1000/ML_NMR_2M_XL_1H_V1_test_f_0_250_x1000_933335.pkl"


config.csv_1H_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_250_350_x1000/ML_NMR_2M_XL_1H_V1_test_f_250_350_x1000.csv'
config.csv_13C_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_250_350_x1000/ML_NMR_2M_XL_13C_V1_test_250_350_x1000.csv'    
config.csv_HSQC_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_250_350_x1000/ML_NMR_2M_XL_HSQC_V1_test_250_350_x1000.csv'    
config.csv_COSY_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_250_350_x1000/ML_NMR_2M_XL_COSY_V1_test_250_350_x1000.csv'   
config.csv_path_val = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_250_350_x1000/ML_NMR_2M_XL_1H_V1_test_f_250_350_x1000.csv'
config.pickle_file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_250_350_x1000/ML_NMR_2M_XL_1H_V1_test_f_250_350_x1000_285005.pkl"

# V8 Raw 
config.IR_data_folder = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/IR_data"


In [None]:
# Test Data
config.csv_1H_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_1H_V1_test_f_350_500_x1000.csv'
config.csv_13C_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_13C_V1_test_350_500_x1000.csv'    
config.csv_HSQC_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_HSQC_V1_test_350_500_x1000.csv'    
config.csv_COSY_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_COSY_V1_test_350_500_x1000.csv'   
config.csv_path_val = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_1H_V1_test_f_350_500_x1000.csv'
config.pickle_file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_1H_V1_test_f_350_500_x1000_185242.pkl"

# V8 Raw 
config.IR_data_folder = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/IR_data"
config.data_size = 1000

config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_Drop4/MultimodalTransformer_time_1710027004.1571195_Loss_0.112.ckpt"
config.vector_db = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/PubChem_vectors/ML_NMR_2M_XL_1H_V1_test_f_350_500_x1000_Vectors.csv'

In [None]:
### Data needs to be loaded in csv_1H_path_SGNN 13C ...etc
# path to store db
config = exp_func.vectorize_db(config, stoi, stoi_MF, "db", "all")

In [None]:
import pandas as pd
import pickle
import torch
from tqdm import tqdm
import ast

vector_db = config.vector_db
# Load CSV data
df = pd.read_csv(vector_db)

# Convert 'Fingerprints' to tensors
tqdm.pandas()
df['Fingerprints'] = df['Fingerprints'].progress_apply(ast.literal_eval)
df['Fingerprints'] = df['Fingerprints'].progress_apply(lambda x: torch.tensor(x, dtype=torch.float32))

# Save the DataFrame to Pickle
pickle_file = vector_db.replace('.csv', '_v2.pkl')
with open(pickle_file, 'wb') as f:
    pickle.dump(df, f)

print(f"Data saved to {pickle_file}")


In [None]:

config.csv_1H_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_250_350_x1000/ML_NMR_2M_XL_1H_V1_test_f_250_350_x1000.csv'
config.csv_13C_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_250_350_x1000/ML_NMR_2M_XL_13C_V1_test_250_350_x1000.csv'    
config.csv_HSQC_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_250_350_x1000/ML_NMR_2M_XL_HSQC_V1_test_250_350_x1000.csv'    
config.csv_COSY_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_250_350_x1000/ML_NMR_2M_XL_COSY_V1_test_250_350_x1000.csv'   
config.csv_path_val = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_250_350_x1000/ML_NMR_2M_XL_1H_V1_test_f_250_350_x1000.csv'
config.pickle_file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_250_350_x1000/ML_NMR_2M_XL_1H_V1_test_f_250_350_x1000_285005.pkl"

# V8 Raw 
config.IR_data_folder = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/IR_data"
config.data_size = 1000

config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_Drop4/MultimodalTransformer_time_1710027004.1571195_Loss_0.112.ckpt"
config.vector_db = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/PubChem_vectors/ML_NMR_2M_XL_1H_V1_test_f_250_350_x1000_Vectors.csv'

In [None]:
### Data needs to be loaded in csv_1H_path_SGNN 13C ...etc
# path to store db
config = exp_func.vectorize_db(config, stoi, stoi_MF, "db", "all")

In [None]:
import pandas as pd
import pickle
import torch
from tqdm import tqdm
import ast

vector_db = config.vector_db
# Load CSV data
df = pd.read_csv(vector_db)

# Convert 'Fingerprints' to tensors
tqdm.pandas()
df['Fingerprints'] = df['Fingerprints'].progress_apply(ast.literal_eval)
df['Fingerprints'] = df['Fingerprints'].progress_apply(lambda x: torch.tensor(x, dtype=torch.float32))

# Save the DataFrame to Pickle
pickle_file = vector_db.replace('.csv', '_v2.pkl')
with open(pickle_file, 'wb') as f:
    pickle.dump(df, f)

print(f"Data saved to {pickle_file}")


In [None]:
config.csv_1H_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_0_250_x1000/ML_NMR_2M_XL_1H_V1_test_f_0_250_x1000.csv'
config.csv_13C_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_0_250_x1000/ML_NMR_2M_XL_13C_V1_test_0_250_x1000.csv'
config.csv_HSQC_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_0_250_x1000/ML_NMR_2M_XL_HSQC_V1_test_0_250_x1000.csv'    
config.csv_COSY_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_0_250_x1000/ML_NMR_2M_XL_COSY_V1_test_0_250_x1000.csv'   
config.csv_path_val = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_0_250_x1000/ML_NMR_2M_XL_1H_V1_test_f_0_250_x1000.csv'
config.pickle_file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_0_250_x1000/ML_NMR_2M_XL_1H_V1_test_f_0_250_x1000_933335.pkl"

# V8 Raw 
config.IR_data_folder = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/IR_data"
config.data_size = 1000

config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_Drop4/MultimodalTransformer_time_1710027004.1571195_Loss_0.112.ckpt"
config.vector_db = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/PubChem_vectors/ML_NMR_2M_XL_1H_V1_test_f_0_250_x1000_Vectors.csv'

In [None]:
### Data needs to be loaded in csv_1H_path_SGNN 13C ...etc
# path to store db
config = exp_func.vectorize_db(config, stoi, stoi_MF, "db", "all")

In [None]:
import pandas as pd
import pickle
import torch
from tqdm import tqdm
import ast

vector_db = config.vector_db
# Load CSV data
df = pd.read_csv(vector_db)

# Convert 'Fingerprints' to tensors
tqdm.pandas()
df['Fingerprints'] = df['Fingerprints'].progress_apply(ast.literal_eval)
df['Fingerprints'] = df['Fingerprints'].progress_apply(lambda x: torch.tensor(x, dtype=torch.float32))

# Save the DataFrame to Pickle
pickle_file = vector_db.replace('.csv', '_v2.pkl')
with open(pickle_file, 'wb') as f:
    pickle.dump(df, f)

print(f"Data saved to {pickle_file}")


##### UMAP, tSNE, PCA Set 1-3

In [None]:
import pandas as pd
import pickle
import torch
import numpy as np
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import umap
import matplotlib.pyplot as plt
from tqdm import tqdm




In [None]:
pwd

In [None]:
# Load test data
test_files = [
    '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/PubChem_vectors/ML_NMR_2M_XL_1H_V1_test_f_0_250_x1000_Vectors_v2.pkl',
    '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/PubChem_vectors/ML_NMR_2M_XL_1H_V1_test_f_250_350_x1000_Vectors_v2.pkl',
    '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/PubChem_vectors/ML_NMR_2M_XL_1H_V1_test_f_350_500_x1000_Vectors_v2.pkl'
]

weight_ranges = ['0-250 Da', '250-350 Da', '350-500 Da']
weight_ranges = ["Set 1", "Set 2", "Set 3"]

# Load training data
print("Loading training data...")
train_vectors, train_smiles = exp_func.load_pickle_data('/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/vector_db_train_5Mod_4M_v2_v2.pkl', sample_size=3000)

for test_file, weight_range in zip(test_files, weight_ranges):
    print(f"Processing test data for {weight_range}...")
    test_vectors, test_smiles = exp_func.load_pickle_data(test_file)
    
    # Use only the first 100 test vectors
    test_vectors_sample = test_vectors[:300]
    
    combined_vectors = np.vstack((train_vectors, test_vectors_sample))
    labels = ['Train'] * len(train_vectors) + ['Test'] * len(test_vectors_sample)
    
    print(f"Shape of combined vectors: {combined_vectors.shape}")
    print(f"Number of labels: {len(labels)}")
    
    print("Performing dimensionality reduction...")
    tsne_result, pca_result, umap_result = exp_func.perform_dimensionality_reduction(combined_vectors)
    
    print("Plotting results...")
    exp_func.plot_tsne_umap_pca_train_test([tsne_result, pca_result, umap_result], labels, 
                 f"Train vs Test Vectors ({weight_range})", 
                 ['t-SNE', 'PCA', 'UMAP'])

print("All plots generated successfully!")

##### 5.2.1 ZINC 1000 vectorization

In [None]:
SGNN_smi_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/PubChem_vectors/ML_NMR_5M_XL_1H_comb_test_V8_1000.csv"

In [None]:
config.SGNN_csv_gen_smi = SGNN_smi_path
config.SGNN_gen_folder_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/46_Project_3_Data/test"
config = ex.gen_sim_aug_data(config, IR_config)
config.csv_path_val = config.csv_1H_path_SGNN
#config = ex.filter_invalid_criteria(config.csv_1H_path_SGNN)

In [None]:
import shutil
import os

# Define your source CSV paths (these are just placeholders)
source_csv_1H = config.csv_1H_path_SGNN
source_csv_13C = config.csv_13C_path_SGNN
source_csv_HSQC = config.csv_HSQC_path_SGNN
source_csv_COSY = config.csv_COSY_path_SGNN
source_IR_folder = config.IR_data_folder

# Define the destination folder (replace with your actual path)
destination_folder = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/ZINC_1000_data/test"

# Ensure the destination folder exists
os.makedirs(destination_folder, exist_ok=True)

# Copy the files
shutil.copy(source_csv_1H, destination_folder)
shutil.copy(source_csv_13C, destination_folder)
shutil.copy(source_csv_HSQC, destination_folder)
shutil.copy(source_csv_COSY, destination_folder)

# If IR_data_folder contains multiple files, you can copy them like this:
for file_name in os.listdir(source_IR_folder):
    full_file_name = os.path.join(source_IR_folder, file_name)
    if os.path.isfile(full_file_name):
        shutil.copy(full_file_name, destination_folder)

print(f"All files have been copied to {destination_folder}")


In [None]:
# Test Data
config.csv_1H_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/ZINC_1000_data/data_1H_791960.csv'
config.csv_13C_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/ZINC_1000_data/data_13C_791960.csv'    
config.csv_HSQC_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/ZINC_1000_data/data_HSQC_791960.csv'    
config.csv_COSY_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/ZINC_1000_data/data_COSY_791960.csv'   
config.csv_path_val = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/ZINC_1000_data/data_1H_791960.csv'
config.pickle_file_path = ""

# V8 Raw 
config.IR_data_folder = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/ZINC_1000_data/IR_Data"
config.data_size = 1000

config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_Drop4/MultimodalTransformer_time_1710027004.1571195_Loss_0.112.ckpt"
config.vector_db = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/ZINC_1000_data/ZINC_1000_vectors.csv'

In [None]:
### Data needs to be loaded in csv_1H_path_SGNN 13C ...etc
# path to store db
config = exp_func.vectorize_db(config, stoi, stoi_MF, "db", "all")

In [None]:
import pandas as pd
import pickle
import torch
from tqdm import tqdm
import ast

vector_db = config.vector_db
# Load CSV data
df = pd.read_csv(vector_db)

# Convert 'Fingerprints' to tensors
tqdm.pandas()
df['Fingerprints'] = df['Fingerprints'].progress_apply(ast.literal_eval)
df['Fingerprints'] = df['Fingerprints'].progress_apply(lambda x: torch.tensor(x, dtype=torch.float32))

# Save the DataFrame to Pickle
pickle_file = vector_db.replace('.csv', '_v2.pkl')
with open(pickle_file, 'wb') as f:
    pickle.dump(df, f)

print(f"Data saved to {pickle_file}")


##### 5.2.1 Sim 34 molecules vectorization

In [None]:
SGNN_smi_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/37_Richard_ACD_sim_data/ACD_1H_with_SN_filtered_v3.csv"

In [None]:
config.SGNN_csv_gen_smi = SGNN_smi_path
config.SGNN_gen_folder_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/46_Project_3_Data/test"
config = ex.gen_sim_aug_data(config, IR_config)
config.csv_path_val = config.csv_1H_path_SGNN
#config = ex.filter_invalid_criteria(config.csv_1H_path_SGNN)

In [None]:
import shutil
import os

# Define your source CSV paths (these are just placeholders)
source_csv_1H = config.csv_1H_path_SGNN
source_csv_13C = config.csv_13C_path_SGNN
source_csv_HSQC = config.csv_HSQC_path_SGNN
source_csv_COSY = config.csv_COSY_path_SGNN
source_IR_folder = config.IR_data_folder

# Define the destination folder (replace with your actual path)
destination_folder = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/37_Richard_ACD_sim_data/SGNN_simulations"

# Ensure the destination folder exists
os.makedirs(destination_folder, exist_ok=True)

# Copy the files
shutil.copy(source_csv_1H, destination_folder)
shutil.copy(source_csv_13C, destination_folder)
shutil.copy(source_csv_HSQC, destination_folder)
shutil.copy(source_csv_COSY, destination_folder)

# If IR_data_folder contains multiple files, you can copy them like this:
for file_name in os.listdir(source_IR_folder):
    full_file_name = os.path.join(source_IR_folder, file_name)
    if os.path.isfile(full_file_name):
        shutil.copy(full_file_name, destination_folder)

print(f"All files have been copied to {destination_folder}")


In [None]:
# real_sim data

config.csv_1H_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/37_Richard_ACD_sim_data/SGNN_simulations/data_1H_812119.csv'
config.csv_13C_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/37_Richard_ACD_sim_data/SGNN_simulations/data_13C_812119.csv'    
config.csv_HSQC_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/37_Richard_ACD_sim_data/SGNN_simulations/data_HSQC_812119.csv'    
config.csv_COSY_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/37_Richard_ACD_sim_data/SGNN_simulations/data_COSY_812119.csv'   
config.csv_path_val = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/37_Richard_ACD_sim_data/SGNN_simulations/data_1H_812119.csv'

config.IR_data_folder = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/37_Richard_ACD_sim_data/SGNN_simulations/IR_spectra"

config.pickle_file_path = ""

config.data_size = 34

config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_Drop4/MultimodalTransformer_time_1710027004.1571195_Loss_0.112.ckpt"
config.vector_db = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/37_Richard_ACD_sim_data/SGNN_simulations/real_sim_34_vectors.csv'

In [None]:
### Data needs to be loaded in csv_1H_path_SGNN 13C ...etc
# path to store db
config = exp_func.vectorize_db(config, stoi, stoi_MF, "db", "all")

In [None]:
import pandas as pd
import pickle
import torch
from tqdm import tqdm
import ast

vector_db = config.vector_db
# Load CSV data
df = pd.read_csv(vector_db)

# Convert 'Fingerprints' to tensors
tqdm.pandas()
df['Fingerprints'] = df['Fingerprints'].progress_apply(ast.literal_eval)
df['Fingerprints'] = df['Fingerprints'].progress_apply(lambda x: torch.tensor(x, dtype=torch.float32))

# Save the DataFrame to Pickle
pickle_file = vector_db.replace('.csv', '_v2.pkl')
with open(pickle_file, 'wb') as f:
    pickle.dump(df, f)

print(f"Data saved to {pickle_file}")


##### 5.2.1 Exp 34 molecules vectorization

In [None]:
# real_sim data
base_path_exp = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/36_Richard_43_dataset/experimenal_data/"
config.csv_1H_path_exp = f"{base_path_exp}real_1H_with_AZ_SMILES_v3.csv"
config.csv_13C_path_exp = f"{base_path_exp}real_13C_with_AZ_SMILES_v3.csv"
config.csv_HSQC_path_exp = f"{base_path_exp}real_HSQC_with_AZ_SMILES_v3.csv"
config.csv_COSY_path_exp = f"{base_path_exp}real_COSY_with_AZ_SMILES_v3.csv"
config.IR_data_folder_exp = f"{base_path_exp}IR_data"
config.csv_path_val = config.csv_1H_path_exp

config.pickle_file_path = ""

config.data_size = 34

config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_Drop4/MultimodalTransformer_time_1710027004.1571195_Loss_0.112.ckpt"
config.vector_db = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/36_Richard_43_dataset/experimenal_data/exp_34_vectors.csv'

In [None]:
config = exp_func.vectorize_db(config, stoi, stoi_MF, "db", "all")

In [None]:
import pandas as pd
import pickle
import torch
from tqdm import tqdm
import ast

vector_db = config.vector_db
# Load CSV data
df = pd.read_csv(vector_db)

# Convert 'Fingerprints' to tensors
tqdm.pandas()
df['Fingerprints'] = df['Fingerprints'].progress_apply(ast.literal_eval)
df['Fingerprints'] = df['Fingerprints'].progress_apply(lambda x: torch.tensor(x, dtype=torch.float32))

# Save the DataFrame to Pickle
pickle_file = vector_db.replace('.csv', '_v2.pkl')
with open(pickle_file, 'wb') as f:
    pickle.dump(df, f)

print(f"Data saved to {pickle_file}")


#### 5.3 Plot exp, Sim with TSNE, UMAP, PCA

In [None]:
save_folder = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Figures"  # Replace this with your desired folder path

# Load training data
print("Loading training data...")
train_vectors, train_smiles = exp_func.load_pickle_data('/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/vector_db_train_5Mod_4M_v2_v2.pkl', sample_size=3000)

# Load experimental data
print("Loading experimental data...")
exp_test_file = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/36_Richard_43_dataset/experimenal_data/exp_34_vectors_v2.pkl"

exp_vectors, exp_smiles = exp_func.load_pickle_data(exp_test_file)

# Combine all vectors
combined_vectors = np.vstack((train_vectors,  exp_vectors))
labels = ['Train'] * len(train_vectors) + ['Experimental'] * len(exp_vectors)

print(f"Shape of combined vectors: {combined_vectors.shape}")
print(f"Number of labels: {len(labels)}")

print("Performing dimensionality reduction...")
tsne_result, pca_result, umap_result = exp_func.perform_dimensionality_reduction(combined_vectors)


# When calling plot_results, specify the save folder
exp_func.plot_tsne_umap_pca_train_test_folder([tsne_result, pca_result, umap_result], labels, 
             "Train vs Experimental Vectors", 
             ['t-SNE', 'PCA', 'UMAP'],
             save_folder)
print("All plots generated successfully!")

In [None]:
# Load experimental data
print("Loading experimental data...")
sim_test_file = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/37_Richard_ACD_sim_data/SGNN_simulations/real_sim_34_vectors_v2.pkl"

sim_vectors, sim_smiles = exp_func.load_pickle_data(sim_test_file)

# Combine all vectors
combined_vectors = np.vstack((train_vectors,  sim_vectors))
labels = ['Train'] * len(train_vectors) + ['Simulated'] * len(sim_vectors)

print(f"Shape of combined vectors: {combined_vectors.shape}")
print(f"Number of labels: {len(labels)}")

print("Performing dimensionality reduction...")
tsne_result, pca_result, umap_result = exp_func.perform_dimensionality_reduction(combined_vectors)

print("Plotting results...")

# When calling plot_results, specify the save folder
exp_func.plot_tsne_umap_pca_train_test_folder([tsne_result, pca_result, umap_result], labels,              "Train vs Simulated Vectors", 
            ['t-SNE', 'PCA', 'UMAP'],
            save_folder)

print("All plots generated successfully!")

#### Load MMT PubChem 100 Prediction 

#### 5.4 1000 Calculations

##### PC and ZINC Latent space comparison for correct and incorrect molecules
- calculations done: /projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MultiModalTransformer/scripts

-- FIX VECTORS

In [None]:

def main(data_configs, train_data_path, output_folder, ranking_method):
    os.makedirs(output_folder, exist_ok=True)

    print("Loading training data...")
    train_vectors, train_smiles = exp_func.load_pickle_data(train_data_path, sample_size=3000)

    for data_config in data_configs:
        exp_func.process_dataset(data_config, train_vectors, train_smiles, output_folder, ranking_method)

if __name__ == "__main__":
    data_configs = [
        {
            'weight_range': 'PC_0-250',
            'pkl_folder': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/PC_0_250',
            'file_path': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/PubChem_vectors/sim_mol_0_250/experiment_results_0_250.pkl'
        },
        {
            'weight_range': 'PC_250-350',
            'pkl_folder': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/PC_250_350',
            'file_path': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/PubChem_vectors/sim_mol_250_350/experiment_results_250_350.pkl'
        },
        {
            'weight_range': 'PC_350-500',
            'pkl_folder': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/PC_350_500',
            'file_path': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/PubChem_vectors/sim_mol_350_500/experiment_results_350_500.pkl'
        },
        {
            'weight_range': 'ZINC_250-350',
            'pkl_folder': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/ZINC_250_350',
            'file_path': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/ZINC_1000_data/ZINC_1000_vectors_v2.pkl.pkl'
        }
    ]

    train_data_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/vector_db_train_5Mod_4M_v2_v2.pkl'
    output_folder = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/1.0_Experiment_baseline'
    ranking_method = 'HSQC & COSY'

    main(data_configs, train_data_path, output_folder, ranking_method)

In [None]:
file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/PubChem_vectors/sim_mol_0_250/experiment_results_0_250.pkl" 
with open(file_path, 'rb') as f:
    df = pickle.load(f)
df

In [None]:
a

#### 5.5 Histograms for NN

-- FIX VECTORS

In [None]:
import os
import pickle
import numpy as np
import pandas as pd
import random
import string

def main_csv_generation(data_configs, output_folder, ranking_method):
    os.makedirs(output_folder, exist_ok=True)

    for data_config in data_configs:
        print(f"Processing {data_config['weight_range']}...")
        pos_csv_path, neg_csv_path = exp_func.process_dataset_and_save_csv(data_config, output_folder, ranking_method)
        print(f"Positive examples saved to: {pos_csv_path}")
        print(f"Negative examples saved to: {neg_csv_path}")

if __name__ == "__main__":
    data_configs = [
        {
            'weight_range': 'PC_0-250',
            'pkl_folder': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/PC_0_250',
            'file_path': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/PubChem_vectors/ML_NMR_2M_XL_1H_V1_test_f_0_250_x1000_Vectors_v2.pkl'
        },
        {
            'weight_range': 'PC_250-350',
            'pkl_folder': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/PC_250_350',
            'file_path': '/projects/cc/se_users/knlr326//knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/PubChem_vectors/ML_NMR_2M_XL_1H_V1_test_f_250_350_x1000_Vectors_v2.pkl'
        },
        {
            'weight_range': 'PC_350-500',
            'pkl_folder': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/PC_350_500',
            'file_path': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/PubChem_vectors/ML_NMR_2M_XL_1H_V1_test_f_350_500_x1000_Vectors_v2.pkl'
        },
        {
            'weight_range': 'ZINC_250-350',
            'pkl_folder': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/ZINC_250_350',
            'file_path': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/ZINC_1000_data/ZINC_1000_vectors_v2.pkl'
        },
        {
            'weight_range': 'ZINC_250-350_4000',
            'pkl_folder': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/ZINC_250_350_4000',
            'file_path': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/ZINC_1000_data/ZINC_1000_vectors_v2.pkl'
        } ,       
    ]

    
    data_configs = [            
        {
            'weight_range': 'ZINC_250-350_4000',
            'pkl_folder': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/ZINC_250_350_4000',
            'file_path': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/ZINC_1000_data/ZINC_1000_vectors_v2.pkl'
        }  
    ]
    output_folder = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples_2'
    ranking_method = 'HSQC & COSY'

    #main_csv_generation(data_configs, output_folder, ranking_method)


#### After Fine-tuning for 100 neg compounds

In [None]:
import os
import pickle
import numpy as np
import pandas as pd
import random
import string

def main_csv_generation(data_configs, output_folder, ranking_method):
    os.makedirs(output_folder, exist_ok=True)

    for data_config in data_configs:
        print(f"Processing {data_config['weight_range']}...")
        pos_csv_path, neg_csv_path = exp_func.process_dataset_and_save_csv(data_config, output_folder, ranking_method)
        print(f"Positive examples saved to: {pos_csv_path}")
        print(f"Negative examples saved to: {neg_csv_path}")

if __name__ == "__main__":
    data_configs = [
        {
            'weight_range': 'PC_0-250_neg_100',
            'pkl_folder': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/PC_0_250',
            'file_path': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/PubChem_vectors/ML_NMR_2M_XL_1H_V1_test_f_0_250_x1000_Vectors_v2.pkl'
        },
        {
            'weight_range': 'PC_250-350_neg_100',
            'pkl_folder': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/PC_250_350',
            'file_path': '/projects/cc/se_users/knlr326//knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/PubChem_vectors/ML_NMR_2M_XL_1H_V1_test_f_250_350_x1000_Vectors_v2.pkl'
        },
        {
            'weight_range': 'PC_350-500_neg_100',
            'pkl_folder': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/PC_350_500',
            'file_path': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/PubChem_vectors/ML_NMR_2M_XL_1H_V1_test_f_350_500_x1000_Vectors_v2.pkl'
        },
        {
            'weight_range': 'ZINC_250-350_neg_100',
            'pkl_folder': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/ZINC_250_350',
            'file_path': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/ZINC_1000_data/ZINC_1000_vectors_v2.pkl'
        },
     
    ]

    
    output_folder = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples_2'
    ranking_method = 'HSQC & COSY'

    main_csv_generation(data_configs, output_folder, ranking_method)


In [None]:
import pandas as pd
import os

# List of file paths
file_paths = [
    "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/PC_0-250_negative.csv",
    "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/PC_0-250_positive.csv",
    "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/PC_250-350_negative.csv",
    "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/PC_250-350_positive.csv",
    "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/PC_350-500_negative.csv",
    "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/PC_350-500_positive.csv",
    "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/ZINC_250-350_negative.csv",
    "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/ZINC_250-350_positive.csv"
]

# Process each file
for file_path in file_paths:
    if os.path.exists(file_path):
        exp_func.rename_column_in_csv(file_path, "sample_id", "sample-id")
    else:
        print(f"File not found: {file_path}")

print("All files processed.")

In [None]:
config.SGNN_csv_gen_smi = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples_2/PC_0-250_negative.csv"
config.SGNN_csv_gen_smi = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples_2/PC_0-250_positive.csv"
config.SGNN_csv_gen_smi = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples_2/PC_250-350_negative.csv"
config.SGNN_csv_gen_smi = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples_2/PC_250-350_positive.csv"
config.SGNN_csv_gen_smi = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples_2/PC_350-500_negative.csv"
config.SGNN_csv_gen_smi = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples_2/PC_350-500_positive.csv"
config.SGNN_csv_gen_smi = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples_2/ZINC_250-350_negative.csv"
#config.SGNN_csv_gen_smi = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/ZINC_250-350_positive.csv"


config.SGNN_gen_folder_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/46_Project_3_Data/test"
config = ex.gen_sim_aug_data(config, IR_config)
config.csv_path_val = config.csv_1H_path_SGNN
#config = ex.filter_invalid_criteria(config.csv_1H_path_SGNN)

In [None]:
config.pickle_file_path = ""
config.data_size = 1000

config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_Drop4/MultimodalTransformer_time_1710027004.1571195_Loss_0.112.ckpt"
config.vector_db = config.SGNN_csv_gen_smi[:-4] + "vector.csv"
config = vectorize_db(config, stoi, stoi_MF, "db", "all")

import pandas as pd
import pickle
import torch
from tqdm import tqdm
import ast

vector_db = config.vector_db
# Load CSV data
df = pd.read_csv(vector_db)

# Convert 'Fingerprints' to tensors
tqdm.pandas()
df['Fingerprints'] = df['Fingerprints'].progress_apply(ast.literal_eval)
df['Fingerprints'] = df['Fingerprints'].progress_apply(lambda x: torch.tensor(x, dtype=torch.float32))

# Save the DataFrame to Pickle
pickle_file = vector_db.replace('.csv', '_v2.pkl')
with open(pickle_file, 'wb') as f:
    pickle.dump(df, f)

print(f"Data saved to {pickle_file}")


In [None]:
PC_0_250_pkl_pos = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/PC_0-250_positivevector_v2.pkl"
PC_250_350_pkl_pos = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/PC_250-350_positivevector_v2.pkl"
PC_350_500_pkl_pos = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/PC_350-500_positivevector_v2.pkl"
ZINC_250_350_pkl_pos = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/ZINC_250-350_positivevector_v2.pkl"

PC_0_250_pkl_neg = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/PC_0-250_negativevector_v2.pkl"
PC_250_350_pkl_neg = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/PC_250-350_negativevector_v2.pkl"
PC_350_500_pkl_neg = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/PC_350-500_negativevector_v2.pkl"
ZINC_250_350_pkl_neg = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/ZINC_250-350_negativevector_v2.pkl"

#big_vector_db = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/vector_db_train_5Mod_4M_v2_v2.pkl"
#output_folder = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples"
#database_nr = 3975764


##### Chart for Anna Paper PubChem, ZINC

In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os

# File paths
PC_0_250_pos = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples_2/PC_0-250_positive.csv"
PC_250_350_pos = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples_2/PC_250-350_positive.csv"
PC_350_500_pos = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples_2/PC_350-500_positive.csv"
ZINC_250_350_pos = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples_2/ZINC_250-350_positive.csv"
PC_0_250_neg = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples_2/PC_0-250_negative.csv"
PC_250_350_neg = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples_2/PC_250-350_negative.csv"
PC_350_500_neg = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples_2/PC_350-500_negative.csv"
ZINC_250_350_neg = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples_2/ZINC_250-350_negative.csv"

# Save location for the figure
save_dir = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2"

#def load_pickle(file_path):
#    with open(file_path, 'rb') as file:
#        return pickle.load(file)
def load_pickle(file_path):
    return pd.read_csv(file_path)

def process_data(pos_df, neg_df):
    total_compounds = 1000  # Assuming 1000 compounds per range
    pos_count = len(pos_df)
    neg_count = len(neg_df)
    failed_count = total_compounds - (pos_count + neg_count)
    
    return {
        'Correct': (pos_count / total_compounds) * 100,
        'Incorrect': (neg_count / total_compounds) * 100,
        'Failed': (failed_count / total_compounds) * 100
    }

# Load DataFrames
PC_0_250_pos_df = pd.read_csv(PC_0_250_pos)
PC_0_250_neg_df = pd.read_csv(PC_0_250_neg)
PC_250_350_pos_df = pd.read_csv(PC_250_350_pos)
PC_250_350_neg_df = pd.read_csv(PC_250_350_neg)
PC_350_500_pos_df = pd.read_csv(PC_350_500_pos)
PC_350_500_neg_df = pd.read_csv(PC_350_500_neg)
ZINC_250_350_pos_df = pd.read_csv(ZINC_250_350_pos)
ZINC_250_350_neg_df = pd.read_csv(ZINC_250_350_neg)

# Process data
data_0_250 = process_data(PC_0_250_pos_df, PC_0_250_neg_df)
data_250_350 = process_data(PC_250_350_pos_df, PC_250_350_neg_df)
data_350_500 = process_data(PC_350_500_pos_df, PC_350_500_neg_df)
data_ZINC_250_350 = process_data(ZINC_250_350_pos_df, ZINC_250_350_neg_df)

# Prepare data for plotting
plot_data = {
    'Correct': [data_0_250['Correct'], data_250_350['Correct'], data_350_500['Correct'], data_ZINC_250_350['Correct']],
    'Incorrect': [data_0_250['Incorrect'], data_250_350['Incorrect'], data_350_500['Incorrect'], data_ZINC_250_350['Incorrect']],
    'Failed': [data_0_250['Failed'], data_250_350['Failed'], data_350_500['Failed'], data_ZINC_250_350['Failed']]
}

def create_stacked_bar_chart(data, categories, title, save_path):
    fontsize = 22
    colors = [
    '#8BE5A0',  # mint green
    '#A1C8F3',  # light blue/periwinkle
    '#FFB381',  # salmon/peach
    '#FF9D9A',  # coral pink
    '#D1B9FE',  # lavender
    '#DEBA9A',  # beige/tan
    '#FCAEE3',  # pink
    '#CFCECE',  # light gray
    '#FEFDA2',  # pale yellow
    '#B8F1EF',  # light cyan/aqua
]
    fig, ax = plt.subplots(figsize=(8, 8))
    
    bottom = np.zeros(4)
    bars = []
    for i, category in enumerate(categories):
        values = data[category]
        bar = ax.bar(range(4), values, bottom=bottom, label=category, color=colors[i], edgecolor='black')
        # Set black edges for each bar patch
        for patch in bar:
            patch.set_edgecolor('black')
        bottom += values
        bars.append(bar)
    
    ax.set_title(title, fontsize=fontsize, pad=20)
    ax.set_xticks(range(4))
    ax.set_xticklabels(['PC (0-250 Da)', 'PC (250-350 Da)', 'PC (350-500 Da)', 'ZINC (250-350 Da)'], fontsize=12)
    ax.set_xticklabels(['Set 1', 'Set 2', 'Set 3', 'ZINC'], fontsize=fontsize)
    ax.set_ylabel('Percentage', fontsize=fontsize)
    ax.set_ylim(0, 100)
    
    ax.tick_params(axis='y', labelsize=fontsize)
    
    ax.legend(loc='lower right', fontsize=fontsize-4)
    
    for i in ax.containers:
        ax.bar_label(i, fmt='%.1f%%', label_type='center', fontsize=fontsize)
    
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    return fig

categories = ['Correct', 'Incorrect', 'Failed']

# Create save path
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
save_path = os.path.join(save_dir, 'PubChem_ZINC_Dataset_Results_v2.png')

fig = create_stacked_bar_chart(plot_data, categories, 'PubChem and ZINC Dataset Results', save_path)
"""
# Print the percentages for each range
for range_name, data in zip(['PC 0-250', 'PC 250-350', 'PC 350-500', 'ZINC 250-350'], 
                            [data_0_250, data_250_350, data_350_500, data_ZINC_250_350]):
    print(f"\nPercentages for {range_name} Da range:")
    for category, percentage in data.items():
        print(f"{category}: {percentage:.1f}%")

# Print DataFrame information
for name, df in [
    ("PC_0_250_pos", PC_0_250_pos_df),
    ("PC_0_250_neg", PC_0_250_neg_df),
    ("PC_250_350_pos", PC_250_350_pos_df),
    ("PC_250_350_neg", PC_250_350_neg_df),
    ("PC_350_500_pos", PC_350_500_pos_df),
    ("PC_350_500_neg", PC_350_500_neg_df),
    ("ZINC_250_350_pos", ZINC_250_350_pos_df),
    ("ZINC_250_350_neg", ZINC_250_350_neg_df)
]:
    print(f"\n{name} DataFrame:")
    print(df.info())
    print(df.head())

print(f"\nFigure saved to: {save_path}")"""

##### Save 100 neg molecules

In [None]:
import pickle
import pandas as pd
import os

# File paths
negative_files = {
    'PC_0_250': "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/PC_0-250_negativevector_v2.pkl",
    'PC_250_350': "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/PC_250-350_negativevector_v2.pkl",
    'PC_350_500': "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/PC_350-500_negativevector_v2.pkl",
    'ZINC_250_350': "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/ZINC_250-350_negativevector_v2.pkl"
}

def load_pickle(file_path):
    with open(file_path, 'rb') as file:
        return pickle.load(file)

def sample_and_save(file_path, sample_size=100):
    # Load the DataFrame
    df = load_pickle(file_path)
    
    # Sample 100 rows
    sampled_df = df.sample(n=sample_size, random_state=42)
    
    # Create the new file name
    dir_path = os.path.dirname(file_path)
    file_name = os.path.basename(file_path)
    new_file_name = file_name.replace('.pkl', '_100_neg.csv')
    new_file_path = os.path.join(dir_path, new_file_name)
    
    # Save the sampled DataFrame as CSV
    sampled_df.to_csv(new_file_path, index=False)
    
    print(f"Saved {sample_size} samples to: {new_file_path}")

# Process each negative file
for name, file_path in negative_files.items():
    print(f"Processing {name}...")
    sample_and_save(file_path)

print("All files processed and saved.")

In [None]:
import pandas as pd
import os

# File paths
csv_files = [
    "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/PC_0-250_negativevector_v2_100_neg.csv",
    "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/PC_250-350_negativevector_v2_100_neg.csv",
    "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/PC_350-500_negativevector_v2_100_neg.csv",
    "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/ZINC_250-350_negativevector_v2_100_neg.csv"
]

def add_sample_id(file_path):
    # Load the CSV file
    df = pd.read_csv(file_path)
    
    # Create the sample-id column
    df['sample-id'] = [f'NEG{i:05d}' for i in range(1, len(df) + 1)]
    
    # Move the sample-id column to the first position
    cols = df.columns.tolist()
    cols = ['sample-id'] + [col for col in cols if col != 'sample-id']
    df = df[cols]
    
    # Save the updated DataFrame back to CSV
    df.to_csv(file_path, index=False)
    
    print(f"Added sample-id column to: {file_path}")
    print(f"First 5 rows of the updated file:")
    print(df.head())
    print("\n")

# Process each CSV file
for file_path in csv_files:
    add_sample_id(file_path)

print("All files processed and updated with sample-id column.")

In [None]:
pd.read_csv("/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/ZINC_250-350_negativevector_v2_100_neg.csv")

#### 5.6 Use 100 of each failed categories to do the improvment cycle

In [None]:
"""import os
from typing import List, Dict, Any, Union, Tuple
import pandas as pd
from datetime import datetime
import pickle
import re
import tempfile


def split_dataset(config, chunk_size: int) -> List[pd.DataFrame]:
    df = pd.read_csv(config.SGNN_csv_gen_smi)
    return [df[i:i+chunk_size] for i in range(0, len(df), chunk_size)]


def create_chunk_folder(config, idx: int) -> str:
    base_dir = config.model_save_dir
    current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S")
    chunk_folder_name = f"chunk_{idx:03d}_{current_datetime}"
    chunk_folder_path = os.path.join(base_dir, chunk_folder_name)
    
    os.makedirs(chunk_folder_path, exist_ok=True)
    print(f"Created folder for chunk {idx}: {chunk_folder_path}")
    
    return chunk_folder_path

def test_pretrained_model_on_sim_data_before(config, IR_config, stoi, itos, stoi_MF, itos_MF, chunk, idx):
    MW_filter, greedy_full = True, False
    
    print("prepare_data")
    config = prepare_data(config, chunk)
    print("generate_simulated_data")
    config = generate_simulated_data(config, IR_config)

    print("load_model_and_data")
    model_MMT, val_dataloader, val_dataloader_multi = load_model_and_data(config, stoi, stoi_MF)

    print("run_model_analysis")
    prob_dict_results_1c_, results_dict_1c_ = mrtf.run_model_analysis(config, model_MMT, val_dataloader_multi, stoi, itos)

    results = test_model_performance(config, model_MMT, val_dataloader, val_dataloader_multi, stoi, itos, stoi_MF, itos_MF)

    save_results_before(results, config, idx)

    return config

def prepare_data(config: Any, chunk: pd.DataFrame) -> Any:
    chunk_csv_path = os.path.join(config.pkl_save_folder, "SGNN_csv_gen_smi.csv")
    chunk.to_csv(chunk_csv_path)
    config.SGNN_csv_gen_smi = chunk_csv_path 
    config.data_size = len(chunk)
    return config

def generate_simulated_data(config: Any, IR_config: Any) -> Any:
    config.execution_type = "data_generation"
    if config.execution_type == "data_generation":
        print("\033[1m\033[31mThis is: data_generation\033[0m")
        #import IPython; IPython.embed();

        config = ex.gen_sim_aug_data(config, IR_config)
        backup_config_paths(config)
    return config

def backup_config_paths(config: Any) -> None:
    config.csv_1H_path_SGNN_backup = copy.deepcopy(config.csv_1H_path_SGNN)
    config.csv_13C_path_SGNN_backup = copy.deepcopy(config.csv_13C_path_SGNN)
    config.csv_HSQC_path_SGNN_backup = copy.deepcopy(config.csv_HSQC_path_SGNN)
    config.csv_COSY_path_SGNN_backup = copy.deepcopy(config.csv_COSY_path_SGNN)
    config.IR_data_folder_backup = copy.deepcopy(config.IR_data_folder)

def save_results_before(results: Dict[str, Any], config: Any, idx: int) -> None:
    variables_to_save = {
        'avg_tani_bl_ZINC': results['avg_tani_bl_ZINC_'],
        'results_dict_greedy_bl_ZINC': results.get('results_dict_greedy_bl_ZINC_'),
        'failed_bl_ZINC': results.get('failed_bl_ZINC_'),
        'avg_tani_greedy_bl_ZINC': results['avg_tani_greedy_bl_ZINC_'],
        'results_dict_ZINC_greedy_bl': results.get('results_dict_ZINC_greedy_bl_'),
        'total_results_bl_ZINC': results['total_results_bl_ZINC_'],
        'corr_sampleing_prob_bl_ZINC': results['corr_sampleing_prob_bl_ZINC_'],
        'results_dict_bl_ZINC': results['results_dict_bl_ZINC_'],
    }
    save_data_with_datetime_index(variables_to_save, config.pkl_save_folder, "before_sim_data", idx)

def create_run_folder(chunk_folder, idx):
    current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S")
    run_folder_name = f"run_{idx}_{current_datetime}"
    run_folder_path = os.path.join(chunk_folder, run_folder_name)
    
    os.makedirs(run_folder_path, exist_ok=True)
    print(f"Created folder for run {idx}: {run_folder_path}")
    
    return run_folder_path

def fine_tune_model_aug_mol(config, stoi, stoi_MF, chunk, idx):
    #import IPython; IPython.embed();
    config, all_gen_smis, aug_mol_df = generate_augmented_molecules_from_aug_mol(config, chunk, idx)
    
    config.parent_model_save_dir = config.model_save_dir
    config.model_save_dir = config.current_run_folder 
    
    if config.execution_type == "transformer_improvement":
        print("\033[1m\033[31mThis is: transformer_improvement, sim_data_gen == TRUE\033[0m")
        config.training_setup = "pretraining"
        mtf.run_MMT(config, stoi, stoi_MF)
    
    config.model_save_dir = config.parent_model_save_dir
    #config = ex.update_model_path(config)

    return config, aug_mol_df, all_gen_smis


def generate_augmented_molecules_from_aug_mol(config, chunk, idx):
    #import IPython; IPython.embed();

    ############# THis is just relevant for the augmented molecules #############
    #chunk.rename(columns={'SMILES': 'SMILES_orig', 'SMILES_regio_isomers': 'SMILES'}, inplace=True)
    #############################################################################
    
    script_dir = os.getcwd()
    
    base_path = os.path.abspath(os.path.join(script_dir, 'deep-molecular-optimization'))

    csv_file_path = f'{base_path}/data/MMP/test_selection_2.csv'
    chunk.to_csv(csv_file_path, index=False)
    print(f"CSV file '{csv_file_path}' created successfully.")

    config.data_size = len(chunk)
    config.n_samples = config.data_size

    config, results_dict_MF = generate_smiles_mf(config)

    combined_list_MF = process_generated_smiles(results_dict_MF, config)

    all_gen_smis = filter_and_combine_smiles(combined_list_MF)

    aug_mol_df = create_augmented_dataframe(all_gen_smis)

    config, final_df = ex.blend_aug_with_train_data(config, aug_mol_df)

    config = ex.gen_sim_aug_data(config, IR_config)
    config.execution_type = "transformer_improvement"

    return config, all_gen_smis, aug_mol_df


def fine_tune_model(config, stoi, stoi_MF, chunk, idx):
    """
    Fine-tune the model on a chunk of data.
    """
    config, aug_mol_df, all_gen_smis = generate_augmented_molecules(config, chunk, idx)
    
    config.parent_model_save_dir = config.model_save_dir
    new_model_save_dir = create_model_save_dir(config.parent_model_save_dir, idx)
    config.model_save_dir = new_model_save_dir
    
    # Fine-tune the model
    if config.execution_type == "transformer_improvement":
        print("\033[1m\033[31mThis is: transformer_improvement, sim_data_gen == TRUE\033[0m")
        config.training_setup = "pretraining"
        mtf.run_MMT(config, stoi, stoi_MF)
        
    #config = ex.update_model_path(config)
    config.model_save_dir = config.parent_model_save_dir
    
    return config, aug_mol_df, all_gen_smis

def generate_augmented_molecules(config, chunk, idx):
    #import IPython; IPython.embed();
    script_dir = os.getcwd()
    
    base_path = os.path.abspath(os.path.join(script_dir, 'deep-molecular-optimization'))

    csv_file_path = f'{base_path}/data/MMP/test_selection_2.csv'
    chunk.to_csv(csv_file_path, index=False)
    print(f"CSV file '{csv_file_path}' created successfully.")

    config.data_size = len(chunk)
    config.n_samples = config.data_size

    config, results_dict_MF = generate_smiles_mf(config)

    combined_list_MF = process_generated_smiles(results_dict_MF, config)

    all_gen_smis = filter_and_combine_smiles(combined_list_MF)

    aug_mol_df = create_augmented_dataframe(all_gen_smis)

    config, final_df = ex.blend_aug_with_train_data(config, aug_mol_df)

    config = ex.gen_sim_aug_data(config, IR_config)
    config.execution_type = "transformer_improvement"

    return config, all_gen_smis, aug_mol_df


def generate_smiles_mf(config):
    print("\033[1m\033[31mThis is: SMI_generation_MF\033[0m")
    return ex.SMI_generation_MF(config, stoi, stoi_MF, itos, itos_MF)

def process_generated_smiles(results_dict_MF, config):
    results_dict_MF = {key: value for key, value in results_dict_MF.items() if not hf.contains_only_nan(value)}
    for key, value in results_dict_MF.items():
        results_dict_MF[key] = hf.remove_nan_from_list(value)

    combined_list_MF, _, _, _ = cv.plot_cluster_MF(results_dict_MF, config)
    return combined_list_MF

def filter_and_combine_smiles(combined_list_MF):
    print("\033[1m\033[31mThis is: combine_MMT_MF\033[0m")
    all_gen_smis = combined_list_MF
    all_gen_smis = [smiles for smiles in all_gen_smis if smiles != 'NAN']

    val_data = pd.read_csv(config.csv_path_val)
    all_gen_smis = mrtf.filter_smiles(val_data, all_gen_smis)
    return all_gen_smis

def create_augmented_dataframe(all_gen_smis):
    length_of_list = len(all_gen_smis)
    random_number_strings = [f"GT_{str(i).zfill(7)}" for i in range(1, length_of_list + 1)]
    return pd.DataFrame({'SMILES': all_gen_smis, 'sample-id': random_number_strings})

def setup_data_paths(config):
    base_path_acd = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/37_Richard_ACD_sim_data/"
    config.csv_1H_path_ACD = f"{base_path_acd}ACD_1H_with_SN_filtered_v3.csv"
    config.csv_13C_path_ACD = f"{base_path_acd}ACD_13C_with_SN_filtered_v3.csv"
    config.csv_HSQC_path_ACD = f"{base_path_acd}ACD_HSQC_with_SN_filtered_v3.csv"
    config.csv_COSY_path_ACD = f"{base_path_acd}ACD_COSY_with_SN_filtered_v3.csv"
    config.IR_data_folder_ACD = f"{base_path_acd}IR_spectra"
    
    base_path_exp = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/36_Richard_43_dataset/experimenal_data/"
    config.csv_1H_path_exp = f"{base_path_exp}real_1H_with_AZ_SMILES_v3.csv"
    config.csv_13C_path_exp = f"{base_path_exp}real_13C_with_AZ_SMILES_v3.csv"
    config.csv_HSQC_path_exp = f"{base_path_exp}real_HSQC_with_AZ_SMILES_v3.csv"
    config.csv_COSY_path_exp = f"{base_path_exp}real_COSY_with_AZ_SMILES_v3.csv"
    config.IR_data_folder_exp = f"{base_path_exp}IR_data"
    return config

def test_model_on_neg_dataset(config, IR_config, stoi, itos, stoi_MF, itos_MF, chunk, composite_idx, aug_mol_df, all_gen_smis):
    checkpoint_path_backup = config.checkpoint_path    
    config.pickle_file_path = ""
    config.training_mode = "1H_13C_HSQC_COSY_IR_MF_MW"
    config = test_on_data(config, IR_config, stoi, itos, stoi_MF, itos_MF, chunk, composite_idx, "sim", aug_mol_df, all_gen_smis)
    config.checkpoint_path = checkpoint_path_backup
    return config

def test_on_data(config, IR_config, stoi, itos, stoi_MF, itos_MF, chunk, composite_idx, data_type, aug_mol_df, all_gen_smis):
    if data_type == 'sim':
        restore_backup_configs(config)
    #else:
    #    sample_ids = chunk['sample-id'].tolist()
    #    process_spectrum_data(config, sample_ids, data_type)
    #import IPython; IPython.embed();

    update_config_settings(config)
    last_checkpoint = get_last_checkpoint(config.current_run_folder)
    config.checkpoint_path = last_checkpoint
    
    model_MMT, val_dataloader, val_dataloader_multi = load_model_and_data(config, stoi, stoi_MF)
    
    prob_dict_results_1c_, results_dict_1c_ = mrtf.run_model_analysis(config, model_MMT, val_dataloader_multi, stoi, itos)

    results = test_model_performance(config, model_MMT, val_dataloader, val_dataloader_multi,
                                     stoi, itos, stoi_MF, itos_MF)
    
    if data_type == 'sim':
        results['aug_mol_df'] = aug_mol_df
        results['all_gen_smis'] = all_gen_smis
    
    save_results_acd_exp(results, config, data_type, composite_idx)
    return config

def restore_backup_configs(config):
    config.csv_1H_path_SGNN = config.csv_1H_path_SGNN_backup
    config.csv_13C_path_SGNN = config.csv_13C_path_SGNN_backup
    config.csv_HSQC_path_SGNN = config.csv_HSQC_path_SGNN_backup
    config.csv_COSY_path_SGNN = config.csv_COSY_path_SGNN_backup
    config.IR_data_folder = config.IR_data_folder_backup 
    config.csv_path_val = config.csv_1H_path_SGNN_backup
    config.pickle_file_path = ""
    

def process_spectrum_data(config: Any, sample_ids: List[str], data_type: str) -> None:
    spectrum_types = ['1H', '13C', 'HSQC', 'COSY']
    for spectrum in spectrum_types:
        csv_path = getattr(config, f'csv_{spectrum}_path_{data_type}')
        df_data = pd.read_csv(csv_path)
        df_data['sample-id'] = df_data['AZ_Number']
        data = select_relevant_samples(df_data, sample_ids)
        dummy_path, config = save_and_update_config(config, data_type, spectrum, data)
        print(f"Saved {spectrum} data to: {dummy_path}")
    if data_type == "ACD" or data_type == "sim":
        config.IR_data_folder = config.IR_data_folder_backup 
    elif  data_type == "exp":
        config.IR_data_folder = config.IR_data_folder_exp 

    
    
def select_relevant_samples(df: pd.DataFrame, sample_ids: List[str]) -> pd.DataFrame:
    return df[df['sample-id'].isin(sample_ids)]

def save_and_update_config(config, data_type: str, spectrum_type: str, data: pd.DataFrame) -> Tuple[str, Any]:
    temp_dir = tempfile.mkdtemp()
    dummy_path = os.path.join(temp_dir, f"{data_type}_{spectrum_type}_selected_samples.csv")
    
    data.to_csv(dummy_path, index=False)
    
    config_key = f'csv_{spectrum_type}_path_SGNN'
    setattr(config, config_key, dummy_path)
    
    return dummy_path, config

def update_config_settings(config: Any) -> None:
    config.csv_path_val = config.csv_1H_path_SGNN
    config.pickle_file_path = ""

def get_last_checkpoint(model_folder: str) -> str:
    checkpoints = [f for f in os.listdir(model_folder) if f.endswith('.ckpt')]
    if not checkpoints:
        raise ValueError(f"No checkpoints found in {model_folder}")
    
    last_checkpoint = max(checkpoints, key=lambda x: os.path.getmtime(os.path.join(model_folder, x)))
    return os.path.join(model_folder, last_checkpoint)

def load_model_and_data(config: Any, stoi: Dict, stoi_MF: Dict) -> Tuple[Any, Any, Any]:
    #import IPython; IPython.embed();

    val_dataloader = mrtf.load_data(config, stoi, stoi_MF, single=True, mode="val")
    val_dataloader_multi = mrtf.load_data(config, stoi, stoi_MF, single=False, mode="val")
    model_MMT = mrtf.load_MMT_model(config)
    return model_MMT, val_dataloader, val_dataloader_multi

def test_model_performance(config: Any, model_MMT: Any, val_dataloader: Any, val_dataloader_multi: Any, 
                           stoi: Dict, itos: Dict, stoi_MF: Dict, itos_MF: Dict) -> Dict[str, Any]:
    print("\033[1m\033[31mThis is: test_performance\033[0m")
    
    MW_filter = True
    greedy_full = False
    
    model_CLIP = mrtf.load_CLIP_model(config)
    
    results = {}
    
    results['results_dict_bl_ZINC_'] = mrtf.run_test_mns_performance_CLIP_3(
        config, model_MMT, model_CLIP, val_dataloader, stoi, itos, MW_filter)
    results['results_dict_bl_ZINC_'], counter = mrtf.filter_invalid_inputs(results['results_dict_bl_ZINC_'])

    results['avg_tani_bl_ZINC_'], html_plot = rbgvm.plot_hist_of_results(results['results_dict_bl_ZINC_'])

    if greedy_full:
        results['results_dict_greedy_bl_ZINC_'], results['failed_bl_ZINC_'] = mrtf.run_test_performance_CLIP_greedy_3(
            config, stoi, stoi_MF, itos, itos_MF)
        results['avg_tani_greedy_bl_ZINC_'], html_plot_greedy = rbgvm.plot_hist_of_results_greedy(
            results['results_dict_greedy_bl_ZINC_'])
    else:
        config, results['results_dict_ZINC_greedy_bl_'] = mrtf.run_greedy_sampling(
            config, model_MMT, val_dataloader_multi, itos, stoi)
        results['avg_tani_greedy_bl_ZINC_'] = results['results_dict_ZINC_greedy_bl_']["tanimoto_mean"]

    results['total_results_bl_ZINC_'] = mrtf.run_test_performance_CLIP_3(
        config, model_MMT, val_dataloader, stoi)
    results['corr_sampleing_prob_bl_ZINC_'] = results['total_results_bl_ZINC_']["statistics_multiplication_avg"][0]

    print("avg_tani, avg_tani_greedy, corr_sampleing_prob'")
    print(results['avg_tani_bl_ZINC_'], results['avg_tani_greedy_bl_ZINC_'], results['corr_sampleing_prob_bl_ZINC_'])
    print("Greedy tanimoto results")
    rbgvm.plot_hist_of_results_greedy_new(results['results_dict_ZINC_greedy_bl_'])

    return results

def save_results_acd_exp(results: Dict[str, Any], config: Any, data_type: str, composite_idx: str) -> None:
    variables_to_save = {
        'avg_tani_bl_ZINC': results['avg_tani_bl_ZINC_'],
        'results_dict_greedy_bl_ZINC': results.get('results_dict_greedy_bl_ZINC_'),
        'failed_bl_ZINC': results.get('failed_bl_ZINC_'),
        'avg_tani_greedy_bl_ZINC': results['avg_tani_greedy_bl_ZINC_'],
        'results_dict_ZINC_greedy_bl': results.get('results_dict_ZINC_greedy_bl_'),
        'total_results_bl_ZINC': results['total_results_bl_ZINC_'],
        'corr_sampleing_prob_bl_ZINC': results['corr_sampleing_prob_bl_ZINC_'],
        'results_dict_bl_ZINC': results['results_dict_bl_ZINC_'],
        'checkpoint_path': config.checkpoint_path,
    }
    
    if data_type == 'sim':
        variables_to_save['aug_mol_df'] = results.get('aug_mol_df')
        variables_to_save['all_gen_smis'] = results.get('all_gen_smis')
    
    save_data_with_datetime_index(
        variables_to_save, 
        config.pkl_save_folder, 
        f"{data_type}_sim_data", 
        composite_idx
    )

def save_data_with_datetime_index(data: Any, base_folder: str, name: str, idx: Union[int, str]) -> None:
    current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"{current_datetime}_{name}_{idx}.pkl"
    os.makedirs(base_folder, exist_ok=True)
    file_path = os.path.join(base_folder, filename)

    
    with open(file_path, 'wb') as f:
        pickle.dump(data, f)
    
    print(f"Data saved to: {file_path}")"""

In [None]:
import utils_MMT.improvement_cycle_neg_examples_v15_4 as icne
import shutil

In [None]:
"""
def main_IC_neg(chunk_size, config, IR_config, stoi, itos, stoi_MF, itos_MF, num_training_runs=3):
    chunks = icne.split_dataset(config, chunk_size)
    config.model_save_dir = config.pkl_save_folder
    model_save_dir_backup = config.model_save_dir
    original_checkpoint_path = config.checkpoint_path  # Store the original checkpoint path

    for chunk_idx, chunk in enumerate(chunks):
        print(f"Processing chunk {chunk_idx+1} of {len(chunks)}")
        
        chunk_folder = icne.create_chunk_folder(config, chunk_idx)
        config.current_chunk_folder = chunk_folder
            
        config.blank_percentage = 0
        config = icne.test_pretrained_model_on_sim_data_before(config, IR_config, stoi, itos, stoi_MF, itos_MF, chunk, f"{chunk_idx}_{0}")
        print(config.csv_1H_path_SGNN)
        for run_idx in range(num_training_runs):
            print(f"Starting training run {run_idx+1} of {num_training_runs}")
            
            run_folder = icne.create_run_folder(config.current_chunk_folder, f"{chunk_idx}_{run_idx}")
            config.current_run_folder = run_folder
            config.model_save_dir = run_folder

            config.blank_percentage = 50
            config, aug_mol_df, all_gen_smis = icne.fine_tune_model_aug_mol(config, IR_config, stoi, stoi_MF, chunk, f"{chunk_idx}_{run_idx}")
            #import IPython; IPython.embed();

            ### Retrun the labelling of the smiles to test it on the correct one
            #chunk.rename(columns={'SMILES': 'SMILES_regio_isomers', 'SMILES_orig': 'SMILES'}, inplace=True)

            config.blank_percentage = 0
            config = icne.setup_data_paths(config)
            config = icne.test_model_on_neg_dataset(config, IR_config, stoi, itos, stoi_MF, itos_MF, chunk, f"{chunk_idx}_{run_idx}", aug_mol_df, all_gen_smis)

            config.checkpoint_path = original_checkpoint_path
        
        print(f"Chunk {chunk_idx+1} completed. All training runs finished.")
        config.model_save_dir = model_save_dir_backup"""


In [None]:
def should_skip_chunk(base_path, chunk_idx):
    """
    Check if a chunk folder exists and contains sufficient processed data,
    ignoring timestamps in folder names.
    
    Args:
        base_path (str): Base directory path
        chunk_idx (int): Index of the chunk to check
        
    Returns:
        bool: True if the chunk should be skipped (already processed), False otherwise
    """
    import os
    
    # Format the chunk prefix to match (ensuring 3 digits with leading zeros)
    chunk_prefix = f"chunk_{chunk_idx:03d}_"
    
    # Get all directories in base_path
    try:
        all_dirs = [d for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d))]
        
        # Find matching chunk folder (ignoring timestamp)
        matching_chunks = [d for d in all_dirs if d.startswith(chunk_prefix)]
        
        if not matching_chunks:
            return False
            
        # Use the first matching chunk folder found
        chunk_folder = matching_chunks[0]
        chunk_path = os.path.join(base_path, chunk_folder)
        
        # Check if there are any run folders inside the chunk folder
        run_folders = [d for d in os.listdir(chunk_path) if os.path.isdir(os.path.join(chunk_path, d))]
        if not run_folders:
            return False
        
        # Check if at least one run folder contains more than 5 files
        for run_folder in run_folders:
            run_path = os.path.join(chunk_path, run_folder)
            files = [f for f in os.listdir(run_path) if os.path.isfile(os.path.join(run_path, f))]
            if len(files) > 5:
                return True
                
        return False
        
    except Exception as e:
        print(f"Error checking chunk {chunk_idx}: {str(e)}")
        return False


def main_IC_neg(chunk_size, config, IR_config, stoi, itos, stoi_MF, itos_MF, num_training_runs=3):
    chunks = icne.split_dataset(config, chunk_size)
    config.model_save_dir = config.pkl_save_folder
    model_save_dir_backup = config.model_save_dir
    original_checkpoint_path = config.checkpoint_path  # Store the original checkpoint path

    for chunk_idx, chunk in enumerate(chunks):
        print(f"Processing chunk {chunk_idx+1} of {len(chunks)}")
        
        if should_skip_chunk(config.pkl_save_folder, chunk_idx):
            print(f"Skipping chunk {chunk_idx+1} as it appears to be already processed")
            continue
            
        # If we reach here, either the chunk doesn't exist or is incomplete
        # Find and delete any existing incomplete chunk folder
        chunk_prefix = f"chunk_{chunk_idx:03d}_"
        try:
            all_dirs = [d for d in os.listdir(config.pkl_save_folder) if os.path.isdir(os.path.join(config.pkl_save_folder, d))]
            matching_chunks = [d for d in all_dirs if d.startswith(chunk_prefix)]
            if matching_chunks:
                chunk_to_delete = os.path.join(config.pkl_save_folder, matching_chunks[0])
                print(f"Removing incomplete chunk folder: {matching_chunks[0]}")
                shutil.rmtree(chunk_to_delete)
        except Exception as e:
            print(f"Error while trying to delete incomplete chunk folder: {str(e)}")
            
            
        chunk_folder = icne.create_chunk_folder(config, chunk_idx)
        config.current_chunk_folder = chunk_folder
            
        config.blank_percentage = 0
        config = icne.test_pretrained_model_on_sim_data_before(config, IR_config, stoi, itos, stoi_MF, itos_MF, chunk, f"{chunk_idx}_{0}")
        print(config.csv_1H_path_SGNN)
        for run_idx in range(num_training_runs):
            print(f"Starting training run {run_idx+1} of {num_training_runs}")
            
            run_folder = icne.create_run_folder(config.current_chunk_folder, f"{chunk_idx}_{run_idx}")
            config.current_run_folder = run_folder
            config.model_save_dir = run_folder

            config.blank_percentage = 50
            config, aug_mol_df, all_gen_smis = icne.fine_tune_model_aug_mol(config, IR_config, stoi, stoi_MF, chunk, f"{chunk_idx}_{run_idx}")
            #import IPython; IPython.embed();

            ### Retrun the labelling of the smiles to test it on the correct one
            #chunk.rename(columns={'SMILES': 'SMILES_regio_isomers', 'SMILES_orig': 'SMILES'}, inplace=True)

            config.blank_percentage = 0
            config = icne.setup_data_paths(config)
            config = icne.test_model_on_neg_dataset(config, IR_config, stoi, itos, stoi_MF, itos_MF, chunk, f"{chunk_idx}_{run_idx}", aug_mol_df, all_gen_smis)

            config.checkpoint_path = original_checkpoint_path
        
        print(f"Chunk {chunk_idx+1} completed. All training runs finished.")
        config.model_save_dir = model_save_dir_backup


In [None]:
### RUN 
config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_Drop4/MultimodalTransformer_time_1710027004.1571195_Loss_0.112.ckpt"
#config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8_MMT_MW2_Drop/MultimodalTransformer_time_1706856620.371885_Loss_0.202.ckpt"

config.pkl_save_folder = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/IC_of_neg_examples/PC_350_500_neg"
#config.model_save_dir = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/small_8_3" # Folder where networks are saved
config.SGNN_csv_gen_smi = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/CSV_examples/PC_350-500_negativevector_v2_100_neg.csv"
#config.SGNN_csv_gen_smi = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/37_Richard_ACD_sim_data/ACD_1H_with_SN_filtered_v3.csv"

config.MF_generations = 30 #50
config.MF_delta_weight = 100
config.max_scaffold_generations = 300
config.blank_percentage = 50
config.weight_MW = 100
config.lr_pretraining = 3e-4
config.tr_te_split = 0.9
config.batch_size = 64
config.num_epochs = 30
config.temperature = 1
config.multinom_runs = 20 #20
config.train_data_blend = 0

chunk_size = 1

main_IC_neg(chunk_size, config, IR_config, stoi, itos, stoi_MF, itos_MF, 1)

In [None]:
import os

def analyze_chunks(base_path):
    """
    Analyze all chunk folders to find those that don't meet requirements.
    
    Args:
        base_path (str): Path to the directory containing chunk folders
    """
    incomplete_chunks = []
    try:
        # Get all directories in base_path
        all_dirs = [d for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d))]
        chunk_dirs = [d for d in all_dirs if d.startswith('chunk_')]
        
        for chunk_dir in chunk_dirs:
            chunk_path = os.path.join(base_path, chunk_dir)
            
            # Check run folders
            run_folders = [d for d in os.listdir(chunk_path) if os.path.isdir(os.path.join(chunk_path, d))]
            
            has_sufficient_files = False
            for run_folder in run_folders:
                run_path = os.path.join(chunk_path, run_folder)
                files = [f for f in os.listdir(run_path) if os.path.isfile(os.path.join(run_path, f))]
                if len(files) > 5:
                    has_sufficient_files = True
                    break
            
            if not has_sufficient_files:
                incomplete_chunks.append(chunk_dir)
        
        print(f"\nTotal chunks analyzed: {len(chunk_dirs)}")
        print(f"Number of incomplete chunks: {len(incomplete_chunks)}")
        print("\nIncomplete chunks:")
        for chunk in incomplete_chunks:
            print(f"- {chunk}")
            
    except Exception as e:
        print(f"Error during analysis: {str(e)}")

# Run the analysis
base_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/IC_of_neg_examples/ZINC_250_350_neg"
analyze_chunks(base_path)

delete before files

In [None]:
import os

def delete_before_files(base_path):
    """
    Delete all files containing 'before' in their name from the directory and its subdirectories.
    
    Args:
        base_path (str): Path to the directory to clean
    """
    count = 0
    try:
        # Walk through all directories and subdirectories
        for root, dirs, files in os.walk(base_path):
            # Find files containing 'before'
            before_files = [f for f in files if 'before' in f.lower()]
            
            # Delete each file
            for file in before_files:
                file_path = os.path.join(root, file)
                try:
                    os.remove(file_path)
                    print(f"Deleted: {file_path}")
                    count += 1
                except Exception as e:
                    print(f"Error deleting {file_path}: {str(e)}")
        
        print(f"\nTotal files deleted: {count}")
        
    except Exception as e:
        print(f"Error walking through directory: {str(e)}")

# Run the deletion
base_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/IC_of_neg_examples/ZINC_250_350_neg"
delete_before_files(base_path)

##### Calculate pos/neg after FT

In [None]:
data_configs = [
        {
            'weight_range': 'PC_0-250_neg_100_IC',
            'pkl_folder': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/IC_of_neg_examples/PC_0_250_neg_2',
            #'file_path': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/PubChem_vectors/ML_NMR_2M_XL_1H_V1_test_f_0_250_x1000_Vectors_v2.pkl'
        },
        {
            'weight_range': 'PC_250-350_neg_100_IC',
            'pkl_folder': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/IC_of_neg_examples/PC_250_350_neg',
            #'file_path': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/PubChem_vectors/ML_NMR_2M_XL_1H_V1_test_f_0_250_x1000_Vectors_v2.pkl'
        },
        {
            'weight_range': 'PC_350-500_neg_100_IC',
            'pkl_folder': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/IC_of_neg_examples/PC_350_500_neg',
            #'file_path': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/PubChem_vectors/ML_NMR_2M_XL_1H_V1_test_f_0_250_x1000_Vectors_v2.pkl'
        },
        {
            'weight_range': 'ZINC_250-350_neg_100_IC',
            'pkl_folder': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/IC_of_neg_examples/ZINC_250_350_neg',
            #'file_path': '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/PubChem_vectors/ML_NMR_2M_XL_1H_V1_test_f_0_250_x1000_Vectors_v2.pkl'
        },    
    ]

    
output_folder = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/IC_of_neg_examples/CSV_examples_2'
ranking_method = 'HSQC & COSY'

main_csv_generation(data_configs, output_folder, ranking_method)


##### Plot correct, incorrect, failed after IC

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# File paths
base_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/IC_of_neg_examples/CSV_examples_2"

def load_and_process_data(weight_range):
    pos_file = f"{base_path}/PC_{weight_range}_neg_100_IC_positive.csv"
    neg_file = f"{base_path}/PC_{weight_range}_neg_100_IC_negative.csv"
    
    if weight_range == "ZINC_250-350":
        pos_file = f"{base_path}/ZINC_250-350_neg_100_IC_positive.csv"
        neg_file = f"{base_path}/ZINC_250-350_neg_100_IC_negative.csv"
    
    pos_df = pd.read_csv(pos_file)
    neg_df = pd.read_csv(neg_file)
    
    total = 100  # Total molecules per set
    correct = len(pos_df)
    incorrect = len(neg_df)
    failed = total - (correct + incorrect)
    
    return {
        'Correct': (correct/total) * 100,
        'Incorrect': (incorrect/total) * 100,
        'Failed': (failed/total) * 100
    }

# Process data for each set
data_0_250 = load_and_process_data("0-250")
data_250_350 = load_and_process_data("250-350")
data_350_500 = load_and_process_data("350-500")
data_zinc = load_and_process_data("ZINC_250-350")

# Prepare plotting data
plot_data = {
    'Correct': [data_0_250['Correct'], data_250_350['Correct'], data_350_500['Correct'], data_zinc['Correct']],
    'Incorrect': [data_0_250['Incorrect'], data_250_350['Incorrect'], data_350_500['Incorrect'], data_zinc['Incorrect']],
    'Failed': [data_0_250['Failed'], data_250_350['Failed'], data_350_500['Failed'], data_zinc['Failed']]
}

def create_stacked_bar_chart(data, categories, title, save_path):
    fontsize = 22
    colors = [
        '#8BE5A0',  # mint green for Failed
        '#A1C8F3',  # light blue for Correct
        '#FFB381',  # salmon for Incorrect
    ]
    
    fig, ax = plt.subplots(figsize=(8, 7))
    bottom = np.zeros(4)
    
    bars = []
    for i, category in enumerate(categories):
        values = data[category]
        bar = ax.bar(range(4), values, bottom=bottom, label=category, color=colors[i], edgecolor='black')
        for patch in bar:
            patch.set_edgecolor('black')
        bottom += values
        bars.append(bar)
    
    ax.set_title(title, fontsize=fontsize, pad=20)
    ax.set_xticks(range(4))
    ax.set_xticklabels(['Set 1', 'Set 2', 'Set 3', 'ZINC'], fontsize=fontsize)
    ax.set_ylabel('Percentage', fontsize=fontsize)
    ax.set_ylim(0, 100)
    
    ax.tick_params(axis='y', labelsize=fontsize)
    
    ax.legend(loc='lower right', fontsize=fontsize)
    
    for i in ax.containers:
        ax.bar_label(i, fmt='%.1f%%', label_type='center', fontsize=fontsize)
    
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    return fig

# Create and save the plot
save_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/IC_Results.png"
categories = ['Correct', 'Incorrect', 'Failed']
fig = create_stacked_bar_chart(plot_data, categories, 'Internal Coordinate Results', save_path)

# Print the percentages for verification
for i, dataset in enumerate(['Set 1 (0-250)', 'Set 2 (250-350)', 'Set 3 (350-500)', 'ZINC']):
    print(f"\n{dataset}:")
    for category in categories:
        print(f"{category}: {plot_data[category][i]:.1f}%")

##### Plot Top 1, 3, 5, 10 Accuracy of IC

In [None]:
## PC 0-250
import numpy as np
import matplotlib.pyplot as plt

ranking_method = 'HSQC & COSY'
pkl_folder= "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/IC_of_neg_examples/PC_0_250_neg_2"
all_rankings = exp_func.process_pkl_files_baseline(pkl_folder, ranking_method)

all_rankings, removed_smiles = exp_func.deduplicate_smiles_from_ranking(all_rankings)
all_rankings, filtered_out_rankings = exp_func.filter_rankings_by_molecular_formula(all_rankings)
accuracies = exp_func.calculate_top_k_accuracy(all_rankings)


# Increase default font sizes
plt.rcParams.update({'font.size': 22})  # Base font size
plt.rcParams['axes.titlesize'] = 22
plt.rcParams['axes.labelsize'] = 22
plt.rcParams['xtick.labelsize'] = 22
plt.rcParams['ytick.labelsize'] = 22

# Data
labels = ['Top 1', 'Top 3', 'Top 5', 'Top 10', 'Top 20']

# Calculate molecules for each accuracy (based on 100 total molecules)
total_molecules = 100
molecules = [int(acc * total_molecules) for acc in accuracies]
all_labels = labels

# Different colors for each bar
colors = ['#A1C8F3',  # light blue/periwinkle
    '#FFB381',  # salmon/peach
    '#8BE5A0',  # mint green
    '#FF9D9A',  # coral pink
    '#D1B9FE',  # lavender
    '#DEBA9A',  # beige/tan
    '#FCAEE3',  # pink
    '#CFCECE',  # light gray
    '#FEFDA2',  # pale yellow
    '#B8F1EF',  # light cyan/aqua
         ]

# Create the bar plot
fig, ax = plt.subplots(figsize=(8, 7))
bars = ax.bar(range(len(molecules)), molecules, color=colors)

# Set black edges for all bars
for bar in bars:
    bar.set_edgecolor('black')
    
# Customize the plot
ax.set_xticks(range(len(all_labels)))
ax.set_xticklabels(all_labels, fontsize=22)
ax.set_ylabel('Number of Molecules', fontsize=22)
ax.set_title('Prediction Accuracy of Set 1 (Molecules: 100)', fontsize=22, pad=20)

# Add value labels on top and inside of each bar
for i, bar in enumerate(bars):
    height = bar.get_height()
    # Percentage on top
    percentage = accuracies[i] * 100
    ax.text(bar.get_x() + bar.get_width()/2., height + 2,
            f'{percentage:.1f}%',
            ha='center', va='bottom', fontsize=22)
    # Number inside bar
    ax.text(bar.get_x() + bar.get_width()/2., height/2,
            f'{int(height)}',
            ha='center', va='center', fontsize=22, color='black')

# Add grid for better readability
ax.grid(True, axis='y', linestyle='--', alpha=0.7)

# Set y-axis limit to accommodate labels
ax.set_ylim(0, 105)

# Adjust layout to prevent label cutoff
plt.tight_layout()

# Save the plot
save_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/PC_0_250_neg_IC_accuracy_plot.png"
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"Plot saved to: {save_path}")

In [None]:
## PC 0-250
import numpy as np
import matplotlib.pyplot as plt

ranking_method = 'HSQC & COSY'
pkl_folder= "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/IC_of_neg_examples/PC_250_350_neg"
all_rankings = exp_func.process_pkl_files_baseline(pkl_folder, ranking_method)

all_rankings, removed_smiles = exp_func.deduplicate_smiles_from_ranking(all_rankings)
all_rankings, filtered_out_rankings = exp_func.filter_rankings_by_molecular_formula(all_rankings)
accuracies = exp_func.calculate_top_k_accuracy(all_rankings)


# Increase default font sizes
plt.rcParams.update({'font.size': 22})  # Base font size
plt.rcParams['axes.titlesize'] = 22
plt.rcParams['axes.labelsize'] = 22
plt.rcParams['xtick.labelsize'] = 22
plt.rcParams['ytick.labelsize'] = 22

# Data
labels = ['Top 1', 'Top 3', 'Top 5', 'Top 10', 'Top 20']

# Calculate molecules for each accuracy (based on 100 total molecules)
total_molecules = 100
molecules = [int(acc * total_molecules) for acc in accuracies]
all_labels = labels

# Different colors for each bar
colors = ['#A1C8F3',  # light blue/periwinkle
    '#FFB381',  # salmon/peach
    '#8BE5A0',  # mint green
    '#FF9D9A',  # coral pink
    '#D1B9FE',  # lavender
    '#DEBA9A',  # beige/tan
    '#FCAEE3',  # pink
    '#CFCECE',  # light gray
    '#FEFDA2',  # pale yellow
    '#B8F1EF',  # light cyan/aqua
         ]

# Create the bar plot
fig, ax = plt.subplots(figsize=(8, 7))
bars = ax.bar(range(len(molecules)), molecules, color=colors)

# Set black edges for all bars
for bar in bars:
    bar.set_edgecolor('black')
    
# Customize the plot
ax.set_xticks(range(len(all_labels)))
ax.set_xticklabels(all_labels, fontsize=22)
ax.set_ylabel('Number of Molecules', fontsize=22)
ax.set_title('Prediction Accuracy of Set 2 (Molecules: 100)', fontsize=22, pad=20)

# Add value labels on top and inside of each bar
for i, bar in enumerate(bars):
    height = bar.get_height()
    # Percentage on top
    percentage = accuracies[i] * 100
    ax.text(bar.get_x() + bar.get_width()/2., height + 2,
            f'{percentage:.1f}%',
            ha='center', va='bottom', fontsize=22)
    # Number inside bar
    ax.text(bar.get_x() + bar.get_width()/2., height/2,
            f'{int(height)}',
            ha='center', va='center', fontsize=22, color='black')

# Add grid for better readability
ax.grid(True, axis='y', linestyle='--', alpha=0.7)

# Set y-axis limit to accommodate labels
ax.set_ylim(0, 105)

# Adjust layout to prevent label cutoff
plt.tight_layout()

# Save the plot
save_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/PC_250_350_neg_IC_accuracy_plot.png"
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"Plot saved to: {save_path}")

In [None]:
## PC 0-250
import numpy as np
import matplotlib.pyplot as plt

ranking_method = 'HSQC & COSY'
pkl_folder= "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/IC_of_neg_examples/PC_350_500_neg"
all_rankings = exp_func.process_pkl_files_baseline(pkl_folder, ranking_method)

all_rankings, removed_smiles = exp_func.deduplicate_smiles_from_ranking(all_rankings)
all_rankings, filtered_out_rankings = exp_func.filter_rankings_by_molecular_formula(all_rankings)
accuracies = exp_func.calculate_top_k_accuracy(all_rankings)


# Increase default font sizes
plt.rcParams.update({'font.size': 22})  # Base font size
plt.rcParams['axes.titlesize'] = 22
plt.rcParams['axes.labelsize'] = 22
plt.rcParams['xtick.labelsize'] = 22
plt.rcParams['ytick.labelsize'] = 22

# Data
labels = ['Top 1', 'Top 3', 'Top 5', 'Top 10', 'Top 20']

# Calculate molecules for each accuracy (based on 100 total molecules)
total_molecules = 100
molecules = [int(acc * total_molecules) for acc in accuracies]
all_labels = labels

# Different colors for each bar
colors = ['#A1C8F3',  # light blue/periwinkle
    '#FFB381',  # salmon/peach
    '#8BE5A0',  # mint green
    '#FF9D9A',  # coral pink
    '#D1B9FE',  # lavender
    '#DEBA9A',  # beige/tan
    '#FCAEE3',  # pink
    '#CFCECE',  # light gray
    '#FEFDA2',  # pale yellow
    '#B8F1EF',  # light cyan/aqua
         ]

# Create the bar plot
fig, ax = plt.subplots(figsize=(8, 7))
bars = ax.bar(range(len(molecules)), molecules, color=colors)

# Set black edges for all bars
for bar in bars:
    bar.set_edgecolor('black')
    
# Customize the plot
ax.set_xticks(range(len(all_labels)))
ax.set_xticklabels(all_labels, fontsize=22)
ax.set_ylabel('Number of Molecules', fontsize=22)
ax.set_title('Prediction Accuracy of Set 3 (Molecules: 100)', fontsize=22, pad=20)

# Add value labels on top and inside of each bar
for i, bar in enumerate(bars):
    height = bar.get_height()
    # Percentage on top
    percentage = accuracies[i] * 100
    ax.text(bar.get_x() + bar.get_width()/2., height + 2,
            f'{percentage:.1f}%',
            ha='center', va='bottom', fontsize=22)
    # Number inside bar
    ax.text(bar.get_x() + bar.get_width()/2., height/2,
            f'{int(height)}',
            ha='center', va='center', fontsize=22, color='black')

# Add grid for better readability
ax.grid(True, axis='y', linestyle='--', alpha=0.7)

# Set y-axis limit to accommodate labels
ax.set_ylim(0, 105)

# Adjust layout to prevent label cutoff
plt.tight_layout()

# Save the plot
save_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/PC_350_500_neg_IC_accuracy_plot.png"
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"Plot saved to: {save_path}")

In [None]:
## PC 0-250
import numpy as np
import matplotlib.pyplot as plt

ranking_method = 'HSQC & COSY'
pkl_folder= "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/IC_of_neg_examples/ZINC_250_350_neg"
all_rankings = exp_func.process_pkl_files_baseline(pkl_folder, ranking_method)

all_rankings, removed_smiles = exp_func.deduplicate_smiles_from_ranking(all_rankings)
all_rankings, filtered_out_rankings = exp_func.filter_rankings_by_molecular_formula(all_rankings)
accuracies = exp_func.calculate_top_k_accuracy(all_rankings)


# Increase default font sizes
plt.rcParams.update({'font.size': 22})  # Base font size
plt.rcParams['axes.titlesize'] = 22
plt.rcParams['axes.labelsize'] = 22
plt.rcParams['xtick.labelsize'] = 22
plt.rcParams['ytick.labelsize'] = 22

# Data
labels = ['Top 1', 'Top 3', 'Top 5', 'Top 10', 'Top 20']

# Calculate molecules for each accuracy (based on 100 total molecules)
total_molecules = 100
molecules = [int(acc * total_molecules) for acc in accuracies]
all_labels = labels

# Different colors for each bar
colors = ['#A1C8F3',  # light blue/periwinkle
    '#FFB381',  # salmon/peach
    '#8BE5A0',  # mint green
    '#FF9D9A',  # coral pink
    '#D1B9FE',  # lavender
    '#DEBA9A',  # beige/tan
    '#FCAEE3',  # pink
    '#CFCECE',  # light gray
    '#FEFDA2',  # pale yellow
    '#B8F1EF',  # light cyan/aqua
         ]

# Create the bar plot
fig, ax = plt.subplots(figsize=(8, 7))
bars = ax.bar(range(len(molecules)), molecules, color=colors)

# Set black edges for all bars
for bar in bars:
    bar.set_edgecolor('black')
    
# Customize the plot
ax.set_xticks(range(len(all_labels)))
ax.set_xticklabels(all_labels, fontsize=22)
ax.set_ylabel('Number of Molecules', fontsize=22)
ax.set_title('Prediction Accuracy of ZINC (Molecules: 100)', fontsize=22, pad=20)

# Add value labels on top and inside of each bar
for i, bar in enumerate(bars):
    height = bar.get_height()
    # Percentage on top
    percentage = accuracies[i] * 100
    ax.text(bar.get_x() + bar.get_width()/2., height + 2,
            f'{percentage:.1f}%',
            ha='center', va='bottom', fontsize=22)
    # Number inside bar
    ax.text(bar.get_x() + bar.get_width()/2., height/2,
            f'{int(height)}',
            ha='center', va='center', fontsize=22, color='black')

# Add grid for better readability
ax.grid(True, axis='y', linestyle='--', alpha=0.7)

# Set y-axis limit to accommodate labels
ax.set_ylim(0, 105)

# Adjust layout to prevent label cutoff
plt.tight_layout()

# Save the plot
save_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/ZINC_250_350_neg_IC_accuracy_plot.png"
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"Plot saved to: {save_path}")

#### 5.7 Bar chart of ESI 7 HSQC Matching results ZINC 100

In [None]:

# V8 Raw 
#config.IR_data_folder = "/projects/cc/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/IR_data"
config.csv_train_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/ML_NMR_5M_XL_1H_comb_train_V8.csv' 
config.csv_1H_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/val_data_all_modalities/ML_NMR_1H_combined_ZINC_test_10x100.csv'
config.csv_13C_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/val_data_all_modalities/ML_NMR_5M_XL_13C_test_10x100.csv'    
config.csv_HSQC_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/val_data_all_modalities/ML_NMR_5M_XL_HSQC_test_10x100.csv'    
config.csv_COSY_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/val_data_all_modalities/ML_NMR_5M_XL_COSY_test_10x100.csv'  
config.csv_path_val = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/val_data_all_modalities/ML_NMR_1H_combined_ZINC_test_10x100.csv'
config.pickle_file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/val_data_all_modalities/ML_NMR_1H_combined_ZINC_test_10x100_909434.pkl"
config.IR_data_folder = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/IR_spectra_NN"


config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8_MMT_MW2_Drop/MultimodalTransformer_time_1706856620.371885_Loss_0.202.ckpt"
config.multinom_runs = 10
#config.multinom_runs = 3
config.temperature = 1
greedy_full = False
MW_filter = True
config.data_size = 100 # config.test_size # why would I do that? 

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
config.execution_type = "test_performance"
if config.execution_type == "test_performance":
    print("\033[1m\033[31mThis is: test_performance\033[0m")
    # config.csv_path_val = config.csv_SMI_targets  #this already got updated in simulate_syn_data
    
    model_MMT = mrtf.load_MMT_model(config)
    model_CLIP = mrtf.load_CLIP_model(config)
    #model_BLIP = mrtf.load_BLIP_model(config)
    #model_MMT = model_CLIP.MT_model
    val_dataloader = mrtf.load_data(config, stoi, stoi_MF, single=True, mode="val")
    val_dataloader_multi = mrtf.load_data(config, stoi, stoi_MF, single=False, mode="val")

    
    results_dict_bl_ZINC = mrtf.run_test_mns_performance_CLIP_3(config,  
                                                        model_MMT,
                                                        model_CLIP,
                                                        val_dataloader,
                                                        stoi, 
                                                        itos,
                                                        MW_filter)
    results_dict_bl_ZINC, counter = mrtf.filter_invalid_inputs(results_dict_bl_ZINC)
    
    avg_tani_bl_ZINC, html_plot = rbgvm.plot_hist_of_results(results_dict_bl_ZINC)
    
    # Slow because also just takes one at the time
    if greedy_full == True:
        results_dict_greedy_bl_ZINC, failed_bl_ZINC = mrtf.run_test_performance_CLIP_greedy_3(config,  
                                                                stoi, 
                                                                stoi_MF, 
                                                                itos, 
                                                                itos_MF)

        avg_tani_greedy_bl_ZINC, html_plot_greedy = rbgvm.plot_hist_of_results_greedy(results_dict_greedy_bl_ZINC)

    else: 
        config, results_dict_ZINC_greedy_bl = mrtf.run_greedy_sampling(config, model_MMT, val_dataloader_multi, itos, stoi)
        avg_tani_greedy_bl_ZINC = results_dict_ZINC_greedy_bl["tanimoto_mean"]
    
    total_results_bl_ZINC = mrtf.run_test_performance_CLIP_3(config, 
                                                        model_MMT, 
                                                        val_dataloader,
                                                        stoi)
    
    corr_sampleing_prob_bl_ZINC = total_results_bl_ZINC["statistics_multiplication_avg"][0]
    print("avg_tani, avg_tani_greedy, corr_sampleing_prob'")
    print(avg_tani_bl_ZINC, avg_tani_greedy_bl_ZINC, corr_sampleing_prob_bl_ZINC)       

    

In [None]:

# Save variables to a pickle file
variables_to_save = {
    'avg_tani_bl_ZINC': avg_tani_bl_ZINC,
    'results_dict_greedy_bl_ZINC': results_dict_greedy_bl_ZINC if greedy_full else None,
    'failed_bl_ZINC': failed_bl_ZINC if greedy_full else None,
    'avg_tani_greedy_bl_ZINC': avg_tani_greedy_bl_ZINC,
    'results_dict_ZINC_greedy_bl': results_dict_ZINC_greedy_bl if not greedy_full else None,
    'total_results_bl_ZINC': total_results_bl_ZINC,
    'corr_sampleing_prob_bl_ZINC': corr_sampleing_prob_bl_ZINC,
    'results_dict_bl_ZINC': results_dict_bl_ZINC,
}

with open('/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240522_ZINC_MNS_HSQC_matching/7.0_imp_cyc_all_100_10_after.pkl', 'wb') as f:
    pickle.dump(variables_to_save, f)


In [None]:
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240522_ZINC_MNS_HSQC_matching/7.0_imp_cyc_all_100_10_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    variables_to_save = pickle.load(file)  

In [None]:
results_dict_bl_ZINC = variables_to_save["results_dict_bl_ZINC"]

In [None]:
# Function to check if the lowest error correctly identifies the correct molecule
def check_lowest_error_correct(results):
    correct_identifications = 0
    total_cases_with_correct_answer = 0
    total_cases_without_correct_answer = 0
    closest_identifications = 0
    incorrect_identification_keys = []

    for key, value in results.items():
        if not isinstance(value, list) or not isinstance(value[0], list):
            continue  # Skip non-list items

        for sublist in value:
            if not isinstance(sublist, list):
                continue  # Skip non-list sublists

            correct_error = None
            lowest_error = float('inf')
            highest_similarity = 0
            closest_error = float('inf')

            for item in sublist:
                if not isinstance(item, list) or len(item) < 5:
                    continue  # Skip invalid items

                similarity = item[3]
                error = item[4]

                if isinstance(similarity, (int, float)) and similarity == 1:
                    correct_error = float(error)
                if isinstance(error, (int, float)) and float(error) < lowest_error:
                    lowest_error = float(error)
                if isinstance(similarity, (int, float)) and similarity > highest_similarity:
                    highest_similarity = similarity
                    closest_error = float(error)

            if correct_error is not None:
                total_cases_with_correct_answer += 1
                if correct_error == lowest_error:
                    correct_identifications += 1
                else:
                    incorrect_identification_keys.append(key)
            else:
                total_cases_without_correct_answer += 1
                if closest_error == lowest_error:
                    closest_identifications += 1

    return (correct_identifications, total_cases_with_correct_answer, 
            closest_identifications, total_cases_without_correct_answer,
            incorrect_identification_keys)

# Run the function
correct_identifications, total_cases_with_correct_answer, closest_identifications, total_cases_without_correct_answer, incorrect_identification_keys = check_lowest_error_correct(results_dict_bl_ZINC)

# Calculate accuracies
accuracy_with_correct = correct_identifications / total_cases_with_correct_answer if total_cases_with_correct_answer > 0 else 0
accuracy_without_correct = closest_identifications / total_cases_without_correct_answer if total_cases_without_correct_answer > 0 else 0

# Print results
print(f"Correct identifications: {correct_identifications}")
print(f"Total cases with correct answer: {total_cases_with_correct_answer}")
print(f"Accuracy with correct answer: {accuracy_with_correct:.2%}")

print(f"Closest identifications: {closest_identifications}")
print(f"Total cases without correct answer: {total_cases_without_correct_answer}")
print(f"Accuracy without correct answer: {accuracy_without_correct:.2%}")

# Print keys of incorrect identifications
print(f"Keys of incorrect identifications: {incorrect_identification_keys}")


In [None]:
### old with 3 bars
"""import matplotlib.pyplot as plt

tani_list = variables_to_save["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
greedy_correct = tani_list.count(1)
total_cases_greedy = len(tani_list)
accuracy_greedy = greedy_correct / total_cases_greedy

# Bar chart for identification accuracies
accuracies = {
    'MNS With \nCorrect Answer': accuracy_with_correct,
    'MNS Without \nCorrect Answer': accuracy_without_correct,
    'Greedy \nSampling': accuracy_greedy
}

# Ratios for annotation
ratios = {
    'With Correct Answer': f'{correct_identifications}/{total_cases_with_correct_answer}',
    'Without Correct Answer': f'{closest_identifications}/{total_cases_without_correct_answer}',
    'Greedy Sampling': f'{greedy_correct}/{total_cases_greedy}'
}

fig, ax = plt.subplots()
bars = ax.bar(accuracies.keys(), accuracies.values(), color=['#87CEEB', '#FFA07A', '#90EE90'])  # LightBlue, LightSalmon, LightGreen
ax.set_ylabel('Accuracy')
ax.set_title('Identification Accuracies')
ax.set_ylim(0, 1)

# Annotate bars with accuracy percentages on top of the bar
for bar, (label, ratio) in zip(bars, ratios.items()):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width() / 2, height,
            f'{height:.2%}',
            ha='center', va='bottom', fontsize=14)

# Annotate bars with ratios inside the bar
for bar, (label, ratio) in zip(bars, ratios.items()):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width() / 2, height / 2,
            f'{ratio}',
            ha='center', va='center', fontsize=14, color='black')

plt.tight_layout()

# Define the file path for saving the plot
output_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/7.0_identification_accuracies_new.png'

# Save the plot
plt.savefig(output_path, bbox_inches='tight')
plt.show()"""

In [None]:
import matplotlib.pyplot as plt

tani_list = variables_to_save["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
greedy_correct = tani_list.count(1)
total_cases_greedy = len(tani_list)
accuracy_greedy = greedy_correct / total_cases_greedy

# Bar chart for identification accuracies
accuracies = {
    'MNS \nSampling': accuracy_with_correct,
    'Greedy \nSampling': accuracy_greedy
}

# Ratios for annotation
ratios = {
    'MNS': f'{correct_identifications}/{total_cases_with_correct_answer}',
    'Greedy Sampling': f'{greedy_correct}/{total_cases_greedy}'
}

fig, ax = plt.subplots()
bars = ax.bar(accuracies.keys(), accuracies.values(), color=['#A1C8F3',  '#FFB381'])  # LightBlue, LightGreen

ax.set_ylabel('Accuracy')
ax.set_title('Identification Accuracies')
ax.set_ylim(0, 1)

# Annotate bars with accuracy percentages on top of the bar
for bar, (label, ratio) in zip(bars, ratios.items()):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width() / 2, height,
            f'{height:.2%}',
            ha='center', va='bottom', fontsize=22)

# Annotate bars with ratios inside the bar
for bar, (label, ratio) in zip(bars, ratios.items()):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width() / 2, height / 2,
            f'{ratio}',
            ha='center', va='center', fontsize=22, color='black')

plt.tight_layout()

# Define the file path for saving the plot
output_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/7.0_identification_accuracies_new.png'

# Save the plot
plt.savefig(output_path, bbox_inches='tight')
plt.show()

#### 5.8 ZINC 4000 MNS Experiment

In [None]:
ranking_method = 'HSQC & COSY'
pkl_folder= "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/ZINC_250_350_4000"
all_rankings = exp_func.process_pkl_files_baseline(pkl_folder, ranking_method)

all_rankings, removed_smiles = exp_func.deduplicate_smiles_from_ranking(all_rankings)
all_rankings, filtered_out_rankings = exp_func.filter_rankings_by_molecular_formula(all_rankings)
accuracies = exp_func.calculate_top_k_accuracy(all_rankings)


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Increase default font sizes
plt.rcParams.update({'font.size': 22})  # Base font size
plt.rcParams['axes.titlesize'] = 22
plt.rcParams['axes.labelsize'] = 22
plt.rcParams['xtick.labelsize'] = 22
plt.rcParams['ytick.labelsize'] = 22

# Data
labels = ['Top 1', 'Top 3', 'Top 5', 'Top 10']

# Calculate molecules for each accuracy
total_molecules = 4000
processed_molecules = 3991
failed_molecules = total_molecules - processed_molecules
molecules = [int(acc * processed_molecules) for acc in accuracies]
all_molecules = molecules + [failed_molecules]
all_labels = labels + ['Failed']

# Different colors for each bar
colors = ['#FF9999', '#66B2FF', '#99FF99', '#FFCC99', '#FF99CC']

# Create the bar plot
fig, ax = plt.subplots(figsize=(12, 6))
bars = ax.bar(range(len(all_molecules)), all_molecules, color=colors)

# Set black edges for all bars
for bar in bars:
    bar.set_edgecolor('black')
    
# Customize the plot
ax.set_xticks(range(len(all_labels)))
ax.set_xticklabels(all_labels, fontsize=22)
ax.set_ylabel('Number of Molecules', fontsize=22)
ax.set_title('Prediction Accuracy by Top-K (Total Molecules: 4000)', fontsize=22, pad=20)

# Add value labels on top and inside of each bar
for i, bar in enumerate(bars):
    height = bar.get_height()
    if i < len(accuracies):  # For accuracy bars
        percentage = accuracies[i] * 100
        # Percentage on top
        ax.text(bar.get_x() + bar.get_width()/2., height + 50,
                f'{percentage:.1f}%',
                ha='center', va='bottom', fontsize=22)
        # Number inside bar
        ax.text(bar.get_x() + bar.get_width()/2., height/2,
                f'{int(height)}',
                ha='center', va='center', fontsize=22, color='black')
    else:  # For failed molecules bar
        percentage = (height/total_molecules) * 100
        # Percentage on top - moved higher for failed bar
        ax.text(bar.get_x() + bar.get_width()/2., height + 200,
                f'{percentage:.1f}%',
                ha='center', va='bottom', fontsize=22)
        # Number inside bar - moved higher for failed bar
        ax.text(bar.get_x() + bar.get_width()/2., height + 100,
                f'{int(height)}',
                ha='center', va='center', fontsize=22, color='black')
        
# Add grid for better readability
ax.grid(True, axis='y', linestyle='--', alpha=0.7)

# Set y-axis limit to accommodate labels
ax.set_ylim(0, max(all_molecules) + 400)
#ax.set_ylim(0, 4100)

# Adjust layout to prevent label cutoff
plt.tight_layout()

#plt.show()
# Save the plot
save_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/ZINC4000_accuracy_plot.png"
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()

print(f"Plot saved to: {save_path}")

### 6.0 Test Improvement cycle 

#### 6.1 Fine Tune an all the molecules that I want to investigate

In [None]:
config.project = "Improv_Cycle_v3" # Name of the project for wandb monitoring
config.csv_train_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/ML_NMR_5M_XL_1H_comb_train_V8.csv' 
config.csv_1H_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/val_data_all_modalities/ML_NMR_1H_combined_ZINC_test_10x100.csv'
config.csv_13C_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/val_data_all_modalities/ML_NMR_5M_XL_13C_test_10x100.csv'    
config.csv_HSQC_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/val_data_all_modalities/ML_NMR_5M_XL_HSQC_test_10x100.csv'    
config.csv_COSY_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/val_data_all_modalities/ML_NMR_5M_XL_COSY_test_10x100.csv'  
config.csv_path_val = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/val_data_all_modalities/ML_NMR_1H_combined_ZINC_test_10x100.csv'
config.pickle_file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/val_data_all_modalities/ML_NMR_1H_combined_ZINC_test_10x100_909434.pkl"

#config.csv_1H_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_1H_V1_test_f_350_500_x1000.csv'
#config.csv_13C_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_13C_V1_test_350_500_x1000.csv'    
#config.csv_HSQC_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_HSQC_V1_test_350_500_x1000.csv'    
#config.csv_COSY_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_COSY_V1_test_350_500_x1000.csv'   
#config.csv_path_val = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_1H_V1_test_f_350_500_x1000.csv'
#config.pickle_file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_1H_V1_test_f_350_500_x1000_185242.pkl"

# V8 Raw 
#config.IR_data_folder = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/IR_data"

#config.IR_data_folder = ""
config.IR_data_folder = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/IR_spectra_NN"
#config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8_MMT/MultimodalTransformer_time_1702145565.246352_Loss_0.048.ckpt"

# V8 / 8Vi Raw MW Drop
config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8_MMT_MW2_Drop/MultimodalTransformer_time_1706856620.371885_Loss_0.202.ckpt"
#config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_Drop4/MultimodalTransformer_time_1710027004.1571195_Loss_0.112.ckpt"

config.data_size = 100 # config.test_size # why would I do that? 
#config.data_size = 4 # config.test_size # why would I do that? 
config.execution_type = "test_performance"
config.multinom_runs = 1
#config.multinom_runs = 3
config.temperature = 1
greedy_full = False
MW_filter = True
config.MF_generations = 50
config.MF_delta_weight = 20
config.max_scaffold_generations = 30


In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
if config.execution_type == "test_performance":
    print("\033[1m\033[31mThis is: test_performance\033[0m")
    # config.csv_path_val = config.csv_SMI_targets  #this already got updated in simulate_syn_data
    
    model_MMT = mrtf.load_MMT_model(config)
    model_CLIP = mrtf.load_CLIP_model(config)
    #model_BLIP = mrtf.load_BLIP_model(config)
    #model_MMT = model_CLIP.MT_model
    val_dataloader = mrtf.load_data(config, stoi, stoi_MF, single=True, mode="val")
    val_dataloader_multi = mrtf.load_data(config, stoi, stoi_MF, single=False, mode="val")

    
    results_dict_bl_ZINC = mrtf.run_test_mns_performance_CLIP_3(config,  
                                                        model_MMT,
                                                        model_CLIP,
                                                        val_dataloader,
                                                        stoi, 
                                                        itos,
                                                        MW_filter)
    results_dict_bl_ZINC, counter = mrtf.filter_invalid_inputs(results_dict_bl_ZINC)
    
    avg_tani_bl_ZINC, html_plot = rbgvm.plot_hist_of_results(results_dict_bl_ZINC)
    
    # Slow because also just takes one at the time
    if greedy_full == True:
        results_dict_greedy_bl_ZINC, failed_bl_ZINC = mrtf.run_test_performance_CLIP_greedy_3(config,  
                                                                stoi, 
                                                                stoi_MF, 
                                                                itos, 
                                                                itos_MF)

        avg_tani_greedy_bl_ZINC, html_plot_greedy = rbgvm.plot_hist_of_results_greedy(results_dict_greedy_bl_ZINC)

    else: 
        config, results_dict_ZINC_greedy_bl = mrtf.run_greedy_sampling(config, model_MMT, val_dataloader_multi, itos, stoi)
        avg_tani_greedy_bl_ZINC = results_dict_ZINC_greedy_bl["tanimoto_mean"]
    
    total_results_bl_ZINC = mrtf.run_test_performance_CLIP_3(config, 
                                                        model_MMT, 
                                                        val_dataloader,
                                                        stoi)
    
    corr_sampleing_prob_bl_ZINC = total_results_bl_ZINC["statistics_multiplication_avg"][0]
    print("avg_tani, avg_tani_greedy, corr_sampleing_prob'")
    print(avg_tani_bl_ZINC, avg_tani_greedy_bl_ZINC, corr_sampleing_prob_bl_ZINC)       

    

In [None]:
#Select the samples that did not succeed to get right
filtered_results = [key for key, value in results_dict_bl_ZINC.items()]
filtered_results

In [None]:

# Create a DataFrame
data = {
    "SMILES": filtered_results,
    "sample-id": [f"SOURCE_00000{i+1}" for i in range(len(filtered_results))]
}

df = pd.DataFrame(data)

# Save DataFrame to CSV
csv_file_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/deep-molecular-optimization/data/MMP/test_selection_2.csv'
df.to_csv(csv_file_path, index=False)
csv_path_val_backup = config.csv_path_val
pickle_file_path_backup = config.pickle_file_path

print(f"CSV file '{csv_file_path}' created successfully.")

In [None]:
config.execution_type = "SMI_generation_MF"

if config.execution_type == "SMI_generation_MF":
    config.n_samples = config.data_size
    #if config.execution_type == "SMI_generation_MF":
    print("\033[1m\033[31mThis is: SMI_generation_MF\033[0m")
    config.csv_path_val  = ex.filter_invalid_criteria(config.csv_path_val)#, config.csv_path_val)
    config, results_dict_MF = ex.SMI_generation_MF(config, stoi, stoi_MF, itos, itos_MF)
    #mode = "val"

    # Iterate through the dictionary and remove 'nan' from lists
    results_dict_MF = {key: value for key, value in results_dict_MF.items() if not hf.contains_only_nan(value)}
    for key, value in results_dict_MF.items():
        results_dict_MF[key] = hf.remove_nan_from_list(value)

    combined_list_MF, html_TSNE, html_UMAP, html_PCA = cv.plot_cluster_MF(results_dict_MF, config)
    max_num = 10
    html_plot = pt.plot_molecules_from_list(combined_list_MF, max_num)
    config.execution_type = "combine_MMT_MF"
    #import IPython; IPython.embed();
    print(config.data_size)


In [None]:
config.execution_type = "combine_MMT_MF"
combined_list_MMT = []
if config.execution_type == "combine_MMT_MF":
    print("\033[1m\033[31mThis is: combine_MMT_MF\033[0m")
    all_gen_smis = combined_list_MMT + combined_list_MF
    combined_list_MF = [smiles for smiles in combined_list_MF if smiles != 'NAN']
    all_gen_smis = [smiles for smiles in all_gen_smis if smiles != 'NAN']
    
    #filter out potential hits from the real test_set
    val_data = pd.read_csv(config.csv_path_val)
    all_gen_smis = mrtf.filter_smiles(val_data, all_gen_smis)
    length_of_list = len(all_gen_smis)   
    random_number_strings = [f"GT_{str(i).zfill(7)}" for i in range(1, length_of_list + 1)]
    aug_mol_df = pd.DataFrame({'SMILES': all_gen_smis, 'sample-id': random_number_strings})
    config.execution_type = "blend_prev_train_data"


In [None]:
config.train_data_blend = 0
config.execution_type = "blend_prev_train_data"
if config.execution_type == "blend_prev_train_data":
    print("\033[1m\033[31mThis is: blend_prev_train_data\033[0m")
    config, final_df = ex.blend_aug_with_train_data(config, aug_mol_df)
    config.execution_type = "data_generation"
    #import IPython; IPython.embed();


In [None]:
config.execution_type = "data_generation"
if config.execution_type == "data_generation":
    #config.csv_SMI_targets = config.csv_1H_path_SGNN
    print("\033[1m\033[31mThis is: data_generation\033[0m")
    config = ex.gen_sim_aug_data(config, IR_config)
    config.execution_type = "transformer_improvement"
    sim_data_gen = True
    #import IPython; IPython.embed();

In [None]:

# Save variables to a pickle file
variables_to_save = {
    'avg_tani_bl_ZINC': avg_tani_bl_ZINC,
    'results_dict_greedy_bl_ZINC': results_dict_greedy_bl_ZINC if greedy_full else None,
    'failed_bl_ZINC': failed_bl_ZINC if greedy_full else None,
    'avg_tani_greedy_bl_ZINC': avg_tani_greedy_bl_ZINC,
    'results_dict_ZINC_greedy_bl': results_dict_ZINC_greedy_bl if not greedy_full else None,
    'total_results_bl_ZINC': total_results_bl_ZINC,
    'corr_sampleing_prob_bl_ZINC': corr_sampleing_prob_bl_ZINC,
    'filtered_results': filtered_results,
    'all_gen_smis': all_gen_smis,
    'aug_mol_df': aug_mol_df,
    'results_dict_bl_ZINC': results_dict_bl_ZINC,
}

with open('/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/20240516_Improvment_Cycle_v2/6.2_imp_cyc_all_100_50.pkl', 'wb') as f:
    pickle.dump(variables_to_save, f)


In [None]:
!#rm -r '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/24_SGNN_gen_folder_2/dump_2_64364'

In [None]:
# config.csv_SMI_targets = config.csv_1H_path_SGNN
# data_IR = irs.run_IR_simulation(config, IR_config, "target")
# config.IR_data_folder = data_IR


In [None]:
config.blank_percentage = 0
config.weight_MW = 0
config.lr_pretraining = 1e-4
config.tr_te_split = 0.9
config.model_save_dir = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/imp_cyc_all_100_50_v2" # Folder where networks are saved
data_size = len(pd.read_csv(config.csv_1H_path_SGNN))
config.data_size = data_size
config.gpu_num = 1
config.batch_size = 64
config.num_epochs = 50

In [None]:
config.execution_type = "transformer_improvement"
sim_data_gen = True

if config.execution_type == "transformer_improvement" and sim_data_gen == True:
    print("\033[1m\033[31mThis is: transformer_improvement, sim_data_gen == TRUE\033[0m")
    config.training_setup = "pretraining"
    mtf.run_MMT(config, stoi, stoi_MF)

    # config.execution_type = "clip_improvement"
    config.execution_type = "update_model"
    # finish_while = True
    #import IPython; IPython.embed();


In [None]:
config.csv_train_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/ML_NMR_5M_XL_1H_comb_train_V8.csv' 
config.csv_1H_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/val_data_all_modalities/ML_NMR_1H_combined_ZINC_test_10x100.csv'
config.csv_13C_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/val_data_all_modalities/ML_NMR_5M_XL_13C_test_10x100.csv'    
config.csv_HSQC_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/val_data_all_modalities/ML_NMR_5M_XL_HSQC_test_10x100.csv'    
config.csv_COSY_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/val_data_all_modalities/ML_NMR_5M_XL_COSY_test_10x100.csv'  
config.csv_path_val = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/val_data_all_modalities/ML_NMR_1H_combined_ZINC_test_10x100.csv'
config.pickle_file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/val_data_all_modalities/ML_NMR_1H_combined_ZINC_test_10x100_909434.pkl"

#config.csv_1H_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_1H_V1_test_f_350_500_x1000.csv'
#config.csv_13C_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_13C_V1_test_350_500_x1000.csv'    
#config.csv_HSQC_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_HSQC_V1_test_350_500_x1000.csv'    
#config.csv_COSY_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_COSY_V1_test_350_500_x1000.csv'   
#config.csv_path_val = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_1H_V1_test_f_350_500_x1000.csv'
#config.pickle_file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_1H_V1_test_f_350_500_x1000_185242.pkl"

# V8 Raw 
#config.IR_data_folder = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/IR_data"

#config.IR_data_folder = ""
config.IR_data_folder = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/15_ZINC270M/IR_spectra_NN"
#config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8_MMT/MultimodalTransformer_time_1702145565.246352_Loss_0.048.ckpt"

# V8 / 8Vi Raw MW Drop
config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8_MMT_MW2_Drop/MultimodalTransformer_time_1706856620.371885_Loss_0.202.ckpt"
#config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_Drop4/MultimodalTransformer_time_1710027004.1571195_Loss_0.112.ckpt"

config.data_size = 100 # config.test_size # why would I do that? 
#config.data_size = 4 # config.test_size # why would I do that? 
config.execution_type = "test_performance"
config.multinom_runs = 1
#config.multinom_runs = 3
config.temperature = 1
greedy_full = False
MW_filter = True
config.MF_generations = 50


In [None]:
config = ex.update_model_path(config)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
if config.execution_type == "test_performance":
    print("\033[1m\033[31mThis is: test_performance\033[0m")
    # config.csv_path_val = config.csv_SMI_targets  #this already got updated in simulate_syn_data
    
    model_MMT = mrtf.load_MMT_model(config)
    model_CLIP = mrtf.load_CLIP_model(config)
    #model_BLIP = mrtf.load_BLIP_model(config)
    #model_MMT = model_CLIP.MT_model
    val_dataloader = mrtf.load_data(config, stoi, stoi_MF, single=True, mode="val")
    val_dataloader_multi = mrtf.load_data(config, stoi, stoi_MF, single=False, mode="val")

    
    results_dict_bl_ZINC = mrtf.run_test_mns_performance_CLIP_3(config,  
                                                        model_MMT,
                                                        model_CLIP,
                                                        val_dataloader,
                                                        stoi, 
                                                        itos,
                                                        MW_filter)
    results_dict_bl_ZINC, counter = mrtf.filter_invalid_inputs(results_dict_bl_ZINC)
    
    avg_tani_bl_ZINC, html_plot = rbgvm.plot_hist_of_results(results_dict_bl_ZINC)
    
    # Slow because also just takes one at the time
    if greedy_full == True:
        results_dict_greedy_bl_ZINC, failed_bl_ZINC = mrtf.run_test_performance_CLIP_greedy_3(config,  
                                                                stoi, 
                                                                stoi_MF, 
                                                                itos, 
                                                                itos_MF)

        avg_tani_greedy_bl_ZINC, html_plot_greedy = rbgvm.plot_hist_of_results_greedy(results_dict_greedy_bl_ZINC)

    else: 
        config, results_dict_ZINC_greedy_bl = mrtf.run_greedy_sampling(config, model_MMT, val_dataloader_multi, itos, stoi)
        avg_tani_greedy_bl_ZINC = results_dict_ZINC_greedy_bl["tanimoto_mean"]
    
    total_results_bl_ZINC = mrtf.run_test_performance_CLIP_3(config, 
                                                        model_MMT, 
                                                        val_dataloader,
                                                        stoi)
    
    corr_sampleing_prob_bl_ZINC = total_results_bl_ZINC["statistics_multiplication_avg"][0]
    print("avg_tani, avg_tani_greedy, corr_sampleing_prob'")
    print(avg_tani_bl_ZINC, avg_tani_greedy_bl_ZINC, corr_sampleing_prob_bl_ZINC)       

    

In [None]:
#Select the samples that did not succeed to get right
filtered_results = [key for key, value in results_dict_bl_ZINC.items()]
# filtered_results

In [None]:

# Save variables to a pickle file
variables_to_save = {
    'avg_tani_bl_ZINC': avg_tani_bl_ZINC,
    'results_dict_greedy_bl_ZINC': results_dict_greedy_bl_ZINC if greedy_full else None,
    'failed_bl_ZINC': failed_bl_ZINC if greedy_full else None,
    'avg_tani_greedy_bl_ZINC': avg_tani_greedy_bl_ZINC,
    'results_dict_ZINC_greedy_bl': results_dict_ZINC_greedy_bl if not greedy_full else None,
    'total_results_bl_ZINC': total_results_bl_ZINC,
    'corr_sampleing_prob_bl_ZINC': corr_sampleing_prob_bl_ZINC,
    'filtered_results': filtered_results,
    'results_dict_bl_ZINC': results_dict_bl_ZINC,
}

with open('/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/20240516_Improvment_Cycle_v2/6.2_imp_cyc_all_100_50_after.pkl', 'wb') as f:
    pickle.dump(variables_to_save, f)


In [None]:

config.data_size = 100 # config.test_size # why would I do that? 
#config.data_size = 4 # config.test_size # why would I do that? 
config.execution_type = "test_performance"
config.multinom_runs = 1
config.multinom_runs = 3
config.temperature = 1
greedy_full = False
MW_filter = True



In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
if config.execution_type == "test_performance":
    print("\033[1m\033[31mThis is: test_performance\033[0m")
    # config.csv_path_val = config.csv_SMI_targets  #this already got updated in simulate_syn_data
    
    model_MMT = mrtf.load_MMT_model(config)
    model_CLIP = mrtf.load_CLIP_model(config)
    #model_BLIP = mrtf.load_BLIP_model(config)
    #model_MMT = model_CLIP.MT_model
    val_dataloader = mrtf.load_data(config, stoi, stoi_MF, single=True, mode="val")
    val_dataloader_multi = mrtf.load_data(config, stoi, stoi_MF, single=False, mode="val")

    
    results_dict_bl_ZINC = mrtf.run_test_mns_performance_CLIP_3(config,  
                                                        model_MMT,
                                                        model_CLIP,
                                                        val_dataloader,
                                                        stoi, 
                                                        itos,
                                                        MW_filter)
    
    results_dict_bl_ZINC, counter = mrtf.filter_invalid_inputs(results_dict_bl_ZINC)
    
    avg_tani_bl_ZINC, html_plot = rbgvm.plot_hist_of_results(results_dict_bl_ZINC)
    
    # Slow because also just takes one at the time
    if greedy_full == True:
        results_dict_greedy_bl_ZINC, failed_bl_ZINC = mrtf.run_test_performance_CLIP_greedy_3(config,  
                                                                stoi, 
                                                                stoi_MF, 
                                                                itos, 
                                                                itos_MF)

        avg_tani_greedy_bl_ZINC, html_plot_greedy = rbgvm.plot_hist_of_results_greedy(results_dict_greedy_bl_ZINC)

    else: 
        config, results_dict_ZINC_greedy_bl = mrtf.run_greedy_sampling(config, model_MMT, val_dataloader_multi, itos, stoi)
        avg_tani_greedy_bl_ZINC = results_dict_ZINC_greedy_bl["tanimoto_mean"]
    
    total_results_bl_ZINC = mrtf.run_test_performance_CLIP_3(config, 
                                                        model_MMT, 
                                                        val_dataloader,
                                                        stoi)
    
    corr_sampleing_prob_bl_ZINC = total_results_bl_ZINC["statistics_multiplication_avg"][0]
    print("avg_tani, avg_tani_greedy, corr_sampleing_prob'")
    print(avg_tani_bl_ZINC, avg_tani_greedy_bl_ZINC, corr_sampleing_prob_bl_ZINC)       

    

In [None]:
#Select the samples that did not succeed to get right
filtered_results = [key for key, value in results_dict_bl_ZINC.items() if value[0][0][-2] != 1]
#filtered_results

In [None]:

# Save variables to a pickle file
variables_to_save = {
    'avg_tani_bl_ZINC': avg_tani_bl_ZINC,
    'results_dict_greedy_bl_ZINC': results_dict_greedy_bl_ZINC if greedy_full else None,
    'failed_bl_ZINC': failed_bl_ZINC if greedy_full else None,
    'avg_tani_greedy_bl_ZINC': avg_tani_greedy_bl_ZINC,
    'results_dict_ZINC_greedy_bl': results_dict_ZINC_greedy_bl if not greedy_full else None,
    'total_results_bl_ZINC': total_results_bl_ZINC,
    'corr_sampleing_prob_bl_ZINC': corr_sampleing_prob_bl_ZINC,
    'filtered_results': filtered_results,
    'results_dict_bl_ZINC': results_dict_bl_ZINC,
}

with open('/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/20240516_Improvment_Cycle_v2/6.2_imp_cyc_all_100_50_after_MNS_3.pkl', 'wb') as f:
    pickle.dump(variables_to_save, f)


In [None]:
#config.csv_1H_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_1H_V1_test_f_350_500_x1000.csv'
#config.csv_13C_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_13C_V1_test_350_500_x1000.csv'    
#config.csv_HSQC_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_HSQC_V1_test_350_500_x1000.csv'    
#config.csv_COSY_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_COSY_V1_test_350_500_x1000.csv'   
#config.csv_path_val = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_1H_V1_test_f_350_500_x1000.csv'
#config.pickle_file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_350_500_x1000/ML_NMR_2M_XL_1H_V1_test_f_350_500_x1000_185242.pkl"

#config.csv_1H_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_0_250_x1000/ML_NMR_2M_XL_1H_V1_test_f_0_250_x1000.csv'
#config.csv_13C_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_0_250_x1000/ML_NMR_2M_XL_13C_V1_test_0_250_x1000.csv'
#config.csv_HSQC_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_0_250_x1000/ML_NMR_2M_XL_HSQC_V1_test_0_250_x1000.csv'    
#config.csv_COSY_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_0_250_x1000/ML_NMR_2M_XL_COSY_V1_test_0_250_x1000.csv'   
#config.csv_path_val = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_0_250_x1000/ML_NMR_2M_XL_1H_V1_test_f_0_250_x1000.csv'
#config.pickle_file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_0_250_x1000/ML_NMR_2M_XL_1H_V1_test_f_0_250_x1000_933335.pkl"


config.csv_1H_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_250_350_x1000/ML_NMR_2M_XL_1H_V1_test_f_250_350_x1000.csv'
config.csv_13C_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_250_350_x1000/ML_NMR_2M_XL_13C_V1_test_250_350_x1000.csv'    
config.csv_HSQC_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_250_350_x1000/ML_NMR_2M_XL_HSQC_V1_test_250_350_x1000.csv'    
config.csv_COSY_path_SGNN = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_250_350_x1000/ML_NMR_2M_XL_COSY_V1_test_250_350_x1000.csv'   
config.csv_path_val = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_250_350_x1000/ML_NMR_2M_XL_1H_V1_test_f_250_350_x1000.csv'
config.pickle_file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/val_data_250_350_x1000/ML_NMR_2M_XL_1H_V1_test_f_250_350_x1000_285005.pkl"

# V8 Raw 
config.IR_data_folder = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/26_PubChem_dataset/IR_data"


In [None]:

config.data_size = 100 # config.test_size # why would I do that? 
#config.data_size = 4 # config.test_size # why would I do that? 
config.execution_type = "test_performance"
config.multinom_runs = 1
config.multinom_runs = 1
config.temperature = 1
greedy_full = False
MW_filter = True



In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
if config.execution_type == "test_performance":
    print("\033[1m\033[31mThis is: test_performance\033[0m")
    # config.csv_path_val = config.csv_SMI_targets  #this already got updated in simulate_syn_data
    
    model_MMT = mrtf.load_MMT_model(config)
    model_CLIP = mrtf.load_CLIP_model(config)
    #model_BLIP = mrtf.load_BLIP_model(config)
    #model_MMT = model_CLIP.MT_model
    val_dataloader = mrtf.load_data(config, stoi, stoi_MF, single=True, mode="val")
    val_dataloader_multi = mrtf.load_data(config, stoi, stoi_MF, single=False, mode="val")

    
    results_dict_bl_ZINC = mrtf.run_test_mns_performance_CLIP_3(config,  
                                                        model_MMT,
                                                        model_CLIP,
                                                        val_dataloader,
                                                        stoi, 
                                                        itos,
                                                        MW_filter)
    results_dict_bl_ZINC, counter = mrtf.filter_invalid_inputs(results_dict_bl_ZINC)
    
    avg_tani_bl_ZINC, html_plot = rbgvm.plot_hist_of_results(results_dict_bl_ZINC)
    
    # Slow because also just takes one at the time
    if greedy_full == True:
        results_dict_greedy_bl_ZINC, failed_bl_ZINC = mrtf.run_test_performance_CLIP_greedy_3(config,  
                                                                stoi, 
                                                                stoi_MF, 
                                                                itos, 
                                                                itos_MF)

        avg_tani_greedy_bl_ZINC, html_plot_greedy = rbgvm.plot_hist_of_results_greedy(results_dict_greedy_bl_ZINC)

    else: 
        config, results_dict_ZINC_greedy_bl = mrtf.run_greedy_sampling(config, model_MMT, val_dataloader_multi, itos, stoi)
        avg_tani_greedy_bl_ZINC = results_dict_ZINC_greedy_bl["tanimoto_mean"]
    
    total_results_bl_ZINC = mrtf.run_test_performance_CLIP_3(config, 
                                                        model_MMT, 
                                                        val_dataloader,
                                                        stoi)
    
    corr_sampleing_prob_bl_ZINC = total_results_bl_ZINC["statistics_multiplication_avg"][0]
    print("avg_tani, avg_tani_greedy, corr_sampleing_prob'")
    print(avg_tani_bl_ZINC, avg_tani_greedy_bl_ZINC, corr_sampleing_prob_bl_ZINC)       

    

In [None]:

# Save variables to a pickle file
variables_to_save = {
    'avg_tani_bl_ZINC': avg_tani_bl_ZINC,
    'results_dict_greedy_bl_ZINC': results_dict_greedy_bl_ZINC if greedy_full else None,
    'failed_bl_ZINC': failed_bl_ZINC if greedy_full else None,
    'avg_tani_greedy_bl_ZINC': avg_tani_greedy_bl_ZINC,
    'results_dict_ZINC_greedy_bl': results_dict_ZINC_greedy_bl if not greedy_full else None,
    'total_results_bl_ZINC': total_results_bl_ZINC,
    'corr_sampleing_prob_bl_ZINC': corr_sampleing_prob_bl_ZINC,
    'filtered_results': filtered_results,
    'results_dict_bl_ZINC': results_dict_bl_ZINC,
}

with open('/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_v2/6.2_PC_FT_Model_0_250.pkl', 'wb') as f:
    pickle.dump(variables_to_save, f)


#### 6.2 Load 

In [None]:
def mns_tani_extraction(data):
    """
    Extracts the 4th item of the first sublist of the value for each key in the dictionary.

    Parameters:
    data (dict): The input dictionary.

    Returns:
    list: A list containing the 4th item of the first sublist for each key.
    """
    # Initialize an empty list to store the results
    result_list = []

    # Iterate over the dictionary items
    for key, value in data.items():
        # Check if the value is not None and has at least one sublist
        if value and value[0]:
            # Get the 4th item of the first sublist
            item = value[0][0][3]
            # Append the item to the result list
            result_list.append(item)

    # Return the result list
    return result_list

In [None]:
'''
# 100_10
import pickle

file_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_v3/6.1_imp_cyc_100_10.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_before = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_v3/6.1_imp_cyc_100_10_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_after = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_v3/6.1_imp_cyc_100_10_after_MNS_3.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_after_mns = pickle.load(file)


# 100_30
import pickle

file_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_v3/6.1_imp_cyc_100_30.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_30_before = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_v3/6.1_imp_cyc_100_30_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_30_after = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_v3/6.1_imp_cyc_100_30_after_MNS_3.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_30_after_mns = pickle.load(file) 
        
    
# 100_50
import pickle

file_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_v3/6.1_imp_cyc_100_50.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_50_before = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_v3/6.1_imp_cyc_100_50_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_50_after = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_v3/6.1_imp_cyc_100_50_after_MNS_3.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_50_after_mns = pickle.load(file) 
    
# 100_100
import pickle

file_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_v3/6.1_imp_cyc_100_100.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_100_before = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_v3/6.1_imp_cyc_100_100_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_100_after = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_v3/6.1_imp_cyc_100_100_after_MNS_3.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_100_after_mns = pickle.load(file)     
'''

##### Load ZINC data experiment

In [None]:
# 100_10
import pickle

file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_v3/6.2_imp_cyc_all_100_10.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_before = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_v3/6.2_imp_cyc_all_100_10_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_after = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_v3/6.2_imp_cyc_all_100_10_after_MNS_3.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_after_mns = pickle.load(file)


# 100_30
import pickle

file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_v3/6.2_imp_cyc_all_100_30.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_30_before = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_v3/6.2_imp_cyc_all_100_30_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_30_after = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_v3/6.2_imp_cyc_all_100_30_after_MNS_3.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_30_after_mns = pickle.load(file) 
        
    
# 100_50
import pickle

file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_v3/6.2_imp_cyc_all_100_50.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_50_before = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_v3/6.2_imp_cyc_all_100_50_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_50_after = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_v3/6.2_imp_cyc_all_100_50_after_MNS_3.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_50_after_mns = pickle.load(file) 
    
# 100_100
import pickle

file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_v3/6.2_imp_cyc_all_100_100_v2.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_100_before = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_v3/6.2_imp_cyc_all_100_100_after_MNS_3_v2.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_100_after = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_v3/6.2_imp_cyc_all_100_100_after_MNS_3_v2.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_100_after_mns = pickle.load(file)     

In [None]:
tani_before_all_100_10 = imp_cyc_all_100_10_before["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_10 = imp_cyc_all_100_10_after["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_10_mns = mns_tani_extraction(imp_cyc_all_100_10_after_mns["results_dict_bl_ZINC"])
tani_before_all_100_30 = imp_cyc_all_100_30_before["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_30 = imp_cyc_all_100_30_after["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_30_mns = mns_tani_extraction(imp_cyc_all_100_30_after_mns["results_dict_bl_ZINC"])
tani_before_all_100_50 = imp_cyc_all_100_50_before["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_50 = imp_cyc_all_100_50_after["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_50_mns = mns_tani_extraction(imp_cyc_all_100_50_after_mns["results_dict_bl_ZINC"])
tani_before_all_100_100 = imp_cyc_all_100_100_before["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_100 = imp_cyc_all_100_100_after["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_100_mns = mns_tani_extraction(imp_cyc_all_100_100_after_mns["results_dict_bl_ZINC"])

In [None]:
imp_cyc_all_100_10_before["results_dict_ZINC_greedy_bl"].keys()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

fontsize = 26
tani_data = [
    (tani_before_all_100_10, 'Greedy Tanomoto Before FT10'), (tani_after_all_100_10, 'Greedy Tanomoto After FT10'), (tani_after_all_100_10_mns, '3 Multinomial Sampling After FT10'),
    (tani_before_all_100_30, 'Greedy Tanomoto Before FT30'), (tani_after_all_100_30, 'Greedy Tanomoto After FT30'), (tani_after_all_100_30_mns, '3 Multinomial Sampling After FT30'),
    (tani_before_all_100_50, 'Greedy Tanomoto Before FT50'), (tani_after_all_100_50, 'Greedy Tanomoto After FT50'), (tani_after_all_100_50_mns, '3 Multinomial Sampling After FT50'), 
    (tani_before_all_100_50, 'Greedy Tanomoto Before FT100'), (tani_after_all_100_100, 'Greedy Tanomoto After FT50'), (tani_after_all_100_100_mns, '3 Multinomial Sampling After FT100')
]

# Extract numerical data (count of perfect matches)
numerical_data = [np.sum(np.array(data) == 1.0) for data, _ in tani_data]

# Labels for each group
group_labels = ['MMST IC-10', 'MMST IC-30', 'MMST IC-50', 'MMST IC-100']
condition_labels = ['Before IC', 'After IC', 'MNS: 3']

# Colors for the bars
colors = ['#A1C8F3', '#FFB381', '#8BE5A0']  # Red, Blue, Green

# Create the plot
fig, ax = plt.subplots(figsize=(17, 9))

# Set the width of each bar and the spacing between groups
bar_width = 0.25
group_spacing = 0.05

# Calculate positions for each bar
num_groups = len(group_labels)
indices = np.arange(num_groups)
positions = [indices + i * (bar_width + group_spacing) for i in range(3)]

# Create grouped bar plot
for i in range(3):  # Three conditions: Before, After, MNS
    ax.bar(positions[i], numerical_data[i::3], 
           width=bar_width, color=colors[i], edgecolor='black', 
           label=condition_labels[i])

# Customize the plot
ax.set_ylabel('Number of Perfect Matches (Tanimoto = 1)', fontsize=fontsize)
ax.set_title('Tanimoto Matches Across Different Conditions: ZINC 250-350 Da', fontsize=fontsize)
ax.set_xlabel('Number of Molecule Analogues for Improvement Cycle', fontsize=fontsize)

# Set x-ticks in the middle of each group
group_centers = indices + bar_width
ax.set_xticks(group_centers)
ax.set_xticklabels(group_labels, fontsize=fontsize)

# Add value labels on top of each bar
for i, v in enumerate(numerical_data):
    ax.text(positions[i % 3][i // 3], v, str(v), ha='center', va='bottom', fontsize=fontsize)

# Add legend
ax.legend(fontsize=fontsize, loc='upper left')
ax.tick_params(axis='y', labelsize=fontsize)
ax.set_ylim(0, max(numerical_data) * 1.1)  # Set y-axis limit to 110% of max value

# Adjust layout and display
plt.tight_layout()

# Specify the path and file name where you want to save the figure.
save_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/6.1.2_ZINC_Bar_Chart_Greedy_MNS_Tanimoto_v4.png'
plt.savefig(save_path, format='png', dpi=300)  # Save as PNG with 300 dpi

plt.show()

In [None]:
# Extract the variables

corr_sampleing_prob_bl_ZINC = imp_cyc_all_100_10_before['corr_sampleing_prob_bl_ZINC']
results_dict_ZINC_greedy_bl = imp_cyc_all_100_10_before['results_dict_ZINC_greedy_bl']
results_dict_greedy_bl_ZINC = imp_cyc_all_100_10_before['results_dict_greedy_bl_ZINC']


#avg_tani_bl_ZINC = imp_cyc_all_100_10_before['avg_tani_bl_ZINC']
#failed_bl_ZINC = imp_cyc_all_100_10_before['failed_bl_ZINC']
#avg_tani_greedy_bl_ZINC = imp_cyc_all_100_10_before['avg_tani_greedy_bl_ZINC']
#total_results_bl_ZINC = imp_cyc_all_100_10_before['total_results_bl_ZINC']
#filtered_results = imp_cyc_all_100_10_before['filtered_results']

In [None]:
imp_cyc_all_100_10_before.keys()

In [None]:
corr_sp_before_all_100_10 = imp_cyc_all_100_10_before["corr_sampleing_prob_bl_ZINC"]
corr_sp_after_all_100_10 = imp_cyc_all_100_10_after["corr_sampleing_prob_bl_ZINC"]
corr_sp_before_all_100_30 = imp_cyc_all_100_30_before["corr_sampleing_prob_bl_ZINC"]
corr_sp_after_all_100_30 = imp_cyc_all_100_30_after["corr_sampleing_prob_bl_ZINC"]
corr_sp_before_all_100_50 = imp_cyc_all_100_50_before["corr_sampleing_prob_bl_ZINC"]
corr_sp_after_all_100_50 = imp_cyc_all_100_50_after["corr_sampleing_prob_bl_ZINC"]
corr_sp_before_all_100_100 = imp_cyc_all_100_100_before["corr_sampleing_prob_bl_ZINC"]
corr_sp_after_all_100_100 = imp_cyc_all_100_100_after["corr_sampleing_prob_bl_ZINC"]



In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Data for the bar chart
values = [
    corr_sp_before_all_100_10,  # Before (same for all)
    corr_sp_after_all_100_10,   # After FT10
    corr_sp_after_all_100_30,   # After FT30
    corr_sp_after_all_100_50,   # After FT50
    corr_sp_after_all_100_100   # After FT100
]

# Labels for the bars
labels = [
    'MMST Model', 
    'IC-10', 
    'IC-30', 
    'IC-50', 
    'IC-100'
]

# Colors for the bars (based on provided image)
colors = [
        '#A1C8F3',  # light blue/periwinkle
    '#FFB381',  # salmon/peach
    '#8BE5A0',  # mint green
    '#FF9D9A',  # coral pink
    '#D1B9FE',  # lavender
    '#DEBA9A',  # beige/tan
    '#FCAEE3',  # pink
    '#CFCECE',  # light gray
    '#FEFDA2',  # pale yellow
    '#B8F1EF',  # light cyan/aqua
      ]

# Create a bar chart
fig, ax = plt.subplots(figsize=(17, 9))
bars = ax.bar(labels, values, color=colors, edgecolor='black')

# Add numbers inside the bars
for bar, value in zip(bars, values):
    ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() / 2, f'{value:.2f}', ha='center', va='center', fontsize=fontsize, color='black')

# Add title and labels
plt.title('Averaged Correct Sample Probability: ZINC 250-350 Da', fontsize=fontsize)
plt.ylabel('Averaged Correct Sample Probability', fontsize=fontsize)
ax.set_xlabel('Number of Molecule Analogues for Improvment Cycle', fontsize=fontsize)

#plt.xlabel('Stages', fontsize=16)

# Increase the size of the labels on the x-axis
ax.set_xticklabels(labels, fontsize=fontsize, rotation=0)

# Increase the size of the ticks on the y-axis
ax.tick_params(axis='y', labelsize=fontsize)

# Save the plot to a specified location
output_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/6.2.2_ZINC_correct_sample_prob_comparison_v3.png"
plt.savefig(output_path, bbox_inches='tight')

# Show the plot
plt.show()


##### Chemical spaces

In [None]:
tani_before_all_100_10 = imp_cyc_all_100_10_before["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_10 = imp_cyc_all_100_10_after["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_10_mns = mns_tani_extraction(imp_cyc_all_100_10_after_mns["results_dict_bl_ZINC"])
tani_before_all_100_30 = imp_cyc_all_100_30_before["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_30 = imp_cyc_all_100_30_after["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_30_mns = mns_tani_extraction(imp_cyc_all_100_30_after_mns["results_dict_bl_ZINC"])
tani_before_all_100_50 = imp_cyc_all_100_50_before["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_50 = imp_cyc_all_100_50_after["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_50_mns = mns_tani_extraction(imp_cyc_all_100_50_after_mns["results_dict_bl_ZINC"])
tani_before_all_100_100 = imp_cyc_all_100_100_before["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_100 = imp_cyc_all_100_100_after["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_100_mns = mns_tani_extraction(imp_cyc_all_100_100_after_mns["results_dict_bl_ZINC"])

In [None]:
results_dict_bl_ZINC_before = imp_cyc_all_100_10_before["results_dict_bl_ZINC"]
filtered_results_false_before = [key for key, value in results_dict_bl_ZINC_before.items() if value[0][0][-2] == 1]
len(filtered_results_false_before)

In [None]:
results_dict_bl_ZINC_after = imp_cyc_all_100_10_after["results_dict_bl_ZINC"]
filtered_results_false_after = [key for key, value in results_dict_bl_ZINC_after.items() if value[0][0][-2] == 1]
len(filtered_results_false_after)

In [None]:
all_gen_smis = imp_cyc_all_100_10_before["all_gen_smis"]

In [None]:

# Convert SMILES to fingerprints
def smiles_to_fps(smiles_list):
    fps = []
    for smiles in tqdm(smiles_list):
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=512)
            fps.append(fp)
    return np.array(fps)

fps1 = smiles_to_fps(filtered_results_false_before)
fps2 = smiles_to_fps(all_gen_smis)
fps3 = smiles_to_fps(filtered_results_false_after)

# Concatenate the two fingerprint arrays
all_fps = np.vstack([fps1, fps2, fps3])

# Dimensionality Reduction: t-SNE
tsne = TSNE(n_components=2, random_state=0)
X_tsne = tsne.fit_transform(all_fps)

# Dimensionality Reduction: UMAP
umap_model = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2)
X_umap = umap_model.fit_transform(all_fps)

# Dimensionality Reduction: PCA
pca = PCA(n_components=2)
X_pca = pca.fit_transform(all_fps)


# Plotting function with the legend below the graph
def plot_2D(X, title, label1='Molecules Correct before', label2='Generated molecules', label3='Molecules Correct after'):
    plt.figure(figsize=(10, 8))
    plt.scatter(X[len(fps1):len(fps1)+len(fps2), 0], X[len(fps1):len(fps1)+len(fps2), 1], c='g', marker='^', label=label2, alpha=0.1)
    plt.scatter(X[:len(fps1), 0], X[:len(fps1), 1], c='b', marker='o', label=label1, alpha=1, s=50)
    plt.scatter(X[len(fps1)+len(fps2):, 0], X[len(fps1)+len(fps2):, 1], c='r', marker='s', label=label3, alpha=1, s=10)
    plt.title(title, fontsize=20)
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), ncol=3, fontsize=14)  
    output_path = "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/nmr_project/1_Dataexploration/2_paper_code/Experiments_SLURM/20.0_SLURM_MasterTransformer/Figures_Paper_2/6.1.2_t-SNE_FT10_v1.png"
    plt.savefig(output_path, bbox_inches='tight')
    plt.show()

# Plot t-SNE
plot_2D(X_tsne, 't-SNE Plot: FT10')

# Plot UMAP
#plot_2D(X_umap, 'UMAP Plot')

# Plot PCA
#plot_2D(X_pca, 'PCA Plot')

#### PubChem

##### Load PubChem 0-250 v1

In [None]:
# 100_10
import pickle

file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_0_250/6.2_imp_cyc_all_100_10.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_before = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_0_250/6.2_imp_cyc_all_100_10_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_after = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_0_250/6.2_imp_cyc_all_100_10_after_MNS_3.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_after_mns = pickle.load(file)


# 100_30
import pickle

file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_0_250/6.2_imp_cyc_all_100_30.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_30_before = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_0_250/6.2_imp_cyc_all_100_30_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_30_after = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_0_250/6.2_imp_cyc_all_100_30_after_MNS_3.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_30_after_mns = pickle.load(file) 
        
    
# 100_50
import pickle

file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_0_250/6.2_imp_cyc_all_100_50.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_50_before = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_0_250/6.2_imp_cyc_all_100_50_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_50_after = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_0_250/6.2_imp_cyc_all_100_50_after_MNS_3.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_50_after_mns = pickle.load(file) 
    
# 100_100
import pickle

file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_0_250/6.2_imp_cyc_all_100_100.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_100_before = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_0_250/6.2_imp_cyc_all_100_100_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_100_after = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_0_250/6.2_imp_cyc_all_100_100_after_MNS_3.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_100_after_mns = pickle.load(file)     

In [None]:
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_v2/6.2_PC_FT_Model_0_250.pkl'
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_250_350/6.2_PC_FT_Model_0_250.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    PC_0_250_FT = pickle.load(file)     

In [None]:
tani_before_all_100_10 = imp_cyc_all_100_10_before["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_10 = imp_cyc_all_100_10_after["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_10_mns = mns_tani_extraction(imp_cyc_all_100_10_after_mns["results_dict_bl_ZINC"])
tani_before_all_100_30 = imp_cyc_all_100_30_before["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_30 = imp_cyc_all_100_30_after["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_30_mns = mns_tani_extraction(imp_cyc_all_100_30_after_mns["results_dict_bl_ZINC"])
tani_before_all_100_50 = imp_cyc_all_100_50_before["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_50 = imp_cyc_all_100_50_after["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_50_mns = mns_tani_extraction(imp_cyc_all_100_50_after_mns["results_dict_bl_ZINC"])
tani_before_all_100_100 = imp_cyc_all_100_100_before["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_100 = imp_cyc_all_100_100_after["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_100_mns = mns_tani_extraction(imp_cyc_all_100_100_after_mns["results_dict_bl_ZINC"])
tani_PC_0_250_FT = PC_0_250_FT["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Data and labels
tani_data = [
    (tani_before_all_100_10, 'FT10 Before'), (tani_after_all_100_10, 'FT10 After'), (tani_after_all_100_10_mns, 'FT10 MNS'),
    (tani_before_all_100_30, 'FT30 Before'), (tani_after_all_100_30, 'FT30 After'), (tani_after_all_100_30_mns, 'FT30 MNS'),
    (tani_before_all_100_50, 'FT50 Before'), (tani_after_all_100_50, 'FT50 After'), (tani_after_all_100_50_mns, 'FT50 MNS'), 
    (tani_before_all_100_100, 'FT100 Before'), (tani_after_all_100_100, 'FT100 After'), (tani_after_all_100_100_mns, 'FT100 MNS')
]

# Add the PubChem Fine-Tuned data
tani_PC_0_250_FT_list = [tani_PC_0_250_FT for _ in range(4)]  # Repeat the data for each group

# Count perfect matches (Tanimoto = 1) for each condition
perfect_matches = [np.sum(np.array(data) == 1) for data, _ in tani_data]
perfect_matches_pc = [np.sum(np.array(data) == 1) for data in tani_PC_0_250_FT_list]

# Combine all perfect matches
all_perfect_matches = []
for i in range(4):  # For each FT group (10, 30, 50, 100)
    all_perfect_matches.extend([perfect_matches_pc[i]] + perfect_matches[i*3:(i+1)*3])

# Colors for the bars
colors = [
    '#A1C8F3',  # light blue/periwinkle
    '#FFB381',  # salmon/peach
    '#8BE5A0',  # mint green
    '#FF9D9A',  # coral pink
    '#D1B9FE',  # lavender
    '#DEBA9A',  # beige/tan
    '#FCAEE3',  # pink
    '#CFCECE',  # light gray
    '#FEFDA2',  # pale yellow
    '#B8F1EF',  # light cyan/aqua
]

# Labels
group_labels = ['MMST IC-10', 'MMST IC-30', 'MMST IC-50', 'MMST IC-100']
condition_labels = ['PubChem FT', 'Before IC', 'After IC', 'MNS: 3']

# Create the plot
fig, ax = plt.subplots(figsize=(17, 9))

# Set the width of each bar and the spacing between groups
bar_width = 0.2
group_spacing = 0.03

# Calculate positions for each bar
num_groups = len(group_labels)
indices = np.arange(num_groups)
positions = [indices + i * (bar_width + group_spacing) for i in range(4)]

# Create grouped bar plot
for i in range(4):  # Four conditions: PubChem FT, Before, After, MNS
    ax.bar(positions[i], all_perfect_matches[i::4], 
           width=bar_width, color=colors[i], edgecolor='black', 
           label=condition_labels[i])

# Customize the plot
ax.set_ylabel('Number of Perfect Matches (Tanimoto = 1)', fontsize=fontsize)
ax.set_title('Tanimoto Matches Across Different Conditions: PubChem 0-250 Da', fontsize=fontsize)
ax.set_xlabel('Number of Molecule Analogues for Improvment Cycle', fontsize=fontsize)

# Set x-ticks in the middle of each group
group_centers = indices + 1.5 * bar_width + 0.5 * group_spacing
ax.set_xticks(group_centers)
ax.set_xticklabels(group_labels, fontsize=fontsize)

# Add value labels on top of each bar
for i, v in enumerate(all_perfect_matches):
    ax.text(positions[i % 4][i // 4], v, str(v), ha='center', va='bottom', fontsize=fontsize)

# Add legend
ax.legend(fontsize=fontsize, loc='upper left')
ax.tick_params(axis='y', labelsize=fontsize)
#ax.set_ylim(0, 105)

# Adjust layout and display
plt.tight_layout()
save_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/6.2.2_PC_Bar_Chart_Greedy_MNS_Tanimoto_0_250_v4.png'  # Update this path as needed

# Save the figure
plt.savefig(save_path, format='png', dpi=300)  # Save as PNG with 300 dpi

plt.show()

# Uncomment to save the plot
# plt.savefig("/path/to/save/grouped_perfect_matches_bar_plot.png", dpi=300, bbox_inches='tight')

In [None]:
corr_sp_before_all_100_10 = imp_cyc_all_100_10_before["corr_sampleing_prob_bl_ZINC"]
corr_sp_after_all_100_10 = imp_cyc_all_100_10_after["corr_sampleing_prob_bl_ZINC"]
corr_sp_before_all_100_30 = imp_cyc_all_100_30_before["corr_sampleing_prob_bl_ZINC"]
corr_sp_after_all_100_30 = imp_cyc_all_100_30_after["corr_sampleing_prob_bl_ZINC"]
corr_sp_before_all_100_50 = imp_cyc_all_100_50_before["corr_sampleing_prob_bl_ZINC"]
corr_sp_after_all_100_50 = imp_cyc_all_100_50_after["corr_sampleing_prob_bl_ZINC"]
corr_sp_before_all_100_100 = imp_cyc_all_100_100_before["corr_sampleing_prob_bl_ZINC"]
corr_sp_after_all_100_100 = imp_cyc_all_100_100_after["corr_sampleing_prob_bl_ZINC"]
corr_PC_0_250_FT = PC_0_250_FT["corr_sampleing_prob_bl_ZINC"]


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Data for the bar chart
values = [
    corr_sp_before_all_100_10,  # Before (same for all)
    corr_PC_0_250_FT,         # PC FT
    corr_sp_after_all_100_10,   # After FT10
    corr_sp_after_all_100_30,   # After FT30
    corr_sp_after_all_100_50,   # After FT50
    corr_sp_after_all_100_100   # After FT100
]

# Labels for the bars
labels = [
    'MMST Model', 
    'PC FT', 
    'IC-10', 
    'IC-30', 
    'IC-50', 
    'IC-100'
]

# Colors for the bars (based on provided image)
colors = [
    '#A1C8F3',  # light blue/periwinkle
    '#B8F1EF',  # light cyan/aqua
    '#FFB381',  # salmon/peach
    '#8BE5A0',  # mint green
    '#FF9D9A',  # coral pink
    '#D1B9FE',  # lavender
    '#DEBA9A',  # beige/tan
    '#FCAEE3',  # pink
    '#CFCECE',  # light gray
    '#FEFDA2',  # pale yellow
]

# Create a bar chart
fig, ax = plt.subplots(figsize=(17, 9))
bars = ax.bar(labels, values, color=colors, edgecolor='black')

# Add numbers inside the bars
for bar, value in zip(bars, values):
    ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() / 2, f'{value:.2f}', ha='center', va='center', fontsize=fontsize, color='black')

# Add title and labels
plt.title('Averaged Correct Sample Probability: PubChem 0-250 Da', fontsize=fontsize)
plt.ylabel('Averaged Correct Sample Probability', fontsize=fontsize)
ax.set_xlabel('Number of Molecule Analogues for Improvment Cycle', fontsize=fontsize)

# Increase the size of the labels on the x-axis
ax.set_xticklabels(labels, fontsize=fontsize, rotation=0)

# Increase the size of the ticks on the y-axis
ax.tick_params(axis='y', labelsize=21)

# Save the plot to a specified location
output_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/6.2.2_PC_correct_sample_prob_comparison_0_250_v3.png"
plt.savefig(output_path, bbox_inches='tight')

# Show the plot
plt.show()


##### Load PubChem 250-350 v2

In [None]:
# 100_10
import pickle

file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_250_350/6.2_imp_cyc_all_100_10.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_before = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_250_350/6.2_imp_cyc_all_100_10_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_after = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_250_350/6.2_imp_cyc_all_100_10_after_MNS_3.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_after_mns = pickle.load(file)


# 100_30
import pickle

file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_250_350/6.2_imp_cyc_all_100_30.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_30_before = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_250_350/6.2_imp_cyc_all_100_30_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_30_after = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_250_350/6.2_imp_cyc_all_100_30_after_MNS_3.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_30_after_mns = pickle.load(file) 
        
    
# 100_50
import pickle

file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_250_350/6.2_imp_cyc_all_100_50.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_50_before = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_250_350/6.2_imp_cyc_all_100_50_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_50_after = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_250_350/6.2_imp_cyc_all_100_50_after_MNS_3.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_50_after_mns = pickle.load(file) 
    
# 100_100
import pickle

file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_250_350/6.2_imp_cyc_all_100_100.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_100_before = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_250_350/6.2_imp_cyc_all_100_100_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_100_after = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_250_350/6.2_imp_cyc_all_100_100_after_MNS_3.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_100_after_mns = pickle.load(file)     

In [None]:
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_250_350/6.2_PC_FT_Model_250_350.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    PC_250_350_FT = pickle.load(file)     

In [None]:
tani_before_all_100_10 = imp_cyc_all_100_10_before["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_10 = imp_cyc_all_100_10_after["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_10_mns = mns_tani_extraction(imp_cyc_all_100_10_after_mns["results_dict_bl_ZINC"])
tani_before_all_100_30 = imp_cyc_all_100_30_before["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_30 = imp_cyc_all_100_30_after["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_30_mns = mns_tani_extraction(imp_cyc_all_100_30_after_mns["results_dict_bl_ZINC"])
tani_before_all_100_50 = imp_cyc_all_100_50_before["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_50 = imp_cyc_all_100_50_after["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_50_mns = mns_tani_extraction(imp_cyc_all_100_50_after_mns["results_dict_bl_ZINC"])
tani_before_all_100_100 = imp_cyc_all_100_100_before["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_100 = imp_cyc_all_100_100_after["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_100_mns = mns_tani_extraction(imp_cyc_all_100_100_after_mns["results_dict_bl_ZINC"])
tani_PC_250_350_FT = PC_250_350_FT["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]


In [None]:
import numpy as np
import matplotlib.pyplot as plt

fontsize = 26
# Data and labels
tani_data = [
    (tani_before_all_100_10, 'FT10 Before'), (tani_after_all_100_10, 'FT10 After'), (tani_after_all_100_10_mns, 'FT10 MNS'),
    (tani_before_all_100_30, 'FT30 Before'), (tani_after_all_100_30, 'FT30 After'), (tani_after_all_100_30_mns, 'FT30 MNS'),
    (tani_before_all_100_50, 'FT50 Before'), (tani_after_all_100_50, 'FT50 After'), (tani_after_all_100_50_mns, 'FT50 MNS'), 
    (tani_before_all_100_100, 'FT100 Before'), (tani_after_all_100_100, 'FT100 After'), (tani_after_all_100_100_mns, 'FT100 MNS')
]

# Add the PubChem Fine-Tuned data
tani_PC_250_350_FT_list = [tani_PC_250_350_FT for _ in range(4)]  # Repeat the data for each group

# Count perfect matches (Tanimoto = 1) for each condition
perfect_matches = [np.sum(np.array(data) == 1) for data, _ in tani_data]
perfect_matches_pc = [np.sum(np.array(data) == 1) for data in tani_PC_250_350_FT_list]

# Combine all perfect matches
all_perfect_matches = []
for i in range(4):  # For each FT group (10, 30, 50, 100)
    all_perfect_matches.extend([perfect_matches_pc[i]] + perfect_matches[i*3:(i+1)*3])

# Colors for the bars
condition_labels = ['PubChem FT', 'Before IC', 'After IC', 'MNS: 3']

# Labels
group_labels = ['MMST IC-10', 'MMST IC-30', 'MMST IC-50', 'MMST IC-100']
colors = [ '#A1C8F3',  # light blue/periwinkle
    '#FFB381',  # salmon/peach
    '#8BE5A0',  # mint green
    '#FF9D9A',  # coral pink
    '#D1B9FE',  # lavender
    '#DEBA9A',  # beige/tan
    '#FCAEE3',  # pink
    '#CFCECE',  # light gray
    '#FEFDA2',  # pale yellow
    '#B8F1EF',  # light cyan/aqua
]

# Create the plot
fig, ax = plt.subplots(figsize=(17, 9))

# Set the width of each bar and the spacing between groups
bar_width = 0.2
group_spacing = 0.03

# Calculate positions for each bar
num_groups = len(group_labels)
indices = np.arange(num_groups)
positions = [indices + i * (bar_width + group_spacing) for i in range(4)]

# Create grouped bar plot
for i in range(4):  # Four conditions: PubChem FT, Before, After, MNS
    ax.bar(positions[i], all_perfect_matches[i::4], 
           width=bar_width, color=colors[i], edgecolor='black', 
           label=condition_labels[i])

# Customize the plot
ax.set_ylabel('Number of Perfect Matches (Tanimoto = 1)', fontsize=fontsize)
ax.set_xlabel('Number of Molecule Analogues for Improvement Cycle', fontsize=fontsize)
ax.set_title('Tanimoto Matches Across Different Conditions: PubChem 250-350 Da', fontsize=fontsize)

# Set x-ticks in the middle of each group
group_centers = indices + 1.5 * bar_width + 0.5 * group_spacing
ax.set_xticks(group_centers)
ax.set_xticklabels(group_labels, fontsize=fontsize)

# Set y-axis tick label font size
ax.tick_params(axis='y', labelsize=fontsize)

# Add value labels on top of each bar
for i, v in enumerate(all_perfect_matches):
    ax.text(positions[i % 4][i // 4], v, str(v), ha='center', va='bottom', fontsize=fontsize)

# Add legend
ax.legend(fontsize=22, loc='upper left')

# Adjust layout and display
plt.tight_layout()

# Save the plot to a specified location
output_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/6.2.2_PC_Bar_Chart_Greedy_MNS_Tanimoto_250_350_v3.png"
plt.savefig(output_path, dpi=300, bbox_inches='tight')

plt.show()

In [None]:
corr_sp_before_all_100_10 = imp_cyc_all_100_10_before["corr_sampleing_prob_bl_ZINC"]
corr_sp_after_all_100_10 = imp_cyc_all_100_10_after["corr_sampleing_prob_bl_ZINC"]
corr_sp_before_all_100_30 = imp_cyc_all_100_30_before["corr_sampleing_prob_bl_ZINC"]
corr_sp_after_all_100_30 = imp_cyc_all_100_30_after["corr_sampleing_prob_bl_ZINC"]
corr_sp_before_all_100_50 = imp_cyc_all_100_50_before["corr_sampleing_prob_bl_ZINC"]
corr_sp_after_all_100_50 = imp_cyc_all_100_50_after["corr_sampleing_prob_bl_ZINC"]
corr_sp_before_all_100_100 = imp_cyc_all_100_100_before["corr_sampleing_prob_bl_ZINC"]
corr_sp_after_all_100_100 = imp_cyc_all_100_100_after["corr_sampleing_prob_bl_ZINC"]
corr_PC_250_350_FT = PC_250_350_FT["corr_sampleing_prob_bl_ZINC"]


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Data for the bar chart
values = [
    corr_sp_before_all_100_10,  # Before (same for all)
    corr_PC_250_350_FT,         # PC FT
    corr_sp_after_all_100_10,   # After FT10
    corr_sp_after_all_100_30,   # After FT30
    corr_sp_after_all_100_50,   # After FT50
    corr_sp_after_all_100_100   # After FT100
]

# Labels for the bars
labels = [
    'MMST Model', 
    'PC FT', 
    'IC-10', 
    'IC-30', 
    'IC-50', 
    'IC-100'
]

# Colors for the bars (based on provided image)
colors = [
    '#A1C8F3',  # light blue/periwinkle
    '#B8F1EF',  # light cyan/aqua
    '#FFB381',  # salmon/peach
    '#8BE5A0',  # mint green
    '#FF9D9A',  # coral pink
    '#D1B9FE',  # lavender
    '#DEBA9A',  # beige/tan
    '#FCAEE3',  # pink
    '#CFCECE',  # light gray
    '#FEFDA2',  # pale yellow
]

# Create a bar chart
fig, ax = plt.subplots(figsize=(17, 9))
bars = ax.bar(labels, values, color=colors, edgecolor='black')

# Add numbers inside the bars
for bar, value in zip(bars, values):
    ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() / 2, f'{value:.2f}', ha='center', va='center', fontsize=fontsize, color='black')

# Add title and labels
plt.title('Averaged Correct Sample Probability: PubChem 250-350 Da', fontsize=fontsize)
plt.ylabel('Averaged Correct Sample Probability', fontsize=fontsize)
ax.set_xlabel('Number of Molecule Analogues for Improvment Cycle', fontsize=fontsize)

# Increase the size of the labels on the x-axis
ax.set_xticklabels(labels, fontsize=fontsize, rotation=0)

# Increase the size of the ticks on the y-axis
ax.tick_params(axis='y', labelsize=fontsize)

# Save the plot to a specified location
output_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/6.2.2_PC_correct_sample_prob_comparison_250_350_v3.png"
plt.savefig(output_path, bbox_inches='tight')

# Show the plot
plt.show()


##### Load PubChem 350-500

In [None]:
# 100_10
import pickle

file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_PC_350_500/6.2_imp_cyc_all_100_10.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_before = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_PC_350_500/6.2_imp_cyc_all_100_10_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_after = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_PC_350_500/6.2_imp_cyc_all_100_30_after_MNS_3.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_after_mns = pickle.load(file)


# 100_30
import pickle

file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_PC_350_500/6.2_imp_cyc_all_100_30.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_30_before = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_PC_350_500/6.2_imp_cyc_all_100_30_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_30_after = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_PC_350_500/6.2_imp_cyc_all_100_30_after_MNS_3.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_30_after_mns = pickle.load(file) 
        
    
# 100_50
import pickle

file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_PC_350_500/6.2_imp_cyc_all_100_50.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_50_before = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_PC_350_500/6.2_imp_cyc_all_100_50_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_50_after = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_PC_350_500/6.2_imp_cyc_all_100_50_after_MNS_3.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_50_after_mns = pickle.load(file) 
    
# 100_100
import pickle

file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_PC_350_500/6.2_imp_cyc_all_100_100.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_100_before = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_PC_350_500/6.2_imp_cyc_all_100_100_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_100_after = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_PC_350_500/6.2_imp_cyc_all_100_100_after_MNS_3.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_100_after_mns = pickle.load(file)     

In [None]:
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_250_350/6.2_PC_FT_Model_350_500.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    PC_350_500_FT = pickle.load(file)     

In [None]:
# 100_10 second fine-tuning step
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240522_PC_Imp_Cycle_2_from_350_500/6.2_imp_cyc_all_100_10_after_epoch41.pkl'
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240522_PC_Imp_Cycle_2_from_350_500/6.2_imp_cyc_all_100_10_after_MNS_3.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_after_2 = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240522_PC_Imp_Cycle_2_from_350_500/6.2_imp_cyc_all_100_10_after_epoch41.pkl'
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240522_PC_Imp_Cycle_2_from_350_500/6.2_imp_cyc_all_100_10_after_MNS_3.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_after_mns_2 = pickle.load(file)


In [None]:
tani_before_all_100_10 = imp_cyc_all_100_10_before["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_10 = imp_cyc_all_100_10_after["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_10_mns = mns_tani_extraction(imp_cyc_all_100_10_after_mns["results_dict_bl_ZINC"])

tani_after_all_100_10_2 = imp_cyc_all_100_10_after_2["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_10_mns_2 = mns_tani_extraction(imp_cyc_all_100_10_after_mns_2["results_dict_bl_ZINC"])


tani_before_all_100_30 = imp_cyc_all_100_30_before["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_30 = imp_cyc_all_100_30_after["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_30_mns = mns_tani_extraction(imp_cyc_all_100_30_after_mns["results_dict_bl_ZINC"])
tani_before_all_100_50 = imp_cyc_all_100_50_before["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_50 = imp_cyc_all_100_50_after["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_50_mns = mns_tani_extraction(imp_cyc_all_100_50_after_mns["results_dict_bl_ZINC"])
tani_before_all_100_100 = imp_cyc_all_100_100_before["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_100 = imp_cyc_all_100_100_after["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_all_100_100_mns = mns_tani_extraction(imp_cyc_all_100_100_after_mns["results_dict_bl_ZINC"])
tani_PC_350_500_FT = PC_350_500_FT["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Data and labels
tani_data = [
    (tani_before_all_100_10, 'FT10 Before'), (tani_after_all_100_10, 'FT10 After'), (tani_after_all_100_10_mns, 'FT10 MNS'),
    #(tani_before_all_100_10, 'FT10 2x Before'), (tani_after_all_100_10_2, 'FT10 2x After'), (tani_after_all_100_10_mns_2, 'FT10 2x MNS'),
    (tani_before_all_100_30, 'FT30 Before'), (tani_after_all_100_30, 'FT30 After'), (tani_after_all_100_30_mns, 'FT30 MNS'),
    (tani_before_all_100_50, 'FT50 Before'), (tani_after_all_100_50, 'FT50 After'), (tani_after_all_100_50_mns, 'FT50 MNS'), 
    (tani_before_all_100_100, 'FT100 Before'), (tani_after_all_100_100, 'FT100 After'), (tani_after_all_100_100_mns, 'FT100 MNS')
]

# Add the PubChem Fine-Tuned data
tani_PC_350_500_FT_list = [tani_PC_350_500_FT for _ in range(5)]  # Repeat the data for each group

# Count perfect matches (Tanimoto = 1) for each condition
perfect_matches = [np.sum(np.array(data) == 1) for data, _ in tani_data]
perfect_matches_pc = [np.sum(np.array(data) == 1) for data in tani_PC_350_500_FT_list]

# Combine all perfect matches
all_perfect_matches = []
for i in range(4):  # For each FT group (10, 10 2x, 30, 50, 100)
    all_perfect_matches.extend([perfect_matches_pc[i]] + perfect_matches[i*3:(i+1)*3])

# Colors for the bars
colors = [    '#A1C8F3',  # light blue/periwinkle
    '#FFB381',  # salmon/peach
    '#8BE5A0',  # mint green
    '#FF9D9A',  # coral pink
    '#D1B9FE',  # lavender
    '#DEBA9A',  # beige/tan
    '#FCAEE3',  # pink
    '#CFCECE',  # light gray
    '#FEFDA2',  # pale yellow
    ]

# Labels
group_labels = ['MMST IC-10',  'MMST IC-30', 'MMST IC-50', 'MMST IC-100'] #'FT10 2x',
condition_labels = ['PubChem FT', 'Before IC', 'After IC', 'MNS: 3']

# Create the plot
fig, ax = plt.subplots(figsize=(17, 9))

# Set the width of each bar and the spacing between groups
bar_width = 0.2
group_spacing = 0.03

# Calculate positions for each bar
num_groups = len(group_labels)
indices = np.arange(num_groups)
positions = [indices + i * (bar_width + group_spacing) for i in range(4)]

# Create grouped bar plot
for i in range(4):  # Four conditions: PubChem FT, Before, After, MNS
    ax.bar(positions[i], all_perfect_matches[i::4], 
           width=bar_width, color=colors[i], edgecolor='black', 
           label=condition_labels[i])

# Customize the plot
ax.set_ylabel('Number of Perfect Matches (Tanimoto = 1)', fontsize=fontsize)
ax.set_xlabel('Number of Molecule Analogues for Improvement Cycle', fontsize=fontsize)
ax.set_title('Tanimoto Matches Across Different Conditions: PubChem 350-500 Da', fontsize=fontsize)

# Set x-ticks in the middle of each group
group_centers = indices + 1.5 * bar_width + 0.5 * group_spacing
ax.set_xticks(group_centers)
ax.set_xticklabels(group_labels, fontsize=fontsize)

# Set y-axis tick label font size
ax.tick_params(axis='y', labelsize=fontsize)

# Add value labels on top of each bar
for i, v in enumerate(all_perfect_matches):
    ax.text(positions[i % 4][i // 4], v, str(v), ha='center', va='bottom', fontsize=fontsize)

# Add legend
ax.legend(fontsize=fontsize, loc='upper left')

# Adjust layout and display
plt.tight_layout()

# Save the plot to a specified location
output_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/6.2.2_PC_Bar_Chart_Greedy_MNS_Tanimoto_350_500_v3.png"
plt.savefig(output_path, dpi=300, bbox_inches='tight')

plt.show()

In [None]:
# Extract the variables
"""
corr_sampleing_prob_bl_ZINC = imp_cyc_all_100_10_before['corr_sampleing_prob_bl_ZINC']
results_dict_ZINC_greedy_bl = imp_cyc_all_100_10_before['results_dict_ZINC_greedy_bl']
results_dict_greedy_bl_ZINC = imp_cyc_all_100_10_before['results_dict_greedy_bl_ZINC']
tani_PC_350_500_FT = PC_350_500_FT["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
"""

#avg_tani_bl_ZINC = imp_cyc_all_100_10_before['avg_tani_bl_ZINC']
#failed_bl_ZINC = imp_cyc_all_100_10_before['failed_bl_ZINC']
#avg_tani_greedy_bl_ZINC = imp_cyc_all_100_10_before['avg_tani_greedy_bl_ZINC']
#total_results_bl_ZINC = imp_cyc_all_100_10_before['total_results_bl_ZINC']
#filtered_results = imp_cyc_all_100_10_before['filtered_results']

In [None]:
corr_sp_before_all_100_10 = imp_cyc_all_100_10_before["corr_sampleing_prob_bl_ZINC"]
corr_sp_after_all_100_10 = imp_cyc_all_100_10_after["corr_sampleing_prob_bl_ZINC"]
corr_sp_after_all_100_10_2 = imp_cyc_all_100_10_after_2["corr_sampleing_prob_bl_ZINC"]
corr_sp_before_all_100_30 = imp_cyc_all_100_30_before["corr_sampleing_prob_bl_ZINC"]
corr_sp_after_all_100_30 = imp_cyc_all_100_30_after["corr_sampleing_prob_bl_ZINC"]
corr_sp_before_all_100_50 = imp_cyc_all_100_50_before["corr_sampleing_prob_bl_ZINC"]
corr_sp_after_all_100_50 = imp_cyc_all_100_50_after["corr_sampleing_prob_bl_ZINC"]
corr_sp_before_all_100_100 = imp_cyc_all_100_100_before["corr_sampleing_prob_bl_ZINC"]
corr_sp_after_all_100_100 = imp_cyc_all_100_100_after["corr_sampleing_prob_bl_ZINC"]
corr_PC_350_500_FT = PC_350_500_FT["corr_sampleing_prob_bl_ZINC"]


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Data for the bar chart
values = [
    corr_sp_before_all_100_10,  # Before (same for all)
    corr_PC_350_500_FT,         # PC FT
    corr_sp_after_all_100_10,   # After FT10
  #  corr_sp_after_all_100_10_2,  # After 2x FT10
    corr_sp_after_all_100_30,   # After FT30
    corr_sp_after_all_100_50,   # After FT50
    corr_sp_after_all_100_100   # After FT100
]

# Labels for the bars
labels = [
    'MMST Model', 
    'PC FT', 
    'IC-10', 
    #'MMTi Imp-Cycle: 2x FT10', 
    'IC-30', 
    'IC-50', 
    'IC-100'
]

# Colors for the bars (based on provided image)
colors = [
    '#A1C8F3',  # light blue/periwinkle
    '#B8F1EF',  # light cyan/aqua
    '#FFB381',  # salmon/peach
    '#8BE5A0',  # mint green
    '#FF9D9A',  # coral pink
    '#D1B9FE',  # lavender
    '#DEBA9A',  # beige/tan
    '#FCAEE3',  # pink
    '#CFCECE',  # light gray
    '#FEFDA2',  # pale yellow
]

# Create a bar chart
fig, ax = plt.subplots(figsize=(17, 9))
bars = ax.bar(labels, values, color=colors, edgecolor='black')

# Add numbers inside the bars
for bar, value in zip(bars, values):
    ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() / 2, f'{value:.2f}', ha='center', va='center', fontsize=fontsize, color='black')

# Add title and labels
plt.title('Averaged Correct Sample Probability: PubChem 350-500 Da', fontsize=fontsize)
plt.ylabel('Averaged Correct Sample Probability', fontsize=21)
ax.set_xlabel('Number of Molecule Analogues for Improvment Cycle', fontsize=fontsize)

# Increase the size of the labels on the x-axis
ax.set_xticklabels(labels, fontsize=fontsize, rotation=0)

# Increase the size of the ticks on the y-axis
ax.tick_params(axis='y', labelsize=fontsize)
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x:.2f}'))

# Save the plot to a specified location
output_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/6.2.2_PC_correct_sample_prob_comparison_350_500_v3.png"
plt.savefig(output_path, bbox_inches='tight')

# Show the plot
plt.show()


#### PubChem Second round training

In [None]:
# 100_10 0_250
import pickle

file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_0_250/6.2_imp_cyc_all_100_10.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_before_0_250 = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_0_250/6.2_imp_cyc_all_100_10_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_after_0_250 = pickle.load(file)

#maybe need to take the first one    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240523_PC_Imp_Cycle_2_from_250_350/6.2_imp_cyc_all_100_10_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_after_2_0_250 = pickle.load(file)


In [None]:
# 100_10 250_350
import pickle

file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_250_350/6.2_imp_cyc_all_100_10.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_before_250_350 = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240520_Improvment_Cycle_PC_250_350/6.2_imp_cyc_all_100_10_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_after_250_350 = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240523_PC_Imp_Cycle_2_from_250_350/6.2_imp_cyc_all_100_10_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_after_2_250_350 = pickle.load(file)

In [None]:
# 100_10 350_500
import pickle

file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_PC_350_500/6.2_imp_cyc_all_100_10.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_before_350_500 = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240517_Improvment_Cycle_PC_350_500/6.2_imp_cyc_all_100_10_after.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_after_350_500 = pickle.load(file)
    
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240522_PC_Imp_Cycle_2_from_350_500/6.2_imp_cyc_all_100_10_after_v3_second_round.pkl'
# Loading results_list_b
with open(file_path, 'rb') as file:
    imp_cyc_all_100_10_after_2_350_500 = pickle.load(file)

In [None]:

tani_before_0_250 = imp_cyc_all_100_10_before_0_250["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_0_250 = imp_cyc_all_100_10_after_0_250["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_2_0_250 = imp_cyc_all_100_10_after_2_0_250["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_before_250_350 = imp_cyc_all_100_10_before_250_350["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_250_350 = imp_cyc_all_100_10_after_250_350["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_2_250_350 = imp_cyc_all_100_10_after_2_250_350["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_before_350_500 = imp_cyc_all_100_10_before_350_500["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_350_500 = imp_cyc_all_100_10_after_350_500["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]
tani_after_2_350_500 = imp_cyc_all_100_10_after_2_350_500["results_dict_ZINC_greedy_bl"]["tanimoto_sim"]


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Data and labels
tani_data = [
    (tani_before_0_250, 'FT0-250 Before'), (tani_after_0_250, 'FT0-250 After'), (tani_after_2_0_250, 'FT0-250 2nd Round'),
    (tani_before_250_350, 'FT250-350 Before'), (tani_after_250_350, 'FT250-350 After'), (tani_after_2_250_350, 'FT250-350 2nd Round'),
    (tani_before_350_500, 'FT350-500 Before'), (tani_after_350_500, 'FT350-500 After'), (tani_after_2_350_500, 'FT350-500 2nd Round')
]

# Count perfect matches (Tanimoto = 1) for each condition
perfect_matches = [np.sum(np.array(data) == 1) for data, _ in tani_data]

# Colors for the bars
colors = [    '#A1C8F3',  # light blue/periwinkle
    '#FFB381',  # salmon/peach
    '#8BE5A0',  # mint green
    '#FF9D9A',  # coral pink
    '#D1B9FE',  # lavender
    '#DEBA9A',  # beige/tan
    '#FCAEE3',  # pink
    '#CFCECE',  # light gray
    '#FEFDA2',  # pale yellow
    '#B8F1EF',  # light cyan/aqua
        ]

# Labels
group_labels = ['PubChem: 0-250', 'PubChem: 250-350', 'PubChem: 350-500']
condition_labels = ['Before IC', 'After IC', 'After 2x IC']

# Create the plot
fig, ax = plt.subplots(figsize=(17, 9))

# Set the width of each bar and the spacing between groups
bar_width = 0.25
group_spacing = 0.05

# Calculate positions for each bar
num_groups = len(group_labels)
indices = np.arange(num_groups)
positions = [indices + i * (bar_width + group_spacing) for i in range(3)]

# Create grouped bar plot
for i in range(3):  # Three conditions: Before, After, 2nd Round
    ax.bar(positions[i], perfect_matches[i::3], 
           width=bar_width, color=colors[i], edgecolor='black', 
           label=condition_labels[i])

# Customize the plot
ax.set_ylabel('Number of Perfect Matches (Tanimoto = 1)', fontsize=fontsize)
ax.set_xlabel('Molecular Weight Ranges', fontsize=fontsize)
ax.set_title('Tanimoto Matches Across Different Conditions: PubChem Data', fontsize=fontsize)

# Set x-ticks in the middle of each group
group_centers = indices + bar_width
ax.set_xticks(group_centers)
ax.set_xticklabels(group_labels, fontsize=fontsize)

# Set y-axis tick label font size
ax.tick_params(axis='y', labelsize=fontsize)

# Add value labels on top of each bar
for i, v in enumerate(perfect_matches):
    ax.text(positions[i % 3][i // 3], v, str(v), ha='center', va='bottom', fontsize=fontsize)

# Add legend to the upper right corner
ax.legend(fontsize=fontsize, loc='upper left')

# Adjust layout and display
plt.tight_layout()

# Save the plot to a specified location
output_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/6.2.2_PC_Bar_Chart_Greedy_second_round_v3.png"
plt.savefig(output_path, dpi=300, bbox_inches='tight')

plt.show()

In [None]:

corr_sp_before_0_250 = imp_cyc_all_100_10_before_0_250["corr_sampleing_prob_bl_ZINC"]
corr_sp_after_0_250 = imp_cyc_all_100_10_after_0_250["corr_sampleing_prob_bl_ZINC"]
corr_sp_after_2_0_250 = imp_cyc_all_100_10_after_2_0_250["corr_sampleing_prob_bl_ZINC"]
corr_sp_before_250_350 = imp_cyc_all_100_10_before_250_350["corr_sampleing_prob_bl_ZINC"]
corr_sp_after_250_350 = imp_cyc_all_100_10_after_250_350["corr_sampleing_prob_bl_ZINC"]
corr_sp_after_2_250_350 = imp_cyc_all_100_10_after_2_250_350["corr_sampleing_prob_bl_ZINC"]
corr_sp_before_350_500 = imp_cyc_all_100_10_before_350_500["corr_sampleing_prob_bl_ZINC"]
corr_sp_after_350_500 = imp_cyc_all_100_10_after_350_500["corr_sampleing_prob_bl_ZINC"]
corr_sp_after_2_350_500 = imp_cyc_all_100_10_after_2_350_500["corr_sampleing_prob_bl_ZINC"]


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Data for the bar chart
values = [
    np.mean(corr_sp_before_0_250), np.mean(corr_sp_after_0_250), np.mean(corr_sp_after_2_0_250),
    np.mean(corr_sp_before_250_350), np.mean(corr_sp_after_250_350), np.mean(corr_sp_after_2_250_350),
    np.mean(corr_sp_before_350_500), np.mean(corr_sp_after_350_500), np.mean(corr_sp_after_2_350_500)
]

# Colors for the bars
colors = ['#E57373', '#A2A37E', '#64B689']
colors = [    '#A1C8F3',  # light blue/periwinkle
    '#FFB381',  # salmon/peach
    '#8BE5A0',  # mint green
    '#FF9D9A',]  # coral pink]  # Red, Blue, Orange, Green

# Create a bar chart
fig, ax = plt.subplots(figsize=(17, 9))

# Set the width of each bar and the spacing between groups
bar_width = 0.25
group_spacing = 0.05

# Calculate positions for each bar
indices = np.arange(3)
positions = [indices + i * (bar_width + group_spacing) for i in range(3)]

# Create grouped bar plot
for i in range(3):  # Three conditions: Before, After, 2nd Round
    bars = ax.bar(positions[i], values[i::3], 
                  width=bar_width, color=colors[i], edgecolor='black', 
                  label=['Before IC', 'After IC', 'After 2x IC'][i])
    
    # Add numbers inside the bars
    for bar, value in zip(bars, values[i::3]):
        ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() / 2, 
                f'{value:.2f}', ha='center', va='center', fontsize=fontsize, color='black')

# Customize the plot
ax.set_ylabel('Averaged Correct Sample Probability', fontsize=fontsize)
ax.set_title('Averaged Correct Sample Probability: PubChem Data', fontsize=fontsize)
ax.set_xlabel('Molecular Weight Ranges', fontsize=fontsize)

# Set x-ticks in the middle of each group
group_centers = indices + bar_width
ax.set_xticks(group_centers)
ax.set_xticklabels(['PubChem 0-250', 'PubChem 250-350', 'PubChem 350-500'], fontsize=fontsize)

# Set y-axis to display two decimal places
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x:.2f}'))

# Increase the size of the ticks on the y-axis
ax.tick_params(axis='y', labelsize=fontsize)

# Add legend
ax.legend(fontsize=fontsize, loc='upper left')

# Adjust layout
plt.tight_layout()

# Save the plot to a specified location
output_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/6.2.2_PC_correct_sample_prob_comparison_second_ro_v3.png"
plt.savefig(output_path, dpi=300, bbox_inches='tight')

# Show the plot
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Data for the bar chart
values = [
    np.mean(corr_sp_before_0_250),  # Before (0-250)
    np.mean(corr_sp_after_0_250),   # After 0-250
    np.mean(corr_sp_after_2_0_250), # 2nd Round After 0-250
    np.mean(corr_sp_before_250_350),  # Before (250-350)
    np.mean(corr_sp_after_250_350),   # After 250-350
    np.mean(corr_sp_after_2_250_350), # 2nd Round After 250-350
    np.mean(corr_sp_before_350_500),  # Before (350-500)
    np.mean(corr_sp_after_350_500),   # After 350-500
    np.mean(corr_sp_after_2_350_500)  # 2nd Round After 350-500
]

# Labels for the bars
labels = [
    'Before IC 0-250', 
    'After IC 0-250', 
    'After 2x IC 0-250',
    'Before IC 250-350', 
    'After IC 250-350', 
    'After 2x IC 250-350',
    'Before IC 350-500', 
    'After IC 350-500', 
    'After 2x IC 350-500'
]

# Colors for the bars (based on provided image)
colors = [
    '#A1C8F3',  # light blue/periwinkle
    '#FFB381',  # salmon/peach
    '#8BE5A0',  # mint green
    '#FF9D9A',  # coral pink
    '#D1B9FE',  # lavender
    '#DEBA9A',  # beige/tan
    '#FCAEE3',  # pink
    '#CFCECE',  # light gray
    '#FEFDA2',  # pale yellow
    '#B8F1EF',  # light cyan/aqua
]

# Create a bar chart
fig, ax = plt.subplots(figsize=(14, 7))
bars = ax.bar(labels, values, color=colors, edgecolor='black')

# Add numbers inside the bars
for bar, value in zip(bars, values):
    ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() / 2, f'{value:.2f}', ha='center', va='center', fontsize=fontsize, color='black')

# Add title and labels
plt.title('Averaged Correct Sample Probability', fontsize=fontsize)
plt.ylabel('Averaged Correct Sample Probability', fontsize=fontsize)
#plt.xlabel('Stages', fontsize=16)

# Increase the size of the labels on the x-axis
ax.set_xticklabels(labels, fontsize=fontsize, rotation=45, ha='right')

# Increase the size of the ticks on the y-axis
ax.tick_params(axis='y', labelsize=fontsize)

# Save the plot to a specified location
output_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/6.2.2_PC_correct_sample_prob_comparison_second_ro_v3.png"
plt.savefig(output_path, bbox_inches='tight')

# Show the plot
plt.show()


### 7.0 Experimental vs Simulated Data

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def plot_comparison_bars(data_dict, colors, figsize=(12, 6), save_path=None):
    """
    Create a bar plot comparing different metrics across all categories.
    Shows accuracy as percentages and adds rotated value labels above bars.
    
    Parameters:
    -----------
    data_dict : dict
        Dictionary with categories and their accuracy values (0-1 scale)
    colors : list
        List of colors for the bars
    figsize : tuple
        Figure size as (width, height)
    save_path : str, optional
        Full path where to save the figure
    """
    # Check if data_dict is empty and use dummy data if it is
    if not data_dict:
        data_dict = {
            'Simulated': [0.18, 0.20, 0.22],
            'IC on Simulated target': [0.32, 0.33, 0.34],
            'IC on Simulated (analogue)': [0.28, 0.30, 0.32],
            'Experimental': [0.02, 0.03, 0.04],
            'IC on Experimental (target)': [0.12, 0.22, 0.25],
            'IC on Experimental (analogue)': [0.10, 0.20, 0.25],
        }
    
    # Create the figure with specified figsize
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    
    # Calculate positions for bars
    categories = list(data_dict.keys())
    n_categories = len(categories)
    n_metrics = len(list(data_dict.values())[0])
    
    # Set width of bars and positions of the bars
    total_width = 0.8
    width = total_width / n_metrics
    x = np.arange(n_categories)
    
    # Create bars for each metric
    metric_names = ['top-1', 'top-3', 'top-10']
    
    # Set y-axis limit first to properly position labels
    ax.set_ylim(0, 1.1)  # Set y-axis limit from 0 to 110%
    
    for i in range(n_metrics):
        values = [data_dict[cat][i] for cat in categories]
        offset = width * i - (total_width/2) + (width/2)
        bars = ax.bar(x + offset, values, width, label=metric_names[i], color=colors[i], edgecolor='black')
        
        # Add value labels above bars with more spacing
        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2, 
                   height + 0.03,  # Add more space between bar and label
                   f'{height*100:.0f}%',
                   ha='center', 
                   va='bottom',
                   rotation=90,
                   fontsize=10)
    
    # Customize the plot
    ax.set_xlabel('')
    ax.set_ylabel('Accuracy (%)', fontsize=14)
    ax.set_title('Comparison Across Categories', fontsize=16, pad=20)
    
    # Set x-ticks
    ax.set_xticks(x)
    ax.set_xticklabels(categories, rotation=45, ha='right', fontsize=12)
    
    # Set y-ticks and convert to percentages
    ax.tick_params(axis='y', labelsize=12)
    y_ticks = np.arange(0, 1.2, 0.2)  # Create ticks at 0%, 20%, 40%, 60%, 80%, 100%
    ax.set_yticks(y_ticks)
    ax.set_yticklabels([f'{x*100:.0f}%' for x in y_ticks])
    
    # Add grid
    ax.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Add legend
    ax.legend(fontsize=12, bbox_to_anchor=(1.05, 1), loc='upper left')
    
    # Adjust layout to prevent label cutoff
    plt.tight_layout()
    
    # Save if path is provided
    if save_path:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
        print(f"Plot saved to: {save_path}")
    
    plt.show()

# Example usage
if __name__ == "__main__":
    colors = [
        '#A1C8F3',  # light blue/periwinkle
        '#B8F1EF',  # light cyan/aqua
        '#FFB381',  # salmon/peach
        #'#8BE5A0',  # mint green
        #'#D1B9FE',  # lavender
    ]
    
    # Example save path
    save_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/7.0_IC_comparison_exp_sim_aug_plot.png"
    
    # Call with dummy data
    #plot_comparison_bars({}, colors, save_path=save_path)

In [None]:
import numpy as np
import matplotlib.pyplot as plt


def plot_comparison_bars(data_dict, colors, figsize=(12, 6), save_path=None):
    """
    Create a bar plot comparing different metrics across all categories.
    Shows accuracy as percentages with labels inside bars (except for BL Experimental).
    X-axis labels are split into two rows.
    Axis labels: font size 18
    All other text: font size 20
    """
    # Check if data_dict is empty and use dummy data if it is
    if not data_dict:
        data_dict = {
            'BL Simulated (target)': [0.576, 0.576, 0.606],
            'IC Simulated (target)': [0.971, 1.0, 1.0],
            'IC Simulated (analogue)': [0.529, 0.559, 0.706],
            'BL Experimental (target)': [0.0, 0.0, 0.029],
            'IC Experimental (target)': [0.312, 0.562, 0.812],
            'IC Experimental (analogue)': [0.118, 0.382, 0.441]
        }
    
    # Create the figure with specified figsize
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    
    # Calculate positions for bars
    categories = list(data_dict.keys())
    n_categories = len(categories)
    n_metrics = len(list(data_dict.values())[0])
    
    # Set width of bars and positions of the bars
    total_width = 0.8
    width = total_width / n_metrics
    x = np.arange(n_categories)
    
    # Create bars for each metric
    metric_names = ['top-1', 'top-3', 'top-10']
    
    # Set y-axis limit first to properly position labels
    ax.set_ylim(0, 1.1)
    
    for i in range(n_metrics):
        values = [data_dict[cat][i] for cat in categories]
        offset = width * i - (total_width/2) + (width/2)
        bars = ax.bar(x + offset, values, width, label=metric_names[i], color=colors[i], edgecolor='black')
        
        # Add value labels
        for j, bar in enumerate(bars):
            height = bar.get_height()
            
            # Special case for BL Experimental (target) where height is very small
            if categories[j] == 'BL Experimental (target)' and height < 0.1:
                # Place label above the bar
                ax.text(bar.get_x() + bar.get_width()/2, 
                       height + 0.03,
                       f'{height*100:.0f}%',
                       ha='center', 
                       va='bottom',
                       rotation=90,
                       fontsize=20)
            else:
                # Place label inside the bar
                ax.text(bar.get_x() + bar.get_width()/2, 
                       height/2,  # Center vertically inside bar
                       f'{height*100:.0f}%',
                       ha='center', 
                       va='center',
                       rotation=90,
                       fontsize=20)
    
    # Customize the plot
    ax.set_xlabel('')
    ax.set_ylabel('Accuracy (%)', fontsize=18)  # Changed to 18
    ax.set_title('Comparison Across Categories', fontsize=20, pad=20)
    
    # Create two-line labels
    def split_category(category):
        if '(target)' in category or '(analogue)' in category:
            main_part = category.split(' (')[0]
            suffix = f'({category.split("(")[1]}'
            return f'{main_part}\n{suffix}'
        return category
    
    # Set x-ticks
    ax.set_xticks(x)
    ax.set_xticklabels([split_category(cat) for cat in categories], 
                       ha='center',  # Center align
                       fontsize=18)  # Changed to 18
    
    # Set y-ticks and convert to percentages
    ax.tick_params(axis='y', labelsize=18)  # Changed to 18
    y_ticks = np.arange(0, 1.2, 0.2)
    ax.set_yticks(y_ticks)
    ax.set_yticklabels([f'{x*100:.0f}%' for x in y_ticks])
    
    # Add grid
    ax.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Add legend inside the plot in the upper right corner
    ax.legend(fontsize=20, loc='upper right')
    
    # Adjust layout to prevent label cutoff
    plt.tight_layout()
    
    # Save if path is provided
    if save_path:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
        print(f"Plot saved to: {save_path}")
    
    plt.show()

    
# Example usage
if __name__ == "__main__":
    colors = [
        '#A1C8F3',  # light blue/periwinkle
        '#B8F1EF',  # light cyan/aqua
        '#FFB381',  # salmon/peach
    ]
    
    # Example save path
    save_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/7.0_IC_comparison_exp_sim_aug_plot.png"
    
    # Call with dummy data
    #plot_comparison_bars({}, colors, save_path=save_path)

In [None]:
len(list(data.values())[0])

In [None]:
data

#### Extracting the data for the plotting

In [None]:
import os
import pickle
from collections import defaultdict
from rdkit import Chem



def process_pkl_files_new(folder_path, file_type, ranking_method):
    pkl_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) 
                 if f.endswith('.pkl') and file_type in f]
    all_rankings = defaultdict(list)
    for file_path in pkl_files:
        file_data = load_data(file_path)
        
        for trg_smi, value_list in file_data.items():
            for sublist in value_list[0]:
                
                gen_smi = sublist[0]
                tanimoto = sublist[4]
                errors = sublist[5]
                if errors == 9:  # Check if both errors are 9
                    errors = [9, 9]  # Keep it as is
                try:
                    all_rankings[trg_smi].append((trg_smi, gen_smi, tanimoto, 0, errors[0], errors[1]))
                except:
                    import IPython; IPython.embed();           
    all_rankings = rank_all_molecules(all_rankings, ranking_method)
    
    return all_rankings


import numpy as np
from collections import defaultdict


def rank_all_molecules(all_rankings, ranking_method):
    new_rankings = defaultdict(list)
    
    for trg_smi, molecule_list in all_rankings.items():
        molecule_data = []
        seen_gen_smiles = set()
        
        for molecule in molecule_list:
            gen_smi = molecule[1]
            if gen_smi in seen_gen_smiles:
                continue
            seen_gen_smiles.add(gen_smi)
            
            tanimoto = molecule[2]
            errors = molecule[4:6]
            if errors == (9, 9):  # Check if both errors are 9
                errors = [9, 9]  # Keep it as is
            molecule_data.append((gen_smi, errors[0], errors[1], tanimoto))
        
        if not molecule_data:
            continue
        
        molecule_array = np.array(molecule_data, dtype=[('gen_smi', 'U100'), 
                                                        ('error1', float), ('error2', float), 
                                                        ('tanimoto', float)])
        
        if ranking_method == 'HSQC & COSY':
            rank1 = molecule_array.argsort(order='error1')
            rank2 = molecule_array.argsort(order='error2')
            average_ranks = (np.arange(len(rank1))[np.argsort(rank1)] + 
                             np.arange(len(rank2))[np.argsort(rank2)]) / 2
            sorted_indices = average_ranks.argsort()
            
        elif ranking_method == 'HSQC':
            sorted_indices = molecule_array.argsort(order='error1')
        elif ranking_method == 'COSY':
            sorted_indices = molecule_array.argsort(order='error2')
        else:
            raise ValueError("Invalid ranking method. Choose 'HSQC & COSY', 'HSQC', or 'COSY'.")
        
        for new_rank, i in enumerate(sorted_indices):
            gen_smi = molecule_array['gen_smi'][i]
            tanimoto = molecule_array['tanimoto'][i]
            error1 = molecule_array['error1'][i]
            error2 = molecule_array['error2'][i]
            
            new_rankings[trg_smi].append((trg_smi, gen_smi, tanimoto, new_rank, error1, error2))
    
    return new_rankings

def load_data(file_path):
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    return data["results_dict_bl_ZINC"]

def process_experimental_data(base_folder_path, ic_folder_path, analogue_folder_path):
    """
    Process experimental data from different sources and prepare it for plotting.
    
    Parameters:
    -----------
    base_folder_path : str
        Path to the folder containing base experiment data (without IC)
    ic_folder_path : str
        Path to the folder containing IC experiment data
    analogue_folder_path : str
        Path to the folder containing analogue experiment data
    
    Returns:
    --------
    dict
        Dictionary with processed data ready for plotting
    """
    # Define constants
    ranking_method = "HSQC"
    file_types = ["exp_sim_data", "sim_sim_data"]
    
    plot_data = {}
    
    # Process IC experiments
    for file_type in file_types:
        IC_rankings = process_pkl_files_new(ic_folder_path, file_type, ranking_method)
        IC_rankings, _ = exp_func.deduplicate_smiles_from_ranking(IC_rankings)
        IC_rankings, _ = exp_func.filter_rankings_by_molecular_formula(IC_rankings)
        accuracies = exp_func.calculate_top_k_accuracy(IC_rankings, k_range=[1, 3, 10])
        accuracies = [round(acc, 3) for acc in accuracies][:3]

        if file_type == "exp_sim_data":
            plot_data['IC Experimental (target)'] = accuracies
        else:
            plot_data['IC Simulated (target)'] = accuracies
            
    # Process base experiments (without IC)
    for file_type in file_types:
        if file_type == "exp_sim_data":
            input_dict = {"exp_sim_data" : base_pkl_dicts["exp_sim_data"]}
            rankings = exp_func.process_pkl_files_BL(input_dict, ranking_method)
            accuracies = exp_func.calculate_top_k_accuracy_BL(rankings["exp_sim_data"],k_range=[1, 3, 10])
            accuracies = [round(acc, 3) for acc in accuracies][:3]
            plot_data['BL Experimental (target)'] = accuracies

        elif file_type == "sim_sim_data":
            input_dict = {"sim_sim_data" : base_pkl_dicts["sim_sim_data"]}
            rankings = exp_func.process_pkl_files_BL(input_dict, ranking_method)
            accuracies = exp_func.calculate_top_k_accuracy_BL(rankings["sim_sim_data"],k_range=[1, 3, 10])       
            accuracies = [round(acc, 3) for acc in accuracies][:3]
            plot_data['BL Simulated (target)'] = accuracies

            
    # Process analogue experiments
    for file_type in file_types:
        analogue_rankings = process_pkl_files_new(analogue_folder_path, file_type, ranking_method)
        analogue_rankings, _ = exp_func.deduplicate_smiles_from_ranking(analogue_rankings)
        analogue_rankings, _ = exp_func.filter_rankings_by_molecular_formula(analogue_rankings)
        accuracies = exp_func.calculate_top_k_accuracy(analogue_rankings, k_range=[1, 3, 10])
        accuracies = [round(acc, 3) for acc in accuracies][:3]
        
        if file_type == "exp_sim_data":
            plot_data['IC Experimental (analogue)'] = accuracies
        else:
            plot_data['IC Simulated (analogue)'] = accuracies
            
    return plot_data

def reorder_data_dict(data_dict):
    """
    Reorders the data dictionary to group simulated and experimental results together.
    """
    # Define the desired order
    desired_order = [
        'BL Simulated (target)',
        'IC Simulated (target)',
        'IC Simulated (analogue)',
        'BL Experimental (target)',
        'IC Experimental (target)',
        'IC Experimental (analogue)'
    ]
    
    # Create new ordered dictionary
    ordered_dict = {key: data_dict[key] for key in desired_order if key in data_dict}
    
    return ordered_dict

# Helper functions remain the same as in previous version

In [None]:
# Define your folder paths
base_folder = ""
ic_folder = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240724_IC_batched_MMT_a1/"
base_pkl_dicts = {
        "sim_sim_data": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240724_IC_batched_MMT_a1_baseline/8.0_sim_real_data_before_FT_MMTi_v0.pkl",
        "exp_sim_data": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240724_IC_batched_MMT_a1_baseline/8.3_real_data_before_FT_MMTi_v0.pkl"
        }    
analogue_folder = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Sim_ACD_Exp_aug/test_on_aug_data_34_v3/"

# Process the data
data = process_experimental_data(base_pkl_dicts, ic_folder, analogue_folder)
data = reorder_data_dict(data)
# Plot with your color scheme
colors = [
    '#A1C8F3',  # light blue/periwinkle
    '#FFB381',  # salmon/peach
    '#8BE5A0',  # mint green
    #'#FF9D9A',  # coral pink
    #'#D1B9FE',  # lavender
    #'#DEBA9A',  # beige/tan
]
save_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/7.0_IC_comparison_exp_sim_aug_plot.png"

plot_comparison_bars( data_dict=data,     colors=colors,     save_path=save_path,    figsize=(12, 6))

In [None]:
data

#### Extract Top 1 and top 3 correct Molecules

In [None]:
def extract_correct_smiles_analogues(analogue_folder_path, file_type="exp_sim_data", ranking_method="HSQC"):
    """
    Extract SMILES of molecules that are correct matches in top 1 and top 3 rankings
    for experimental analogues.
    
    Parameters:
    -----------
    analogue_folder_path : str
        Path to the folder containing analogue experiment data
    file_type : str
        Type of file to process ('exp_sim_data' or 'sim_sim_data')
    ranking_method : str
        Method used for ranking ('HSQC', 'COSY', or 'HSQC & COSY')
        
    Returns:
    --------
    tuple
        (correct_top1_pairs, correct_top3_pairs) where each is a list of tuples containing
        (target_smiles, generated_smiles, rank) where rank indicates which position (1,2,3)
        had the correct match
    """
    # Process the rankings
    analogue_rankings = process_pkl_files_new(analogue_folder_path, file_type, ranking_method)
    
    # Deduplicate and filter rankings
    analogue_rankings, _ = exp_func.deduplicate_smiles_from_ranking(analogue_rankings)
    analogue_rankings, _ = exp_func.filter_rankings_by_molecular_formula(analogue_rankings)
    
    correct_top1_pairs = []
    correct_top3_pairs = []
    
    # For each target molecule
    for target_smi, rankings in analogue_rankings.items():
        # Sort rankings by rank
        sorted_rankings = sorted(rankings, key=lambda x: x[3])  # x[3] is the rank
        
        # Check top 1 match
        if sorted_rankings:
            top1_gen_smi = sorted_rankings[0][1]  # x[1] is generated SMILES
            if top1_gen_smi == target_smi:  # Only append if it's a correct match
                correct_top1_pairs.append((target_smi, top1_gen_smi, 1))
        
        # Check each of the top 3 matches individually
        for idx, ranking in enumerate(sorted_rankings[:3], 1):  # idx will be 1, 2, or 3
            gen_smi = ranking[1]
            if gen_smi == target_smi:  # If we find a match at any position
                correct_top3_pairs.append((target_smi, gen_smi, idx))
    
    return correct_top1_pairs, correct_top3_pairs

def save_correct_smiles_to_file(smiles_pairs, output_path, include_rank=False):
    """
    Save correct SMILES pairs to a text file.
    
    Parameters:
    -----------
    smiles_pairs : list
        List of (target_smiles, generated_smiles, rank) tuples that are correct matches
    output_path : str
        Path where to save the file
    include_rank : bool
        If True, includes which position (1,2,3) had the correct match
    """
    with open(output_path, 'w') as f:
        if include_rank:
            f.write("SMILES\tRank\n")
            for target_smi, _, rank in smiles_pairs:  # We can use either SMILES since they're identical
                f.write(f"{target_smi}\t{rank}\n")
        else:
            f.write("SMILES\n")
            for target_smi, _, _ in smiles_pairs:
                f.write(f"{target_smi}\n")


In [None]:
### with Analogues
# Example usage:

analogue_folder = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Sim_ACD_Exp_aug/test_on_aug_data_34_v3/"

# Extract only correct SMILES matches
correct_top1_pairs, correct_top3_pairs = extract_correct_smiles_analogues(
    analogue_folder,
    file_type="exp_sim_data",
    ranking_method="HSQC"
)

# Save to files (only correct matches)
# For top 1, we don't need the rank
save_correct_smiles_to_file(correct_top1_pairs, "correct_top1_experimental_analogues.txt", include_rank=False)
# For top 3, we might want to know which position had the match
save_correct_smiles_to_file(correct_top3_pairs, "correct_top3_experimental_analogues.txt", include_rank=True)

# Optional: Print some statistics
print(f"Number of correct top 1 matches: {len(correct_top1_pairs)}")
print(f"Number of correct top 3 matches: {len(correct_top3_pairs)}")
print("\nBreakdown of top 3 matches by position:")
for rank in [1, 2, 3]:
    count = sum(1 for _, _, r in correct_top3_pairs if r == rank)
    print(f"Position {rank}: {count} matches")


In [None]:
pwd

In [None]:

# Example usage:

analogue_folder = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240724_IC_batched_MMT_a1/"

# Extract only correct SMILES matches
correct_top1_pairs, correct_top3_pairs = extract_correct_smiles_analogues(
    analogue_folder,
    file_type="exp_sim_data",
    ranking_method="HSQC"
)

# Save to files (only correct matches)
# For top 1, we don't need the rank
save_correct_smiles_to_file(correct_top1_pairs, "correct_top1_experimental_analogues.txt", include_rank=False)
# For top 3, we might want to know which position had the match
save_correct_smiles_to_file(correct_top3_pairs, "correct_top3_experimental_analogues.txt", include_rank=True)


# Optional: Print some statistics
print(f"Number of correct top 1 matches: {len(correct_top1_pairs)}")
print(f"Number of correct top 3 matches: {len(correct_top3_pairs)}")
print("\nBreakdown of top 3 matches by position:")
for rank in [1, 2, 3]:
    count = sum(1 for _, _, r in correct_top3_pairs if r == rank)
    print(f"Position {rank}: {count} matches")


In [None]:
Chem.MolFromSmiles("CCc1c(C)[nH]c2c1C(=O)C(CN1CCOCC1)CC2")

### 7.0 Simulated, ACD, Experimental results

#### Plotting

In [None]:
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from rdkit import Chem
from rdkit.Chem import AllChem, Draw, Descriptors
import pandas as pd
import random
import string
import rdkit
from rdkit import Chem

colors = [
    '#E57373', '#A2A37E', '#64B689', '#4DA6A9', '#5C95CC', '#9574D0', '#EB6CC2'
]

def load_data(file_path):
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    return data["results_dict_bl_ZINC"]


def process_pkl_files_new(folder_path, file_type, ranking_method):
    pkl_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) 
                 if f.endswith('.pkl') and file_type in f]
    all_rankings = defaultdict(list)
    for file_path in pkl_files:
        file_data = load_data(file_path)
        
        for trg_smi, value_list in file_data.items():
            for sublist in value_list[0]:
                
                gen_smi = sublist[0]
                tanimoto = sublist[4]
                errors = sublist[5]
                if errors == 9:  # Check if both errors are 9
                    errors = [9, 9]  # Keep it as is
                try:
                    all_rankings[trg_smi].append((trg_smi, gen_smi, tanimoto, 0, errors[0], errors[1]))
                except:
                    import IPython; IPython.embed();           
    all_rankings = rank_all_molecules(all_rankings, ranking_method)
    
    return all_rankings


def calculate_top_k_accuracy(all_rankings, k_range=[1, 3, 5, 10, 20]):
    accuracies = []
    total_molecules = len(all_rankings)
    for k in k_range:
        correct_count = sum(
            any(molecule[2] == 1 for molecule in rankings[:k])
            for rankings in all_rankings.values()
        )
        accuracy = correct_count / total_molecules
        accuracies.append(accuracy)
    correct_count = sum(
        any(molecule[2] == 1 for molecule in rankings[:])
        for rankings in all_rankings.values()
    )
    accuracy = correct_count / total_molecules
    accuracies.append(accuracy)    
    return accuracies
"""
def count_molecules_with_sim_rank_one(all_rankings):
    count = sum(
        any(molecule[2] == 1.0 for molecule in rankings)
        for rankings in all_rankings.values()
    )
    return count"""

def plot_top_k_accuracy(accuracies, data_type, model_type, ranking_method, sim_rank_one_count, save_path=None, total_samples=34):
    fontsize = 30
    k_values = [1, 3, 5, 10, 20]
    labels = [f'Top {k}' for k in k_values] + ['Total']
    
    plt.figure(figsize=(12, 8))  # Slightly wider to accommodate longer labels
    bars = plt.bar(range(len(labels)), accuracies, color=colors[:len(labels)])
    
    plt.ylabel('Accuracy', fontsize=fontsize)
    plt.title(f'{data_type} Data and {ranking_method} Ranking', fontsize=fontsize+4)
    
    plt.xticks(range(len(labels)), labels, fontsize=fontsize, rotation=0, ha='center')
    plt.yticks(fontsize=fontsize)
    plt.ylim(0, 1.1)
    
    for bar in bars:
        height = bar.get_height()
        correct_samples = int(height * total_samples)
        
        plt.text(bar.get_x() + bar.get_width()/2., height,
                 f'{height:.2f}',
                 ha='center', va='bottom', fontsize=fontsize)
        
        plt.text(bar.get_x() + bar.get_width()/2., height/2,
                 f'{correct_samples}',
                 ha='center', va='center', fontsize=fontsize, color='black')
    
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Plot saved to {save_path}")
    
    plt.close()
    
"""

def get_molecular_formula(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    return Chem.rdMolDescriptors.CalcMolFormula(mol)

def filter_rankings_by_molecular_formula(all_rankings):
    filtered_rankings = defaultdict(list)
    filtered_out_rankings = defaultdict(list)

    for key, rankings in all_rankings.items():
        filtered_rankings_for_key = []
        filtered_out_rankings_for_key = []
        
        for ranking in rankings:
            if len(ranking) >= 2:
                smiles1 = ranking[0]
                smiles2 = ranking[1]
                formula1 = get_molecular_formula(smiles1)
                formula2 = get_molecular_formula(smiles2)
                
                if formula1 is not None and formula2 is not None and formula1 == formula2:
                    filtered_rankings_for_key.append(ranking)
                else:
                    filtered_out_rankings_for_key.append(ranking)
        
        if filtered_rankings_for_key:
            filtered_rankings[key] = filtered_rankings_for_key
        
        if filtered_out_rankings_for_key:
            filtered_out_rankings[key] = filtered_out_rankings_for_key

    return filtered_rankings, filtered_out_rankings


def deduplicate_smiles_from_ranking(all_rankings):
    def canonicalize(smiles):
        mol = Chem.MolFromSmiles(smiles)
        return Chem.MolToSmiles(mol) if mol else smiles
        #return Chem.MolToSmiles(mol, isomericSmiles=False, canonical=True) if mol else smiles

    deduplicated_rankings = {}
    removed_smiles = {}

    for key, rankings in all_rankings.items():
        seen_canonical = set()
        deduplicated_rankings[key] = []
        removed_smiles[key] = []

        for ranking in rankings:
            canonical_smiles = canonicalize(ranking[1])
            if canonical_smiles not in seen_canonical:
                seen_canonical.add(canonical_smiles)
                deduplicated_rankings[key].append(ranking)
            else:
                removed_smiles[key].append(ranking[1])

    return deduplicated_rankings, removed_smiles"""



def main_normal():
    folder_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240724_IC_batched_MMT_a1/"
    #folder_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Sim_ACD_Exp_aug/test_on_aug_data_34_v5/"
    model_type = "MMT"
    #ranking_methods = ["COSY", "HSQC", "HSQC & COSY"]
    ranking_methods = [ "HSQC"]
    #file_types = ["exp_sim_data", "sim_sim_data", "ACD_sim_data", ]
    file_types = ["exp_sim_data", "sim_sim_data"]

    data_type_map = {
        "exp_sim_data": "Experimental",
        "sim_sim_data": "Our Simulated",
     #   "ACD_sim_data": "ACD Simulated",
    }
    filtered_out = []
    all_rankings_list = []
    for ranking_method in ranking_methods:
        for file_type in file_types:
            all_rankings = process_pkl_files_new(folder_path, file_type, ranking_method)
            #import IPython; IPython.embed();
            all_rankings, removed_smiles = exp_func.deduplicate_smiles_from_ranking(all_rankings)
            all_rankings, filtered_out_rankings = filter_rankings_by_molecular_formula(all_rankings)
            #filtered_out.append(filtered_out_rankings)
            all_rankings_list.append(all_rankings)
            print("second_breakout")
            #import IPython; IPython.embed();

            accuracies = calculate_top_k_accuracy(all_rankings)
            sim_rank_one_count = count_molecules_with_sim_rank_one(all_rankings)

            data_type = data_type_map[file_type]
            save_path = f"/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/top_k_accuracy_MMTi_{file_type}_{ranking_method}_va_FINAL.png"
            #save_path = f"/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/top_k_accuracy_MMT_{file_type}_{ranking_method}_v10.png"

            plot_top_k_accuracy(accuracies, data_type, model_type, ranking_method, sim_rank_one_count, save_path, total_samples=len(all_rankings.keys()))

            print(f"Completed plot for {ranking_method} - {data_type}")
    return all_rankings_list, filtered_out


In [None]:
import os
import pickle
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem, Draw, Descriptors
from IPython.display import display, SVG
from collections import defaultdict

def load_data(file_path):
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    return data["results_dict_bl_ZINC"]


def rank_molecules_in_file(file_data, ranking_method):
    molecule_data = []
    for trg_smi, value_list in file_data.items():
        for sublist in value_list[0]:
            gen_smi = sublist[0]
            tanimoto = sublist[4]
            errors = sublist[5]
            if errors == 9:
                errors = [9,9] ### Necessary if it doesn't manage to calculate HSQC or COSY to calculate the errors
                                ## Basically put it last then
            molecule_data.append((trg_smi, gen_smi, errors[0], errors[1], tanimoto))

    if not molecule_data:
        return []

    molecule_array = np.array(molecule_data, dtype=[('trg_smi', 'U100'), ('gen_smi', 'U100'), 
                                                    ('error1', float), ('error2', float), 
                                                    ('tanimoto', float)])
    
    if ranking_method == 'HSQC & COSY':
        rank1 = molecule_array.argsort(order='error1')
        rank2 = molecule_array.argsort(order='error2')
        average_ranks = (np.arange(len(rank1))[np.argsort(rank1)] + 
                         np.arange(len(rank2))[np.argsort(rank2)]) / 2
        sorted_indices = average_ranks.argsort()
        
    elif ranking_method == 'HSQC':
        sorted_indices = molecule_array.argsort(order='error1')
    elif ranking_method == 'COSY':
        sorted_indices = molecule_array.argsort(order='error2')
    else:
        raise ValueError("Invalid ranking method. Choose 'HSQC & COSY', 'HSQC', or 'COSY'.")

    sorted_molecules = [(molecule_array['trg_smi'][i], 
                         molecule_array['gen_smi'][i],
                         molecule_array["tanimoto"][i],
                         new_rank,  # Use index as rank
                         molecule_array['error1'][i], 
                         molecule_array['error2'][i]) 
                        for new_rank, i  in enumerate(sorted_indices)]
    return sorted_molecules

import numpy as np
from collections import defaultdict


def rank_all_molecules(all_rankings, ranking_method):
    new_rankings = defaultdict(list)
    
    for trg_smi, molecule_list in all_rankings.items():
        molecule_data = []
        seen_gen_smiles = set()
        
        for molecule in molecule_list:
            gen_smi = molecule[1]
            if gen_smi in seen_gen_smiles:
                continue
            seen_gen_smiles.add(gen_smi)
            
            tanimoto = molecule[2]
            errors = molecule[4:6]
            if errors == (9, 9):  # Check if both errors are 9
                errors = [9, 9]  # Keep it as is
            molecule_data.append((gen_smi, errors[0], errors[1], tanimoto))
        
        if not molecule_data:
            continue
        
        molecule_array = np.array(molecule_data, dtype=[('gen_smi', 'U100'), 
                                                        ('error1', float), ('error2', float), 
                                                        ('tanimoto', float)])
        
        if ranking_method == 'HSQC & COSY':
            rank1 = molecule_array.argsort(order='error1')
            rank2 = molecule_array.argsort(order='error2')
            average_ranks = (np.arange(len(rank1))[np.argsort(rank1)] + 
                             np.arange(len(rank2))[np.argsort(rank2)]) / 2
            sorted_indices = average_ranks.argsort()
            
        elif ranking_method == 'HSQC':
            sorted_indices = molecule_array.argsort(order='error1')
        elif ranking_method == 'COSY':
            sorted_indices = molecule_array.argsort(order='error2')
        else:
            raise ValueError("Invalid ranking method. Choose 'HSQC & COSY', 'HSQC', or 'COSY'.")
        
        for new_rank, i in enumerate(sorted_indices):
            gen_smi = molecule_array['gen_smi'][i]
            tanimoto = molecule_array['tanimoto'][i]
            error1 = molecule_array['error1'][i]
            error2 = molecule_array['error2'][i]
            
            new_rankings[trg_smi].append((trg_smi, gen_smi, tanimoto, new_rank, error1, error2))
    
    return new_rankings


def prepare_mol(smiles):
    mol = Chem.MolFromSmiles(smiles)
    AllChem.Compute2DCoords(mol)
    return mol

def get_mol_weight(mol):
    return round(Descriptors.ExactMolWt(mol), 2)


In [None]:
all_rankings_list, filtered_out = main_normal()


#### Plot baseline performance

In [None]:
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

colors = [
    '#E57373', '#A2A37E', '#64B689', '#4DA6A9', '#5C95CC', '#9574D0', '#EB6CC2'
]

def load_data(file_path):
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    return data["results_dict_bl_ZINC"]

def rank_molecules(file_data, ranking_method):
    all_rankings = {}
    for trg_smi, value_list in file_data.items():
        molecule_data = []
        for sublist in value_list[0]:
            try:
                gen_smi = sublist[0]
                tanimoto = sublist[4]
                errors = sublist[5]
                molecule_data.append((gen_smi, errors[0], errors[1], tanimoto))
            except:
                print(f"Error processing sublist for {trg_smi}: {sublist}")
                continue

        if not molecule_data:
            continue

        molecule_array = np.array(molecule_data, dtype=[('gen_smi', 'U100'), 
                                                        ('error1', float), ('error2', float), 
                                                        ('tanimoto', float)])
        
        if ranking_method == 'HSQC & COSY':
            rank1 = molecule_array.argsort(order='error1')
            rank2 = molecule_array.argsort(order='error2')
            average_ranks = (np.arange(len(rank1))[np.argsort(rank1)] + 
                             np.arange(len(rank2))[np.argsort(rank2)]) / 2
            sorted_indices = average_ranks.argsort()
        elif ranking_method == 'HSQC':
            sorted_indices = molecule_array.argsort(order='error1')
        elif ranking_method == 'COSY':
            sorted_indices = molecule_array.argsort(order='error2')
        else:
            raise ValueError("Invalid ranking method. Choose 'HSQC & COSY', 'HSQC', or 'COSY'.")

        sorted_molecules = [(trg_smi,
                             molecule_array['gen_smi'][i],
                             molecule_array["tanimoto"][i],
                             new_rank,  # Use index as rank
                             molecule_array['error1'][i], 
                             molecule_array['error2'][i]) 
                            for  new_rank, i in enumerate(sorted_indices)]
        
        all_rankings[trg_smi] = sorted_molecules
    
    return all_rankings
"""
def process_pkl_files(file_paths_dict, ranking_method):
    all_rankings = {}
    
    for data_type, file_path in file_paths_dict.items():
        file_data = load_data(file_path)
        ranked_molecules = rank_molecules_in_file(file_data, ranking_method)
        all_rankings[data_type] = defaultdict(list)
        for molecule in ranked_molecules:
            trg_smi = molecule[0]
            all_rankings[data_type][trg_smi].append(molecule)
    
        # Sort rankings for each target SMILES
        for trg_smi in all_rankings[data_type]:
            all_rankings[data_type][trg_smi].sort(key=lambda x: x[3])  # Sort by rank
    
    return all_rankings
    
def calculate_top_k_accuracy(all_rankings, k_range=[1, 3, 5, 10, 20]):
    accuracies = []
    total_molecules = len(all_rankings)
    for k in k_range:
        correct_count = sum(
            any(molecule[2] == 1.0 for molecule in rankings[:k])
            for rankings in all_rankings.values()
        )
        accuracy = correct_count / total_molecules
        accuracies.append(accuracy)
    return accuracies
    
def process_pkl_files_BL(file_paths_dict, ranking_method):
    all_rankings = {}
    for data_type, file_path in file_paths_dict.items():
        file_data = load_data_results(file_path)
        all_rankings[data_type] = rank_molecules(file_data, ranking_method)
        
    return all_rankings
"""



def plot_top_k_accuracy(accuracies, data_type, model_type, ranking_method, sim_rank_one_count, save_path=None, total_samples=34):
    fontsize = 30
    k_values = [1, 3, 5, 10, 20]
    labels = [f'Top {k}' for k in k_values] + ['Total']
    
    plt.figure(figsize=(12, 8))  # Slightly wider to accommodate longer labels
    bars = plt.bar(range(len(labels)), accuracies + [sim_rank_one_count / total_samples], color=colors[:len(labels)])
    
    plt.ylabel('Accuracy', fontsize=fontsize)
    plt.title(f'{data_type} Data and {ranking_method} Ranking', fontsize=fontsize+4)
    
    plt.xticks(range(len(labels)), labels, fontsize=fontsize, rotation=0, ha='center')
    plt.yticks(fontsize=fontsize)
    plt.ylim(0, 1.1)
    
    for bar in bars:
        height = bar.get_height()
        correct_samples = int(height * total_samples)
        
        plt.text(bar.get_x() + bar.get_width()/2., height,
                 f'{height:.2f}',
                 ha='center', va='bottom', fontsize=fontsize)
        
        plt.text(bar.get_x() + bar.get_width()/2., height/2,
                 f'{correct_samples}',
                 ha='center', va='center', fontsize=fontsize, color='black')
    
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Plot saved to {save_path}")
    
    plt.close()
    
def main_BL():
    #folder_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240724_IC_batched_MMT_a1_baseline/"
    model_type = "MMT"
    ranking_methods = ["COSY", "HSQC", "HSQC & COSY"]
    file_types = ["sim_sim_data", "ACD_sim_data", "exp_sim_data"]
    
    
    file_paths_dict = {
        "sim_sim_data": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240724_IC_batched_MMT_a1_baseline/8.0_sim_real_data_before_FT_MMTi_v0.pkl",
        "ACD_sim_data": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240724_IC_batched_MMT_a1_baseline/8.2_simACD_real_data_before_FT_MMTi_v0.pkl",
        "exp_sim_data": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240724_IC_batched_MMT_a1_baseline/8.3_real_data_before_FT_MMTi_v0.pkl"
    }    
    
    #file_paths_dict = {
    #    "sim_sim_data": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240724_IC_batched_MMT_d1_baseline/8.0_sim_real_data_before_FT_MMT_v0.pkl",
    #    "ACD_sim_data": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240724_IC_batched_MMT_d1_baseline/8.2_simACD_real_data_before_FT_MMT_v0.pkl",
    #    "exp_sim_data": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240724_IC_batched_MMT_d1_baseline/8.3_real_data_before_FT_MMT_v0.pkl"
    #}    
    
    data_type_map = {
        "sim_sim_data": "Our Simulated",
        "ACD_sim_data": "ACD Simulated",
        "exp_sim_data": "Experimental"
    }
    for ranking_method in ranking_methods:
        all_rankings = process_pkl_files_BL(file_paths_dict, ranking_method)
        
        for data_type, rankings in all_rankings.items():
            #import IPython; IPython.embed();
            accuracies = calculate_top_k_accuracy(rankings)
            sim_rank_one_count = count_molecules_with_sim_rank_one(rankings)
            
            save_path = f"/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/top_k_accuracy_MMTi_baseline_{data_type}_{ranking_method}_FINAL.png"
            #save_path = f"/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/top_k_accuracy_MMT_baseline_{data_type}_{ranking_method}_v8.png"
            
            plot_top_k_accuracy(accuracies, data_type_map[data_type], model_type, ranking_method, sim_rank_one_count, save_path, total_samples=len(rankings))

            print(f"Completed plot for {ranking_method} - {data_type_map[data_type]}")



In [None]:
main_BL()

#### Plot all molecules

In [None]:
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import Descriptors
from math import ceil
import os
from IPython.display import display, SVG

def plot_molecules_from_df(df, smiles_col, id_col, mols_per_row=5, output_folder='molecule_images'):
    """
    Plot molecules from a DataFrame with SMILES and sample ID columns and save images to a folder.
    
    Args:
    df (pd.DataFrame): DataFrame containing SMILES and sample ID columns
    smiles_col (str): Name of the column containing SMILES strings
    id_col (str): Name of the column containing sample IDs
    mols_per_row (int): Number of molecules to display per row (default: 5)
    output_folder (str): Name of the folder to save images (default: 'molecule_images')
    """
    mols = []
    legends = []
    
    # Create output folder if it doesn't exist
    os.makedirs(output_folder, exist_ok=True)
    
    for _, row in df.iterrows():
        mol = Chem.MolFromSmiles(row[smiles_col])
        if mol is not None:
            mols.append(mol)
            mol_weight = Descriptors.ExactMolWt(mol)
            legend = f"{row[id_col]} | MW: {mol_weight:.2f}"
            legends.append(legend)
    
    n_rows = ceil(len(mols) / mols_per_row)
    
    for i in range(n_rows):
        start_idx = i * mols_per_row
        end_idx = min((i + 1) * mols_per_row, len(mols))
        
        img = Draw.MolsToGridImage(
            mols[start_idx:end_idx],
            molsPerRow=mols_per_row,
            subImgSize=(300, 300),
            legends=[legends[j] for j in range(start_idx, end_idx)],
            useSVG=True
        )
        
        # Save the SVG image
        filename = f'molecules_row_aug_{i+1}.svg'
        filepath = os.path.join(output_folder, filename)
        with open(filepath, 'w') as f:
            f.write(img.data)
        
        print(f"Saved {filename}")
        
        # Display the image (optional, comment out if not needed)
        display(SVG(img.data))

# Example usage:
#df = pd.DataFrame({
#     'SMILES': ['CCO', 'CC(=O)O', 'c1ccccc1', 'CN1C=NC2=C1C(=O)N(C(=O)N2C)C', 'CCN(CC)CC'],
#     'Sample_ID': ['S1', 'S2', 'S3', 'S4', 'S5']
# })
from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.drawOptions.addAtomIndices = False 

#df = pd.read_csv("/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/37_Richard_ACD_sim_data/ACD_1H_with_SN_filtered_v3.csv")
df = pd.read_csv("/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/37_Richard_ACD_sim_data/ACD_1H_with_SN_filtered_v3_regio_aug.csv")

output_folder = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/experimental_mol_34"

plot_molecules_from_df(df, 'SMILES_regio_isomers', 'sample-id', output_folder=output_folder)
#plot_molecules_from_df(df, 'SMILES', 'sample-id', output_folder=output_folder)


In [None]:
import pandas as pd
from rdkit import Chem

def find_duplicate_smiles(df, smiles_col, id_col):
    """
    Find duplicate SMILES in a DataFrame after canonicalization.
    
    Args:
    df (pd.DataFrame): DataFrame containing SMILES and sample ID columns
    smiles_col (str): Name of the column containing SMILES strings
    id_col (str): Name of the column containing sample IDs
    
    Returns:
    pd.DataFrame: A DataFrame containing duplicate SMILES, their IDs, and count
    """
    # Create a dictionary to store canonical SMILES and their corresponding IDs
    canonical_smiles_dict = {}
    
    for _, row in df.iterrows():
        smiles = row[smiles_col]
        mol_id = row[id_col]
        
        # Generate canonical SMILES
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            canonical_smiles = Chem.MolToSmiles(mol, canonical=True)
            
            # Add to dictionary
            if canonical_smiles in canonical_smiles_dict:
                canonical_smiles_dict[canonical_smiles].append((mol_id, smiles))
            else:
                canonical_smiles_dict[canonical_smiles] = [(mol_id, smiles)]
    
    # Filter for duplicates and create a list of results
    duplicates = []
    for canonical_smiles, id_smiles_list in canonical_smiles_dict.items():
        if len(id_smiles_list) > 1:
            for mol_id, original_smiles in id_smiles_list:
                duplicates.append({
                    'Canonical_SMILES': canonical_smiles,
                    'Original_SMILES': original_smiles,
                    'ID': mol_id,
                    'Duplicate_Count': len(id_smiles_list)
                })
    
    # Create a DataFrame from the list of duplicates
    duplicates_df = pd.DataFrame(duplicates)
    
    # Sort by Duplicate_Count (descending) and Canonical_SMILES
    duplicates_df = duplicates_df.sort_values(['Duplicate_Count', 'Canonical_SMILES'], ascending=[False, True])
    
    return duplicates_df

# Example usage:
df = pd.read_csv("/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/37_Richard_ACD_sim_data/ACD_1H_with_SN_filtered_v3.csv")
duplicate_smiles = find_duplicate_smiles(df, 'SMILES', 'sample-id')
print(duplicate_smiles)
print(f"Total number of duplicate entries: {len(duplicate_smiles)}")
print(f"Number of unique molecules with duplicates: {duplicate_smiles['Canonical_SMILES'].nunique()}")

### 7.0 Improvement Cycle on Similar Molecules

In [None]:
import os
from typing import List, Dict, Any, Union, Tuple
import pandas as pd
from datetime import datetime
import pickle
import re
import tempfile


def split_dataset(config, chunk_size: int) -> List[pd.DataFrame]:
    df = pd.read_csv(config.SGNN_csv_gen_smi)
    return [df[i:i+chunk_size] for i in range(0, len(df), chunk_size)]


def create_chunk_folder(config, idx: int) -> str:
    base_dir = config.model_save_dir
    current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S")
    chunk_folder_name = f"chunk_{idx:03d}_{current_datetime}"
    chunk_folder_path = os.path.join(base_dir, chunk_folder_name)
    
    os.makedirs(chunk_folder_path, exist_ok=True)
    print(f"Created folder for chunk {idx}: {chunk_folder_path}")
    
    return chunk_folder_path

def test_pretrained_model_on_sim_data_before(config, IR_config, stoi, itos, stoi_MF, itos_MF, chunk, idx):
    MW_filter, greedy_full = True, False
    
    print("prepare_data")
    config = prepare_data(config, chunk)
    print("generate_simulated_data")
    config = generate_simulated_data(config, IR_config)

    print("load_model_and_data")
    model_MMT, val_dataloader, val_dataloader_multi = load_model_and_data(config, stoi, stoi_MF)

    print("run_model_analysis")
    prob_dict_results_1c_, results_dict_1c_ = mrtf.run_model_analysis(config, model_MMT, val_dataloader_multi, stoi, itos)

    results = test_model_performance(config, model_MMT, val_dataloader, val_dataloader_multi, stoi, itos, stoi_MF, itos_MF)

    save_results_before(results, config, idx)

    return config

def prepare_data(config: Any, chunk: pd.DataFrame) -> Any:
    chunk_csv_path = os.path.join(config.pkl_save_folder, "SGNN_csv_gen_smi.csv")
    chunk.to_csv(chunk_csv_path)
    config.SGNN_csv_gen_smi = chunk_csv_path 
    config.data_size = len(chunk)
    return config

def generate_simulated_data(config: Any, IR_config: Any) -> Any:
    config.execution_type = "data_generation"
    if config.execution_type == "data_generation":
        print("\033[1m\033[31mThis is: data_generation\033[0m")
        #import IPython; IPython.embed();

        config = ex.gen_sim_aug_data(config, IR_config)
        backup_config_paths(config)
    return config

def backup_config_paths(config: Any) -> None:
    config.csv_1H_path_SGNN_backup = copy.deepcopy(config.csv_1H_path_SGNN)
    config.csv_13C_path_SGNN_backup = copy.deepcopy(config.csv_13C_path_SGNN)
    config.csv_HSQC_path_SGNN_backup = copy.deepcopy(config.csv_HSQC_path_SGNN)
    config.csv_COSY_path_SGNN_backup = copy.deepcopy(config.csv_COSY_path_SGNN)
    config.IR_data_folder_backup = copy.deepcopy(config.IR_data_folder)

def save_results_before(results: Dict[str, Any], config: Any, idx: int) -> None:
    variables_to_save = {
        'avg_tani_bl_ZINC': results['avg_tani_bl_ZINC_'],
        'results_dict_greedy_bl_ZINC': results.get('results_dict_greedy_bl_ZINC_'),
        'failed_bl_ZINC': results.get('failed_bl_ZINC_'),
        'avg_tani_greedy_bl_ZINC': results['avg_tani_greedy_bl_ZINC_'],
        'results_dict_ZINC_greedy_bl': results.get('results_dict_ZINC_greedy_bl_'),
        'total_results_bl_ZINC': results['total_results_bl_ZINC_'],
        'corr_sampleing_prob_bl_ZINC': results['corr_sampleing_prob_bl_ZINC_'],
        'results_dict_bl_ZINC': results['results_dict_bl_ZINC_'],
    }
    save_data_with_datetime_index(variables_to_save, config.pkl_save_folder, "before_sim_data", idx)

def create_run_folder(chunk_folder, idx):
    current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S")
    run_folder_name = f"run_{idx}_{current_datetime}"
    run_folder_path = os.path.join(chunk_folder, run_folder_name)
    
    os.makedirs(run_folder_path, exist_ok=True)
    print(f"Created folder for run {idx}: {run_folder_path}")
    
    return run_folder_path

def fine_tune_model_aug_mol(config, stoi, stoi_MF, chunk, idx):
    #import IPython; IPython.embed();
    config, all_gen_smis, aug_mol_df = generate_augmented_molecules_from_aug_mol(config, chunk, idx)
    
    config.parent_model_save_dir = config.model_save_dir
    config.model_save_dir = config.current_run_folder 
    
    if config.execution_type == "transformer_improvement":
        print("\033[1m\033[31mThis is: transformer_improvement, sim_data_gen == TRUE\033[0m")
        config.training_setup = "pretraining"
        mtf.run_MMT(config, stoi, stoi_MF)
    
    config.model_save_dir = config.parent_model_save_dir
    #config = ex.update_model_path(config)

    return config, aug_mol_df, all_gen_smis


def generate_augmented_molecules_from_aug_mol(config, chunk, idx):
    #import IPython; IPython.embed();

    ############# THis is just relevant for the augmented molecules #############
    chunk.rename(columns={'SMILES': 'SMILES_orig', 'SMILES_regio_isomers': 'SMILES'}, inplace=True)
    #############################################################################
    
    script_dir = os.getcwd()
    
    base_path = os.path.abspath(os.path.join(script_dir, 'deep-molecular-optimization'))

    csv_file_path = f'{base_path}/data/MMP/test_selection_2.csv'
    chunk.to_csv(csv_file_path, index=False)
    print(f"CSV file '{csv_file_path}' created successfully.")

    config.data_size = len(chunk)
    config.n_samples = config.data_size

    config, results_dict_MF = generate_smiles_mf(config)

    combined_list_MF = process_generated_smiles(results_dict_MF, config)

    all_gen_smis = filter_and_combine_smiles(combined_list_MF)

    aug_mol_df = create_augmented_dataframe(all_gen_smis)

    config, final_df = ex.blend_aug_with_train_data(config, aug_mol_df)

    config = ex.gen_sim_aug_data(config, IR_config)
    config.execution_type = "transformer_improvement"

    return config, all_gen_smis, aug_mol_df


def fine_tune_model(config, stoi, stoi_MF, chunk, idx):
    """
    Fine-tune the model on a chunk of data.
    """
    config, aug_mol_df, all_gen_smis = generate_augmented_molecules(config, chunk, idx)
    
    config.parent_model_save_dir = config.model_save_dir
    new_model_save_dir = create_model_save_dir(config.parent_model_save_dir, idx)
    config.model_save_dir = new_model_save_dir
    
    # Fine-tune the model
    if config.execution_type == "transformer_improvement":
        print("\033[1m\033[31mThis is: transformer_improvement, sim_data_gen == TRUE\033[0m")
        config.training_setup = "pretraining"
        mtf.run_MMT(config, stoi, stoi_MF)
        
    #config = ex.update_model_path(config)
    config.model_save_dir = config.parent_model_save_dir
    
    return config, aug_mol_df, all_gen_smis

def generate_augmented_molecules(config, chunk, idx):
    #import IPython; IPython.embed();
    script_dir = os.getcwd()
    
    base_path = os.path.abspath(os.path.join(script_dir, 'deep-molecular-optimization'))

    csv_file_path = f'{base_path}/data/MMP/test_selection_2.csv'
    chunk.to_csv(csv_file_path, index=False)
    print(f"CSV file '{csv_file_path}' created successfully.")

    config.data_size = len(chunk)
    config.n_samples = config.data_size

    config, results_dict_MF = generate_smiles_mf(config)

    combined_list_MF = process_generated_smiles(results_dict_MF, config)

    all_gen_smis = filter_and_combine_smiles(combined_list_MF)

    aug_mol_df = create_augmented_dataframe(all_gen_smis)

    config, final_df = ex.blend_aug_with_train_data(config, aug_mol_df)

    config = ex.gen_sim_aug_data(config, IR_config)
    config.execution_type = "transformer_improvement"

    return config, all_gen_smis, aug_mol_df


def generate_smiles_mf(config):
    print("\033[1m\033[31mThis is: SMI_generation_MF\033[0m")
    return ex.SMI_generation_MF(config, stoi, stoi_MF, itos, itos_MF)

def process_generated_smiles(results_dict_MF, config):
    results_dict_MF = {key: value for key, value in results_dict_MF.items() if not hf.contains_only_nan(value)}
    for key, value in results_dict_MF.items():
        results_dict_MF[key] = hf.remove_nan_from_list(value)

    combined_list_MF, _, _, _ = cv.plot_cluster_MF(results_dict_MF, config)
    return combined_list_MF

def filter_and_combine_smiles(combined_list_MF):
    print("\033[1m\033[31mThis is: combine_MMT_MF\033[0m")
    all_gen_smis = combined_list_MF
    all_gen_smis = [smiles for smiles in all_gen_smis if smiles != 'NAN']

    val_data = pd.read_csv(config.csv_path_val)
    all_gen_smis = mrtf.filter_smiles(val_data, all_gen_smis)
    return all_gen_smis

def create_augmented_dataframe(all_gen_smis):
    length_of_list = len(all_gen_smis)
    random_number_strings = [f"GT_{str(i).zfill(7)}" for i in range(1, length_of_list + 1)]
    return pd.DataFrame({'SMILES': all_gen_smis, 'sample-id': random_number_strings})

def setup_data_paths(config):
    base_path_acd = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/37_Richard_ACD_sim_data/"
    config.csv_1H_path_ACD = f"{base_path_acd}ACD_1H_with_SN_filtered_v3.csv"
    config.csv_13C_path_ACD = f"{base_path_acd}ACD_13C_with_SN_filtered_v3.csv"
    config.csv_HSQC_path_ACD = f"{base_path_acd}ACD_HSQC_with_SN_filtered_v3.csv"
    config.csv_COSY_path_ACD = f"{base_path_acd}ACD_COSY_with_SN_filtered_v3.csv"
    config.IR_data_folder_ACD = f"{base_path_acd}IR_spectra"
    
    base_path_exp = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/36_Richard_43_dataset/experimenal_data/"
    config.csv_1H_path_exp = f"{base_path_exp}real_1H_with_AZ_SMILES_v3.csv"
    config.csv_13C_path_exp = f"{base_path_exp}real_13C_with_AZ_SMILES_v3.csv"
    config.csv_HSQC_path_exp = f"{base_path_exp}real_HSQC_with_AZ_SMILES_v3.csv"
    config.csv_COSY_path_exp = f"{base_path_exp}real_COSY_with_AZ_SMILES_v3.csv"
    config.IR_data_folder_exp = f"{base_path_exp}IR_data"
    return config

def test_model_on_datasets(config, IR_config, stoi, itos, stoi_MF, itos_MF, chunk, composite_idx, aug_mol_df, all_gen_smis):
    checkpoint_path_backup = config.checkpoint_path    
    for data_type in ['exp', 'sim', 'ACD', ]:
        print(f"Testing on {data_type} data")
        config.pickle_file_path = ""
        config.training_mode = "1H_13C_HSQC_COSY_IR_MF_MW"
        config = test_on_data(config, IR_config, stoi, itos, stoi_MF, itos_MF, chunk, composite_idx, data_type, aug_mol_df, all_gen_smis)
    config.checkpoint_path = checkpoint_path_backup
    return config

def test_on_data(config, IR_config, stoi, itos, stoi_MF, itos_MF, chunk, composite_idx, data_type, aug_mol_df, all_gen_smis):
    if data_type == 'sim':
        restore_backup_configs(config)
    else:
        sample_ids = chunk['sample-id'].tolist()
        process_spectrum_data(config, sample_ids, data_type)
    #import IPython; IPython.embed();

    update_config_settings(config)
    last_checkpoint = get_last_checkpoint(config.current_run_folder)
    config.checkpoint_path = last_checkpoint
    
    model_MMT, val_dataloader, val_dataloader_multi = load_model_and_data(config, stoi, stoi_MF)
    
    prob_dict_results_1c_, results_dict_1c_ = mrtf.run_model_analysis(config, model_MMT, val_dataloader_multi, stoi, itos)

    results = test_model_performance(config, model_MMT, val_dataloader, val_dataloader_multi,
                                     stoi, itos, stoi_MF, itos_MF)
    
    if data_type == 'sim':
        results['aug_mol_df'] = aug_mol_df
        results['all_gen_smis'] = all_gen_smis
    
    save_results_acd_exp(results, config, data_type, composite_idx)
    return config

def restore_backup_configs(config):
    config.csv_1H_path_SGNN = config.csv_1H_path_SGNN_backup
    config.csv_13C_path_SGNN = config.csv_13C_path_SGNN_backup
    config.csv_HSQC_path_SGNN = config.csv_HSQC_path_SGNN_backup
    config.csv_COSY_path_SGNN = config.csv_COSY_path_SGNN_backup
    config.IR_data_folder = config.IR_data_folder_backup 
    config.csv_path_val = config.csv_1H_path_SGNN_backup
    config.pickle_file_path = ""

def process_spectrum_data(config: Any, sample_ids: List[str], data_type: str) -> None:
    spectrum_types = ['1H', '13C', 'HSQC', 'COSY']
    for spectrum in spectrum_types:
        csv_path = getattr(config, f'csv_{spectrum}_path_{data_type}')
        df_data = pd.read_csv(csv_path)
        df_data['sample-id'] = df_data['AZ_Number']
        data = select_relevant_samples(df_data, sample_ids)
        dummy_path, config = save_and_update_config(config, data_type, spectrum, data)
        print(f"Saved {spectrum} data to: {dummy_path}")
    if data_type == "ACD" or data_type == "sim":
        config.IR_data_folder = config.IR_data_folder_backup 
    elif  data_type == "exp":
        config.IR_data_folder = config.IR_data_folder_exp 

    
    
def select_relevant_samples(df: pd.DataFrame, sample_ids: List[str]) -> pd.DataFrame:
    return df[df['sample-id'].isin(sample_ids)]

def save_and_update_config(config, data_type: str, spectrum_type: str, data: pd.DataFrame) -> Tuple[str, Any]:
    temp_dir = tempfile.mkdtemp()
    dummy_path = os.path.join(temp_dir, f"{data_type}_{spectrum_type}_selected_samples.csv")
    
    data.to_csv(dummy_path, index=False)
    
    config_key = f'csv_{spectrum_type}_path_SGNN'
    setattr(config, config_key, dummy_path)
    
    return dummy_path, config

def update_config_settings(config: Any) -> None:
    config.csv_path_val = config.csv_1H_path_SGNN
    config.pickle_file_path = ""

def get_last_checkpoint(model_folder: str) -> str:
    checkpoints = [f for f in os.listdir(model_folder) if f.endswith('.ckpt')]
    if not checkpoints:
        raise ValueError(f"No checkpoints found in {model_folder}")
    
    last_checkpoint = max(checkpoints, key=lambda x: os.path.getmtime(os.path.join(model_folder, x)))
    return os.path.join(model_folder, last_checkpoint)

def load_model_and_data(config: Any, stoi: Dict, stoi_MF: Dict) -> Tuple[Any, Any, Any]:
    #import IPython; IPython.embed();

    val_dataloader = mrtf.load_data(config, stoi, stoi_MF, single=True, mode="val")
    val_dataloader_multi = mrtf.load_data(config, stoi, stoi_MF, single=False, mode="val")
    model_MMT = mrtf.load_MMT_model(config)
    return model_MMT, val_dataloader, val_dataloader_multi

def test_model_performance(config: Any, model_MMT: Any, val_dataloader: Any, val_dataloader_multi: Any, 
                           stoi: Dict, itos: Dict, stoi_MF: Dict, itos_MF: Dict) -> Dict[str, Any]:
    print("\033[1m\033[31mThis is: test_performance\033[0m")
    
    MW_filter = True
    greedy_full = False
    
    model_CLIP = mrtf.load_CLIP_model(config)
    
    results = {}
    
    results['results_dict_bl_ZINC_'] = mrtf.run_test_mns_performance_CLIP_3(
        config, model_MMT, model_CLIP, val_dataloader, stoi, itos, MW_filter)
    results['results_dict_bl_ZINC_'], counter = mrtf.filter_invalid_inputs(results['results_dict_bl_ZINC_'])

    results['avg_tani_bl_ZINC_'], html_plot = rbgvm.plot_hist_of_results(results['results_dict_bl_ZINC_'])

    if greedy_full:
        results['results_dict_greedy_bl_ZINC_'], results['failed_bl_ZINC_'] = mrtf.run_test_performance_CLIP_greedy_3(
            config, stoi, stoi_MF, itos, itos_MF)
        results['avg_tani_greedy_bl_ZINC_'], html_plot_greedy = rbgvm.plot_hist_of_results_greedy(
            results['results_dict_greedy_bl_ZINC_'])
    else:
        config, results['results_dict_ZINC_greedy_bl_'] = mrtf.run_greedy_sampling(
            config, model_MMT, val_dataloader_multi, itos, stoi)
        results['avg_tani_greedy_bl_ZINC_'] = results['results_dict_ZINC_greedy_bl_']["tanimoto_mean"]

    results['total_results_bl_ZINC_'] = mrtf.run_test_performance_CLIP_3(
        config, model_MMT, val_dataloader, stoi)
    results['corr_sampleing_prob_bl_ZINC_'] = results['total_results_bl_ZINC_']["statistics_multiplication_avg"][0]

    print("avg_tani, avg_tani_greedy, corr_sampleing_prob'")
    print(results['avg_tani_bl_ZINC_'], results['avg_tani_greedy_bl_ZINC_'], results['corr_sampleing_prob_bl_ZINC_'])
    print("Greedy tanimoto results")
    rbgvm.plot_hist_of_results_greedy_new(results['results_dict_ZINC_greedy_bl_'])

    return results

def save_results_acd_exp(results: Dict[str, Any], config: Any, data_type: str, composite_idx: str) -> None:
    variables_to_save = {
        'avg_tani_bl_ZINC': results['avg_tani_bl_ZINC_'],
        'results_dict_greedy_bl_ZINC': results.get('results_dict_greedy_bl_ZINC_'),
        'failed_bl_ZINC': results.get('failed_bl_ZINC_'),
        'avg_tani_greedy_bl_ZINC': results['avg_tani_greedy_bl_ZINC_'],
        'results_dict_ZINC_greedy_bl': results.get('results_dict_ZINC_greedy_bl_'),
        'total_results_bl_ZINC': results['total_results_bl_ZINC_'],
        'corr_sampleing_prob_bl_ZINC': results['corr_sampleing_prob_bl_ZINC_'],
        'results_dict_bl_ZINC': results['results_dict_bl_ZINC_'],
        'checkpoint_path': config.checkpoint_path,
    }
    
    if data_type == 'sim':
        variables_to_save['aug_mol_df'] = results.get('aug_mol_df')
        variables_to_save['all_gen_smis'] = results.get('all_gen_smis')
    
    save_data_with_datetime_index(
        variables_to_save, 
        config.pkl_save_folder, 
        f"{data_type}_sim_data", 
        composite_idx
    )

def save_data_with_datetime_index(data: Any, base_folder: str, name: str, idx: Union[int, str]) -> None:
    current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"{current_datetime}_{name}_{idx}.pkl"
    os.makedirs(base_folder, exist_ok=True)
    file_path = os.path.join(base_folder, filename)

    
    with open(file_path, 'wb') as f:
        pickle.dump(data, f)
    
    print(f"Data saved to: {file_path}")

In [None]:

def main_IC(chunk_size, config, IR_config, stoi, itos, stoi_MF, itos_MF, num_training_runs=3):
    chunks = split_dataset(config, chunk_size)
    config.model_save_dir = config.pkl_save_folder
    model_save_dir_backup = config.model_save_dir
    original_checkpoint_path = config.checkpoint_path  # Store the original checkpoint path

    for chunk_idx, chunk in enumerate(chunks):
        print(f"Processing chunk {chunk_idx+1} of {len(chunks)}")
        
        chunk_folder = create_chunk_folder(config, chunk_idx)
        config.current_chunk_folder = chunk_folder
            
        config.blank_percentage = 0
        config = test_pretrained_model_on_sim_data_before(config, IR_config, stoi, itos, stoi_MF, itos_MF, chunk, f"{chunk_idx}_{0}")
        print(config.csv_1H_path_SGNN)
        for run_idx in range(num_training_runs):
            print(f"Starting training run {run_idx+1} of {num_training_runs}")
            
            run_folder = create_run_folder(config.current_chunk_folder, f"{chunk_idx}_{run_idx}")
            config.current_run_folder = run_folder
            config.model_save_dir = run_folder

            config.blank_percentage = 50
            config, aug_mol_df, all_gen_smis = fine_tune_model_aug_mol(config, stoi, stoi_MF, chunk, f"{chunk_idx}_{run_idx}")
            #import IPython; IPython.embed();

            ### Retrun the labelling of the smiles to test it on the correct one
            #chunk.rename(columns={'SMILES': 'SMILES_regio_isomers', 'SMILES_orig': 'SMILES'}, inplace=True)

            config.blank_percentage = 0
            config = setup_data_paths(config)
            config = test_model_on_datasets(config, IR_config, stoi, itos, stoi_MF, itos_MF, chunk, f"{chunk_idx}_{run_idx}", aug_mol_df, all_gen_smis)

            config.checkpoint_path = original_checkpoint_path
        
        print(f"Chunk {chunk_idx+1} completed. All training runs finished.")
        config.model_save_dir = model_save_dir_backup


In [None]:
### RUN a1
config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_Drop4/MultimodalTransformer_time_1710027004.1571195_Loss_0.112.ckpt"
#config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8_MMT_MW2_Drop/MultimodalTransformer_time_1706856620.371885_Loss_0.202.ckpt"

config.pkl_save_folder = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Sim_ACD_Exp_aug/test_on_aug_data_34_v3"
#config.model_save_dir = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/small_8_3" # Folder where networks are saved
config.SGNN_csv_gen_smi = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Sim_ACD_Exp_aug/ACD_1H_with_SN_filtered_v3_regio_aug.csv"
#config.SGNN_csv_gen_smi = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/37_Richard_ACD_sim_data/ACD_1H_with_SN_filtered_v3.csv"

config.MF_generations = 30#50
config.MF_delta_weight = 100
config.max_scaffold_generations = 300
config.blank_percentage = 50
config.weight_MW = 100
config.lr_pretraining = 3e-4
config.tr_te_split = 0.9
config.batch_size = 64
config.num_epochs = 30#30
config.temperature = 1
config.multinom_runs = 20 #20
config.train_data_blend = 0

chunk_size = 1

main_IC(chunk_size, config, IR_config, stoi, itos, stoi_MF, itos_MF, 3)

In [None]:
### RUN a1
config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8i_MMT_Drop4/MultimodalTransformer_time_1710027004.1571195_Loss_0.112.ckpt"
#config.checkpoint_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/V8_MMT_MW2_Drop/MultimodalTransformer_time_1706856620.371885_Loss_0.202.ckpt"

config.pkl_save_folder = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Sim_ACD_Exp_aug/test_on_aug_data_34_v4"
#config.model_save_dir = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/1_old_models/models_v2/small_8_3" # Folder where networks are saved
config.SGNN_csv_gen_smi = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Sim_ACD_Exp_aug/ACD_1H_with_SN_filtered_v3_regio_aug.csv"
#config.SGNN_csv_gen_smi = "/projects len(df) len(df) len(df)/cc/knlr326/1_NMR_project/1_NMR_data_AZ/37_Richard_ACD_sim_data/ACD_1H_with_SN_filtered_v3.csv"

config.MF_generations = 300#50
config.MF_delta_weight = 100
config.max_scaffold_generations = 300
config.blank_percentage = 50
config.weight_MW = 100
config.lr_pretraining = 3e-4
config.tr_te_split = 0.9
config.batch_size = 64
config.num_epochs = 10#30
config.temperature = 1
config.multinom_runs = 20 #20
config.train_data_blend = 0

chunk_size = 1

main_IC(chunk_size, config, IR_config, stoi, itos, stoi_MF, itos_MF, 3)

In [None]:
def main_aug():
    folder_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Sim_ACD_Exp_aug/test_on_aug_data_34_v3/"
    #folder_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/precomputed_raw_data/20240724_IC_batched_MMT_d1/"
    model_type = "MMT"
    ranking_methods = ["COSY", "HSQC", "HSQC & COSY"]
    file_types = ["sim_sim_data", "ACD_sim_data", "exp_sim_data"]

    data_type_map = {
        "sim_sim_data": "Our Simulated",
        "ACD_sim_data": "ACD Simulated",
        "exp_sim_data": "Experimental"
    }
    filtered_out = []
    all_rankings_list = []
    for ranking_method in ranking_methods:
        for file_type in file_types:
            all_rankings = process_pkl_files_new(folder_path, file_type, ranking_method)
            #import IPython; IPython.embed();
            all_rankings, removed_smiles = exp_func.deduplicate_smiles_from_ranking(all_rankings)
            all_rankings, filtered_out_rankings = filter_rankings_by_molecular_formula(all_rankings)
            filtered_out.append(filtered_out_rankings)
            all_rankings_list.append(all_rankings)
            
            accuracies = calculate_top_k_accuracy(all_rankings)
            sim_rank_one_count = count_molecules_with_sim_rank_one(all_rankings)

            data_type = data_type_map[file_type]
            save_path = f"/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/similar_starting_top_k_accuracy_MMT_{file_type}_{ranking_method}_FINAL.png"
            #save_path = f"/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/top_k_accuracy_MMT_{file_type}_{ranking_method}_v10.png"

            plot_top_k_accuracy(accuracies, data_type, model_type, ranking_method, sim_rank_one_count, save_path, total_samples=len(all_rankings.keys()))

            print(f"Completed plot for {ranking_method} - {data_type}")
    return all_rankings_list, filtered_out



In [None]:
all_rankings_list, filtered_out = main_aug()

In [None]:
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Draw
from IPython.display import display

# Helper function to calculate molecular weight
def calculate_weight(smiles):
    mol = Chem.MolFromSmiles(smiles)
    return Chem.Descriptors.ExactMolWt(mol)

def plot_smiles_from_df(data_sorted, column_name):
    list_smiles = []
    sample_id_list = []
    for idx, smiles in enumerate(data_sorted[column_name]):
        weight = calculate_weight(smiles)
        string = f"{column_name}_{idx}_{weight:.2f}"
        sample_id_list.append(string)
        list_smiles.append(Chem.MolFromSmiles(smiles))
        if len(list_smiles) == 9 or idx == len(data_sorted) - 1:
            pic = Draw.MolsToGridImage(list_smiles, subImgSize=(250, 250), legends=sample_id_list)
            display(pic)
            list_smiles = []
            sample_id_list = []

# Read the CSV file
df = pd.read_csv("/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Sim_ACD_Exp_aug/ACD_1H_with_SN_filtered_v3_regio_aug.csv")

# Plot SMILES from both columns
print("Original SMILES:")
plot_smiles_from_df(df, 'SMILES')

print("\nRegio Isomer SMILES:")
plot_smiles_from_df(df, 'SMILES_regio_isomers')

# Plotting both columns side by side
def plot_smiles_side_by_side(data_sorted, column1, column2):
    list_smiles1 = []
    list_smiles2 = []
    sample_id_list1 = []
    sample_id_list2 = []
    for idx, (smiles1, smiles2) in enumerate(zip(data_sorted[column1], data_sorted[column2])):
        weight1 = calculate_weight(smiles1)
        weight2 = calculate_weight(smiles2)
        string1 = f"{column1}_{idx}_{weight1:.2f}"
        string2 = f"{column2}_{idx}_{weight2:.2f}"
        sample_id_list1.append(string1)
        sample_id_list2.append(string2)
        list_smiles1.append(Chem.MolFromSmiles(smiles1))
        list_smiles2.append(Chem.MolFromSmiles(smiles2))
        if len(list_smiles1) == 4 or idx == len(data_sorted) - 1:
            combined_smiles = [mol for pair in zip(list_smiles1, list_smiles2) for mol in pair]
            combined_legends = [legend for pair in zip(sample_id_list1, sample_id_list2) for legend in pair]
            pic = Draw.MolsToGridImage(combined_smiles, molsPerRow=2, subImgSize=(250, 250), legends=combined_legends)
            display(pic)
            list_smiles1 = []
            list_smiles2 = []
            sample_id_list1 = []
            sample_id_list2 = []

print("\nSide-by-side comparison:")
plot_smiles_side_by_side(df, 'SMILES', 'SMILES_regio_isomers')

In [None]:
list(results_dict_bl_ZINC.values())[0][0]

In [None]:
results_dict_bl_ZINC = df['results_dict_bl_ZINC']
smiles_list = [value[0] for value in list(results_dict_bl_ZINC.values())[0][0]]
smiles_list

In [None]:
file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Sim_ACD_Exp_aug/test_on_aug_data_34_v1/20240925_144233_sim_sim_data_0_0.pkl"
df = pd.read_pickle(file_path)

In [None]:
Chem.MolFromSmiles('COc1ccc(CC(C)NCC(O)c2ccc(O)c(NC=O)c2)cc1')

In [None]:
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Draw
from IPython.display import display

def plot_smiles_grid(smiles_list, molsPerRow=5, subImgSize=(200, 200)):
    """
    Plot a list of SMILES as a grid with a specified number of molecules per row.
    
    :param smiles_list: List of SMILES strings to plot
    :param molsPerRow: Number of molecules per row in the grid (default: 5)
    :param subImgSize: Size of each molecule image (default: (200, 200))
    """
    mols = [Chem.MolFromSmiles(smiles) for smiles in smiles_list]
    
    # Generate labels (you can customize this if needed)
    legends = [f"Mol {i+1}" for i in range(len(mols))]
    
    # Create the grid image
    img = Draw.MolsToGridImage(mols, molsPerRow=molsPerRow, subImgSize=subImgSize, legends=legends)
    
    # Display the image
    display(img)

# Read the pickle file
file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Sim_ACD_Exp_aug/test_on_aug_data_34_v1/20240925_144730_exp_sim_data_0_0.pkl"
df = pd.read_pickle(file_path)

# Extract the SMILES list
results_dict_bl_ZINC = df['results_dict_bl_ZINC']
smiles_list = [value[0] for value in list(results_dict_bl_ZINC.values())[0][0]]

# Print the number of SMILES
print(f"Number of SMILES: {len(smiles_list)}")

# Print the first few SMILES for verification
print("\nFirst few SMILES:")
for smiles in smiles_list[:5]:
    print(smiles)

# Plot the SMILES grid
print("\nPlotting SMILES grid:")
plot_smiles_grid(smiles_list)

# Optionally, you can adjust the grid layout or image size like this:
# plot_smiles_grid(smiles_list, molsPerRow=4, subImgSize=(150, 150))

In [None]:
l = ['COc1ccccc1CC(C)NCC(O)c1ccc(N)c(NC=O)c1',
 'COc1ccccc1CC(CO)NCC(O)C(O)Cc1ccc(O)cc1',
 'CCN(CC)c1ccc(CC(C)NCC(O)c2ccc(O)c(NC=O)c2)cc1',
 'C=CNc1cc(O)ccc1C(O)CNC(C)Cc1ccccc1OC',
 'O=CNc1cccc(C=O)c1C(=O)C(O)CNC(CO)Cc1ccccc1O',
 'COc1ccc(CC(C)NCC(O)c2ccc(O)c(NC=O)c2)c(OC)c1',
 'COc1ccccc1CC(C=O)NCC(O)C(O)C1CCOc2ccc1cc2NC=O',
 'CCOc1ccccc1CC(C)NCC(O)c1ccc(O)c(NC=O)c1',
 'COc1ccccc1CC(C)NCC(OC)c1ccc(O)c(NC=O)c1',
 'CCOc1cc(C(C)NCC(O)c2ccc(O)c(NC=O)c2)ccc1OC',
 'COc1ccccc1CC(C)NCC(O)c1ccc(OC(C)=O)c(O)c1',
 'COc1ccccc1CC(C)NCC(O)c1ccc(NC=O)c(O)c1',
 'O=CNc1cc(C(O)CNc2cc(C=O)ccc2Cl)ccc1O',
 'O=CNc1cc2ccc(O)c(NC=O)c2cc1CCc1ccccc1O',
 'COc1ccccc1CC(C)NCC(O)c1cc(O)c(NC=O)cc1C',
 'CCN(CC(O)c1ccc(O)c(NC=O)c1)C(C)Cc1ccccc1OC',
 'CC(Cc1ccccc1OCC(C)(C)C)NCC(O)c1ccc(O)c(NC=O)c1',
 'COc1ccccc1CC(O)C(C)NCC(O)c1ccc(O)c(NC=O)c1',
 'COc1cc(C(C)NCC(O)c2ccc(O)c(NC=O)c2)ccc1O',
 'COc1ccc(CC(C)NCC(O)c2ccc(NC=O)c(O)c2)cc1F']

In [None]:
plot_smiles_grid(l)


### Plot Guess Molecule in SVG

In [None]:
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import AllChem
import os

# SMILES string of the molecule
smiles = "OC(CNC(C)Cc1ccc(OC)cc1)c1cc(ccc1O)NC=O"

# Create a RDKit molecule object from the SMILES string
mol = Chem.MolFromSmiles(smiles)

# Generate 2D coordinates for the molecule
AllChem.Compute2DCoords(mol)

# Create a drawer object
d = Draw.MolDraw2DSVG(300, 300)

# Draw the molecule
d.DrawMolecule(mol)

# End the drawing
d.FinishDrawing()

# Get the SVG as a string
svg = d.GetDrawingText()

# Specify the path where you want to save the SVG file
save_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/___FIGURES_PAPERS/Figures_Paper_2/guess_molecule.svg"

# Ensure the directory exists
os.makedirs(os.path.dirname(save_path), exist_ok=True)

# Save the SVG to a file
with open(save_path, 'w') as f:
    f.write(svg)

print(f"SVG has been saved to: {save_path}")

Further Analysis

In [None]:
import os
import pickle

# Define the directory path
directory = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Experiment_baseline_PC_ZINC/ZINC_250_350_4000'

# Get list of all PKL files in the directory
pkl_files = [f for f in os.listdir(directory) if f.endswith('.pkl')]

# Sort the files to ensure consistent ordering
pkl_files.sort()

# Load the first PKL file
first_file_path = os.path.join(directory, pkl_files[0])

# Load the data from the first PKL file
with open(first_file_path, 'rb') as f:
    data = pickle.load(f)

# Print the file name we loaded
print(f"Loaded file: {pkl_files[0]}")

# You can now examine the structure of 'data'