In [1]:
from IPython.display import Image, display
from ftplib import FTP
import numpy as np
import re
import warnings
import tempfile
import os
import tarfile
import scanpy as sc
import gzip
import shutil
import pandas as pd
import anndata as ad
import sys

In [2]:
# get the Claude API key from local text file
# check if we're on MacOS or Windows and read appropriate file
if sys.platform.startswith("win"):
    with open("C:/Users/David/.claude_api.txt") as f:
        claude_key = f.read().strip()
else:
    with open("/Users/tatarakis/.api-keys/tatarakis-test-key.txt") as f:
        claude_key = f.read().strip()

os.environ['ANTHROPIC_API_KEY'] = claude_key

In [3]:
from langchain.chat_models import init_chat_model

model = init_chat_model(
    "claude-sonnet-4-5-20250929",
    timeout=120,
    max_tokens=10000
)

  from pydantic.v1.fields import FieldInfo as FieldInfoV1


# Initial Test

In [4]:
from langchain.agents import create_agent

def get_weather(city: str) -> str:
    """Get weather for a given city."""
    return f"It's always sunny in {city}!"

agent = create_agent(
    model=model,
    tools=[get_weather],
    system_prompt="You are a helpful assistant",
)

# Run the agent
agent.invoke(
    {"messages": [{"role": "user", "content": "what is the weather in sf"}]}
)

