In [1]:
import torch
import pandas as pd
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt
from matplotlib import colors
from tqdm import tqdm
import os
from sklearn import metrics
from torch.optim.lr_scheduler import LambdaLR
import shap
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
import time
import random
import json
import gc
from efficient_kan import KAN
import collections.abc as c

sns.set_style("darkgrid")

In [2]:

device = torch.device('cpu')

In [3]:
smarts = {'alkane':'[CX4;H0,H1,H2,H4]',
                   'methyl':'[CH3]',
                   'alkene':'[CX3]=[CX3]',
                   'alkyne':'[CX2]#C',
                   'alcohols':'[#6][OX2H]',
                   'amines':'[NX3;H2,H1;!$(NC=O)]', 
                   'nitriles':'[NX1]#[CX2]', 
                   'aromatics':'[$([cX3](:*):*),$([cX2+](:*):*)]',
                   'alkyl halides':'[#6][F,Cl,Br,I]', 
                   'esters':'[#6][CX3](=O)[OX2H0][#6]', 
                   'ketones':'[#6][CX3](=O)[#6]',
                   'aldehydes':'[CX3H1](=O)[#6]', 
                   'carboxylic acids':'[CX3](=O)[OX2H1]', 
                   'ether': '[OD2]([#6])[#6]',
                   'acyl halides':'[CX3](=[OX1])[F,Cl,Br,I]',
                   'amides':'[NX3][CX3](=[OX1])[#6]',
                   'nitro':'[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]',
                   'heterocyclic': '[!#6;!R0]',
                   'aryl chlorides': '[Cl][c]',
                   'carboxylic esters': '[CX3;$([R0][#6]),$([H1R0])](=[OX1])[OX2][#6;!$(C=[O,N,S])]',
                   'alkyl aryl ethers': '[OX2](c)[CX4;!$(C([OX2])[O,S,#7,#15,F,Cl,Br,I])]',
                   'phenols': '[OX2H][c]'}

func_group_names = pd.Series(smarts.keys())
func_group_names

0                alkane
1                methyl
2                alkene
3                alkyne
4              alcohols
5                amines
6              nitriles
7             aromatics
8         alkyl halides
9                esters
10              ketones
11            aldehydes
12     carboxylic acids
13                ether
14         acyl halides
15               amides
16                nitro
17         heterocyclic
18       aryl chlorides
19    carboxylic esters
20    alkyl aryl ethers
21              phenols
dtype: object

## Paths to directiries

In [4]:
preprocDirPath='D:\\SHAPE\\1VII\\functional-cnn-main\\ALL\\preprocessed_data'
funcGroupDirPath='D:\\SHAPE\\1VII\\functional-cnn-main\\ALL\\functional_groups'
save_dir = os.path.join("C:\\Users\\tomek_\\Desktop\\saved_computations", "same_func_group_predictionsTALLKAN1")
model_selection="nowy_KAN25VI"
trainedModelDirPath="C:\\Users\Tomek_\\Desktop\\trained_models\\"

trained_models_path = os.path.join(trainedModelDirPath, model_selection)

In [5]:
class FunctionalGroupsDataset(Dataset):
    '''
    PyTorch compatible dataset of functional groups files

    Parameters
    ----------

    func_group : str
        Name of the functional group to retrieve data from.
    
    convert_to : str, default None
        Converts all data to specified data type.

    '''
    def __init__(self, func_group: str, convert_to: str = None) -> None:
        self.convert_to = convert_to
        self.func_group_number = np.where(func_group_names.values == func_group)[0][0]
        self.main_dir = os.path.join('..', 'ALL')
        self.func_group = func_group

        #NIST IDs that passed preprocessing
        preprocessed_data_dir = pd.DataFrame(os.listdir(preprocDirPath))


        if func_group not in os.listdir(funcGroupDirPath):
            raise ValueError(f'{func_group} is not present in our database.')
        else:
            #All NIST IDs of specific functional group
            func_group_data_dir = pd.DataFrame(os.listdir(os.path.join(funcGroupDirPath, func_group)))

        #NIST IDs of specific functional group that passed preprocessing
        to_sample = pd.merge(preprocessed_data_dir, func_group_data_dir, on = [0, 0], how = 'outer', indicator = True).query('_merge=="left_only"')[0]

        #Equinumerous dataset of preprocessed functional group NIST IDs and shuffled from every other functional groups
        if len(to_sample) < len(func_group_data_dir):
            func_group_data_dir = func_group_data_dir.sample(len(to_sample))
            self.data = pd.concat([to_sample.sample(len(to_sample)), func_group_data_dir], axis = 0)
        else:
            self.data = pd.concat([to_sample.sample(len(func_group_data_dir)), func_group_data_dir], axis = 0)

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        file_path = os.path.join(preprocDirPath, self.data.iloc[index][0])
        file = pd.read_csv(file_path)
        spectra_type = file['spectraType'][0]
    
        if not self.convert_to:
            spectra = torch.nan_to_num(torch.tensor(file['y'].values, requires_grad=True)).to(torch.float)

        elif self.convert_to.lower() not in ('absorbance', 'absorbancja', 'transmittance', 'transmitancja'):
            raise ValueError(f'Cant convert to {self.convert_to}.')
        
        elif str(spectra_type).lower() not in ('converted_to_absorbance', "absorbance"):
            spectra = torch.nan_to_num(torch.tensor(np.abs(1 - file['y'].values), requires_grad=True)).to(torch.float)
        
        else:
            spectra = torch.nan_to_num(torch.tensor(file['y'].values, requires_grad=True)).to(torch.float)

        #Reshapes it as required
        spectra = spectra.reshape(1, 1, 3106)
        
        #Prevents unknown problem with NaN from before
        func_group = torch.nan_to_num(torch.tensor(file['funcGroups'].values[self.func_group_number], requires_grad=True)).to(torch.float)

        return spectra, func_group
    
    def get_nist_id(self, index):
        return self.data.iloc[index][0]
    
    def get_func_groups(self, index):
        file_path = os.path.join(preprocDirPath, self.data.iloc[index][0])
        file = pd.read_csv(file_path)

        return torch.nan_to_num(torch.tensor(file['funcGroups'].values[self.func_group_number], requires_grad=True)).to(torch.float)
    
    def sample(self, n):
        indexes = random.sample(range(self.data.__len__()), n)
        return (torch.cat([self.__getitem__(idx)[0] for idx in indexes]).reshape(n, 1, 1, 3106), torch.tensor([self.__getitem__(idx)[1] for idx in indexes]))

