In [71]:
using Flux, Zygote
using Flux.Data: DataLoader
using Statistics: mean
using JLD2
include("preprocessing.jl")

reconstruct_midi_file (generic function with 1 method)

In [72]:
function prepare_dataloader(input_dir::String, train_ratio::Float64=0.8, batch_size::Int=32)
    inputs = []
    outputs = []
    println("Reading data from $(input_dir)...")
    for csv in readdir(input_dir)
        df = CSV.read(joinpath(input_dir, csv), DataFrame)
        if size(df)[1] < 500
            continue
        end
        push!(inputs, Matrix{Float32}(df[1:500, 1:end-1])')
        push!(outputs, sum(df[1:500, end]))
    end
    
    # Find the length of the longest input array
    max_length = maximum(size(input, 1) for input in inputs)
    println("Padding inputs to length $(max_length)...")
    # Pad input arrays with zeros to match the longest array's length
    padded_inputs = []
    for input in inputs
        rows_to_pad = max_length - size(input, 1)
        padded_input = vcat(input, zeros(Float32, rows_to_pad, size(input, 2)))
        push!(padded_inputs, padded_input)
    end

    println("Batching data...")
    # Split the data into train and validation sets
    num_train = Int(round(length(padded_inputs) * train_ratio))
    train_inputs = padded_inputs[1:num_train]
    train_outputs = outputs[1:num_train]
    val_inputs = padded_inputs[num_train+1:end]
    val_outputs = outputs[num_train+1:end]

    train_data = Flux.Data.DataLoader((train_inputs, train_outputs), batchsize=batch_size, shuffle=true, partial=false)
    val_data = Flux.Data.DataLoader((val_inputs, val_outputs), batchsize=batch_size, shuffle=true, partial=false)
    println("Done preparing data.")

    return train_data, val_data
end


prepare_dataloader (generic function with 3 methods)

In [73]:
batch_size = 32

train_set, val_set = prepare_dataloader("assets/anomalous")

Reading data from assets/anomalous...
Padding inputs to length 4...
Batching data...
Done preparing data.


(DataLoader(::Tuple{Vector{Any}, Vector{Any}}, shuffle=true, batchsize=32, partial=false), DataLoader(::Tuple{Vector{Any}, Vector{Any}}, shuffle=true, batchsize=32, partial=false))

In [76]:
# Many-to-one RNN architecture
model = Flux.Chain(
    Flux.LSTM(4, 64),
    x -> x[:, :, end],  # Select the hidden state at the last time step
    Dense(64, 1)
)

loss(x, y) = Flux.mse(model(x), y)

loss (generic function with 1 method)

In [77]:
opt = Flux.setup(Adam(1e-2), model)
ps = Flux.params(model) # Get the model's parameters

Params([Float32[0.097776175 0.124119595 -0.13322908 -0.019199725; -0.051209584 0.08477816 -0.10397716 0.049686965; … ; -0.038245685 0.016786586 0.071300104 -0.10400492; -0.098428905 0.0021999406 -0.022911513 0.031021858], Float32[0.04454969 0.09986814 … 0.06935859 0.054235328; -0.12886257 -0.11384021 … 0.068782404 0.045261636; … ; -0.072473325 0.1240657 … -0.07712732 0.11565524; 0.033414874 0.053624082 … -0.05203304 -0.12691651], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.17766395 0.10723281 … 0.23744194 0.24607086], Float32[0.0]])

In [125]:
num_epochs = 10


for epoch in 1:num_epochs
    for (x, y) in train_set
        lstm_input = cat(x..., dims=3)  # Concatenate along the third dimension
        lstm_input = permutedims(lstm_input, (1, 3, 2))  # Transpose dimensions to (sequence_length, batch_size, n_features)
        y = reshape(y, (length(y), 1)) # Reshape y to match model output shape

        val, grads = Flux.withgradient(model) do m
            ŷ = m(lstm_input)
            ŷ = reshape(ŷ, (size(ŷ, 2), 1))
            loss_val = Flux.Losses.mse(ŷ, y)
        end
    end
    # Optionally, you can evaluate the model on the validation set and print the validation loss here
end

In [128]:
function evaluate_model(model, val_set)
    total_loss = 0.0
    num_batches = 0

    for (x, y) in val_set
        # Preprocess input data
        lstm_input = cat(x..., dims=3)  # Concatenate along the third dimension
        lstm_input = permutedims(lstm_input, (1, 3, 2))  # Transpose dimensions to (sequence_length, batch_size, n_features)

        # Expand the dimensions of the target data
        y = reshape(y, (length(y), 1)) # Reshape y to match model output shape

        # Compute predictions and loss
        ŷ = model(lstm_input)
        ŷ = reshape(ŷ, (size(ŷ, 2), 1))
        batch_loss = Flux.Losses.mse(ŷ, y)

        total_loss += batch_loss
        num_batches += 1
    end

    avg_loss = total_loss / num_batches
    return avg_loss
end


evaluate_model (generic function with 1 method)

In [85]:
@save "trained_model.jld2" model

In [129]:
evaluate_model(model, val_set)

LoadError: UndefVarError: m not defined

In [114]:
loaded_model = nothing

In [115]:
@load "trained_model.jld2" loaded_model

LoadError: KeyError: key "loaded_model" not found

In [120]:
function preprocess_new_midi(csv_file::String)
    df = CSV.read(csv_file, DataFrame)
    if size(df)[1] < 500
        println("MIDI file is too short (< 500 rows).")
        return nothing
    end
    input = Matrix{Float32}(df[1:500, Not(:anomalies)])
    if sum(df.anomalies[1:500]) > 0
        println("MIDI file contains $(sum(df.anomalies[1:500])) anomalies.")
    end
    println((size(input, 1), 1, size(input, 2)))

    # Transpose the input data and reshape it for LSTM
    lstm_input = transpose(input)
    lstm_input = reshape(lstm_input, (size(lstm_input, 1), size(lstm_input, 2), 1))
    
    return lstm_input
end


preprocess_new_midi (generic function with 1 method)

In [121]:
preprocessed_input = preprocess_new_midi("assets/anomalous/[ajin_op]_yoru_wa_nemureru_kai_-__flumpool__fonzi_m__0.1.csv")
if preprocessed_input !== nothing
    ŷ = model(preprocessed_input)
    # Process the output ŷ as needed
end

MIDI file contains 82 anomalies.
(500, 1, 4)


LoadError: DimensionMismatch: arrays could not be broadcast to a common size; got a dimension with lengths 500 and 32

In [None]:
# Preprocess the new MIDI file
new_midi_file = "assets/anomalous/liz_rhap15_0.3.csv"
new_midi_input = preprocess_new_midi(new_midi_file)

# Check if the preprocessing was successful (the file had at least 500 rows)
if new_midi_input !== nothing
    # Reshape the input array to match the model's input shape
    new_midi_input = reshape(new_midi_input, size(new_midi_input, 1), 1, size(new_midi_input, 2))
    println(new_midi_input)    
    # Predict the number of errors in the new MIDI file
    num_errors = model(new_midi_input)

    println("Predicted number of errors: ", num_errors[1])
else
    println("Prediction cannot be performed due to insufficient data.")
end