# Contrastive Learning

### Add the src folder to the path

In [14]:
import sys
import os

root_path = os.path.dirname(os.getcwd())
src_path = os.path.join(root_path, "src")
sys.path.insert(0, src_path)

In [15]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Read in the config file

In [16]:
import yaml
with open("./config.yaml", 'r') as stream:
    config = yaml.safe_load(stream)
config

{'neptune_project': 'fedorgrab/slide-seq-contrastive',
 'experiment_name': 'New Settings: 170',
 'experiment_tags': ['CL', 'Final', 'New Settings'],
 'simulation': {'MAX_EPOCHS': 2000,
  'BATCH_SIZE': 128,
  'CROP_SIZE': 224,
  'INPUT_SIZE': 224,
  'N_ELEMENT_MIN': 1000,
  'N_CROPS_TEST': 1600,
  'PIXEL_SIZE': 4,
  'NUMBER_OF_CHANNELS': 9,
  'RANDOM_SEED': 1},
 'model_settings': {'BACKBONE_TYPE': 'resnet18',
  'BACKBONE_NUM_FTRS': 128,
  'PROJECTION_OUT_DIM': 128,
  'INPUT_CHANNELS': 9},
 'optimizer': {'LEARNING_RATE': 0.001,
  'IS_SCHEDULED': True,
  'SCHEDULER_STEP_SIZE': 3,
  'SCHEDULER_GAMMA': 0.995}}

### Common import and set random seed

In [17]:
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import torchvision
import torch
import pytorch_lightning as pl
import tissue_purifier as tp

matplotlib.rcParams["figure.dpi"] = 200
torch.manual_seed(config['simulation']['RANDOM_SEED'])
np.random.seed(config['simulation']['RANDOM_SEED'])

### Read in all the csv file

In [18]:
#data_folder = "/home/jupyter/data/slide-seq/original_data/"
data_folder = "../slide-seq-data"

df_wt1 = pd.read_csv(os.path.join(data_folder, "wt_1.csv"))
df_wt2 = pd.read_csv(os.path.join(data_folder, "wt_2.csv"))
df_wt3 = pd.read_csv(os.path.join(data_folder, "wt_3.csv"))
df_dis1 = pd.read_csv(os.path.join(data_folder, "sick_1.csv"))
df_dis2 = pd.read_csv(os.path.join(data_folder, "sick_2.csv"))
df_dis3 = pd.read_csv(os.path.join(data_folder, "sick_3.csv"))

### Use all the tissues to create the trainloader

In [19]:
from tissue_purifier.data_utils.helpers import define_trainloader

all_df = [df_wt1, df_wt2, df_wt3, df_dis1, df_dis2, df_dis3]
labels_sparse_images = [0, 0, 0, 1, 1, 1]
names_sparse_images = ["wt1", "wt2", "wt3", "dis1", "dis2", "dis3"]
        
sparse_images = [
    tp.data_utils.SparseImage.from_panda(
        df, x="x", y="y", category="cell_type", pixel_size=config["simulation"]["PIXEL_SIZE"], padding=10
    )
    for df in all_df
]

trainloader = define_trainloader(sparse_images,
                                 labels_sparse_images,
                                 names_sparse_images,
                                 config,
                                 simclr_output=True)

number of elements ---> 31659
The dense shape of the image is -> torch.Size([9, 1168, 1168])
number of elements ---> 33059
The dense shape of the image is -> torch.Size([9, 1170, 845])
number of elements ---> 39206
The dense shape of the image is -> torch.Size([9, 1169, 1170])
number of elements ---> 27194
The dense shape of the image is -> torch.Size([9, 1166, 1170])
number of elements ---> 42776
The dense shape of the image is -> torch.Size([9, 1170, 1170])
number of elements ---> 33441
The dense shape of the image is -> torch.Size([9, 1154, 1155])


### Create the model, optimizer and loss 

In [20]:
import lightly
from tissue_purifier.model_utils.encoder import Encoder

model = tp.model_utils.helpers.define_model(
    backbone_type=config["model_settings"]["BACKBONE_TYPE"],
    number_of_channels=config["model_settings"]["INPUT_CHANNELS"],
    num_of_filters=config["model_settings"]["BACKBONE_NUM_FTRS"],
    projection_out_dim=config["model_settings"]["PROJECTION_OUT_DIM"],
)

optimizer, scheduler = tp.model_utils.helpers.define_optimizer_and_scheduler(
    model=model,
    num_epochs=config["simulation"]["MAX_EPOCHS"],
    learning_rate=config["optimizer"]["LEARNING_RATE"], 
    scheduler_step_size=config["optimizer"]["SCHEDULER_STEP_SIZE"],
    scheduler_gamma=config["optimizer"]["SCHEDULER_GAMMA"]
)

criterion = tp.loss_utils.helpers.define_contrastive_loss()

encoder = Encoder(
    model, criterion, optimizer, trainloader, scheduler
)

In [12]:
#model

## Train the model on GPU if available

In [None]:
gpus = 1 if torch.cuda.is_available() else 0
model_folder = "../trained_model"
os.makedirs(model_folder, exist_ok=True)

encoder.train_embedding(
    gpus=gpus, 
    progress_bar_refresh_rate=0, 
    max_epochs=MAX_EPOCHS, 
    log_every_n_steps=1
)

torch.save(model.state_dict(), os.path.join(model_folder, "simclr_model.pt"))

## Evaluate the model

In [21]:
from tissue_purifier.data_utils.helpers import define_testloader

testloader = define_testloader(sparse_images,
                               labels_sparse_images,
                               names_sparse_images,
                               config)

embeddings, labels, fnames = encoder.embed_by_backbone(testloader, 
                                                       device=sparse_images[0].device, 
                                                       to_numpy=False)

Compute efficiency: 0.35:   8%|▊         | 1/13 [00:22<04:27, 22.26s/it]


KeyboardInterrupt: 

### Check the embeddings

In [22]:
print("embeddings.shape -->", embeddings.shape)
tp.plot_utils.plot_knn_examples(embeddings, test_dataloader, figsize=(5, 8), n_neighbors=5)

NameError: name 'embeddings' is not defined

In [23]:
umap_embeddings = tp.evaluation_utils.get_umap(embeddings=embeddings)

NameError: name 'embeddings' is not defined

In [None]:
tp.plot_utils.umap_binary_label(umap_embedded=umap_embeddings, labels=labels)

In [None]:
morans = tp.evaluation_utils.get_morans(test_dataloader, all_df=all_df, pixel_size=PIXEL_SIZE, crop_size=CROP_SIZE)
tp.plot_utils.umap(umap_embedded=umap_embeddings, colors=morans)

In [None]:
pca_emb = tp.evaluation_utils.get_pca(embeddings)
tp.plot_utils.pca(pca_emb, morans)

In [None]:
tsne_emb = tp.evaluation_utils.get_tsne(embeddings)
tp.plot_utils.tsne(tsne_emb, morans)

In [None]:
tp.evaluation_utils.create_projector(
    test_dataloader, embeddings, {"labels": labels, "morans": morans}
)