Skip to content

Commit

Permalink
Use task::block_in_place
Browse files Browse the repository at this point in the history
Use `task::block_in_place` with allows a number of `'static` lifetimes
to be removed.

This now works outside a task spawn due to
tokio-rs/tokio#2410 with the restriction that a
threaded runtime is used.

Fixes mehcode#9.

Signed-off-by: Joe Grund <jgrund@whamcloud.io>
  • Loading branch information
jgrund committed May 20, 2020
1 parent 44577ef commit 3a5ec82
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 34 deletions.
72 changes: 72 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
name: ci
on:
pull_request:
push:
branches:
- master

jobs:
check:
name: Check
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: actions-rs/toolchain@v1
with:
toolchain: stable
override: true
- uses: actions-rs/cargo@v1
with:
command: check
args: --locked

rustfmt:
name: Format
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: actions-rs/toolchain@v1
with:
toolchain: stable
override: true
- run: rustup component add rustfmt
- uses: actions-rs/cargo@v1
with:
command: fmt
args: --all -- --check

test:
name: Test Suite
runs-on: ubuntu-latest
services:
postgres:
image: postgres
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: postgres
ports:
- 5432:5432
# needed because the postgres container does not provide a healthcheck
options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
steps:
- name: Checkout sources
uses: actions/checkout@v1

- name: Install PostgreSQL client
run: sudo apt-get -yqq install libpq-dev

- name: Install toolchain
uses: actions-rs/toolchain@v1
with:
toolchain: stable
override: true

- name: Run cargo test
uses: actions-rs/cargo@v1
with:
command: test
args: --locked
env:
POSTGRES_HOST: postgres
POSTGRES_PORT: ${{ job.services.postgres.ports[5432] }}
16 changes: 8 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ license = "MIT/Apache-2.0"
categories = ["asynchronous", "database"]

[dependencies]
async-trait = "0.1.21"
diesel = { version = "1.4.3", features = [ "r2d2" ] }
futures = "0.3.1"
r2d2 = "0.8.7"
tokio = { version = "0.2.4", features = [ "blocking" ] }
async-trait = "0.1.31"
diesel = { version = "1.4.4", default-features = false, features = [ "r2d2" ] }
futures = { version = "0.3.5", default-features = false }
r2d2 = "0.8.8"
tokio = { version = ">=0.2.21", default-features = false, features = [ "blocking", "rt-threaded" ] }

