# Ploting predictions along a solvent sequence

The code in this section was used to plot the predictions of the model for the 'Template' and 'Scratch' predictions against the true values, showing the range and standard devation of predictions as well as the standard devidation of each property.

In [None]:
import json
import pandas as pd
import copy
import numpy as np
import matplotlib.pyplot as plt


In [None]:
with open('predict_properties/test_predictions.json', 'r') as res_file:
        results_dict = json.load(res_file)

with open('data/normalisation_stats.json', 'r') as norm_file:
        norm_dict = json.load(norm_file)
print(norm_dict)
base_values = pd.read_csv('data/test_values.csv')

std_vals = {}
print(norm_dict)
for key, vals in norm_dict.items():
        print(vals)
        std_vals[key] = vals['std']

with open('predict_properties/template_preds.json', 'r') as template_file:
        temp_dict = json.load(template_file)

In [None]:
def avg_predictions(dictionary):
    predictions_averaged = {}

    for smiles, properties in dictionary.items():
        predictions_averaged[smiles] = {}

        for prop, indices in properties.items():
            total_sum = 0
            total_count = 0
            
            # Calculate the sum and count of all predictions
            for index, prediction_list in indices.items():
                total_sum += sum(prediction_list)
                total_count += len(prediction_list)
            
            # Calculate the average, handling the case where total_count is zero to prevent division by zero
            if total_count > 0:
                average_prediction = total_sum / total_count
            else:
                average_prediction = 0 # or None, depending on desired behavior
            
            predictions_averaged[smiles][prop] = average_prediction
    return predictions_averaged
    
predictions_averaged = avg_predictions(results_dict)

In [None]:
true_values_dict = {}

# Use df.itertuples() for a memory-efficient way to iterate over DataFrame rows.
# 'index=False' prevents the row index from being included in the tuple.
for row in base_values.itertuples(index=False):
    # The first element of the tuple is the SMILES string
    smiles = row[1]
    
    # Initialize a new dictionary for this SMILES if it doesn't exist
    if smiles not in true_values_dict:
        true_values_dict[smiles] = {}
        
    # Iterate through the rest of the columns to get the properties and values
    # We slice the row tuple from the second element (index 1) onwards.
    # We also get the corresponding column names from df.columns, excluding 'SMILES'.
    for prop_name, value in zip(base_values.columns[1:], row[1:]):
        # Store the value in the nested dictionary
        true_values_dict[smiles][prop_name] = value

In [None]:
def _process_data(property_data):
    """Helper function to extract and sort plotting data from a dictionary."""
    n_values, means, std_devs, mins, maxs = [], [], [], [], []
    # Sort keys numerically to ensure the plot is drawn in the correct order
    sorted_ns = sorted(map(int, property_data.keys()))
    for n in sorted_ns:
        n_str = str(n)
        values = property_data[n_str]
        n_values.append(n)
        means.append(np.mean(values))
        std_devs.append(np.std(values))
        mins.append(min(values))
        maxs.append(max(values))
    return n_values, means, std_devs, mins, maxs

