In [1]:
function y = psi(x)
    %DIGAMMA   Digamma function.
    % DIGAMMA(X) returns digamma(x) = d log(gamma(x)) / dx
    % If X is a matrix, returns the digamma function evaluated at each element.

    % Reference:
    %
    %    J Bernardo,
    %    Psi ( Digamma ) Function,
    %    Algorithm AS 103,
    %    Applied Statistics,
    %    Volume 25, Number 3, pages 315-317, 1976.
    %
    % From http://www.psc.edu/~burkardt/src/dirichlet/dirichlet.f

    large = 9.5;
    d1 = -0.5772156649015328606065121;  % digamma(1)
    d2 = pi^2/6;
    small = 1e-6;
    s3 = 1/12;
    s4 = 1/120;
    s5 = 1/252;
    s6 = 1/240;
    s7 = 1/132;
    s8 = 691/32760;
    s9 = 1/12;
    s10 = 3617/8160;

    % Initialize
    y = zeros(size(x));

    % illegal arguments
    i = find(x == -Inf | isnan(x));
    if ~isempty(i)
        x(i) = NaN;
        y(i) = NaN;
    end

    % Negative values
    i = find(x < 0);
    if ~isempty(i)
        % Use the reflection formula (Jeffrey 11.1.6):
        % digamma(-x) = digamma(x+1) + pi*cot(pi*x)
        y(i) = digamma(-x(i)+1) + pi*cot(-pi*x(i));
        % This is related to the identity
        % digamma(-x) = digamma(x+1) - digamma(z) + digamma(1-z)
        % where z is the fractional part of x
        % For example:
        % digamma(-3.1) = 1/3.1 + 1/2.1 + 1/1.1 + 1/0.1 + digamma(1-0.1)
        %               = digamma(4.1) - digamma(0.1) + digamma(1-0.1)
        % Then we use
        % digamma(1-z) - digamma(z) = pi*cot(pi*z)
    end

    i = find(x == 0);
    if ~isempty(i)
        y(i) = -Inf;
    end

    %  Use approximation if argument <= small.
    i = find(x > 0 & x <= small);
    if ~isempty(i)
        y(i) = y(i) + d1 - 1 ./ x(i) + d2*x(i);
    end

    %  Reduce to digamma(X + N) where (X + N) >= large.
    while(1)
        i = find(x > small & x < large);
        if isempty(i)
            break
          end
        y(i) = y(i) - 1 ./ x(i);
        x(i) = x(i) + 1;
    end

    %  Use de Moivre's expansion if argument >= large.
    % In maple: asympt(Psi(x), x);
    i = find(x >= large);
    if ~isempty(i)
        r = 1 ./ x(i);
        y(i) = y(i) + log(x(i)) - 0.5 * r;
        r = r .* r;
        y(i) = y(i) - r .* ( s3 - r .* ( s4 - r .* (s5 - r .* (s6 - r .* s7))));
    end
end

function ld = logdet(Sigma)
    U = chol(Sigma);
    ld = 2 * sum(log(diag(U)));
endfunction

function printvar(name, var)
    N = size(var, 1);
    M = size(var, 2);
    name
    if N > 1 && M == 1
        var(1:3)
        var(N-2:N)
    else
        if N > 1 && M > 1
            var(1:3, 1:3)
            var(1:3, M-2:M)
            var(N-2:N, 1:3)
            var(N-2:N, M-2:M)
        else
            var
        endif
    endif
endfunction

