In [1]:
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import datetime
import warnings

warnings.filterwarnings("ignore")

### Constants & Options

In [144]:
BULK_PATH = "input/2dRNA/group1/bulk_RawCounts.tsv"
SC_DIR_PATH = "input/2dRNA/group1/"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pd.set_option('display.max_columns', None)
pd.set_option('display.width', 1000)
pd.set_option('display.max_colwidth', 100)
np.set_printoptions(linewidth=120)
np.set_printoptions(precision=4, suppress=True)

### 1. Data Loading
Load necessary files to DataFrames, see info/stats

In [4]:
bulk_df = pd.read_csv(BULK_PATH, sep="\t")

print("B Matrix (Tissue GEPs) Sample:\n")
print(bulk_df.iloc[:, :6].head(5))
print("\n----------------------------------------------")
print(f"\nB DIMENSIONS: rows (genes) = {bulk_df.shape[0]}, columns (patients) = {bulk_df.shape[1]}")

B Matrix (Tissue GEPs) Sample:

             gene_id  gene_symbol  CANUCK1057-BAL-LB3B  CANUCK1047-BAL-LB5  RESP1024-BAL-LB5  CANUCK1060-BAL-RB4
0  ENSG00000290825.1      DDX11L2                    2                   0                 0                   1
1  ENSG00000223972.6      DDX11L1                    0                   0                 0                   0
2  ENSG00000227232.6       WASH7P                   89                  81                47                 101
3  ENSG00000278267.1    MIR6859-1                   23                  12                 7                  14
4  ENSG00000243485.5  MIR1302-2HG                    0                   0                 0                   0

----------------------------------------------

B DIMENSIONS: rows (genes) = 63187, columns (patients) = 34


In [5]:
sc_path = SC_DIR_PATH + "scRNA_CT1_top200_RawCounts.tsv"
sc_df = pd.read_csv(sc_path, sep="\t")

print("S Matrix (Cell GEPs) Sample:\n")
print(sc_df.iloc[:, :12].head(5))
print("\n----------------------------------------------")
print(f"\nS DIMENSIONS: rows (patients x cells) = {sc_df.shape[0]}, columns (genes) = {sc_df.shape[1]}")

S Matrix (Cell GEPs) Sample:

                cell_id patient_id  TUBA1A  SPA17  ACTG1  TSTD1  H1-0  NQO1  ATP5IF1  DNPH1  NEDD9  ALDH1A1
0  AAACCCACAATACGAA-1_1   BAL-RB-2       0      0      1      1     0     0        0      0      0        0
1  AAACGAACACGCTATA-1_1   BAL-RB-2      81      1     64      2     0     1        6      1      0       15
2  AACAACCCAAACTCGT-1_1   BAL-RB-2       4      0    106      0     0     0        4      1      0       13
3  AACACACCAAATTGGA-1_1   BAL-RB-2       0      0      1      0     0     0        0      0      0        0
4  AACAGGGGTCGTACTA-1_1   BAL-RB-2       0      0     16      0     0     0        1      0      2        0

----------------------------------------------

S DIMENSIONS: rows (patients x cells) = 241924, columns (genes) = 1013


In [6]:
sc_metadata_path = SC_DIR_PATH + "scRNA_CT1_top200_Metadata.tsv"
sc_metadata_df = pd.read_csv(sc_metadata_path, sep="\t")

print("S Metadata Matrix Sample:\n")
print(sc_metadata_df.head(5))
print("\n----------------------------------------------")
print("S Metadata Info:\n")
sc_metadata_df.info()
print("----------------------------------------------")
print(f"\nS METADATA DIMENSIONS: rows (patients x cells) = {sc_metadata_df.shape[0]}, columns (metadata) = {sc_metadata_df.shape[1]}\n")

S Metadata Matrix Sample:

                cell_id patient_id  patient_age patient_sex  cell_type_1                cell_type_2                cell_type_3          cell_type_4                      data_source deconv_cluster
