In [1]:
import numpy as np
import pandas as pd
import os
import torch
from collections import Counter
from torch.utils.data import Dataset, DataLoader
# from Bio import SeqIO # Uncomment if you want to inspect GFF3 files (e.g., ensembl_annotation.gff3) later with Biopython

# --- STEP 1: Mount Google Drive ---
# This is the initial step to make your Google Drive files accessible within Google Colab.
# When you run this cell, a pop-up window will appear asking you to authorize Colab
# to access your Google Drive. You'll need to select your Google account and grant permissions.
from google.colab import drive
drive.mount('/content/gdrive')

# --- Configuration ---
# This section defines the paths to your data files within your Google Drive.
# It's crucial to set these paths correctly so your code can find the data.

# `google_drive_project_path` points to the folder in your Google Drive
# where your main project files (including the 'data' subfolder) are located.
# Based on your previous information, your data is in 'My Drive/DnARnAProject/data'.
# So, the 'DnARnAProject' folder is the direct parent of your 'data' folder.
google_drive_project_path = '/content/gdrive/MyDrive/DnARnAProject/'

# `data_dir` constructs the full path to your 'data' folder by joining the project path
# with the 'data' folder name. This is where data.npz, regions.parquet, etc., are stored.
data_dir = os.path.join(google_drive_project_path, 'data/')

print(f"Attempting to access data in: {data_dir}")

# --- Verify the data directory exists ---
# This block checks if the specified data directory actually exists in your mounted Google Drive.
# It's a critical debugging step to ensure your path is correct before attempting to load data.
if not os.path.isdir(data_dir):
    print(f"Error: The directory '{data_dir}' does not exist.")
    print("Please check your Google Drive path and folder names for typos.")
    print("Listing contents of your project folder for debugging:")
    try:
        # This helps in debugging by showing what's actually inside the assumed project path.
        # It can help you spot if 'DnARnAProject' is misspelled or located elsewhere.
        print(os.listdir(google_drive_project_path))
    except FileNotFoundError:
        # Handles the case where even the parent project folder isn't found.
        print(f"Project folder '{google_drive_project_path}' not found either. Check the full path.")
    # Note: We don't use exit() here in a Colab notebook to avoid stopping the entire runtime.
    # However, subsequent cells relying on `data_dir` will likely fail if the path is wrong.


# --- 1. Inspect data.npz ---
# This section attempts to load and inspect 'data.npz'.
# This file contains the concatenated sequence and expression data for all chromosomes.
print("--- Inspecting data.npz ---")
try:
    data_npz_path = os.path.join(data_dir, 'data.npz')
    # np.load() is used to load .npz files, which are zipped NumPy arrays.
    data_npz = np.load(data_npz_path)

    print(f"Keys in data.npz: {list(data_npz.keys())}") # Shows what arrays are stored inside the .npz file.

    # Check for 'sequence' data, its shape, data type, and a sample of values.
    if 'sequence' in data_npz:
        seq_array = data_npz['sequence']
        print(f"Sequence array shape: {seq_array.shape}") # (num_bases_total,) indicating a 1D array of base encodings.
        print(f"Sequence array dtype: {seq_array.dtype}") # Typically uint8 for integer encodings (A=0, C=1, G=2, T=3, N=4).
        print(f"First 100 sequence values: {seq_array[:100]}")
        # Sample unique values to confirm the encoding scheme (e.g., 0, 1, 2, 3, 4).
        unique_seq_values = np.unique(seq_array[:10000])
        print(f"Unique values in sequence (sample): {unique_seq_values}")
        counts = dict(Counter(seq_array))
        print("Base counts in entire sequence array:")
        print(counts)
        print("Unique values in first 1000 bases:", np.unique(seq_array[:1000]))
        print("Unique values in first 10000 bases:", np.unique(seq_array[:10000]))


    # Check for 'expression_plus' data (forward strand expression).
    if 'expressed_plus' in data_npz:
        expr_plus_array = data_npz['expressed_plus']
        print(f"Expression_plus array shape: {expr_plus_array.shape}") # (num_bases_total,) similar to sequence.
        print(f"Expression_plus array dtype: {expr_plus_array.dtype}") # Typically uint8 (expressed=1, unexpressed=0).
        print(f"First 10 expression_plus values: {expr_plus_array[:10]}")
        unique_expr_values = np.unique(expr_plus_array[:10000])
        print(f"Unique values in expression_plus (sample): {unique_expr_values}")

    # Check for 'expression_minus' data (backward strand expression).
    if 'expressed_minus' in data_npz:
        expr_minus_array = data_npz['expressed_minus']
        print(f"Expression_minus array shape: {expr_minus_array.shape}")
        print(f"Expression_minus array dtype: {expr_minus_array.dtype}")
        print(f"First 10 expression_minus values: {expr_minus_array[:10]}")
        unique_expr_values = np.unique(expr_minus_array[:10000])
        print(f"Unique values in expression_minus (sample): {unique_expr_values}")

