# Simple test with real data

In [None]:
import csv
import itertools as itt

import matplotlib.pyplot as plt
import numpy as np
import lengthPriors

import dynamicComputation as dc
import readers

In [None]:
from defaultPriors import arcPriorTempo, arcPriorLoud, lengthPriorParamsTempo as lengthPriorParams

# Loading some data

In [None]:
timingsData = readers.readAllMazurkaTimingsAndSeg(timingPath="data/beat_time", segPath="data/deaf_structure_tempo")
dynData = readers.readAllMazurkaDataAndSeg(timingPath="data/beat_dyn", segPath="data/deaf_structure_loudness")

fullData = [(piece, interpret, tempo, tempoSeg, dyn, dynSeg) 
            for ((piece, interpret, tempo, tempoSeg),(piece2, interpret2, dyn, dynSeg))
            in itt.product(timingsData,dynData)
            if interpret == interpret2]

In [None]:
arcPrior = [arcPriorTempo, arcPriorLoud]

# Get it Running

In [None]:
# Unpack the data
(piece, interpret, tempo, tempoSeg, dyn, dynSeg) = fullData[0]

piece_formatted = piece[16:20]
print(piece_formatted, interpret)

sampleData = list(zip(tempo, dyn[1:]))
segs = (tempoSeg, dynSeg)

tatums = list(range(len(sampleData)))
# tatums,idx = np.unique(tatums[1:],return_index=True)
# sampleData = y[idx]

# idx = sampleData<300
# sampleData = sampleData[idx]
# tatums = tatums[idx]

# sampleData = list(zip(sampleData,sampleData))
# tatums=tatums[:300]
# sampleData=sampleData[:300]

lengthPrior = lengthPriors.NormalLengthPrior(lengthPriorParams['mean'], lengthPriorParams['stddev'], range(
    len(sampleData)), lengthPriorParams['maxLength'])

posteriorMarginals = dc.runAlphaBeta(sampleData, arcPrior, lengthPrior)

fig, ax1 = plt.subplots()
tempo, dyn = zip(*sampleData)
ax1.plot(tatums, tempo, color="r")  # Tempo input data
plt.ylim(0, 300)

ax2 = ax1.twinx()

ax2.plot(tatums, posteriorMarginals, 'k')  # Posterior Marginals
plt.ylim(0, 1)
ax2.plot(tatums, dyn, color="b")  # Dyn input data
plt.vlines(segs[0], ymin=0, ymax=1, colors="r", linestyle='dotted')  # Tempo seg
plt.vlines(segs[1], ymin=0, ymax=1, colors="b", linestyle='dotted')  # Dyn seg
plt.show()

In [None]:
with open(f"output/{piece_formatted}_{interpret}_pm.csv", 'w') as outfile:
    csvWriter = csv.writer(outfile)
    csvWriter.writerow(["Beat count", "Posterior Marginal"])
    csvWriter.writerows(enumerate(posteriorMarginals))