In [None]:
"""
ShapeNet .npz Model Downloader from Google Cloud Storage

This script downloads specific .npz ShapeNet models from a Google Cloud Storage bucket.

Prerequisites:
- Google Cloud SDK installed and configured
- Appropriate permissions to access the GCS bucket
- Required Python packages: google-cloud-storage, tqdm, numpy

Usage:
    python shapenet_downloader.py
"""

# ShapeNet .npz Model Downloader from Google Cloud Storage

This script downloads specific .npz ShapeNet models from a Google Cloud Storage bucket.

One can choose shapenet categories and predetermined resolutions of the point clouds to extract in:
CATEGORIES_TO_EXTRACT and RESOLUTION_TO_EXTRACT.

## Prerequisites
- Google Cloud SDK installed and configured
- Appropriate permissions to access the GCS bucket
- Required Python packages installed

In [None]:
import os
import numpy as np
from google.cloud import storage
from tqdm import tqdm
from pathlib import Path

In [None]:
# Configuration settings
BUCKET_NAME = "adlr2025-pointclouds"  # Replace with your bucket name
PROJECT_ID = "adlr-2025"  # Replace with your GCP project ID
LOCAL_DOWNLOAD_PATH = "../src/data/shapenet_extracted"  # Local directory to save models

# Create local directory if it doesn't exist
Path(LOCAL_DOWNLOAD_PATH).mkdir(parents=True, exist_ok=True)

print(f"Bucket: {BUCKET_NAME}")
print(f"Local download path: {LOCAL_DOWNLOAD_PATH}")

In [None]:
# Category filtering - Set to None to extract all, or specify categories to extract
# CATEGORIES_TO_EXTRACT = None             # Extract all categories
CATEGORIES_TO_EXTRACT = ["02946921", "02880940", "03085013"]  # Only airplane, car, chair

if CATEGORIES_TO_EXTRACT:
    print(f"Categories to extract: {len(CATEGORIES_TO_EXTRACT)} categories")
    for cat in CATEGORIES_TO_EXTRACT:
        print(f"  - {cat}")
else:
    print("Extracting all categories")

# Resolution filtering - Set to None to extract all resolutions, or specify resolutions to extract
# RESOLUTION_TO_EXTRACT = None             # Extract all resolutions
RESOLUTION_TO_EXTRACT = [2048]       # Only specific resolutions

if RESOLUTION_TO_EXTRACT:
    print(f"Resolutions to extract: {RESOLUTION_TO_EXTRACT}")
else:
    print("Extracting all resolutions")

In [None]:
def setup_gcs_client():
    """Initialize the Google Cloud Storage client"""
    try:
        client = storage.Client(project=PROJECT_ID)
        bucket = client.bucket(BUCKET_NAME)
        print("✓ Successfully connected to Google Cloud Storage")
        return client, bucket
    except Exception as e:
        print(f"✗ Error connecting to GCS: {e}")
        print("Make sure you have authenticated with 'gcloud auth login' or set up service account credentials")
        return None, None

# Initialize client and bucket
client, bucket = setup_gcs_client()

In [None]:
def list_npz_files_in_categories(bucket, categories_to_extract=None, max_files_per_category=50):
    """
    List all .npz files in specified category folders
    
    Args:
        bucket: GCS bucket object
        categories_to_extract: List of category folder names, or None for all
        max_files_per_category: Maximum number of files to return per category
    
    Returns:
        List of file information dictionaries
    """
    npz_files = []
    
    if categories_to_extract is None:
        # If no categories specified, scan entire bucket
        print("Scanning entire bucket for .npz files...")
        blobs = bucket.list_blobs()
        
        for blob in blobs:
            if blob.name.endswith('.npz'):
                npz_files.append({
                    'name': blob.name,
                    'category': blob.name.split('/')[0] if '/' in blob.name else 'root',
                    'size_mb': round(blob.size / (1024 * 1024), 2),
                    'updated': blob.updated.strftime('%Y-%m-%d %H:%M:%S')
                })
                
                if len(npz_files) >= max_files_per_category * 20:  # Reasonable limit for all categories
                    break
    else:
        # Scan specific category folders
        print(f"Scanning {len(categories_to_extract)} category folders for .npz files...")
        
        for category in categories_to_extract:
            print(f"  Scanning category: {category}")
            category_files = 0
            
            # List all blobs in this category folder
            blobs = bucket.list_blobs(prefix=f"{category}/")
            
            for blob in blobs:
                if blob.name.endswith('.npz'):
                    npz_files.append({
                        'name': blob.name,
                        'category': category,
                        'size_mb': round(blob.size / (1024 * 1024), 2),
                        'updated': blob.updated.strftime('%Y-%m-%d %H:%M:%S')
                    })
                    category_files += 1
                    
                    if category_files >= max_files_per_category:
                        print(f"    Found {category_files} files (limit reached)")
                        break
            
            if category_files < max_files_per_category:
                print(f"    Found {category_files} files")
    
    return npz_files

