### Search your Google Drive knowledge base with fully local processing.

In [None]:
%pip install -qU "langchain==0.3.27" "langchain-core<1.0.0,>=0.3.78" "langchain-text-splitters<1.0.0,>=0.3.9" langchain_ollama langchain_chroma langchain_community google-auth google-auth-oauthlib google-auth-httplib2 google-api-python-client python-docx


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [None]:
import json
import logging
import os
import re
import sys
import hashlib
from pathlib import Path
from enum import StrEnum
from typing import Iterable, Optional

import gradio as gr
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter, MarkdownHeaderTextSplitter
from langchain_ollama import OllamaEmbeddings, ChatOllama
from langchain.storage import InMemoryStore
from langchain_chroma import Chroma
from langchain_community.document_loaders import TextLoader
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import InstalledAppFlow
from google.auth.transport.requests import Request
from googleapiclient.discovery import build
from googleapiclient.http import MediaIoBaseDownload
from googleapiclient.errors import HttpError
from docx import Document as DocxDocument

In [None]:
logger = logging.getLogger('drive_sage')
logger.setLevel(logging.DEBUG)

if not logger.handlers:
    handler = logging.StreamHandler(sys.stdout)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    handler.setFormatter(formatter)
    logger.addHandler(handler)

In [None]:
SCOPES = ['https://www.googleapis.com/auth/drive.readonly']
APP_ROOT = Path.cwd()
DATA_DIR = APP_ROOT / '.drive_sage'
DOWNLOAD_DIR = DATA_DIR / 'downloads'
VECTORSTORE_DIR = DATA_DIR / 'chroma'
TOKEN_PATH = DATA_DIR / 'token.json'
MANIFEST_PATH = DATA_DIR / 'manifest.json'
CLIENT_SECRET_FILE = APP_ROOT / 'client_secret_202216035337-4qson0c08g71u8uuihv6v46arv64nhvg.apps.googleusercontent.com.json'

for path in (DATA_DIR, DOWNLOAD_DIR, VECTORSTORE_DIR):
    path.mkdir(parents=True, exist_ok=True)

FILE_TYPE_OPTIONS = {
    'txt': {
        'label': '.txt - Plain text',
        'extensions': ['.txt'],
        'mime_types': ['text/plain'],
    },
    'md': {
        'label': '.md - Markdown',
        'extensions': ['.md'],
        'mime_types': ['text/markdown', 'text/plain'],
    },
    'docx': {
        'label': '.docx - Word (OpenXML)',
        'extensions': ['.docx'],
        'mime_types': ['application/vnd.openxmlformats-officedocument.wordprocessingml.document'],
    },
    'doc': {
        'label': '.doc - Word 97-2003',
        'extensions': ['.doc'],
        'mime_types': ['application/msword', 'application/vnd.ms-word.document.macroenabled.12'],
    },
    'gdoc': {
        'label': 'Google Docs (exported)',
        'extensions': ['.docx'],
        'mime_types': ['application/vnd.google-apps.document'],
    },
}

FILE_TYPE_LABEL_TO_KEY = {config['label']: key for key, config in FILE_TYPE_OPTIONS.items()}
DEFAULT_FILE_TYPE_KEYS = ['txt', 'md', 'docx', 'doc', 'gdoc']
DEFAULT_FILE_TYPE_LABELS = [FILE_TYPE_OPTIONS[key]['label'] for key in DEFAULT_FILE_TYPE_KEYS]

MIME_TYPE_TO_EXTENSION = {}
for key, config in FILE_TYPE_OPTIONS.items():
    extension = config['extensions'][0]
    for mime in config['mime_types']:
        MIME_TYPE_TO_EXTENSION[mime] = extension

GOOGLE_EXPORT_FORMATS = {
    'application/vnd.google-apps.document': (
        'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
        '.docx'
    ),
}

SIMILARITY_DISTANCE_MAX = float(os.getenv('DRIVE_SAGE_DISTANCE_MAX', '1.2'))
MAX_CONTEXT_SNIPPET_CHARS = 1200
HASH_BLOCK_SIZE = 65536
EMBED_MODEL = os.getenv('DRIVE_SAGE_EMBED_MODEL', 'nomic-embed-text')
CHAT_MODEL = os.getenv('DRIVE_SAGE_CHAT_MODEL', 'llama3.1:latest')