[dev-dependencies]
diesel = { version = "1.4.3", features = [ "postgres", "uuidv07" ] }
uuid = { version = "0.7.4", features = [ "v4" ] }
tokio = { version = "0.2.4", default-features = false, features = [ "full" ] }
diesel = { version = "1.4.4", default-features = false, features = [ "postgres", "uuidv07" ] }
uuid = { version = "0.8.1", features = [ "v4" ] }
tokio = { version = ">=0.2.21", default-features = false, features = [ "full" ] }
46 changes: 20 additions & 26 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,10 @@ where
async fn batch_execute_async(&self, query: &str) -> AsyncResult<()> {
let self_ = self.clone();
let query = query.to_string();
task::spawn_blocking(move || {
task::block_in_place(move || {
let conn = self_.get().map_err(AsyncError::Checkout)?;
conn.batch_execute(&query).map_err(AsyncError::Error)
})
.await
.expect("task has panicked")
}
}

Expand All @@ -89,13 +87,13 @@ where
{
async fn run<R, Func>(&self, f: Func) -> AsyncResult<R>
where
R: 'static + Send,
Func: 'static + FnOnce(&Conn) -> QueryResult<R> + Send;
R: Send,
Func: FnOnce(&Conn) -> QueryResult<R> + Send;

async fn transaction<R, Func>(&self, f: Func) -> AsyncResult<R>
where
R: 'static + Send,
Func: 'static + FnOnce(&Conn) -> QueryResult<R> + Send;
R: Send,
Func: FnOnce(&Conn) -> QueryResult<R> + Send;
}

#[async_trait]
Expand All @@ -106,31 +104,27 @@ where
#[inline]
async fn run<R, Func>(&self, f: Func) -> AsyncResult<R>
where
R: 'static + Send,
Func: 'static + FnOnce(&Conn) -> QueryResult<R> + Send,
R: Send,
Func: FnOnce(&Conn) -> QueryResult<R> + Send,
{
let self_ = self.clone();
task::spawn_blocking(move || {
task::block_in_place(move || {
let conn = self_.get().map_err(AsyncError::Checkout)?;
f(&*conn).map_err(AsyncError::Error)
})
.await
.expect("task has panicked")
}

#[inline]
async fn transaction<R, Func>(&self, f: Func) -> AsyncResult<R>
where
R: 'static + Send,
Func: 'static + FnOnce(&Conn) -> QueryResult<R> + Send,
R: Send,
Func: FnOnce(&Conn) -> QueryResult<R> + Send,
{
let self_ = self.clone();
task::spawn_blocking(move || {
task::block_in_place(move || {
let conn = self_.get().map_err(AsyncError::Checkout)?;
conn.transaction(|| f(&*conn)).map_err(AsyncError::Error)
})
.await
.expect("task has panicked")
}
}

Expand All @@ -145,30 +139,30 @@ where

async fn load_async<U>(self, asc: &AsyncConn) -> AsyncResult<Vec<U>>
where
U: 'static + Send,
U: Send,
Self: LoadQuery<Conn, U>;

async fn get_result_async<U>(self, asc: &AsyncConn) -> AsyncResult<U>
where
U: 'static + Send,
U: Send,
Self: LoadQuery<Conn, U>;

async fn get_results_async<U>(self, asc: &AsyncConn) -> AsyncResult<Vec<U>>
where
U: 'static + Send,
U: Send,
Self: LoadQuery<Conn, U>;

async fn first_async<U>(self, asc: &AsyncConn) -> AsyncResult<U>
where
U: 'static + Send,
U: Send,
Self: LimitDsl,
Limit<Self>: LoadQuery<Conn, U>;
}

#[async_trait]
impl<T, Conn> AsyncRunQueryDsl<Conn, Pool<ConnectionManager<Conn>>> for T
where
T: 'static + Send + RunQueryDsl<Conn>,
T: Send + RunQueryDsl<Conn>,
Conn: 'static + Connection,
{
async fn execute_async(self, asc: &Pool<ConnectionManager<Conn>>) -> AsyncResult<usize>
Expand All @@ -180,31 +174,31 @@ where

async fn load_async<U>(self, asc: &Pool<ConnectionManager<Conn>>) -> AsyncResult<Vec<U>>
where
U: 'static + Send,
U: Send,
Self: LoadQuery<Conn, U>,
{
asc.run(|conn| self.load(&*conn)).await
}

async fn get_result_async<U>(self, asc: &Pool<ConnectionManager<Conn>>) -> AsyncResult<U>
where
U: 'static + Send,
U: Send,
Self: LoadQuery<Conn, U>,
{
asc.run(|conn| self.get_result(&*conn)).await
}

async fn get_results_async<U>(self, asc: &Pool<ConnectionManager<Conn>>) -> AsyncResult<Vec<U>>
where
U: 'static + Send,
U: Send,
Self: LoadQuery<Conn, U>,
{
asc.run(|conn| self.get_results(&*conn)).await
}

async fn first_async<U>(self, asc: &Pool<ConnectionManager<Conn>>) -> AsyncResult<U>
where
U: 'static + Send,
U: Send,
Self: LimitDsl,
Limit<Self>: LoadQuery<Conn, U>,
{
Expand Down
1 change: 1 addition & 0 deletions tests/create_users.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE table users (id uuid);
47 changes: 47 additions & 0 deletions tests/integration_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#[macro_use]
extern crate diesel;

use diesel::{
prelude::*,
r2d2::{ConnectionManager, Pool},
sql_query
};
use std::{env, error::Error};
use tokio_diesel::*;
use uuid::Uuid;

// Schema
table! {
users (id) {
id -> Uuid,
}
}




#[tokio::test(threaded_scheduler)]
async fn test_db_ops() -> Result<(), Box<dyn Error>> {
let hostname = env::var("POSTGRES_HOST").unwrap_or_else(|_| "localhost".into());

let manager =
ConnectionManager::<PgConnection>::new(&format!("postgres://postgres@{}/tokio_diesel__test", hostname));
let pool = Pool::builder().build(manager)?;

let _ = sql_query(include_str!("./create_users.sql")).execute_async(&pool).await;

// Add
println!("add a user");
diesel::insert_into(users::table)
.values(users::id.eq(Uuid::new_v4()))
.execute_async(&pool)
.await?;

// Count
let num_users: i64 = users::table.count().get_result_async(&pool).await?;
println!("now there are {:?} users", num_users);

assert!(num_users > 0);

Ok(())
}

0 comments on commit 3a5ec82

Please sign in to comment.