In [6]:
preprocessed_datadir = pd.Series(os.listdir(preprocDirPath), name=0)

In [7]:
functional_groups_datadirs = {}

for name in func_group_names:
    functional_groups_datadirs[name] = pd.Series(os.listdir(os.path.join(funcGroupDirPath, name)), name=0)

In [8]:
to_sample = pd.merge(preprocessed_datadir, functional_groups_datadirs['alkane'], on = [0,0], how = 'outer', indicator = True).query('_merge=="left_only"')[0]

In [9]:
spectra_types = {}

In [10]:
for file in preprocessed_datadir:
    inside = pd.read_csv(os.path.join(preprocDirPath, file))
    if inside['spectraType'][0] not in spectra_types.keys():
        spectra_types[inside['spectraType'][0]] = 1
    else:
        spectra_types[inside['spectraType'][0]] += 1

In [11]:
spectra_types

{'TRANSMITTANCE': 2955, 'ABSORBANCE': 5834}

In [None]:
for func_group in func_group_names:
    for file in functional_groups_datadirs[func_group]:
        inside = pd.read_csv(os.path.join(funcGroupDirPath, func_group, file))
        if func_group not in spectra_types.keys():
            spectra_types[func_group] = {}

        if inside['spectraType'][0] not in spectra_types[func_group].keys():
            spectra_types[func_group][inside['spectraType'][0]] = 1
        else:
            spectra_types[func_group][inside['spectraType'][0]] += 1
            
      

In [None]:
#Creates hashmaps where data is further retrieved by functional group name and its whether its training or test set

def createHashmaps(test_ratio: float = 0.3,
                    batch_size: int = 128):
    
    '''
    Creates hashmaps containing FTIR data.

    Parameters
    ----------

    test_ratio : float, default 0.3
        Ratio of test dataset.
    
    batch_size : int, default 128
        Size of the batch.
    '''

    func_groups_data, func_groups_datasets, func_groups_dataloaders = {}, {}, {}
    for data_directory in os.listdir(funcGroupDirPath):
        print(data_directory)
        if data_directory=="carboxylic esters": 
            continue
        dataset = FunctionalGroupsDataset(data_directory, convert_to = 'absorbance')

        training_dataset, test_dataset = random_split(dataset, [1 - test_ratio, test_ratio], torch.Generator())

        func_groups_datasets[data_directory] = {'training': training_dataset, 'test': test_dataset}

        func_groups_dataloaders[data_directory] = {'training': DataLoader(training_dataset, batch_size = batch_size, shuffle = True), 'test' : DataLoader(test_dataset, batch_size = batch_size, shuffle = False)}

    return func_groups_data, func_groups_datasets, func_groups_dataloaders

In [None]:
func_groups_data, func_groups_datasets, func_groups_dataloaders = createHashmaps()

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
device

