In [1]:
from neo.core import SpikeTrain
from tqdm import tqdm
import elephant as eph
import imageio
import matplotlib as mpl
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import os
import quantities as q
from subprocess import call
import re 

%config Completer.use_jedi = False

mpl.rcParams.update({'font.size': 15})

# Community Dynamics Animation (Supplementary Video)

In [2]:
def eval_spikes(graphno):
    mat = np.loadtxt(f'../modules/networks/matrix_{graphno}.csv',delimiter=",")
    module = np.loadtxt(f'../modules/networks/matrix_{graphno}_modules.csv')
    order = np.argsort(module)

    all_events = []
    for pertseed in [59428,13674,84932,72957,85036]:
        events = np.load(f"../data/30LN/LN30_events_{graphno}_{pertseed}.npy",allow_pickle=True)
        all_events.append(events)
    all_events = np.array(all_events,dtype=object)

    axs = []
    for rep in range(5):
        spike_times = []
        for i in range(30):
            spike_times.append(SpikeTrain(np.concatenate([all_events[0,rep,i],
                                               all_events[1,rep,i]+7000,
                                               all_events[2,rep,i]+14000,
                                               all_events[3,rep,i]+21000,
                                               all_events[4,rep,i]+28000])*q.ms,t_stop=35000*q.ms))
        spike_times = np.array(spike_times)
        bst = eph.conversion.BinnedSpikeTrain(list(spike_times),bin_size=50*q.ms)
        axes = []
        for i in np.unique(module):
            temp = np.mean(bst.to_array()[np.arange(30)[module==i],:],axis=0)
            axes.append(temp)
        axs.append(axes)
    axs = np.array(axs) 
    return axs
def order_by_variability(axs):
    axs = axs[:,np.argsort(axs.std(axis=(0,2)))[::-1],:]
    return axs, [f"M{mod+1}" for mod in np.argsort(axs.std(axis=(0,2)))]

In [3]:
def animate3d(axs,graphno,nreps=5,labels=["M1","M2","M3"]):
    for rep in range(nreps):
        print(f"Replicate {rep+1}")
        lims = axs[rep][0].min(),axs[rep][0].max(),axs[rep][1].min(),axs[rep][1].max(),axs[rep][2].min(),axs[rep][2].max()
        fig = plt.figure(figsize=(5,7))
        G = gridspec.GridSpec(5, 1)

        ax1 = fig.add_subplot(G[0:4], projection='3d')
        ax1.text(-0.1, -0.1, -0.1, "0", color='k',fontdict={"fontsize":10})
        ax1.quiver(0,0,0,lims[1],0,0,arrow_length_ratio=0.1,color='k')
        ax1.text(lims[1]+0.05, -0.05, -0.05, labels[0] , color='k',fontdict={"fontsize":10})
        ax1.quiver(0,0,0,0,lims[3],0,arrow_length_ratio=0.1,color='k')
        ax1.text(-0.05,lims[3]+0.05,-0.05, labels[1], color='k',fontdict={"fontsize":10})
        ax1.quiver(0,0,0,0,0,lims[5],arrow_length_ratio=0.1,color='k')
        ax1.text(-0.05,-0.05,lims[5]+0.05, labels[2], color='k',fontdict={"fontsize":10})
        ax1.set_xlim(lims[0],lims[1])
        ax1.set_ylim(lims[2],lims[3])
        ax1.set_zlim(lims[4],lims[5])
        ax1.set_axis_off()
        ax1.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
        ax1.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
        ax1.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
        ax1.xaxis._axinfo["grid"]['color'] =  (1,1,1,0)
        ax1.yaxis._axinfo["grid"]['color'] =  (1,1,1,0)
        ax1.zaxis._axinfo["grid"]['color'] =  (1,1,1,0)

        line = ax1.plot3D(0, 0, 0, 'gray',linewidth=1)
        dot = ax1.plot3D(0, 0, 0,marker='o',color='k')

        tim = ax1.text2D(0.4,0.05,"t=0.0s",transform=ax1.transAxes,fontsize=10)

        ax2 = fig.add_subplot(G[4])
        tracks = []
        endpoint = []
        for i in range(axs.shape[1]):
            a = ax2.plot(np.arange(0,0.05*1,0.05),axs[rep][0][0],"-",color=plt.cm.inferno(i/axs.shape[1]),linewidth=0.5,label=f'M{i+1}')
            tracks.append(a)
            a = ax2.plot(0,0,"o",color=plt.cm.inferno(i/axs.shape[1]))
            endpoint.append(a)
        plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.05),
                  ncol=axs.shape[1], fontsize=10, frameon=False)
        ax2.set_xlim((0,35))
        ax2.set_ylim((0,1.2*np.max([lims[1],lims[3],lims[5]])))
        ax2.set_xlabel("Time (in s)")
        ax2.set_ylabel("Mean Activity")

        perts = []
        for i in range(5):
            for j in range(5):
                perts.append(7*i+1.5+j)
        pert_counter = 0

        plt.subplots_adjust(wspace=0.0,hspace=0.0)

        for i in tqdm(range(10,axs.shape[2])):
            tim.set_text(f"t = {i/20:0.3f}s")
            tim.set_color("black")
            if pert_counter<len(perts) and perts[pert_counter]<i/20:
                ax2.fill_betweenx((0,1.2*np.max([lims[1],lims[3],lims[5]])),perts[pert_counter],perts[pert_counter]+0.1,color='k',alpha=0.1)
                pert_counter+=1
                tim.set_text(f"t = {i/20:0.3f} s")
                tim.set_color("red")
            xline = axs[rep][0][i-10:i]
            yline = axs[rep][1][i-10:i]
            zline = axs[rep][2][i-10:i]
            line[0].set_data_3d(xline, yline, zline)
            dot[0].set_data_3d(xline[-1], yline[-1], zline[-1])
            plt.gca().spines['top'].set_visible(False)
            plt.gca().spines['right'].set_visible(False)
            ax1.view_init(elev=20, azim=i*0.3)
            for j in range(axs.shape[1]):
                tracks[j][0].set_xdata(np.arange(i)/20)
                tracks[j][0].set_ydata(axs[rep][j][:i])
                endpoint[j][0].set_xdata(i/20)
                endpoint[j][0].set_ydata(axs[rep][j][i])
            plt.savefig(f"__animationcache__/{i}.png", dpi=90)
        plt.close()
        
        images = []
        files = os.listdir('__animationcache__')
        files.sort(key=lambda f: int(re.sub('\D', '', f)))
        for filename in tqdm(files):
            images.append(imageio.imread('__animationcache__/'+filename))
        imageio.mimsave(f'Videos/movie_{graphno}_{rep}.gif', images,fps=20)
        
        for file in os.listdir('__animationcache__/'):
            os.remove('__animationcache__/'+file)

