In [5]:
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from IPython.display import Image, display
from ftplib import FTP
import numpy as np
import re
import warnings
import tempfile
import os
import tarfile
import scanpy as sc
import gzip
import shutil
import pandas as pd
import anndata as ad
import sys

In [6]:
# get the Claude API key from local text file
# check if we're on MacOS or Windows and read appropriate file
if sys.platform.startswith("win"):
    with open("C:/Users/David/.claude_api.txt") as f:
        claude_key = f.read().strip()
else:
    with open("/Users/tatarakis/.api-keys/tatarakis-test-key.txt") as f:
        claude_key = f.read().strip()

os.environ['ANTHROPIC_API_KEY'] = claude_key

# Initial Test

In [8]:
from langchain.agents import create_agent

def get_weather(city: str) -> str:
    """Get weather for a given city."""
    return f"It's always sunny in {city}!"

agent = create_agent(
    model="claude-sonnet-4-5-20250929",
    tools=[get_weather],
    system_prompt="You are a helpful assistant",
)

# Run the agent
agent.invoke(
    {"messages": [{"role": "user", "content": "what is the weather in sf"}]}
)

{'messages': [HumanMessage(content='what is the weather in sf', additional_kwargs={}, response_metadata={}, id='f8b6e9bc-f6a9-434b-8f8e-82c72a7c373e'),
  AIMessage(content=[{'id': 'toolu_018siAXHhZCV5dUKSpKe6Zqp', 'input': {'city': 'San Francisco'}, 'name': 'get_weather', 'type': 'tool_use'}], additional_kwargs={}, response_metadata={'id': 'msg_01GgAKCXgECBaRLfMPNt3LFC', 'model': 'claude-sonnet-4-5-20250929', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'cache_creation': {'ephemeral_1h_input_tokens': 0, 'ephemeral_5m_input_tokens': 0}, 'cache_creation_input_tokens': 0, 'cache_read_input_tokens': 0, 'input_tokens': 569, 'output_tokens': 54, 'server_tool_use': None, 'service_tier': 'standard'}, 'model_name': 'claude-sonnet-4-5-20250929', 'model_provider': 'anthropic'}, id='lc_run--019b8ac9-08e6-7200-88b7-536301492fb4-0', tool_calls=[{'name': 'get_weather', 'args': {'city': 'San Francisco'}, 'id': 'toolu_018siAXHhZCV5dUKSpKe6Zqp', 'type': 'tool_call'}], invalid_tool_calls=[

# State Structure

Each node/task the agent performs will need to keep track of inputs and outputs. The graph state class allows us to provide all of the possible inputs and outputs the agent will need, such that all nodes can see it and update it as needed.

# Tools

We need to define the set of tools the agent can use.

In [None]:
def get_geo_ftp_path(accession: str) -> str:
    """
    Return the FTP directory for a GEO accession (GSE or GSM).
    """
    prefix = accession[:3]     # GSE or GSM
    number = accession[3:]
    chunk = prefix + number[:-3] + "nnn"

    # the ftp site stores series and samples in directories named by the accession number with the last three digits replaced by 'nnn'
    if prefix == "GSE":
        return f"/geo/series/{chunk}/{accession}/suppl/"
    elif prefix == "GSM":
        return f"/geo/samples/{chunk}/{accession}/suppl/"
    else:
        raise ValueError("Only GSE or GSM supported")

def list_geo_files(accession: str):

    ftp = FTP("ftp.ncbi.nlm.nih.gov")
    ftp.login()

    path = get_geo_ftp_path(accession)
    try:
        ftp.cwd(path)
    except:
        try:
            path = re.sub(r"suppl/$", "", path)
            ftp.cwd(path)
            warnings.warn(f"No supplementary files for: {accession}")
        except:
            raise FileNotFoundError(f"Could not find FTP path: {path}")

    files = ftp.nlst()
    ftp.quit()
    return files

def download_geo_supp_file(accession: str, file_name:str, output_dir: str):
    ftp = FTP("ftp.ncbi.nlm.nih.gov")
    ftp.login()
    
    path = get_geo_ftp_path(accession)
    try:
        ftp.cwd(path)
    except:
        try:
            path = re.sub(r"suppl/$", "", path)
            ftp.cwd(path)
            warnings.warn(f"No supplementary files for: {accession}")
        except:
            raise FileNotFoundError(f"Could not find FTP path: {path}")

    local_file_path = os.path.join(output_dir, file_name)
    with open(local_file_path, "wb") as f:
        try:
            ftp.retrbinary(f"RETR {file_name}", f.write)
        except Exception as e:
            ftp.quit()
            raise e
    ftp.quit()
    return local_file_path

def list_tar_contents(file_name: str):
    contents = []
    with tarfile.open(file_name, "r:*") as tar:
        for member in tar.getmembers():
            print(member.name)
            contents.append(member.name)
    return contents

# adding a tool to batch view contents of tar files
def batch_list_tar_contents(directory: str):
    tar_contents = {}
    for file in os.listdir(directory):
        if file.endswith(".tar") or file.endswith(".tar.gz") or file.endswith(".tgz"):
            tar_file_path = os.path.join(directory, file)
            contents = list_tar_contents(tar_file_path)
            tar_contents[file] = contents
    return tar_contents

def unpack_tar_file(tar_file_path: str, output_dir: str):
    with tarfile.open(tar_file_path, "r") as tar:
        tar.extractall(path=output_dir)
        return [member.name for member in tar.getmembers()]

# adding a tool for batch unpacking tar files
def batch_unpack_tar_files(directory: str):
    unpacked_files = []
    for file in os.listdir(directory):
        if file.endswith(".tar") or file.endswith(".tar.gz") or file.endswith(".tgz"):
            tar_file_path = os.path.join(directory, file)
            unpacked = unpack_tar_file(tar_file_path, directory)
            unpacked_files.extend(unpacked)
    return unpacked_files

# adding a tool to let the agent look at the directory
def list_directory(directory: str) -> list:
    """List all files and folders in a directory."""
    return os.listdir(directory)

def build_anndata(counts_directory: str, sample_name: str, outdir: str):
    adata = sc.read_10x_mtx(counts_directory)
    
    # add sample name to obs and store anndata in dictionary
    adata.obs["sample_name"] = sample_name
    # adatas[sample_name] = adata
    
    # make a subdirectory to store anndata files
    adata_dir = os.path.join(outdir, "adatas")
    os.makedirs(adata_dir, exist_ok=True)
    adata.write_h5ad(os.path.join(adata_dir, f"{sample_name}.h5ad"))

    # check that the file was actually saved
    saved_file_path = os.path.join(adata_dir, f"{sample_name}.h5ad")
    if os.path.exists(saved_file_path):
        print(f"Anndata object successfully saved at: {saved_file_path}")
    else:
        raise FileNotFoundError(f"Failed to save anndata object at: {saved_file_path}")
    
    return saved_file_path

# Rename files according to 10x Genomics conventions
def rename_geo_files(directory: str):
    files = os.listdir(directory)
    
    matrix = None
    features = None
    barcodes = None

    for f in files:
        n = f.lower()

        # matrix
        if "mtx" in n:
            matrix = f
            continue

        # features (genes)
        if any(x in n for x in ["gene", "feature", "symbol"]):
            features = f
            continue

        # barcodes (cells)
        if any(x in n for x in ["barcode", "cell"]):
            barcodes = f
            continue

    # Safety check
    if not (matrix and features and barcodes):
        raise ValueError(
            f"Could not find all required files in {directory}. "
            f"Found matrix={matrix}, features={features}, barcodes={barcodes}"
        )

    rename_map = {
        matrix: "matrix.mtx",
        features: "features.tsv",
        barcodes: "barcodes.tsv"
    }

    # check if files are gzipped and add appropriate extension to the new name
    for key, value in list(rename_map.items()):
        if key.endswith(".gz"):
            rename_map[key] = value + ".gz"

    # apply the new names by using a bash mv command
    for old, new in rename_map.items():
        src = os.path.join(directory, old)
        dst = os.path.join(directory, new)
        shutil.move(src, dst)
        print(f"Renamed {src} → {dst}")

    return directory, rename_map

# structure 10x directory
def structure_10x_directory(directory):
    # create a new directory for the 10x files
    counts_directory = os.path.join(directory, "10x_counts")
    os.makedirs(counts_directory, exist_ok=True)

    # move the relevant files to the new directory
    for file in os.listdir(directory):
        if file in ["matrix.mtx", "matrix.mtx.gz", "features.tsv", "features.tsv.gz", "barcodes.tsv", "barcodes.tsv.gz"]:
            src = os.path.join(directory, file)
            dst = os.path.join(counts_directory, file)
            shutil.move(src, dst)
    
    return counts_directory

# convert csv files to tsv files
def convert_csv_to_tsv(file_path):
        
    # handle uncompressed csvs
    if file_path.endswith(".csv"):
        tsv_file_path = file_path[:-4] + ".tsv" # change suffix
        with open(file_path, "r") as csv_file, open(tsv_file_path, "w") as tsv_file:
            for line in csv_file:
                tsv_file.write(line.replace(",", "\t"))

    # handle gzipped csvs
    if file_path.endswith(".csv.gz"):
        tsv_file_path = file_path[:-7] + ".tsv.gz" # change suffix
        with gzip.open(file_path, "rt") as csv_file, gzip.open(tsv_file_path, "wt") as tsv_file:
            for line in csv_file:
                tsv_file.write(line.replace(",", "\t"))

    if not file_path.endswith(".csv.gz") and not file_path.endswith(".csv"):
        tsv_file_path = file_path  # no conversion needed
    return tsv_file_path

# a simple function to get the dimensions of the counts matrix to help with reformatting input files
def get_matrix_dimensions(matrix_file_path: str) -> tuple:
    # get dimensions of sparse matrix
    opener = gzip.open if matrix_file_path.endswith(".gz") else open

    with opener(matrix_file_path, "rt") as f:
        for line in f:
            line = line.strip()
            # skip comments and header
            if line.startswith("%") or line.startswith("%%"):
                continue

            # first non-comment line should be: rows cols nnz
            parts = line.split()
            if len(parts) == 3:
                rows, cols, nnz = map(int, parts)
                return rows, cols, nnz
    raise ValueError("Could not determine matrix dimensions from file.")

# check the formatting of the features file and reformat if necessary.
# likely problems are only a single column or a header row when there shouldn't be one.
def format_features_file(features_file_path: str, matrix_dimensions: tuple):

    # if file contains only one column, add a second column identical to the first with name gene_id
    if features_file_path.endswith(".gz"):
        with gzip.open(features_file_path, "rt") as f:
            lines = f.readlines()
    else:
        with open(features_file_path, "r") as f:
            lines = f.readlines()
    
    # check if there is exactly one more row than the number of rows in the matrix (indicating a header)
    nrows = len(lines)
    if nrows == matrix_dimensions[0] + 1:
        header = True
    else:
        header = False
    
    # if there's a header, parse accordingly
    if header:
        header = lines[0].strip().split("\t")
        rows = [line.strip().split("\t") for line in lines[1:]]
    else:
        rows = [line.strip().split("\t") for line in lines]

    # read lines as a pandas dataframe to check number of columns
    features_df = pd.DataFrame(rows)
    # if only one column, duplicate it as column 2 and add dummy column 3
    if features_df.shape[1] == 1:
        features_df[1] = features_df.iloc[:, 0]
        features_df[2] = "gene"
    elif features_df.shape[1] == 2:
        features_df[2] = "gene"
    # keep only first three columns if more are present
    features_df = features_df.iloc[:, :3]
    
    # write the reformatted features file back to disk
    if features_file_path.endswith(".gz"):
        with gzip.open(features_file_path, "wt") as f:
            features_df.to_csv(f, sep="\t", index=False, header = False)
    else:
        with open(features_file_path, "w") as f:
                features_df.to_csv(f, sep="\t", index=False, header = False)

    return features_file_path

def format_barcodes_file(barcodes_file_path: str, matrix_dimensions: tuple):

    # read barcodes file
    if barcodes_file_path.endswith(".gz"):
        with gzip.open(barcodes_file_path, "rt") as f:
            lines = f.readlines()
    else:
        with open(barcodes_file_path, "r") as f:
            lines = f.readlines()
    
    # check if there is exactly one more row than the number of columns in the matrix (indicating a header)
    nrows = len(lines)
    if nrows == matrix_dimensions[1] + 1:
        header = True
    else:
        header = False
    
    # if there's a header, parse accordingly
    if header:
        header = lines[0].strip().split("\t")
        rows = [line.strip().split("\t") for line in lines[1:]]
    else:
        rows = [line.strip().split("\t") for line in lines]
        

    # read lines as a pandas dataframe to check number of columns
    barcodes_df = pd.DataFrame(rows)
    # keep only first column if more are present
    barcodes_df = barcodes_df.iloc[:, :1]
    
     # strip any suffix
    tenx_pattern = r"([ACGTN]{16,20}-\d+)"
    barcodes_df[barcodes_df.columns[0]] = barcodes_df[barcodes_df.columns[0]].str.extract(tenx_pattern)

    # write the reformatted barcodes file back to disk
    if barcodes_file_path.endswith(".gz"):
        with gzip.open(barcodes_file_path, "wt") as f:
            barcodes_df.to_csv(f, sep="\t", index=False, header = False)
    else:
        with open(barcodes_file_path, "w") as f:
                barcodes_df.to_csv(f, sep="\t", index=False, header = False)

    return barcodes_file_path
    


In [None]:
class State(TypedDict):
    query: str
    geo_accession: str
    files_url: list[str]
    download_path: str
    counts_path: str
    anndata: object | None
    anndata_path: str