Skip to content

Commit

Permalink
Minor sqlx query tweaks and README update.
Browse files Browse the repository at this point in the history
  • Loading branch information
emkshv committed Jul 26, 2023
1 parent d4d70c5 commit 88f5d44
Show file tree
Hide file tree
Showing 10 changed files with 61 additions and 57 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

This file was deleted.

This file was deleted.

1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,3 @@ chrono = "0.4.26"
async-openai = "0.12.1"
async-trait = "0.1.72"
clap = { version = "4.3.19", features = ["derive"] }
dotenvy = "0.15.7"
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,20 @@ Built with:

### Creating a Telegram bot and obtaining a token

Go to [BotFather](https://telegram.me/BotFather) and enter `/newbot`. Fill in the description and save the token to the `TELEGRAM_TOKEN` environment variable. Also set the commands for the nice autocomplete: enter `/setcommands`, select your bot, and then paste:
Go to [BotFather](https://telegram.me/BotFather) and enter `/newbot`. Fill in the description and save the token to the `TELEGRAM_TOKEN` environment variable. To define the commands for the autocomplete: enter `/setcommands`, select your bot, and then paste:

```
new - Clear the current context and start a new chat.
get_behavior - Display the current system message that define's the bot's behavior.
get_behavior - Display the current system message that defines the bot's behavior.
set_behavior - Set the new system message for defining the bot's behavior.
get_model - Get the current completion model.
set_model - Set the completion model for your bot.
version - Display the current version
version - Display the current version.
```

### Running using Docker

Make sure you have [Docker](https://docs.docker.com/get-docker/) & [Docker Compose](https://docs.docker.com/compose/install/) or [OrbStack](https://orbstack.dev/) installed.
Make sure you have [Docker](https://docs.docker.com/get-docker/) & [Docker Compose](https://docs.docker.com/compose/install/). On desktop, you can use [Docker Desktop](https://docker.com/products/docker-desktop/) or [OrbStack](https://orbstack.dev/).

The Docker Compose file expects your environment variables to be loaded:

Expand All @@ -48,3 +48,9 @@ docker-compose up
* Edit `.envrc` to set environment variables
* Load environment variables from `.envrc` using [direnv](https://direnv.net/), or `source .envrc`-equivalent in your shell.
* Now you can compile with `cargo build`

### Building from source

* Install Rust via [RustUp](https://rustup.rs/)
* Clone the repository
* Run `cargo build --release`
4 changes: 2 additions & 2 deletions src/bot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ async fn handle_any(e: Event, state: State<RunningBotState>) -> Result<Action, a
.await
.context("Failed to get the current chat thread")?;

let _new_chat_message = chat_message::insert_new_message(
let _new_chat_message_id = chat_message::insert_new_message(
&db,
&message_content,
message.chat.id,
Expand Down Expand Up @@ -243,7 +243,7 @@ async fn handle_any(e: Event, state: State<RunningBotState>) -> Result<Action, a

match maybe_answer.await {
Ok(content) => {
let _new_chat_message = chat_message::insert_new_message(
let _new_chat_message_id = chat_message::insert_new_message(
&db,
&content,
message.chat.id,
Expand Down
20 changes: 9 additions & 11 deletions src/db/chat_bot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,18 @@ pub async fn set_chat_bot_mock_model(
) -> Result<ChatBot> {
let completion_model_string = completion_model.as_str();

sqlx::query!(
r#"UPDATE chat_bots SET mock_model = ?1 WHERE id = ?2;"#,
completion_model_string,
id
let chat_bot = sqlx::query_as::<_, ChatBot>(
"UPDATE chat_bots SET mock_model = ?1 WHERE id = ?2 RETURNING *;",
)
.bind(completion_model_string)
.bind(id)
.fetch_one(db_conn)
.await
.context(format!(
"Couldn't update the chat bot's Mock completion model with {}",
completion_model_string
))?;

let chat_bot = get_by_id(db_conn, id).await?;

Ok(chat_bot)
}

Expand All @@ -94,13 +92,13 @@ pub async fn set_chat_bot_openai_model(
id: i64,
completion_model: OpenAICompletionModel,
) -> Result<ChatBot> {
let completion_model_string = completion_model.as_str().to_string();
let completion_model_string = completion_model.as_str();

sqlx::query!(
r#"UPDATE chat_bots SET openai_model = ?1 WHERE id = ?2;"#,
completion_model_string,
id,
sqlx::query_as::<_, ChatBot>(
"UPDATE chat_bots SET openai_model = ?1 WHERE id = ?2 RETURNING *;",
)
.bind(completion_model_string)
.bind(id)
.fetch_one(db_conn)
.await
.context(format!(
Expand Down
40 changes: 28 additions & 12 deletions src/db/chat_message.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use sqlx::{FromRow, Pool, Sqlite};
extern crate rand;
use anyhow::Context;
use rand::Rng;

#[derive(Clone, FromRow, Debug)]
Expand All @@ -12,27 +13,38 @@ pub struct ChatMessage {
pub inserted_at: chrono::DateTime<chrono::Utc>,
}

#[derive(Clone, FromRow, Debug)]
pub struct NewChatMessage {
pub id: i64,
pub content: String,
pub chat_id: i64,
pub chat_thread_id: i64,
pub user_role: String,
}

pub async fn insert_new_message(
db_conn: &Pool<Sqlite>,
content: &String,
chat_id: i64,
chat_thread_id: i64,
user_role: &str,
) -> anyhow::Result<ChatMessage> {
) -> anyhow::Result<i64> {
let new_id: i64 = rand::thread_rng().gen_range(1..i64::MAX);

let chat_message = sqlx::query_as::<_, ChatMessage>(
"INSERT INTO chat_messages (id, content, chat_id, chat_thread_id, user_role) VALUES(?1, ?2, ?3, ?4, ?5) RETURNING *",
let _chat_message = sqlx::query_as!(
NewChatMessage,
r#"INSERT INTO chat_messages (id, content, chat_id, chat_thread_id, user_role) VALUES(?1, ?2, ?3, ?4, ?5)"#,
new_id,
content,
chat_id,
chat_thread_id,
user_role
)
.bind(new_id)
.bind(content)
.bind(chat_id)
.bind(chat_thread_id)
.bind(user_role)
.fetch_one(db_conn)
.await?;
.execute(db_conn)
.await
.context("Failed to create a chat message")?;

Ok(chat_message)
Ok(new_id)
}

pub async fn get_chat_thread_messages(
Expand All @@ -44,7 +56,11 @@ pub async fn get_chat_thread_messages(
)
.bind(chat_thread_id)
.fetch_all(db_conn)
.await?;
.await
.context(format!(
"Failed to ge tthe chat messages for id {}",
chat_thread_id
))?;

Ok(chat_messages)
}
2 changes: 0 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ mod llm;

#[tokio::main(flavor = "current_thread")]
async fn main() -> anyhow::Result<()> {
dotenvy::dotenv()?;

let config = config::create_config();

println!(
Expand Down

0 comments on commit 88f5d44

Please sign in to comment.