def compare_predictions_by_n(dict_1, dict_2, ground_truth=None, ground_truth_std=None,
                             label1='Dataset 1', label2='Dataset 2', output_dir='plots'):
    """
    Create plots comparing predictions for each property from two different datasets.
    
    Generates a side-by-side view:
    1. A plot of the predictions (with std dev) against the ground truth value.
    2. A zoomed-in plot of the mean residuals (prediction - ground truth).
    
    Args:
        dict_1 (dict): The first dictionary containing the data.
        dict_2 (dict): The second dictionary containing the data.
        ground_truth (dict): Dictionary mapping property names to their ground truth values.
        ground_truth_std (dict): Dictionary mapping property names to the ground truth standard deviation.
        label1 (str): Label for the first dataset.
        label2 (str): Label for the second dataset.
        output_dir (str): Directory to save the plots.
    """
    data1 = dict_1
    data2 = dict_2
    plt.rcParams['font.size'] = 18

    os.makedirs(output_dir, exist_ok=True)
    all_properties = set(data1.keys()) | set(data2.keys())

    for property_name in all_properties:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(22, 8), sharex=True)

        all_n_values = []
        gt_value = ground_truth.get(property_name) if ground_truth else None

        # --- Process and Plot Dataset 1 ---
        if property_name in data1:
            n_values1, means1, std_devs1, mins1, maxs1 = _process_data(data1[property_name])
            all_n_values.extend(n_values1)
            
            # Plot predictions WITH standard deviation on the first axis (ax1)
            ax1.errorbar(n_values1, means1, yerr=std_devs1, fmt='-o',
                         capsize=5, capthick=2, label=f'{label1} Mean ± Std Dev',
                         color="#dacb00", alpha=0.8)
            ax1.fill_between(n_values1, mins1, maxs1, alpha=0.4,
                             label=f'{label1} Min/Max Range', color='#dacb00')

            # If ground truth exists, plot residuals WITHOUT standard deviation on the second axis (ax2)
            if gt_value is not None:
                residuals_mean1 = np.array(means1) - gt_value
                ax2.plot(n_values1, residuals_mean1, '-o',
                         label=f'{label1} Mean Residual',
                         color="#dacb00", alpha=0.8)

        # --- Process and Plot Dataset 2 ---
        if property_name in data2:
            n_values2, means2, std_devs2, mins2, maxs2 = _process_data(data2[property_name])
            all_n_values.extend(n_values2)

            # Plot predictions WITH standard deviation on the first axis (ax1)
            ax1.errorbar(n_values2, means2, yerr=std_devs2, fmt='-s',
                         capsize=5, capthick=2, label=f'{label2} Mean ± Std Dev',
                         color="#010055", alpha=0.8)
            ax1.fill_between(n_values2, mins2, maxs2, alpha=0.4,
                             label=f'{label2} Min/Max Range', color='#010055')

            # If ground truth exists, plot residuals WITHOUT standard deviation on the second axis (ax2)
            if gt_value is not None:
                residuals_mean2 = np.array(means2) - gt_value
                ax2.plot(n_values2, residuals_mean2, '-s',
                         label=f'{label2} Mean Residual',
                         color="#010055", alpha=0.8)

        # --- Plot Ground Truth Lines and Shading ---
        if gt_value is not None:
            # Plot GT line on the predictions plot (ax1)
            ax1.axhline(gt_value, linestyle='--', color="#019480", linewidth=3,
                        label=f'Ground Truth ({gt_value:.3f})', alpha=0.8)
            
            # Plot zero-error line on the residuals plot (ax2)
            ax2.axhline(0, linestyle='--', color="#019480", linewidth=3,
                        label='Zero Error', alpha=0.8)

            # Add GT uncertainty band ONLY to the first plot
            if ground_truth_std and property_name in ground_truth_std:
                gt_std = ground_truth_std[property_name]
                ax1.axhspan(gt_value - gt_std, gt_value + gt_std, color="#019480", alpha=0.1,
                            label=f'Ground Truth ± Std ({gt_std:.3f})')

        # --- Customize and Finalize Plot ---
        fig.suptitle(f'Prediction Comparison for {property_name}', fontsize=25)

        # Axis 1: Predictions
        ax1.set_title('Predictions vs. Ground Truth')
        ax1.set_ylabel('Prediction Value')
        ax1.grid(True, linestyle='--', alpha=0.6)
        ax1.legend()

        # Axis 2: Residuals
        ax2.set_title('Residuals (Prediction - Ground Truth)')
        ax2.set_ylabel('Residual Value')
        ax2.grid(True, linestyle='--', alpha=0.6)
        ax2.legend()

        # Shared X-axis settings
        unique_n_values = sorted(set(all_n_values))
        if unique_n_values:
            ax1.set_xticks(unique_n_values)
            ax1.set_xlabel('Prediction Position (n)')
            ax2.set_xlabel('Prediction Position (n)')

        plt.tight_layout(rect=[0, 0.03, 1, 1])

        plot_filename = os.path.join(output_dir, f'{property_name}_comparison.pdf')
        plt.savefig(plot_filename, bbox_inches='tight', dpi=300)
        plt.close(fig)

        print(f'Saved comparison plot for {property_name} to {plot_filename}')

