<a href="https://colab.research.google.com/github/cortalo/DiffCkt/blob/master/DiffCkt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import re
import zipfile
import torch
import warnings
import numpy as np
from torch.utils.data import Dataset, DataLoader

class CircuitDataset(Dataset):
    """PyTorch Dataset for circuit graph data with fixed-size tensors.

    Each sample contains:
        - Node features (X)
        - Edge features (E)
        - Circuit performance (Y)
        - Node masks (node_mask)
    """
    def __init__(self, data_list):
        """
        Args:
            data_list: List[dict] where each dictionary contains:
                - X: [22, 21] node features (float32)
                - E: [22, 22, 25] edge features (float32)
                - Y: [13] circuit performance (float32)
                - node_mask: [22] node masks (bool)
        """
        self.data = data_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        return {
            'X': torch.from_numpy(sample['X']).float(),  # convert to torch.tensor
            'E': torch.from_numpy(sample['E']).float(),
            'Y': torch.from_numpy(sample['Y']).float(),
            'node_mask': torch.from_numpy(sample['node_mask']).bool()
        }


def unzip_with_progress(zip_path, target_dir):
    """Extract ZIP file and print current file being processed.

    Args:
        zip_path: Path to source ZIP file
        target_dir: Destination directory for extracted files

    Returns:
        bool: True if extraction succeeded, False otherwise
    """
    try:
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(target_dir)

        return True

    except Exception as e:
        print(f"Failed to extract {os.path.basename(zip_path)}: {str(e)}")
        return False


def unzip_all_data_chunks():
    """Chunk processing for all data chunk ZIP files in current directory."""
    print("Starting data chunk extraction of data_chunks_*.zip files...")
    pattern = re.compile(r'^data_chunks_\d+\.zip$')
    current_dir = os.getcwd()
    target_dir = os.path.join(current_dir, "unzipped_data")

    os.makedirs(target_dir, exist_ok=True)
    print(f"Extraction directory: {target_dir}")

    zip_files = [f for f in os.listdir(current_dir) if pattern.match(f)]
    if not zip_files:
        print("No data_chunks_*.zip files found")
        return

    print(f"Found {len(zip_files)} files to extract")

    success_count = 0
    for i, zip_file in enumerate(zip_files, 1):
        print(f"Processing file {i}/{len(zip_files)}: {zip_file}")
        zip_path = os.path.join(current_dir, zip_file)
        if unzip_with_progress(zip_path, target_dir):
            success_count += 1

    print(f"\nExtraction complete: {success_count}/{len(zip_files)} files processed\n")

def load_data(dir_path='unzipped_data'):
    """Load and concatenate all circuit data from extracted files.

    Args:
        dir_path: The directory path of unzipped data, default to 'unzipped_data'
    Returns:
        CircuitDataset: Initialized dataset object
    """
    # Suppress FutureWarnings
    warnings.filterwarnings("ignore", category=FutureWarning)

    files = [f for f in os.listdir(dir_path)]
    files = sorted(files, key=lambda x: int(re.search(r'(\d+)', x).group()))

    print(f"Loading {len(files)} data chunks...")

    data_list = []
    for i, file in enumerate(files, 1):
        if i % 100 == 1:
            print(f"Loading file {i}/{len(files)}: {file}")
        file_path = os.path.join(dir_path, file)
        data_chunk = torch.load(file_path, weights_only=False)
        data_list.extend(data_chunk)

    circuit_dataset = CircuitDataset(data_list)
    print("Circuit dataset is created and returned")
    return circuit_dataset


if __name__ == "__main__":
    #unzip_all_data_chunks() # extract data chunk (execute only once)
    print("finished")

finished


In [None]:
data = load_data()

Loading 1143 data chunks...
Loading file 1/1143: data_chunk_0.pt
Loading file 101/1143: data_chunk_100.pt
Loading file 201/1143: data_chunk_200.pt
Loading file 301/1143: data_chunk_300.pt
Loading file 401/1143: data_chunk_400.pt
Loading file 501/1143: data_chunk_500.pt
Loading file 601/1143: data_chunk_600.pt
Loading file 701/1143: data_chunk_700.pt
Loading file 801/1143: data_chunk_800.pt
Loading file 901/1143: data_chunk_900.pt
Loading file 1001/1143: data_chunk_1000.pt
Loading file 1101/1143: data_chunk_1100.pt
Circuit dataset is created and returned


In [None]:
data.__getitem__(0)
print("hello")