0  AAACCCACAATACGAA-1_1   BAL-RB-2           32      Female   Epithelial                        NaN                 Epithelial           Epithelial  Post-covid respiratory symptoms     Epithelial
1  AAACGAACACGCTATA-1_1   BAL-RB-2           32      Female  Macrophages  Alveolar_Macrophage_CSF1R  Alveolar_Macrophage_CSF1R  Alveolar_macrophage  Post-covid respiratory symptoms    Macrophages
2  AACAACCCAAACTCGT-1_1   BAL-RB-2           32      Female  Macrophages           Macrophage_CCL18           Macrophage_CCL18  Alveolar_macrophage  Post-covid respiratory symptoms    Macrophages
3  AACACACCAAATTGGA-1_1   BAL-RB-2           32      Female  Macrophages           Macrophage_CCL18           Macrophage_CCL18  Alveolar_macrophage  Post-covid respiratory symptoms    Macro

### 2. Data Processing
Process bulk and single-cell data to generate training samples.

In [129]:
def process_B(bulk: pd.DataFrame, sc: pd.DataFrame, sc_metadata: pd.DataFrame):
    # Filter B to keep only common genes with S
    bulk_genes_all = bulk["gene_symbol"].str.strip().str.lower()
    common_genes = set(bulk_genes_all).intersection(sc.columns[2:].str.strip().str.lower())
    filtered_bulk = bulk[bulk_genes_all.isin(common_genes)].drop_duplicates(subset="gene_symbol", keep="first")
    filtered_bulk_vals = filtered_bulk.iloc[:, 2:]  # drop gene_id and gene_symbol cols
    
    # Normalize and convert to np array
    B = np.log1p(filtered_bulk_vals.values.T)
    print(f"B dims (patients x genes): {B.shape}")

    # Assert that patient IDs in S match B
    sc_patient_ids = sc_metadata['patient_id'].unique()
    bulk_patient_ids = filtered_bulk_vals.columns
    if not all(i in bulk_patient_ids for i in sc_patient_ids):
        raise ValueError("Patient IDs in S do not match B. Check mapping.")

    return B, bulk_patient_ids


# def process_S(sc_metadata: pd.DataFrame, patient_ids: np.ndarray, n_aug=15, sample_fraction=0.5):
#     """
#     Derive stratified augmented C matrices from S metadata.

#     Args:
#         sc_metadata (pd.DataFrame): Single-cell metadata containing patient and cell type information.
#         patient_ids (np.ndarray): Array of patient IDs in bulk data.
#         n_aug (int): Number of augmentations per patient.
#         sample_fraction (float): Fraction of cells to sample for each augmentation (e.g., 0.9 for 90%).

#     Returns:
#         np.ndarray: Augmented C matrices (patients x n_augs x cell types).
#         list: Successfully processed patient IDs.
#     """
#     ct_labels = sc_metadata["cell_type_1"].dropna().unique()
#     C_augs = []  # Augmentations for all patients
#     processed_patients = []  # Successfully processed patient IDs

#     for pid in patient_ids:
#         print(f"Augmenting Patient {pid}")
#         patient_cells = sc_metadata[sc_metadata["patient_id"] == pid]
#         if patient_cells.empty:
#             print(f"  Skipping (no cells)")
#             continue

#         patient_augs = []

#         for aug_idx in range(n_aug):
#             # Stratify sampling: random sample by cell type
#             strat_sample = []
#             for ct in ct_labels:
#                 ct_cells = patient_cells[patient_cells["cell_type_1"] == ct]
#                 if not ct_cells.empty:
#                     # Ensure sample size is not larger than available cells
#                     n_sample = max(1, min(len(ct_cells), int(len(ct_cells) * sample_fraction)))
#                     sampled_cells = ct_cells.sample(n=n_sample, replace=False, random_state=aug_idx)
#                     strat_sample.append(sampled_cells)

#             # Calculate cell type fractions for this augmentation
#             sampled_cells = pd.concat(strat_sample) if strat_sample else pd.DataFrame(columns=["cell_type_1"])
#             ct_fractions = sampled_cells["cell_type_1"].value_counts(normalize=True)
#             all_ct_fractions = {ct: ct_fractions.get(ct, 0.0) for ct in ct_labels}
#             patient_augs.append(list(all_ct_fractions.values()))

#         C_augs.append(patient_augs)
#         processed_patients.append(pid)

