In [None]:
import torch
import pickle
import plotly.graph_objs as go
from scipy.stats import ttest_ind

In [None]:
def activity_average_sim(datas, max_activity, start=0, end=-1, window=100):
    """
    Use for plotting the average over multiple simulations.

    Parameters:
    - datas (list): The input dataset.
    - max_activity (float): The maximum activity value.
    - start (int): The starting index of the data to consider (default: 0).
    - end (int): The ending index of the data to consider (default: -1, which means the last index).
    - window (int): The size of the window for calculating the average (default: 100).

    Returns:
    - average (torch.Tensor): The average activity values.
    - lower_bound (list): The lower bound of the activity range.
    - upper_bound (list): The upper bound of the activity range.
    """
    datas = datas
    start = start
    end = end
    if end == -1:
        average = torch.mean(torch.Tensor(datas).reshape(-1, window), axis=1) * max_activity
        std = torch.std(torch.Tensor(datas).reshape(-1, window), axis=1) * max_activity
    else:
        average = torch.mean(torch.Tensor(datas)[start:end].reshape(-1, window), axis=1) * max_activity
        std = torch.std(torch.Tensor(datas)[start:end].reshape(-1, window), axis=1) * max_activity

    return average, list(average - std)[::-1], list(average + std)


In [None]:

def plot(datas,max_activity, wide = False):
    """
    Plots the activity data.

    Parameters:
    - datas (dict): A dictionary containing the activity data.
    - max_activity (float): The maximum activity value.
    - wide (bool): Optional. If True, the plot will be wide. Default is False.

    Returns:
    - y_toreturn (list): The last 500 values of the activity data.

    """
    if wide:
        wide = 'wide'
    else:
        wide = ''
    traces = []
    colors = [(255,0,0),(0,167,159),(0,0,0),(65,61,68)]

    m = 0
    keys = [ 'Activity_P2','Activity_P1','Activity_full']
    for key in keys:
        data = datas[key]
        if key == 'Activity_P1' or key == 'Activity_P2':

            y, y_lower, y_upper = activity_average_sim(data,max_activity)
        else:
            y, y_lower, y_upper = activity_average_sim(torch.Tensor(datas['Activity_P1'])+
                                                       torch.Tensor(datas['Activity_P2']),max_activity)
            y_toreturn = y[-500:]
            y = list(y)
        x = list(torch.arange(len(y)))
        trace_std = go.Scatter(
        x=x+x[::-1],
        y=(y_upper+y_lower) ,
        fill='tozerox',
        fillcolor='rgba({0},{1},{2},0.2)'.format(*colors[m]),
        line=dict(color='rgba(255,255,255,0)'),
        showlegend=False,

        )
        trace = go.Scatter(
        x=x,
        y=y ,
        line=dict(width= 1,color='rgb{0}'.format(colors[m])),
        mode='lines',
        name='{0}'.format(key),
            showlegend=False,
        )
        traces += [trace, trace_std]
        m+=1

    layout = go.Layout(
        paper_bgcolor='rgb(255,255,255)',
        plot_bgcolor='rgb(255,255,255)',
        xaxis=dict(
            gridcolor='rgb(255,255,255)',
            showgrid=True,
            showline=True, linewidth=2, linecolor='black',
            showticklabels=True,
            tickcolor='rgb(0,0,0)',
            zeroline=True
        ),
        yaxis=dict(
            gridcolor='rgb(255,255,255)',
            showgrid=True,
            showline=True, linewidth=2, linecolor='black',
            rangemode="tozero",
            showticklabels=True,        
            zeroline=True,

        ),
    )
    fig = go.Figure(data=traces, layout=layout)
    fig.update_layout(
#         width=1096, height=377,
        font=dict(

                family="sans-serif",
                size=20,
                color="black"
        ),
            paper_bgcolor='rgba(0,0,0,0)',
            plot_bgcolor='rgba(0,0,0,0)',

            font_size = 20,
        shapes=[
        dict(
          type= 'line',
        line=dict(color='black',dash='dash'),
          yref= 'paper', y0= 0, y1= 1,
          xref= 'x', x0= 500, x1= 500
        )], showlegend=True
            )
    fig.update_yaxes(range=[0, .6])
    fig.write_image(wide +"activity_behaviour_{0}_moves_{1}.pdf".format(type_,n_moves))
    fig.show()
    
    
    
    fig = go.Figure()
    colors = [(255,0,0),(0,167,159),(0,0,0),(65,61,68)]

    keys = [ 'Activity_P2','Activity_P1','Activity_full']
    traces = []
    
    start = 496 * 100
    end =  506 * 100

    m = 0
    for key in keys:
        data = datas[key]
        if key == 'Activity_P1' or key == 'Activity_P2':

            y, y_lower, y_upper = activity_average_sim(data,max_activity,start, end,1)
        else:
            y, y_lower, y_upper = activity_average_sim(torch.Tensor(datas['Activity_P1'])+
                                                       torch.Tensor(datas['Activity_P2']),max_activity,start, end,1)
            
        x = list(torch.arange(len(y)))
        trace_std = go.Scatter(
        x=x+x[::-1],
        y=(y_upper+y_lower) ,
        fill='tozerox',
        fillcolor='rgba({0},{1},{2},0.2)'.format(*colors[m]),
        line=dict(color='rgba(255,255,255,0)'),
        showlegend=False,

        )
        trace = go.Scatter(
        x=x,
        y=y ,
        line=dict(width= 1,color='rgb{0}'.format(colors[m])),
        mode='lines',
        name='{0}'.format(key),
            showlegend=False,
        )
        traces += [trace, trace_std]
        m+=1

    layout = go.Layout(
        paper_bgcolor='rgb(255,255,255)',
        plot_bgcolor='rgb(255,255,255)',
        xaxis=dict(
            gridcolor='rgb(255,255,255)',
            showgrid=True,
            showline=False,# linewidth=2, linecolor='black',
            showticklabels=False,
            tickcolor='rgb(0,0,0)',
            zeroline=False
        ),
        yaxis=dict(
            gridcolor='rgb(255,255,255)',
            showgrid=True,
            showline=False,# linewidth=2, linecolor='black',
            #rangemode="tozero",
            showticklabels=False,
            tickcolor='rgb(0,0,0)',

            zeroline=False,

        ),
    )
    fig = go.Figure(data=traces, layout=layout)
    shapes = []
    for step in range(0,int((end-start) / 100),1):
        shapes += [dict(
          type= 'line',
        line=dict(color='black',dash='dash'),
          yref= 'paper', y0= 0, y1= 1,
          xref= 'x', x0= 100 * step, x1= 100 * step
        )]
    fig.update_layout(
        margin=go.layout.Margin(
            l=0, #left margin
            r=0, #right margin
            b=0, #bottom margin
            t=0  #top margin
        ),
        font=dict(

                family="sans-serif",
                size=20,
                color="black"
        ),
            paper_bgcolor='rgba(0,0,0,0)',
            plot_bgcolor='rgba(0,0,0,0)',

            font_size = 20,
            #xaxis_title= "Presented stimulus",
            #yaxis_title= "Projection neurons  activity",
        shapes=shapes, showlegend=False
            )
    fig.write_image(wide +"activity_behaviour_{0}_moves_{1}_zoomed.pdf".format(type_,n_moves))
    fig.show()
    return y_toreturn