In [None]:
class ChemicyBad(torch.nn.Module):
    def __init__(self, kan_layers: list) -> None:
        super(ChemicyBad, self).__init__()

        self.tr_accuracy = None
        self.te_accuracy = None
        self.tr_loss = None
        self.te_loss = None

        kernel_size_1 = 5
        stride_conv_1 = 1

        stride_pool_1 = 3
        filter_size_1 = 3

        self.conv_1 = torch.nn.Conv1d(1, 10, kernel_size=kernel_size_1, stride=stride_conv_1)

        self.pool_1 = torch.nn.MaxPool1d(filter_size_1, stride_pool_1)

        self.conv_2 = torch.nn.Conv1d(10, 10, kernel_size=kernel_size_1, stride=stride_conv_1)

        self.pool_2 = torch.nn.MaxPool1d(filter_size_1, stride_pool_1)

        self.conv_3 = torch.nn.Conv1d(10, 10, kernel_size=kernel_size_1, stride=stride_conv_1)

        self.pool_3 = torch.nn.MaxPool1d(filter_size_1, stride_pool_1)

        self.kan = KAN([1130] + kan_layers).to(device)


    def forward(self, x):
        bs, c, h, w = x.shape
        x = x.reshape(bs, c, h*w)
        out = self.conv_1(x)
        out = F.tanh(out)
        out = self.pool_1(out)

        out = self.conv_2(out)
        out = F.tanh(out)
        out = self.pool_2(out)
        
        out = self.conv_3(out)
        out = F.tanh(out)
        out = self.pool_3(out)

        out = out.reshape(x.shape[0],1,out.shape[2]*out.shape[1])

        out = self.kan(out)
        out = F.sigmoid(out)
        return out.squeeze(1)

    
def train(num_epochs, lossFn, optimizer, group: str, weight_decay, learning_rate, kan_layers, lambda_lr, seed=42, save: bool = False, plot: bool = False, disable_verbose: bool = False, iter=0, dirpath="C:\\Users\\tomek\\Desktop\\cc\\res"):
    train_losses, test_losses = [], []


    model = ChemicyBad(kan_layers).to(device)

    #torch.manual_seed(seed)

    optimizer = optimizer(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: lambda_lr ** epoch)
    #scheduler = ExponentialLR(optimizer, gamma=0.9)

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.
        with tqdm(func_groups_dataloaders[group]['training'], disable=disable_verbose) as pbar:
            f1 = 0
            lista_grup = np.zeros(1)
            list_func = np.zeros(1)
            for y_batch, func_group_batch in pbar:
                y_batch, func_group_batch = y_batch.to(device), func_group_batch.to(device)

                optimizer.zero_grad()
                output = model(y_batch).squeeze()
                loss = lossFn(output, func_group_batch)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()

                for i in range(output.shape[0]):
                    #F1 metryka
                    predicted_well = ((torch.round(output[i]) == func_group_batch[i]) * 1 * func_group_batch[i]).item()

                    lista_grup += predicted_well
                    list_func += func_group_batch[i].item()

                f1 = metrics.f1_score(torch.Tensor.cpu(func_group_batch).detach().numpy(), torch.Tensor.cpu(torch.round(output)).detach().numpy())
                pbar.set_postfix(epoch=epoch, loss=train_loss, accuracy=(lista_grup/list_func).item(), f1=f1, lr=optimizer.param_groups[0]['lr'])

                train_acc = (lista_grup/list_func).item()
            train_loss /= len(func_groups_dataloaders[group]['training'])
            train_losses.append(train_loss)


        model.eval()
        test_loss = 0.

        with torch.no_grad():
            with tqdm(func_groups_dataloaders[group]['test'], disable=disable_verbose) as tbar:
                f1 = 0
                lista_grup = np.zeros(1)
                list_func = np.zeros(1)
                for y_batch, func_group_batch in tbar:
                    y_batch, func_group_batch = y_batch.to(device), func_group_batch.to(device)

                    output = model(y_batch).squeeze()
                    loss = lossFn(output, func_group_batch)
                    test_loss += loss.item()

                    for i in range(output.shape[0]):
                        predicted_well = ((torch.round(output[i]) == func_group_batch[i]) * 1 * func_group_batch[i]).item()

                        lista_grup += predicted_well
                        list_func += func_group_batch[i].item()
                    f1 = metrics.f1_score(torch.Tensor.cpu(func_group_batch).detach().numpy(), torch.Tensor.cpu(torch.round(output)).detach().numpy())
                    tbar.set_postfix(epoch=epoch, loss=test_loss, accuracy=(lista_grup/list_func).item(), f1=f1)

                    test_acc = (lista_grup/list_func).item()
                test_loss /= len(func_groups_dataloaders[group]['test'])
                test_losses.append(test_loss)

        scheduler.step()
        

    #if save:
    #    torch.save(model, os.path.join('..','ALL', 'trained_models', group, f'{group}_{time.strftime(r"%m:%d:%Y, %H:%M:%S")}.pt'))

    if plot:
        plt.plot(train_losses, label='training loss')
        plt.plot(test_losses, label='test loss')
        plt.legend()
        plt.title(f'{group}')
        plt.plot()
        plt.savefig(dirpath+"\\Learning_progress_"+str(iter)+".png")
        plt.show()

    model.tr_accuracy = np.round(train_acc,2)
    model.te_accuracy = np.round(test_acc,2)
    model.tr_loss = np.round(train_losses[-1],2)
    model.te_loss = np.round(test_losses[-1],2)

    return model

