In [1]:
import numpy as np
import pandas as pd
import os
import torch
from collections import Counter
from torch.utils.data import Dataset, DataLoader
# from Bio import SeqIO # Uncomment if you want to inspect GFF3 files (e.g., ensembl_annotation.gff3) later with Biopython

# --- STEP 1: Mount Google Drive ---
from google.colab import drive
drive.mount('/content/gdrive')

# --- Configuration ---
google_drive_project_path = '/content/gdrive/MyDrive/DnARnAProject/'
data_dir = os.path.join(google_drive_project_path, 'data/')

print(f"Attempting to access data in: {data_dir}")

# --- Verify the data directory exists ---
if not os.path.isdir(data_dir):
    print(f"Error: The directory '{data_dir}' does not exist.")
    print("Please check your Google Drive path and folder names for typos.")
    print("Listing contents of your project folder for debugging:")
    try:
        print(os.listdir(google_drive_project_path))
    except FileNotFoundError:
        print(f"Project folder '{google_drive_project_path}' not found either. Check the full path.")


# --- 1. Inspect data.npz ---
print("--- Inspecting data.npz ---")
try:
    data_npz_path = os.path.join(data_dir, 'data.npz')
    data_npz = np.load(data_npz_path)

    print(f"Keys in data.npz: {list(data_npz.keys())}")

    if 'sequence' in data_npz:
        seq_array = data_npz['sequence']
        print(f"Sequence array shape: {seq_array.shape}")
        print(f"Sequence array dtype: {seq_array.dtype}")
        print(f"First 100 sequence values: {seq_array[:100]}")
        unique_seq_values = np.unique(seq_array[:10000])
        print(f"Unique values in sequence (sample): {unique_seq_values}")
        counts = dict(Counter(seq_array))
        print("Base counts in entire sequence array:")
        print(counts)
        print("Unique values in first 1000 bases:", np.unique(seq_array[:1000]))
        print("Unique values in first 10000 bases:", np.unique(seq_array[:10000]))


    if 'expressed_plus' in data_npz:
        expr_plus_array = data_npz['expressed_plus']
        print(f"Expression_plus array shape: {expr_plus_array.shape}")
        print(f"Expression_plus array dtype: {expr_plus_array.dtype}")
        print(f"First 10 expression_plus values: {expr_plus_array[:10]}")
        unique_expr_values = np.unique(expr_plus_array[:10000])
        print(f"Unique values in expression_plus (sample): {unique_expr_values}")

    if 'expressed_minus' in data_npz:
        expr_minus_array = data_npz['expressed_minus']
        print(f"Expression_minus array shape: {expr_minus_array.shape}")
        print(f"Expression_minus array dtype: {expr_minus_array.dtype}")
        print(f"First 10 expression_minus values: {expr_minus_array[:10]}")
        unique_expr_values = np.unique(expr_minus_array[:10000])
        print(f"Unique values in expression_minus (sample): {unique_expr_values}")

except FileNotFoundError:
    print(f"Error: {data_npz_path} not found. This typically means the path is incorrect after mounting Drive.")
    print("Double-check the folder names in your Google Drive for typos.")
except Exception as e:
    print(f"An error occurred while loading data.npz: {e}")

# --- 2. Inspect regions.parquet ---
print("\n--- Inspecting regions.parquet ---")
try:
    regions_parquet_path = os.path.join(data_dir, 'regions.parquet')
    regions_df = pd.read_parquet(regions_parquet_path)

    print(f"Regions DataFrame shape: {regions_df.shape}")
    print("Regions DataFrame head:")
    print(regions_df.head())
    print("\nRegions DataFrame info:")
    regions_df.info()

    print(f"\nUnique strands: {regions_df['strand'].unique()}")

except FileNotFoundError:
    print(f"Error: {regions_parquet_path} not found. This typically means the path is incorrect after mounting Drive.")
