Permalink
Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
287 lines (262 sloc) 8.04 KB
--- svm-train.c.orig 2012-07-15 07:12:58.000000000 -0400
+++ svm-train.c 2014-05-13 13:50:56.895044086 -0400
@@ -11,7 +11,7 @@
void exit_with_help()
{
printf(
- "Usage: svm-train [options] training_set_file [model_file]\n"
+ "Usage: svm-train [options] training_set_file weight_file [model_file]\n"
"options:\n"
"-s svm_type : set type of SVM (default 0)\n"
" 0 -- C-SVC (multi-class classification)\n"
@@ -48,10 +48,11 @@
exit(1);
}
-void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name);
-void read_problem(const char *filename);
+void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name, char *weights_file_name);
+void read_problem(const char *filename, const char *weights_filename);
void do_cross_validation();
+float *deep_weights; // set by read_problem
struct svm_parameter param; // set by parse_command_line
struct svm_problem prob; // set by read_problem
struct svm_model *model;
@@ -84,10 +85,11 @@
{
char input_file_name[1024];
char model_file_name[1024];
+ char weights_file_name[1024];
const char *error_msg;
- parse_command_line(argc, argv, input_file_name, model_file_name);
- read_problem(input_file_name);
+ parse_command_line(argc, argv, input_file_name, model_file_name, weights_file_name);
+ read_problem(input_file_name, weights_file_name);
error_msg = svm_check_parameter(&prob,&param);
if(error_msg)
@@ -102,7 +104,7 @@
}
else
{
- model = svm_train(&prob,&param);
+ model = svm_train(&prob,&param,deep_weights);
if(svm_save_model(model_file_name,model))
{
fprintf(stderr, "can't save model to file %s\n", model_file_name);
@@ -115,6 +117,7 @@
free(prob.x);
free(x_space);
free(line);
+ free(deep_weights);
return 0;
}
@@ -158,7 +161,7 @@
free(target);
}
-void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name)
+void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name, char *weights_file_name)
{
int i;
void (*print_func)(const char*) = NULL; // default printing to stdout
@@ -260,6 +263,9 @@
strcpy(input_file_name, argv[i]);
+ i++;
+ strcpy(weights_file_name, argv[i]);
+
if(i<argc-1)
strcpy(model_file_name,argv[i+1]);
else
@@ -275,12 +281,16 @@
// read in a problem (in svmlight format)
-void read_problem(const char *filename)
+void read_problem(const char *filename, const char *weights_filename)
{
int elements, max_index, inst_max_index, i, j;
FILE *fp = fopen(filename,"r");
char *endptr;
char *idx, *val, *label;
+
+ FILE *fi;
+ char weight_line[50];
+ float weight;
if(fp == NULL)
{
@@ -376,4 +386,18 @@
}
fclose(fp);
+
+ // read in the weights
+ deep_weights = Malloc(float,prob.l);
+ fi = fopen(weights_filename, "r");
+
+ i = 0;
+ while (fgets(weight_line,20,fi) != NULL)
+ {
+ sscanf(weight_line,"%f", &weight);
+ deep_weights[i] = weight;
+ i++;
+ }
+
+ fclose(fi);
}
--- svm.cpp.orig 2012-10-01 21:41:17.000000000 -0400
+++ svm.cpp 2014-05-13 13:50:45.687043818 -0400
@@ -9,6 +9,9 @@
#include <locale.h>
#include "svm.h"
int libsvm_version = LIBSVM_VERSION;
+
+const float *weights;
+
typedef float Qfloat;
typedef signed char schar;
#ifndef min
@@ -58,6 +61,8 @@
static void info(const char *fmt,...) {}
#endif
+// #define DEBUG
+
//
// Kernel Cache
//
@@ -516,6 +521,11 @@
this->eps = eps;
unshrink = false;
+ #ifdef DEBUG
+ printf("Cp: %f, Cn: %f\n", Cp, Cn);
+ printf("l: %d\n", l);
+ #endif
+
// initialize alpha_status
{
alpha_status = new char[l];
@@ -543,7 +553,10 @@
}
for(i=0;i<l;i++)
if(!is_lower_bound(i))
- {
+ {
+ #ifdef DEBUG
+ printf("call to get_Q, i: %d, l: %l\n", i, l);
+ #endif
const Qfloat *Q_i = Q.get_Q(i,l);
double alpha_i = alpha[i];
int j;
@@ -589,13 +602,27 @@
++iter;
// update alpha[i] and alpha[j], handle bounds carefully
-
+
+ #ifdef DEBUG
+ printf("optimization, call to get_Q, i: %d, active_size: %d\n", i, active_size);
+ #endif
+
const Qfloat *Q_i = Q.get_Q(i,active_size);
+
+ #ifdef DEBUG
+ printf("optimization, call to get_Q, j: %d, active_size: %d\n", j, active_size);
+ #endif
+
const Qfloat *Q_j = Q.get_Q(j,active_size);
double C_i = get_C(i);
double C_j = get_C(j);
+ // printf("in solver: C_i: %f, C_j: %f\n", C_i, C_j);
+ #ifdef DEBUG
+ printf("alpha[i]: %f, alpha[j]: %f\n", alpha[i], alpha[j]);
+ #endif
+
double old_alpha_i = alpha[i];
double old_alpha_j = alpha[j];
@@ -1268,22 +1295,32 @@
public:
SVC_Q(const svm_problem& prob, const svm_parameter& param, const schar *y_)
:Kernel(prob.l, prob.x, param)
- {
+ {
clone(y,y_,prob.l);
cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20)));
QD = new double[prob.l];
- for(int i=0;i<prob.l;i++)
- QD[i] = (this->*kernel_function)(i,i);
+ this->C = param.C;
+ for(int i=0;i<prob.l;i++) {
+ QD[i] = (this->*kernel_function)(i,i)+weights[i]/C;
+ // printf("constructor: %d, %f\n", i, QD[i]);
+ }
}
Qfloat *get_Q(int i, int len) const
{
Qfloat *data;
int start, j;
+
if((start = cache->get_data(i,&data,len)) < len)
{
- for(j=start;j<len;j++)
+ for(j=start;j<len;j++) {
data[j] = (Qfloat)(y[i]*y[j]*(this->*kernel_function)(i,j));
+ // printf("%d, %d, %f\n", i, j, data[j]);
+ }
+ if(i >= start && i < len) {
+ data[i] += weights[i]/C;
+ // printf("--- %f\n", data[i]);
+ }
}
return data;
}
@@ -1308,6 +1345,7 @@
delete[] QD;
}
private:
+ double C;
schar *y;
Cache *cache;
double *QD;
@@ -1456,7 +1494,7 @@
Solver s;
s.Solve(l, SVC_Q(*prob,*param,y), minus_ones, y,
- alpha, Cp, Cn, param->eps, si, param->shrinking);
+ alpha, INF, INF, param->eps, si, param->shrinking);
double sum_alpha=0;
for(i=0;i<l;i++)
@@ -1958,7 +1996,7 @@
subparam.weight_label[1]=-1;
subparam.weight[0]=Cp;
subparam.weight[1]=Cn;
- struct svm_model *submodel = svm_train(&subprob,&subparam);
+ struct svm_model *submodel = svm_train(&subprob,&subparam,NULL);
for(j=begin;j<end;j++)
{
svm_predict_values(submodel,prob->x[perm[j]],&(dec_values[perm[j]]));
@@ -2071,12 +2109,15 @@
//
// Interface functions
//
-svm_model *svm_train(const svm_problem *prob, const svm_parameter *param)
+svm_model *svm_train(const svm_problem *prob, const svm_parameter *param, const float *deep_weights)
{
svm_model *model = Malloc(svm_model,1);
model->param = *param;
model->free_sv = 0; // XXX
+ if (deep_weights != NULL)
+ weights = deep_weights;
+
if(param->svm_type == ONE_CLASS ||
param->svm_type == EPSILON_SVR ||
param->svm_type == NU_SVR)
@@ -2413,7 +2454,7 @@
subprob.y[k] = prob->y[perm[j]];
++k;
}
- struct svm_model *submodel = svm_train(&subprob,param);
+ struct svm_model *submodel = svm_train(&subprob,param,NULL);
if(param->probability &&
(param->svm_type == C_SVC || param->svm_type == NU_SVC))
{
--- svm.h.orig 2012-11-16 09:43:53.000000000 -0500
+++ svm.h 2014-05-13 13:50:52.471043980 -0400
@@ -71,7 +71,7 @@
/* 0 if svm_model is created by svm_train */
};
-struct svm_model *svm_train(const struct svm_problem *prob, const struct svm_parameter *param);
+struct svm_model *svm_train(const struct svm_problem *prob, const struct svm_parameter *param, const float *deep_weights);
void svm_cross_validation(const struct svm_problem *prob, const struct svm_parameter *param, int nr_fold, double *target);
int svm_save_model(const char *model_file_name, const struct svm_model *model);