In [22]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import os
from pathlib import Path
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
# Import torch.cuda.amp for Automatic Mixed Precision
from torch.cuda.amp import autocast, GradScaler

# Assuming your modified model is in 'model_transf.py'
# Make sure you have the DNASequenceClassifier in model_transf.py
# (The model_transf.py file contents were provided in a previous turn when we modified model.py)
from model_transf import DNASequenceClassifier

# --- Your Data Loader Code (as provided previously) ---
# This part is copied directly from your original data loader.
# In a real project, you'd usually import this.
def get_data_dir():
    """
    Determines the correct data directory based on the execution environment.
    This function makes the code portable between Colab, local PyCharm, and cluster.
    """
    # local_or_cluster_project_path = os.getcwd()  # Adjust this for your specific setup if n
    # data_directory = os.path.join(local_or_cluster_project_path, 'data/')
    # data_directory = Path(local_or_cluster_project_path).parent.parent / "data"
    cwd = os.getcwd()
    data_directory = os.path.join(cwd, 'data/')

    try:
        from google.colab import drive
        drive.mount('/content/gdrive')
        google_drive_project_path = '/content/gdrive/MyDrive/DnARnAProject/'
        data_directory = os.path.join(google_drive_project_path, 'data/')
        print("Detected Google Colab environment. Using Google Drive path.")
    except ImportError:
        print("Not in Google Colab. Using local/cluster path.")

    if not os.path.isdir(data_directory):
        raise FileNotFoundError(f"Error: The data directory '{data_directory}' does not exist. "
                                "Please ensure your data is located correctly for your environment.")
    return data_directory


class GenomeExpressionDataset(Dataset):
    """
    Custom Dataset for loading DNA sequence and expression data for genomic regions.
    It loads data from pre-processed .npz and .parquet files.
    Handles reverse complement for '-' strand sequences and uses appropriate expression labels.
    """
    def __init__(self, data_dir):
        """
        Initializes the dataset by loading the full sequence and expression arrays
        and the DataFrame of genomic regions.

        Args:
            data_dir (str): The path to the directory containing 'data.npz' and 'regions.parquet'.
        """
        self.data_dir = data_dir

        self.data_npz_path = os.path.join(data_dir, 'data.npz')
        self.regions_parquet_path = os.path.join(data_dir, 'regions.parquet')

        try:
            self.data_npz = np.load(self.data_npz_path, allow_pickle=True)
            self.sequence_data = self.data_npz['sequence']
            self.expression_plus_data = self.data_npz['expressed_plus']
            self.expression_minus_data = self.data_npz['expressed_minus']
            self.data_npz.close()
        except KeyError as e:
            available_keys = list(np.load(self.data_npz_path).keys()) if os.path.exists(self.data_npz_path) else "File not found during key check."
            raise RuntimeError(f"KeyError: Key '{e}' not found in {self.data_npz_path}. "
                               f"Available keys: {available_keys}. "
                               "Please check your .npz file structure.")
        except Exception as e:
            raise RuntimeError(f"Could not load data from {self.data_npz_path}. Make sure the file exists and is not corrupted: {e}")

        try:
            self.regions_df = pd.read_parquet(self.regions_parquet_path)
        except Exception as e:
            raise RuntimeError(f"Could not load regions from {self.regions_parquet_path}. Make sure the file exists and is not corrupted: {e}")

        self.num_nucleotides = 5 # A, C, G, T, N (mapped to 0, 1, 2, 3, 4)
        self.complement_map = np.array([3, 2, 1, 0, 4], dtype=np.uint8)


    def __len__(self):
        return len(self.regions_df)

    def _one_hot_encode(self, sequence_segment):
        """
        Converts a sequence segment (array of integer encodings) into a one-hot encoded tensor.
        """
        one_hot_tensor = torch.zeros(len(sequence_segment), self.num_nucleotides, dtype=torch.float32)
        one_hot_tensor.scatter_(1, torch.tensor(sequence_segment).unsqueeze(1).long(), 1)
        return one_hot_tensor

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        region_info = self.regions_df.iloc[idx]

        offset = region_info['offset']
        window_size = region_info['window_size']
        strand = region_info['strand']

        sequence_segment = self.sequence_data[offset : offset + window_size].copy()

        if strand == '+':
            encoded_sequence = self._one_hot_encode(sequence_segment)
            expression_label = self.expression_plus_data[offset]
        else: # strand == '-'
            reverse_complemented_sequence = self.complement_map[sequence_segment][::-1].copy()
            encoded_sequence = self._one_hot_encode(reverse_complemented_sequence)
            expression_label = self.expression_minus_data[offset]

        expression_label = torch.tensor(expression_label, dtype=torch.float32)

        return encoded_sequence, expression_label

