In [12]:
import os
import random
import numpy as np
import scipy.io as spio
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.graph_objs import Layout

def generateLocIndexArray(hpcM,pfcM,sArm,gArm):
    iSW = []
    iSE = []
    iNW = []
    iNE = []
    for a,el in enumerate(sArm):
        currJourn = f'{sArm[a]}{gArm[a]}'
        if currJourn == 'SW':
            iSW.append(a)
        if currJourn == 'SE':
            iSE.append(a)
        if currJourn == 'NW':
            iNW.append(a)
        if currJourn == 'NE':
            iNE.append(a)
    swHPC = np.take(hpcM,iSW,axis=0)
    seHPC = np.take(hpcM,iSE,axis=0)
    nwHPC = np.take(hpcM,iNW,axis=0)
    neHPC = np.take(hpcM,iNE,axis=0)
    swPFC = np.take(pfcM,iSW,axis=0)
    sePFC = np.take(pfcM,iSE,axis=0)
    nwPFC = np.take(pfcM,iNW,axis=0)
    nePFC = np.take(pfcM,iNE,axis=0)
    return swHPC,seHPC,nwHPC,neHPC,swPFC,sePFC,nwPFC,nePFC

def generateIndexLocArray(hpcM,pfcM,locI):
    wpI = []
    stI = []
    chI = []
    goI = []
    for ts,sLoc in enumerate(locI):
        if 'P' in sLoc:
            wpI.append(ts)
        if 'S' in sLoc:
            stI.append(ts)
        if 'C' in sLoc:
            chI.append(ts)
        if 'G' in sLoc:
            goI.append(ts)
    wpHPC = np.take(hpcM,wpI,axis=0)
    stHPC = np.take(hpcM,stI,axis=0)
    chHPC = np.take(hpcM,chI,axis=0)
    goHPC = np.take(hpcM,goI,axis=0)
    wpPFC = np.take(pfcM,wpI,axis=0)
    stPFC = np.take(pfcM,stI,axis=0)
    chPFC = np.take(pfcM,chI,axis=0)
    goPFC = np.take(pfcM,goI,axis=0)
    return wpHPC,stHPC,chHPC,goHPC,wpPFC,stPFC,chPFC,goPFC

def getValues(inputMDS,inputTSNE):
    mmat = spio.loadmat(inputMDS)
    tmat = spio.loadmat(inputTSNE)
    hpcMatrixMDS = mmat.get('hpcMatrix')
    pfcMatrixMDS = mmat.get('pfcMatrix')
    posIndex = tmat.get('trialPositions').flatten()
    startArm = mmat.get('startArm').flatten()
    goalArm = mmat.get('goalArm').flatten()
    return generateIndexLocArray(hpcMatrixMDS,pfcMatrixMDS,posIndex)

def generateSingleFigure(pBlock,rBlock,bBlock,gBlock,savName):
    fig = go.Figure()
    fig.add_trace(go.Scatter3d(x=rBlock[:,0], y=rBlock[:,1], z=rBlock[:,2], name='SE',
                          mode='markers',marker=dict(size=3,color='#DDCC77')))
    fig.add_trace(go.Scatter3d(x=bBlock[:,0], y=bBlock[:,1], z=bBlock[:,2], name='NE',
                          mode='markers',marker=dict(size=3,color='#44AA99')))
    fig.add_trace(go.Scatter3d(x=gBlock[:,0], y=gBlock[:,1], z=gBlock[:,2], name='NW',
                          mode='markers',marker=dict(size=3,color='#332288')))
    fig.add_trace(go.Scatter3d(x=pBlock[:,0], y=pBlock[:,1], z=pBlock[:,2], name='SW',
                          mode='markers',marker=dict(size=3,color='#AA4499')))
    fig.update_layout(scene = dict(xaxis = dict(backgroundcolor="#F5F5F5",gridcolor="#EAEAEA",showbackground=True,
                         zerolinecolor="#EAEAEA",),
                    yaxis = dict(backgroundcolor="#F5F5F5",gridcolor="#EAEAEA",showbackground=True,
                        zerolinecolor="#EAEAEA"),
                    zaxis = dict(backgroundcolor="#F5F5F5",gridcolor="#EAEAEA",showbackground=True,
                        zerolinecolor="#EAEAEA",),),
                  )
    fig.write_html(savName)
    return None

def generateOtherSingleFigure(rot1,rot2,pBlock,rBlock,bBlock,gBlock,savName):
    fig = plt.figure(figsize=(20,20))
    ax01 = fig.add_subplot(2,2,1,projection='3d')
    ax01.view_init(rot1,rot2)
    
    ax01.scatter3D(pBlock[:,0],pBlock[:,1],pBlock[:,2],color='purple',s=35,alpha=1,zorder=-1)
    ax01.scatter3D(rBlock[:,0],rBlock[:,1],rBlock[:,2],color='red',s=35,alpha=1,zorder=1)
    ax01.scatter3D(bBlock[:,0],bBlock[:,1],bBlock[:,2],color='blue',s=35,alpha=1,zorder=3)
    ax01.scatter3D(gBlock[:,0],gBlock[:,1],gBlock[:,2],color='green',s=35,alpha=1,zorder=2)
    ax01.set_xticklabels('',fontsize=16)
    ax01.set_yticklabels('',fontsize=16)
    ax01.set_zticklabels('',fontsize=16)
    ax01.set_xticks([])
    ax01.set_yticks([])
    ax01.set_zticks([])
    #plt.show()
    fig.savefig(savName,bbox_inches='tight')
    plt.close()
    return

def generateFigure(rawFileMDS,rawFileTSNE,savFigFile1,savFigFile2):
    hWP,hST,hCH,hGO,pWP,pST,pCH,pGO = getValues(rawFileMDS,rawFileTSNE)
    figHPC = generateOtherSingleFigure(15,130,hWP,hST,hCH,hGO,savFigFile1)
    figPFC = generateOtherSingleFigure(15,130,pWP,pST,pCH,pGO,savFigFile2)
    return None

rawMDS = ''
rawTSNE = ''
savFile1 = ''
savFile2 = ''
figGen = generateFigure(rawMDS,rawTSNE,savFile1,savFile2)
print('Done')


Done
