In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt # Import for plotting

# Assuming your modified model is in 'model_transf.py'
from model_transf import DNASequenceClassifier
# Assuming your data loader is in 'data_loader.py' or similar file structure.
# For this example, I'll copy the necessary parts of your data loader here,
# but in practice, you should import it.
import numpy as np
import pandas as pd
import os
from torch.utils.data import Dataset # Import Dataset class

# --- Your Data Loader Code (as provided previously) ---
# This part is copied directly from your original data loader.
# In a real project, you'd usually import this.
def get_data_dir():
    """
    Determines the correct data directory based on the execution environment.
    This function makes the code portable between Colab, local PyCharm, and cluster.
    """
    # Default path for local/cluster (adjust this for your specific setup if needed)
    # For PyCharm, this assumes 'data' folder is in your project root.
    # For cluster, you'd typically set it to the absolute path where your 'data' folder resides.
    local_or_cluster_project_path = os.getcwd()
    data_directory = os.path.join(local_or_cluster_project_path, 'data/')

    # Attempt to detect Google Colab
    try:
        from google.colab import drive
        # If drive is importable, we are likely in Colab.
        # This will only mount if not already mounted.
        drive.mount('/content/gdrive')
        google_drive_project_path = '/content/gdrive/MyDrive/DnARnAProject/'
        data_directory = os.path.join(google_drive_project_path, 'data/')
        print("Detected Google Colab environment. Using Google Drive path.")
    except ImportError:
        print("Not in Google Colab. Using local/cluster path.")

    if not os.path.isdir(data_directory):
        # Create data directory if it doesn't exist, useful for local testing if you then put data in it
        # Or raise an error if data is expected to be present.
        # os.makedirs(data_directory, exist_ok=True)
        # print(f"Created data directory: {data_directory}")
        raise FileNotFoundError(f"Error: The data directory '{data_directory}' does not exist. "
                                "Please ensure your data is located correctly for your environment.")
    return data_directory


class GenomeExpressionDataset(Dataset):
    """
    Custom Dataset for loading DNA sequence and expression data for genomic regions.
    It loads data from pre-processed .npz and .parquet files.
    Handles reverse complement for '-' strand sequences and uses appropriate expression labels.
    """
    def __init__(self, data_dir):
        """
        Initializes the dataset by loading the full sequence and expression arrays
        and the DataFrame of genomic regions.

        Args:
            data_dir (str): The path to the directory containing 'data.npz' and 'regions.parquet'.
        """
        self.data_dir = data_dir

        self.data_npz_path = os.path.join(data_dir, 'data.npz')
        self.regions_parquet_path = os.path.join(data_dir, 'regions.parquet')

        try:
            # Added allow_pickle=True as it can sometimes be necessary for .npz files
            self.data_npz = np.load(self.data_npz_path, allow_pickle=True)
            self.sequence_data = self.data_npz['sequence']
            self.expression_plus_data = self.data_npz['expressed_plus']
            self.expression_minus_data = self.data_npz['expressed_minus']
            self.data_npz.close() # Important to close the loaded npz file handle
        except KeyError as e:
            # Provide more informative error if a key is missing
            available_keys = list(np.load(self.data_npz_path).keys()) if os.path.exists(self.data_npz_path) else "File not found during key check."
            raise RuntimeError(f"KeyError: Key '{e}' not found in {self.data_npz_path}. "
                               f"Available keys: {available_keys}. "
                               "Please check your .npz file structure.")
        except Exception as e:
            raise RuntimeError(f"Could not load data from {self.data_npz_path}. Make sure the file exists and is not corrupted: {e}")

        try:
            self.regions_df = pd.read_parquet(self.regions_parquet_path)
        except Exception as e:
            raise RuntimeError(f"Could not load regions from {self.regions_parquet_path}. Make sure the file exists and is not corrupted: {e}")

        self.num_nucleotides = 5 # A, C, G, T, N (mapped to 0, 1, 2, 3, 4)

        # Define the complement mapping for integer-encoded bases
        # Assuming A=0, C=1, G=2, T=3, N=4
        # Complement: A<->T, C<->G, N<->N
        # 0<->3, 1<->2, 4<->4
        self.complement_map = np.array([3, 2, 1, 0, 4], dtype=np.uint8)


    def __len__(self):
        return len(self.regions_df)

    def _one_hot_encode(self, sequence_segment):
        """
        Converts a sequence segment (array of integer encodings) into a one-hot encoded tensor.
        """
        one_hot_tensor = torch.zeros(len(sequence_segment), self.num_nucleotides, dtype=torch.float32)
        one_hot_tensor.scatter_(1, torch.tensor(sequence_segment).unsqueeze(1).long(), 1)
        return one_hot_tensor

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        region_info = self.regions_df.iloc[idx]

        offset = region_info['offset']
        window_size = region_info['window_size']
        strand = region_info['strand']

        # Extract sequence segment
        sequence_segment = self.sequence_data[offset : offset + window_size].copy() # .copy() to avoid modifying original array

        # Determine the expression label and prepare sequence based on strand
        if strand == '+':
            encoded_sequence = self._one_hot_encode(sequence_segment)
            expression_label = self.expression_plus_data[offset] # Use label at the start of the window from plus strand data
        else: # strand == '-'
            # Reverse complement the sequence
            reverse_complemented_sequence = self.complement_map[sequence_segment][::-1].copy()
            encoded_sequence = self._one_hot_encode(reverse_complemented_sequence)
            expression_label = self.expression_minus_data[offset] # Use label at the start of the window from minus strand data

        # Ensure expression_label is float for BCEWithLogitsLoss
        expression_label = torch.tensor(expression_label, dtype=torch.float32)

        # Return sequence and label
        return encoded_sequence, expression_label

