In [11]:
import os
import random
import numpy as np
import scipy.io as spio
import plotly.graph_objects as go

def generateSmithArray(inputCov,inputLab):
    outputArr = []
    for tr,arr in enumerate(inputCov):
        if np.shape(arr) != (0,0):
            currLabelArr = [inputLab[tr]] * np.shape(arr)[0]
            outputArr = np.concatenate((outputArr,currLabelArr),axis=0)
    return outputArr

def generateSmithColorArray(inputSmith):
    outputColors = []
    for el in inputSmith:
        if el >= 0.75:
            outputColors.append([51/255,34/255,136/255])
        elif el >= 0.5:
            outputColors.append([68/255,170/255,153/255])
        elif el >= 0.25:
            outputColors.append([221/255,204/255,119/255])
        else:
            outputColors.append([170/255,68/255,153/255])
    return np.asarray(outputColors)

def getTrialColorArray(inputLengthFileName):
    smat = spio.loadmat(inputLengthFileName)
    covMat = smat.get('allCovMat')[0]
    colorArr = []
    for el,arr in enumerate(covMat):
        if np.shape(arr) != (0,0):
            currLen = np.shape(arr)[0]
            currColorTuple = (np.random.uniform(0,255),np.random.uniform(0,255),np.random.uniform(0,250))
            for i in range(currLen):
                colorArr.append(currColorTuple)
    return colorArr

def generateSmithLabels(fullRawName,justFileName):
    if '/Spa_mPFC_Sal/' in fullRawName:
        currSmithFile = os.path.join('/home/aditya/Extracted_Data/06_Consol_Smith/mPFC/Saline/',justFileName)
        currLenFile = os.path.join('/home/aditya/Extracted_Data/11_Spike_Cross_Corr/05_SSASC_Cov_Theta_Spa/mPFC/Saline/',
                                   justFileName)
        ymat = spio.loadmat(currSmithFile)
        zmat = spio.loadmat(currLenFile)
        lowSmith = ymat.get('graCut')[:,0]
        allCovMat = zmat.get('allCovMat')[0]
    elif '/Spa_mPFC_Mus/' in fullRawName:
        currSmithFile = os.path.join('/home/aditya/Extracted_Data/06_Consol_Smith/mPFC/Muscimol/',justFileName)
        currLenFile = os.path.join('/home/aditya/Extracted_Data/11_Spike_Cross_Corr/05_SSASC_Cov_Theta_Spa/mPFC/Muscimol/',
                                   justFileName)
        ymat = spio.loadmat(currSmithFile)
        zmat = spio.loadmat(currLenFile)
        lowSmith = ymat.get('graCut')[:,0]
        allCovMat = zmat.get('allCovMat')[0]
    elif '/Spatial/' in fullRawName:
        currSmithFile = os.path.join('/home/aditya/Extracted_Data/11_Spike_Cross_Corr/05_SSASC_Cov_Theta_Spa/Striatum/',
                                   justFileName)
        currLenFile = os.path.join('/home/aditya/Extracted_Data/11_Spike_Cross_Corr/05_SSASC_Cov_Theta_Spa/Striatum/',
                                   justFileName)
        ymat = spio.loadmat(currSmithFile)
        zmat = spio.loadmat(currLenFile)
        lowSmith = ymat.get('lowGra').flatten()
        allCovMat = zmat.get('allCovMat')[0]
    elif '/Cue/' in fullRawName:
        currSmithFile = os.path.join('/home/aditya/Extracted_Data/11_Spike_Cross_Corr/05_SSASC_Cov_Theta_Cue/Striatum/',
                                   justFileName)
        currLenFile = os.path.join('/home/aditya/Extracted_Data/11_Spike_Cross_Corr/05_SSASC_Cov_Theta_Cue/Striatum/',
                                   justFileName)
        ymat = spio.loadmat(currSmithFile)
        zmat = spio.loadmat(currLenFile)
        #lowSmith = ymat.get('lowGra').flatten()
        allCovMat = zmat.get('allCovMat')[0]
        lowSmith = np.ones((len(allCovMat,)))
    elif '/HPC_Only_Mus' in fullRawName:
        currSmithFile = os.path.join('/home/aditya/Extracted_Data/06_Consol_Smith/OFC_LFR/HPC_Only/Muscimol/',justFileName)
        currLenFile = os.path.join('/home/aditya/Extracted_Data/11_Spike_Cross_Corr/05_SSASC_Cov_Theta_Spa/OFC_LFR/HPC_Only/Muscimol/',
                                   justFileName)
        ymat = spio.loadmat(currSmithFile)
        zmat = spio.loadmat(currLenFile)
        lowSmith = ymat.get('graCut')[:,0]
        allCovMat = zmat.get('allCovMat')[0]
    elif '/HPC_Only_Sal' in fullRawName:
        currSmithFile = os.path.join('/home/aditya/Extracted_Data/06_Consol_Smith/OFC_LFR/HPC_Only/Saline/',justFileName)
        currLenFile = os.path.join('/home/aditya/Extracted_Data/11_Spike_Cross_Corr/05_SSASC_Cov_Theta_Spa/OFC_LFR/HPC_Only/Saline/',
                                   justFileName)
        ymat = spio.loadmat(currSmithFile)
        zmat = spio.loadmat(currLenFile)
        lowSmith = ymat.get('graCut')[:,0]
        allCovMat = zmat.get('allCovMat')[0]
    elif '/HPC_OFC' in fullRawName:
        currSmithFile = os.path.join('/home/aditya/Extracted_Data/06_Consol_Smith/OFC_LFR/HPC_OFC/',justFileName)
        currLenFile = os.path.join('/home/aditya/Extracted_Data/11_Spike_Cross_Corr/05_SSASC_Cov_Theta_Spa/OFC_LFR/HPC_OFC/',
                                   justFileName)
        ymat = spio.loadmat(currSmithFile)
        zmat = spio.loadmat(currLenFile)
        lowSmith = ymat.get('graCut')[:,0]
        allCovMat = zmat.get('allCovMat')[0]
    fileSmithArray = generateSmithArray(allCovMat,lowSmith)
    return generateSmithColorArray(fileSmithArray),getTrialColorArray(currLenFile)