In [None]:
def unnormalize_dict_with_lists(data_dict, stats):
    """
    Unnormalizes the values in a nested dictionary using z-score statistics.
    Handles single values and lists of values.

    Args:
        data_dict (dict): The nested dictionary with normalized values.
                          Format: {smiles: {property: value or [values]}}
        stats (dict): The dictionary containing mean and std for each property.
                      Format: {property: {'mean': value, 'std': value}}

    Returns:
        dict: A new dictionary with unnormalized values.
    """
    unnormalized_data = copy.deepcopy(data_dict)
    
    for smiles, properties in unnormalized_data.items():
        for prop, indicies in properties.items():
            for index, values in indicies.items():
                try:
                    mean = stats[prop]['mean']
                    std = stats[prop]['std']
                    # Check if the value is a list and iterate if so
                    if isinstance(values, list):
                        unnormalized_data[smiles][prop][index] = [
                            (item * std) + mean for item in values
                        ]
                    
                    else:
                        # Apply the reverse z-score formula for a single value
                        unnormalized_value = (values * std) + mean
                        unnormalized_data[smiles][prop][index] = unnormalized_value

                except KeyError:
                    print(f"Warning: Statistics not found for property '{prop}'. Skipping.")
                    # The original value (single or list) is kept
                    unnormalized_data[smiles][prop] = values

    return unnormalized_data

In [None]:
def unnormalize_dict(data_dict, stats):
    """
    Unnormalizes the values in a nested dictionary using z-score statistics.

    Args:
        data_dict (dict): The nested dictionary with normalized values.
                          Format: {smiles: {property: value}}
        stats (dict): The dictionary containing mean and std for each property.
                      Format: {property: {'mean': value, 'std': value}}

    Returns:
        dict: A new dictionary with unnormalized values.
    """
    unnormalized_data = copy.deepcopy(data_dict)
    
    # Iterate through each smiles string in the dictionary
    for smiles, properties in unnormalized_data.items():
        # Iterate through each property and its normalized value
        for prop, value in properties.items():
            try:
                mean = stats[prop]['mean']
                std = stats[prop]['std']
                
                # Apply the reverse z-score formula: x = (z * std) + mean
                unnormalized_value = (value * std) + mean
                
                # Update the value in the new dictionary
                unnormalized_data[smiles][prop] = unnormalized_value
            except KeyError:
                print(f"Warning: Statistics not found for property '{prop}'. Skipping unnormalization for this property.")
                # If stats are not found, we keep the original value
                unnormalized_data[smiles][prop] = value
            
    return unnormalized_data

In [None]:
prediction_avg_unnorm = unnormalize_dict(predictions_averaged, norm_dict)
true_values_dict_unorm = unnormalize_dict(true_values_dict, norm_dict)
print(true_values_dict_unorm.keys())

## Edit 'compare predictions' to change the colours, and plot parameters

In [None]:
temp_unorm = unnormalize_dict_with_lists(temp_dict, norm_dict)
results_unnorm = unnormalize_dict_with_lists(results_dict, norm_dict) 

compare_predictions_by_n(temp_unorm['CC(=O)c1ccccc1'],results_unnorm['CC(=O)c1ccccc1'],

                      ground_truth=true_values_dict_unorm['CC(=O)c1ccccc1'], ground_truth_std=std_vals, label1='template predictions', label2='scratch predictions')

# Word Embedding Plots
Plot word embeddings, label and circle selected points for illustration

In [None]:
from model.decoder import MultiModalRegressionTransformer

