In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from textwrap import wrap
plt.style.use('Solarize_Light2')

In [2]:
data = pd.read_csv('totals.csv', index_col=0)

In [3]:
# Data extrapolation
# Dataframe will be reversed here

genes = ['abca7', 'apoe', 'app', 'bin1', 'clu', 'picalm', 'presenilin-1',
         'presinilin-2', 'sorl-1']

new_data = pd.DataFrame(columns = data.columns)
for col in data.columns:
    new_series = []
    old_series = data.loc[:,col].values
    for i in reversed(range(data.shape[0])):
        if (i == data.shape[0]-1):
            new_series.append(old_series[i])
        elif (pd.isnull(old_series[i])):
            for j in range(7):
                new_series.append(None)
            new_series.append(old_series[i])
        else:
            for j in range(1,9):
                if col in genes:
                    part = old_series[i]/8
                    new_series.append(part)
                else:
                    part = (old_series[i] - old_series[i+1])/8
                    new_series.append(old_series[i+1] + j*part)

    new_data[col] = new_series

In [4]:
# Data from additional papers
gene_type = ['Transport protein',
             'Metabolism protein',
             'Membrane protein',
             'Adaptor protein',
             'Chaperone protein',
             'Assembly protein',
             'Transmembrane protein',
             'Transmembrane protein',
             'Transmembrane protein']
odds_ratio = [2.03, 3.68, 5.7, 1.2, 1.16, 0.88, 6.1, 3.8, 0.74]
prevalence = [0.1, 0.22, 0.001, 0.29, 0.38, 0.35, 0.04, 0.0005, 0.1]

In [None]:
import matplotlib.animation as animation

def data_gen(i=0):
    while i < new_data.shape[0]:
        i+=1
        yield new_data.iloc[:i,:]

fig, ax = plt.subplots(figsize=(17,8), constrained_layout=True)
low_y = 0.73
ax1 = fig.add_axes([0.08, low_y , 0.15, 0.02])
cb1 = mpl.colorbar.ColorbarBase(ax1, cmap='Reds',
                                orientation='horizontal')
cb1.set_ticks([0,1])
cb1.ax.tick_params(length=0)
cb1.ax.set_xticklabels(['low', 'high'])
cb1.set_label('Relative risk of gene', fontsize=10, labelpad=-40, color='grey')

start_coords = fig.transFigure.transform((0.085, low_y-0.09))
end_coords = fig.transFigure.transform((0.23, low_y-0.09))
coord_step = (end_coords[0] - start_coords[0])/8
size_circles = []
for i in range(9):
    x = start_coords[0] + i*coord_step
    y = start_coords[1]
    circ = mpl.patches.Circle(xy=(x,y), radius=(i+1)*1.2, transform=None,
                              facecolor='black', alpha=0.75)
    size_circles.append(circ)

ax.text(0.155, low_y-0.07, 'Relative prevalence of gene',
        verticalalignment='bottom', horizontalalignment='center',
        transform=fig.transFigure, fontsize=10, color='grey')
ax.text(0.08, low_y - 0.12, 'low', verticalalignment='top', horizontalalignment='center',
        transform=fig.transFigure, fontsize=10, color='grey')
ax.text(0.23, low_y - 0.12, 'high', verticalalignment='top', horizontalalignment='center',
        transform=fig.transFigure, fontsize=10, color='grey')

ax.set_xlim(np.min(new_data.year), np.max(new_data.year))
ax.set_ylabel('Number of published papers')
ax.set_xlabel('Year')
fig.suptitle('Alzheimer\'s research', fontsize=16)


def run(gen_data):
    cur_data = gen_data
    ylabels = [item.get_text() for item in ax.get_yticklabels()]
    ax.clear()
    amy_handler, = ax.plot(cur_data.year, cur_data.amyloid,
                           label=r'Amyloid hypothesis - Build up of extracellular amyloid beta ($A_\beta$) deposits directly cause the disease')
    tau_handler, = ax.plot(cur_data.year, cur_data.tau,
                      label=r'Tau hypothesis - Hyperphosporylation of tau lead to neurofibrillary tangles which ultimately cause the disease')
    choli_handler, = ax.plot(cur_data.year, cur_data.cholinergic,
                             label=r'Cholinergic hypothesis- The diseases is caused by reduced acetylcholine synthesis')

    dot_handles = []
    portion = 1 / len(genes)
    max_odds = np.max(odds_ratio)
    max_prev = np.max(prevalence)

    for i, gene in enumerate(genes):
        y = np.sum(cur_data[gene])
        if y > 4000:
            y = 4000
        elif y == 0:
            continue
        dot_size = 10 + prevalence[i]/max_prev * 500
        perc_odds = odds_ratio[i]/max_odds
        label = '\n'.join(wrap(('%s(%d) - %s'%(gene, np.sum(cur_data[gene]), gene_type[i])),30))
        dot = ax.scatter(cur_data.loc[cur_data[gene].notnull(), 'year'].iloc[0],
                         y, s=dot_size, color=mpl.cm.Reds(perc_odds), zorder=10,
                         alpha=0.8,
                         edgecolor = 'black',
                         label=label)
        dot_handles.append(dot)

    ymin, ymax = ax.get_ylim()

    if ymax > 4000:
        ax.set_ylim(ymin, 4000)
        ylabels[-1] = '4000+'
        ax.set_yticklabels(ylabels)

    for item in size_circles:
        ax.add_patch(item)
    # add the text
    ax.text(0.155, low_y-0.07, 'Relative prevalence of gene',
            verticalalignment='bottom', horizontalalignment='center',
            transform=fig.transFigure, fontsize=10, color='grey')
    ax.text(0.08, low_y - 0.12, 'low', verticalalignment='top', horizontalalignment='center',
            transform=fig.transFigure, fontsize=10, color='grey')
    ax.text(0.23, low_y - 0.12, 'high', verticalalignment='top', horizontalalignment='center',
            transform=fig.transFigure, fontsize=10, color='grey')

    ax.set_ylabel('Number of published papers')
    ax.set_xlabel('Year')

    second_legend = ax.legend(handles=dot_handles,
                              bbox_to_anchor=(1.01, -0.01, 0.4, 1), loc='upper left', mode='expand',
                              borderaxespad=0., title='Gene(total papers)      ',
                              title_fontsize=13, labelspacing=2, handletextpad=2)
    ax.add_artist(second_legend)

    ax.legend(handles=[amy_handler, tau_handler, choli_handler],
              bbox_to_anchor=(0.,1.02, 1.41, 1.02), loc='lower left',
              mode='expand', borderaxespad=0., fontsize=11, title="Hypotheses",
              title_fontsize=13)


ani = animation.FuncAnimation(fig, run, data_gen(), blit=False, interval=200,
                              repeat=False, save_count=10000)
ani.save("movie.gif", writer=animation.PillowWriter(fps=24))