Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/gdrive_text_embedding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ def gdrive_text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope:
"""
credential_path = os.environ["GOOGLE_SERVICE_ACCOUNT_CREDENTIAL"]
root_folder_ids = os.environ["GOOGLE_DRIVE_ROOT_FOLDER_IDS"].split(",")

data_scope["documents"] = flow_builder.add_source(
cocoindex.sources.GoogleDrive(
service_account_credential_path=credential_path,
root_folder_ids=root_folder_ids),
root_folder_ids=root_folder_ids,
recent_changes_poll_interval=datetime.timedelta(seconds=10)),
refresh_options=cocoindex.SourceRefreshOptions(
refresh_interval=datetime.timedelta(minutes=1)))

Expand Down
2 changes: 2 additions & 0 deletions python/cocoindex/sources.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""All builtin sources."""
from . import op
import datetime

class LocalFile(op.SourceSpec):
"""Import data from local file system."""
Expand All @@ -26,3 +27,4 @@ class GoogleDrive(op.SourceSpec):
service_account_credential_path: str
root_folder_ids: list[str]
binary: bool = False
recent_changes_poll_interval: datetime.timedelta | None = None
50 changes: 26 additions & 24 deletions src/execution/live_updater.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,33 +85,35 @@ async fn update_source(
let mut futs: Vec<BoxFuture<'_, Result<()>>> = Vec::new();

// Deal with change streams.
if let (true, Some(change_stream)) = (options.live_mode, import_op.executor.change_stream()) {
let pool = pool.clone();
let source_update_stats = source_update_stats.clone();
futs.push(
async move {
let mut change_stream = change_stream;
while let Some(change) = change_stream.next().await {
source_context
.process_change(change, &pool, &source_update_stats)
.map(tokio::spawn);
if options.live_mode {
if let Some(change_stream) = import_op.executor.change_stream().await? {
let pool = pool.clone();
let source_update_stats = source_update_stats.clone();
futs.push(
async move {
let mut change_stream = change_stream;
while let Some(change) = change_stream.next().await {
source_context
.process_change(change, &pool, &source_update_stats)
.map(tokio::spawn);
}
Ok(())
}
Ok(())
}
.boxed(),
);
futs.push(
async move {
let mut interval = tokio::time::interval(REPORT_INTERVAL);
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
interval.tick().await;
loop {
.boxed(),
);
futs.push(
async move {
let mut interval = tokio::time::interval(REPORT_INTERVAL);
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
interval.tick().await;
report_stats();
loop {
interval.tick().await;
report_stats();
}
}
}
.boxed(),
);
.boxed(),
);
}
}

// The main update loop.
Expand Down
4 changes: 2 additions & 2 deletions src/ops/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ pub trait SourceExecutor: Send + Sync {
// Get the value for the given key.
async fn get_value(&self, key: &KeyValue) -> Result<Option<FieldValues>>;

fn change_stream<'a>(&'a self) -> Option<BoxStream<'a, SourceChange>> {
None
async fn change_stream(&self) -> Result<Option<BoxStream<'async_trait, SourceChange>>> {
Ok(None)
}
}

Expand Down
135 changes: 126 additions & 9 deletions src/ops/sources/google_drive.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
use std::{
collections::{HashMap, HashSet},
sync::{Arc, LazyLock},
};

use async_stream::try_stream;
use chrono::Duration;
use google_drive3::{
api::{File, Scope},
yup_oauth2::{read_service_account_key, ServiceAccountAuthenticator},
Expand All @@ -12,7 +7,6 @@ use google_drive3::{
use http_body_util::BodyExt;
use hyper_rustls::HttpsConnector;
use hyper_util::client::legacy::connect::HttpConnector;
use log::{trace, warn};

use crate::base::field_attrs;
use crate::ops::sdk::*;
Expand Down Expand Up @@ -75,12 +69,14 @@ pub struct Spec {
service_account_credential_path: String,
binary: bool,
root_folder_ids: Vec<String>,
recent_changes_poll_interval: Option<std::time::Duration>,
}

struct Executor {
drive_hub: DriveHub<HttpsConnector<HttpConnector>>,
binary: bool,
root_folder_ids: Vec<Arc<str>>,
root_folder_ids: IndexSet<Arc<str>>,
recent_updates_poll_interval: Option<std::time::Duration>,
}

impl Executor {
Expand All @@ -106,6 +102,7 @@ impl Executor {
drive_hub,
binary: spec.binary,
root_folder_ids: spec.root_folder_ids.into_iter().map(Arc::from).collect(),
recent_updates_poll_interval: spec.recent_changes_poll_interval,
})
}
}
Expand All @@ -122,6 +119,7 @@ fn escape_string(s: &str) -> String {
escaped
}

const CUTOFF_TIME_BUFFER: Duration = Duration::seconds(1);
impl Executor {
fn visit_file(
&self,
Expand Down Expand Up @@ -151,7 +149,6 @@ impl Executor {
ordinal: file.modified_time.map(|t| t.try_into()).transpose()?,
})
} else {
trace!("Skipping file with unsupported mime type: id={id}, mime_type={mime_type}, name={:?}", file.name);
None
};
Ok(result)
Expand All @@ -175,9 +172,101 @@ impl Executor {
list_call = list_call.page_token(next_page_token);
}
let (_, files) = list_call.doit().await?;
*next_page_token = files.next_page_token;
let file_iter = files.files.into_iter().flat_map(|file| file.into_iter());
Ok(file_iter)
}

fn make_cutoff_time(
most_recent_modified_time: Option<DateTime<Utc>>,
list_start_time: DateTime<Utc>,
) -> DateTime<Utc> {
let safe_upperbound = list_start_time - CUTOFF_TIME_BUFFER;
most_recent_modified_time
.map(|t| t.min(safe_upperbound))
.unwrap_or(safe_upperbound)
}

async fn get_recent_updates(
&self,
cutoff_time: &mut DateTime<Utc>,
) -> Result<Vec<SourceChange>> {
let mut page_size: i32 = 10;
let mut next_page_token: Option<String> = None;
let mut changes = Vec::new();
let mut most_recent_modified_time = None;
let start_time = Utc::now();
'paginate: loop {
let mut list_call = self
.drive_hub
.files()
.list()
.add_scope(Scope::Readonly)
.param("fields", "files(id,modifiedTime,parents,trashed)")
.order_by("modifiedTime desc")
.page_size(page_size);
if let Some(token) = next_page_token {
list_call = list_call.page_token(token.as_str());
}
let (_, files) = list_call.doit().await?;
for file in files.files.into_iter().flat_map(|files| files.into_iter()) {
let modified_time = file.modified_time.unwrap_or_default();
if most_recent_modified_time.is_none() {
most_recent_modified_time = Some(modified_time);
}
if modified_time <= *cutoff_time {
break 'paginate;
}
if self.is_file_covered(&file).await? {
changes.push(SourceChange {
ordinal: Some(modified_time.try_into()?),
key: KeyValue::Str(Arc::from(
file.id.ok_or_else(|| anyhow!("File has no id"))?,
)),
value: SourceValueChange::Upsert(None),
});
}
}
if let Some(token) = files.next_page_token {
next_page_token = Some(token);
} else {
break;
}
// List more in a page since 2nd.
page_size = 100;
}
*cutoff_time = Self::make_cutoff_time(most_recent_modified_time, start_time);
Ok(changes)
}