model = MultiModalRegressionTransformer(384, 26, 64, 28, 6, 16, 5, 0)
import torch
import torch.nn as nn
import pandas as pd

model_path = 'val_loss0.1074_DPR_0.1_MP_0.3_DM_64_TL_5_heads_16.pth'

model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval()

embedding_layer = model.embeddings_module.property_embedding

# Get the weights (embeddings)
# .weight returns a torch.nn.Parameter, so we typically convert it to a tensor
embeddings = embedding_layer.weight.data
transposed_data = embeddings.T  # Transpose the embeddings tensor 



print(f"Shape of embeddings: {embeddings.shape}")
print(f"First 5 embedding vectors:\n {embeddings[:5]}")

WORD_TOKENS = ['alkane', 'aromatic', 'halohydrocarbon', 'ether', 'ketone', 'ester', 'nitrile', 'amine', 'amide', 'misc_N_compound', 'carboxylic_acid', 'monohydric_alcohol' , 'polyhydric_alcohol', 'other','ET30', 'alpha', 'beta', 'pi_star', 'SA', 'SB', 'SP', 'SdP', 'N_mol_cm3', 'n', 'fn', 'delta']


# Convert to NumPy array
transposed = embeddings.T
numpy_array = embeddings.numpy()

# Create DataFrame
# The columns will be indexed 0, 1, 2...
df = pd.DataFrame(transposed)
print(df.shape)

df.columns = WORD_TOKENS

In [None]:
import pandas as pd 
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from matplotlib import cm
import re # Import the regular expressions module

def pca_scatter_plot(df, n_components=2, highlight_groups=None, group_labels=None, misc_labels=None):
    """
    Reduce each column (384 dimensions) to n_components using PCA and create scatter plot.
    
    Args:
        df: DataFrame with shape (384, 26) - each column will be reduced to 2D
        n_components: Number of PCA components (default=2)
        highlight_groups: list of lists of column names, each group will be circled in a different color
        group_labels: list of names (strings), one per group, used to label each circle
        misc_labels: list of column names that will get labels but no circles (for illustration)
    """
    # Transpose so each row represents a column (26 rows, 384 features each)
    df_transposed = df.T
    plt.rcParams['font.family'] = 'sans-serif'
    plt.rcParams['font.sans-serif'] = ['Latin Modern Sans', 'DejaVu Sans', 'Arial']    
    plt.rcParams['font.size'] = 20

    # Standardize the data
    scaler = StandardScaler()
    scaled_data = scaler.fit_transform(df_transposed)
    
    # Apply PCA to reduce from 384 dimensions to 2
    pca = PCA(n_components=n_components)
    pca_result = pca.fit_transform(scaled_data)
    
    # Create scatter plot
    fig, ax = plt.subplots(figsize=(14, 6))
    ax.scatter(pca_result[:, 0], pca_result[:, 1], color="#7F51FF", alpha=0.7, s=600)
    
    # Collect all labels that should be shown
    labels_to_show = set()
    
    # Add circled group labels
    if highlight_groups:
        for group in highlight_groups:
            labels_to_show.update(group)
    
    # Add misc labels
    if misc_labels:
        labels_to_show.update(misc_labels)
    
    # Add labels only for selected points (using column names)
    for i, col_name in enumerate(df.columns):
        if col_name in labels_to_show:
            ax.annotate(col_name, (pca_result[i, 0], pca_result[i, 1]), 
                         xytext=(5, 5), textcoords='offset points', 
                         fontsize=15, fontweight='bold', alpha=0.8)
    
    # Highlight groups of points with enclosing circles
    if highlight_groups:
        # Use colormap
        colormap = plt.cm.Dark2
        
        # Create a range of values from 0 to 1, one for each group
        num_groups = len(highlight_groups)
        color_values = np.linspace(0, 1, num_groups)

        for g_idx, group in enumerate(highlight_groups):
            coords = []
            for label in group:
                if label in df.columns:
                    idx = list(df.columns).index(label)
                    x, y = pca_result[idx, 0], pca_result[idx, 1]
                    coords.append((x, y))
            if coords:
                coords = np.array(coords)
                centroid = coords.mean(axis=0)
                max_dist = np.sqrt(((coords - centroid) ** 2).sum(axis=1)).mean()
                radius = max_dist * 2.2
                
                # Use the colormap function to get the color for the current group
                color = colormap(color_values[g_idx] * 0.6)
                
                circle = plt.Circle(
                    centroid, radius,
                    color=color,
                    fill=False, linewidth=3, linestyle='--'
                )
                ax.add_patch(circle)

                # Add a label for the circle (use group_labels if provided, else default)
                if group_labels and g_idx < len(group_labels):
                    circle_label = group_labels[g_idx]
                else:
                    circle_label = f"Group {g_idx+1}"

                # New code to position the label at the top of the circle
                label_x = centroid[0]
                label_y = centroid[1] + radius
                
                ax.annotate(
                    circle_label,
                    xy=(label_x, label_y),
                    ha='center', va='bottom',  # Align the text's bottom to the xy point
                    fontsize=18, fontweight='bold',
                    color=color
                )
    

    
    ax.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
    ax.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')
    
    ax.grid(False, alpha=0.3)
    plt.tight_layout()
    name = 'word_token_embeddings'
    plt.savefig(f'{name}.pdf', bbox_inches='tight')
    plt.show() 
    
    print(f"Explained variance ratio: {pca.explained_variance_ratio_}")
    print(f"Total variance explained: {pca.explained_variance_ratio_.sum():.2%}")
    
    return pca_result, pca