except Exception as e:
    print(f"An error occurred while loading regions.parquet: {e}")


# --- 3. Inspect ensembl_annotation.gff3 (Optional) ---
print("\n--- Inspecting ensembl_annotation.gff3 (Optional) ---")
try:
    gff3_path = os.path.join(data_dir, 'ensembl_annotation.gff3')
    with open(gff3_path, "r") as handle:
        gff_lines = [next(handle) for _ in range(10)]
        print("First 10 lines of GFF3 file:")
        for line in gff_lines:
            print(line.strip())

except FileNotFoundError:
    print(f"Error: {gff3_path} not found. This file is optional for the DataLoader.")
except Exception as e:
    print(f"An error occurred while inspecting ensembl_annotation.gff3: {e}")


# --- Custom PyTorch DataLoader Class (modified) ---
class GenomeExpressionDataset(Dataset):
    """
    Custom Dataset for loading DNA sequence and expression data for genomic regions.
    It loads data from pre-processed .npz and .parquet files.
    Handles reverse complement for '-' strand sequences and reverses expression labels.
    """
    def __init__(self, data_dir):
        """
        Initializes the dataset by loading the full sequence and expression arrays
        and the DataFrame of genomic regions.

        Args:
            data_dir (str): The path to the directory containing 'data.npz' and 'regions.parquet'.
        """
        self.data_dir = data_dir

        self.data_npz_path = os.path.join(data_dir, 'data.npz')
        self.regions_parquet_path = os.path.join(data_dir, 'regions.parquet')

        try:
            self.data_npz = np.load(self.data_npz_path)
            self.sequence_data = self.data_npz['sequence']
            self.expression_plus_data = self.data_npz['expressed_plus']
            self.expression_minus_data = self.data_npz['expressed_minus']
        except Exception as e:
            raise RuntimeError(f"Could not load data from {self.data_npz_path}. Make sure the file exists and is not corrupted: {e}")

        try:
            self.regions_df = pd.read_parquet(self.regions_parquet_path)
        except Exception as e:
            raise RuntimeError(f"Could not load regions from {self.regions_parquet_path}. Make sure the file exists and is not corrupted: {e}")

        self.num_nucleotides = 5 # A, C, G, T, N (mapped to 0, 1, 2, 3, 4)

        # Define the complement mapping for integer-encoded bases
        # Assuming A=0, C=1, G=2, T=3, N=4
        # Complement: A<->T, C<->G, N<->N
        # 0<->3, 1<->2, 4<->4
        self.complement_map = np.array([3, 2, 1, 0, 4], dtype=np.uint8)


    def __len__(self):
        return len(self.regions_df)

    def _one_hot_encode(self, sequence_segment):
        """
        Converts a sequence segment (array of integer encodings) into a one-hot encoded tensor.
        """
        one_hot_tensor = torch.zeros(len(sequence_segment), self.num_nucleotides, dtype=torch.float32)
        one_hot_tensor.scatter_(1, torch.tensor(sequence_segment).unsqueeze(1).long(), 1)
        return one_hot_tensor

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        region_info = self.regions_df.iloc[idx]

        offset = region_info['offset']
        window_size = region_info['window_size']
        strand = region_info['strand']

        # Extract sequence segment
        sequence_segment = self.sequence_data[offset : offset + window_size].copy() # .copy() to avoid modifying original array

        # Extract expression segment (for the whole window, even if we only use the center label)
        # We extract the window for expression data to allow for potential future
        # modeling of expression across the window, or just for consistency in handling.
        # For now, we only care about the label at `offset`.
        expression_segment_plus = self.expression_plus_data[offset : offset + window_size].copy()
        expression_segment_minus = self.expression_minus_data[offset : offset + window_size].copy()


        if strand == '+':
            # For the forward strand, sequence and expression are used as is.
            encoded_sequence = self._one_hot_encode(sequence_segment)
            expression_label = expression_segment_plus[0] # Use label at the start of the window
        else: # strand == '-'
            # 1. Reverse complement the sequence
            # First, complement the bases, then reverse the order
            reverse_complemented_sequence = self.complement_map[sequence_segment][::-1].copy()
            encoded_sequence = self._one_hot_encode(reverse_complemented_sequence)

            # 2. Reverse the expression segment and take the appropriate label
            # If the sequence is reversed, the corresponding expression values
            # should also be considered in reverse order to maintain alignment.
            # The label for the '-' strand region is associated with the 'offset'
            # when read from the original chromosome (forward strand coordinates).
            # When we flip the sequence, the *first* base of the reverse complement
            # corresponds to the *last* base of the original segment.
            # So, if expression_segment_minus was [val_at_offset, ..., val_at_offset+window_size-1],
            # and we reverse it, the label would be the last element of the original segment.
            # However, the problem statement often means that the label for the region
            # (which is defined by its start 'offset') should be taken from the
            # 'expressed_minus' array at that *same* offset.
            # Let's stick to using the `offset` index for the label as it's defined
            # in your `regions.parquet` and the original implementation.
            expression_label = expression_segment_minus[0] # Use label at the start of the window from the minus strand data


        expression_label = torch.tensor(expression_label, dtype=torch.long)

        return encoded_sequence, expression_label, region_info.to_dict()