def animate2d(axs,graphno,nreps=5,labels=["M1","M2"]):
    for rep in range(nreps):
        print(f"Replicate {rep+1}")
        lims = axs[rep][0].min(),axs[rep][0].max(),axs[rep][1].min(),axs[rep][1].max()
        fig = plt.figure(figsize=(5,7))
        G = gridspec.GridSpec(5, 1)

        ax1 = fig.add_subplot(G[0:4])
        ax1.set_xlim(lims[0],lims[1])
        ax1.set_ylim(lims[2],lims[3])
        ax1.set_xticks([lims[0],lims[1]])
        ax1.set_xticklabels(["0",labels[0]])
        ax1.set_yticks([lims[2],lims[3]])
        ax1.set_yticklabels(["",labels[1]])
        line = ax1.plot(0, 0, 'gray',linewidth=1)
        dot = ax1.plot(0, 0, marker='o',color='k')

        tim = ax1.text(0.4,0.05,"t=0.0s",transform=ax1.transAxes,fontsize=10)
        ax1.spines['right'].set_visible(False)
        ax1.spines['top'].set_visible(False)

        ax2 = fig.add_subplot(G[4])
        tracks = []
        endpoint = []
        for i in range(axs.shape[1]):
            a = ax2.plot(np.arange(0,0.05*1,0.05),axs[rep][0][0],"-",color=plt.cm.inferno(i/axs.shape[1]),linewidth=0.5,label=f'M{i+1}')
            tracks.append(a)
            a = ax2.plot(0,0,"o",color=plt.cm.inferno(i/axs.shape[1]))
            endpoint.append(a)
        plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.05),
                  ncol=3, fontsize=10, frameon=False)
        ax2.set_xlim((0,35))
        ax2.set_ylim((0,1.2*np.max([lims[1],lims[3]])))
        ax2.set_xlabel("Time (in s)")
        ax2.set_ylabel("Mean Activity")

        perts = []
        for i in range(5):
            for j in range(5):
                perts.append(7*i+1.5+j)
        pert_counter = 0

        plt.subplots_adjust(wspace=0.0,hspace=0.5)

        for i in tqdm(range(10,axs.shape[2])):
            tim.set_text(f"t = {i/20:0.3f}s")
            tim.set_color("black")
            if pert_counter<len(perts) and perts[pert_counter]<i/20:
                ax2.fill_betweenx((0,1.2*np.max([lims[1],lims[3]])),perts[pert_counter],perts[pert_counter]+0.1,color='k',alpha=0.1)
                pert_counter+=1
                tim.set_text(f"t = {i/20:0.3f} s")
                tim.set_color("red")
            xline = axs[rep][0][i-10:i]
            yline = axs[rep][1][i-10:i]
            line[0].set_data(xline, yline)
            dot[0].set_data(xline[-1], yline[-1])
            plt.gca().spines['top'].set_visible(False)
            plt.gca().spines['right'].set_visible(False)
            for j in range(axs.shape[1]):
                tracks[j][0].set_xdata(np.arange(i)/20)
                tracks[j][0].set_ydata(axs[rep][j][:i])
                endpoint[j][0].set_xdata(i/20)
                endpoint[j][0].set_ydata(axs[rep][j][i])
            plt.savefig(f"__animationcache__/{i}.png", dpi=90)
        plt.close()
        
        images = []
        files = os.listdir('__animationcache__')
        files.sort(key=lambda f: int(re.sub('\D', '', f)))
        for filename in tqdm(files):
            images.append(imageio.imread('__animationcache__/'+filename))
        imageio.mimsave(f'Videos/movie_{graphno}_{rep}.gif', images,fps=20)
        
        for file in os.listdir('__animationcache__/'):
            os.remove('__animationcache__/'+file)