def generateJourneyColors(starts,goals):
    journeyColorArrFull = []
    for a,el in enumerate(starts):
        currJourney = f'{starts[a][0]}{goals[a][0]}'
        if currJourney == 'NE':
            journeyColorArrFull.append([51/255,34/255,136/255])
        if currJourney == 'NW':
            journeyColorArrFull.append([68/255,170/255,153/255])
        if currJourney == 'SE':
            journeyColorArrFull.append([221/255,204/255,119/255])
        if currJourney == 'SW':
            journeyColorArrFull.append([170/255,68/255,153/255])
    return np.asarray(journeyColorArrFull)

def generateTrialLocColors(fullRawName,justFileName):
    if '/Spa_mPFC_Sal/' in fullRawName:
        currLenFile = os.path.join('/home/aditya/Neural_Data_Structure/mPFC/01_Point_Cloud_TSNE_3D/Spa_mPFC_Sal/',
                                   justFileName)
        ymat = spio.loadmat(currLenFile)
    elif '/Spa_mPFC_Mus/' in fullRawName:
        currLenFile = os.path.join('/home/aditya/Neural_Data_Structure/mPFC/01_Point_Cloud_TSNE_3D/Spa_mPFC_Mus/',
                                   justFileName)
        ymat = spio.loadmat(currLenFile)
    elif '/Spatial/' in fullRawName:
        currLenFile = os.path.join('/home/aditya/Neural_Data_Structure/mPFC/01_Point_Cloud_TSNE_3D/Spatial/',justFileName)
        ymat = spio.loadmat(currLenFile)
    elif '/Cue/' in fullRawName:
        currLenFile = os.path.join('/home/aditya/Neural_Data_Structure/mPFC/01_Point_Cloud_TSNE_3D/Cue/',justFileName)
        ymat = spio.loadmat(currLenFile)
    elif '/HPC_OFC' in fullRawName:
        currLenFile = os.path.join('/home/aditya/Neural_Data_Structure/OFC/01_Point_Cloud_TSNE_3D/HPC_OFC/',justFileName)
        ymat = spio.loadmat(currLenFile)
    elif 'HPC_Only_Sal/' in fullRawName:
        currLenFile = os.path.join('/home/aditya/Neural_Data_Structure/OFC/01_Point_Cloud_TSNE_3D/HPC_Only_Sal/',justFileName)
        ymat = spio.loadmat(currLenFile)
    elif 'HPC_Only_Mus/' in fullRawName:
        currLenFile = os.path.join('/home/aditya/Neural_Data_Structure/OFC/01_Point_Cloud_TSNE_3D/HPC_Only_Mus/',justFileName)
        ymat = spio.loadmat(currLenFile)
    trialPositions = ymat.get('trialPositions').flatten()
    trialPositionColorArrFull = []
    for z,randomVar in enumerate(trialPositions):
        if trialPositions[z][0] == 'S':
            trialPositionColorArrFull.append([51/255,34/255,136/255])
        elif trialPositions[z][0] == 'C':
            trialPositionColorArrFull.append([68/255,170/255,153/255])
        elif trialPositions[z][0] == 'G':
            trialPositionColorArrFull.append([221/255,204/255,119/255])
        else:
            trialPositionColorArrFull.append([170/255,68/255,153/255])
    return np.asarray(trialPositionColorArrFull)

