# Self Supervised Learning

This notebook demonstrates how to train a self supervised model on a collection of anndata objects and produce a "complete" checkpoint file. \
Here we are going to train the model for a short time just for demonstration. \
To reproduce the results in the paper run the scripts with provided config.yaml file. \
See documentation for details. 

## Common imports

In [1]:
# TODO remove this once the notebook is stable
%load_ext autoreload
%autoreload 2

In [5]:
import tissue_purifier as tp

from tissue_purifier.models import Barlow
a = tp.models.Barlow()

TypeError: __init__() missing 13 required positional arguments: 'backbone_type', 'image_in_ch', 'head_hidden_chs', 'head_out_ch', 'lambda_off_diagonal', 'optimizer_type', 'warm_up_epochs', 'warm_down_epochs', 'max_epochs', 'min_learning_rate', 'max_learning_rate', 'min_weight_decay', and 'max_weight_decay'

In [4]:
help(tp.models)

Help on package tissue_purifier.models in tissue_purifier:

NAME
    tissue_purifier.models

PACKAGE CONTENTS
    _optim_scheduler
    classifier_regressor (package)
    logger
    patch_analyzer (package)
    ssl_models (package)

CLASSES
    pytorch_lightning.loggers.neptune.NeptuneLogger(pytorch_lightning.loggers.base.LightningLoggerBase)
        tissue_purifier.models.logger.NeptuneLoggerCkpt
    
    class NeptuneLoggerCkpt(pytorch_lightning.loggers.neptune.NeptuneLogger)
     |  NeptuneLoggerCkpt(**kargs)
     |  
     |  Thin wrapper around the Neptune Logger with the after_save_checkpoint specified
     |  
     |  Method resolution order:
     |      NeptuneLoggerCkpt
     |      pytorch_lightning.loggers.neptune.NeptuneLogger
     |      pytorch_lightning.loggers.base.LightningLoggerBase
     |      abc.ABC
     |      builtins.object
     |  
     |  Methods defined here:
     |  
     |  __init__(self, **kargs)
     |      Initialize self.  See help(type(self)) for accurate

In [None]:
import numpy
import torch
import seaborn
import tarfile
import os
import matplotlib
import matplotlib.pyplot as plt
from anndata import read_h5ad

# import tissue purifier
import tissue_purifier as tp

### Download and untar the example dataset

In [None]:
import tissue_purifier.io

bucket_name = "ld-data-bucket"
data_source_path = "tissue-purifier/slideseq_testis_anndata_h5ad.tar.gz"
data_destination_path = "./slideseq_testis_anndata_h5ad.tar.gz"
data_destination_folder = "./testis_anndata"

# download data from google bucket
# tp.io.download_from_bucket(bucket_name, data_source_path, data_destination_path)

# untar the data
with tarfile.open(data_destination_path, "r:gz") as fp:
    fp.extractall(path=data_destination_folder)

# Make a list of all the h5ad files in the data_destination_folder
fname_list = []
for f in os.listdir(data_destination_folder):
    if f.endswith('.h5ad'):
        fname_list.append(f)
print(fname_list)

At this point we have a folder with six 'h5ad' files corresponding to different tissues 

### Visualize the six anndata as a sanity check

In [None]:
# read in all the anndata

anndata_list = []
for fname in fname_list:
    anndata = read_h5ad(os.path.join(data_destination_folder, fname))
    print("Loaded {}".format(fname))
    anndata_list.append(anndata)

Each anndata contains the gene expression of ~20K genes for ~30K cells. \
Moreover each cell has 3 annotations: 'x' and 'y' coordinates and 'cell_type' information.

In [None]:
anndata_list[0]

Plot the cell_types in space

In [None]:
ncols=3
nrows=2
fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(6*ncols,6*nrows))

# Define an consistent mapping from cell_type and color for all the plots
cell_types = numpy.asarray(anndata_list[0].obs['cell_type'].values)
unique_cell_types = numpy.unique(cell_types)

n = -1
for r in range(nrows):
    for c in range(ncols):
        n += 1
        anndata_tmp = anndata_list[n]
        cell_types = numpy.asarray(anndata_tmp.obs['cell_type'].values)
        x = numpy.asarray(anndata_tmp.obs['x'].values)
        y = numpy.asarray(anndata_tmp.obs['y'].values) 
        seaborn.scatterplot(x=x, y=y, hue=cell_types, ax=axes[r,c], size=numpy.ones_like(x), sizes=(10, 10), hue_order=unique_cell_types) 
        _ = axes[r,c].set_title(fname_list[n])

Plot the cell_type counts in all the tissues 

In [None]:
x_labels_rotation = 90
fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(6*ncols,6*nrows))
n = -1
for r in range(nrows):
    for c in range(ncols):
        n += 1
        counts = anndata_list[n].obs["cell_type"].value_counts(sort=False)
        x = numpy.asarray(counts.index)
        y = counts.to_numpy()
        _ = seaborn.barplot(x=x, y=y, ax=axes[r,c])
        x_labels_raw = axes[r,c].get_xticklabels()
        axes[r,c].set_xticklabels(labels=x_labels_raw, rotation=x_labels_rotation)
        _ = axes[r,c].set_title(fname_list[n])

### Instantiate the DataModule (i.e. define train/test/val dataloader)