# Loadding CNN-KAN pretrained models 

In [None]:
loaded_models = {}
print(trained_models_path)
for model_path in os.listdir(trained_models_path):
    model_name = model_path.split('_')[0]
    loaded_models[model_name] = torch.load(os.path.join(trained_models_path, model_path), weights_only=False)

In [None]:
explainer_models: dict[str, shap.GradientExplainer] = {}

for func_group in func_groups_dataloaders.keys():
    #if func_group!='alkane' and func_group!='aromatics' and func_group!='nitriles' and func_group!='alcohols' and func_group!='ketones' and func_group!='nitro' and func_group!='phenols' and func_group!='carboxylic acids':
    #    continue
    for idx, (spectra, func_groups) in enumerate(func_groups_dataloaders[func_group]['training']):
        if idx != 0:
            full_test_spectra = torch.cat((full_test_spectra, spectra))  # noqa: F821
        else:
            full_test_spectra = spectra
    zzz=full_test_spectra[0:1000,:,:,:]
    explainer_models[func_group] = shap.GradientExplainer(
        loaded_models[func_group].cpu(),
        zzz,
        batch_size=128
    )

In [None]:
def get_shap_values(func_group: str, target_func_group: str) -> tuple[np.array, torch.tensor, torch.tensor]:
    '''
    Returns SHAP values, functional groups and full spectras of a specific functional group. 

    Parameters
    ----------

    func_group : str
        Name of the functional group.
    
    target_func_group : str
        Name of the target functional group.
    '''

    
    for idx, (spectra, func_groups) in enumerate(func_groups_dataloaders[target_func_group]['test']):
        if idx != 0:
            full_test_spectra = torch.cat((full_test_spectra, spectra))
            full_func_groups = torch.cat((full_func_groups, func_groups))
        else:
            full_test_spectra = spectra
            full_func_groups = func_groups
        
    
    return explainer_models[func_group].shap_values(full_test_spectra[0:101,:,:,:], nsamples=100), full_func_groups, full_test_spectra

In [None]:
explainer_shaps: dict[str, np.array] = {}
shap_func_groups: dict[str, torch.tensor] = {}
full_test_spectras: dict[str, torch.tensor] = {}



In [None]:
# Check if all required files exist
if not all(file in os.listdir(save_dir) for file in ["explainer_shaps.pkl", "shap_func_groups.pkl", "full_test_spectras.pkl"]):
    # Calculate SHAP values for each functional group
    for func_group in func_groups_dataloaders.keys():
        print(func_group)
        #if func_group!='alkane' and func_group!='aromatics' and func_group!='nitriles' and func_group!='alcohols' and func_group!='ketones' and func_group!='nitro' and func_group!='phenols' and func_group!='carboxylic acids':
        #if func_group!='nitriles' :
        #    print("pomijam")
        #    continue
        explainer_shaps[func_group], shap_func_groups[func_group], full_test_spectras[func_group] = get_shap_values(func_group, func_group)
    
    # Save the data as pickle files
    import pickle
    with open(os.path.join(save_dir, "explainer_shaps.pkl"), "wb") as f:
        pickle.dump(explainer_shaps, f)
    with open(os.path.join(save_dir, "shap_func_groups.pkl"), "wb") as f:
        pickle.dump(shap_func_groups, f)
    with open(os.path.join(save_dir, "full_test_spectras.pkl"), "wb") as f:
        pickle.dump(full_test_spectras, f)
else:
    # Load the data from pickle files
    import pickle
    with open(os.path.join(save_dir, "explainer_shaps.pkl"), "rb") as f:
        explainer_shaps = pickle.load(f)
    with open(os.path.join(save_dir, "shap_func_groups.pkl"), "rb") as f:
        shap_func_groups = pickle.load(f)
    with open(os.path.join(save_dir, "full_test_spectras.pkl"), "rb") as f:
        full_test_spectras = pickle.load(f)