def loadValuesFromFile(fileOfInterest,singleFileName):
    rmat = spio.loadmat(fileOfInterest)
    jColArr = generateJourneyColors(rmat.get('startArm').flatten(),rmat.get('goalArm').flatten())
    tLocArr = generateTrialLocColors(fileOfInterest,singleFileName)
    sLabArr,tLabArr = generateSmithLabels(fileOfInterest,singleFileName)
    allEmbed = rmat.get('allMatrix')
    hpcEmbed = rmat.get('hpcMatrix')
    pfcEmbed = rmat.get('pfcMatrix')
    sc = np.shape(allEmbed)[0]
    return jColArr[:sc],tLocArr[:sc],sLabArr[:sc],tLabArr[:sc],allEmbed,hpcEmbed,pfcEmbed

def generateValueColorBlocks(inputValues,inputColors):
    redBlock = np.zeros((1,3))
    blueBlock = np.zeros((1,3))
    greenBlock = np.zeros((1,3))
    purpleBlock = np.zeros((1,3))
    for vs,colorEl in enumerate(inputColors):
        if colorEl[0] == 51/255:
            redBlock = np.concatenate((redBlock,np.reshape(inputValues[vs],(1,3))),axis=0)
        if colorEl[0] == 68/255:
            blueBlock = np.concatenate((blueBlock,np.reshape(inputValues[vs],(1,3))),axis=0)
        if colorEl[0] == 221/255:
            greenBlock = np.concatenate((greenBlock,np.reshape(inputValues[vs],(1,3))),axis=0)
        if colorEl[0] == 170/255:
            purpleBlock = np.concatenate((purpleBlock,np.reshape(inputValues[vs],(1,3))),axis=0)
    return redBlock[1:],blueBlock[1:],greenBlock[1:],purpleBlock[1:]