{'messages': [HumanMessage(content='what is the weather in sf', additional_kwargs={}, response_metadata={}, id='92ec6e34-9fd8-4277-9034-a1a8fe05af4b'),
  AIMessage(content=[{'id': 'toolu_012gSyNivdQdzeTWJ73PdfX9', 'input': {'city': 'San Francisco'}, 'name': 'get_weather', 'type': 'tool_use'}], additional_kwargs={}, response_metadata={'id': 'msg_01Euj9uQNxmUwB23At5CwYCy', 'model': 'claude-sonnet-4-5-20250929', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'cache_creation': {'ephemeral_1h_input_tokens': 0, 'ephemeral_5m_input_tokens': 0}, 'cache_creation_input_tokens': 0, 'cache_read_input_tokens': 0, 'input_tokens': 569, 'output_tokens': 54, 'server_tool_use': None, 'service_tier': 'standard'}, 'model_name': 'claude-sonnet-4-5-20250929', 'model_provider': 'anthropic'}, id='lc_run--019b9498-2542-7c60-976f-f6c0f24f2b7c-0', tool_calls=[{'name': 'get_weather', 'args': {'city': 'San Francisco'}, 'id': 'toolu_012gSyNivdQdzeTWJ73PdfX9', 'type': 'tool_call'}], invalid_tool_calls=[

# Tools

We need to define the set of tools the agent can use.

In [5]:
from langchain.tools import tool

################
# HELPER TOOLS #
################

# These functions are not called by the agent directly, but are used internally by other tools.
def get_geo_ftp_path(accession: str) -> str:
    """
    Return the FTP directory for a GEO accession (GSE or GSM).
    """
    prefix = accession[:3]     # GSE or GSM
    number = accession[3:]
    chunk = prefix + number[:-3] + "nnn"

    if prefix == "GSE":
        return f"/geo/series/{chunk}/{accession}/suppl/"
    elif prefix == "GSM":
        return f"/geo/samples/{chunk}/{accession}/suppl/"
    else:
        raise ValueError("Only GSE or GSM supported")

def list_tar_contents_helper(file_name: str):
    """List contents of a tar file."""
    contents = []
    with tarfile.open(file_name, "r:*") as tar:
        for member in tar.getmembers():
            contents.append(member.name)
    return contents

def unpack_tar_file_helper(tar_file_path: str, output_dir: str):
    """Unpack a tar file to a specified directory."""
    with tarfile.open(tar_file_path, "r:*") as tar:
        tar.extractall(path=output_dir)
        return [member.name for member in tar.getmembers()]

#########
# TOOLS #
#########

# These are the functions that the agent can call.
@tool
def list_geo_files(accession: str) -> list:
    """
    List supplementary files for a GEO accession (GSE or GSM).
    Returns a list of available files.
    """
    ftp = FTP("ftp.ncbi.nlm.nih.gov")
    ftp.login()

    path = get_geo_ftp_path(accession)
    try:
        ftp.cwd(path)
    except:
        try:
            path = re.sub(r"suppl/$", "", path)
            ftp.cwd(path)
            warnings.warn(f"No supplementary files for: {accession}")
        except:
            ftp.quit()
            raise FileNotFoundError(f"Could not find FTP path: {path}")

    files = ftp.nlst()
    ftp.quit()
    return files


@tool
def download_geo_supp_file(accession: str, file_name: str, output_dir: str) -> str:
    """
    Download a supplementary file for a GEO accession (GSE or GSM).
    Creates the output directory if it doesn't exist.
    Returns the path to the downloaded file.
    """
    os.makedirs(output_dir, exist_ok=True)
    
    ftp = FTP("ftp.ncbi.nlm.nih.gov")
    ftp.login()
    
    path = get_geo_ftp_path(accession)
    try:
        ftp.cwd(path)
    except:
        try:
            path = re.sub(r"suppl/$", "", path)
            ftp.cwd(path)
            warnings.warn(f"No supplementary files for: {accession}")
        except:
            ftp.quit()
            raise FileNotFoundError(f"Could not find FTP path: {path}")

    local_file_path = os.path.join(output_dir, file_name)
    with open(local_file_path, "wb") as f:
        try:
            ftp.retrbinary(f"RETR {file_name}", f.write)
        except Exception as e:
            ftp.quit()
            raise e
    ftp.quit()
    
    print(f"Downloaded: {local_file_path}")
    return local_file_path

# Let's the agent list any directory it needs to.
@tool
def list_directory(directory: str) -> list:
    """
    List all files and folders in a directory.
    """
    return os.listdir(directory)

# This tool lists contents of all tar files in a directory.
@tool
def batch_list_tar_contents(directory: str) -> dict:
    """
    List contents of all tar files in a directory.
    Returns a dict mapping tar filenames to their contents.
    """
    tar_contents = {}
    for file in os.listdir(directory):
        if file.endswith(".tar") or file.endswith(".tar.gz") or file.endswith(".tgz"):
            tar_file_path = os.path.join(directory, file)
            contents = list_tar_contents_helper(tar_file_path)
            tar_contents[file] = contents
    return tar_contents

# This tool unpacks all tar files in a directory.
@tool
def batch_unpack_tar_files(directory: str) -> list:
    """
    Unpack all tar files in a directory.
    Returns list of unpacked file names.
    """
    unpacked_files = []
    for file in os.listdir(directory):
        if file.endswith(".tar") or file.endswith(".tar.gz") or file.endswith(".tgz"):
            tar_file_path = os.path.join(directory, file)
            unpacked = unpack_tar_file_helper(tar_file_path, directory)
            unpacked_files.extend(unpacked)
            print(f"Unpacked {file}: {len(unpacked)} files")
    return unpacked_files

# This tool renames files in a directory according to 10x Genomics conventions.
@tool
def rename_geo_files(directory: str) -> dict:
    """
    Rename files in a directory according to 10x Genomics conventions.
    Returns the rename mapping.
    """
    files = os.listdir(directory)
    
    matrix = None
    features = None
    barcodes = None

    for f in files:
        n = f.lower()
        # Matrix file
        if "mtx" in n:
            matrix = f
            continue
        # Features file
        if any(x in n for x in ["gene", "feature", "symbol"]):
            features = f
            continue
        # Barcodes file
        if any(x in n for x in ["barcode", "cell"]):
            barcodes = f
            continue

    if not (matrix and features and barcodes):
        raise ValueError(
            f"Could not find all required files in {directory}. "
            f"Found matrix={matrix}, features={features}, barcodes={barcodes}"
        )

    rename_map = {
        matrix: "matrix.mtx",
        features: "features.tsv",
        barcodes: "barcodes.tsv"
    }

    for key, value in list(rename_map.items()):
        if key.endswith(".gz"):
            rename_map[key] = value + ".gz"

    for old, new in rename_map.items():
        src = os.path.join(directory, old)
        dst = os.path.join(directory, new)
        shutil.move(src, dst)
        print(f"Renamed {old} → {new}")

    return rename_map

# This tool structures a directory to conform to 10x Genomics file organization.
@tool
def structure_10x_directory(directory: str) -> str:
    """
    Structure a directory to conform to 10x Genomics file organization.
    Creates a '10x_counts' subdirectory and moves relevant files there.
    Returns the path to the 10x_counts directory.
    """
    counts_directory = os.path.join(directory, "10x_counts")
    os.makedirs(counts_directory, exist_ok=True)

    for file in os.listdir(directory):
        if file in ["matrix.mtx", "matrix.mtx.gz", "features.tsv", "features.tsv.gz", "barcodes.tsv", "barcodes.tsv.gz"]:
            src = os.path.join(directory, file)
            dst = os.path.join(counts_directory, file)
            shutil.move(src, dst)
            print(f"Moved {file} to 10x_counts/")
    
    return counts_directory

# This tool converts CSV files to TSV files (which scanpy expects).
@tool
def convert_csv_to_tsv(file_path: str) -> str:
    """
    Convert CSV files to TSV files, handling both uncompressed and gzipped files.
    Returns the path to the TSV file.
    """
    if file_path.endswith(".csv"):
        tsv_file_path = file_path[:-4] + ".tsv"
        with open(file_path, "r") as csv_file, open(tsv_file_path, "w") as tsv_file:
            for line in csv_file:
                tsv_file.write(line.replace(",", "\t"))
        return tsv_file_path

    if file_path.endswith(".csv.gz"):
        tsv_file_path = file_path[:-7] + ".tsv.gz"
        with gzip.open(file_path, "rt") as csv_file, gzip.open(tsv_file_path, "wt") as tsv_file:
            for line in csv_file:
                tsv_file.write(line.replace(",", "\t"))
        return tsv_file_path

    return file_path

# Gets the dimensions of a sparse matrix from a Matrix Market file.
@tool
def get_matrix_dimensions(matrix_file_path: str) -> tuple:
    """
    Get the dimensions of a sparse matrix from a Matrix Market file.
    Returns (rows, cols, nnz).
    """
    opener = gzip.open if matrix_file_path.endswith(".gz") else open

    with opener(matrix_file_path, "rt") as f:
        for line in f:
            line = line.strip()
            if line.startswith("%") or line.startswith("%%"):
                continue
            parts = line.split()
            if len(parts) == 3:
                rows, cols, nnz = map(int, parts)
                return rows, cols, nnz
    raise ValueError("Could not determine matrix dimensions from file.")

# This tool checks and reformats features file to conform to 10x Genomics conventions.
@tool
def format_features_file(features_file_path: str, matrix_dimensions: tuple) -> str:
    """
    Check and reformat features file to conform to 10x Genomics conventions.
    Returns the path to the reformatted file.
    """
    if features_file_path.endswith(".gz"):
        with gzip.open(features_file_path, "rt") as f:
            lines = f.readlines()
    else:
        with open(features_file_path, "r") as f:
            lines = f.readlines()
    
    nrows = len(lines)
    has_header = (nrows == matrix_dimensions[0] + 1)
    
    if has_header:
        rows = [line.strip().split("\t") for line in lines[1:]]
    else:
        rows = [line.strip().split("\t") for line in lines]

    features_df = pd.DataFrame(rows)
    if features_df.shape[1] == 1:
        features_df[1] = features_df.iloc[:, 0]
        features_df[2] = "gene"
    elif features_df.shape[1] == 2:
        features_df[2] = "gene"
    features_df = features_df.iloc[:, :3]
    
    if features_file_path.endswith(".gz"):
        with gzip.open(features_file_path, "wt") as f:
            features_df.to_csv(f, sep="\t", index=False, header=False)
    else:
        with open(features_file_path, "w") as f:
            features_df.to_csv(f, sep="\t", index=False, header=False)

    return features_file_path

# This tool checks and reformats barcodes file to conform to 10x Genomics conventions.
@tool
def format_barcodes_file(barcodes_file_path: str, matrix_dimensions: tuple) -> str:
    """
    Check and reformat barcodes file to conform to 10x Genomics conventions.
    Returns the path to the reformatted file.
    """
    if barcodes_file_path.endswith(".gz"):
        with gzip.open(barcodes_file_path, "rt") as f:
            lines = f.readlines()
    else:
        with open(barcodes_file_path, "r") as f:
            lines = f.readlines()
    
    nrows = len(lines)
    has_header = (nrows == matrix_dimensions[1] + 1)
    
    if has_header:
        rows = [line.strip().split("\t") for line in lines[1:]]
    else:
        rows = [line.strip().split("\t") for line in lines]

    barcodes_df = pd.DataFrame(rows)
    barcodes_df = barcodes_df.iloc[:, :1]
    
    tenx_pattern = r"([ACGTN]{16,20}-\d+)"
    barcodes_df[barcodes_df.columns[0]] = barcodes_df[barcodes_df.columns[0]].str.extract(tenx_pattern)

    if barcodes_file_path.endswith(".gz"):
        with gzip.open(barcodes_file_path, "wt") as f:
            barcodes_df.to_csv(f, sep="\t", index=False, header=False)
    else:
        with open(barcodes_file_path, "w") as f:
            barcodes_df.to_csv(f, sep="\t", index=False, header=False)

    return barcodes_file_path

# This tool builds an anndata object from 10x formatted counts matrix.
@tool
def build_anndata(counts_directory: str, sample_name: str, output_dir: str) -> str:
    """
    Build an anndata object from 10x formatted counts matrix.
    Saves to output_dir/adatas/sample_name.h5ad.
    Returns the path to the saved file.
    """
    adata = sc.read_10x_mtx(counts_directory)
    adata.obs["sample_name"] = sample_name
    
    adata_dir = os.path.join(output_dir, "adatas")
    os.makedirs(adata_dir, exist_ok=True)
    
    saved_file_path = os.path.join(adata_dir, f"{sample_name}.h5ad")
    adata.write_h5ad(saved_file_path)

    if os.path.exists(saved_file_path):
        print(f"AnnData object successfully saved at: {saved_file_path}")
    else:
        raise FileNotFoundError(f"Failed to save AnnData object at: {saved_file_path}")
    
    return saved_file_path

In [6]:
from langchain.agents import create_agent

agent = create_agent(
    model=model,
    tools=[
        list_geo_files,
        download_geo_supp_file,
        list_directory,
        batch_list_tar_contents,
        batch_unpack_tar_files,
        rename_geo_files,
        structure_10x_directory,
        convert_csv_to_tsv,
        get_matrix_dimensions,
        format_features_file,
        format_barcodes_file,
        build_anndata
    ],
    system_prompt="""You are a helpful assistant specialized in processing GEO single-cell RNA-seq datasets.

When processing a GEO dataset:
1. List available files with list_geo_files
2. Download relevant files using download_geo_supp_file (creates directories automatically)
3. Use list_directory to see what was downloaded
4. Extract tar archives with batch_unpack_tar_files
5. Use list_directory to examine extracted files
6. Rename files to 10x conventions with rename_geo_files
7. Convert CSV to TSV if needed
8. Structure the directory with structure_10x_directory
9. Check matrix dimensions and fix features/barcodes if needed
10. Build the AnnData object with build_anndata

Always use the paths returned by tools in subsequent tool calls.
"""
)

In [7]:
test = agent.invoke(
    {"messages": [{
        "role": "user", 
        "content": """Tell me what you are designed to do and what you need."""
    }]}
)

In [8]:
print(test["messages"][-1].content)

I am a specialized assistant designed to process GEO (Gene Expression Omnibus) single-cell RNA-seq datasets and convert them into AnnData objects that can be used for downstream analysis.

## What I can do:

1. **Download GEO data**: Retrieve supplementary files from GEO accessions (GSE or GSM)
2. **Extract archives**: Unpack tar files to access the data
3. **File management**: List, rename, and organize files according to 10x Genomics conventions
4. **Format conversion**: Convert CSV files to TSV format when needed
5. **Data validation**: Check matrix dimensions and ensure features/barcodes files are properly formatted
6. **Build AnnData objects**: Create standardized .h5ad files from count matrices

## What I need from you:

To process a dataset, I need:

- **GEO Accession ID** (e.g., GSE123456) - the dataset you want to process
- **Output directory** (optional) - where you want the processed data saved (I can use a default if not specified)
- **Sample name** (optional for some steps

In [9]:
# Set working directory for the agent to use
WORKING_DIR = "/Users/tatarakis/data/langchain_geo_test_data"

# Run the agent. Pass the working directory in the prompt
result = agent.invoke({
    "messages": [{
        "role": "user", 
        "content": f"""Download counts data for GSE209912, extract the 10x Genomics counts matrix, 
reformat it as needed, and build an AnnData object from it.

Use {WORKING_DIR} as the base directory for all file operations."""
    }]
})

Downloaded: /Users/tatarakis/data/langchain_geo_test_data/GSE209912_barcodes.csv.gz
Downloaded: /Users/tatarakis/data/langchain_geo_test_data/GSE209912_symbols.csv.gz
Downloaded: /Users/tatarakis/data/langchain_geo_test_data/GSE209912_counts.mtx.gz
Renamed GSE209912_counts.mtx.gz → matrix.mtx.gz
Renamed GSE209912_symbols.tsv.gz → features.tsv.gz
Renamed GSE209912_barcodes.csv.gz → barcodes.tsv.gz
Moved features.tsv.gz to 10x_counts/
Moved barcodes.tsv.gz to 10x_counts/
Moved matrix.mtx.gz to 10x_counts/


  utils.warn_names_duplicates("obs")


AnnData object successfully saved at: /Users/tatarakis/data/langchain_geo_test_data/adatas/GSE209912.h5ad


# Batch Test

In [10]:
test_accessions = [
    "GSE174188",
    "GSE209912",
    "GSE188367",
    "GSE136103"
]

In [None]:
for chunk in agent.stream({
    "messages": [{
        "role": "user", 
        "content": f"""Download counts data for assession GSE188367, keep an eye out for tar archives and extract as needed. Extract the 10x Genomics counts matrix, 
    reformat as needed, and build an AnnData object from it.

    Use {WORKING_DIR} as the base directory for all file operations."""
        }]
}):

    # chunk is a dict with node name as key
    for node_name, node_output in chunk.items():
        print(f"\n{'='*50}")
        print(f"Node: {node_name}")
        print(f"{'='*50}")
        
        if "messages" in node_output:
            for msg in node_output["messages"]:
                print(f"Type: {type(msg).__name__}")
                if hasattr(msg, "tool_calls") and msg.tool_calls:
                    for tc in msg.tool_calls:
                        print(f"  Tool: {tc['name']}")
                        print(f"  Args: {tc['args']}")
                elif hasattr(msg, "content"):
                    print(f"  Content: {msg.content}")


#     "messages": [{
#         "role": "user", 
#         "content": f"""Download counts data for each accession in {test_accessions}, keep an eye out for tar archives and extract as needed. Extract the 10x Genomics counts matrix, 
# reformat as needed, and build an AnnData object from it.

# Use {WORKING_DIR} as the base directory for all file operations."""
#     }]
# })


Node: model
Type: AIMessage
  Tool: list_geo_files
  Args: {'accession': 'GSE188367'}

Node: tools
Type: ToolMessage
  Content: ["GSE188367_atac_tf_counts.tar.gz", "filelist.txt", "GSE188367_RAW.tar"]...

Node: model
Type: AIMessage
  Tool: download_geo_supp_file
  Args: {'accession': 'GSE188367', 'file_name': 'GSE188367_RAW.tar', 'output_dir': '/Users/tatarakis/data/langchain_geo_test_data'}
  Tool: download_geo_supp_file
  Args: {'accession': 'GSE188367', 'file_name': 'GSE188367_atac_tf_counts.tar.gz', 'output_dir': '/Users/tatarakis/data/langchain_geo_test_data'}
Downloaded: /Users/tatarakis/data/langchain_geo_test_data/GSE188367_atac_tf_counts.tar.gz

Node: tools
Type: ToolMessage
  Content: /Users/tatarakis/data/langchain_geo_test_data/GSE188367_atac_tf_counts.tar.gz...
Downloaded: /Users/tatarakis/data/langchain_geo_test_data/GSE188367_RAW.tar

Node: tools
Type: ToolMessage
  Content: /Users/tatarakis/data/langchain_geo_test_data/GSE188367_RAW.tar...

Node: model
Type: AIMessage