In [None]:
final_values = {}
for type_ in ['rand','onehot']:
    if type_ == 'onehot':
        max_activity = .1 * 2 * 128 / 16 * torch.tanh(torch.tensor([1])) ## Expectation of maximum neuron active at the same time
    else:
        max_activity = 0.1 * 0.1 * 2 * 512 / 16 * torch.tanh(torch.log(torch.cosh(torch.tensor([1])))) * 3
    for n_moves in [1,2]:
        print(type_,n_moves)
        filename ='../results/pickle/SpikeSuM_info_moves_{0}_{1}.pkl'.format(n_moves, type_)
        with open(filename, 'rb') as handle:
            datas = pickle.load(handle)
        final_values[type_+'moves_{0}'.format(n_moves)]= [plot(datas,max_activity)]

In [None]:
print('final activities random connexions:',torch.mean(final_values['randmoves_1'][0]).item(),torch.mean(final_values['randmoves_2'][0]).item(),
      ' p value: ',ttest_ind(final_values['randmoves_1'][0],final_values['randmoves_2'][0]).pvalue)
print('final activities one hot connexions:',torch.mean(final_values['onehotmoves_1'][0]).item(),torch.mean(final_values['onehotmoves_2'][0]).item(),
      ' p value: ',ttest_ind(final_values['onehotmoves_1'][0],final_values['onehotmoves_2'][0]).pvalue)