# Heatmap

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob

In [2]:
def heatmap(data, row_labels, col_labels, ax=None,
            cbar_kw={}, cbarlabel="", **kwargs):
    """
    Create a heatmap from a numpy array and two lists of labels.

    Arguments:
        data       : A 2D numpy array of shape (N,M)
        row_labels : A list or array of length N with the labels
                     for the rows
        col_labels : A list or array of length M with the labels
                     for the columns
    Optional arguments:
        ax         : A matplotlib.axes.Axes instance to which the heatmap
                     is plotted. If not provided, use current axes or
                     create a new one.
        cbar_kw    : A dictionary with arguments to
                     :meth:`matplotlib.Figure.colorbar`.
        cbarlabel  : The label for the colorbar
    All other arguments are directly passed on to the imshow call.
    """

    if not ax:
        ax = plt.gca()

    # Plot the heatmap
    im = ax.imshow(data, **kwargs)

    # Create colorbar
    cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
    cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")

    # We want to show all ticks...
    ax.set_xticks(np.arange(data.shape[1]))
    ax.set_yticks(np.arange(data.shape[0]))
    # ... and label them with the respective list entries.
    ax.set_xticklabels(col_labels)
    ax.set_yticklabels(row_labels)

    # Let the horizontal axes labeling appear on top.
    ax.tick_params(top=True, bottom=False,
                   labeltop=True, labelbottom=False)

    # Rotate the tick labels and set their alignment.
    # plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor")

    # Turn spines off and create white grid.
    for edge, spine in ax.spines.items():
        spine.set_visible(False)

    ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
    ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
    ax.grid(which="minor", color="w", linestyle='-')
    ax.tick_params(which="minor", bottom=False, left=False)

    return im, cbar

In [3]:
main_elem = ['H', 'Li', 'Na', 'K', 'Rb', 'Cs',
             'Be', 'Mg', 'Ca', 'Sr', 'Ba',
             'B', 'Al', 'Ga', 'In', 'Tl',
             'C', 'Si', 'Ge', 'Sn', 'Pb',
             'N', 'P', 'As', 'Sb', 'Bi',
             'O', 'S', 'Se', 'Te',
             'Cl', 'Br', 'I']

In [11]:
def get_heatmap(file_name):
    data = pd.read_csv(file_name, index_col=0)
    data = data.loc[main_elem]
    fig, ax = plt.subplots(figsize=(10,20))
    
    cmap = "RdBu"
    if 'NMF' in file_name:
        cmap = "Blues"
    im, cbar = heatmap(data, list(data.index), list(range(len(data.columns))), ax=ax, cmap=cmap)

    fig.tight_layout()
    
    file_name = file_name.replace('atom2vec','Heatmap').replace('csv','png')
    plt.savefig(file_name, dpi=100)
    plt.close(fig)

# Hierarchical clustering

In [5]:
from scipy.cluster.hierarchy import dendrogram, linkage

In [9]:
def get_HC(file_name):
    data = pd.read_csv(file_name, index_col=0)
    if file_name == 'atom_env_matrix.csv':
        data = data.transpose()
    data = data.loc[main_elem]
    
    Z = linkage(data, method='average', metric='cosine')
    #Z = linkage(data, method='centroid', metric='euclidean', optimal_ordering=True)
    fig = plt.figure(figsize=(10, 30))
    dn = dendrogram(Z, orientation='right', leaf_font_size=12, labels=list(data.index))
    if file_name != 'atom_env_matrix.csv':
        file_name = file_name.replace('atom2vec', 'HC')
    file_name = file_name.replace('csv', 'png')
    plt.savefig(file_name)
    plt.close(fig)

---

In [12]:
file_list = glob.glob('*.csv')
file_list.remove('mat_energy.csv')
for file_name in file_list:
    print (file_name)
    get_HC(file_name)
    if not 'atom_env' in file_name:
        get_heatmap(file_name)

atom2vec_NMF20.csv
atom2vec_AE1024.csv
atom2vec_AE20.csv
atom_env_matrix.csv
atom2vec_SVD20.csv
