In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split

# Assuming your modified model is in 'model.py'
from model_transf import DNASequenceClassifier
# Assuming your data loader is in 'data_loader.py' or similar file structure.
# For this example, I'll copy the necessary parts of your data loader here,
# but in practice, you should import it.
import numpy as np
import pandas as pd
import os
from torch.utils.data import Dataset # Import Dataset class

# --- Your Data Loader Code (as provided previously) ---
# This part is copied directly from your original data loader.
# In a real project, you'd usually import this.
def get_data_dir():
    # ... (your existing get_data_dir function) ...
    # Default path for local/cluster (adjust this for your specific setup if needed)
    local_or_cluster_project_path = os.getcwd()
    data_directory = os.path.join(local_or_cluster_project_path, 'data/')

    try:
        from google.colab import drive
        drive.mount('/content/gdrive')
        google_drive_project_path = '/content/gdrive/MyDrive/DnARnAProject/'
        data_directory = os.path.join(google_drive_project_path, 'data/')
        print("Detected Google Colab environment. Using Google Drive path.")
    except ImportError:
        print("Not in Google Colab. Using local/cluster path.")

    if not os.path.isdir(data_directory):
        # Create data directory if it doesn't exist, useful for local testing if you then put data in it
        # Or raise an error if data is expected to be present.
        # os.makedirs(data_directory, exist_ok=True)
        # print(f"Created data directory: {data_directory}")
        raise FileNotFoundError(f"Error: The data directory '{data_directory}' does not exist. "
                                "Please ensure your data is located correctly for your environment.")
    return data_directory


class GenomeExpressionDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir

        self.data_npz_path = os.path.join(data_dir, 'data.npz')
        self.regions_parquet_path = os.path.join(data_dir, 'regions.parquet')

        try:
            self.data_npz = np.load(self.data_npz_path, allow_pickle=True)
            self.sequence_data = self.data_npz['sequence']
            self.expression_plus_data = self.data_npz['expressed_plus']
            self.expression_minus_data = self.data_npz['expressed_minus']
            self.data_npz.close()
        except KeyError as e:
            available_keys = list(np.load(self.data_npz_path).keys()) if os.path.exists(self.data_npz_path) else "File not found during key check."
            raise RuntimeError(f"KeyError: Key '{e}' not found in {self.data_npz_path}. "
                               f"Available keys: {available_keys}. "
                               "Please check your .npz file structure.")
        except Exception as e:
            raise RuntimeError(f"Could not load data from {self.data_npz_path}. Make sure the file exists and is not corrupted: {e}")

        try:
            self.regions_df = pd.read_parquet(self.regions_parquet_path)
        except Exception as e:
            raise RuntimeError(f"Could not load regions from {self.regions_parquet_path}. Make sure the file exists and is not corrupted: {e}")

        self.num_nucleotides = 5 # A, C, G, T, N (mapped to 0, 1, 2, 3, 4)
        self.complement_map = np.array([3, 2, 1, 0, 4], dtype=np.uint8)


    def __len__(self):
        return len(self.regions_df)

    def _one_hot_encode(self, sequence_segment):
        one_hot_tensor = torch.zeros(len(sequence_segment), self.num_nucleotides, dtype=torch.float32)
        one_hot_tensor.scatter_(1, torch.tensor(sequence_segment).unsqueeze(1).long(), 1)
        return one_hot_tensor

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        region_info = self.regions_df.iloc[idx]

        offset = region_info['offset']
        window_size = region_info['window_size']
        strand = region_info['strand']

        sequence_segment = self.sequence_data[offset : offset + window_size].copy()

        if strand == '+':
            encoded_sequence = self._one_hot_encode(sequence_segment)
            expression_label = self.expression_plus_data[offset]
        else: # strand == '-'
            reverse_complemented_sequence = self.complement_map[sequence_segment][::-1].copy()
            encoded_sequence = self._one_hot_encode(reverse_complemented_sequence)
            expression_label = self.expression_minus_data[offset]

        # Ensure expression_label is float for BCEWithLogitsLoss
        expression_label = torch.tensor(expression_label, dtype=torch.float32)

        # Return sequence and label
        return encoded_sequence, expression_label