CUSTOM_CSS = """
#chat-column {
    height: 80vh;
}
#chat-column > div {
    height: 100%;
}
#chat-column .gradio-chatbot,
#chat-column .gradio-chat-interface,
#chat-column .gradio-chatinterface {
    height: 100%;
}
#chat-output {
    height: 100%;
}
#chat-output .overflow-y-auto {
    max-height: 100% !important;
}
#chat-output .h-full {
    height: 100% !important;
}
"""

In [None]:
def build_drive_service():
    creds = None
    if TOKEN_PATH.exists():
        try:
            creds = Credentials.from_authorized_user_file(str(TOKEN_PATH), SCOPES)
        except Exception as exc:
            logger.warning('Failed to load cached credentials: %s', exc)
            TOKEN_PATH.unlink(missing_ok=True)
            creds = None

    if not creds or not creds.valid:
        if creds and creds.expired and creds.refresh_token:
            try:
                creds.refresh(Request())
            except Exception as exc:
                logger.warning('Refreshing credentials failed: %s', exc)
                creds = None

        if not creds or not creds.valid:
            if not CLIENT_SECRET_FILE.exists():
                raise FileNotFoundError(
                    'client_secret.json not found. Download it from Google Cloud Console and place it next to this notebook.'
                )
            flow = InstalledAppFlow.from_client_secrets_file(str(CLIENT_SECRET_FILE), SCOPES)
            creds = flow.run_local_server(port=0)

        with TOKEN_PATH.open('w', encoding='utf-8') as token_file:
            token_file.write(creds.to_json())
            
    return build('drive', 'v3', credentials=creds)

In [None]:
def load_manifest() -> dict:
    if MANIFEST_PATH.exists():
        try:
            with MANIFEST_PATH.open('r', encoding='utf-8') as fp:
                raw = json.load(fp)
            if isinstance(raw, dict):
                normalized: dict[str, dict] = {}
                for file_id, entry in raw.items():
                    if isinstance(entry, dict):
                        normalized[file_id] = entry
                    else:
                        normalized[file_id] = {'modified': str(entry)}
                return normalized
        except json.JSONDecodeError:
            logger.warning('Manifest file is corrupted; resetting cache.')
    return {}

def save_manifest(manifest: dict) -> None:
    with MANIFEST_PATH.open('w', encoding='utf-8') as fp:
        json.dump(manifest, fp, indent=2)

In [None]:
class Metadata(StrEnum):
    ID = 'id'
    SOURCE = 'source'
    PARENT_ID = 'parent_id'
    FILE_TYPE = 'file_type'
    TITLE = 'title'
    MODIFIED = 'modified'

def metadata_key(key: Metadata) -> str:
    return key.value

embeddings = OllamaEmbeddings(model=EMBED_MODEL)

try:
    vectorstore = Chroma(
        collection_name='drive_sage',
        embedding_function=embeddings,
    )
except Exception as exc:
    logger.exception('Failed to initialize in-memory Chroma vector store')
    raise RuntimeError('Unable to initialize Chroma vector store without persistence.') from exc

docstore = InMemoryStore()
model = ChatOllama(model=CHAT_MODEL)

DEFAULT_TEXT_SPLITTER = RecursiveCharacterTextSplitter(
    chunk_size=1000,
    chunk_overlap=150,
    separators=['\n\n', '\n', ' ', '']
)
MARKDOWN_HEADERS = [('#', 'Header 1'), ('##', 'Header 2'), ('###', 'Header 3')]
MARKDOWN_SPLITTER = MarkdownHeaderTextSplitter(headers_to_split_on=MARKDOWN_HEADERS, strip_headers=False)

In [None]:
def safe_filename(name: str, max_length: int = 120) -> str:
    sanitized = re.sub(r'[^A-Za-z0-9._-]', '_', name)
    sanitized = sanitized.strip('._') or 'untitled'
    return sanitized[:max_length]

def determine_extension(metadata: dict) -> str:
    mime_type = metadata.get('mimeType', '')
    name = metadata.get('name')
    if name and Path(name).suffix:
        return Path(name).suffix.lower()
    if mime_type in GOOGLE_EXPORT_FORMATS:
        return GOOGLE_EXPORT_FORMATS[mime_type][1]
    return MIME_TYPE_TO_EXTENSION.get(mime_type, '.txt')

def cached_file_path(metadata: dict) -> Path:
    file_id = metadata.get('id', 'unknown')
    extension = determine_extension(metadata)
    safe_name = safe_filename(Path(metadata.get('name', file_id)).stem)
    return DOWNLOAD_DIR / f'{safe_name}_{file_id}{extension}'

