# NB01: Data Extraction -- AlphaEarth Embeddings + Environment Labels

**Requires BERDL JupyterHub** (`get_spark_session()` is only available there)

This notebook extracts and joins three data layers:
1. AlphaEarth 64-dim environmental embeddings (83K genomes)
2. NCBI environment metadata (EAV format, pivoted)
3. Coverage statistics

**Outputs** (saved to `../data/`):
- `alphaearth_with_env.csv` -- embeddings + pivoted env labels
- `coverage_stats.csv` -- overlap statistics
- `ncbi_env_attribute_counts.csv` -- which attributes exist and their population rates
- `isolation_source_raw_counts.csv` -- raw value counts for harmonization

In [None]:
import os
import pandas as pd

# Initialize Spark session (available on BERDL JupyterHub)
spark = get_spark_session()

# Output directory
DATA_DIR = '../data'
os.makedirs(DATA_DIR, exist_ok=True)

print('Spark session initialized')
print(f'Output directory: {os.path.abspath(DATA_DIR)}')

## 1. Extract AlphaEarth Embeddings

The `alphaearth_embeddings_all_years` table has 83,287 rows -- small enough to collect entirely.

In [None]:
# Extract full AlphaEarth table
ae_df = spark.sql("""
    SELECT *
    FROM kbase_ke_pangenome.alphaearth_embeddings_all_years
""").toPandas()

print(f'AlphaEarth embeddings: {len(ae_df):,} genomes')
print(f'Columns: {list(ae_df.columns)}')
print(f'\nEmbedding columns: A00 through A63 ({len([c for c in ae_df.columns if c.startswith("A") and c[1:].isdigit()])} dimensions)')
print(f'\nLat/lon coverage:')
print(f'  cleaned_lat non-null: {ae_df["cleaned_lat"].notna().sum():,}')
print(f'  cleaned_lon non-null: {ae_df["cleaned_lon"].notna().sum():,}')
print(f'\nYear range: {ae_df["cleaned_year"].min()} - {ae_df["cleaned_year"].max()}')
print(f'\nTaxonomy coverage:')
for col in ['domain', 'phylum', 'class', 'order', 'family', 'genus', 'species']:
    if col in ae_df.columns:
        print(f'  {col}: {ae_df[col].nunique():,} unique values')

In [None]:
# Quick look at the data
ae_df.head(3)

## 2. Inventory NCBI Environment Attributes

Before pivoting, check which `harmonized_name` values exist and how many genomes have each.

In [None]:
# Get attribute inventory across ALL genomes (not just AlphaEarth)
attr_counts = spark.sql("""
    SELECT harmonized_name,
           COUNT(*) as n_rows,
           COUNT(DISTINCT accession) as n_genomes
    FROM kbase_ke_pangenome.ncbi_env
    GROUP BY harmonized_name
    ORDER BY n_genomes DESC
""").toPandas()

print(f'NCBI env attributes: {len(attr_counts)} distinct harmonized_name values')
print(f'\nTop 30 attributes by genome count:')
print(attr_counts.head(30).to_string(index=False))

In [None]:
# Save attribute inventory
attr_counts.to_csv(os.path.join(DATA_DIR, 'ncbi_env_attribute_counts.csv'), index=False)
print(f'Saved ncbi_env_attribute_counts.csv ({len(attr_counts)} attributes)')

## 3. Pivot NCBI Environment Labels for AlphaEarth Genomes

Join `ncbi_env` to AlphaEarth genomes via `ncbi_biosample_accession_id` and pivot key attributes into columns.

In [None]:
# Get the biosample IDs from AlphaEarth table
biosample_ids = ae_df['ncbi_biosample_accession_id'].dropna().unique().tolist()
print(f'AlphaEarth genomes with biosample IDs: {len(biosample_ids):,}')

# Register as temp view for efficient Spark join
biosample_sdf = spark.createDataFrame(
    [(b,) for b in biosample_ids],
    ['accession']
)
biosample_sdf.createOrReplaceTempView('ae_biosamples')

In [None]:
# Pivot key environment attributes for AlphaEarth genomes
# Target attributes based on common NCBI BioSample fields
TARGET_ATTRS = [
    'isolation_source',
    'geo_loc_name',
    'env_broad_scale',
    'env_local_scale',
    'env_medium',
    'host',
    'collection_date',
    'lat_lon',
    'depth',
    'altitude',
    'temp',
]

attr_in_clause = "', '".join(TARGET_ATTRS)

env_pivot = spark.sql(f"""
    SELECT ne.accession,
           MAX(CASE WHEN ne.harmonized_name = 'isolation_source' THEN ne.content END) as isolation_source,
           MAX(CASE WHEN ne.harmonized_name = 'geo_loc_name' THEN ne.content END) as geo_loc_name,
           MAX(CASE WHEN ne.harmonized_name = 'env_broad_scale' THEN ne.content END) as env_broad_scale,
           MAX(CASE WHEN ne.harmonized_name = 'env_local_scale' THEN ne.content END) as env_local_scale,
           MAX(CASE WHEN ne.harmonized_name = 'env_medium' THEN ne.content END) as env_medium,
           MAX(CASE WHEN ne.harmonized_name = 'host' THEN ne.content END) as host,
           MAX(CASE WHEN ne.harmonized_name = 'collection_date' THEN ne.content END) as collection_date,
           MAX(CASE WHEN ne.harmonized_name = 'lat_lon' THEN ne.content END) as lat_lon,
           MAX(CASE WHEN ne.harmonized_name = 'depth' THEN ne.content END) as depth,
           MAX(CASE WHEN ne.harmonized_name = 'altitude' THEN ne.content END) as altitude,
           MAX(CASE WHEN ne.harmonized_name = 'temp' THEN ne.content END) as temp
    FROM kbase_ke_pangenome.ncbi_env ne
    JOIN ae_biosamples ab ON ne.accession = ab.accession
    WHERE ne.harmonized_name IN ('{attr_in_clause}')
    GROUP BY ne.accession
""").toPandas()