#     C_augs = np.array(C_augs)
#     # Flatten to 2D so each row is an augmentation for a specific patient
#     C_flat = C_augs.reshape(-1, C_augs.shape[2])
#     print(f"C dims ((patients * n_augs) x CTs): {C_flat.shape}")
#     print(f"Processed patients: {len(processed_patients)}")

#     return C_flat, processed_patients


def process_S(sc_metadata: pd.DataFrame, patient_ids: np.ndarray, n_aug=15, sample_fraction=0.9):
    """
    Derive augmented C matrices with random sampling (not stratified).

    Args:
        sc_metadata (pd.DataFrame): Single-cell metadata containing patient and cell type information.
        patient_ids (np.ndarray): Array of patient IDs in bulk data.
        n_aug (int): Number of augmentations per patient.
        sample_fraction (float): Fraction of cells to sample for each augmentation (e.g., 0.9 for 90%).

    Returns:
        np.ndarray: Augmented C matrices (patients x n_augs x cell types).
        list: Successfully processed patient IDs.
    """
    ct_labels = sc_metadata["cell_type_1"].dropna().unique()
    C_augs = []  # Augmentations for all patients
    processed_patients = []  # Successfully processed patient IDs

    for pid in patient_ids:
        print(f"Augmenting Patient {pid}")
        patient_cells = sc_metadata[sc_metadata["patient_id"] == pid]
        if patient_cells.empty:
            print(f"  Skipping (no cells)")
            continue

        patient_augs = []

        for _ in range(n_aug):
            # Randomly sample a subset of all cells
            n_sample = max(1, int(len(patient_cells) * sample_fraction))
            sampled_cells = patient_cells.sample(n=n_sample, replace=False, random_state=None)

            # Calculate cell type fractions for this augmentation
            ct_fractions = sampled_cells["cell_type_1"].value_counts(normalize=True)
            all_ct_fractions = {ct: ct_fractions.get(ct, 0.0) for ct in ct_labels}
            patient_augs.append(list(all_ct_fractions.values()))

        C_augs.append(patient_augs)
        processed_patients.append(pid)

    C_augs = np.array(C_augs)
    # Flatten to 2D so each row is an augmentation for a specific patient
    C_flat = C_augs.reshape(-1, C_augs.shape[2])
    print(f"C dims ((patients * n_augs) x CTs): {C_flat.shape}")
    print(f"Processed patients: {len(processed_patients)}")

    return C_flat, processed_patients


In [137]:
n_aug = 30
B, patient_ids = process_B(bulk_df, sc_df, sc_metadata_df)
C_flat, processed_patient_ids = process_S(sc_metadata_df, patient_ids, n_aug)

B_filtered = B[np.isin(patient_ids, processed_patient_ids)]
print(f"Filtered B dims (patients x genes): {B_filtered.shape}")

B_aug = np.repeat(B_filtered, n_aug, axis=0)  # Repeat B for each augmentation
print(f"Flattened B dims ((patients * n_augs) x genes): {B_aug.shape}")

X_train, X_test, Y_train, Y_test = train_test_split(B_aug, C_flat, test_size=0.2, random_state=42)

train_dataset = TensorDataset(
    torch.tensor(X_train, dtype=torch.float32),
    torch.tensor(Y_train, dtype=torch.float32),
)
test_dataset = TensorDataset(
    torch.tensor(X_test, dtype=torch.float32),
    torch.tensor(Y_test, dtype=torch.float32),
)

B dims (patients x genes): (32, 1009)
Augmenting Patient CANUCK1057-BAL-LB3B
Augmenting Patient CANUCK1047-BAL-LB5
Augmenting Patient RESP1024-BAL-LB5
Augmenting Patient CANUCK1060-BAL-RB4
  Skipping (no cells)
Augmenting Patient RESPONSE1030-BAL-RB5
Augmenting Patient CANUCK1012-BAL
Augmenting Patient RESPONSE1094-BAL-LB5
Augmenting Patient VAPE1007-BAL-LB5
Augmenting Patient RESP1022-BAL-LB1
Augmenting Patient BAL-RB-2
Augmenting Patient VAPE1009-BAL-RB5
Augmenting Patient RESP1023-BAL-LB5
Augmenting Patient RESP1019-BAL-RB9
Augmenting Patient RESP1038-BAL-RB5
Augmenting Patient RESP1001-BAL-LB
Augmenting Patient Response1014-BAL
Augmenting Patient RESP1036-BAL-LB5
Augmenting Patient VAPE1010-BAL-RB5
Augmenting Patient RESPONSE1040-BAL-LB5
Augmenting Patient CANUCK1043-BAL-RB5
Augmenting Patient CANUCK1039-BAL-RB6
  Skipping (no cells)