def hash_file(path: Path) -> str:
    digest = hashlib.sha1()
    with path.open('rb') as fh:
        while True:
            block = fh.read(HASH_BLOCK_SIZE)
            if not block:
                break
            digest.update(block)
    return digest.hexdigest()

def manifest_version(entry: dict | str | None) -> Optional[str]:
    if entry is None:
        return None
    if isinstance(entry, str):
        return entry
    if isinstance(entry, dict):
        return entry.get('modified')
    return None

def update_manifest_entry(manifest: dict, *, file_id: str, modified: str, path: Path, mime_type: str, name: str) -> None:
    manifest[file_id] = {
        'modified': modified,
        'path': str(path),
        'mimeType': mime_type,
        'name': name,
        'file_type': Path(path).suffix.lower(),
    }

In [None]:
def list_drive_text_files(service, folder_id: Optional[str], allowed_mime_types: list[str], limit: Optional[int]) -> list[dict]:
    query_parts = ["trashed = false"]
    mime_types = allowed_mime_types or list(MIME_TYPE_TO_EXTENSION.keys())
    mime_clause = ' or '.join([f"mimeType = '{mime}'" for mime in mime_types])
    query_parts.append(f'({mime_clause})')
    if folder_id:
        query_parts.append(f"'{folder_id}' in parents")
    query = ' and '.join(query_parts)

    files: list[dict] = []
    page_token: Optional[str] = None

    while True:
        page_size = min(100, limit - len(files)) if limit else 100
        if page_size <= 0:
            break
        try:
            response = service.files().list(
                q=query,
                spaces='drive',
                fields='nextPageToken, files(id, name, mimeType, modifiedTime)',
                orderBy='modifiedTime desc',
                pageToken=page_token,
                pageSize=page_size,
            ).execute()
        except HttpError as exc:
            raise RuntimeError(f'Google Drive API error: {exc}') from exc

        batch = response.get('files', [])
        files.extend(batch)
        if limit and len(files) >= limit:
            return files[:limit]
        page_token = response.get('nextPageToken')
        if not page_token:
            break
    return files

def download_drive_file(service, metadata: dict, manifest: dict) -> Path:
    file_id = metadata['id']
    mime_type = metadata.get('mimeType', '')
    cache_path = cached_file_path(metadata)
    export_mime = None
    if mime_type in GOOGLE_EXPORT_FORMATS:
        export_mime, extension = GOOGLE_EXPORT_FORMATS[mime_type]
        if cache_path.suffix.lower() != extension:
            cache_path = cache_path.with_suffix(extension)


    request = (
        service.files().export_media(fileId=file_id, mimeType=export_mime)
        if export_mime
        else service.files().get_media(fileId=file_id)
    )

    logger.debug('Downloading %s (%s) -> %s', metadata.get('name', file_id), file_id, cache_path)
    with cache_path.open('wb') as fh:
        downloader = MediaIoBaseDownload(fh, request)
        done = False
        while not done:
            status, done = downloader.next_chunk()
            if status:
                logger.debug('Download progress %.0f%%', status.progress() * 100)

    update_manifest_entry(
        manifest,
        file_id=file_id,
        modified=metadata.get('modifiedTime', ''),
        path=cache_path,
        mime_type=mime_type,
        name=metadata.get('name', cache_path.name),
    )
    return cache_path

In [None]:
def extract_docx_text(path: Path) -> str:
    doc = DocxDocument(str(path))
    lines = [paragraph.text.strip() for paragraph in doc.paragraphs if paragraph.text.strip()]
    return '\n'.join(lines)

def load_documents(
    path: Path,
    *,
    source_id: Optional[str] = None,
    file_type: Optional[str] = None,
    modified: Optional[str] = None,
    display_name: Optional[str] = None,
 ) -> list[Document]:
    suffix = (file_type or path.suffix or '.txt').lower()
    try:
        if suffix in {'.txt', '.md'}:
            loader = TextLoader(str(path), encoding='utf-8')
            documents = loader.load()
        elif suffix == '.docx':
            documents = [Document(page_content=extract_docx_text(path), metadata={'source': str(path)})]
        else:
            raise ValueError(f'Unsupported file type: {suffix}')
    except UnicodeDecodeError as exc:
        raise ValueError(f'Failed to read {path}: {exc}') from exc

    base_metadata = {
        metadata_key(Metadata.SOURCE): str(path),
        metadata_key(Metadata.FILE_TYPE): suffix,
        metadata_key(Metadata.TITLE): display_name or path.name,
    }
    if source_id:
        base_metadata[metadata_key(Metadata.ID)] = source_id
    if modified:
        base_metadata[metadata_key(Metadata.MODIFIED)] = modified

    cleaned: list[Document] = []
    for doc in documents:
        content = doc.page_content.strip()
        if not content:
            continue
        merged_metadata = {**doc.metadata, **base_metadata}
        doc.page_content = content
        doc.metadata = merged_metadata
        cleaned.append(doc)
    return cleaned

