In [None]:
!pip install matplotlib scanpy

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier
from sklearn.datasets import load_boston, load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, mean_absolute_error
from collections import Counter
import numpy as np
import pandas as pd
np.set_printoptions(precision=3)
import jax
import jax.numpy as jnp
from typing import Tuple, Union

In [None]:
import scanpy as sc

from warnings import filterwarnings
filterwarnings('ignore')

sc.settings.verbosity = 3
sc.logging.print_header()
sc.settings.set_figure_params(dpi=80, facecolor='white')

scanpy==1.9.1 anndata==0.8.0 umap==0.5.3 numpy==1.21.6 scipy==1.7.3 pandas==1.3.5 scikit-learn==1.0.2 statsmodels==0.12.2 pynndescent==0.5.8


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# 1.1 Datasets

In [None]:
#Classification Dataset
#X_class, y_class = load_breast_cancer(return_X_y=True)
#X_train_c, X_test_c, y_train_c, y_test_c = train_test_split(X_class, y_class)

### Regression Dataset
#X_reg, y_reg = load_boston(return_X_y=True)
#X_train_r, X_test_r, y_train_r, y_test_r = train_test_split(X_reg, y_reg)

In [None]:
!wget https://saturn-public-data.s3.us-east-2.amazonaws.com/cancer-immunotherapy-challenge/data/sc_training.h5ad

--2023-01-09 20:37:19--  https://saturn-public-data.s3.us-east-2.amazonaws.com/cancer-immunotherapy-challenge/data/sc_training.h5ad
Resolving saturn-public-data.s3.us-east-2.amazonaws.com (saturn-public-data.s3.us-east-2.amazonaws.com)... 52.219.94.50
Connecting to saturn-public-data.s3.us-east-2.amazonaws.com (saturn-public-data.s3.us-east-2.amazonaws.com)|52.219.94.50|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1291266853 (1.2G) [binary/octet-stream]
Saving to: ‘sc_training.h5ad.1’


2023-01-09 20:37:54 (36.2 MB/s) - ‘sc_training.h5ad.1’ saved [1291266853/1291266853]



In [None]:
!wget https://saturn-public-data.s3.us-east-2.amazonaws.com/cancer-immunotherapy-challenge/code/sc_training_visualization.ipynb

--2023-01-09 20:37:54--  https://saturn-public-data.s3.us-east-2.amazonaws.com/cancer-immunotherapy-challenge/code/sc_training_visualization.ipynb
Resolving saturn-public-data.s3.us-east-2.amazonaws.com (saturn-public-data.s3.us-east-2.amazonaws.com)... 52.219.110.194
Connecting to saturn-public-data.s3.us-east-2.amazonaws.com (saturn-public-data.s3.us-east-2.amazonaws.com)|52.219.110.194|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 446576 (436K) [binary/octet-stream]
Saving to: ‘sc_training_visualization.ipynb.1’


2023-01-09 20:37:54 (1.39 MB/s) - ‘sc_training_visualization.ipynb.1’ saved [446576/446576]



In [None]:
!wget https://saturn-public-data.s3.us-east-2.amazonaws.com/cancer-immunotherapy-challenge/data/clone_information.csv

--2023-01-09 20:37:55--  https://saturn-public-data.s3.us-east-2.amazonaws.com/cancer-immunotherapy-challenge/data/clone_information.csv
Resolving saturn-public-data.s3.us-east-2.amazonaws.com (saturn-public-data.s3.us-east-2.amazonaws.com)... 52.219.110.194
Connecting to saturn-public-data.s3.us-east-2.amazonaws.com (saturn-public-data.s3.us-east-2.amazonaws.com)|52.219.110.194|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1542596 (1.5M) [text/csv]
Saving to: ‘clone_information.csv.1’


2023-01-09 20:37:55 (3.22 MB/s) - ‘clone_information.csv.1’ saved [1542596/1542596]



In [None]:
!wget https://saturn-public-data.s3.us-east-2.amazonaws.com/cancer-immunotherapy-challenge/data/guide_abundance.csv

