# Running LLM inference with Spring AI & Ollama
This notebook implement the basic text-to-text generation using [Spring AI](https://spring.io/projects/spring-ai) and OpenAI.
You need to an OpenAI API key to run. Rename `openaikey.example.json` to `openaikey.secret.json` and update the OpenAI key

Free feel to contribute to add more use cases.

## Install dependencies

In [11]:
// load version variables
%use @file[resources/version.json](currentDir=".")

In [12]:
USE {
    repositories {
        maven { url = "https://repo.spring.io/milestone" }
        mavenCentral()
    }
    dependencies {
        implementation("org.springframework.ai:spring-ai-core:$springAiVersion")
        implementation("org.springframework.ai:spring-ai-ollama:$springAiVersion")
        implementation("org.springframework.ai:spring-ai-openai:$springAiVersion")

        implementation("io.projectreactor:reactor-core:$reactorVersion")
        implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core-jvm:$kotlinCoroutineVersion")
        implementation("org.jetbrains.kotlinx:kotlinx-coroutines-reactor:$kotlinCoroutineVersion")
        implementation("org.jetbrains.kotlinx:kotlinx-coroutines-reactive:$kotlinCoroutineVersion")
        implementation("com.fasterxml.jackson.module:jackson-module-kotlin:$jacksonKotlinModule")
    }
    import(
        "kotlinx.coroutines.*",
        "kotlinx.coroutines.flow.*",
        "kotlinx.coroutines.reactor.*",
        "kotlinx.coroutines.reactive.*",
    )
}
// list the library, if the dependencies doesn't show up, run again and restart the kernel
notebook.currentClasspath.flatMap { it.split("/|\\\\".toRegex()).takeLast(1) }.sorted().joinToString("\n")

HdrHistogram-2.2.2.jar
LatencyUtils-2.0.3.jar
ST4-4.3.4.jar
annotations-13.0.jar
annotations-23.0.0.jar
antlr-runtime-3.5.3.jar
antlr4-runtime-4.13.1.jar
api-0.12.0-363.jar
classmate-1.5.1.jar
commons-lang3-3.11.jar
context-propagation-1.1.2.jar
groovy-4.0.16.jar
groovy-json-4.0.16.jar
jackson-annotations-2.17.2.jar
jackson-annotations-2.18.2.jar
jackson-core-2.17.2.jar
jackson-core-2.18.2.jar
jackson-databind-2.16.1.jar
jackson-databind-2.18.2.jar
jackson-datatype-jsr310-2.16.1.jar
jackson-module-jsonSchema-2.17.2.jar
jackson-module-kotlin-2.18.2.jar
json-path-5.4.0.jar
jsonschema-generator-4.35.0.jar
jsonschema-module-jackson-4.35.0.jar
jsonschema-module-swagger-2-4.35.0.jar
jtokkit-1.1.0.jar
jul-to-slf4j-2.0.16.jar
kotlin-reflect-1.8.10.jar
kotlin-reflect-1.9.23.jar
kotlin-script-runtime-1.9.23.jar
kotlin-stdlib-1.9.23.jar
kotlin-stdlib-2.0.0.jar
kotlinx-coroutines-core-jvm-1.9.0.jar
kotlinx-coroutines-reactive-1.9.0.jar
kotlinx-coroutines-reactor-1.9.0.jar
kotlinx-serialization-cor

## Function to measure model performance
A sample implementation using stream to measure key metrics
- time to first token
- input token process rate/s
- output token process rate/s

In [13]:
import org.springframework.ai.chat.model.ChatModel
import org.springframework.ai.chat.model.ChatResponse
import org.springframework.ai.chat.prompt.Prompt
import reactor.core.publisher.Flux
import kotlin.time.measureTimedValue

data class ModelPerformance(
    val timeToFirstTokenInMills: Double,
    val totalTimeInMills: Double,
    val promptTokens: Long,
    val generationTokens: Long,
    val inputTokenRatePerSec: Double,
    val outputTokenRatePerSec: Double
)

fun <T> measureModelPerformance(block: suspend () -> T): ModelPerformance {
    var ttft: Long?

    println("=== evaluating model performance ===")
    val timedResp = measureTimedValue {
        runBlocking {
            val startTime = System.nanoTime()
            ttft = System.nanoTime() - startTime
            val runnable = block()
            if (runnable is Flow<*>) {
                runnable
                    .onStart {
                        ttft = System.nanoTime() - startTime
                    }
                    .last()
            } else {
                runnable!!
            }
        }
    }
    println("\n=== model performance ===")
    val resp = timedResp.value

    if (resp is ChatResponse) {
        val totalTime = timedResp.duration.inWholeMilliseconds
        val timeToFirstTokenInMills = ttft!! / 1_000_000.0
        return ModelPerformance(
            timeToFirstTokenInMills = timeToFirstTokenInMills,
            totalTimeInMills = timedResp.duration.inWholeMilliseconds.toDouble(),
            promptTokens = resp.metadata.usage.promptTokens,
            generationTokens = resp.metadata.usage.generationTokens,
            inputTokenRatePerSec = resp.metadata.usage.promptTokens * 1000.0 / timeToFirstTokenInMills,
            outputTokenRatePerSec = resp.metadata.usage.generationTokens * 1000.0 / (totalTime - timeToFirstTokenInMills)
        )
    }
    throw IllegalArgumentException("not support type")
}

## Create GPT4o-mini chat model

### Load OpenAI Key into Kotlin Notebook
Rename `openaikey.example.json` to `openaikey.secret.json` and update the OpenAI key


In [14]:
// Load openaikey.json into `openAiKey`
%use @file[resources/openaikey.secret.json](currentDir=".")

In [15]:
import org.springframework.ai.chat.prompt.Prompt
import org.springframework.ai.openai.OpenAiChatModel
import org.springframework.ai.openai.OpenAiChatOptions
import org.springframework.ai.openai.api.OpenAiApi

val model = OpenAiChatModel(
    OpenAiApi(openAiKey),
    OpenAiChatOptions.builder()
        .streamUsage(true)
        .model(OpenAiApi.ChatModel.GPT_4_O_MINI)
        .temperature(0.7)
        .build()
)

measureModelPerformance {
    model.call(Prompt("tell me 5 jokes"))?.also {
        print(it.result?.output?.content)
    }
}

=== evaluating model performance ===
Sure! Here are five jokes for you:

1. Why don't scientists trust atoms?
   Because they make up everything!

2. What did the ocean say to the beach?
   Nothing, it just waved!

3. Why did the scarecrow win an award?
   Because he was outstanding in his field!

4. How does a penguin build its house?
   Igloos it together!

5. Why did the bicycle fall over?
   Because it was two-tired!

Hope these made you smile!
=== model performance ===


ModelPerformance(timeToFirstTokenInMills=8.3E-5, totalTimeInMills=2014.0, promptTokens=12, generationTokens=103, inputTokenRatePerSec=1.4457831325301206E8, outputTokenRatePerSec=51.14200806593181)

### Using

In [16]:
import org.springframework.ai.chat.client.ChatClient

val chatModel = ChatClient.builder(model)
    .defaultSystem("You are Peter, a helpful AI assistant.")
    .build()

measureModelPerformance {
    chatModel.prompt()
        .user("Tell me 5 jokes.")
        .call()
        .chatResponse()
        .also {
            print(it?.result?.output?.content)
        }
}

=== evaluating model performance ===
Sure! Here are five jokes for you:

1. Why don’t skeletons fight each other?  
   Because they don’t have the guts!

2. What do you call fake spaghetti?  
   An impasta!

3. Why did the scarecrow win an award?  
   Because he was outstanding in his field!

4. How does a penguin build its house?  
   Igloos it together!

5. Why was the math book sad?  
   Because it had too many problems!

Hope these made you smile!
=== model performance ===


ModelPerformance(timeToFirstTokenInMills=4.2E-5, totalTimeInMills=5607.0, promptTokens=26, generationTokens=109, inputTokenRatePerSec=6.190476190476191E8, outputTokenRatePerSec=19.439985877738433)

In [17]:
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.kotlin.jsonMapper
import com.fasterxml.jackson.module.kotlin.kotlinModule
import org.springframework.ai.chat.client.ChatClient
import org.springframework.ai.model.function.FunctionCallback
import org.springframework.ai.model.function.DefaultFunctionCallbackBuilder
import reactor.core.scheduler.Schedulers
import java.util.function.Function;

data class GetWeatherRequest(
    val date: String,
    val city: String,
    val country: String
)

data class GetWeatherResponse(
    val date: String,
    val city: String,
    val country: String,
    val unit: String,
    val tempature: String
)

val getWeatherAPI =
    DefaultFunctionCallbackBuilder().function("GetWeatherTool", Function<GetWeatherRequest, GetWeatherResponse> { req ->
        println("=== calling weather API with date=${req.date}, city=${req.city}, country=${req.country} ===")
        // mock response
        GetWeatherResponse(
            date = req.date,
            city = req.city,
            country = req.country,
            unit = "celsius",
            tempature = "24.0"
        )
    })
        .description("Get weather tool with city, country, and date")
        .inputType(GetWeatherRequest::class.java)
        .schemaType(FunctionCallback.SchemaType.JSON_SCHEMA)
        .objectMapper(jsonMapper {
            addModule(kotlinModule())
        })
        .build()


val chatModel = ChatClient.builder(model)
    .defaultSystem("You are Peter, a helpful AI assistant.")
    .defaultFunctions(getWeatherAPI)
    .build()

measureModelPerformance {
    chatModel.prompt("What's today wether in Seattle on 2025-01-01?")
        .call()
        .chatResponse()
        .also {
            print(it?.result?.output?.content)
        }
}

=== evaluating model performance ===
=== calling weather API with date=2025-01-01, city=Seattle, country=US ===
The weather in Seattle on January 1, 2025, is expected to be 24.0°C.
=== model performance ===


ModelPerformance(timeToFirstTokenInMills=8.3E-5, totalTimeInMills=1782.0, promptTokens=229, generationTokens=54, inputTokenRatePerSec=2.7590361445783134E9, outputTokenRatePerSec=30.30303171445097)