def preprocess(documents: Iterable[Document]) -> list[Document]:
    return [doc for doc in documents if doc.page_content]

def chunk_documents(doc: Document) -> list[Document]:
    parent_id = doc.metadata.get(metadata_key(Metadata.ID))
    if not parent_id:
        raise ValueError('Document is missing a stable identifier for chunking.')

    if doc.metadata.get(metadata_key(Metadata.FILE_TYPE)) == '.md':
        markdown_docs = MARKDOWN_SPLITTER.split_text(doc.page_content)
        seed_docs = [
            Document(page_content=section.page_content, metadata={**doc.metadata, **section.metadata})
            for section in markdown_docs
        ]
    else:
        seed_docs = [doc]

    chunks = DEFAULT_TEXT_SPLITTER.split_documents(seed_docs)
    for idx, chunk in enumerate(chunks):
        chunk.metadata[metadata_key(Metadata.PARENT_ID)] = parent_id
        chunk.metadata[metadata_key(Metadata.ID)] = f'{parent_id}::chunk-{idx:04d}'
        chunk.metadata.setdefault(metadata_key(Metadata.SOURCE), doc.metadata.get(metadata_key(Metadata.SOURCE)))
        chunk.metadata.setdefault(metadata_key(Metadata.TITLE), doc.metadata.get(metadata_key(Metadata.TITLE)))
    return chunks

In [None]:
def sync_drive_and_index(folder_id=None, selected_types=None, file_limit=None, _state: bool = False, progress=gr.Progress(track_tqdm=False)):
    folder = (folder_id or '').strip() or None

    selections = selected_types if selected_types is not None else DEFAULT_FILE_TYPE_LABELS
    if not isinstance(selections, (list, tuple)):
        selections = [selections]
    selections = list(selections)

    if len(selections) == 0:
        yield 'Select at least one file type before syncing.', False
        return

    chosen_keys: list[str] = []
    for item in selections:
        key = FILE_TYPE_LABEL_TO_KEY.get(item, item)
        if key in FILE_TYPE_OPTIONS:
            chosen_keys.append(key)

    if not chosen_keys:
        yield 'Select at least one file type before syncing.', False
        return

    allowed_mime_types = sorted({mime for key in chosen_keys for mime in FILE_TYPE_OPTIONS[key]['mime_types']})

    limit: Optional[int] = None
    limit_warning: Optional[str] = None
    if file_limit not in (None, '', 0):
        try:
            parsed_limit = int(file_limit)
            if parsed_limit > 0:
                limit = parsed_limit
            else:
                raise ValueError
        except (TypeError, ValueError):
            limit_warning = 'File limit must be a positive integer. Syncing all matching files instead.'

    log_lines: list[str] = []

    def push(message: str) -> str:
        log_lines.append(message)
        return '\n'.join(log_lines)

    if limit_warning:
        logger.warning(limit_warning)
        yield push(limit_warning), False

    progress(0, 'Authorizing Google Drive access...')
    yield push('Authorizing Google Drive access...'), False

    try:
        service = build_drive_service()
    except FileNotFoundError as exc:
        error_msg = f'Error: {exc}'
        logger.error(error_msg)
        yield push(error_msg), False
        return
    except Exception as exc:
        logger.exception('Drive authorization failed')
        error_msg = f'Error authenticating with Google Drive: {exc}'
        yield push(error_msg), False
        return

    list_message = 'Listing documents' + (f' (limit {limit})' if limit else '') + '...'
    progress(0, list_message)
    yield push(list_message), False

    try:
        files = list_drive_text_files(service, folder, allowed_mime_types, limit)
    except Exception as exc:
        logger.exception('Listing Drive files failed')
        error_msg = f'Error listing Google Drive files: {exc}'
        yield push(error_msg), False
        return

    total = len(files)
    if total == 0:
        info = 'No documents matching the selected types were found in Google Drive.'
        yield push(info), True
        return

    manifest = load_manifest()
    downloaded_count = 0

    for index, metadata in enumerate(files, start=1):
        file_id = metadata['id']
        name = metadata.get('name', file_id)
        remote_version = metadata.get('modifiedTime', '')
        manifest_entry = manifest.get(file_id)
        cache_path = cached_file_path(metadata)
        if isinstance(manifest_entry, dict) and manifest_entry.get('path'):
            cache_path = Path(manifest_entry['path'])
        cached_version = manifest_version(manifest_entry)

        if cached_version == remote_version and cache_path.exists():
            message = f"{index}/{total} Skipping cached file: {name} -> {cache_path}"
            progress(index / total, message)
            yield push(message), False
            continue

        download_message = f"{index}/{total} Downloading {name} -> {cache_path}"
        progress(max((index - 0.5) / total, 0), download_message)
        yield push(download_message), False

        try:
            downloaded_path = download_drive_file(service, metadata, manifest)
            index_message = f"{index}/{total} Indexing {downloaded_path.name}"
            progress(index / total, index_message)
            yield push(index_message), False
            index_document(
                downloaded_path,
                source_id=file_id,
                file_type=downloaded_path.suffix,
                modified=remote_version,
                display_name=name,
                manifest=manifest,
            )
            downloaded_count += 1
        except Exception as exc:
            error_message = f"{index}/{total} Failed to sync {name}: {exc}"
            logger.exception(error_message)
            progress(index / total, error_message)
            yield push(error_message), False

    if downloaded_count > 0:
        save_manifest(manifest)
        summary = f'Indexed {downloaded_count} new document(s) from Google Drive.'
    else:
        summary = 'Google Drive is already in sync.'

    progress(1, summary)
    yield push(summary), True

