In [1]:
import numpy as np
from matplotlib import pyplot as plt

In [2]:
def create_dir(dir_path):
    if not os.path.exists(dir_path):
        os.mkdir(dir_path)
    
def plot_heatmaps(maps, x_labels, y_labels, x_title, y_title, subfolder, filename, 
                  plot_title="", subplot_titles=None, annotate=False, fmt='png'):
    columns, rows = len(maps), 1
    fig = plt.figure()
    for i in range(1, columns * rows + 1):
        ax = fig.add_subplot(rows, columns, i)
        ax.set_title(subplot_titles[i-1] if subplot_titles is not None else "")
        ax.set_xticks(np.arange(len(x_labels)))
        ax.set_yticks(np.arange(len(y_labels)))
        ax.set_xticklabels(x_labels)
        ax.set_yticklabels(y_labels)
        ax.set_xlabel(x_title)
        ax.set_ylabel(y_title)
        
        im = ax.imshow(maps[i-1], cmap=plt.cm.Blues)
        plt.setp(ax.xaxis.get_majorticklabels(), rotation=90) # rotate x tick labels by 90 degrees
        
        fig.colorbar(im, ax=ax)
        
        if annotate:
            fmt_str = '{:.3f}' if maps[i-1].dtype == float else '{:.0f}'
            thresh = (maps[i-1].max() + maps[i-1].min()) / 2.
            for k in range(maps[i-1].shape[0]):
                for j in range(maps[i-1].shape[1]):
                    ax.text(j, k, fmt_str.format(maps[i-1][k, j]),
                            ha="center", va="center", fontsize=8,
                            color="white" if maps[i-1][k, j] > thresh else "black")

    fig.suptitle(plot_title)
#     fig.subplots_adjust(right=0.8)
#     cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
#     fig.colorbar(im, cax=cbar_ax)

    create_dir('{}'.format(subfolder))
    plt.savefig('{}/{}.{}'.format(subfolder, filename, fmt), bbox_inches='tight')
    plt.show()
    plt.close()

def plot_lines(lines, errors, x_labels, x_vals, plot_title, x_title, y_title, linestyles, 
               colors, labels, subfolder, filename, fmt='png', separate_legend=False):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.set_title(plot_title)
    ax.set_xlabel(x_title)
    ax.set_ylabel(y_title)
    ax.set_xticks(x_vals)
    ax.set_xticklabels(x_labels)

    for i in range(len(lines)):
        ax.errorbar(x_vals, lines[i], yerr=errors[i], color=colors[i], 
                    label=labels[i], linestyle=linestyles[i], marker='o')
    
    create_dir('{}'.format(subfolder))
    
    if not separate_legend:
        box = ax.get_position()
        ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
        ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    else:
        fig_legend = plt.figure(figsize=(3, 3))
        handles, labels = ax.get_legend_handles_labels()
        fig_legend.legend(handles, labels, 'center', ncol=1)
        fig_legend.savefig('{}/{}_legend.{}'.format(subfolder, filename, fmt), 
                           bbox_inches='tight')
    
    
    fig.savefig('{}/{}.{}'.format(subfolder, filename, fmt), bbox_inches='tight')
    plt.show()
    plt.close()

In [1]:
def plot_scatter(all_x_vals, all_y_vals, colors, shapes, annotations, annotation_positions, plot_title, 
                 x_title, y_title, subfolder, filename, labels, separate_legend=False, fmt='png'):
    if isinstance(annotations[0], int) or isinstance(annotations[0], float):
        annotations = list(map(str, annotations))
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.set_title(plot_title)
    ax.set_xlabel(x_title)
    ax.set_ylabel(y_title)
    for i, (x_vals, y_vals) in enumerate(zip(all_x_vals, all_y_vals)):
        ax.scatter(x_vals, y_vals, s=64, c=colors[i], marker=shapes[i], label=labels[i])
    flattened_x_vals, flattened_y_vals = [], []
    for x_vals in all_x_vals:
        for x in x_vals:
            flattened_x_vals.append(x)
    for y_vals in all_y_vals:
        for y in y_vals:
            flattened_y_vals.append(y)
    arrowprops = dict(arrowstyle = "<-")
    for i, txt in enumerate(annotations):
        if txt != '':
            ax.annotate(txt, xytext=annotation_positions[i], xy=(flattened_x_vals[i], flattened_y_vals[i]), 
                        arrowprops=arrowprops)
    
    create_dir('{}'.format(subfolder))
    
    if not separate_legend:
        box = ax.get_position()
        ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
        ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    else:
        fig_legend = plt.figure(figsize=(3, 3))
        handles, labels = ax.get_legend_handles_labels()
        fig_legend.legend(handles, labels, 'center', ncol=1)
        fig_legend.savefig('{}/{}_legend.{}'.format(subfolder, filename, fmt), 
                           bbox_inches='tight')
    
    fig.savefig('{}/{}.{}'.format(subfolder, filename, fmt), bbox_inches='tight')
    plt.show()
    plt.close()