# --- Instantiating and trying out the class ---
print("\n--- Instantiating and Testing GenomeExpressionDataset ---")
try:
    my_dataset = GenomeExpressionDataset(data_dir=data_dir)
    print(f"Successfully instantiated GenomeExpressionDataset.")
    print(f"Total number of samples (regions) in the dataset: {len(my_dataset)}")

    print("\n--- Accessing individual samples ---")
    # Test a sample from the '+' strand
    sample_plus_idx = regions_df[regions_df['strand'] == '+'].index[0]
    sequence_plus, label_plus, metadata_plus = my_dataset[sample_plus_idx]
    print(f"Sample at index {sample_plus_idx} (strand +):")
    print(f"  Sequence shape: {sequence_plus.shape} (One-hot encoded)")
    print(f"  Label: {label_plus.item()}")
    print(f"  Metadata: {metadata_plus}")

    # Test a sample from the '-' strand
    sample_minus_idx = regions_df[regions_df['strand'] == '-'].index[0]
    sequence_minus, label_minus, metadata_minus = my_dataset[sample_minus_idx]
    print(f"\nSample at index {sample_minus_idx} (strand -):")
    print(f"  Sequence shape: {sequence_minus.shape} (One-hot encoded)")
    print(f"  Label: {label_minus.item()}")
    print(f"  Metadata: {metadata_minus}")

    # You can add checks here to confirm reverse complement if you want,
    # e.g., by converting sequence_plus back to integer array and comparing
    # its reverse complement to the integer array of sequence_minus (if they were designed to be complements).

    # 3. Use it with PyTorch's DataLoader
    print("\n--- Using DataLoader to get batches ---")
    batch_size = 8
    data_loader = DataLoader(my_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    for batch_idx, (sequences, labels, metadata) in enumerate(data_loader):
        print(f"\n--- Batch {batch_idx + 1} ---")
        print(f"  Batch of sequences shape: {sequences.shape}")
        print(f"  Batch of labels shape: {labels.shape}")
        print(f"  Labels in this batch: {labels.numpy()}")
        print(f"  Metadata for first sample in batch:")
        print(f"    Contig: {metadata['contig'][0]}, Strand: {metadata['strand'][0]}, Offset: {metadata['offset'][0]}")

        if batch_idx >= 1:
            break

except RuntimeError as e:
    print(f"\nError during dataset instantiation or usage: {e}")
    print("Please ensure your Google Drive path is correct and data files are accessible.")
except Exception as e:
    print(f"\nAn unexpected error occurred: {e}")

ModuleNotFoundError: No module named 'google.colab'