## RAG Pipeline

In [None]:
def persist_vectorstore(_store) -> None:
    """In-memory mode: Chroma client does not persist between sessions."""
    return


def index_document(
    file_path: Path | str,
    *,
    source_id: Optional[str] = None,
    file_type: Optional[str] = None,
    modified: Optional[str] = None,
    display_name: Optional[str] = None,
    manifest: Optional[dict] = None,
 ) -> tuple[str, int]:
    path = Path(file_path)
    path = path.expanduser().resolve()
    resolved_id = source_id or f'local::{hash_file(path)}'
    documents = load_documents(
        path,
        source_id=resolved_id,
        file_type=file_type,
        modified=modified,
        display_name=display_name,
    )
    documents = preprocess(documents)
    if not documents:
        logger.warning('No readable content found in %s; skipping.', path)
        return resolved_id, 0

    total_chunks = 0
    for doc in documents:
        doc_id = doc.metadata.get(metadata_key(Metadata.ID), resolved_id)
        doc.metadata[metadata_key(Metadata.ID)] = doc_id
        vectorstore.delete(where={metadata_key(Metadata.PARENT_ID): doc_id})
        chunks = chunk_documents(doc)
        if not chunks:
            continue
        vectorstore.add_documents(chunks)
        docstore.mset([(doc_id, doc)])
        total_chunks += len(chunks)

    persist_vectorstore(vectorstore)
    if manifest is not None and not source_id:
        update_manifest_entry(
            manifest,
            file_id=resolved_id,
            modified=hash_file(path),
            path=path,
            mime_type=file_type or Path(path).suffix or '.txt',
            name=display_name or path.name,
        )
    return resolved_id, total_chunks

### LLM Interaction

In [None]:
def retrieve_context(query: str, *, top_k: int = 8, distance_threshold: Optional[float] = SIMILARITY_DISTANCE_MAX):
    results_with_scores = vectorstore.similarity_search_with_score(query, k=top_k)
    logger.info(f'Matching records: {len(results_with_scores)}')

    filtered: list[tuple[Document, float]] = []
    for doc, score in results_with_scores:
        if score is None:
            continue
        score_value = float(score)
        print(f'DEBUG: Retrieved doc source={doc.metadata.get(metadata_key(Metadata.SOURCE))} distance={score_value}')
        if distance_threshold is not None and score_value > distance_threshold:
            logger.debug(
                'Skipping %s with distance %.4f (above threshold %.4f)',
                doc.metadata.get(metadata_key(Metadata.SOURCE)),
                score_value,
                distance_threshold,
            )
            continue
        filtered.append((doc, score_value))

    if not filtered:
        return []

    for doc, score_value in filtered:
        parent_id = doc.metadata.get(metadata_key(Metadata.PARENT_ID))
        if parent_id:
            parent_doc = docstore.mget([parent_id])[0]
            if parent_doc and parent_doc.page_content:
                logger.debug(
                    'Parent preview (%s | %.3f): %s',
                    doc.metadata.get(metadata_key(Metadata.SOURCE), 'unknown'),
                    score_value,
                    parent_doc.page_content[:400].replace('\n', ' '),
                )

    return filtered


