Skip to content

Commit

Permalink
dolly : disable interactive_port on Windows (#339)
Browse files Browse the repository at this point in the history
  • Loading branch information
iboB committed Jul 4, 2023
1 parent 965568d commit b2b6de8
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions examples/dolly-v2/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,16 @@
#include <string>
#include <vector>

#if !defined(_WIN32)
#define DOLLY_INTERACTIVE_PORT
#endif

#if defined(DOLLY_INTERACTIVE_PORT)
#include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <unistd.h>
#endif

#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
Expand Down Expand Up @@ -775,6 +781,7 @@ std::string execute_prompt(
return output;
}

#if defined(DOLLY_INTERACTIVE_PORT)
int setup_port(const int port) {
int sockfd = socket(AF_INET, SOCK_STREAM, 0);
if (sockfd < 0) {
Expand Down Expand Up @@ -818,6 +825,7 @@ std::string read_from_port(int sockfd, int clientfd) {
}
return std::string("");
}
#endif

int main(int argc, char ** argv) {
ggml_time_init();
Expand Down Expand Up @@ -865,6 +873,7 @@ int main(int argc, char ** argv) {
test_gpt_tokenizer(vocab, params.token_test);
}

#if defined(DOLLY_INTERACTIVE_PORT)
int sockfd;
if (params.interactive_port != -1) {
sockfd = setup_port(params.interactive_port);
Expand All @@ -874,17 +883,21 @@ int main(int argc, char ** argv) {
fprintf(stdout, "Model is ready on port %i\n", params.interactive_port);
fflush(stdout);
}
#endif

if (params.interactive or params.interactive_port != -1) {
if (params.interactive || params.interactive_port != -1) {
while (true) {
std::string prompt_input;
#if defined(DOLLY_INTERACTIVE_PORT)
int clientfd;
if (params.interactive_port != -1) {
sockaddr_in clientaddr;
socklen_t clientaddrlen = sizeof(clientaddr);
clientfd = accept(sockfd, (struct sockaddr *)&clientaddr, &clientaddrlen);
clientfd = accept(sockfd, (struct sockaddr *)&clientaddr, &clientaddrlen);
prompt_input = read_from_port(sockfd, clientfd);
} else {
} else
#endif
{
printf("Please enter your quesiton:\n>");
fflush(stdout);

Expand All @@ -899,6 +912,7 @@ int main(int argc, char ** argv) {
// call the model
const std::string response = execute_prompt(model, vocab, prompt, params, rng, t_load_us, t_sample_us, t_predict_us, mem_per_token, n_past, true);

#if defined(DOLLY_INTERACTIVE_PORT)
if (params.interactive_port != -1) {
if (write(clientfd, response.c_str(), response.size()) < 0) {
std::cerr << "Failed to write to client\n";
Expand All @@ -907,8 +921,9 @@ int main(int argc, char ** argv) {
if (close(clientfd) < 0) {
std::cerr << "Failed to close client socket\n";
}
}
else {
} else
#endif
{
printf("%s\n\n", response.c_str());
}
fflush(stdout);
Expand Down Expand Up @@ -936,9 +951,11 @@ int main(int argc, char ** argv) {

ggml_free(model.ctx);

#if defined(DOLLY_INTERACTIVE_PORT)
if (params.interactive_port != -1 && close(sockfd) < 0) {
std::cerr << "Failed to close server socket\n";
}
#endif

return 0;
}

0 comments on commit b2b6de8

Please sign in to comment.