# --- End of Data Loader Code ---


# --- Hyperparameters for your DNA Sequence Classifier ---
# These should be tuned!
NUM_NUCLEOTIDES = 5 # A, C, G, T, N (from your data loader)
D_MODEL = 128      # Dimension of model (e.g., 128, 256, 512)
NUM_HEADS = 8      # Number of attention heads (should divide D_MODEL)
NUM_LAYERS = 3     # Number of encoder layers
D_FF = 512         # Dimension of feed-forward network (typically 2*D_MODEL or 4*D_MODEL)
# Max sequence length: This should be your window_size from the data.
# It's crucial for PositionalEncoding to match your data's window size.
# You'll need to get this from your dataset or define it based on your data generation.
# Let's assume a typical window size for now, e.g., 200 or 500
MAX_SEQ_LENGTH = 200 # <-- IMPORTANT: Set this to your actual window_size from data preprocessing
DROPOUT = 0.1

BATCH_SIZE = 64
LEARNING_RATE = 0.0001
NUM_EPOCHS = 50 # You might need many more epochs

# --- Device Configuration ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Initialize Dataset and DataLoader ---
data_dir = get_data_dir() # Get your data directory
full_dataset = GenomeExpressionDataset(data_dir)

# Get the actual window_size from the first item in the dataset to set MAX_SEQ_LENGTH dynamically
# This assumes all sequences have the same window_size.
if len(full_dataset) > 0:
    sample_seq, _ = full_dataset[0]
    MAX_SEQ_LENGTH = sample_seq.shape[0] # sequence_length (window_size)
    print(f"Detected sequence length (window_size): {MAX_SEQ_LENGTH}")
else:
    print("Warning: Dataset is empty, cannot determine MAX_SEQ_LENGTH dynamically. Using default.")


# --- Dataset Split: Train, Validation, Test ---
# Define the proportions for the splits
train_ratio = 0.8
val_ratio = 0.1
test_ratio = 0.1

# Calculate the sizes for each split
total_size = len(full_dataset)
train_size = int(train_ratio * total_size)
val_size = int(val_ratio * total_size)
test_size = total_size - train_size - val_size # Ensure all samples are accounted for

# Perform the split
train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

# Create DataLoaders for each set
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

