In [1]:
# This program is designed to generate heat maps for T2 from a folder of .csv files.

# This cell is composed of the necessary imports and class formations for the program.

# Used for math opperations:
import numpy as np
import math

# Used for plotting data:
import matplotlib.pyplot as plt
import seaborn as sns

# Used to read files:
import pandas as pd

In [2]:
'''This cell extracts a list of file names present in a folder, given the computer address of the folder. 
We need to draw a distinction between Windows devices and Mac/Linux devices because of the formatting
of addresses. When given the address of the folder, getFile returns the name of the file in the folder associated 
to a given n and s. getFiles iterates over a list 'nVals' of desired n values as strings, and a list 'sVals' of
desired s values as strings, and returns a list of the corresponding files.'''

def getFile(address: str, n: str, s: str) -> str:
    name = address + '/' + n + '_' + s + '.csv'
    return name


def getFiles(address: str, nVals: list[str], sVals: list[str]) -> list[str]:
    files = [getFile(address, n, s) for n in nVals for s in sVals]
    return files

def getFileWindows(address: str, n: str, s: str) -> str:
    name = address + '\\' + n + '_' + s + '.csv'
    return name


def getFilesWindows(address: str, nVals: list[str], sVals: list[str]) -> list[str]:
    files = [getFileWindows(address, n, s) for n in nVals for s in sVals]
    return files

In [3]:
'''This cell retrieves the data frames from the files in the folder. getDataFrame is able to obtain the data
frame of a given file, and getDataFrames iterates over all files obtained by getfiles and returns a list of
the corresponding dataframes. getAllTrialData then uses this list to produce a list of the 'datum' columns 
of these data frames, converted into lists form.'''


def getDataFrame(file: str) -> pd.core.frame.DataFrame or str:
    colnames = ['i', 'j', 'datum']
    try: 
        fileData = pd.read_csv(file, names=colnames)
        return fileData
    except: 
        return file


def getDataFrames(files: list[str]) -> list[pd.core.frame.DataFrame or str]:
    dataFrames = [getDataFrame(file) for file in files]
    return dataFrames


def getTrialData(files: list[str]) -> tuple[list[list[list[int]]], list[str]]:
    trialData = []
    badFiles = []
    for dataFrame in getDataFrames(files):
        if type(dataFrame) == str:
            badFiles.append(dataFrame)
        else:
            temp = []
            trial = list(dataFrame['datum'])
            n = int(np.sqrt(len(trial)))
            for p in range(n):
                division = []
                for q in range(n):
                    division.append(trial[p*n + q])
                temp.append(division)
            trialData.append(temp)
    return trialData, badFiles

In [4]:
# This cell creates a heat maps for each specified s for a fixed n.

def graphHeatMap(n: str, trialData: list[list[list[float]]], sVals: list[str], nVals: list[str]) -> None:
    index = nVals.index(n) # This will allow us to ensure that we have the right index.
    fig = plt.figure(figsize=(30,30*math.ceil(len(sVals)/3))) # Creates a plot.
    plt.style.use('fivethirtyeight') # Specifies plot style.
    fig.suptitle("N = {}".format(n), fontsize=15) # Labes the heat maps according the n value.
    # Creates an axis for each of the s values:
    # Creates the heat maps:
    for i in range(len(sVals)):
        ax = plt.subplot(math.ceil(len(sVals)/3), 3, 1 + i)
        ax = sns.heatmap(trialData[index*len(sVals) + i], cmap = 'coolwarm', cbar = False)
        ax.set_title('s = {}'.format(sVals[i]))
        ax.set_aspect('equal')
        plt.axis('off')

In [5]:
# This cell creates all of the heat maps for all s and n specified. It runs graphHeatMap for each n specified.

def graphHeatMaps(trialData: list[list[list[float]]], sVals: list[str], nVals: list[str], 
                  badFiles = list[str]) -> None or str:
    if not badFiles:
        for n in nVals:
            graphHeatMap(n, trialData, sVals, nVals)
    else:
        print("The heat maps could not be generated because the following files could not be located:")
        for badFile in badFiles:
            print(badFile)
        print("Please remove the problematic s and n values from your inputs and try again.")

In [6]:
# %matplotlib qt

# address = '/Users/alexanderneuschotz/Downloads/Torus'

# nList = ['3', '4']
# sListPadded = ['0.000000', '0.250000', '0.500000']


# files = getFiles(address, nList, sListPadded)
# trialData, badFiles = getTrialData(files)
# graphHeatMaps(trialData, sListPadded, nList, badFiles)