Skip to content

Commit

Permalink
Text Generation working with the HF hosted and managed endpoints.
Browse files Browse the repository at this point in the history
  • Loading branch information
honnuanand committed Dec 20, 2023
1 parent 000ad39 commit f8c027f
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 5 deletions.
6 changes: 6 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@
<artifactId>spring-web</artifactId>
<version>6.1.2</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-configuration-processor</artifactId>
<optional>true</optional>
</dependency>
</dependencies>

<build>
Expand All @@ -67,6 +72,7 @@
</execution>
</executions>
</plugin>

</plugins>
</build>

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.clue2solve.huggingface.inference.spring.cloud.starter.autoconfigure;

import io.clue2solve.huggingface.inference.spring.cloud.starter.config.HuggingFaceProperties;
import io.clue2solve.huggingface.inference.spring.cloud.starter.service.HFInferenceService;
import io.clue2solve.huggingface.inference.spring.cloud.starter.service.impl.HFTextGenerationService;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;

@AutoConfiguration
@EnableConfigurationProperties({ HuggingFaceProperties.class })
public class HFServiceAutoconfig {

@Bean
public HFInferenceService hfInferenceService(HuggingFaceProperties properties) {
return new HFTextGenerationService(properties);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package io.clue2solve.huggingface.inference.spring.cloud.starter.config;

import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Configuration;

@ConfigurationProperties(prefix = "huggingface")
public record HuggingFaceProperties(String apiToken, String modelName) {

}
Original file line number Diff line number Diff line change
@@ -1,26 +1,33 @@
package io.clue2solve.huggingface.inference.spring.cloud.starter.service.impl;

import com.fasterxml.jackson.core.JsonProcessingException;
import io.clue2solve.huggingface.inference.spring.cloud.starter.config.HuggingFaceProperties;
import io.clue2solve.huggingface.inference.spring.cloud.starter.service.HFInferenceService;
import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.http.*;
import org.springframework.web.client.RestTemplate;

import java.util.Collections;

public class HFInferenceEndpointService implements HFInferenceService {
public class HFTextGenerationService implements HFInferenceService {

private final HuggingFaceProperties properties;

public HFTextGenerationService(HuggingFaceProperties properties) {
this.properties = properties;
}

@Override
public String invoke(String prompt) throws JsonProcessingException {

RestTemplate restTemplate = new RestTemplateBuilder().build();

String url = "https://api-inference.huggingface.co/models/bert-base-uncased";
String url = "https://api-inference.huggingface.co/models/" + properties.modelName();

HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON));
headers.setBearerAuth("YOUR_HF_API_TOKEN"); // Replace with your actual token
headers.setBearerAuth(properties.apiToken()); // Replace with your actual token

String requestBody = "{\"inputs\":\"" + prompt + "\"}";

Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
io.clue2solve.spring.cloud.starter.huggingface.inference.autoconfigure.RestTemplateAutoConfiguration
io.clue2solve.spring.cloud.starter.huggingface.inference.service.HFModels
io.clue2solve.spring.cloud.starter.huggingface.inference.service.HFModels
io.clue2solve.huggingface.inference.spring.cloud.starter.autoconfigure.HuggingFaceInferenceAutoConfiguration
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;

import io.clue2solve.spring.cloud.starter.huggingface.inference.autoconfigure.RestTemplateAutoConfiguration;
import io.clue2solve.huggingface.inference.spring.cloud.starter.autoconfigure.RestTemplateAutoConfiguration;
import io.clue2solve.huggingface.inference.spring.cloud.starter.service.HFModels;

import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package io.clue2solve.spring.cloud.starter.huggingface.inference.starter.service.impl;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.clue2solve.huggingface.inference.spring.cloud.starter.autoconfigure.HFServiceAutoconfig;
import io.clue2solve.huggingface.inference.spring.cloud.starter.service.HFInferenceService;
import io.clue2solve.spring.cloud.starter.huggingface.inference.starter.TestInit;

import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;

import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

@SpringBootTest(classes = { TestInit.class, HFServiceAutoconfig.class })
public class HFTextGenerationServiceTest {

Logger logger = LoggerFactory.getLogger(HFTextGenerationServiceTest.class);

@Autowired
private HFInferenceService hfInferenceService;

@Test
public void HFInferenceEndpointServiceTestInvoke() throws Exception {
String response = hfInferenceService.invoke("the world or technology is heading towards");

// Assert that the response is not null
assertNotNull(response);

logger.info(response);
// // Parse the response into a JsonNode
ObjectMapper mapper = new ObjectMapper();
JsonNode node = mapper.readTree(response);

// // Pretty print the JsonNode
String prettyJson = mapper.writerWithDefaultPrettyPrinter().writeValueAsString(node);
logger.info("Response: {}", prettyJson);

// // Assert that the JsonNode is an array
assertTrue(node.isArray());
}

}
3 changes: 3 additions & 0 deletions src/test/resources/application.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
huggingface.api-token=YOUR_HUGGINGFACE_API_TOKEN
huggingface.model-name=MODEL_NAME ( like 'gpt2' )

0 comments on commit f8c027f

Please sign in to comment.