def build_prompt_sections(relevant_docs: list[tuple[Document, float]]) -> str:
    sections: list[str] = []
    for idx, (doc, score) in enumerate(relevant_docs, start=1):
        source = doc.metadata.get(metadata_key(Metadata.SOURCE), 'unknown')
        snippet = doc.page_content.strip()[:MAX_CONTEXT_SNIPPET_CHARS]
        section = (
            f'[{idx}] Source: {source}\n'
            f'Distance: {score:.3f}\n'
            f'Content:\n{snippet}'
        )
        sections.append(section)
    return '\n\n'.join(sections)


def ask(message, history):
    relevant_docs = retrieve_context(message)
    if not relevant_docs:
        yield "I don't have enough information in the synced documents to answer that yet. Please sync additional files or adjust the filters."
        return

    context = build_prompt_sections(relevant_docs)
    prompt = f'''
    You are a retrieval-augmented assistant. Use ONLY the facts provided in the context to answer the user.
    If the context does not contain the answer, reply exactly: "I don't have enough information in the synced documents to answer that yet. Please sync additional files."
    
    Context:\n{context}
    '''

    messages = [
        ('system', prompt),
        ('user', message)
    ]

    stream = model.stream(messages)
    response_text = ''

    for chunk in stream:
        response_text += chunk.content or ''
        if not response_text:
            continue

        yield response_text

## Gradio UI

In [None]:
def chat(message, history, sync_ready):
    if message is None:
        return ''

    text_input = message.get('text', '')
    files_uploaded = message.get('files', [])
    latest_file_path = Path(files_uploaded[-1]) if files_uploaded else None
    if latest_file_path:
        manifest = load_manifest()
        doc_id, chunk_count = index_document(
            latest_file_path,
            file_type=latest_file_path.suffix,
            display_name=latest_file_path.name,
            manifest=manifest,
        )
        save_manifest(manifest)
        logger.info('Indexed upload %s as %s with %s chunk(s)', latest_file_path, doc_id, chunk_count)
        if not text_input:
            yield f'Indexed document from upload ({chunk_count} chunk(s)).'
            return

    if not text_input:
        return ''

    if not sync_ready and not files_uploaded:
        yield 'Sync Google Drive before chatting or upload a document first.'
        return

    for chunk in ask(text_input, history):
        yield chunk

title = "Drive Sage"
with gr.Blocks(title=title, fill_height=True, css=CUSTOM_CSS) as ui:
    gr.Markdown(f'# {title}')
    gr.Markdown('## Search your Google Drive knowledge base with fully local processing.')
    sync_state = gr.State(False)

    with gr.Row():
        with gr.Column(scale=3, elem_id='chat-column'):
            gr.ChatInterface(
                fn=chat,
                chatbot=gr.Chatbot(height='80vh', elem_id='chat-output'),
                type='messages',
                textbox=gr.MultimodalTextbox(
                    file_types=['text', '.txt', '.md'],
                    autofocus=True,
                    elem_id='chat-input',
                ),
                additional_inputs=[sync_state],
            )
        with gr.Column(scale=2, min_width=320):
            gr.Markdown('### Google Drive Sync')
            drive_folder = gr.Textbox(
                label='Folder ID (optional)',
                placeholder='Leave blank to scan My Drive root',
            )
            file_types = gr.CheckboxGroup(
                label='File types to sync',
                choices=[config['label'] for config in FILE_TYPE_OPTIONS.values()],
                value=DEFAULT_FILE_TYPE_LABELS,
            )
            file_limit = gr.Number(
                label='Max files to sync (leave blank for all)',
                value=20,
            )
            sync_btn = gr.Button('Sync Google Drive')
            sync_status = gr.Markdown('No sync performed yet.')

            sync_btn.click(
                sync_drive_and_index,
                inputs=[drive_folder, file_types, file_limit, sync_state],
                outputs=[sync_status, sync_state],
            )

ui.launch(debug=True)