In [None]:
def plot_shap_values(idx, functional_group, shap_func_groups, shap_values: dict[str, np.array], save: bool, shap_limit: float) -> None:
    x = range(670, 3776)
    
    y = func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()

    with torch.no_grad():
        model_prediction = loaded_models[functional_group].forward(func_groups_datasets[functional_group]['test'][idx][0].reshape(1,1,1,3106))[0][0]

    plt.plot(
        x,
        y,
        alpha=0.2,
        c='grey'
    )
   
    plt.scatter(
        x,
        y,
        c=shap_values[functional_group][idx].flatten(),
        cmap='bwr',
        marker='D',
        s=4,
        norm=colors.TwoSlopeNorm(vcenter=0, vmin=-shap_limit, vmax=shap_limit)
    )
    #plt.title(f"SHAP values for {"" if model_prediction >= 0.5 else "non"} {functional_group} object\nActual group: {shap_func_groups[functional_group][idx].item()}\nModel prediction: {model_prediction:.2f}\n{func_groups_datasets[functional_group]['test'].dataset.get_nist_id(func_groups_datasets[functional_group]['test'].indices[idx]).split('.')[0]}")
    plt.title(f'SHAP values for {"" if model_prediction >= 0.5 else "non"}   {functional_group}  object\nActual group: {shap_func_groups[functional_group][idx].item()}  \n\nModel prediction: {model_prediction:.2f}  ')
   
    plt.ylabel("Normalized absorbance")
    plt.xlabel(r"Wavenumber [cm$^{-1}$]")
    plt.tight_layout()
    plt.colorbar()
    if save:
        plt.savefig('C:\\Users\\Tomek_\\Desktop\\rys1\\fig' +str(idx)+'.png', format='png',dpi=300)
    plt.show()

In [None]:
#Plot shap values for spectrum with idx=52 for nitriles functional group
plot_shap_values(52, 
                     'nitriles', 
                     shap_func_groups, 
                     explainer_shaps, 
                     False, 
                     0.002)

In [None]:
#function which add SHAP values for specified spectrum for characteristic regions for different functional groups