# Example usage:
highlight_labels = [['n', 'fn', 'SP'], ['ET30', 'SA', 'alpha']]
circle_names = ['Polarisability', 'Polarity']  # <-- names for each circle
misc_labels = ['alkane', 'beta', 'other', 'SDP', 'carboxylic_acid', 'amine', 'delta', 'aromatic']  # <-- labels without circles
pca_result, pca_model = pca_scatter_plot(df, n_components=2, highlight_groups=highlight_labels, group_labels=circle_names, misc_labels=misc_labels)

# Attention Plots 
Plot attention scores of different solvent sequences to compare pre-trained and finetuned attention scores

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import MultipleLocator
import os
import numpy as np

def plot_attention_comparison(attention_weights_dict, layers_to_compare=None, sample_idx=0, head_idx=0,
                            token_type_ids=None, save_path=None, figsize=(16, 6)):
    """
    Plot attention weights from multiple layers side by side with one shared colorbar.
    """

    plt.rcParams['font.size'] = 16

    if layers_to_compare is None:
        layers_to_compare = sorted(attention_weights_dict.keys())
    
    num_layers = len(layers_to_compare)
    if num_layers == 0:
        print("No valid layers to plot.")
        return
    
    # Add extra width for colorbar space
    fig_width = figsize[0] + 2
    fig, axes = plt.subplots(1, num_layers, figsize=(fig_width, figsize[1]), constrained_layout=True)

    if num_layers == 1:
        axes = [axes]

    heatmaps = []
    count = 0
    for i, layer_name in enumerate(layers_to_compare):
        if layer_name not in attention_weights_dict:
            print(f"Layer {layer_name} not found. Skipping.")
            continue

        attention_weights = attention_weights_dict[layer_name]
        attn_matrix = attention_weights[sample_idx, head_idx]

        # Create heatmap WITHOUT its own colorbar
        colour_bar = False
        if count == 2:
            hm = sns.heatmap(attn_matrix, 
                            cmap='viridis', 
                            cbar=True,
                            square=True,
                            ax=axes[i], 
                            cbar_kws={'label': 'Attention Weight'})
        else:
            hm = sns.heatmap(attn_matrix, 
                            cmap='viridis', 
                            cbar=False,
                            square=True,
                            ax=axes[i])
        heatmaps.append(hm)

        seq_len = attn_matrix.shape[0]

        # Ticks every 4 positions
        axes[i].xaxis.set_major_locator(MultipleLocator(4))
        axes[i].yaxis.set_major_locator(MultipleLocator(4))

        # Explicitly set ticks to 0..N-1 every 4
        axes[i].set_xticks(np.arange(0, seq_len, 4))
        axes[i].set_yticks(np.arange(0, seq_len, 4))
        axes[i].set_xticklabels(np.arange(0, seq_len, 4))
        axes[i].set_yticklabels(np.arange(0, seq_len, 4))

        # Axis labels
        axes[i].set_xlabel('Key Position')
        if i == 0:
            axes[i].set_ylabel('Query Position')

        # Replace underscores with spaces in title
        pretty_name = layer_name.replace('_', ' ').title()
        axes[i].set_title(pretty_name)

        count += 1
    
    # Add ONE colorbar on the right
       # Add ONE colorbar on the right
    '''if heatmaps:
        cbar = fig.colorbar(heatmaps[0].collections[0], ax=axes, location="right", pad=0, shrink=0.9)
        cbar.set_label('Attention Weight')
    '''

    # Add a figure title with some whitespace
    plt.suptitle(f'Fine-tuned Model Attention Scores', fontsize=24, y=1.08)

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Comparison plot saved to {save_path}")
    
    plt.show()