print(f"Dataset split: Train={len(train_dataset)}, Validation={len(val_dataset)}, Test={len(test_dataset)}")


# --- Initialize Model, Loss, and Optimizer ---
model = DNASequenceClassifier(
    input_features=NUM_NUCLEOTIDES,
    d_model=D_MODEL,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    d_ff=D_FF,
    max_seq_length=MAX_SEQ_LENGTH,
    dropout=DROPOUT
).to(device)

# For binary classification, BCEWithLogitsLoss is appropriate as it combines sigmoid and BCELoss for numerical stability.
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.98), eps=1e-9)


# --- Lists to store metrics for plotting ---
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []


# --- Training Loop ---
print("\nStarting training...")
for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    for batch_idx, (sequences, labels) in enumerate(train_loader):
        sequences = sequences.to(device) # Shape: (batch_size, window_size, 5)
        labels = labels.to(device).unsqueeze(1) # Ensure labels are (batch_size, 1)

        optimizer.zero_grad()
        logits = model(sequences) # Output logits (batch_size, 1)

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Calculate accuracy for monitoring
        predictions = (torch.sigmoid(logits) > 0.5).float() # Apply sigmoid and threshold
        correct_predictions += (predictions == labels).sum().item()
        total_samples += labels.size(0)

        if (batch_idx + 1) % 100 == 0:
            print(f"  Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}")

    avg_train_loss = total_loss / len(train_loader)
    train_accuracy = correct_predictions / total_samples
    train_losses.append(avg_train_loss)
    train_accuracies.append(train_accuracy)

    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}:")
    print(f"  Train Loss: {avg_train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")

    # --- Validation Loop ---
    model.eval()
    val_loss = 0
    val_correct_predictions = 0
    val_total_samples = 0
    with torch.no_grad():
        for sequences, labels in val_loader:
            sequences = sequences.to(device)
            labels = labels.to(device).unsqueeze(1)

            logits = model(sequences)
            loss = criterion(logits, labels)
            val_loss += loss.item()

            predictions = (torch.sigmoid(logits) > 0.5).float()
            val_correct_predictions += (predictions == labels).sum().item()
            val_total_samples += labels.size(0)

    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = val_correct_predictions / val_total_samples
    val_losses.append(avg_val_loss)
    val_accuracies.append(val_accuracy)

    print(f"  Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

print("\nTraining complete!")


# --- Test Evaluation ---
print("\nStarting test evaluation...")
model.eval() # Set model to evaluation mode
test_loss = 0
test_correct_predictions = 0
test_total_samples = 0
with torch.no_grad(): # Disable gradient calculations
    for sequences, labels in test_loader:
        sequences = sequences.to(device)
        labels = labels.to(device).unsqueeze(1)

        logits = model(sequences)
        loss = criterion(logits, labels)
        test_loss += loss.item()

        predictions = (torch.sigmoid(logits) > 0.5).float()
        test_correct_predictions += (predictions == labels).sum().item()
        test_total_samples += labels.size(0)

avg_test_loss = test_loss / len(test_loader)
test_accuracy = test_correct_predictions / test_total_samples
print(f"Test Loss: {avg_test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")


# --- Plotting Training and Validation Metrics ---
epochs_range = range(1, NUM_EPOCHS + 1)

plt.figure(figsize=(12, 5))

# Plot Loss
plt.subplot(1, 2, 1) # 1 row, 2 columns, first plot
plt.plot(epochs_range, train_losses, label='Training Loss')
plt.plot(epochs_range, val_losses, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Plot Accuracy
plt.subplot(1, 2, 2) # 1 row, 2 columns, second plot
plt.plot(epochs_range, train_accuracies, label='Training Accuracy')
plt.plot(epochs_range, val_accuracies, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)

plt.tight_layout() # Adjusts plot parameters for a tight layout
plt.show() # Display the plots


# Save the trained model (optional)
# torch.save(model.state_dict(), "dna_transformer_classifier.pth")
# print("Model saved to dna_transformer_classifier.pth")