def sumator(idx, functional_group, shap_func_groups, shap_values: dict[str, np.array], save: bool, shap_limit: float) -> None:
       
    x = range(670, 3776)        
    if functional_group=='alkane':
        x=list(range(1340,1395))
        x1=list(range(1430,1480))  
        x2=list(range(2850,2990)) #alkane
        x=x+x1+x2        
    if functional_group=='alkene':
        x=list(range(685,995))
        x1=list(range(1600,1680))  
        x2=list(range(3000,3100)) #alkene
        x=x+x1+x2        
    if functional_group=='alkyne':
        x=list(range(2100,2250))
        x1=list(range(3200,3310)) #alkyne        
        x=x+x1        
    if functional_group=='aromatics':
        x=list(range(680,900))
        x1=list(range(1440,1620))  
        x2=list(range(3000,3100)) #romatics
        x=x+x1+x2        
    if functional_group=='alcohols':
        x=list(range(1000,1300))
        x1=list(range(3200,3650)) #       
        x=x+x1         
    if functional_group=="amines":
        x=list(range(3250,3550))        
    if functional_group=='nitriles':
        x = range(2200, 2280) #nitryle

    if functional_group=='aldehydes':    
        x=list(range(1680,1715)) #       
        x1=list(range(1720,1740)) #  
        x2=list(range(2700,2900))
        x=x+x1+x2
    if functional_group=='ketones':
        x=list(range(1650,1700)) #ketones
        x1=list(range(1705,1750)) #ketones
        x=x+x1
    if functional_group=='esters':
        x=list(range(1000,1300))
        x1=list(range(1715,1730)) #    
        x2=list(range(1735,1765)) #
        x=x+x1+x2
    if functional_group=='carboxylic acids':
        x=list(range(1000,1300))
        x1=list(range(1680,1725))
        x2=list(range(2500,3200))
        x=x+x1+x2
    if functional_group=='amides':
        x=list(range(1630,1700))
        x1=list(range(3150,3500)) #       
        x=x+x1
    if functional_group=='nitro':
        x=list(range(1300,1390))
        x1=list(range(1490,1570)) #ketones       
        x=x+x1        
    if functional_group=='phenols':
        x=range(3200,3700) #phenols
    if functional_group=='methyl':
        x=list(range(1365,1395))
        x1=list(range(1430,1470)) #    
        x2=list(range(2860,2880)) #
        x3=list(range(2950,2970)) #
        x=x+x1+x2+x3

    if functional_group=='alkane':
        y = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[670:725]) #alkane
        y1 = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[760:810]) #alkane        
        y2 = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[2180:2320]) #alkane
        y=y+y1+y2
        
    if functional_group=='alkene':
        y = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[15:325]) #alkene
        y1 = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[930:1010]) #alkene        
        y2 = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[2330:2430]) #alkene
        y=y+y1+y2
    if functional_group=='alkyne':
        y = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[1430:1580]) #alkyne
        y1 = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[2530:2640]) #alkyne      
        y=y+y1
    if functional_group=='aromatics':
        y = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[10:230]) #aromatics
        y1 = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[770:950]) #aromatics       
        y2 = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[2330:2430]) #aromatics
        y=y+y1+y2
    if functional_group=='alcohols':
        y = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[330:630])
        y1 = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[2530:2980]) #
        y=y+y1  
    if functional_group=='amines':
        y = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[2580:2880])
    if functional_group=='nitriles':
        y = func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[1530:1610]  #nitryles
    if functional_group=='aldehydes':    
        y = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[1010:1045]) #aldehydes
        y1 = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[1050:1070]) #aldehydes       
        y2 = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[2030:2230]) #aldehydes
        y=y+y1+y2
    if functional_group=='ketones':
        y = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[980:1030]) #ketones
        y1 = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[1035:1080]) #ketones
        y=y+y1
    if functional_group=='esters':
        y = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[330:630])
        y1 = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[1045:1060]) #
        y2 = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[1065:1095]) #nitro
        y=y+y1+y2
    if functional_group=='carboxylic acids':
        y = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[330:630])
        y1 = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[1010:1055]) #
        y2 = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[1830:2530]) #nitro
        y=y+y1+y2
    if functional_group=='amides':
        y = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[960:1030])
        y1 = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[2480:2830]) #
        y=y+y1  
    if functional_group=='nitro':
        y = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[630:720])
        y1 = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[820:900]) #nitro
        y=y+y1 
    if functional_group=='phenols':
        y = func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[2530:3030] #phenols
    if functional_group=='methyl':
        y = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[695:725])
        y1 = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[760:800]) #
        y2 = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[2190:2210]) #n
        y3 = list(func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()[2280:2300]) #n
        y=y+y1+y2+y3
    
    with torch.no_grad():
        model_prediction = loaded_models[functional_group].forward(func_groups_datasets[functional_group]['test'][idx][0].reshape(1,1,1,3106))[0][0]
    #print(x)
    plt.plot(
        x,
        y,
        alpha=0.2,
        c='grey'
    )
    
    ccc='x',
    if functional_group=='alkane':
        ccc=list(shap_values[functional_group][idx].flatten()[670:725]) #alkanes
        ccc1=list(shap_values[functional_group][idx].flatten()[760:810]) #alkanes
        ccc2=list(shap_values[functional_group][idx].flatten()[2180:2320]) #alkanes
        ccc=ccc+ccc1+ccc2
    if functional_group=='alkene':
        ccc=list(shap_values[functional_group][idx].flatten()[15:325]) #alkanes
        ccc1=list(shap_values[functional_group][idx].flatten()[930:1010]) #alkanes
        ccc2=list(shap_values[functional_group][idx].flatten()[2330:2430]) #alkanes
        ccc=ccc+ccc1+ccc2
    if functional_group=='alkyne':
        ccc=list(shap_values[functional_group][idx].flatten()[1430:1580]) #alkenes
        ccc1=list(shap_values[functional_group][idx].flatten()[2530:2640]) #alkenes        
        ccc=ccc+ccc1
    if functional_group=='aromatics':
        ccc=list(shap_values[functional_group][idx].flatten()[10:230]) #aromatics
        ccc1=list(shap_values[functional_group][idx].flatten()[770:950]) #aromatics
        ccc2=list(shap_values[functional_group][idx].flatten()[2330:2430]) #aromatics
        ccc=ccc+ccc1+ccc2
    if functional_group=='alcohols':
        ccc=list(shap_values[functional_group][idx].flatten()[330:630]) #n
        ccc1=list(shap_values[functional_group][idx].flatten()[2530:2980]) #
        ccc=ccc+ccc1
    if functional_group=='amines':
        ccc=list(shap_values[functional_group][idx].flatten()[2580:2880]) #amines        
    if functional_group=='nitriles':    
        ccc=shap_values[functional_group][idx].flatten()[1530:1610] #nitriles
    if functional_group=='aldehydes':
        ccc=list(shap_values[functional_group][idx].flatten()[1010:1045]) #aldehydes
        ccc1=list(shap_values[functional_group][idx].flatten()[1050:1070]) #aldehydes
        ccc2=list(shap_values[functional_group][idx].flatten()[2030:2230]) #aldehydes
        ccc=ccc+ccc1+ccc2
    if functional_group=='ketones':    
        ccc=shap_values[functional_group][idx].flatten()[980:1030] #ketones
        ccc1=shap_values[functional_group][idx].flatten()[1035:1080] #ketones
        ccc=ccc+ccc1
    if functional_group=='esters':
        ccc=list(shap_values[functional_group][idx].flatten()[330:630]) #n
        ccc1=list(shap_values[functional_group][idx].flatten()[1045:1060]) #
        ccc2=list(shap_values[functional_group][idx].flatten()[1065:1095]) #
        ccc=ccc+ccc1+ccc2
    if functional_group=='carboxylic acids':
        ccc=list(shap_values[functional_group][idx].flatten()[330:630]) #n
        ccc1=list(shap_values[functional_group][idx].flatten()[1010:1055]) #
        ccc2=list(shap_values[functional_group][idx].flatten()[1830:2530]) #
        ccc=ccc+ccc1+ccc2
    if functional_group=='amides':
        ccc=list(shap_values[functional_group][idx].flatten()[960:1030]) #n
        ccc1=list(shap_values[functional_group][idx].flatten()[2480:2830]) #
        ccc=ccc+ccc1
    if functional_group=='nitro':
        ccc=list(shap_values[functional_group][idx].flatten()[630:720]) #nitro
        ccc1=list(shap_values[functional_group][idx].flatten()[820:900]) #nitro
        ccc=ccc+ccc1 
    if functional_group=='phenols':    
        ccc=shap_values[functional_group][idx].flatten()[2530:3030] #phenols
    if functional_group=='methyl':
        ccc=list(shap_values[functional_group][idx].flatten()[695:725]) #n
        ccc1=list(shap_values[functional_group][idx].flatten()[760:800]) #
        ccc2=list(shap_values[functional_group][idx].flatten()[2190:2210]) #
        ccc3=list(shap_values[functional_group][idx].flatten()[2280:2300]) #
        ccc=ccc+ccc1+ccc2+ccc3
        
    #print(np.sum(ccc))
    plt.scatter(
        x,
        y,
        c=ccc,
        cmap='bwr',
        marker='D',
        s=4,
        norm=colors.TwoSlopeNorm(vcenter=0, vmin=-shap_limit, vmax=shap_limit)
    )
    return(np.sum(ccc));