Augmenting Patient CANUCK1031-BAL-LB5
Augmenting Patient RESP1020-BAL-LB5
Augmenting Patient CANUCK1035-BAL-RB5
Augmenting Patient RESPONSE1076-BAL-LB

In [134]:
print("C sample:")
print(C_flat[10:20, :])

C sample:
[[0.0708 0.8299 0.0153 0.0289 0.0075 0.0063 0.0019 0.0078 0.0048 0.0002 0.0266]
 [0.068  0.8331 0.0149 0.0274 0.0075 0.006  0.002  0.008  0.0048 0.0002 0.0281]
 [0.0702 0.8305 0.0158 0.0285 0.0073 0.0061 0.0019 0.0067 0.0048 0.0002 0.0279]
 [0.0691 0.8323 0.0151 0.0279 0.0071 0.0058 0.002  0.0078 0.005  0.0002 0.0276]
 [0.0702 0.8301 0.0155 0.0285 0.0073 0.0056 0.0019 0.0084 0.005  0.0002 0.0274]
 [0.0706 0.8318 0.0155 0.0278 0.0063 0.0063 0.002  0.0076 0.005  0.0002 0.0268]
 [0.071  0.8282 0.0166 0.0287 0.0078 0.0054 0.0017 0.0076 0.0047 0.0002 0.0281]
 [0.0704 0.8307 0.016  0.0279 0.0075 0.0061 0.002  0.0073 0.005  0.0002 0.0268]
 [0.0693 0.8312 0.0149 0.0292 0.0076 0.0061 0.0017 0.0082 0.0043 0.0002 0.0272]
 [0.0708 0.829  0.0156 0.0287 0.0075 0.0061 0.0019 0.0076 0.0054 0.0002 0.0272]]


### 3. Model Training
Define model architecture/parameters, run training, eval

In [145]:
class Model2dRNA(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Model2dRNA, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 1000),
            nn.ReLU(),
            nn.BatchNorm1d(1000),
            nn.Dropout(0.3),
            nn.Linear(1000, 500),
            nn.ReLU(),
            nn.BatchNorm1d(500),
            nn.Dropout(0.3),
            nn.Linear(500, 100),
            nn.ReLU(),
            nn.BatchNorm1d(100),
            nn.Linear(100, output_dim),
        )

    def forward(self, x):
        return self.model(x)


def train_model(model: Model2dRNA, train_set: TensorDataset, test_set: TensorDataset, epochs: int, batch_size: int):
    model.to(DEVICE)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.MSELoss()

    for e in range(epochs):
        model.train()
        epoch_loss = 0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for X_val, y_val in test_loader:
                X_val, y_val = X_val.to(DEVICE), y_val.to(DEVICE)
                val_outputs = model(X_val)
                val_loss += criterion(val_outputs, y_val).item()
        print(f"Epoch {e+1}/{epochs}, Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}")


def save_model(model: Model2dRNA, X_test, Y_test):
    os.makedirs("output", exist_ok=True)
    dtnum = str(datetime.datetime.now().strftime("%Y%m%d_%H%M"))
    model_dir = os.path.join("output", "2dRNA", dtnum)
    os.makedirs(model_dir, exist_ok=True)

    model_path = os.path.join(model_dir, "model.pth")
    torch.save(model.state_dict(), model_path)
    print(f"Saved model to {model_path}")

    # Save predictions and true fractions
    X_test = torch.tensor(X_test, dtype=torch.float32)
    Y_test = torch.tensor(Y_test, dtype=torch.float32)
    model.eval()
    with torch.no_grad():
        predictions = model(X_test.to(DEVICE)).cpu().numpy()
    preds_file = os.path.join(model_dir, "pred_fractions.csv")
    true_fractions_file = os.path.join(model_dir, "true_fractions.csv")
    np.savetxt(preds_file, predictions, delimiter=",")
    np.savetxt(true_fractions_file, Y_test.numpy(), delimiter=",")
    print(f"Saved predictions to {preds_file}")
    print(f"Saved true fractions to {true_fractions_file}")