Here we use the defaults parameters for the datamodule and only define:
1. The mapping from cell_type to channels in the image. Each channel will represent the density of a specific cell_type. In some situation it may make sense to map multiple cell-types to the same channel. For example CD4+ and CD4- cells might be mapped to the same channel. 
2. The folder with the anndata h5ad files

In [None]:
from tissue_purifier.data import AnndataFolderDM

n_unique_cell_types = len(unique_cell_types)
categories_to_channels = dict(zip(unique_cell_types, range(n_unique_cell_types)))
print(categories_to_channels)

config_dm = tp.data.AnndataFolderDM.get_default_params() # get the defaults parameters
config_dm["root"] = data_destination_folder  # specify the folder with the anndata h5ad files
config_dm["categories_to_channels"] = categories_to_channels  # specify the mapping between cell_types and channels

dm = tp.data.AnndataFolderDM(**config_dm)

config_dm

### Instantiate the Model

Here we use the Barlow but the same apporach works for Dino, Vae, Simclr

We use the defaults parameters and only change the number of input channels of the image. 

In [None]:
from tissue_purifier.models import Barlow
# from tissue_purifier.models import Simclr
# from tissue_purifier.models import Dino
# from tissue_purifier.models import Vae

config_model = tp.models.Barlow.get_default_params()  # get the default parameters
config_model['image_in_ch'] = dm.ch_in  # specify the number of input channels consistently with datamodule
config_model

# DO NOT DO THIS
# model = tp.models.Barlow(**config_model)  
# This will work but the resulting checkpoint will not include the configuration of the datamodule and the 
# resulting ckpt file will be "incomplete".

# DO THIS INSTEAD
config_model.update(config_dm)  # concatenate the two configuration dictionaries
model = tp.models.Barlow(**config_model)  
# Now the checkpoint contains the full information to reproduce the simulation.
# To reproduce the results in the paper you need to use the config.yaml file provided.

### Train the model and save the final checkpoint

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from tissue_purifier.models import NeptuneLoggerCkpt

In [None]:
pl.seed_everything(seed=0, workers=True)

We use Neptune to log our results. \
Here Neptune is run in the 'offline' mode and the logs are written to local disk. \
Neptune can run in 'async' mode and the results will be saved on a remote database with a nice graphic interface. \
For the 'async' option to work you need to sign up for a free account and provide the correct project and api_key. \
See https://docs.neptune.ai/ for more info.

In [None]:
pl_neptune_logger = NeptuneLoggerCkpt(
    api_key="ANONYMOUS",  # replace with your own
    project='cellarium/tissue-purifier', # replace with your own
    run=None,  # if None a new run will be logged. If the run is provided the result will be appended to existing run  
    log_model_checkpoints=True, 
    mode="offline",  # "async"
    tags=["test"],
    fail_on_exception=True,  
)

# Save the checkpoint periodically during training
ckpt_train = ModelCheckpoint(
    save_weights_only=False,
    save_on_train_epoch_end=True,
    save_last=True,
    every_n_epochs=3,
)
    
# Define the trainer
pl_trainer = Trainer(
    weights_save_path="saved_ckpt",
    callbacks=[ckpt_train],
    gpus=torch.cuda.device_count(),  # number of gpu cards on a single machine to use
    check_val_every_n_epoch=10,
    num_sanity_val_steps=0,
    max_epochs=1, #config_model["max_epochs"],  # run for a single epoch
    logger=pl_neptune_logger,
    log_every_n_steps=100,
    sync_batchnorm=True,
)

In [None]:
pl_trainer.fit(model=model, datamodule=dm)

### Save the final checkpoint. 

Since the model was instantiate using a dictionary containing all the parameters (both for model and datamodule) the checkpoint is complete. \
A single checkpoint file contains all the information needded to reproduce the simulation.   

In [None]:
pl_trainer.save_checkpoint("ckpt_barlow.pt")  

### Visualize the crops used for training  as a sanity check

In [None]:
from tissue_purifier.plots import show_raw_all_channels, show_raw_one_channel

In [None]:
train_loader = dm.train_dataloader()  # get the train_dataloader from the datamodule
batch = next(iter(train_loader))  # get one batch from the dataloader
list_sp_imgs, list_labels, list_metadata = batch  # batch consists of 3 lists: sparse_images, labels, metadata

In [None]:
n_examples = 5  # number of distinct crops
n_augmentations = 3  # apply the random data augmentation this many times

all_imgs = []
for n in range(n_augmentations):
    imgs_tmp = dm.trsfm_train_global(list_sp_imgs[:n_examples])  # apply the data augmentations
    all_imgs.append(imgs_tmp)
    
imgs_train = torch.cat(all_imgs, dim=0)
print("imgs_train.shape ->", imgs_train.shape)

In [None]:
# Each column is a different patch
# Each row is a different instance of the random data-augmentation

titles = []
for r in range(n_augmentations):
    for c in range(n_examples):
        titles.append("crop = {}, augmentation = {}".format(c,r))

train_all_ch_fig = show_raw_all_channels(imgs_train, 
                                         cmap="viridis", 
                                         n_col=n_examples, 
                                         figsize=(4*n_examples, 4*n_augmentations), 
                                         sup_title="Train crops, all channels",
                                        titles=titles)
train_all_ch_fig

In [None]:
train_one_ch_fig = show_raw_one_channel(imgs_train[0], 
                                        n_col=3,  
                                        sup_title="One crop used for training",
                                        titles=list(unique_cell_types))
train_one_ch_fig