Skip to content
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.

Commit

Permalink
ENH - major code cleanup for readability, speed up of computations with
Browse files Browse the repository at this point in the history
many observations, allowing for a low-dim sqrtm representation of the kernel
  • Loading branch information
schoffelen committed Nov 8, 2019
1 parent 7fa4991 commit 89deb68
Show file tree
Hide file tree
Showing 3 changed files with 441 additions and 300 deletions.
40 changes: 27 additions & 13 deletions external/dmlt/+dml/svm.m
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@

C % regularization parameter

kernel = 'linear'; % type of kernel

Ktrain % precomputed kernel for training data
Ktest % precomputed kernel for test data

distance = false; % return distance from decision boundary instead of class label

native = false; % uses native Bioinformatics toolbox SVM implementation if true

native = false; % uses native Bioinformatics toolbox SVM implementation if true
issqrtmK = false;
end

methods
Expand All @@ -42,6 +43,8 @@

function obj = train(obj,X,Y)

opts = {'verb' -1};

% handle multiple datasets
if iscell(X)
obj = dml.ndata('method',obj);
Expand All @@ -50,34 +53,45 @@
end

if obj.native
obj.Ktrain = svmtrain(X,Y);
obj.Ktrain = fitcsvm(X,Y);
return
end

if obj.restart || isempty(obj.dual)

obj.Ktrain = compKernel(X,X,'linear');
obj.X = X;
obj.Ktrain = compKernel(X,X,obj.kernel,obj.Ktrain);
obj.Ktest = [];

if isempty(obj.C)
obj.C = .1*(mean(diag(obj.Ktrain))-mean(obj.Ktrain(:)));
if obj.issqrtmK
% the actual kernel will be K*K'
diagK = sum(obj.Ktrain.^2,2);
meanK = mean(reshape(obj.Ktrain*obj.Ktrain',[],1));
else
diagK = diag(obj.Ktrain);
meanK = mean(obj.Ktrain(:));
end
obj.C = .1*(mean(diagK)-meanK);
if obj.verbose, fprintf('using default C=%.2f\n',obj.C); end
end

obj.X = X;

obj.Ktest = [];

obj.dual = l2svm_cg(obj.Ktrain,2*(Y-1)-1,obj.C,'verb',-1);
if obj.issqrtmK
opts = cat(2, opts, {'issqrtmK' true});
end
obj.dual = l2svm_cg(obj.Ktrain,2*(Y-1)-1,obj.C, opts);

else

% facilitate warm restarts
obj.dual = l2svm_cg(obj.Ktrain,2*(Y-1)-1,obj.C,'verb',-1); %,'alphab',obj.dual);
obj.dual = l2svm_cg(obj.Ktrain,2*(Y-1)-1,obj.C, opts); %,'alphab',obj.dual);

end

obj.primal = 0;
for j=1:size(X,1), obj.primal = obj.primal + obj.dual(j)*X(j,:); end
for j = 1:size(X,1)
obj.primal = obj.primal + obj.dual(j)*X(j,:);
end

end

Expand Down
106 changes: 61 additions & 45 deletions external/dmlt/external/svm/compKernel.m
Original file line number Diff line number Diff line change
Expand Up @@ -33,58 +33,74 @@
% implied


if (nargin < 3) % check correct number of arguments
error('Insufficient arguments'); return;
if nargin < 3 % check correct number of arguments
error('Insufficient arguments');
end
if ( isinteger(X) ) X=single(X); end % need floats for products!
if isinteger(X)
X = single(X);
end % need floats for products!
% empty Z means use X
isgram=false;
dir =1; % Assume row vectors in X
if ( isempty(Z) ) Z=X; isgram=true; elseif ( isinteger(Z) ) Z=single(Z); end;

if ( isstr(kerType) )
switch lower(kerType)

case {'linear','nlinear'}; % linear
K = linKer(X,Z,dir);

case {'poly','npoly'}; % polynomial
if(numel(varargin)<2) varargin{2}=1;end;
K = (linKer(X,Z,dir)+varargin{2}).^varargin{1};

case {'rbf','nrbf'}; % Radial basis function, a.k.a. gaussian
K = sqDist(X,Z,dir); % pairwise distance
K = exp( - .5 * K / varargin{1} ) ;
if(isgram) K=.5*(K+K'); end % to avoid rounding error problems....

case 'precomp'; % Given in the input
if ( all( size(varargin) == [size(X1,1) size(X2,1)] ) )
K=varargin{1};
else
error('Kernel matrix does not match input dimensions');
end

otherwise;
if ( exist(kerType)>1 ) % String which specifies function on the path
K=feval(kerType,X,Z,varargin{:});
else
error(['Unrecognised kernel type : ' kerType]);
end
end
isgram = false;
dir = 1; % Assume row vectors in X
if isempty(Z)
Z = X;
isgram = true;
elseif isinteger(Z)
Z = single(Z);
end

elseif ( isa(kerType,'function_handle') ) % function handle
% use this handle
K=feval(kerType,X,Z,varargin{:});
if ischar(kerType)
switch lower(kerType)

case {'linear','nlinear'} % linear
K = linKer(X,Z,dir);

case {'poly','npoly'} % polynomial
if numel(varargin)<2
varargin{2}=1;
end
K = (linKer(X,Z,dir)+varargin{2}).^varargin{1};

case {'rbf','nrbf'} % Radial basis function, a.k.a. gaussian
K = sqDist(X,Z,dir); % pairwise distance
K = exp( - .5 * K / varargin{1} );
if(isgram)
K = .5*(K+K');
end % to avoid rounding error problems....

case 'precomp' % Given in the input
% if ~all( size(varargin{1}) == [size(X,1) size(Z,1)] )
% warning('Kernel matrix does not match input dimensions');
% end
K = varargin{1};

otherwise
if exist(kerType)>1 % String which specifies function on the path
K=feval(kerType,X,Z,varargin{:});
else
error(['Unrecognised kernel type : ' kerType]);
end
end

elseif isa(kerType,'function_handle') % function handle
% use this handle
K = feval(kerType,X,Z,varargin{:});
else
error('Unknown kernel type');return
error('Unknown kernel type');
end

if ( isequal(kerType(1),'n') ) % normalise computed kernel
if ( isgram ) % Need to compute K(X,X)_i,i and K(Z,Z)_j,j
Nr = diag(K); Nc = diag(K);
if isequal(kerType(1),'n') % normalise computed kernel
if isgram % Need to compute K(X,X)_i,i and K(Z,Z)_j,j
Nr = diag(K); Nc = diag(K);
else
for i=1:size(X,1); Nr(i,1)=compKernel(X(i,:),[],kerType,varargin{:});end;
for i=1:size(Z,1); Nc(i,1)=compKernel(Z(i,:),[],kerType,varargin{:});end;
Nr = zeros(size(X,1),1);
for i=1:size(X,1)
Nr(i,1)=compKernel(X(i,:),[],kerType,varargin{:});
end
Nc = zeros(size(Z,1),1);
for i=1:size(Z,1)
Nc(i,1)=compKernel(Z(i,:),[],kerType,varargin{:});
end
end
K = repop(repop(K,sqrt(Nr),'./'),sqrt(Nc)','./');
end;
Loading

0 comments on commit 89deb68

Please sign in to comment.