In [None]:
def calcMean(fg,posneg,onlyCorrectDec):
    sumka=0.0;
    iter=0;
    for idx in range(100): 
        
        z=shap_func_groups[fg][idx].item()
        model_prediction = loaded_models[fg].forward(func_groups_datasets[fg]['test'][idx][0].reshape(1,1,1,3106))[0][0]
        r=(np.round(model_prediction.item()))
        co=abs(r-z)<0.01;
        war=False;
        if posneg=='positive' and onlyCorrectDec==False:
            war= (z==1.0)
        if posneg=='negative' and onlyCorrectDec==False:
            war= (z==0.0)
        if posneg=='positive' and onlyCorrectDec==True:
            war= (z==1.0 and model_prediction>=0.5)
        if posneg=='negative' and onlyCorrectDec==True:
            war= (z==0.0 and model_prediction<0.5)
                   
        if war:
 
            iter=iter+1
            wklad=sumator(idx,   fg,   shap_func_groups,    explainer_shaps, True,     0.002)
            sumka=sumka+wklad;
            #print(idx,np.round(wklad,3),co)
    
    print('sumka',sumka)
    srednia=sumka/iter
    print('srednia',srednia);
    print('iter',iter)
 

In [None]:
plt.clf()
plt.show()
import random
#fg='alcohols'
#fg='aldehydes'
#fg='alkane'
fg='alkene'
#fg='alkyne'
#fg='amides'
#fg='amines'
#fg='aromatics'
#fg='carboxylic acids'
#fg='methyl'
#fg='phenols'
#fg='ketones'
#fg='nitriles'
#fg='esters'
#fg='nitro'
#fg='aromatics'

print('All positive')
calcMean(fg,'positive',False)

print('All negative')
calcMean(fg,'negative',False)




In [None]:
##calculate the accuracy (in %) for given dataset
print('accuracy of CNN-KAN model (in%)  for given subset of spectra for ',fg)
import math
correct=0;
for idx in range(100):         
        z=shap_func_groups[fg][idx].item()
        model_prediction = loaded_models[fg].forward(func_groups_datasets[fg]['test'][idx][0].reshape(1,1,1,3106))[0][0].item()
        r=(np.round(model_prediction))
        if abs(r-z)<0.01:
            correct+=1;
print(correct)

In [None]:
def calcToPlot1(functional_group):
    sumka=[]
    sumkaPlus=[];
    sumkaMinus=[];
    firstPlus=True
    firstMinus=True
    for idx in range(100):     
        z=shap_func_groups[functional_group][idx].item()  
        if z==1.0: 
            if firstPlus==True:
                #sumka=explainer_shaps[functional_group][idx].flatten()
                sumkaPlus=explainer_shaps[functional_group][idx].flatten()
                #print(sumka.shape)
                firstPlus=False;
            else:            
                #sumka=-1.0*explainer_shaps[functional_group][idx].flatten()
                sumkaPlus+=explainer_shaps[functional_group][idx].flatten()
        else:
            if firstMinus==True:           
                #sumka+=explainer_shaps[functional_group][idx].flatten()
                sumkaMinus=-1*explainer_shaps[functional_group][idx].flatten()
                firstMinus=False;
            else: 
                #sumka-=explainer_shaps[functional_group][idx].flatten()
                sumkaMinus-=explainer_shaps[functional_group][idx].flatten()     
    x = range(670, 3776)
    return [x,sumkaPlus,sumkaMinus]

In [None]:
#plot S_+,S_- and S for four functional groups (article Fig. 5) 

plt.clf()
plt.show()
import random
func_gr=['nitriles','nitro','phenols','ketones','carboxylic acids']

