In [None]:
using Einsum
using Random
using LinearAlgebra

In [None]:
function ×₃(T::Array{Float64, 3}, v::Vector{Float64})
    @einsum M[i,j] := T[i,j,k]*v[k] #contract the 3rd index of a tensor with a vector
    return M
end

In [None]:
function mu_gnmf(X, T; maxiter=800, tol=5e-4, λ=0, ϵ=1e-8) #multipicative updated nonnegative matrix factorization
    #using LinearAlgebra
    
    # Initilization
    m, n = size(X)
    r, N, p = size(T)
    @assert n==N "Missmatch between the second dimention of X and T"
    A = abs.(randn((m, r)))
    b = abs.(randn((p,)))
    i = 0

    # Updates
    while (norm(X - A*(T×₃b))/norm(X) > tol) && (i < maxiter)
        i += 1
        @show i
        #@einsum b[q] = b[q] * T[u,v,q]*A[w,u]*X[w,v] / (T[i,j,q]*A[l,i]*A[l,s]*T[s,j,t]*b[t] + ϵ + λ*b[q]) #(W'*V ) ./ (W'*W*H  .+ ϵ + λ.*H) #update b
        #b_old = b
        for q ∈ 1:p
            TT = T[:,:,q]
            b[q] *= tr(TT'*A'*X) / (tr(TT'*A'*A*(T×₃b)) + ϵ + λ*b[q]) #can replace b with b_old?
        end
        B = (T×₃b)
        A .*= (X *B') ./ (A *B*B' .+ ϵ + λ.*A) #update A
    end

    return (A, b, i)
end

In [None]:
function envelope(t, delay)
    ϵ = 0#1e-4 #prevent divide by zero errors
    a = @. exp(-(t - delay)) * (t .>= delay)
    return @. a + (ϵ * (t < delay))
end

In [None]:
#S = 10   # of time-shifts
J = 2316 # of frequencies
K = 2    # of sources
L = 88   # of pitches
N = 10   # of harmonics
#i = 1:S
j = 1:J
l = 1:L
n = 1:N

In [None]:
# STFT Parameters
w = 250         # window width
hop = w÷2 - 1   # number of samples to hop over
window = hann(w)
N = sample_rate÷2
T = maximum(t)

# Notes to be played, major triad has ratio 2:3:5
notes = [100, 150, 250] #Hz

# Matrix Sizes
m, n = size(stft(t, window, hop))
r = length(notes)
freqs = range(0, N, m)
times = range(0, T, n)

In [None]:
h = n #multiple of the fundimental for the n'th harmonic
f = @. 440*2^((l-49)/12) #pitch of the l'th piano note
ν = @. 440*2^((j-961)/240) #frequency j for the typical range of human hearing
b1 = @. 1/n
b2 = @. 1/n^2

In [None]:
function close_in_cents(f,g;tol=5) #within tol cents
	#f_cents = 1200*log2(f/440)
	#g_cents = 1200*log2(g/440)
	#return abs(f_cents - g_cents) < tol
	return 1200*abs(log2(f/g)) < tol #1200 cents per octave (power of 2)
end

@einsum D[l,j,n] := close_in_cents(h[n]*f[l], ν[j])

In [None]:
Random.seed!(314)
delays = [rand()*(T-0.5) for _ ∈ notes]
y = arp_chord(t,notes,delays)
Y = abs.(stft(y, window, hop))
Xs = [abs.(stft(make_note(t,note,delay=delay), window, hop)) for (note,delay) ∈ zip(notes,delays)]

In [None]:

m,n,r,p = (100,100,10,5)
A = abs.(randn((m, r)))
b = abs.(randn((p,)))
T = abs.(randn((r,n,p)))
T[T.>=1] .= 1
T[T.< 1] .= 0
X = A*(T×₃b)
(AA, bb, i) = mu_gnmf(X, T)