except FileNotFoundError:
    print(f"Error: {data_npz_path} not found. This typically means the path is incorrect after mounting Drive.")
    print("Double-check the folder names in your Google Drive for typos.")
except Exception as e:
    print(f"An error occurred while loading data.npz: {e}")

# --- 2. Inspect regions.parquet ---
# This section loads and inspects 'regions.parquet'.
# This file contains suggested training regions with their offsets and window sizes.
print("\n--- Inspecting regions.parquet ---")
try:
    regions_parquet_path = os.path.join(data_dir, 'regions.parquet')
    # pd.read_parquet() is used to load data from Parquet files into a Pandas DataFrame.
    regions_df = pd.read_parquet(regions_parquet_path)

    print(f"Regions DataFrame shape: {regions_df.shape}") # Displays the number of rows (regions) and columns.
    print("Regions DataFrame head:")
    print(regions_df.head()) # Shows the first few rows of the DataFrame, providing a quick look at the data.
    print("\nRegions DataFrame info:")
    regions_df.info() # Provides a summary of the DataFrame, including data types and non-null values for each column.

    print(f"\nUnique strands: {regions_df['strand'].unique()}") # Checks for unique values in the 'strand' column (e.g., '+' or '-').

except FileNotFoundError:
    print(f"Error: {regions_parquet_path} not found. This typically means the path is incorrect after mounting Drive.")
except Exception as e:
    print(f"An error occurred while loading regions.parquet: {e}")


# --- 3. Inspect ensembl_annotation.gff3 (Optional) ---
# This section attempts to inspect 'ensembl_annotation.gff3'.
# This is an annotation file from Ensembl and is optional for the DataLoader itself,
# but can be useful for understanding the problem.
print("\n--- Inspecting ensembl_annotation.gff3 (Optional) ---")
try:
    gff3_path = os.path.join(data_dir, 'ensembl_annotation.gff3')
    # This block opens the GFF3 file and prints its first 10 lines to give a glimpse of its format.
    with open(gff3_path, "r") as handle:
        gff_lines = [next(handle) for _ in range(10)]
        print("First 10 lines of GFF3 file:")
        for line in gff_lines:
            print(line.strip()) # .strip() removes leading/trailing whitespace including newlines.

except FileNotFoundError:
    print(f"Error: {gff3_path} not found. This file is optional for the DataLoader.")
except Exception as e:
    print(f"An error occurred while inspecting ensembl_annotation.gff3: {e}")