def eval_model(model: Model2dRNA, X_test, Y_test):
    print("\nEvaluating model on Y_test:")
    X_test = torch.tensor(X_test, dtype=torch.float32)
    Y_test = torch.tensor(Y_test, dtype=torch.float32)
    model.eval()
    with torch.no_grad():
        Y_pred = model(X_test.to(DEVICE)).cpu()

    target_min = Y_test.min()
    target_max = Y_test.max()
    target_mean = Y_test.mean()

    mae = torch.mean(torch.abs(Y_pred - Y_test)).item()
    rmse = torch.sqrt(torch.mean((Y_pred - Y_test) ** 2)).item()
    cosine = torch.nn.functional.cosine_similarity(Y_pred, Y_test, dim=1).mean().item()

    mae_pct_range = (mae / (target_max - target_min)) * 100
    mae_pct_mean = (mae / target_mean) * 100
    rmse_pct_range = (rmse / (target_max - target_min)) * 100
    rmse_pct_mean = (rmse / target_mean) * 100

    print(f" - Target value range: [{target_min:.4f}, {target_max:.4f}]")
    print(f" - Target value average: {target_mean:.4f}")
    print(f" - MAE: {mae:.4f}")
    print(f" - MAE as percentage of range: {mae_pct_range:.2f}%")
    print(f" - MAE as percentage of average: {mae_pct_mean:.2f}%")
    print(f" - RMSE: {rmse:.4f}")
    print(f" - RMSE as percentage of range: {rmse_pct_range:.2f}%")
    print(f" - RMSE as percentage of average: {rmse_pct_mean:.2f}%")
    print(f" - Cosine similarity: {cosine:.4f}")


In [150]:
input_dim = X_train.shape[1]
output_dim = Y_train.shape[1]
model = Model2dRNA(input_dim, output_dim)
epochs = 150
batch_size = 32

saved_model_path = None # "output/2dRNA/20241229_1515/model.pth"
if saved_model_path and os.path.exists(saved_model_path):
    model.load_state_dict(torch.load(saved_model_path))
    print(f"Loaded model from {saved_model_path}")
else:
    print("Training model...")
    train_model(model, train_dataset, test_dataset, epochs, batch_size)
    print("Training complete!")
    save_model(model, X_test, Y_test)

eval_model(model, X_test, Y_test)

Training model...
Epoch 1/150, Loss: 4.3714, Val Loss: 4.4960
Epoch 2/150, Loss: 1.5722, Val Loss: 0.4129
Epoch 3/150, Loss: 0.9700, Val Loss: 0.1413
Epoch 4/150, Loss: 0.5277, Val Loss: 0.0542
Epoch 5/150, Loss: 0.2799, Val Loss: 0.0141
Epoch 6/150, Loss: 0.1602, Val Loss: 0.0063
Epoch 7/150, Loss: 0.1303, Val Loss: 0.0070
Epoch 8/150, Loss: 0.1045, Val Loss: 0.0059
Epoch 9/150, Loss: 0.1007, Val Loss: 0.0051
Epoch 10/150, Loss: 0.0876, Val Loss: 0.0035
Epoch 11/150, Loss: 0.0789, Val Loss: 0.0046
Epoch 12/150, Loss: 0.0702, Val Loss: 0.0031
Epoch 13/150, Loss: 0.0637, Val Loss: 0.0035
Epoch 14/150, Loss: 0.0598, Val Loss: 0.0024
Epoch 15/150, Loss: 0.0594, Val Loss: 0.0031
Epoch 16/150, Loss: 0.0531, Val Loss: 0.0038
Epoch 17/150, Loss: 0.0484, Val Loss: 0.0019
Epoch 18/150, Loss: 0.0415, Val Loss: 0.0018
Epoch 19/150, Loss: 0.0392, Val Loss: 0.0020
Epoch 20/150, Loss: 0.0408, Val Loss: 0.0022
Epoch 21/150, Loss: 0.0378, Val Loss: 0.0023
Epoch 22/150, Loss: 0.0370, Val Loss: 0.0016
E