In [1]:
import numpy as np
import pandas as pd
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler
from dataset import Dataset
from model import scGREAT
from train_val import train,validate


data_dir = 'hESC500'
# data_dir = 'mESC500'
expression_data_path = data_dir + '/BL--ExpressionData.csv'
biovect_e_path       = data_dir + '/biovect768.npy'
train_data_path      = data_dir + '/Train_set.csv'
val_data_path        = data_dir + '/Validation_set.csv'
test_data_path       = data_dir + '/Test_set.csv'
expression_data = np.array(pd.read_csv(expression_data_path,index_col=0,header=0))

# Data Preprocessing
standard = StandardScaler()
scaled_df = standard.fit_transform(expression_data.T)
expression_data = scaled_df.T
expression_data_shape = expression_data.shape 

train_dataset = Dataset(train_data_path, expression_data)
val_dataset = Dataset(val_data_path, expression_data)
test_dataset = Dataset(test_data_path, expression_data)



In [2]:
# SANITY ORIGINAL MODEL
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=32,
                                               shuffle=True,
                                               drop_last=False,
                                               num_workers=0)
print(f"train_loader has {len(train_loader)} batches")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for idx, (gene_pair_index, expr_embedding, label) in enumerate(train_loader):
        label = label.to(device)
        gene_pair_index = gene_pair_index.to(device)
        expr_embedding = expr_embedding.to(torch.float32)
        expr_embedding = expr_embedding.to(device)
        break

print(gene_pair_index.shape, expr_embedding.shape, label.shape)

T = scGREAT(expression_data_shape,768,2,4,biovect_e_path)
predicted_label = T(gene_pair_index, expr_embedding)
print(predicted_label.shape, predicted_label)

train_loader has 647 batches
torch.Size([32, 2]) torch.Size([32, 2, 758]) torch.Size([32])
torch.Size([32, 1]) tensor([[0.4856],
        [0.5310],
        [0.5134],
        [0.5054],
        [0.5034],
        [0.5394],
        [0.5406],
        [0.5385],
        [0.4284],
        [0.4570],
        [0.4467],
        [0.5221],
        [0.4799],
        [0.4964],
        [0.5045],
        [0.5748],
        [0.5023],
        [0.5693],
        [0.5390],
        [0.4825],
        [0.5336],
        [0.5211],
        [0.5392],
        [0.5054],
        [0.5352],
        [0.5461],
        [0.5105],
        [0.4701],
        [0.4989],
        [0.5508],
        [0.5140],
        [0.5477]], grad_fn=<SigmoidBackward0>)


In [3]:
import torch
import numpy as np
import pandas as pd

class GeneInteractionDataset(torch.utils.data.Dataset):
    def __init__(self, data_path, expression_data):
        # Load labeled gene pairs
        data = pd.read_csv(data_path, index_col=0, header=0)

        self.labeled_pairs = data[['TF', 'Target']].values
        self.labels = torch.tensor(data['Label'].values.astype(np.float32))  # Convert labels to float32
        self.expression_data = torch.tensor(expression_data, dtype=torch.float32)  # [num_genes, embedding_dim]
        self.num_genes = self.expression_data.shape[0]

    def __len__(self):
        return 1  # Since the expression data is the same, we can return a single sample

    def __getitem__(self, idx):
        # Gene embeddings
        expr_embedding = self.expression_data  # [num_genes, embedding_dim]

        # Initialize label matrix (all -1) and mask (all 0)
        label_matrix = -torch.ones((self.num_genes, self.num_genes), dtype=torch.float32)  
        mask = torch.zeros((self.num_genes, self.num_genes), dtype=torch.float32)

        # Fill in the labels and mask for labeled pairs
        for (i, j), label in zip(self.labeled_pairs, self.labels):
            label_matrix[i, j] = label  # Assign the label (1 or 0)
            mask[i, j] = 1.0            # Mark as labeled

        return expr_embedding, label_matrix, mask


In [4]:
interaction_train_ds = GeneInteractionDataset(train_data_path, expression_data)
interaction_val_ds = GeneInteractionDataset(val_data_path, expression_data)
interaction_test_ds = GeneInteractionDataset(test_data_path, expression_data)

interaction_train_loader = torch.utils.data.DataLoader(
    dataset=interaction_train_ds,
    batch_size=1,
    shuffle=True,
    num_workers=0
)
interaction_val_loader = torch.utils.data.DataLoader(
    dataset=interaction_val_ds,
    batch_size=1,
    shuffle=False,
    num_workers=0
)
interaction_test_loader = torch.utils.data.DataLoader(
    dataset=interaction_test_ds,
    batch_size=1,
    shuffle=False,
    num_workers=0
)

