# Embedding and Imputation on Ramani et al. scHi-C

## Notes

This tutorial uses the new API of Higashi (wrapping all functions of Higashi into the Higashi() class).
The old API of Higashi will still be supported and maintained).
Please check changelog for the current status of migration from the old API to the new API.

## Preparation

### Download input files
Download the demo data from the following link (Ramani et al.)
https://drive.google.com/drive/folders/1S0KOMAj60MxQP6mgPV1OKjn_J-lVpzKM?usp=sharing

The dataset contains 620 cells from the ML1/ML3 library of the Ramani et al. dataset.

Change the file path in the corresponding JSON file according to the location of the downloaded files.

### Install Higashi

1. install pytorch>=1.8.0 with cuda support when available.
2. `conda install -c ruochiz higashi`
(Although higashi would install pytorch when needed, there is no guarantee that it will install the correct version with cuda support. It is recommended to install pytorch separately before higashi.)

## Start running Higashi¶

### 1. Import package, set the path to the configuration JSON file.¶

In [1]:
cd ..

/Users/fengcong/Desktop/0825-course/Guided Study/Code/Higashi


In [2]:
from higashi.Higashi_wrapper import *
config = "./Data/config_ramani.JSON"
higashi_model = Higashi(config)

### 2. Process data for higashi model

In [None]:
# higashi_model.process_data()

: 

### 3. Prep the higashi model for training and imputation & Stage 1 training

In [14]:
def save_model_from_checkpoint():
    import os

    sd = torch.load(os.path.expanduser("./Temp/model/model.chkpt_stage1"))
    higashi_model.higashi_model.load_state_dict(sd['model_link'], strict=False)
    torch.save(higashi_model.higashi_model, os.path.expanduser("./Temp/model/model.chkpt_stage1_model"))


In [3]:
higashi_model.prep_model()
# Stage 1 training
# higashi_model.train_for_embeddings()

cpu_num 8
training on data from: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']
total_sparsity_cell 0.025731658257764963
no contractive loss
batch_size 256
Node type num [620 250 244 199 192 181 172 160 147 142 136 136 134 116 108 103  91  82
  79  60  64  49  52 156] [ 620  870 1114 1313 1505 1686 1858 2018 2165 2307 2443 2579 2713 2829
 2937 3040 3131 3213 3292 3352 3416 3465 3517 3673]
start making attribute


  0%|          | 0/300 [00:00<?, ?it/s]

loss 0.6227015852928162 loss best 0.6080297827720642 epochs 122

initializing data generator


  0%|          | 0/23 [00:00<?, ?it/s]

initializing data generator


  0%|          | 0/23 [00:00<?, ?it/s]

In [15]:
# save_model_from_checkpoint()

In [4]:
higashi_model.train_for_phasing()
higashi_model.phase()

  0%|          | 0/23 [00:00<?, ?it/s]

  0%|          | 0/23 [00:00<?, ?it/s]

[ Epoch 0 of 30 ]


 - (Training) :   0%|          | 0/1000 [00:00<?, ?it/s]

TypeError: '<=' not supported between instances of 'float' and 'NoneType'

In [None]:
higashi_model.train_for_imputation_nbr_0()
higashi_model.impute_no_nbr()

### 5. Stage 3 training and imputation with neighbor information

In [None]:
higashi_model.train_for_imputation_with_nbr()
higashi_model.impute_with_nbr()

### 5. Visulizing embedding results

In [None]:
# Visualize embedding results
cell_embeddings = higashi_model.fetch_cell_embeddings()
print (cell_embeddings.shape)

from umap import UMAP
from sklearn.decomposition import PCA
import seaborn as sns
import matplotlib.pyplot as plt

cell_type = higashi_model.label_info['cell type']
fig = plt.figure(figsize=(14, 5))
ax = plt.subplot(1, 2, 1)
vec = PCA(n_components=2).fit_transform(cell_embeddings)
sns.scatterplot(x=vec[:, 0], y=vec[:, 1], hue=cell_type, ax=ax, s=6, linewidth=0)
handles, labels = ax.get_legend_handles_labels()
labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
ax.legend(handles=handles, labels=labels, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., ncol=1)
ax = plt.subplot(1, 2, 2)
vec = UMAP(n_components=2).fit_transform(cell_embeddings)
sns.scatterplot(x=vec[:, 0], y=vec[:, 1], hue=cell_type, ax=ax, s=6, linewidth=0)
handles, labels = ax.get_legend_handles_labels()
labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
ax.legend(handles=handles, labels=labels, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., ncol=1)
plt.tight_layout()
plt.show()



### 6. Visualizing imputation results

In [None]:
count = 0
fig = plt.figure(figsize=(6, 2*5))
for id_ in np.random.randint(0, 620, 5):
    ori, nbr0, nbr5 = higashi_model.fetch_map("chr3", id_)
    count += 1
    ax = plt.subplot(5, 3, count * 3 - 2)
    ax.imshow(ori.toarray(), cmap='Reds', vmin=0.0, vmax=np.quantile(ori.data, 0.6))
    ax.set_xticks([], [])
    ax.set_yticks([], [])
    if count == 1:
        ax.set_title("raw")

    ax = plt.subplot(5, 3, count * 3 - 1)
    ax.imshow(nbr0.toarray(), cmap='Reds', vmin=0.0, vmax=np.quantile(nbr0.data, 0.95))
    ax.set_xticks([], [])
    ax.set_yticks([], [])
    if count == 1:
        ax.set_title("higashi, k=0")

    ax = plt.subplot(5, 3, count * 3)
    ax.imshow(nbr5.toarray(), cmap='Reds', vmin=0.0, vmax=np.quantile(nbr5.data, 0.95))
    ax.set_xticks([], [])
    ax.set_yticks([], [])
    if count == 1:
        ax.set_title("higashi, k=5")

plt.tight_layout()