def generateSingleFigure(figType,values,colorVals,titleName,savHTML):
    rBlock,bBlock,gBlock,pBlock = generateValueColorBlocks(values,colorVals)
    if figType == 'byTrial':
        fig = go.Figure(data=[go.Scatter3d(x=values[:,0], y=values[:,1], z=values[:,2], mode='markers',
                                              marker=dict(size=3,color=colorVals))])
    else:
        fig = go.Figure()
        if figType == 'byJourney':
            fig.add_trace(go.Scatter3d(x=rBlock[:,0], y=rBlock[:,1], z=rBlock[:,2], name='NE',
                                  mode='markers',marker=dict(size=3,color='rgb(51,34,136)')))
            fig.add_trace(go.Scatter3d(x=bBlock[:,0], y=bBlock[:,1], z=bBlock[:,2], name='NW',
                                  mode='markers',marker=dict(size=3,color='rgb(68,170,153)')))
            fig.add_trace(go.Scatter3d(x=gBlock[:,0], y=gBlock[:,1], z=gBlock[:,2], name='SE',
                                  mode='markers',marker=dict(size=3,color='rgb(221,204,119)')))
            fig.add_trace(go.Scatter3d(x=pBlock[:,0], y=pBlock[:,1], z=pBlock[:,2], name='SW',
                                  mode='markers',marker=dict(size=3,color='rgb(170,68,153)')))
        if figType == 'byTrialLoc':
            fig.add_trace(go.Scatter3d(x=pBlock[:,0], y=pBlock[:,1], z=pBlock[:,2], name='WP',
                                  mode='markers',marker=dict(size=3,color='purple')))
            fig.add_trace(go.Scatter3d(x=rBlock[:,0], y=rBlock[:,1], z=rBlock[:,2], name='Start',
                                  mode='markers',marker=dict(size=3,color='red')))
            fig.add_trace(go.Scatter3d(x=bBlock[:,0], y=bBlock[:,1], z=bBlock[:,2], name='Choice',
                                  mode='markers',marker=dict(size=3,color='blue')))
            fig.add_trace(go.Scatter3d(x=gBlock[:,0], y=gBlock[:,1], z=gBlock[:,2], name='Goal',
                                  mode='markers',marker=dict(size=3,color='green')))
        if figType == 'bySmith':
            fig.add_trace(go.Scatter3d(x=rBlock[:,0], y=rBlock[:,1], z=rBlock[:,2], name='Smith ≥ .75',
                                  mode='markers',marker=dict(size=3,color='rgb(51,34,136)')))
            fig.add_trace(go.Scatter3d(x=bBlock[:,0], y=bBlock[:,1], z=bBlock[:,2], name='.5 ≤ Smith < .75',
                                  mode='markers',marker=dict(size=3,color='rgb(68,170,153)')))
            fig.add_trace(go.Scatter3d(x=gBlock[:,0], y=gBlock[:,1], z=gBlock[:,2], name='.25 ≤ Smith < .5',
                                  mode='markers',marker=dict(size=3,color='rgb(221,204,119)')))
            fig.add_trace(go.Scatter3d(x=pBlock[:,0], y=pBlock[:,1], z=pBlock[:,2], name='Smith < .25',
                                  mode='markers',marker=dict(size=3,color='rgb(170,68,153)')))        
    fig.update_layout(title=titleName)
    fig.update_layout(scene = dict(xaxis_title='Axis 1',yaxis_title='Axis 2',zaxis_title='Axis 3'))
    fig.write_html(savHTML)
    return