In [6]:
import sys
sys.path.append('/Users/factored/Dev/GEARS/gears')
from gene_transformer import *
class GrnTransformer(nn.Module):
    """Input is a tensor of shape (batch_size, num_genes, embed_dim) and output is a tensor of shape (batch_size, num_genes)."""

    def __init__(
        self,
        num_genes, # Different from GeneExpressionTransformer from GEARS
        embed_dim=64,
        num_heads=8,
        hidden_dim=256,
        num_layers=1,
        group_size=512,
        dropout=0.0,
    ):
        super().__init__()
        self.layers = nn.ModuleList(
            [
                TransformerLayer(embed_dim, num_heads, hidden_dim, group_size, dropout)
                for _ in range(num_layers)
            ]
        )
        self.norm = RMSNorm(embed_dim)
        
        # Components different from GeneExpressionTransformer from GEARS
        # embed input vectors of size 758 to size 768. TODO: 758 is hardcoded here
        self.proj = nn.Linear(758, embed_dim)
        self.grn_pred_head = nn.Linear(hidden_dim, num_genes)


    def forward(self, x):
        x = self.proj(x)
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return self.grn_pred_head(x)

In [14]:
# SANITY CHECK NEW MODEL
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for expr_embedding, label_matrix, mask in interaction_train_loader:
    expr_embedding = expr_embedding.to(device)  # [batch_size, num_genes_per_sample, embedding_dim]
    label_matrix = label_matrix.to(device)      # [batch_size, num_genes_per_sample, num_genes_per_sample]
    mask = mask.to(device)     
    break

print(len(interaction_train_ds), len(interaction_train_loader))
print(expr_embedding.shape, label_matrix.shape, mask.shape)

model = GrnTransformer(
    num_genes=interaction_train_ds.expression_data.shape[0],
    embed_dim=768,
    num_heads=4,
    hidden_dim=768,
    num_layers=2,
    group_size=512,
    dropout=0.0,
)
predicted_label = model(expr_embedding) 
print(expr_embedding.shape, predicted_label.shape)

1 1
torch.Size([1, 910, 758]) torch.Size([1, 910, 910]) torch.Size([1, 910, 910])
torch.Size([1, 910, 758]) torch.Size([1, 910, 910])


In [15]:
# NEW MODEL TRAINING LOOP
from sklearn.metrics import roc_auc_score, average_precision_score

def Evaluation(y_pred, y_true, mask):
    # Detach tensors and move to CPU
    y_pred = y_pred.detach().cpu()
    y_true = y_true.detach().cpu()
    mask = mask.detach().cpu()

    # Flatten the tensors
    y_pred = y_pred.view(-1)
    y_true = y_true.view(-1)
    mask = mask.view(-1)

    # Apply mask to select labeled pairs
    y_pred = y_pred[mask > 0]
    y_true = y_true[mask > 0]

    # Convert to numpy arrays
    y_pred = y_pred.numpy()
    y_true = y_true.numpy()

    try:
        # Compute evaluation metrics
        AUROC = roc_auc_score(y_true=y_true, y_score=y_pred)
        AUPRC = average_precision_score(y_true=y_true, y_score=y_pred)
    except Exception as e:
        AUROC = 0
        AUPRC = 0

    return AUROC, AUPRC

def train(model, dataloader, loss_func, optimizer, epoch, scheduler, args):
    model.train()
    log_interval = 200
    total_loss = 0

    for idx, (expr_embedding, label_matrix, mask) in enumerate(dataloader):
        # Move tensors to the appropriate device
        expr_embedding = expr_embedding.to(torch.float32).to(device)  # [batch_size, num_genes, embedding_dim]
        label_matrix = label_matrix.to(torch.float32).to(device)      # [batch_size, num_genes, num_genes]
        mask = mask.to(torch.float32).to(device)                      # [batch_size, num_genes, num_genes]

        optimizer.zero_grad()

        # Forward pass through the model
        predicted_output = model(expr_embedding)  # Should output [batch_size, num_genes, num_genes]

        # Compute the element-wise loss
        loss = loss_func(predicted_output, label_matrix)

        # Apply the mask to ignore unlabeled gene pairs
        masked_loss = loss * mask

        # Compute the mean loss over labeled pairs
        final_loss = masked_loss.sum() / mask.sum()
        total_loss += final_loss.item()

        # Backpropagation
        final_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()

        if args and args.scheduler_flag:
            scheduler.step()

        # Logging and evaluation
        if idx % log_interval == 0:
            AUROC, AUPRC = Evaluation(y_pred=predicted_output, y_true=label_matrix, mask=mask)
            print(
                "| epoch {:3d} | {:5d} /{:5d} batches |Train loss {:8.3f} | AUROC {:8.3f} | AUPRC {:8.3f}".format(
                    epoch, idx, len(dataloader), final_loss.item(), AUROC, AUPRC
                )
            )

    print("| epoch {:3d} | total_loss {:8.3f}".format(epoch, total_loss))


# for epoch in range(50):
#     train(model, interaction_train_loader, loss_fn, optimizer, epoch, None, None)
#     AUC_val,AUPR_val = validate(T,val_loader,loss_func)
#     print('-' * 100)
#     print('| end of epoch {:3d} |valid AUROC {:8.3f} | valid AUPRC {:8.3f}'.format(epoch,AUC_val,AUPR_val))
#     print('-' * 100)