In [2]:
function state = bemkl_supervised_classification_variational_train(Km, y, parameters, use_init_vars, init_vars, verbose)
    rand('state', parameters.seed); %#ok<RAND>
    randn('state', parameters.seed); %#ok<RAND>

    D = size(Km, 1);
    N = size(Km, 2);
    P = size(Km, 3);
    sigma_g = parameters.sigma_g;

    log2pi = log(2 * pi);

    if use_init_vars
        "loading init_vars"
        a.mu = init_vars.a_mu';
        G.mu = init_vars.G_mu';
        f.mu = init_vars.f_mu';
    else
        "calculating init_vars"
        a.mu = randn(D, 1);
        G.mu = (abs(randn(P, N)) + parameters.margin) .* sign(repmat(y', P, 1));
        f.mu = (abs(randn(N, 1)) + parameters.margin) .* sign(y);
    endif
    lambda.alpha = (parameters.alpha_lambda + 0.5) * ones(D, 1);
    lambda.beta = parameters.beta_lambda * ones(D, 1);
    a.sigma = eye(D, D);
    G.sigma = eye(P, P);
    gamma.alpha = (parameters.alpha_gamma + 0.5);
    gamma.beta = parameters.beta_gamma;
    omega.alpha = (parameters.alpha_omega + 0.5) * ones(P, 1);
    omega.beta = parameters.beta_omega * ones(P, 1);
    be.mu = [0; ones(P, 1)];
    be.sigma = eye(P + 1, P + 1);
    f.sigma = ones(N, 1);
    
    KmKm = zeros(D, D);
    for m = 1:P
        KmKm = KmKm + Km(:, :, m) * Km(:, :, m)';
    endfor
    Km = reshape(Km, [D, N * P]);

    lower = -1e40 * ones(N, 1);
    lower(y > 0) = +parameters.margin;
    upper = +1e40 * ones(N, 1);
    upper(y < 0) = -parameters.margin;

    if parameters.progress == 1
        bounds = zeros(parameters.iteration, 1);
    endif
    
    atimesaT.mu = a.mu * a.mu' + a.sigma;
    GtimesGT.mu = G.mu * G.mu' + N * G.sigma;
    btimesbT.mu = be.mu(1)^2 + be.sigma(1, 1);
    etimeseT.mu = be.mu(2:P + 1) * be.mu(2:P + 1)' + be.sigma(2:P + 1, 2:P + 1);
    etimesb.mu = be.mu(2:P + 1) * be.mu(1) + be.sigma(2:P + 1, 1);
    KmtimesGT.mu = Km * reshape(G.mu', N * P, 1);


    init_time = time();
    for iter = 1:parameters.iteration
        %%%% update lambda
        lambda.beta = 1 ./ (1 / parameters.beta_lambda + 0.5 * diag(atimesaT.mu));
        %%%% update a
        a.sigma = (diag(lambda.alpha .* lambda.beta) + KmKm / sigma_g^2) \ eye(D, D);
        a.mu = a.sigma * KmtimesGT.mu / sigma_g^2;
        atimesaT.mu = a.mu * a.mu' + a.sigma;
        %%%% update G
        G.sigma = (eye(P, P) / sigma_g^2 + etimeseT.mu) \ eye(P, P);
        G.mu = G.sigma * (reshape(a.mu' * Km, [N, P])' / sigma_g^2 + be.mu(2:P + 1) * f.mu' - repmat(etimesb.mu, 1, N));
        GtimesGT.mu = G.mu * G.mu' + N * G.sigma;
        KmtimesGT.mu = Km * reshape(G.mu', N * P, 1);
        %%%% update gamma
        gamma.beta = 1 / (1 / parameters.beta_gamma + 0.5 * btimesbT.mu);
        %%%% update omega
        omega.beta = 1 ./ (1 / parameters.beta_omega + 0.5 * diag(etimeseT.mu));
        %%%% update b and e
        be.sigma = [gamma.alpha * gamma.beta + N, sum(G.mu, 2)'; sum(G.mu, 2), diag(omega.alpha .* omega.beta) + GtimesGT.mu] \ eye(P + 1, P + 1);
        be.mu = be.sigma * ([ones(1, N); G.mu] * f.mu);
        btimesbT.mu = be.mu(1)^2 + be.sigma(1, 1);
        etimeseT.mu = be.mu(2:P + 1) * be.mu(2:P + 1)' + be.sigma(2:P + 1, 2:P + 1);
        etimesb.mu = be.mu(2:P + 1) * be.mu(1) + be.sigma(2:P + 1, 1);
        %%%% update f
        output = [ones(1, N); G.mu]' * be.mu;
        alpha_norm = lower - output;
        beta_norm = upper - output;
        normalization = normcdf(beta_norm) - normcdf(alpha_norm);
        normalization(normalization == 0) = 1;
        f.mu = output + (normpdf(alpha_norm) - normpdf(beta_norm)) ./ normalization;
        f.sigma = 1 + (alpha_norm .* normpdf(alpha_norm) - beta_norm .* normpdf(beta_norm)) ./ normalization - (normpdf(alpha_norm) - normpdf(beta_norm)).^2 ./ normalization.^2;

        if parameters.progress == 1
            lb = 0;

            %%%% p(lambda)
            p.lambda = sum(
                      (parameters.alpha_lambda - 1) * (psi(lambda.alpha) + log(lambda.beta))
                    - gammaln(parameters.alpha_lambda)
                    - lambda.alpha .* lambda.beta / parameters.beta_lambda
                    - parameters.alpha_lambda * log(parameters.beta_lambda)
                  );
            lb = lb + p.lambda;
            if verbose
                "p(lambda)", p.lambda
            endif
            %%%% p(a | lambda)
            p.a_lambda = (
                        - 0.5 * sum(lambda.alpha .* lambda.beta .* diag(atimesaT.mu))
                        - 0.5 * D * log2pi
                        + 0.5 * sum(psi(lambda.alpha) + log(lambda.beta))
                        );
            lb = lb + p.a_lambda;
            if verbose
                "p(a_lambda)", p.a_lambda
            endif
            
            %%%% p(G | a, Km)
            p.G_a_Km = (
                  - 0.5 * sigma_g^-2 * sum(diag(GtimesGT.mu))
                  + sigma_g^-2 * a.mu' * KmtimesGT.mu
                  - 0.5 * sigma_g^-2 * sum(sum(KmKm .* atimesaT.mu))
                  - 0.5 * N * P * (log2pi + 2 * log(sigma_g))
            );
            lb = lb + p.G_a_Km;
            if verbose
                "p(G_a_Km)", p.G_a_Km
            endif

            %%%% p(gamma)
            p.gamma = (
                    (parameters.alpha_gamma - 1) * (psi(gamma.alpha) + log(gamma.beta))
                  - gamma.alpha * gamma.beta / parameters.beta_gamma
                  - gammaln(parameters.alpha_gamma)
                  - parameters.alpha_gamma * log(parameters.beta_gamma)
            );
            lb = lb + p.gamma;
            if verbose
                "p(gamma)", p.gamma
            endif

            %%%% p(b | gamma)
            p.b_gamma = (
                  - 0.5 * gamma.alpha * gamma.beta * btimesbT.mu
                  - 0.5 * (log2pi - (psi(gamma.alpha) + log(gamma.beta)))
            );
            lb = lb + p.b_gamma;
            if verbose
                  gamma.alpha * gamma.beta
                  btimesbT.mu
                "p(b_gamma)", p.b_gamma
            endif

            %%%% p(omega)
            p.omega = sum(
                      (parameters.alpha_omega - 1) * (psi(omega.alpha) + log(omega.beta))
                    - omega.alpha .* omega.beta / parameters.beta_omega
                    - gammaln(parameters.alpha_omega)
                    - parameters.alpha_omega * log(parameters.beta_omega)
                  );
            lb = lb + p.omega;
            if verbose
                "p(omega)", p.omega
            endif

            %%%% p(e | omega)
            p.e_omega = (
                  - 0.5 * sum(omega.alpha .* omega.beta .* diag(etimeseT.mu))
                  - 0.5 * P * log2pi
                  + 0.5 * sum(psi(omega.alpha) + log(omega.beta))
            );
            lb = lb + p.e_omega;
            if verbose
                "p(e_omega)", p.e_omega
            endif

            %%%% p(f | b, e, G)
            p.f_b_e_G = (
                  - 0.5 * (f.mu' * f.mu + sum(f.sigma))
                  + f.mu' * (G.mu' * be.mu(2:P + 1))
                  + sum(be.mu(1) * f.mu)
                  - 0.5 * sum(sum(etimeseT.mu .* GtimesGT.mu))
                  - sum(G.mu' * etimesb.mu)
                  - 0.5 * N * btimesbT.mu
                  - 0.5 * N * log2pi
            );
            lb = lb + p.f_b_e_G;
            if verbose
                "p(f_b_e_G)", p.f_b_e_G
            endif

            %%%% q(lambda)
            q.lambda = sum(
                    - lambda.alpha
                    - log(lambda.beta)
                    - gammaln(lambda.alpha)
                    - (1 - lambda.alpha) .* psi(lambda.alpha)
            );
            lb = lb - q.lambda;
            if verbose
                "q(lambda)", q.lambda
            endif

            %%%% q(a)
            q.a = (
                - 0.5 * D * (log2pi + 1)
                - 0.5 * logdet(a.sigma)
            );
            lb = lb - q.a;
            if verbose
                "q(a)", q.a
            endif

            %%%% q(G)
            q.G = (
                - 0.5 * N * P * (log2pi + 1)
                - 0.5 * N * logdet(G.sigma)
            );
            lb = lb - q.G;
            if verbose
                "q(G)", q.G
            endif

            %%%% q(gamma)
            q.gamma = (
                - gamma.alpha
                - log(gamma.beta)
                - gammaln(gamma.alpha)
                - (1 - gamma.alpha) * psi(gamma.alpha)
            );
            lb = lb - q.gamma;
            if verbose
                "q(gamma)", q.gamma
            endif

            %%%% q(omega)
            q.omega = sum(
                - omega.alpha
                - log(omega.beta)
                - gammaln(omega.alpha)
                - (1 - omega.alpha) .* psi(omega.alpha)
            );
            lb = lb - q.omega;
            if verbose
                "q(omega)", q.omega
            endif

            %%%% q(b, e)
            q.b_e = (
                - 0.5 * (P + 1) * (log2pi + 1)
                - 0.5 * logdet(be.sigma)
            );
            lb = lb - q.b_e;
            if verbose
                "q(b_e)", q.b_e
            endif

            %%%% q(f)
            q.f = sum(
                - 0.5 * (log2pi + f.sigma)
                - log(normalization)
            );
            lb = lb - q.f;
            if verbose
                "q(f)", q.f
                error("stop")
            endif

            bounds(iter) = lb;
        endif
        if mod(iter, 1) == 0
            fprintf(1, '.');
        endif
        if mod(iter, 10) == 0
            fprintf(1, ' %5d', iter);
            if parameters.progress==1
                fprintf(1, ' %.4f', bounds(iter));
            endif
            fprintf('\n')
        endif
    endfor

    state.total_time = time() - init_time;
    if parameters.progress==1
        fprintf(1, ' %.4f', state.total_time);
    endif
    fprintf('\n')

    state.lambda = lambda;
    state.a = a;
    state.gamma = gamma;
    state.omega = omega;
    state.be = be;
    if parameters.progress == 1
        state.bounds = bounds;
    endif
    state.parameters = parameters;
endfunction

In [3]:
function prediction = bemkl_supervised_classification_variational_test(Km, state)
    N = size(Km, 2);
    P = size(Km, 3);

    prediction.G.mu = zeros(P, N);
    prediction.G.sigma = zeros(P, N);
    for m = 1:P
        prediction.G.mu(m, :) = state.a.mu' * Km(:, :, m);
        prediction.G.sigma(m, :) = state.parameters.sigma_g^2 + diag(Km(:, :, m)' * state.a.sigma * Km(:, :, m));
    end

    prediction.f.mu = [ones(1, N); prediction.G.mu]' * state.be.mu;
    prediction.f.sigma = 1 + diag([ones(1, N); prediction.G.mu]' * state.be.sigma * [ones(1, N); prediction.G.mu]);
    prediction.a.mu = state.a.mu;
    prediction.a.sigma = state.a.sigma;
    prediction.be.mu = state.be.mu;
    prediction.be.sigma = state.be.sigma;
    prediction.bounds = state.bounds;

    pos = 1 - normcdf((+state.parameters.margin - prediction.f.mu) ./ prediction.f.sigma);
    neg = normcdf((-state.parameters.margin - prediction.f.mu) ./ prediction.f.sigma);
    prediction.p = pos ./ (pos + neg);
end

In [4]:
Km = load('Km.mat');
Km_train = Km.Km_train;
y_train = double(Km.y_train');
Km_test = Km.Km_test;
y_test = double(Km.y_test');
init_vars = Km.init_vars;

In [5]:
size(Km_train)
size(y_train)
size(Km_test)
size(y_test)

ans =

   478   478   130

ans =

   478     1

ans =

   478   205   130

ans =

   205     1



In [6]:
typeinfo(Km_train)
typeinfo(y_train)
typeinfo(Km_test)
typeinfo(y_test)

ans = matrix
ans = matrix
ans = matrix
ans = matrix


In [7]:
%initalize the parameters of the algorithm
parameters = struct();

%set the hyperparameters of gamma prior used for sample weights
parameters.alpha_lambda = 1;
parameters.beta_lambda = 1;

%set the hyperparameters of gamma prior used for bias
parameters.alpha_gamma = 1;
parameters.beta_gamma = 1;

%set the hyperparameters of gamma prior used for kernel weights
parameters.alpha_omega = 1;
parameters.beta_omega = 1;

%%% IMPORTANT %%%
%For gamma priors, you can experiment with three different (alpha, beta) values
%(1, 1) => default priors
%(1e-10, 1e+10) => good for obtaining sparsity
%(1e-10, 1e-10) => good for small sample size problems

%set the number of iterations
parameters.iteration = 200;

%set the margin parameter
parameters.margin = 1;

%determine whether you want to calculate and store the lower bound values
parameters.progress = 1;

%set the seed for random number generator used to initalize random variables
parameters.seed = 1606;

%set the standard deviation of intermediate representations
parameters.sigma_g = 0.1;

%initialize the kernels and class labels for training
Ktrain = Km_train; %should be an Ntra x Ntra x P matrix containing similarity values between training samples
ytrain = y_train; %should be an Ntra x 1 matrix containing class labels (contains only -1s and +1s)

%perform training
state = bemkl_supervised_classification_variational_train(Ktrain, ytrain, parameters, true, init_vars, false);

%display the kernel weights
%display(state.be.mu(2:end));

%initialize the kernels for testing
Ktest = Km_test; %should be an Ntra x Ntest x P matrix containing similarity values between training and test samples

%perform prediction
prediction = bemkl_supervised_classification_variational_test(Ktest, state);

%display the predicted probabilities
%display(prediction.p);

ans = loading init_vars
..........    10 -1428.2337
..........    20 -1340.5488
..........    30 -1317.0624
..........    40 -1302.2127
..........    50 -1285.7225
..........    60 -1259.3099
..........    70 -1234.3284
..........    80 -1211.4699
..........    90 -1200.7572
..........   100 -1194.3797
..........   110 -1190.0772
..........   120 -1184.2967
..........   130 -1173.7283
..........   140 -1169.3529
..........   150 -1165.7853
..........   160 -1159.1401
..........   170 -1153.6763
..........   180 -1149.7780
..........   190 -1146.1310
..........   200 -1144.6308
 16.1027


In [8]:
save('-6', 'prediction.mat', "prediction")