# Running LLM inference with Spring AI & Ollama
This notebook implement the basic text-to-text generation using [Spring AI](https://spring.io/projects/spring-ai) and Ollama.
You need to install Ollama in your machine in order to run.

Free feel to contribute to add more use cases.


## Install dependencies

In [51]:
// load version variables
%use @file[resources/version.json](currentDir=".")

In [52]:
USE {
    repositories {
        maven { url = "https://repo.spring.io/milestone" }
        mavenCentral()
    }
    dependencies {
        implementation("org.springframework.ai:spring-ai-core:$springAiVersion")
        implementation("org.springframework.ai:spring-ai-ollama:$springAiVersion")

        implementation("io.projectreactor:reactor-core:$reactorVersion")
        implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core-jvm:$kotlinCoroutineVersion")
        implementation("org.jetbrains.kotlinx:kotlinx-coroutines-reactor:$kotlinCoroutineVersion")
        implementation("org.jetbrains.kotlinx:kotlinx-coroutines-reactive:$kotlinCoroutineVersion")
        implementation("com.fasterxml.jackson.module:jackson-module-kotlin:$jacksonKotlinModule")
    }
    import(
        "kotlinx.coroutines.*",
        "kotlinx.coroutines.flow.*",
        "kotlinx.coroutines.reactor.*",
        "kotlinx.coroutines.reactive.*",
    )
}

// list the library, if the dependencies doesn't show up, run again and restart the kernel
notebook.currentClasspath.flatMap { it.split("/|\\\\".toRegex()).takeLast(1) }.sorted().joinToString("\n")

HdrHistogram-2.2.2.jar
LatencyUtils-2.0.3.jar
ST4-4.3.4.jar
annotations-13.0.jar
annotations-23.0.0.jar
antlr-runtime-3.5.3.jar
antlr4-runtime-4.13.1.jar
api-0.12.0-363.jar
classmate-1.5.1.jar
context-propagation-1.1.2.jar
jackson-annotations-2.18.2.jar
jackson-core-2.18.2.jar
jackson-databind-2.18.2.jar
jackson-datatype-jsr310-2.16.1.jar
jackson-module-jsonSchema-2.17.2.jar
jackson-module-kotlin-2.18.2.jar
jsonschema-generator-4.35.0.jar
jsonschema-module-jackson-4.35.0.jar
jsonschema-module-swagger-2-4.35.0.jar
jtokkit-1.1.0.jar
jul-to-slf4j-2.0.16.jar
kotlin-reflect-1.8.10.jar
kotlin-reflect-1.9.23.jar
kotlin-script-runtime-1.9.23.jar
kotlin-stdlib-1.9.23.jar
kotlin-stdlib-2.0.0.jar
kotlinx-coroutines-core-jvm-1.9.0.jar
kotlinx-coroutines-reactive-1.9.0.jar
kotlinx-coroutines-reactor-1.9.0.jar
kotlinx-serialization-core-jvm-1.6.3.jar
kotlinx-serialization-json-jvm-1.6.3.jar
lib-0.12.0-363.jar
log4j-api-2.23.1.jar
log4j-to-slf4j-2.23.1.jar
logback-classic-1.5.12.jar
logback-core-1.5.

## Function to measure model performance
A sample implementation using stream to measure key metrics
- time to first token
- input token process rate/s
- output token process rate/s

In [53]:
import org.springframework.ai.chat.model.ChatModel
import org.springframework.ai.chat.model.ChatResponse
import org.springframework.ai.chat.prompt.Prompt
import reactor.core.publisher.Flux
import kotlin.time.measureTimedValue

data class ModelPerformance(
    val timeToFirstTokenInMills: Double,
    val totalTimeInMills: Double,
    val promptTokens: Long,
    val generationTokens: Long,
    val inputTokenRatePerSec: Double,
    val outputTokenRatePerSec: Double
)

fun <T> measureModelPerformance(block: suspend () -> T): ModelPerformance {
    var ttft: Long?

    println("=== evaluating model performance ===")
    val timedResp = measureTimedValue {
        runBlocking {
            val startTime = System.nanoTime()
            ttft = System.nanoTime() - startTime
            val runnable = block()
            if (runnable is Flow<*>) {
                runnable
                    .onStart {
                        ttft = System.nanoTime() - startTime
                    }
                    .last()
            } else {
                runnable!!
            }
        }
    }
    println("\n=== model performance ===")
    val resp = timedResp.value

    if (resp is ChatResponse) {
        val totalTime = timedResp.duration.inWholeMilliseconds
        val timeToFirstTokenInMills = ttft!! / 1_000_000.0
        return ModelPerformance(
            timeToFirstTokenInMills = timeToFirstTokenInMills,
            totalTimeInMills = timedResp.duration.inWholeMilliseconds.toDouble(),
            promptTokens = resp.metadata.usage.promptTokens,
            generationTokens = resp.metadata.usage.generationTokens,
            inputTokenRatePerSec = resp.metadata.usage.promptTokens * 1000.0 / timeToFirstTokenInMills,
            outputTokenRatePerSec = resp.metadata.usage.generationTokens * 1000.0 / (totalTime - timeToFirstTokenInMills)
        )
    }
    throw IllegalArgumentException("not support type")
}

## Create Ollama Gemma2 2B INT4 model
You need approximately 2GB GPU VRAM to run `gemma2:2b` locally.

In [54]:
import org.springframework.ai.chat.prompt.Prompt
import org.springframework.ai.ollama.OllamaChatModel
import org.springframework.ai.ollama.api.OllamaApi
import org.springframework.ai.ollama.api.OllamaOptions

val model = OllamaChatModel.builder()
    .ollamaApi(OllamaApi("http://localhost:11434"))
    .defaultOptions(
        OllamaOptions.builder()
            .model("llama3.2:1b")
            .numCtx(1024)
            .temperature(0.0)
            .build()
    ).build()

measureModelPerformance {
    model.stream(Prompt("tell me 5 jokes"))
        .asFlow()
        .onEach {
            print("${it.result?.output?.text}")
        }
}


=== evaluating model performance ===
Here are five jokes for you:

1. Why did the scarecrow win an award? Because he was outstanding in his field.

2. What do you call a fake noodle? An impasta.

3. Why did the bicycle fall over? Because it was two-tired.

4. I told my wife she was drawing her eyebrows too high. She looked surprised.

5. Why don't scientists trust atoms? Because they make up everything.
=== model performance ===


ModelPerformance(timeToFirstTokenInMills=0.273209, totalTimeInMills=1118.0, promptTokens=30, generationTokens=91, inputTokenRatePerSec=109806.04592088841, outputTokenRatePerSec=81.41524452373979)

### Using `ChatClient` interface
An universal interface for multiple models in Spring AI. It is the preferred way to chat with LLM.


In [55]:
import org.springframework.ai.chat.client.ChatClient

val chatModel = ChatClient.builder(model)
    .defaultSystem("You are Peter, a helpful AI assistant.")
    .build()

measureModelPerformance {
    chatModel.prompt()
        .user("Tell me 5 jokes.")
        .call()
        .chatResponse()
        .also {
            print(it?.result?.output?.content)
        }
}


=== evaluating model performance ===
Here are five jokes for you:

1. Why did the scarecrow win an award? Because he was outstanding in his field.
2. What do you call a fake noodle? An impasta.
3. Why did the bicycle fall over? Because it was two-tired.
4. I told my wife she was drawing her eyebrows too high. She looked surprised.
5. Why don't scientists trust atoms? Because they make up everything.

I hope these jokes made you smile!
=== model performance ===


ModelPerformance(timeToFirstTokenInMills=4.2E-5, totalTimeInMills=743.0, promptTokens=40, generationTokens=99, inputTokenRatePerSec=9.523809523809525E8, outputTokenRatePerSec=133.24361453059464)

### Chat with Tools
Extend the LLM capability to lookup weathers via Function calling.

- Define a `FunctionCallback` in Spring AI
- Link it with the `ChatClient`

In [56]:
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.kotlin.jsonMapper
import com.fasterxml.jackson.module.kotlin.kotlinModule
import org.springframework.ai.chat.client.ChatClient
import org.springframework.ai.model.function.FunctionCallback
import org.springframework.ai.model.function.DefaultFunctionCallbackBuilder
import reactor.core.scheduler.Schedulers
import java.util.function.Function;

data class GetWeatherRequest(
    val date: String,
    val city: String,
    val country: String
)

data class GetWeatherResponse(
    val date: String,
    val city: String,
    val country: String,
    val unit: String,
    val tempature: String
)

val getWeatherAPI =
    DefaultFunctionCallbackBuilder().function("GetWeatherTool", Function<GetWeatherRequest, GetWeatherResponse> { req ->
        println("=== calling weather API with date=${req.date}, city=${req.city}, country=${req.country} ===")
        // mock response
        GetWeatherResponse(
            date = req.date,
            city = req.city,
            country = req.country,
            unit = "celsius",
            tempature = "24.0"
        )
    })
        .description("Get weather tool with city, country, and date")
        .inputType(GetWeatherRequest::class.java)
        .schemaType(FunctionCallback.SchemaType.JSON_SCHEMA)
        .objectMapper(jsonMapper {
            addModule(kotlinModule())
        })
        .build()


val chatModel = ChatClient.builder(model)
    .defaultSystem("You are Peter, a helpful AI assistant.")
    .defaultFunctions(getWeatherAPI)
    .build()

measureModelPerformance {
    chatModel.prompt("What's today wether in Seattle on 2025-01-01?")
        .call()
        .chatResponse()
        .also {
            print(it?.result?.output?.content)
        }
}


=== evaluating model performance ===
=== calling weather API with date=2025-01-01, city=Seattle, country=USA ===
The weather on January 1st, 2025 in Seattle is expected to be 24°C (75°F).
=== model performance ===


ModelPerformance(timeToFirstTokenInMills=4.2E-5, totalTimeInMills=466.0, promptTokens=354, generationTokens=65, inputTokenRatePerSec=8.428571428571429E9, outputTokenRatePerSec=139.48499111238118)