<a href="https://colab.research.google.com/github/marcvonrohr/DeepLearning/blob/main/meta_learning_project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import time
import json
import random
from google.colab import drive

#################################################################
#  STEP 2.1: PREPARE LOCAL VM
#################################################################

# --- 1. Mount Google Drive ---
print("Connecting Google Drive...")
drive.mount('/content/drive')
print("...Google Drive connected.")

# --- 2. Define Key Paths ---
GDRIVE_ROOT = '/content/drive/MyDrive/'
PROJECT_DIR = os.path.join(GDRIVE_ROOT, 'Deep Learning')
DATASETS_ROOT_DIR = os.path.join(PROJECT_DIR, 'datasets')
INAT_ROOT_DIR = os.path.join(DATASETS_ROOT_DIR, 'inaturalist')

# Source: The COMPRESSED archives
ARCHIVES_DIR_ON_DRIVE = os.path.join(INAT_ROOT_DIR, 'archives')

# Target: The LOCAL VM fast disk
LOCAL_DATA_ROOT = '/content/data'
# This is the final path your PyTorch code will use:
FINAL_DATA_PATH = os.path.join(LOCAL_DATA_ROOT, 'inaturalist_unpacked')

# Define source/destination paths
TAR_FILES = {
    "2021_train_mini": {
        "src": os.path.join(ARCHIVES_DIR_ON_DRIVE, '2021_train_mini.tar.gz'),
        "dest_tar": os.path.join(LOCAL_DATA_ROOT, '2021_train_mini.tar.gz'),
        "check_unpacked": os.path.join(FINAL_DATA_PATH, '2021_train_mini')
    },
    "2021_valid": {
        "src": os.path.join(ARCHIVES_DIR_ON_DRIVE, '2021_valid.tar.gz'),
        "dest_tar": os.path.join(LOCAL_DATA_ROOT, '2021_valid.tar.gz'),
        "check_unpacked": os.path.join(FINAL_DATA_PATH, '2021_valid')
    }
}

# --- 3. Create Local Directories on VM ---
os.makedirs(LOCAL_DATA_ROOT, exist_ok=True)
os.makedirs(FINAL_DATA_PATH, exist_ok=True)
print(f"Local data directory created at: {FINAL_DATA_PATH}")

# --- 4. Copy, Unpack, and Clean up for each file ---
for name, paths in TAR_FILES.items():
    print(f"\n--- Processing {name} ---")

    if os.path.exists(paths["check_unpacked"]):
        print(f"'{name}' is already unpacked in local VM. Skipping.")
        continue

    # 4a. Copy .tar.gz from Drive to local VM
    print(f"Copying '{name}.tar.gz' from Drive to local VM...")
    start_time = time.time()
    !cp "{paths['src']}" "{paths['dest_tar']}"
    print(f"...Copy complete. Took {time.time() - start_time:.2f} seconds.")

    # 4b. Unpack the file on the local VM
    print(f"Unpacking '{name}.tar.gz' locally...")
    start_time = time.time()
    !tar -xzf "{paths['dest_tar']}" -C "{FINAL_DATA_PATH}"
    print(f"...Unpacking complete. Took {time.time() - start_time:.2f} seconds.")

    # 4c. Delete the local .tar.gz file to save VM space
    print(f"Deleting local tarball '{paths['dest_tar']}'...")
    !rm "{paths['dest_tar']}"
    print("...Local tarball deleted.")

# --- 5. Verify and Set Path for Training ---
print("\n--- Final Data Setup Verification ---")
print(f"Dataset is ready for training at: {FINAL_DATA_PATH}")
!ls -lh "{FINAL_DATA_PATH}"
print("\nLocal VM Disk Space Usage:")
!df -h

Mounted at /content/drive


In [None]:
#################################################################
#  STEP 2.2: SCIENTIFIC DATA PARTITIONING
#################################################################
print("\n--- STEP 2.2: Loading/Creating Scientific Class Partition ---")

# --- 6. Define Paths for Partition File ---
# We create a 'project_meta' folder on GDrive to store helper files
META_DIR_ON_DRIVE = os.path.join(PROJECT_DIR, 'project_meta')
os.makedirs(META_DIR_ON_DRIVE, exist_ok=True)

PARTITION_FILE_PATH = os.path.join(META_DIR_ON_DRIVE, 'inat_class_split.json')
print(f"Looking for partition file at: {PARTITION_FILE_PATH}")

In [None]:
# --- 7. Load or Create the Partition ---
class_split = {}
RANDOM_SEED = 42 # Guarantees the shuffle is always the same

if os.path.exists(PARTITION_FILE_PATH):
    # Load the existing file from GDrive
    print("Found existing partition file. Loading...")
    with open(PARTITION_FILE_PATH, 'r') as f:
        class_split = json.load(f)
        # Ensure keys are loaded as lists (just in case)
        class_split = {k: list(v) for k, v in class_split.items()}

else:
    # Create the partition for the first time
    print("No partition file found. Creating new partition...")

    # 7a. Load the dataset's metadata file from the local VM
    # This file contains the list of all 10,000 categories
    metadata_file = os.path.join(FINAL_DATA_PATH, '2021_train_mini', '2021_train_mini.json')

    if not os.path.exists(metadata_file):
        print(f"ERROR: Metadata file not found at {metadata_file}")
        # Stop execution if something went wrong in Step 2.1
        raise FileNotFoundError(f"Metadata file not found: {metadata_file}")

    with open(metadata_file, 'r') as f:
        metadata = json.load(f)

    # 7b. Get all category IDs (0-9999)
    # The 'categories' list in the JSON is ordered by ID from 0 to 9999
    num_classes = len(metadata['categories'])
    if num_classes != 10000:
        print(f"WARNING: Expected 10,000 classes, but found {num_classes}.")

    all_class_ids = list(range(num_classes))

    # 7c. Shuffle the class list reproducibly
    print(f"Shuffling {num_classes} class IDs with random seed {RANDOM_SEED}...")
    random.seed(RANDOM_SEED)
    random.shuffle(all_class_ids)

    # 7d. Split into C_base, C_val, C_novel
    c_base_ids = all_class_ids[:6000]
    c_val_ids = all_class_ids[6000:8000]
    c_novel_ids = all_class_ids[8000:]

    class_split = {
        'c_base': sorted(c_base_ids), # Sort for easier inspection
        'c_val': sorted(c_val_ids),
        'c_novel': sorted(c_novel_ids)
    }

    # 7e. Save the new partition file to Google Drive
    print(f"Saving new partition file to: {PARTITION_FILE_PATH}")
    with open(PARTITION_FILE_PATH, 'w') as f:
        json.dump(class_split, f, indent=4)

In [None]:
# --- 8. Verification ---
print("\n--- Partitioning Complete ---")
print(f"Total C_base classes:  {len(class_split['c_base'])}")
print(f"Total C_val classes:   {len(class_split['c_val'])}")
print(f"Total C_novel classes: {len(class_split['c_novel'])}")

# Check for overlaps (should be 0)
base_val_overlap = set(class_split['c_base']) & set(class_split['c_val'])
base_novel_overlap = set(class_split['c_base']) & set(class_split['c_novel'])
val_novel_overlap = set(class_split['c_val']) & set(class_split['c_novel'])

print(f"Overlap (Base-Val):    {len(base_val_overlap)}")
print(f"Overlap (Base-Novel):  {len(base_novel_overlap)}")
print(f"Overlap (Val-Novel):   {len(val_novel_overlap)}")

print("\nReady to proceed with Data Loader.")