--2023-01-09 20:37:56--  https://saturn-public-data.s3.us-east-2.amazonaws.com/cancer-immunotherapy-challenge/data/guide_abundance.csv
Resolving saturn-public-data.s3.us-east-2.amazonaws.com (saturn-public-data.s3.us-east-2.amazonaws.com)... 52.219.110.194
Connecting to saturn-public-data.s3.us-east-2.amazonaws.com (saturn-public-data.s3.us-east-2.amazonaws.com)|52.219.110.194|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4452 (4.3K) [text/csv]
Saving to: ‘guide_abundance.csv.1’


2023-01-09 20:37:56 (310 MB/s) - ‘guide_abundance.csv.1’ saved [4452/4452]



In [None]:
!wget https://saturn-public-data.s3.us-east-2.amazonaws.com/cancer-immunotherapy-challenge/data/scRNA_ATAC.h5

--2023-01-09 20:37:56--  https://saturn-public-data.s3.us-east-2.amazonaws.com/cancer-immunotherapy-challenge/data/scRNA_ATAC.h5
Resolving saturn-public-data.s3.us-east-2.amazonaws.com (saturn-public-data.s3.us-east-2.amazonaws.com)... 52.219.110.194
Connecting to saturn-public-data.s3.us-east-2.amazonaws.com (saturn-public-data.s3.us-east-2.amazonaws.com)|52.219.110.194|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 87606852 (84M) [application/x-hdf5]
Saving to: ‘scRNA_ATAC.h5.1’


2023-01-09 20:37:59 (28.8 MB/s) - ‘scRNA_ATAC.h5.1’ saved [87606852/87606852]



