In [5]:
import numpy as np
import pandas as pd
import os
import torch
from torch.utils.data import Dataset, DataLoader
# from Bio import SeqIO # Uncomment if you want to inspect GFF3

# --- STEP 1: Mount Google Drive ---
# This will open a pop-up. You'll need to authorize Google Drive access.
from google.colab import drive
drive.mount('/content/gdrive')

# --- Configuration ---
# Adjust this path based on your Google Drive structure
# After mounting, your Google Drive root is typically /content/gdrive/MyDrive/
# Then, navigate to your project folder and the 'data' subfolder.

# IMPORTANT: Replace 'Joint Modelling of DnA and RnA PROJECT' with the exact name
# of your project folder in Google Drive if it's different.
google_drive_project_path = '/content/gdrive/MyDrive/Joint Modelling of DnA and RnA PROJECT/'
data_dir = os.path.join(google_drive_project_path, 'data/')

print(f"Attempting to access data in: {data_dir}")

# --- Verify the data directory exists ---
if not os.path.isdir(data_dir):
    print(f"Error: The directory '{data_dir}' does not exist.")
    print("Please check your Google Drive path and folder names.")
    print("Listing contents of your project folder for debugging:")
    try:
        # List contents of the project folder to help debug
        print(os.listdir(google_drive_project_path))
    except FileNotFoundError:
        print(f"Project folder '{google_drive_project_path}' not found either. Check the full path.")
    exit() # Exit the script if the data directory isn't found


# --- 1. Inspect data.npz ---
print("--- Inspecting data.npz ---")
try:
    data_npz_path = os.path.join(data_dir, 'data.npz')
    data_npz = np.load(data_npz_path)

    print(f"Keys in data.npz: {list(data_npz.keys())}")

    if 'sequence' in data_npz:
        seq_array = data_npz['sequence']
        print(f"Sequence array shape: {seq_array.shape}")
        print(f"Sequence array dtype: {seq_array.dtype}")
        print(f"First 10 sequence values: {seq_array[:10]}")
        unique_seq_values = np.unique(seq_array[:10000])
        print(f"Unique values in sequence (sample): {unique_seq_values}")

    if 'expression_plus' in data_npz:
        expr_plus_array = data_npz['expression_plus']
        print(f"Expression_plus array shape: {expr_plus_array.shape}")
        print(f"Expression_plus array dtype: {expr_plus_array.dtype}")
        print(f"First 10 expression_plus values: {expr_plus_array[:10]}")
        unique_expr_values = np.unique(expr_plus_array[:10000])
        print(f"Unique values in expression_plus (sample): {unique_expr_values}")

    if 'expression_minus' in data_npz:
        expr_minus_array = data_npz['expression_minus']
        print(f"Expression_minus array shape: {expr_minus_array.shape}")
        print(f"Expression_minus array dtype: {expr_minus_array.dtype}")
        print(f"First 10 expression_minus values: {expr_minus_array[:10]}")
        unique_expr_values = np.unique(expr_minus_array[:10000])
        print(f"Unique values in expression_minus (sample): {unique_expr_values}")

except FileNotFoundError:
    print(f"Error: {data_npz_path} not found. This typically means the path is incorrect after mounting Drive.")
    print("Double-check the folder names in your Google Drive for typos.")
except Exception as e:
    print(f"An error occurred while loading data.npz: {e}")

# --- 2. Inspect regions.parquet ---
print("\n--- Inspecting regions.parquet ---")
try:
    regions_parquet_path = os.path.join(data_dir, 'regions.parquet')
    regions_df = pd.read_parquet(regions_parquet_path)

    print(f"Regions DataFrame shape: {regions_df.shape}")
    print("Regions DataFrame head:")
    print(regions_df.head())
    print("\nRegions DataFrame info:")
    regions_df.info()

    print(f"\nUnique strands: {regions_df['strand'].unique()}")

except FileNotFoundError:
    print(f"Error: {regions_parquet_path} not found. This typically means the path is incorrect after mounting Drive.")
except Exception as e:
    print(f"An error occurred while loading regions.parquet: {e}")


# --- 3. Inspect ensembl_annotation.gff3 (Optional) ---
print("\n--- Inspecting ensembl_annotation.gff3 (Optional) ---")
try:
    gff3_path = os.path.join(data_dir, 'ensembl_annotation.gff3')
    with open(gff3_path, "r") as handle:
        gff_lines = [next(handle) for _ in range(10)]
        print("First 10 lines of GFF3 file:")
        for line in gff_lines:
            print(line.strip())