# List available .npz files based on category filtering
if bucket:
    available_files = list_npz_files_in_categories(bucket, CATEGORIES_TO_EXTRACT, max_files_per_category=1000)
    
    print(f"\nFound {len(available_files)} .npz files:")
    
    # Group by category for summary
    category_counts = {}
    for file_info in available_files:
        category = file_info['category']
        category_counts[category] = category_counts.get(category, 0) + 1
    
    print("\nFiles by category:")
    for category, count in category_counts.items():
        print(f"  {category}: {count} files")
    
    # Show first 10 files
    print(f"\nFirst 10 files:")
    for i, file_info in enumerate(available_files[:10]):
        print(f"{i+1:2d}. {file_info['name']} ({file_info['size_mb']} MB) [{file_info['category']}]")
    
    if len(available_files) > 10:
        print(f"... and {len(available_files) - 10} more files")
else:
    available_files = []
    print("Cannot list files - GCS client not initialized")

In [None]:
# %% Filter Files for Download
def filter_files_by_resolution(available_files, resolution_to_extract=None):
    """
    Filter files by resolution based on filename patterns like pc_2048.npz, pc_10240.npz
    
    Args:
        available_files: List of available file information
        resolution_to_extract: List of resolutions to extract, or None for all
    
    Returns:
        List of files to download
    """
    files_to_download = []
    
    for file_info in available_files:
        file_name = os.path.basename(file_info['name'])  # Get just the filename
        should_download = False
        
        # If no resolution filter, download all files
        if resolution_to_extract is None:
            should_download = True
            file_info['reason'] = 'all_resolutions'
            file_info['resolution'] = 'unknown'
        else:
            # Check if filename matches pattern pc_XXXX.npz
            for resolution in resolution_to_extract:
                if file_name == f"pc_{resolution}.npz":
                    should_download = True
                    file_info['reason'] = 'resolution_match'
                    file_info['resolution'] = resolution
                    break
            
            # If no resolution matched, mark for skipping
            if not should_download:
                file_info['reason'] = 'resolution_filtered'
                # Try to extract resolution from filename for reporting
                if file_name.startswith('pc_') and file_name.endswith('.npz'):
                    try:
                        extracted_res = file_name[3:-4]  # Remove 'pc_' and '.npz'
                        file_info['resolution'] = int(extracted_res)
                    except ValueError:
                        file_info['resolution'] = 'unknown'
                else:
                    file_info['resolution'] = 'unknown'
        
        if should_download:
            files_to_download.append(file_info)
    
    return files_to_download

# Filter files by resolution
files_to_download = filter_files_by_resolution(available_files, RESOLUTION_TO_EXTRACT)

print(f"\nSelected {len(files_to_download)} files for download:")
total_size_mb = 0
for file_info in files_to_download:
    total_size_mb += file_info['size_mb']
    print(f"- {file_info['name']} ({file_info['size_mb']} MB) [{file_info['reason']}]")

print(f"\nTotal download size: {total_size_mb:.2f} MB")

In [None]:
def download_file(bucket, blob_name, local_path, overwrite=False):
    """
    Download a single file from GCS bucket, preserving folder structure
    
    Args:
        bucket: GCS bucket object
        blob_name: Full path of the blob to download (e.g., "category/model.npz")
        local_path: Local directory path
        overwrite: Whether to overwrite existing files
    
    Returns:
        Tuple of (success: bool, local_file_path: str)
    """
    # Create local file path preserving the folder structure
    local_file_path = os.path.join(local_path, blob_name)
    
    # Create subdirectories if they don't exist
    local_dir = os.path.dirname(local_file_path)
    Path(local_dir).mkdir(parents=True, exist_ok=True)
    
    # Check if file already exists
    if os.path.exists(local_file_path) and not overwrite:
        print(f"⏭️  Skipping {blob_name} (already exists)")
        return True, local_file_path
    
    try:
        blob = bucket.blob(blob_name)
        blob.download_to_filename(local_file_path)
        return True, local_file_path
    except Exception as e:
        print(f"❌ Error downloading {blob_name}: {e}")
        return False, None

def download_files_batch(bucket, files_to_download, local_path, overwrite=False):
    """
    Download multiple files with progress tracking
    
    Args:
        bucket: GCS bucket object
        files_to_download: List of file information dictionaries
        local_path: Local directory path
        overwrite: Whether to overwrite existing files
    
    Returns:
        Tuple of (successful_downloads: list, failed_downloads: list)
    """
    successful_downloads = []
    failed_downloads = []
    
    print(f"\nStarting download of {len(files_to_download)} files...")
    print(f"Download directory: {local_path}")
    
    for file_info in tqdm(files_to_download, desc="Downloading"):
        blob_name = file_info['name']
        success, local_file_path = download_file(bucket, blob_name, local_path, overwrite)
        
        if success:
            successful_downloads.append({
                'blob_name': blob_name,
                'local_path': local_file_path,
                'size_mb': file_info['size_mb']
            })
            print(f"✅ Downloaded: {blob_name}")
        else:
            failed_downloads.append(blob_name)
    
    return successful_downloads, failed_downloads