In [None]:
for i in range(1,11):
    axs = eval_spikes(i)
    if axs.shape[1]>3:
        axs,labels = order_by_variability(axs)
        animate3d(axs,i,labels=labels)
    elif axs.shape[1]>2:
        animate3d(axs,i)
    else:
        animate2d(axs,i)

In [5]:
for graphno in range(1,11):
    gifs = [imageio.get_reader(f'Videos/{x}') for x in filter(lambda v: f"_{graphno}_" in v,os.listdir("Videos/"))]
    new_gif = imageio.get_writer(f'Videos/{graphno}_out.gif',fps=20)
    for frame_number in tqdm(range(np.min([gif.get_length() for gif in gifs]))):
        datas = [gif.get_next_data() for gif in gifs]
        new_image = np.hstack(datas)
        new_gif.append_data(new_image)
    for gif in gifs:
        gif.close()
    new_gif.close()

100%|██████████| 688/688 [02:05<00:00,  5.49it/s]
100%|██████████| 688/688 [01:57<00:00,  5.86it/s]
100%|██████████| 688/688 [02:04<00:00,  5.51it/s]
100%|██████████| 688/688 [02:20<00:00,  4.90it/s]
100%|██████████| 688/688 [02:14<00:00,  5.10it/s]
100%|██████████| 688/688 [02:03<00:00,  5.59it/s]
100%|██████████| 688/688 [02:06<00:00,  5.45it/s]
100%|██████████| 688/688 [02:16<00:00,  5.04it/s]
100%|██████████| 688/688 [02:12<00:00,  5.19it/s]
100%|██████████| 688/688 [01:57<00:00,  5.88it/s]


In [6]:
files = list(filter(lambda v: "_out" in v,os.listdir("Videos/")))
files.sort(key=lambda f: int(re.sub('\D', '', f)))
for graphno in range(1,11,2):
    gifs = [imageio.get_reader(f'Videos/{x}') for x in files[graphno-1:graphno+1]]
    new_gif = imageio.get_writer(f'Videos/set_{graphno}_{graphno+1}.gif',fps=20)
    for frame_number in tqdm(range(np.min([gif.get_length() for gif in gifs]))):
        datas = [gif.get_next_data() for gif in gifs]
        new_image = np.vstack(datas)
        new_gif.append_data(new_image)
    for gif in gifs:
        gif.close()
    new_gif.close()

100%|██████████| 688/688 [04:28<00:00,  2.56it/s]
100%|██████████| 688/688 [03:45<00:00,  3.05it/s]
100%|██████████| 688/688 [03:28<00:00,  3.30it/s]
100%|██████████| 688/688 [04:18<00:00,  2.67it/s]
100%|██████████| 688/688 [04:10<00:00,  2.75it/s]


In [9]:
os.chdir("Videos/")
with open('vid_list.txt','w') as f:
    f.write("file 'supplementary_video_caption.mp4'\n")
    for graphno in tqdm(range(1,11,2)):
        call(['ffmpeg.exe','-i',f'set_{graphno}_{graphno+1}.gif',f'set_{graphno}_{graphno+1}.mp4'])
        f.write(f"file 'set_{graphno}_{graphno+1}.mp4'\n")

100%|██████████| 5/5 [02:25<00:00, 29.08s/it]


1

In [16]:
call(['ffmpeg.exe','-f','concat','-safe','0','-i','vid_list.txt','-c','copy','../supplementary_video1.mp4'])

0

In [19]:
for f in filter(lambda v:".gif" in v, os.listdir()):
    os.rename(f,'.archive/'+f)