# --- End of Data Loader Code ---


# --- Hyperparameters for your DNA Sequence Classifier ---
# These should be tuned!
NUM_NUCLEOTIDES = 5 # A, C, G, T, N (from your data loader)
D_MODEL = 128      # Dimension of model (e.g., 128, 256, 512)
NUM_HEADS = 8      # Number of attention heads (should divide D_MODEL)
NUM_LAYERS = 3     # Number of encoder layers
D_FF = 512         # Dimension of feed-forward network (typically 2*D_MODEL or 4*D_MODEL)
# Max sequence length: This should be your window_size from the data.
# It's crucial for PositionalEncoding to match your data's window size.
# You'll need to get this from your dataset or define it based on your data generation.
# Let's assume a typical window size for now, e.g., 200 or 500
MAX_SEQ_LENGTH = 200 # <-- IMPORTANT: Set this to your actual window_size from data preprocessing
DROPOUT = 0.1

BATCH_SIZE = 64
LEARNING_RATE = 0.0001
NUM_EPOCHS = 50 # You might need many more epochs

# --- Device Configuration ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Initialize Dataset and DataLoader ---
data_dir = get_data_dir() # Get your data directory
full_dataset = GenomeExpressionDataset(data_dir)

# Get the actual window_size from the first item in the dataset to set MAX_SEQ_LENGTH dynamically
# This assumes all sequences have the same window_size.
if len(full_dataset) > 0:
    sample_seq, _ = full_dataset[0]
    MAX_SEQ_LENGTH = sample_seq.shape[0] # sequence_length (window_size)
    print(f"Detected sequence length (window_size): {MAX_SEQ_LENGTH}")
else:
    print("Warning: Dataset is empty, cannot determine MAX_SEQ_LENGTH dynamically. Using default.")

# Split dataset into training and validation
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4) # num_workers can speed up loading
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)


# --- Initialize Model, Loss, and Optimizer ---
model = DNASequenceClassifier(
    input_features=NUM_NUCLEOTIDES,
    d_model=D_MODEL,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    d_ff=D_FF,
    max_seq_length=MAX_SEQ_LENGTH,
    dropout=DROPOUT
).to(device)

# For binary classification, BCEWithLogitsLoss is appropriate as it combines sigmoid and BCELoss for numerical stability.
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.98), eps=1e-9)

# --- Training Loop ---
print("\nStarting training...")
for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    for batch_idx, (sequences, labels) in enumerate(train_loader):
        sequences = sequences.to(device) # Shape: (batch_size, window_size, 5)
        labels = labels.to(device).unsqueeze(1) # Ensure labels are (batch_size, 1)

        optimizer.zero_grad()
        logits = model(sequences) # Output logits (batch_size, 1)

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Calculate accuracy for monitoring
        predictions = (torch.sigmoid(logits) > 0.5).float() # Apply sigmoid and threshold
        correct_predictions += (predictions == labels).sum().item()
        total_samples += labels.size(0)

        if (batch_idx + 1) % 100 == 0:
            print(f"  Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}")

    avg_train_loss = total_loss / len(train_loader)
    train_accuracy = correct_predictions / total_samples
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}:")
    print(f"  Train Loss: {avg_train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")

    # --- Validation Loop ---
    model.eval()
    val_loss = 0
    val_correct_predictions = 0
    val_total_samples = 0
    with torch.no_grad():
        for sequences, labels in val_loader:
            sequences = sequences.to(device)
            labels = labels.to(device).unsqueeze(1)

            logits = model(sequences)
            loss = criterion(logits, labels)
            val_loss += loss.item()

            predictions = (torch.sigmoid(logits) > 0.5).float()
            val_correct_predictions += (predictions == labels).sum().item()
            val_total_samples += labels.size(0)

    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = val_correct_predictions / val_total_samples
    print(f"  Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

print("\nTraining complete!")

# Save the trained model (optional)
# torch.save(model.state_dict(), "dna_transformer_classifier.pth")
# print("Model saved to dna_transformer_classifier.pth")