labelFontSize=16
labelBigFontSize=18
labelAxisSize=14
res=[];

for i in range(len(func_gr)):
    res.append(calcToPlot1(func_gr[i]));


fig, axs = plt.subplots(nrows=len(func_gr), ncols=1, figsize=(9, 15),constrained_layout=True)


# clear subplots
for ax in axs:
    ax.remove()

# add subfigure per subplot
gridspec = axs[0].get_subplotspec().get_gridspec()
subfigs = [fig.add_subfigure(gs) for gs in gridspec]

for row, subfig in enumerate(subfigs):
    subfig.suptitle(func_gr[row].capitalize(),fontsize=labelBigFontSize)

    # create 1x3 subplots per subfig
    axs = subfig.subplots(nrows=1, ncols=3,sharey='row')
    temp=res[row]
    for col, ax in enumerate(axs):
        if col==0:
            ax.plot(temp[0],temp[1],color='salmon')
            ax.tick_params(labelsize=labelAxisSize)
        if col==1:
            ax.plot(temp[0],temp[2],color='cornflowerblue')
            ax.tick_params(labelsize=labelAxisSize)
        if col==2:
            ax.plot(temp[0],temp[1]+temp[2],color='mediumseagreen')
            ax.tick_params(labelsize=labelAxisSize)
  

fig.text(0.5, -0.025, 'Wavenumber [cm$^{-1}$]', ha='center',fontsize=labelFontSize)
fig.text(-0.025, 0.5, 'Regions of spectra important for detection of given functional group [arb.units]', va='center', rotation='vertical',fontsize=labelFontSize)

#Export plot as png file to specified directory
fig.savefig('C:\\Users\\Tomek_\\Desktop\\explainerArticle\\pp.png',bbox_inches='tight')

In [None]:
#different version of function for plotting SHAP values for given spectrum (without visible colormap), other aspect ratio and fonts (usefull for creaction  of figures for article)
def plot_shap_valuesNV(idx, functional_group, shap_func_groups, shap_values: dict[str, np.array], save: bool, shap_limit: float) -> None:
    from matplotlib import ticker
    x = range(670, 3776)
    #print(functional_group)
    #print(func_groups_datasets.keys())
    #print(shap_values.keys())
    y = func_groups_datasets[functional_group]['test'][idx][0].flatten().detach().numpy()

    labelFontSize=16
    labelBigFontSize=18
    labelAxisSize=14
    labelColorBarFontSize=14

    with torch.no_grad():
        model_prediction = loaded_models[functional_group].forward(func_groups_datasets[functional_group]['test'][idx][0].reshape(1,1,1,3106))[0][0]

    plt.plot(
        x,
        y,
        alpha=0.55,
        c='dimgrey'
    )
    #print('aa',shap_values[functional_group][idx].shape)
    #print('xx',shap_values[functional_group][idx].flatten().shape)

    plt.scatter(
        x,
        y,
        c=shap_values[functional_group][idx].flatten(),
        cmap='bwr',
        marker='D',
        s=4,
        norm=colors.TwoSlopeNorm(vcenter=0, vmin=-shap_limit, vmax=shap_limit)
    )
    #plt.title(f"SHAP values for {"" if model_prediction >= 0.5 else "non"} {functional_group} object\nActual group: {shap_func_groups[functional_group][idx].item()}\nModel prediction: {model_prediction:.2f}\n{func_groups_datasets[functional_group]['test'].dataset.get_nist_id(func_groups_datasets[functional_group]['test'].indices[idx]).split('.')[0]}")
    plt.title(f'SHAP values for {"" if model_prediction >= 0.5 else "non"}   {functional_group}  object\nActual group: {shap_func_groups[functional_group][idx].item()}  \n\nModel prediction: {model_prediction:.2f}  ', fontsize=labelFontSize)
    plt.tick_params(labelsize=labelAxisSize)
    #plt.ylabel("Normalized absorbance",fontsize=labelBigFontSize)
    #plt.xlabel(r"Wavenumber [cm$^{-1}$]",fontsize=labelBigFontSize)
    plt.tight_layout()
    if False:
        colbar=plt.colorbar(orientation='horizontal')
        colbar.ax.tick_params(labelsize=labelColorBarFontSize)   
        colbar.set_ticks([-0.002,-0.001,0,0.001,0.002])
    plt.xticks(fontsize=labelAxisSize)
    plt.yticks(fontsize=labelAxisSize)
    if save:
        plt.savefig('C:\\Users\\Tomek_\\Desktop\\rysKetonesNW1\\fig' +str(idx)+'.png', format='png',dpi=300,bbox_inches='tight')
    plt.show()

In [None]:
import random
fg='ketones'
for a in range(1): 
    i=random.randint(0,100)
    plot_shap_valuesNV(i, 
                     fg,
                     shap_func_groups, 
                     explainer_shaps, 
                     False, 
                     0.002)