def validate(model, dataloader, loss_func):
    model.eval()
    total_loss = 0
    pre_list = []
    lb_list = []
    mask_list = []

    with torch.no_grad():
        for idx, (expr_embedding, label_matrix, mask) in enumerate(dataloader):
            expr_embedding = expr_embedding.to(torch.float32).to(device)  # [batch_size, num_genes, embedding_dim]
            label_matrix = label_matrix.to(torch.float32).to(device)      # [batch_size, num_genes, num_genes]
            mask = mask.to(torch.float32).to(device)                      # [batch_size, num_genes, num_genes]

            # Forward pass
            predicted_output = model(expr_embedding)  # Should output [batch_size, num_genes, num_genes]

            # Compute loss (optional)
            loss = loss_func(predicted_output, label_matrix)
            masked_loss = loss * mask
            final_loss = masked_loss.sum() / mask.sum()
            total_loss += final_loss.item()

            # Collect predictions and labels
            pre_list.append(predicted_output.detach().cpu())
            lb_list.append(label_matrix.detach().cpu())
            mask_list.append(mask.detach().cpu())

        # Concatenate all batches
        y_pred = torch.cat(pre_list, dim=0)  # [total_samples, num_genes, num_genes]
        y_true = torch.cat(lb_list, dim=0)
        mask = torch.cat(mask_list, dim=0)

        # Flatten tensors
        y_pred = y_pred.view(-1)
        y_true = y_true.view(-1)
        mask = mask.view(-1)

        # Apply mask to select labeled pairs
        y_pred = y_pred[mask > 0]
        y_true = y_true[mask > 0]

        # Convert to numpy arrays
        y_pred = y_pred.numpy()
        y_true = y_true.numpy()

        try:
            # Compute evaluation metrics
            AUROC = roc_auc_score(y_true=y_true, y_score=y_pred)
            AUPRC = average_precision_score(y_true=y_true, y_score=y_pred)
        except Exception as e:
            AUROC = 0
            AUPRC = 0

    return AUROC, AUPRC


EPOCHS = 647 # The same number of steps as 1 epoch in the original model
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
loss_fn = torch.nn.BCEWithLogitsLoss(reduction='none')

for epoch in range(1, EPOCHS + 1):
    train(model, interaction_train_loader, loss_fn, optimizer, epoch, None, None)
    AUC_val, AUPR_val = validate(model, interaction_val_loader, loss_fn)
    print('-' * 100)
    print('| end of epoch {:3d} | valid AUROC {:8.3f} | valid AUPRC {:8.3f}'.format(epoch, AUC_val, AUPR_val))
    print('-' * 100)
    AUC_test, AUPR_test = validate(model, interaction_test_loader, loss_fn)
    print('| end of epoch {:3d} | test  AUROC {:8.3f} | test  AUPRC {:8.3f}'.format(epoch, AUC_test, AUPR_test))
    print('-' * 100)

    if AUC_val < 0.501:
        print("AUC_val < 0.501 !!")
        break

| epoch   1 |     0 /    1 batches |Train loss    0.738 | AUROC    0.492 | AUPRC    0.144
| epoch   1 | total_loss    0.738
----------------------------------------------------------------------------------------------------
| end of epoch   1 | valid AUROC    0.530 | valid AUPRC    0.155
----------------------------------------------------------------------------------------------------
| end of epoch   1 | test  AUROC    0.500 | test  AUPRC    0.147
----------------------------------------------------------------------------------------------------
| epoch   2 |     0 /    1 batches |Train loss    0.736 | AUROC    0.495 | AUPRC    0.145
| epoch   2 | total_loss    0.736
----------------------------------------------------------------------------------------------------
| end of epoch   2 | valid AUROC    0.530 | valid AUPRC    0.155
----------------------------------------------------------------------------------------------------
| end of epoch   2 | test  AUROC    0.500 | test  AU

In [None]:

# for epoch in range(50):
#     for expr_embedding, label_matrix, mask in interaction_train_loader:
#         expr_embedding = expr_embedding.to(device)  # [batch_size, num_genes, embedding_dim]
#         label_matrix = label_matrix.to(device)      # [batch_size, num_genes, num_genes]
#         mask = mask.to(device)                      # [batch_size, num_genes, num_genes]

#         # Forward pass
#         outputs = model(expr_embedding)  # Should output [batch_size, num_genes, num_genes]

#         # Compute the loss
#         loss_fn = torch.nn.BCEWithLogitsLoss(reduction='none')
#         loss = loss_fn(outputs, label_matrix)
#         masked_loss = loss * mask
#         final_loss = masked_loss.sum() / mask.sum()

#         # Backpropagation and optimization steps
#         optimizer.zero_grad()
#         final_loss.backward()
#         optimizer.step()
#         print(f"Epoch {epoch}, Loss: {final_loss.item()}")