print(f'Environment labels pivoted: {len(env_pivot):,} genomes')
print(f'\nAttribute population rates (of {len(env_pivot):,} genomes with any env data):')
for col in env_pivot.columns[1:]:
    n = env_pivot[col].notna().sum()
    pct = 100 * n / len(env_pivot) if len(env_pivot) > 0 else 0
    print(f'  {col}: {n:,} ({pct:.1f}%)')

## 4. Join Embeddings with Environment Labels

In [None]:
# Merge AlphaEarth embeddings with pivoted env labels
merged = ae_df.merge(
    env_pivot,
    left_on='ncbi_biosample_accession_id',
    right_on='accession',
    how='left'
)

print(f'Merged dataset: {len(merged):,} genomes')
print(f'  With env labels: {merged["isolation_source"].notna().sum():,}')
print(f'  Without env labels: {merged["isolation_source"].isna().sum():,}')

# Drop duplicate accession column from merge
if 'accession' in merged.columns:
    merged = merged.drop(columns=['accession'])

## 5. Coverage Statistics

In [None]:
# Compute coverage flags
has_latlon = merged['cleaned_lat'].notna() & merged['cleaned_lon'].notna()
has_isolation = merged['isolation_source'].notna()
has_env_broad = merged['env_broad_scale'].notna()
has_env_local = merged['env_local_scale'].notna()
has_env_medium = merged['env_medium'].notna()
has_host = merged['host'].notna()
has_geo_loc = merged['geo_loc_name'].notna()

coverage = pd.DataFrame({
    'attribute': ['lat/lon', 'isolation_source', 'env_broad_scale', 'env_local_scale',
                  'env_medium', 'host', 'geo_loc_name'],
    'n_genomes': [has_latlon.sum(), has_isolation.sum(), has_env_broad.sum(),
                  has_env_local.sum(), has_env_medium.sum(), has_host.sum(), has_geo_loc.sum()],
    'pct_of_alphaearth': [100 * has_latlon.mean(), 100 * has_isolation.mean(),
                          100 * has_env_broad.mean(), 100 * has_env_local.mean(),
                          100 * has_env_medium.mean(), 100 * has_host.mean(),
                          100 * has_geo_loc.mean()]
})

print('Coverage of AlphaEarth genomes (83K total):')
print(coverage.to_string(index=False))

# Intersection counts for UpSet plot
# Store per-genome boolean flags
merged['has_latlon'] = has_latlon
merged['has_isolation_source'] = has_isolation
merged['has_env_broad_scale'] = has_env_broad
merged['has_env_local_scale'] = has_env_local
merged['has_env_medium'] = has_env_medium
merged['has_host'] = has_host
merged['has_geo_loc_name'] = has_geo_loc

In [None]:
coverage.to_csv(os.path.join(DATA_DIR, 'coverage_stats.csv'), index=False)
print('Saved coverage_stats.csv')

## 6. Raw Isolation Source Value Counts

Save the raw `isolation_source` values for harmonization in NB02.

In [None]:
iso_counts = (
    merged[merged['isolation_source'].notna()]['isolation_source']
    .value_counts()
    .reset_index()
)
iso_counts.columns = ['isolation_source', 'count']

print(f'Unique isolation_source values: {len(iso_counts):,}')
print(f'\nTop 30:')
print(iso_counts.head(30).to_string(index=False))

iso_counts.to_csv(os.path.join(DATA_DIR, 'isolation_source_raw_counts.csv'), index=False)
print(f'\nSaved isolation_source_raw_counts.csv')

## 7. Save Merged Dataset

In [None]:
# Save the full merged dataset
out_path = os.path.join(DATA_DIR, 'alphaearth_with_env.csv')
merged.to_csv(out_path, index=False)

print(f'Saved alphaearth_with_env.csv')
print(f'  Rows: {len(merged):,}')
print(f'  Columns: {len(merged.columns)}')
print(f'  File size: {os.path.getsize(out_path) / 1e6:.1f} MB')

## 8. Quick Sanity Checks

In [None]:
# Embedding dimension stats
emb_cols = [f'A{i:02d}' for i in range(64)]
emb_stats = merged[emb_cols].describe().T
print('Embedding dimension summary (A00-A63):')
print(f'  Value range: [{emb_stats["min"].min():.3f}, {emb_stats["max"].max():.3f}]')
print(f'  Mean of means: {emb_stats["mean"].mean():.3f}')
print(f'  Mean of stds: {emb_stats["std"].mean():.3f}')
print(f'  Any NaN in embeddings: {merged[emb_cols].isna().any().any()}')

# Lat/lon range
print(f'\nLat range: [{merged["cleaned_lat"].min():.2f}, {merged["cleaned_lat"].max():.2f}]')
print(f'Lon range: [{merged["cleaned_lon"].min():.2f}, {merged["cleaned_lon"].max():.2f}]')

# Phylum distribution
print(f'\nTop 10 phyla:')
print(merged['phylum'].value_counts().head(10).to_string())

In [None]:
print('\n=== Data extraction complete ===')
print(f'Output files in {os.path.abspath(DATA_DIR)}:')
for f in sorted(os.listdir(DATA_DIR)):
    if f.endswith('.csv'):
        size = os.path.getsize(os.path.join(DATA_DIR, f)) / 1e6
        print(f'  {f} ({size:.1f} MB)')