In [None]:
def execute_download():
    """Execute the download process"""
    if not bucket:
        print("Cannot download - GCS client not initialized")
        return [], []
    
    if not files_to_download:
        print("No files selected for download.")
        return [], []
    
    successful_downloads, failed_downloads = download_files_batch(
        bucket, 
        files_to_download, 
        LOCAL_DOWNLOAD_PATH, 
        overwrite=False  # Set to True to overwrite existing files
    )
    
    # Summary
    print(f"\n{'='*50}")
    print("DOWNLOAD SUMMARY")
    print(f"{'='*50}")
    print(f"✅ Successful downloads: {len(successful_downloads)}")
    print(f"❌ Failed downloads: {len(failed_downloads)}")
    
    if successful_downloads:
        total_downloaded_mb = sum(d['size_mb'] for d in successful_downloads)
        print(f"📁 Total downloaded: {total_downloaded_mb:.2f} MB")
        print(f"📂 Download location: {LOCAL_DOWNLOAD_PATH}")
    
    if failed_downloads:
        print(f"\nFailed downloads:")
        for failed_file in failed_downloads:
            print(f"  - {failed_file}")
    
    return successful_downloads, failed_downloads

# Execute the download
successful_downloads, failed_downloads = execute_download()

In [None]:
def verify_npz_files(download_path):
    """
    Verify that downloaded .npz files can be loaded
    
    Args:
        download_path: Path to directory containing .npz files
    
    Returns:
        Tuple of (verified_files: list, corrupted_files: list)
    """
    npz_files = list(Path(download_path).glob('*.npz'))
    
    print(f"\nVerifying {len(npz_files)} downloaded .npz files...")
    
    verified_files = []
    corrupted_files = []
    
    for npz_file in npz_files:
        try:
            # Try to load the .npz file
            with np.load(npz_file) as data:
                keys = list(data.keys())
                file_info = {
                    'filename': npz_file.name,
                    'size_mb': round(npz_file.stat().st_size / (1024 * 1024), 2),
                    'keys': keys,
                    'num_arrays': len(keys)
                }
                verified_files.append(file_info)
                print(f"✅ {npz_file.name}: {len(keys)} arrays - {keys[:3]}{'...' if len(keys) > 3 else ''}")
        except Exception as e:
            corrupted_files.append(npz_file.name)
            print(f"❌ {npz_file.name}: Error loading - {e}")
    
    print(f"\nVerification complete:")
    print(f"✅ Valid files: {len(verified_files)}")
    print(f"❌ Corrupted files: {len(corrupted_files)}")
    
    return verified_files, corrupted_files

# Verify downloaded files
if os.path.exists(LOCAL_DOWNLOAD_PATH):
    verified_files, corrupted_files = verify_npz_files(LOCAL_DOWNLOAD_PATH)
else:
    verified_files, corrupted_files = [], []
    print(f"Download path {LOCAL_DOWNLOAD_PATH} does not exist.")

In [None]:
def inspect_sample_npz(file_path, max_arrays=5):
    """
    Inspect the contents of a sample .npz file
    
    Args:
        file_path: Path to the .npz file
        max_arrays: Maximum number of arrays to inspect
    """
    try:
        with np.load(file_path) as data:
            print(f"\nInspecting: {os.path.basename(file_path)}")
            print(f"File size: {os.path.getsize(file_path) / (1024*1024):.2f} MB")
            print(f"Number of arrays: {len(data.keys())}")
            print("\nArray details:")
            
            for i, key in enumerate(list(data.keys())[:max_arrays]):
                arr = data[key]
                print(f"  {key}: shape={arr.shape}, dtype={arr.dtype}, size={arr.nbytes/(1024*1024):.2f}MB")
                
                # Show some sample values for small arrays
                if arr.size < 20:
                    print(f"    Sample values: {arr.flatten()[:10]}")
            
            if len(data.keys()) > max_arrays:
                print(f"  ... and {len(data.keys()) - max_arrays} more arrays")
                
    except Exception as e:
        print(f"Error inspecting {file_path}: {e}")

# Inspect a sample file if any were downloaded
if verified_files:
    sample_file = os.path.join(LOCAL_DOWNLOAD_PATH, verified_files[0]['filename'])
    inspect_sample_npz(sample_file)
else:
    print("No verified files available for inspection.")

In [None]:
"""Main function to run the entire download process"""
print("ShapeNet NPZ Downloader")
print("=" * 50)

# Configuration summary
print(f"Bucket: {BUCKET_NAME}")
print(f"Project: {PROJECT_ID}")
print(f"Download path: {LOCAL_DOWNLOAD_PATH}")
print(f"Files to download: {len(files_to_download)}")

if files_to_download:
    # Confirm download
    response = input(f"\nProceed with downloading {len(files_to_download)} files? (y/n): ")
    if response.lower() == 'y':
        successful, failed = execute_download()
        
        if successful:
            verify_npz_files(LOCAL_DOWNLOAD_PATH)
            print(f"\n✅ Download completed successfully!")
            print(f"📁 {len(successful)} files downloaded to {LOCAL_DOWNLOAD_PATH}")
        else:
            print("❌ No files were downloaded successfully.")
    else:
        print("Download cancelled.")
else:
    print("No files selected for download.")