def buildAllFiguresFromSingleFile(currRawFile,currFileName,savPath1,savPath2,savPath3,savPath4):
    jCol,tlCo,sCol,trCo,allV,hpcV,pfcV = loadValuesFromFile(currRawFile,currFileName)
    htmlEnd = currFileName[:-12]+'.html'
    if 'Spa_mPFC' not in savPath1:
        allEnd = f'All/{htmlEnd}'
        hpcEnd = f'HPC/{htmlEnd}'
        pfcEnd = f'mPFC/{htmlEnd}'
        allTitle = 'Conjoint'
        hpcTitle = 'HPC'
        pfcTitle = 'mPFC'
    else:
        allEnd = f'All/{htmlEnd}'
        hpcEnd = f'HPC/{htmlEnd}'
        pfcEnd = f'Striatum/{htmlEnd}'
        allTitle = 'Conjoint'
        hpcTitle = 'HPC'
        pfcTitle = 'Striatum'
    '''if '/HPC_OFC_Dual/' in savPath1:
        allEnd = f'All/{htmlEnd}'
        hpcEnd = f'HPC/{htmlEnd}'
        pfcEnd = f'OFC/{htmlEnd}'
        allTitle = 'Conjoint'
        hpcTitle = 'HPC'
        pfcTitle = 'OFC'
        #fig00 = generateSingleFigure('byJourney',allV,jCol,allTitle,f'{savPath1}{allEnd}')
        #fig01 = generateSingleFigure('byJourney',hpcV,jCol,hpcTitle,f'{savPath1}{hpcEnd}')
        #fig02 = generateSingleFigure('byJourney',pfcV,jCol,pfcTitle,f'{savPath1}{pfcEnd}')
        fig03 = generateSingleFigure('byTrialLoc',allV,tlCo,allTitle,f'{savPath2}{allEnd}')
        fig04 = generateSingleFigure('byTrialLoc',hpcV,tlCo,hpcTitle,f'{savPath2}{hpcEnd}')
        fig05 = generateSingleFigure('byTrialLoc',pfcV,tlCo,pfcTitle,f'{savPath2}{pfcEnd}')
        #fig06 = generateSingleFigure('bySmith',allV,sCol,allTitle,f'{savPath3}{allEnd}')
        #fig07 = generateSingleFigure('bySmith',hpcV,sCol,hpcTitle,f'{savPath3}{hpcEnd}')
        #fig08 = generateSingleFigure('bySmith',pfcV,sCol,pfcTitle,f'{savPath3}{pfcEnd}')
        #fig09 = generateSingleFigure('byTrial',allV,trCo,allTitle,f'{savPath4}{allEnd}')
        #fig10 = generateSingleFigure('byTrial',hpcV,trCo,hpcTitle,f'{savPath4}{hpcEnd}')
        #fig11 = generateSingleFigure('byTrial',pfcV,trCo,pfcTitle,f'{savPath4}{pfcEnd}')
    else:
        allEnd = f'{htmlEnd}'
        allTitle = 'Conjoint'
        fig00 = generateSingleFigure('byJourney',allV,jCol,allTitle,f'{savPath1}{allEnd}')
        fig03 = generateSingleFigure('byTrialLoc',allV,tlCo,allTitle,f'{savPath2}{allEnd}')
        fig06 = generateSingleFigure('bySmith',allV,sCol,allTitle,f'{savPath3}{allEnd}')
        fig09 = generateSingleFigure('byTrial',allV,trCo,allTitle,f'{savPath4}{allEnd}')'''
    #fig00 = generateSingleFigure('byJourney',allV,jCol,allTitle,f'{savPath1}{allEnd}')
    #fig01 = generateSingleFigure('byJourney',hpcV,jCol,hpcTitle,f'{savPath1}{hpcEnd}')
    #fig02 = generateSingleFigure('byJourney',pfcV,jCol,pfcTitle,f'{savPath1}{pfcEnd}')
    fig03 = generateSingleFigure('byTrialLoc',allV,tlCo,allTitle,f'{savPath2}{allEnd}')
    fig04 = generateSingleFigure('byTrialLoc',hpcV,tlCo,hpcTitle,f'{savPath2}{hpcEnd}')
    fig05 = generateSingleFigure('byTrialLoc',pfcV,tlCo,pfcTitle,f'{savPath2}{pfcEnd}')
    #fig06 = generateSingleFigure('bySmith',allV,sCol,allTitle,f'{savPath3}{allEnd}')
    #fig07 = generateSingleFigure('bySmith',hpcV,sCol,hpcTitle,f'{savPath3}{hpcEnd}')
    #fig08 = generateSingleFigure('bySmith',pfcV,sCol,pfcTitle,f'{savPath3}{pfcEnd}')
    #fig09 = generateSingleFigure('byTrial',allV,trCo,allTitle,f'{savPath4}{allEnd}')
    #fig10 = generateSingleFigure('byTrial',hpcV,trCo,hpcTitle,f'{savPath4}{hpcEnd}')
    #fig11 = generateSingleFigure('byTrial',pfcV,trCo,pfcTitle,f'{savPath4}{pfcEnd}')
    return 'Done'

def handleFromSeveralPaths(pathsRaw,pathsSav1,pathsSav2,pathsSav3,pathsSav4):
    for z,singleRawP in enumerate(pathsRaw):
        for aRoot,aDirs,aFiles in os.walk(singleRawP):
            for aFile in aFiles:
                if aFile.endswith('.mat'):
                    currentRaw = os.path.join(singleRawP,aFile)
                    buildVar = buildAllFiguresFromSingleFile(currentRaw,aFile,pathsSav1[z],pathsSav2[z],
                                                                 pathsSav3[z],pathsSav4[z])
        print(singleRawP)
    return 'Done'

rawPs = []

savP1 = []

savP2 = []

savP3 = []

savP4 = []

doneVar = handleFromSeveralPaths(rawPs,savP1,savP2,savP3,savP4)
print(doneVar)


/home/aditya/Neural_Data_Structure/mPFC/01_Point_Cloud_MDS_3D/Spatial/
Done