# --- Custom PyTorch DataLoader Class (same as before) ---
# This is your core data loading class for PyTorch, inheriting from torch.utils.data.Dataset.
# It defines how individual data samples (sequence segments and expression labels) are loaded.
class GenomeExpressionDataset(Dataset):
    """
    Custom Dataset for loading DNA sequence and expression data for genomic regions.
    It loads data from pre-processed .npz and .parquet files.
    """
    def __init__(self, data_dir):
        """
        Initializes the dataset by loading the full sequence and expression arrays
        and the DataFrame of genomic regions.

        Args:
            data_dir (str): The path to the directory containing 'data.npz' and 'regions.parquet'.
        """
        self.data_dir = data_dir

        # Construct full paths to data files
        self.data_npz_path = os.path.join(data_dir, 'data.npz')
        self.regions_parquet_path = os.path.join(data_dir, 'regions.parquet')

        # Load data.npz containing sequence and expression arrays
        try:
            self.data_npz = np.load(self.data_npz_path)
            self.sequence_data = self.data_npz['sequence'] # Array of encoded DNA bases (0-4)
            # IMPORTANT: Corrected key to 'expressed_plus' based on your data.npz keys
            self.expression_plus_data = self.data_npz['expressed_plus'] # Array of expression labels for forward strand (0 or 1)
            # IMPORTANT: Corrected key to 'expressed_minus' based on your data.npz keys
            self.expression_minus_data = self.data_npz['expressed_minus'] # Array of expression labels for reverse strand (0 or 1)
        except Exception as e:
            raise RuntimeError(f"Could not load data from {self.data_npz_path}. Make sure the file exists and is not corrupted: {e}")

        # Load regions.parquet containing metadata for each genomic region
        try:
            self.regions_df = pd.read_parquet(self.regions_parquet_path)
        except Exception as e:
            raise RuntimeError(f"Could not load regions from {self.regions_parquet_path}. Make sure the file exists and is not corrupted: {e}")

        # Number of unique nucleotides/channels for one-hot encoding (A, C, G, T, N)
        self.num_nucleotides = 5

    def __len__(self):
        """
        Returns the total number of samples (genomic regions) in the dataset.
        This is determined by the number of rows in the regions DataFrame.
        """
        return len(self.regions_df)

    def _one_hot_encode(self, sequence_segment):
        """
        Converts a sequence segment (array of integer encodings) into a one-hot encoded tensor.
        Example: [0, 1, 2] (A, C, G) -> [[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0]]
        """
        one_hot_tensor = torch.zeros(len(sequence_segment), self.num_nucleotides, dtype=torch.float32)
        one_hot_tensor.scatter_(1, torch.tensor(sequence_segment).unsqueeze(1).long(), 1)
        return one_hot_tensor

    def __getitem__(self, idx):
        """
        Retrieves a single data sample (sequence, expression label, and metadata) by its index.

        Args:
            idx (int or torch.Tensor): The index of the region to retrieve.

        Returns:
            tuple: A tuple containing:
                - encoded_sequence (torch.Tensor): The one-hot encoded DNA sequence segment.
                - expression_label (torch.Tensor): The expression label (0 or 1) for the region.
                - region_info (dict): A dictionary containing metadata for the region (e.g., contig, strand, offset).
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()

        region_info = self.regions_df.iloc[idx]

        offset = region_info['offset'] #
        window_size = region_info['window_size'] #
        strand = region_info['strand']

        sequence_segment = self.sequence_data[offset : offset + window_size]
        encoded_sequence = self._one_hot_encode(sequence_segment)

        if strand == '+':
            expression_label = self.expression_plus_data[offset]
        else:
            expression_label = self.expression_minus_data[offset]

        expression_label = torch.tensor(expression_label, dtype=torch.long)

        return encoded_sequence, expression_label, region_info.to_dict()


# --- Instantiating and trying out the class ---
print("\n--- Instantiating and Testing GenomeExpressionDataset ---")
try:
    # 1. Instantiate the dataset:
    # Pass the 'data_dir' variable which points to your 'data' folder in Google Drive.
    my_dataset = GenomeExpressionDataset(data_dir=data_dir)
    print(f"Successfully instantiated GenomeExpressionDataset.")
    print(f"Total number of samples (regions) in the dataset: {len(my_dataset)}")

    # 2. Access a single sample using __getitem__
    # You can access individual samples by their index, like a list.
    print("\n--- Accessing individual samples ---")
    sample_index_0 = 0 # Get the first sample
    sequence_0, label_0, metadata_0 = my_dataset[sample_index_0]
    print(f"Sample at index {sample_index_0}:")
    print(f"  Sequence shape: {sequence_0.shape} (One-hot encoded)")
    print(f"  Label: {label_0.item()}") # .item() gets the scalar value from a 0-dim tensor
    print(f"  Metadata: {metadata_0}")

    # Get another sample, e.g., at a different index
    sample_index_500 = 500 # Get the 501st sample
    sequence_500, label_500, metadata_500 = my_dataset[sample_index_500]
    print(f"\nSample at index {sample_index_500}:")
    print(f"  Sequence shape: {sequence_500.shape} (One-hot encoded)")
    print(f"  Label: {label_500.item()}")
    print(f"  Metadata: {metadata_500}")

    # 3. Use it with PyTorch's DataLoader
    # The DataLoader is what typically iterates over your dataset in batches during training.
    print("\n--- Using DataLoader to get batches ---")
    batch_size = 8 # Define your desired batch size
    # num_workers=0 is usually recommended for debugging in Colab to avoid multiprocessing issues.
    # Set to >0 for faster data loading in production.
    data_loader = DataLoader(my_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    # Iterate through a few batches
    for batch_idx, (sequences, labels, metadata) in enumerate(data_loader):
        print(f"\n--- Batch {batch_idx + 1} ---")
        print(f"  Batch of sequences shape: {sequences.shape}") # (batch_size, window_size, num_nucleotides)
        print(f"  Batch of labels shape: {labels.shape}")       # (batch_size,)
        print(f"  Labels in this batch: {labels.numpy()}") # .numpy() to see the values easily
        print(f"  Metadata for first sample in batch:")
        print(f"    Contig: {metadata['contig'][0]}, Strand: {metadata['strand'][0]}, Offset: {metadata['offset'][0]}")

        if batch_idx >= 1: # Print only the first 2 batches for brevity
            break

except RuntimeError as e:
    print(f"\nError during dataset instantiation or usage: {e}")
    print("Please ensure your Google Drive path is correct and data files are accessible.")
except Exception as e:
    print(f"\nAn unexpected error occurred: {e}")

Mounted at /content/gdrive
Attempting to access data in: /content/gdrive/MyDrive/DnARnAProject/data/
--- Inspecting data.npz ---
Keys in data.npz: ['sequence', 'expressed_plus', 'expressed_minus']
Sequence array shape: (12157105,)
Sequence array dtype: uint8
First 100 sequence values: [1 1 0 1 0 1 1 0 1 0 1 1 1 0 1 0 1 0 1 1 1 0 1 0 1 0 1 1 0 1 0 1 1 0 1 0 1
 0 1 1 0 1 0 1 1 0 1 0 1 1 1 0 1 0 1 0 1 0 1 0 1 0 3 1 1 3 0 0 1 0 1 3 0 1
 1 1 3 0 0 1 0 1 0 2 1 1 1 3 0 0 3 1 3 0 0 1 1 1 3 2]
Unique values in sequence (sample): [0 1 2 3]
Base counts in entire sequence array:
{np.uint8(1): 2320576, np.uint8(0): 3766349, np.uint8(3): 3753080, np.uint8(2): 2317100}
Unique values in first 1000 bases: [0 1 2 3]
Unique values in first 10000 bases: [0 1 2 3]
Expression_plus array shape: (12157105,)
Expression_plus array dtype: uint8
First 10 expression_plus values: [0 0 0 0 0 0 0 0 0 0]
Unique values in expression_plus (sample): [0 1]
Expression_minus array shape: (12157105,)
Expression_minus array d