In [None]:
import pandas as pd
from rfphate import RFPHATE
from phate import PHATE
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import umap
import sys
sys.path.append('../..')

In [None]:
X = pd.read_csv('../data/rnaseq/data.csv')
y = pd.read_csv('../data/rnaseq/labels.csv', dtype={"Class": "category"})

In [None]:
dropout_rates = [0, 25, 50, 75]
n_rows, n_cols = X.shape

for rate in dropout_rates:
    X_dropout = X.copy()
    if rate > 0:
        mask = np.random.rand(n_rows, n_cols) < (rate / 100)
        X_dropout = X_dropout.to_numpy().copy()
        X_dropout[mask] = 0
    globals()[f'X_{rate:02d}'] = X_dropout


In [None]:
oob_scores = []
fig, axes = plt.subplots(
    3, 4, figsize=(24, 18),
    gridspec_kw={'hspace': 0, 'wspace': 0}  # remove vertical space
)
axes = axes.reshape(3, 4)

# First row: RF-PHATE
for idx, rate in enumerate(dropout_rates):
    rfphate_op = RFPHATE(prediction_type='classification', random_state=42, oob_score=True)
    rfphate_op.fit(globals()[f'X_{rate:02d}'], y)
    emb = rfphate_op.transform(globals()[f'X_{rate:02d}'])
    oob_scores.append(rfphate_op.oob_score_)
    ax = axes[0, idx]
    sns.scatterplot(
        x=emb[:, 0], y=emb[:, 1], hue=y['Class'],
        palette='Paired', s=50, ax=ax, edgecolor=None,
        legend=(idx == 0)
    )
    ax.set_title(f'OOB: {rfphate_op.oob_score_:.3f}', fontsize = 20)
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_xticks([])
    ax.set_yticks([])
# y-axis label for row
axes[0, 0].set_ylabel("RF-PHATE", fontsize=20)

# Second row: PHATE
for idx, rate in enumerate(dropout_rates):
    phate_op = PHATE(random_state=42)
    phate_op.fit(globals()[f'X_{rate:02d}'])
    emb = phate_op.transform(globals()[f'X_{rate:02d}'])
    ax = axes[1, idx]
    sns.scatterplot(
        x=emb[:, 0], y=emb[:, 1], hue=y['Class'],
        palette='Paired', s=50, ax=ax, legend=False, edgecolor=None
    )
    # ax.set_title(f'{rate}% Dropout')
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_xticks([])
    ax.set_yticks([])
axes[1, 0].set_ylabel("PHATE", fontsize=20)

# Third row: UMAP
for idx, rate in enumerate(dropout_rates):
    umap_op = umap.UMAP(random_state=42)
    emb = umap_op.fit_transform(globals()[f'X_{rate:02d}'])
    ax = axes[2, idx]
    sns.scatterplot(
        x=emb[:, 0], y=emb[:, 1], hue=y['Class'],
        palette='Paired', s=50, ax=ax, legend=False, edgecolor=None
    )
    # ax.set_title(f'{rate}% Dropout')
    ax.set_xlabel(f'{rate}% Dropout', fontsize = 20)  # put dropout rate as x-axis label
    ax.set_ylabel('')
    ax.set_xticks([])
    ax.set_yticks([])
axes[2, 0].set_ylabel("UMAP", fontsize=20)

# Add one shared x-axis label
# fig.text(0.5, 0.04, "Dropout Rate (%)", ha="center", fontsize=18)

plt.tight_layout(pad=0.0)  # fully collapse spacing
plt.savefig('rnaseq-vis.pdf', bbox_inches="tight")
plt.show()