# --- End of Data Loader Code ---


# --- Hyperparameters for your DNA Sequence Classifier ---
NUM_NUCLEOTIDES = 5 # A, C, G, T, N (from your data loader)
# Reduced D_MODEL and NUM_HEADS to reduce memory footprint
D_MODEL = 64       # Reduced from 128
NUM_HEADS = 4      # Reduced from 8, must divide D_MODEL
NUM_LAYERS = 3
D_FF = 256         # Adjusted D_FF to be 4*D_MODEL (previously 2*D_MODEL or 4*D_MODEL relative to old D_MODEL)
MAX_SEQ_LENGTH = 2048 # Will be updated dynamically based on your data's window_size
DROPOUT = 0.1

# Significantly reduce BATCH_SIZE to combat OutOfMemoryError
BATCH_SIZE = 8 # Drastically reduced from 64
LEARNING_RATE = 0.0001
NUM_EPOCHS = 50

# --- Device Configuration ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Initialize Dataset and DataLoader ---
# data_dir = get_data_dir()
cwd = os.getcwd()
data_dir = os.path.join(cwd, 'data/')
full_dataset = GenomeExpressionDataset(data_dir)

if len(full_dataset) > 0:
    sample_seq, _ = full_dataset[0]
    MAX_SEQ_LENGTH = sample_seq.shape[0] # sequence_length (window_size)
    print(f"Detected sequence length (window_size): {MAX_SEQ_LENGTH}")
else:
    print("Warning: Dataset is empty, cannot determine MAX_SEQ_LENGTH dynamically. Using default.")


# --- Dataset Split: Train, Validation, Test ---
train_ratio = 0.8
val_ratio = 0.1
test_ratio = 0.1

total_size = len(full_dataset)
train_size = int(train_ratio * total_size)
val_size = int(val_ratio * total_size)
test_size = total_size - train_size - val_size # Ensure all samples are accounted for

train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

# Reduced num_workers if you still get warnings or freezes on Colab
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Dataset split: Train={len(train_dataset)}, Validation={len(val_dataset)}, Test={len(test_dataset)}")


# --- Initialize Model, Loss, and Optimizer ---
model = DNASequenceClassifier(
    input_features=NUM_NUCLEOTIDES,
    d_model=D_MODEL,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    d_ff=D_FF,
    max_seq_length=MAX_SEQ_LENGTH,
    dropout=DROPOUT
).to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.98), eps=1e-9)

# Initialize GradScaler for Automatic Mixed Precision
scaler = GradScaler()

# --- Lists to store metrics for plotting ---
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []


# --- Training Loop ---
print("\nStarting training...")
for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    # Clear GPU cache before each epoch to free up memory
    if device.type == 'cuda':
        torch.cuda.empty_cache()

    for batch_idx, (sequences, labels) in enumerate(train_loader):
        sequences = sequences.to(device)
        labels = labels.to(device).unsqueeze(1)

        optimizer.zero_grad()

        # Use autocast for Automatic Mixed Precision
        with autocast():
            logits = model(sequences)
            loss = criterion(logits, labels)

        # Scale the loss and call backward()
        scaler.scale(loss).backward()

        # Unscale gradients and step optimizer
        scaler.step(optimizer)
        scaler.update() # Update the scaler for the next iteration

        total_loss += loss.item()

        # Calculate accuracy for monitoring (still using float32 for metrics)
        predictions = (torch.sigmoid(logits) > 0.5).float()
        correct_predictions += (predictions == labels).sum().item()
        total_samples += labels.size(0)

        if (batch_idx + 1) % 100 == 0:
            print(f"  Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}")

    avg_train_loss = total_loss / len(train_loader)
    train_accuracy = correct_predictions / total_samples
    train_losses.append(avg_train_loss)
    train_accuracies.append(train_accuracy)

    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}:")
    print(f"  Train Loss: {avg_train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")

    # --- Validation Loop ---
    model.eval()
    val_loss = 0
    val_correct_predictions = 0
    val_total_samples = 0
    with torch.no_grad():
        if device.type == 'cuda':
            torch.cuda.empty_cache() # Clear cache before validation as well
        for sequences, labels in val_loader:
            sequences = sequences.to(device)
            labels = labels.to(device).unsqueeze(1)

            with autocast(): # Use autocast for validation too
                logits = model(sequences)
                loss = criterion(logits, labels)

            val_loss += loss.item()

            predictions = (torch.sigmoid(logits) > 0.5).float()
            val_correct_predictions += (predictions == labels).sum().item()
            val_total_samples += labels.size(0)

    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = val_correct_predictions / val_total_samples
    val_losses.append(avg_val_loss)
    val_accuracies.append(val_accuracy)

    print(f"  Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

print("\nTraining complete!")


# --- Test Evaluation ---
print("\nStarting test evaluation...")
model.eval() # Set model to evaluation mode
test_loss = 0
test_correct_predictions = 0
test_total_samples = 0
with torch.no_grad(): # Disable gradient calculations
    if device.type == 'cuda':
        torch.cuda.empty_cache() # Clear cache before test evaluation
    for sequences, labels in test_loader:
        sequences = sequences.to(device)
        labels = labels.to(device).unsqueeze(1)

        with autocast(): # Use autocast for test evaluation too
            logits = model(sequences)
            loss = criterion(logits, labels)

        test_loss += loss.item()

        predictions = (torch.sigmoid(logits) > 0.5).float()
        test_correct_predictions += (predictions == labels).sum().item()
        test_total_samples += labels.size(0)

avg_test_loss = test_loss / len(test_loader)
test_accuracy = test_correct_predictions / test_total_samples
print(f"Test Loss: {avg_test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")


# --- Plotting Training and Validation Metrics ---
epochs_range = range(1, NUM_EPOCHS + 1)

plt.figure(figsize=(12, 5))

# Plot Loss
plt.subplot(1, 2, 1) # 1 row, 2 columns, first plot
plt.plot(epochs_range, train_losses, label='Training Loss')
plt.plot(epochs_range, val_losses, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Plot Accuracy
plt.subplot(1, 2, 2) # 1 row, 2 columns, second plot
plt.plot(epochs_range, train_accuracies, label='Training Accuracy')
plt.plot(epochs_range, val_accuracies, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)

plt.tight_layout() # Adjusts plot parameters for a tight layout
plt.show() # Display the plots


# Save the trained model (optional)
# torch.save(model.state_dict(), "dna_transformer_classifier.pth")
# print("Model saved to dna_transformer_classifier.pth")


Using device: cpu
Detected sequence length (window_size): 2048
Dataset split: Train=12564660, Validation=1570582, Test=1570583

Starting training...


  scaler = GradScaler()
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/homebrew/anaconda3/envs/ml-base/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/opt/homebrew/anaconda3/envs/ml-base/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'GenomeExpressionDataset' on <module '__main__' (built-in)>


KeyboardInterrupt: 