except FileNotFoundError:
    print(f"Error: {gff3_path} not found. This file is optional for the DataLoader.")
except Exception as e:
    print(f"An error occurred while inspecting ensembl_annotation.gff3: {e}")


# --- Custom PyTorch DataLoader Class (same as before) ---
class GenomeExpressionDataset(Dataset):
    """
    Custom Dataset for loading DNA sequence and expression data.
    """
    def __init__(self, data_dir):
        self.data_dir = data_dir

        self.data_npz_path = os.path.join(data_dir, 'data.npz')
        try:
            self.data_npz = np.load(self.data_npz_path)
            self.sequence_data = self.data_npz['sequence']
            self.expression_plus_data = self.data_npz['expression_plus']
            self.expression_minus_data = self.data_npz['expression_minus']
        except Exception as e:
            raise RuntimeError(f"Could not load data from {self.data_npz_path}: {e}")

        self.regions_parquet_path = os.path.join(data_dir, 'regions.parquet')
        try:
            self.regions_df = pd.read_parquet(self.regions_parquet_path)
        except Exception as e:
            raise RuntimeError(f"Could not load regions from {self.regions_parquet_path}: {e}")

        self.num_nucleotides = 5

    def __len__(self):
        return len(self.regions_df)

    def _one_hot_encode(self, sequence_segment):
        one_hot_tensor = torch.zeros(len(sequence_segment), self.num_nucleotides, dtype=torch.float32)
        one_hot_tensor.scatter_(1, torch.tensor(sequence_segment).unsqueeze(1).long(), 1)
        return one_hot_tensor

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        region_info = self.regions_df.iloc[idx]

        offset = region_info['offset']
        window_size = region_info['window_size']
        strand = region_info['strand']

        sequence_segment = self.sequence_data[offset : offset + window_size]
        encoded_sequence = self._one_hot_encode(sequence_segment)

        if strand == '+':
            expression_label = self.expression_plus_data[offset]
        else:
            expression_label = self.expression_minus_data[offset]

        expression_label = torch.tensor(expression_label, dtype=torch.long)

        return encoded_sequence, expression_label, region_info.to_dict()


# --- Example Usage of the DataLoader ---
print("\n--- Example Usage of the DataLoader ---")
try:
    # Use the data_dir defined after mounting Drive
    dataset = GenomeExpressionDataset(data_dir=data_dir)
    print(f"Dataset contains {len(dataset)} regions.")

    batch_size = 4
    # Using num_workers=0 in Colab is often safer initially to debug
    # multiprocessing issues, then set to >0 for performance.
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    for i, (sequences, labels, metadata) in enumerate(data_loader):
        print(f"\nBatch {i+1}:")
        print(f"Sequences batch shape: {sequences.shape} (Batch, SeqLength, Channels)")
        print(f"Labels batch shape: {labels.shape} (Batch)")
        print(f"Labels (first batch): {labels}")
        print(f"Metadata for first sample in batch: {metadata['contig'][0]}, strand: {metadata['strand'][0]}, offset: {metadata['offset'][0]}")

        if i == 0:
            break

except RuntimeError as e:
    print(f"Error during DataLoader usage: {e}")
except Exception as e:
    print(f"An unexpected error occurred during DataLoader usage: {e}")

Mounted at /content/gdrive
Attempting to access data in: /content/gdrive/MyDrive/Joint Modelling of DnA and RnA PROJECT/data/
--- Inspecting data.npz ---
Keys in data.npz: ['sequence', 'expressed_plus', 'expressed_minus']
Sequence array shape: (12157105,)
Sequence array dtype: uint8
First 10 sequence values: [1 1 0 1 0 1 1 0 1 0]
Unique values in sequence (sample): [0 1 2 3]

--- Inspecting regions.parquet ---
Regions DataFrame shape: (15705825, 6)
Regions DataFrame head:
  contig strand  start  offset  window_size  num_expressed
0   chrI      +   7655    7655         2048            205
1   chrI      +   7656    7656         2048            206
2   chrI      +   7657    7657         2048            207
3   chrI      +   7658    7658         2048            208
4   chrI      +   7659    7659         2048            209

Regions DataFrame info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 15705825 entries, 0 to 15705824
Data columns (total 6 columns):
 #   Column         Dtype 
---