For *read_h5ad()*, see [scanpy documentation](https://scanpy.readthedocs.io/en/stable/generated/scanpy.read_h5ad.html)

In [None]:
adata = sc.read_h5ad('./sc_training.h5ad')
adata

AnnData object with n_obs × n_vars = 28697 × 15077
    obs: 'gRNA_maxID', 'state', 'condition', 'lane'
    layers: 'rawcounts'

For *adata* object type, see [AnnData documentation](https://anndata.readthedocs.io/en/stable/generated/anndata.AnnData.html#anndata.AnnData)

In [None]:
adata.obs

Unnamed: 0,gRNA_maxID,state,condition,lane
053l1_AAACCTGAGATGTCGG-1,ONE-NON-GENE-SITE-7,terminal exhausted,Unperturbed,lane1
053l1_AAACCTGAGCAACGGT-1,Tox2-3,effector,Tox2,lane1
053l1_AAACCTGAGTACGACG-1,Tpt1-2,effector,Tpt1,lane1
053l1_AAACCTGAGTCGTTTG-1,Tox2-3,terminal exhausted,Tox2,lane1
053l1_AAACCTGAGTGAAGAG-1,Tcf7-2,effector,Tcf7,lane1
...,...,...,...,...
053l4_TTTGTCATCAGGTTCA-1,Tox2-3,other,Tox2,lane4
053l4_TTTGTCATCAGTGTTG-1,Dvl2-3,cycling,Dvl2,lane4
053l4_TTTGTCATCCTCGCAT-1,Zeb2-2,cycling,Zeb2,lane4
053l4_TTTGTCATCTTCAACT-1,Sox4-3,cycling,Sox4,lane4


# 1.1 Preprocess data

In [None]:
expression_data = adata.X
expression_data_sub = expression_data[:10, :10]
expression_data_sub.toarray()

array([[0.512, 0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   ,
        0.   ],
       [0.484, 0.484, 0.809, 0.   , 0.   , 0.   , 0.   , 0.   , 0.484,
        0.   ],
       [0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.694, 0.   , 0.   ,
        0.   ],
       [0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   ,
        0.   ],
       [0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   ,
        0.   ],
       [0.   , 1.089, 0.686, 0.686, 0.   , 0.   , 0.   , 0.   , 0.   ,
        0.   ],
       [0.   , 0.193, 0.493, 0.   , 0.   , 0.   , 0.193, 0.   , 0.724,
        0.   ],
       [0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   ,
        0.   ],
       [0.292, 0.292, 0.292, 0.292, 0.292, 0.   , 0.   , 0.   , 0.292,
        0.   ],
       [0.412, 0.704, 0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   ,
        0.   ]], dtype=float32)

Each row in this matrix represents a cell type (28,697 in total) and the column represents 15,077 genes. The elements indicate the expression levels. 

In [None]:
genes = adata.var.index.tolist()
expression_data_csc = expression_data.tocsc()

test_genes = ['Ets1', 'Fosb', 'Mafk', 'Stat3']
validation_genes = ['Aqr', 'Bach2', 'Bhlhe40']

test_indices = [genes.index(x) for x in test_genes if x in genes]
validation_indices = [genes.index(x) for x in validation_genes if x in genes]

X_test = expression_data_csc[:, test_indices]
X_test = X_test.tocsr()
X_val = expression_data_csc[:, validation_indices]
X_val = X_val.tocsr()

all_indices = jnp.arange(expression_data_csc.shape[1])
rm_indices = test_indices + validation_indices
rm_indices.sort()
keep_indices = jnp.setdiff1d(jnp.array(all_indices), jnp.array(rm_indices))

expression_data_csc = expression_data_csc[:, keep_indices]
X_train = expression_data_csc.tocsr()
X_train, X_test, X_val

(<28697x15070 sparse matrix of type '<class 'numpy.float32'>'
 	with 80404869 stored elements in Compressed Sparse Row format>,
 <28697x4 sparse matrix of type '<class 'numpy.float32'>'
 	with 55392 stored elements in Compressed Sparse Row format>,
 <28697x3 sparse matrix of type '<class 'numpy.float32'>'
 	with 36958 stored elements in Compressed Sparse Row format>)

Now we have our train, test and validation samples. 



In [None]:
obs = adata.obs
states = ['progenitor', 'effector', 'terminal exhausted', 'cycling', 'other']
ground_truth = {}

In [None]:
for gene in genes:
  state_frequency = {state: 0 for state in states}
  rows = obs[obs['condition'] == gene]

  if not rows.empty:  
    for index, row in rows.iterrows():
        state = row['state']
        if state not in states:
          state_frequency['other'] += 1
        else:
          state_frequency[state] += 1
    
    total_count = sum(state_frequency.values())
    for state, count in state_frequency.items():
        state_frequency[state] = count / total_count
    
    ground_truth[gene] = list(state_frequency.values())
len(ground_truth)

64

Now we have a dictionary that stores the ground truth, i.e. the cell state frequencies, for each gene. Note that each list in this dictionary sums to 1, as required. We have 64 ground truth vectors because of the 66 knockout experiments that are used for training purposes, 2 did not pass quality control (see: [challenge 1 document](https://drive.google.com/file/d/1rR5oIhETmyVu6Uo5BsWjIokdBZctKzCE/view))

In [None]:
y_data_dict = ground_truth.copy()
y_data = pd.DataFrame.from_dict(y_data_dict, orient='index', 
                                         columns=['progenitor', 'effector', 
                                         'terminal exhausted', 'cycling', 
                                         'other'])
y_data

# TODO: Make a train-test split for the y data


Unnamed: 0,progenitor,effector,terminal exhausted,cycling,other
Stat4,0.300000,0.300000,0.200000,0.166667,0.033333
Sp140,0.515152,0.121212,0.121212,0.242424,0.000000
Sp100,0.024476,0.288462,0.281469,0.391608,0.013986
Zeb2,0.017483,0.115385,0.282051,0.439394,0.145688
Nr4a2,0.197059,0.079412,0.300000,0.410294,0.013235
...,...,...,...,...,...
Crem,0.034166,0.425756,0.214192,0.308804,0.017083
Egr1,0.274678,0.278970,0.231760,0.214592,0.000000
Nr3c1,0.050667,0.090667,0.312000,0.533333,0.013333
Rela,0.228261,0.293478,0.141304,0.282609,0.054348


# 2. Random Forest

In [None]:
class RandomForest:
    
    def __init__(self, n_estimators:int=100, subsample:float=0.1, regression:bool=True, **kwargs)->None:
        self.n_estimators = n_estimators
        self.estimators = []
        self.subsample = subsample
      
        for _ in range(self.n_estimators):
            if regression:
                self.estimators.append(DecisionTreeRegressor(**kwargs))
            else:
                self.estimators.append(DecisionTreeClassifier(**kwargs))
            
               
    def bootstrap_sample(self, X:np.array, y:np.array)->Tuple[np.array,...]:
        n_samples = X.shape[0]
        idxs = np.random.choice(n_samples, size=int(self.subsample*n_samples), replace=True)
        return X[idxs], y[idxs]
    
    def fit(self, X:np.array, y:np.array)->None:
        for estimator in self.estimators:
            X_sample, y_sample = self.bootstrap_sample(X, y)
            estimator.fit(X_sample, y_sample)

        
    def predict(self, X:np.array)->np.array:
        preds = np.array([estimator.predict(X) for estimator in self.estimators])
        preds = np.swapaxes(preds, 0, 1)
        return np.array([self._most_common_pred(pred) for pred in preds])
        
    def _most_common_pred(self, y:np.array)->Union[float, int]:
        return Counter(y).most_common(1)[0][0]

## 3.1. Random Forest Classifier

In [None]:
rfc = RandomForest(regression=False)

In [None]:
rfc.fit(X_train_c, y_train_c)

In [None]:
print(accuracy_score(rfc.predict(X_test_c), y_test_c).round(3))

## 3.2. Random Forest Regressor

In [None]:
rfr = RandomForest(regression=True)

In [None]:
rfr.fit(X_train_r, y_train_r)

In [None]:
print(mean_absolute_error(y_test_r, rfr.predict(X_test_r)).round(3))

___

# 4. Gradient Boosting

In [None]:
# Loss functions
def MSE(y_true:jnp.array, y_pred:jnp.array):
    return jnp.mean(jnp.sum(jnp.square(y_true-y_pred)))

def CrossEntropy(y_true:jnp.array, y_proba:jnp.array):
    y_proba = jnp.clip(y_proba, 1e-5, 1 - 1e-5)
    return jnp.sum(- y_true * jnp.log(y_proba) - (1 - y_true) * jnp.log(1 - y_proba))

In [None]:
class GradientBoosting:

    def __init__(self, n_estimators:int=100, learning_rate:float=.1, regression:bool=True, **kwargs):
        self.n_estimators = n_estimators
        self.learning_rate = learning_rate
        self.regression = regression
        self.loss = MSE if self.regression else CrossEntropy

        self.estimators = []
        for _ in range(self.n_estimators):
                self.estimators.append(DecisionTreeRegressor(**kwargs))

    def fit(self, X:np.array, y:np.array):
        y_pred = np.full(np.shape(y), np.mean(y))
        for i, estimator in enumerate(self.estimators):
            gradient = jax.grad(self.loss, argnums=1)(y.astype(np.float32), y_pred.astype(np.float32))
            self.estimators[i].fit(X, gradient)
            update = self.estimators[i].predict(X)
            y_pred -= (self.learning_rate * update)

    def predict(self, X:np.array):
        y_pred = np.zeros(X.shape[0], dtype=np.float32)
        for estimator in self.estimators:
            y_pred -= (self.learning_rate * estimator.predict(X))
    
        if not self.regression:
            return np.where(1/(1 + np.exp(-y_pred))>.5, 1, 0)
        return y_pred

## 4.1 Gradient Boosting Regressor

In [None]:
gbr = GradientBoosting(regression=True)

In [None]:
gbr.fit(X_train_r, y_train_r)

In [None]:
print(mean_absolute_error(gbr.predict(X_test_r), y_test_r).round(3))

## 4.2 Gradient Boosting Classifier

In [None]:
gbc = GradientBoosting(regression=False)

In [None]:
gbc.fit(X_train_c, y_train_c)

In [None]:
print(accuracy_score(gbc.predict(X_test_c), y_test_c).round(3))