In [None]:
using Flux, Zygote
using Flux.Data: DataLoader
using Statistics: mean
using JLD2
include("preprocessing.jl")

In [None]:
using Flux.Data: DataLoader

function prepare_dataloader(input_dir::String, train_ratio::Float64=0.8, batch_size::Int=32)
    inputs = []
    outputs = []

    for csv in readdir(input_dir)
        df = CSV.read(joinpath(input_dir, csv), DataFrame)
        if size(df)[1] < 500
            continue
        end
        push!(inputs, Matrix{Float32}(df[1:500, 1:end-1])')
        push!(outputs, sum(df[1:500, end]))
    end
    
    # Find the length of the longest input array
    max_length = maximum(size(input, 1) for input in inputs)

    # Pad input arrays with zeros to match the longest array's length
    padded_inputs = []
    for input in inputs
        rows_to_pad = max_length - size(input, 1)
        padded_input = vcat(input, zeros(Float32, rows_to_pad, size(input, 2)))
        push!(padded_inputs, padded_input)
    end

    # Split the data into train and validation sets
    num_train = Int(round(length(padded_inputs) * train_ratio))
    train_inputs = padded_inputs[1:num_train]
    train_outputs = outputs[1:num_train]
    val_inputs = padded_inputs[num_train+1:end]
    val_outputs = outputs[num_train+1:end]

    train_data = Flux.Data.DataLoader((train_inputs, train_outputs), batchsize=batch_size, shuffle=true)
    val_data = Flux.Data.DataLoader((val_inputs, val_outputs), batchsize=batch_size, shuffle=true)

    return train_data, val_data
end


In [11]:
train_set, val_set = prepare_data("assets/anomalous")

In [None]:
using Random: shuffle
using DataFrames: nrow

function split_data(inputs, outputs, train_frac=0.8)
    data = collect(zip(inputs, outputs))
    data = shuffle(data)
    println(size(data))

    train_size = Int(floor(train_frac * size(data)[1]))
    train_data = data[1:train_size]
    val_data = data[train_size+1:end]
    
    return train_data, val_data
end

train_data, val_data = split_data(inputs, targets)

In [None]:
# Many-to-one RNN architecture
model = Flux.Chain(
    x -> reshape(x, 4, 1, 500),
    Flux.LSTM(4, 64),
    x -> x[:, :, end],  # Select the hidden state at the last time step
    Dense(64, 1)
)

loss(x, y) = Flux.mse(model(x), y)


In [None]:
opt = Flux.setup(Adam(1e-2), model)
ps = Flux.params(model) # Get the model's parameters

In [None]:
num_epochs = 100
train_loader = DataLoader(train_data, batchsize=32, shuffle=true);
println(train_loader)

for epoch in 1:num_epochs
    for (x, y) in train_loader
        val, grads = Flux.withgradient(model) do m
            ŷ = m(x)
            loss_val = Flux.Losses.mse(ŷ, y)
        end

        if !isfinite(val)
            @warn "loss is $val on item $epoch" epoch
            continue
        end
        if length(grads) > 0
            Flux.Optimise.update!(opt, ps, grads[1])
        else
            @warn "no gradients on item $epoch" epoch
        end
    end
    # Optionally, you can evaluate the model on the validation set and print the validation loss here
end

In [None]:
# Test the model on the validation set
val_predictions = [model(reshape(x, size(x, 1), 1, size(x, 2))) for x in X_val]
val_loss = mean(loss_function(ŷ, y) for (ŷ, y) in zip(val_predictions, y_val))
println("Validation Loss: $val_loss")

In [None]:
function preprocess_new_midi(csv_file::String)
    df = CSV.read(csv_file, DataFrame)
    if size(df)[1] < 500
        println("MIDI file is too short (< 500 rows).")
        return nothing
    end
    input = Matrix{Float32}(df[1:500,Not(:anomalies)])
    if sum(df.anomalies[1:500]) > 0
        println("MIDI file contains $(sum(df.anomalies[1:500])) anomalies.")
    end
    return input
end


In [None]:
display(model)

In [None]:
# Preprocess the new MIDI file
new_midi_file = "assets/anomalous/liz_rhap15_0.3.csv"
new_midi_input = preprocess_new_midi(new_midi_file)

# Check if the preprocessing was successful (the file had at least 500 rows)
if new_midi_input !== nothing
    # Reshape the input array to match the model's input shape
    new_midi_input = reshape(new_midi_input, size(new_midi_input, 1), 1, size(new_midi_input, 2))
    println(new_midi_input)    
    # Predict the number of errors in the new MIDI file
    num_errors = model(new_midi_input)

    println("Predicted number of errors: ", num_errors[1])
else
    println("Prediction cannot be performed due to insufficient data.")
end