forked from VowpalWabbit/vowpal_wabbit
-
Notifications
You must be signed in to change notification settings - Fork 3
/
active_interactor.cc
139 lines (129 loc) · 3.39 KB
/
active_interactor.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
/*
Copyright (c) by respective owners including Yahoo!, Microsoft, and
individual contributors. All rights reserved. Released under a BSD (revised)
license as described in the file LICENSE.
*/
#include <iostream>
#include <string>
#include <cstring>
#include <cerrno>
#include <cstdlib>
#ifdef _WIN32
#include <WinSock2.h>
#else
#include <sys/types.h>
#include <unistd.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <netdb.h>
#endif
using namespace std;
int open_socket(const char* host, unsigned short port)
{
hostent* he;
he = gethostbyname(host);
if (he == nullptr)
{
cerr << "gethostbyname(" << host << "): " << strerror(errno) << endl;
throw exception();
}
int sd = socket(PF_INET, SOCK_STREAM, 0);
if (sd == -1)
{
cerr << "socket: " << strerror(errno) << endl;
throw exception();
}
sockaddr_in far_end;
far_end.sin_family = AF_INET;
far_end.sin_port = htons(port);
far_end.sin_addr = *(in_addr*)(he->h_addr);
memset(&far_end.sin_zero, '\0',8);
if (connect(sd,(sockaddr*)&far_end, sizeof(far_end)) == -1)
{
cerr << "connect(" << host << ':' << port << "): " << strerror(errno) << endl;
throw exception();
}
return sd;
}
int recvall(int s, char* buf, int n){
int total=0;
int ret=recv(s, buf, n, 0);
while(ret>0 && total<n){
total+=ret;
if(buf[total-1]=='\n')
break;
ret=recv(s, buf+total, n, 0);
}
return total;
}
int main(int argc, char* argv[]){
char buf[256];
char* toks,*itok,*ttag;
string tag;
const char* host="localhost";
unsigned short port=~0;
ssize_t pos;
int s,ret,queries=0;
string line;
if(argc>1){
host = argv[1];
}
if(argc>2){
port=atoi(argv[2]);
}
if(port <= 1024 || port==(unsigned short)(~0)){
port = 26542;
}
s=open_socket(host, port);
size_t id=0;
ret=send(s,&id,sizeof(id),0);
if(ret<0){
cerr << "Could not perform handshake!" << endl;
throw exception();
}
while(getline(cin,line)){
line.append("\n");
int len=line.size();
const char* cstr = line.c_str();
const char* sp = strchr(cstr,' ');
ret=send(s,sp+1,len-(sp+1-cstr),0);
if(ret<0){
cerr << "Could not send unlabeled data!" << endl;
throw exception();
}
ret=recvall(s, buf, 256);
if(ret<0){
cerr << "Could not receive queries!" << endl;
throw exception();
}
buf[ret]='\0';
toks=&buf[0];
strsep(&toks," ");
ttag=strsep(&toks," ");
tag=ttag?string(ttag):string("'empty");
itok=strsep(&toks,"\n");
if(itok==nullptr || itok[0]=='\0'){
continue;
}
queries+=1;
string imp=string(itok)+" "+tag+" |";
pos = line.find_first_of ("|");
line.replace(pos,1,imp);
cstr = line.c_str();
len = line.size();
ret = send(s,cstr,len,0);
if(ret<0){
cerr << "Could not send labeled data!" << endl;
throw exception();
}
ret=recvall(s, buf, 256);
if(ret<0){
cerr << "Could not receive predictions!" << endl;
throw exception();
}
}
close(s);
cout << "Went through the data by doing " << queries << " queries" << endl;
return 0;
}