In [1]:
import torch
import numpy as np
import scanpy as sc
import pickle
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
import os

In [4]:
# Create necessary directories
os.makedirs('./models/dann/', exist_ok=True)

# 1. Data Loading and Preparation
print("Loading and preparing data...")
with open('./src/data/dann/all_cell_data.pkl', 'rb') as f:
    adata = pickle.load(f)

Loading and preparing data...


In [5]:
# Extract gene expression data and domain labels
X = adata.X  # Gene expression matrix
domain_labels = adata.obs['tech'].astype('category').cat.codes  # Domain labels (10x vs SS2)

# Standardize the data
scaler = StandardScaler(with_mean=False)
X_scaled = scaler.fit_transform(X)

# Convert to PyTorch tensors
X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
domain_tensor = torch.tensor(domain_labels.values, dtype=torch.long)

# Split into train and validation sets
X_train, X_val, domain_train, domain_val = train_test_split(
    X_tensor, domain_tensor, test_size=0.2, random_state=42, stratify=domain_tensor
)

ValueError: Cannot center sparse matrices: pass `with_mean=False` instead. See docstring for motivation and alternatives.

In [None]:
# 2. Model Architecture
class GradientReversalLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.alpha, None

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 128)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return F.relu(self.fc2(x))

In [None]:
class DomainClassifier(nn.Module):
    def __init__(self, input_dim):
        super(DomainClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 2)  # Binary classification: 10x vs SS2
        self.dropout = nn.Dropout(0.3)

    def forward(self, x, alpha=1.0):
        x = GradientReversalLayer.apply(x, alpha)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return F.log_softmax(self.fc2(x), dim=1)

In [None]:
# 3. Dataset and Training Setup
class SingleCellDataset(Dataset):
    def __init__(self, X, domain_labels):
        self.X = X
        self.domain_labels = domain_labels

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.domain_labels[idx]

In [None]:
# Hyperparameters
input_dim = X_tensor.shape[1]
hidden_dim = 512
lambda_domain = 1.0
num_epochs = 100
batch_size = 64
learning_rate = 0.001

# Create datasets and dataloaders
train_dataset = SingleCellDataset(X_train, domain_train)
val_dataset = SingleCellDataset(X_val, domain_val)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Initialize models
encoder = Encoder(input_dim, hidden_dim)
domain_classifier = DomainClassifier(128)

# Optimizer
optimizer = torch.optim.Adam(
    list(encoder.parameters()) + list(domain_classifier.parameters()),
    lr=learning_rate
)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)

In [None]:
# 4. Training Loop
def train_epoch(encoder, domain_classifier, train_loader, optimizer, lambda_domain):
    encoder.train()
    domain_classifier.train()
    total_loss = 0
    
    for batch_X, batch_domain in train_loader:
        # Forward pass
        encoded = encoder(batch_X)
        domain_preds = domain_classifier(encoded)
        
        # Calculate losses
        domain_loss = F.nll_loss(domain_preds, batch_domain)
        contrastive_loss = torch.mean((encoded - encoded) ** 2)  # Example contrastive loss
        
        # Total loss
        loss = contrastive_loss - lambda_domain * domain_loss
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(train_loader)

def validate(encoder, domain_classifier, val_loader, lambda_domain):
    encoder.eval()
    domain_classifier.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch_X, batch_domain in val_loader:
            encoded = encoder(batch_X)
            domain_preds = domain_classifier(encoded)
            
            domain_loss = F.nll_loss(domain_preds, batch_domain)
            contrastive_loss = torch.mean((encoded - encoded) ** 2)
            
            loss = contrastive_loss - lambda_domain * domain_loss
            total_loss += loss.item()
    
    return total_loss / len(val_loader)

In [None]:
print("Starting training...")
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    train_loss = train_epoch(encoder, domain_classifier, train_loader, optimizer, lambda_domain)
    val_loss = validate(encoder, domain_classifier, val_loader, lambda_domain)
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    scheduler.step(val_loss)
    
    print(f'Epoch {epoch+1}/{num_epochs}:')
    print(f'Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
    print('-' * 50)

# 5. Visualization and Analysis
print("Generating visualizations...")
# Plot training curves
plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Losses')
plt.legend()
plt.savefig('./models/training_curves.png')
plt.close()

# Get embeddings for visualization
encoder.eval()
with torch.no_grad():
    embeddings = encoder(X_tensor).numpy()

# Plot embeddings
plt.figure(figsize=(10, 6))
sns.scatterplot(x=embeddings[:, 0], y=embeddings[:, 1], hue=domain_labels)
plt.title('DANN Embeddings by Domain')
plt.savefig('./models/embeddings.png')
plt.close()

# 6. Save Model
print("Saving models...")
torch.save(encoder.state_dict(), './models/dann_encoder.pth')
torch.save(domain_classifier.state_dict(), './models/dann_domain_classifier.pth')

print("Training complete!") 