async fn is_file_covered(&self, file: &File) -> Result<bool> {
if file.trashed == Some(true) {
return Ok(false);
}
let mut next_file_id = Some(Cow::Borrowed(
file.id.as_ref().ok_or_else(|| anyhow!("File has no id"))?,
));
while let Some(file_id) = next_file_id {
if self.root_folder_ids.contains(file_id.as_str()) {
return Ok(true);
}
let (_, file) = self
.drive_hub
.files()
.get(&file_id)
.add_scope(Scope::Readonly)
.param("fields", "parents")
.doit()
.await?;
next_file_id = file
.parents
.into_iter()
.flat_map(|parents| parents.into_iter())
.map(Cow::Owned)
.next();
}
Ok(false)
}
}

trait ResultExt<T> {
Expand Down Expand Up @@ -311,6 +400,34 @@ impl SourceExecutor for Executor {
};
Ok(value)
}

async fn change_stream(&self) -> Result<Option<BoxStream<'async_trait, SourceChange>>> {
let poll_interval = if let Some(poll_interval) = self.recent_updates_poll_interval {
poll_interval
} else {
return Ok(None);
};
let mut cutoff_time = Utc::now() - CUTOFF_TIME_BUFFER;
let mut interval = tokio::time::interval(poll_interval);
interval.tick().await;
let stream = stream! {
loop {
interval.tick().await;
let changes = self.get_recent_updates(&mut cutoff_time).await;
match changes {
Ok(changes) => {
for change in changes {
yield change;
}
}
Err(e) => {
error!("Error getting recent updates: {e}");
}
}
}
};
Ok(Some(stream.boxed()))
}
}

pub struct Factory;
Expand Down
5 changes: 4 additions & 1 deletion src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

pub(crate) use anyhow::Result;
pub(crate) use async_trait::async_trait;
pub(crate) use chrono::{DateTime, Utc};
pub(crate) use futures::{future::BoxFuture, prelude::*, stream::BoxStream};
pub(crate) use futures::{FutureExt, StreamExt};
pub(crate) use indexmap::{IndexMap, IndexSet};
pub(crate) use itertools::Itertools;
pub(crate) use serde::{Deserialize, Serialize};
pub(crate) use std::borrow::Cow;
pub(crate) use std::collections::{BTreeMap, HashMap};
pub(crate) use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
pub(crate) use std::sync::{Arc, LazyLock, Mutex, OnceLock, RwLock, Weak};

pub(crate) use crate::base::{schema, spec, value};
Expand All @@ -20,4 +22,5 @@ pub(crate) use crate::service::error::ApiError;
pub(crate) use crate::{api_bail, api_error};

pub(crate) use anyhow::{anyhow, bail};
pub(crate) use async_stream::{stream, try_stream};
pub(crate) use log::{debug, error, info, trace, warn};