In [12]:
using Flux, Zygote
using Flux.Data: DataLoader
using Statistics: mean
using JLD2
include("preprocessing.jl")

reconstruct_midi_file (generic function with 1 method)

In [13]:
function prepare_dataloader(input_dir::String, train_ratio::Float64=0.8, batch_size::Int=32)
    inputs = []
    outputs = []
    println("Reading data from $(input_dir)...")
    for csv in readdir(input_dir)
        df = CSV.read(joinpath(input_dir, csv), DataFrame)
        if size(df)[1] < 500
            continue
        end
        push!(inputs, Matrix{Float32}(df[1:500, 1:end-1])')
        push!(outputs, sum(df[1:500, end]))
    end
    
    # Find the length of the longest input array
    max_length = maximum(size(input, 1) for input in inputs)
    println("Padding inputs to length $(max_length)...")
    # Pad input arrays with zeros to match the longest array's length
    padded_inputs = []
    for input in inputs
        rows_to_pad = max_length - size(input, 1)
        padded_input = vcat(input, zeros(Float32, rows_to_pad, size(input, 2)))
        push!(padded_inputs, padded_input)
    end

    println("Batching data...")
    # Split the data into train and validation sets
    num_train = Int(round(length(padded_inputs) * train_ratio))
    train_inputs = padded_inputs[1:num_train]
    train_outputs = outputs[1:num_train]
    val_inputs = padded_inputs[num_train+1:end]
    val_outputs = outputs[num_train+1:end]

    train_data = Flux.Data.DataLoader((train_inputs, train_outputs), batchsize=batch_size, shuffle=true)
    val_data = Flux.Data.DataLoader((val_inputs, val_outputs), batchsize=batch_size, shuffle=true)
    println("Done preparing data.")

    return train_data, val_data
end


prepare_dataloader (generic function with 3 methods)

In [14]:
batch_size = 32

train_set, val_set = prepare_dataloader("assets/anomalous")

Reading data from assets/anomalous...
Padding inputs to length 4...
Batching data...
Done preparing data.


(DataLoader(::Tuple{Vector{Any}, Vector{Any}}, shuffle=true, batchsize=32), DataLoader(::Tuple{Vector{Any}, Vector{Any}}, shuffle=true, batchsize=32))

In [25]:
# Many-to-one RNN architecture
model = Flux.Recur(Flux.LSTM(4, 64), 500)

loss(x, y) = Flux.mse(model(x), y)

loss (generic function with 1 method)

In [19]:
opt = Flux.setup(Adam(1e-2), model)
ps = Flux.params(model) # Get the model's parameters

Params([Float32[0.10702626 0.01696822 -0.065254584 0.092855886; -0.10495285 0.13590126 -0.06654843 -0.14827868; … ; 0.08450505 0.07369942 -0.12488393 -0.028589867; -0.09073787 -0.04979955 -0.10373106 0.12580813], Float32[0.009323896 -0.10380827 … -0.05286085 0.027610043; 0.083879985 -0.090319835 … 0.029768014 -0.09390886; … ; 0.0499807 0.08628668 … 0.10984822 0.08031799; -0.0045832857 -0.09732817 … -0.031048324 0.010956693], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.009871284 0.060647424 … 0.010794526 -0.03783303], Float32[0.0]])

In [26]:
num_epochs = 100


for epoch in 1:num_epochs
    for (x, y) in train_set
        val, grads = Flux.withgradient(model) do m
            ŷ = m(x)
            loss_val = Flux.Losses.mse(ŷ, y)
        end

        if !isfinite(val)
            @warn "loss is $val on item $epoch" epoch
            continue
        end
        if length(grads) > 0
            Flux.Optimise.update!(opt, ps, grads[1])
        else
            @warn "no gradients on item $epoch" epoch
        end
    end
    # Optionally, you can evaluate the model on the validation set and print the validation loss here
end

LoadError: DimensionMismatch: new dimensions (500, 4, 32) must be consistent with array size 32

In [None]:
# Test the model on the validation set
val_predictions = [model(reshape(x, size(x, 1), 1, size(x, 2))) for x in X_val]
val_loss = mean(loss_function(ŷ, y) for (ŷ, y) in zip(val_predictions, y_val))
println("Validation Loss: $val_loss")

In [None]:
function preprocess_new_midi(csv_file::String)
    df = CSV.read(csv_file, DataFrame)
    if size(df)[1] < 500
        println("MIDI file is too short (< 500 rows).")
        return nothing
    end
    input = Matrix{Float32}(df[1:500,Not(:anomalies)])
    if sum(df.anomalies[1:500]) > 0
        println("MIDI file contains $(sum(df.anomalies[1:500])) anomalies.")
    end
    return input
end


In [None]:
display(model)

In [None]:
# Preprocess the new MIDI file
new_midi_file = "assets/anomalous/liz_rhap15_0.3.csv"
new_midi_input = preprocess_new_midi(new_midi_file)

# Check if the preprocessing was successful (the file had at least 500 rows)
if new_midi_input !== nothing
    # Reshape the input array to match the model's input shape
    new_midi_input = reshape(new_midi_input, size(new_midi_input, 1), 1, size(new_midi_input, 2))
    println(new_midi_input)    
    # Predict the number of errors in the new MIDI file
    num_errors = model(new_midi_input)

    println("Predicted number of errors: ", num_errors[1])
else
    println("Prediction cannot be performed due to insufficient data.")
end