In [None]:
from figures.plot_attention import MultiModalRegressionTransformerWithWeights
from model.decoder import MultiModalRegressionTransformer
import torch
original_model = MultiModalRegressionTransformer(384, 26, 64, 28, 6, 16, 5, 0)
original_model.load_state_dict(torch.load('val_loss0.1074_DPR_0.1_MP_0.3_DM_64_TL_5_heads_16.pth'))

FP_model = MultiModalRegressionTransformer(384, 26, 64, 28, 6, 16, 5, 0)
FP_model.load_state_dict(torch.load('FP_model.pt'))

#MultiModalRegressionTransformerWithWeights()

In [None]:
from figures.plot_attention import create_modified_model_from_original, save_attention_weights
modified_model = create_modified_model_from_original(original_model, 'all')
fp_mod = create_modified_model_from_original(FP_model, 'all')
from model.dataset import load_dataset
from model.config import COLUMN_DICT, MAX_SEQUENCE_LENGTH, TOKEN_TYPE_VOCAB
from model.collate import create_collate_fn
from torch.utils.data import DataLoader

In [None]:
data_path = 'data/train_set.csv'

dataset, chemberta_dimension = load_dataset(data_path, COLUMN_DICT, MAX_SEQUENCE_LENGTH)
    
    # Create collate function
configured_collate_fn = create_collate_fn(TOKEN_TYPE_VOCAB, 0)
    
    # Create DataLoader
dataloader = DataLoader(
    dataset, 
    batch_size=1, 
    shuffle=False,  # Usually no shuffling for inference
    collate_fn=configured_collate_fn
)


In [None]:
for index, batch in enumerate(dataloader):
    if index == 10:
        sample_input = batch
sample_input['token_type_vocab'] = TOKEN_TYPE_VOCAB
print(sample_input.keys())

attention_weights = save_attention_weights(modified_model, sample_input)
fp_weight = save_attention_weights(fp_mod, sample_input)
head = 3
plot_attention_comparison(attention_weights, layers_to_compare=['layer_0', 'layer_1', 'layer_4'], head_idx=head)
plot_attention_comparison(fp_weight, layers_to_compare=['layer_0', 'layer_1', 'layer_4'], head_idx=head)

'''plot_attention_comparison(attention_weights, fp_weight, 
                              title1="Pretrained Model", 
                              title2="Fine-tuned Model", 
                              layers_to_compare=['layer_0', 'layer_1', 'layer_4'], 
                              sample_idx=0, 
                              head_idx=3,
                              save_path=None, 
                              figsize=(16, 12))'''